diff --git a/.gitignore b/.gitignore index 896b38a12..24b47eaa5 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,5 @@ save* .log *.pid *.ipynb* +models/ +output_*HiFloat4/ diff --git a/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml new file mode 100644 index 000000000..1540e8b75 --- /dev/null +++ b/configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml @@ -0,0 +1,53 @@ +base: + seed: &seed 42 +model: + type: Wan2T2V + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-A14B-Diffusers + torch_dtype: auto +calib: + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + sample_steps: 20 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed +eval: + eval_pos: [transformed, fake_quant] + type: video_gen + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_awq/ +quant: + video_gen: + method: Awq + weight: + # quant_type: int-quant + quant_type: hif4 + bit: 4 + symmetric: True + granularity: per_channel + group_size: -1 + act: + # quant_type: int-qu + quant_type: hif4 + bit: 4 + symmetric: True + granularity: per_token + special: + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True +save: + save_lightx2v: True + save_path: ../lightx2v/wan2_2_t2v_awq_w_a/x2v/ diff --git a/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml index 680fab43b..1b1097ad7 100755 --- a/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml @@ -2,7 +2,7 @@ base: seed: &seed 42 model: type: WanI2V - path: /path/to/model + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-A14B/ torch_dtype: auto calib: name: i2v @@ -31,12 +31,12 @@ quant: video_gen: method: Awq weight: - bit: 8 + bit: 4 symmetric: True granularity: per_channel group_size: -1 act: - bit: 8 + bit: 4 symmetric: True granularity: per_token special: @@ -46,4 +46,4 @@ quant: clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: ../lightx2v/wan_i2v_awq_w_a/x2v/ diff --git a/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8_example.yaml b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8_example.yaml new file mode 100644 index 000000000..adba728d0 --- /dev/null +++ b/configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8_example.yaml @@ -0,0 +1,57 @@ +# Wan2.1 I2V FP8 量化配置示例 +# 这是一个快速开始的配置文件,请根据实际情况修改路径 + +base: + seed: &seed 42 + +model: + type: WanI2V + path: /path/to/wan2.1-i2v-model # 修改为你的 Wan2.1 I2V 模型路径 + torch_dtype: auto + +calib: + name: i2v + download: False + path: /path/to/calibration/data # 修改为你的校准数据路径 + sample_steps: 40 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed + +eval: + eval_pos: [fake_quant] + type: video_gen + name: i2v + download: False + path: /path/to/eval/data # 修改为你的评估数据路径 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_fp8/ + +quant: + video_gen: + method: SmoothQuant + weight: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_channel + use_qtorch: True + act: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_token + use_qtorch: True + special: + alpha: 0.75 # SmoothQuant 平衡参数,范围 0.5-1.0 + +save: + save_lightx2v: True # 保存为 lightx2v 兼容格式 + save_path: /path/to/save/quantized/model # 修改为你的保存路径 diff --git a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml index 14d05479d..ec6d8714e 100755 --- a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml @@ -2,12 +2,12 @@ base: seed: &seed 42 model: type: WanT2V - path: /path/to/wan_t2v + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.1-T2V-14B-Diffusers torch_dtype: auto calib: name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ sample_steps: 20 bs: 1 target_height: 480 @@ -20,7 +20,7 @@ eval: type: video_gen name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ bs: 1 target_height: 480 target_width: 832 @@ -31,12 +31,12 @@ quant: video_gen: method: Awq weight: - bit: 6 + bit: 4 symmetric: True granularity: per_channel group_size: -1 act: - bit: 6 + bit: 4 symmetric: True granularity: per_token special: @@ -46,4 +46,4 @@ quant: clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: ../lightx2v/wan_t2v_awq_w_a/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml b/configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml new file mode 100755 index 000000000..f140839e3 --- /dev/null +++ b/configs/quantization/video_gen/wan_t2v/awq_w_a_s.yaml @@ -0,0 +1,49 @@ +base: + seed: &seed 42 +model: + type: WanT2V + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.1-T2V-1.3B-Diffusers + torch_dtype: auto +calib: + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + sample_steps: 20 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + seed: *seed +eval: + eval_pos: [transformed, fake_quant] + type: video_gen + name: t2v + download: False + path: ./assets/wan_t2v/calib/ + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 + output_video_path: ./output_videos_awq/ +quant: + video_gen: + method: Awq + weight: + bit: 4 + symmetric: True + granularity: per_channel + group_size: -1 + act: + bit: 4 + symmetric: True + granularity: per_token + special: + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True +save: + save_lightx2v: True + save_path: ../lightx2v/wan_t2v_awq_w_a_s/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml b/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml deleted file mode 100755 index b6a53b0e0..000000000 --- a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml +++ /dev/null @@ -1,32 +0,0 @@ -base: - seed: &seed 42 -model: - type: WanT2V - path: /path/to/wan_t2v - torch_dtype: auto -eval: - eval_pos: [transformed, fake_quant] - type: video_gen - name: t2v - download: False - path: ../assets/wan_t2v/eval/ - bs: 1 - target_height: 480 - target_width: 832 - num_frames: 81 - guidance_scale: 5.0 - output_video_path: ./output_videos_rtn/ -quant: - video_gen: - method: RTN - weight: - bit: 6 - symmetric: True - granularity: per_channel - act: - bit: 6 - symmetric: True - granularity: per_token -save: - save_lightx2v: True - save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml index 7d65f31fc..f76edd294 100755 --- a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml @@ -2,12 +2,12 @@ base: seed: &seed 42 model: type: WanT2V - path: /path/to/wan_t2v + path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-14B-Diffusers torch_dtype: auto calib: name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ sample_steps: 20 bs: 1 target_height: 480 @@ -20,26 +20,30 @@ eval: type: video_gen name: t2v download: False - path: ../assets/wan_t2v/calib/ + path: ./assets/wan_t2v/calib/ bs: 1 target_height: 480 target_width: 832 num_frames: 81 guidance_scale: 5.0 - output_video_path: ./output_videos_sq/ + output_video_path: ./output_videos_awq/ quant: video_gen: - method: SmoothQuant + method: Awq weight: - bit: 6 + bit: 4 symmetric: True granularity: per_channel + group_size: -1 act: - bit: 6 + bit: 4 symmetric: True granularity: per_token special: - alpha: 0.7 + trans: True + trans_version: v2 + weight_clip: True + clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ + save_path: ../lightx2v/wan_t2v_awq_w_a/x2v/ diff --git a/docs/wan2.1_quantization_guide.md b/docs/wan2.1_quantization_guide.md new file mode 100644 index 000000000..eeef5ac63 --- /dev/null +++ b/docs/wan2.1_quantization_guide.md @@ -0,0 +1,288 @@ +# Wan2.1 视频生成模型量化指南 + +## 概述 + +llmc 框架现已全面支持 Wan2.1 系列视频生成模型的量化,并提供真正量化的 INT8/FP8 权重导出,与 lightx2v 推理框架兼容。 + +## 支持的模型类型 + +- **WanI2V**: Image-to-Video (图像到视频) +- **WanT2V**: Text-to-Video (文本到视频) + +## 支持的量化方法 + +### FP8 量化 (推荐) + +**配置文件**: `configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8.yaml` + +**特点**: +- 使用 E4M3 FP8 格式 (8-bit 浮点数,4位指数,3位尾数) +- SmoothQuant 算法,平衡权重和激活的量化难度 +- 适合 GPU 推理,性能损失小 + +**量化配置**: +```yaml +quant: + video_gen: + method: SmoothQuant + weight: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_channel + use_qtorch: True + act: + quant_type: float-quant + bit: e4m3 # FP8 E4M3 格式 + symmetric: True + granularity: per_token + use_qtorch: True + special: + alpha: 0.75 # SmoothQuant 平衡参数 +``` + +### INT8 量化 + +#### 1. RTN (Round-to-Nearest) +**配置文件**: `configs/quantization/video_gen/wan_i2v/rtn_w_a.yaml` + +**特点**: +- 最简单的量化方法 +- 直接四舍五入到最近的量化级别 +- 速度快,精度略低 + +#### 2. AWQ (Activation-aware Weight Quantization) +**配置文件**: `configs/quantization/video_gen/wan_i2v/awq_w_a.yaml` + +**特点**: +- 基于激活分布优化权重量化 +- 保护重要通道,减少精度损失 +- 需要校准数据 + +#### 3. SmoothQuant +**配置文件**: `configs/quantization/video_gen/wan_i2v/smoothquant_w_a.yaml` + +**特点**: +- 平衡权重和激活的量化难度 +- 数学上等价于平滑激活异常值 +- 通常提供最佳精度 + +### LoRA 模型量化 + +支持对 LoRA 适配器模型的量化: +- `smoothquant_w_a_int8_lora.yaml` +- `rtn_w_a_lora.yaml` + +## 运行步骤 + +### 1. 准备环境 + +```bash +# 设置 llmc 路径 +export llmc=/path/to/llmc +export PYTHONPATH=$llmc:$PYTHONPATH + +# 设置 GPU +export CUDA_VISIBLE_DEVICES=0 +``` + +### 2. 准备校准数据 + +为 I2V 模型准备校准数据: +``` +assets/wan_i2v/calib/ +├── image_1.jpg +├── image_2.jpg +└── ... +``` + +为 T2V 模型准备校准数据: +``` +assets/wan_t2v/calib/ +├── prompt_1.txt +├── prompt_2.txt +└── ... +``` + +### 3. 修改配置文件 + +编辑对应的 YAML 配置文件,设置: +- `model.path`: Wan2.1 模型路径 +- `calib.path`: 校准数据路径 +- `save.save_path`: 量化模型保存路径 + +**示例 (FP8 量化)**: +```yaml +base: + seed: 42 +model: + type: WanI2V + path: /path/to/wan2.1-i2v-model # 修改为你的模型路径 + torch_dtype: auto +calib: + name: i2v + download: False + path: /path/to/calibration/data # 修改为校准数据路径 + sample_steps: 40 + bs: 1 + target_height: 480 + target_width: 832 + num_frames: 81 + guidance_scale: 5.0 +save: + save_lightx2v: True + save_path: /path/to/save/quantized/model # 修改为保存路径 +``` + +### 4. 运行量化 + +#### 使用脚本运行 (推荐) + +```bash +# 运行 FP8 量化 (I2V) +./run_llmc.sh wan_i2v_fp8 + +# 运行 INT8 RTN 量化 (I2V) +./run_llmc.sh wan_i2v_int8_rtn + +# 运行 INT8 AWQ 量化 (I2V) +./run_llmc.sh wan_i2v_int8_awq + +# 运行 INT8 SmoothQuant 量化 (I2V) +./run_llmc.sh wan_i2v_int8_smoothquant + +# 运行 T2V 模型量化 +./run_llmc.sh wan_t2v_int8_rtn +./run_llmc.sh wan_t2v_int8_awq +./run_llmc.sh wan_t2v_int8_smoothquant +``` + +#### 直接运行命令 + +```bash +torchrun \ +--nnodes 1 \ +--nproc_per_node 1 \ +--rdzv_id $RANDOM \ +--rdzv_backend c10d \ +--rdzv_endpoint 127.0.0.1:29500 \ +${llmc}/llmc/__main__.py \ +--config configs/quantization/video_gen/wan_i2v/smoothquant_w_a_fp8.yaml \ +--task_id my_quant_task +``` + +### 5. 监控进度 + +```bash +# 查看日志 +tail -f wan_i2v_fp8.log + +# 查看进程 +ps aux | grep __main__.py +``` + +### 6. 停止任务 + +```bash +# 使用保存的 PID 文件 +xargs kill -9 < wan_i2v_fp8.pid +``` + +## 配置参数说明 + +### 模型配置 +- `type`: 模型类型 (`WanI2V` 或 `WanT2V`) +- `path`: 模型权重路径 +- `torch_dtype`: 数据类型 (`auto`, `bfloat16`, `float32`) + +### 校准配置 +- `sample_steps`: 采样步数 (通常 20-40) +- `bs`: 批大小 (通常 1,视频生成显存占用大) +- `target_height`: 目标视频高度 (默认 480) +- `target_width`: 目标视频宽度 (默认 832) +- `num_frames`: 视频帧数 (默认 81) +- `guidance_scale`: CFG 引导强度 (默认 5.0) + +### 量化配置 +- `method`: 量化方法 (`RTN`, `Awq`, `SmoothQuant`) +- `weight.bit`: 权重位宽 (8, e4m3) +- `act.bit`: 激活位宽 (8, e4m3) +- `granularity`: 量化粒度 (`per_channel`, `per_token`) +- `special.alpha`: SmoothQuant 平衡参数 (0.5-1.0) + +## 在 lightx2v 中使用量化模型 + +### 1. 配置 lightx2v + +编辑 `lightx2v/configs/quantization/wan_i2v.json`: +```json +{ + "infer_steps": 40, + "target_video_length": 81, + "target_height": 480, + "target_width": 832, + "dit_quantized_ckpt": "/path/to/quantized/model", + "dit_quantized": true, + "dit_quant_scheme": "int8-vllm" +} +``` + +对于 FP8 模型,设置 `"dit_quant_scheme": "fp8"`。 + +### 2. 运行推理 + +```bash +python -m lightx2v.infer \ +--model_cls wan2.1 \ +--task i2v \ +--model_path /path/to/original/model \ +--config_json configs/quantization/wan_i2v.json \ +--prompt "Your prompt here" \ +--image_path /path/to/input/image.jpg \ +--save_result_path output.mp4 +``` + +## 性能建议 + +1. **FP8 vs INT8**: + - FP8: 精度更高,适合对质量要求高的场景 + - INT8: 压缩率更高,适合对速度要求高的场景 + +2. **量化方法选择**: + - 快速原型: RTN + - 平衡精度和速度: SmoothQuant + - 最高精度: AWQ + +3. **校准数据**: + - 使用 10-50 个样本 + - 覆盖典型使用场景 + - I2V: 使用多样化图像 + - T2V: 使用多样化文本描述 + +4. **资源需求**: + - GPU: 建议 24GB+ 显存 + - 校准时间: 30分钟 - 2小时 (取决于数据量) + - 存储空间: 量化后模型约原模型 25-50% 大小 + +## 故障排除 + +### 显存不足 +- 减小 `bs` 到 1 +- 减小 `num_frames` +- 减小 `target_height` 和 `target_width` + +### 量化精度损失过大 +- 尝试 SmoothQuant 方法 +- 增加校准数据数量 +- 调整 `alpha` 参数 (0.5-1.0) + +### lightx2v 兼容性问题 +- 确保使用 `save_lightx2v: True` +- 检查 `dit_quant_scheme` 设置 +- 确认量化模型路径正确 + +## 参考 + +- lightx2v 文档: [lightx2v 项目地址] +- llmc 框架: [llmc 项目地址] +- Wan2.1 模型: [模型地址] diff --git a/llmc/compression/quantization/__init__.py b/llmc/compression/quantization/__init__.py index 2c08343e2..07b4f5967 100644 --- a/llmc/compression/quantization/__init__.py +++ b/llmc/compression/quantization/__init__.py @@ -10,7 +10,7 @@ from .ntweak import NormTweaking from .omniq import OmniQuant from .osplus import OsPlus -from .quant import FloatQuantizer, IntegerQuantizer +from .quant import FloatQuantizer, HiFloat4Quantizer, IntegerQuantizer from .quarot import Quarot from .quik import QUIK from .rtn import RTN diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 5a2232699..0c3d5474f 100755 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -35,7 +35,12 @@ _TRANSFORMERS_LN_TYPES_, EffcientFakeQuantLinear, FakeQuantLinear, LlmcActFn, OriginFloatLinear, RotateLinear) -from .quant import FloatQuantizer, IntegerQuantizer, Weight48IntegerQuantizer +from .quant import ( + FloatQuantizer, + HiFloat4Quantizer, + IntegerQuantizer, + Weight48IntegerQuantizer, +) class BaseBlockwiseQuantization(BlockwiseOpt): @@ -157,6 +162,8 @@ def set_quant_config(self): self.weight_quant_module = IntegerQuantizer elif quant_type == 'float-quant': self.weight_quant_module = FloatQuantizer + elif quant_type == 'hif4': + self.weight_quant_module = HiFloat4Quantizer logger.info(f'The used Weight Quant Module is {self.weight_quant_module}') self.wquantizer = self.weight_quant_module(**self.quant_config['weight']) @@ -175,6 +182,13 @@ def set_quant_config(self): self.act_quant_module = IntegerQuantizer elif quant_type == 'float-quant': self.act_quant_module = FloatQuantizer + elif quant_type == 'hif4': + self.act_quant_module = HiFloat4Quantizer + else: + raise ValueError( + f"Unsupported act quant_type: {quant_type}. " + "Supported: int-quant, float-quant, hif4." + ) self.quant_config['act']['tp'] = self.tp self.aquantizer = self.act_quant_module(**self.quant_config['act']) self.act_static = self.quant_config['act'].get('static', False) diff --git a/llmc/compression/quantization/quant.py b/llmc/compression/quantization/quant.py index 2c24c03a8..55cd791a1 100755 --- a/llmc/compression/quantization/quant.py +++ b/llmc/compression/quantization/quant.py @@ -1,4 +1,6 @@ import gc +import os +import sys import torch from loguru import logger @@ -1229,6 +1231,102 @@ def __repr__(self): ) +def _get_hif4_quant_cy(): + """Lazy import HiFloat4 quant_cy (QType, quant_dequant_float) from HiFloat4/hif4_gpu.""" + _repo_root = os.path.dirname( + os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) + ) + _hif4_gpu = os.path.join(_repo_root, 'HiFloat4', 'hif4_gpu') + if _hif4_gpu not in sys.path: + sys.path.insert(0, _hif4_gpu) + try: + from quant_cy import QType, quant_dequant_float + return QType, quant_dequant_float + except Exception as e: + raise ImportError( + 'HiFloat4 4-bit quantization requires the HiFloat4/hif4_gpu package. ' + 'Ensure HiFloat4 is available at repo_root/HiFloat4/hif4_gpu and built.' + ) from e + + +class HiFloat4Quantizer(BaseQuantizer): + """4-bit HiFloat (hif4) simulation quantizer using HiFloat4 quant_dequant_float. + + Uses the HiFloat4 library's quant_dequant_float for block-wise float 4-bit + quantization. No scales/zeros; quantization is done per block along the last dim. + Only supports fake (simulation) quantization; real weight packing is not implemented. + """ + + def __init__(self, bit=4, symmetric=None, granularity=None, **kwargs): + super().__init__(bit, symmetric, granularity, **kwargs) + self.quant_type = 'hif4' + self.q_dim = kwargs.get('hif4_qdim', -1) + self.force_py = kwargs.get('force_py', False) + self.force_fp32 = kwargs.get('force_fp32', True) + self._QType = None + self._quant_dequant_float = None + + def _ensure_hif4(self): + if self._quant_dequant_float is None: + self._QType, self._quant_dequant_float = _get_hif4_quant_cy() + + def fake_quant_act_static(self, act, args={}): + self._ensure_hif4() + org_dtype = act.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + act, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def fake_quant_act_dynamic(self, act, args={}): + self._ensure_hif4() + org_dtype = act.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + act, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def fake_quant_weight_static(self, weight, args): + self._ensure_hif4() + org_dtype = weight.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + weight, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def fake_quant_weight_dynamic(self, weight, args={}): + self._ensure_hif4() + org_dtype = weight.dtype + qtype = self._QType('hifx4').dim(self.q_dim) + out = self._quant_dequant_float( + weight, qtype, force_py=self.force_py, force_fp32=self.force_fp32 + ) + return out.to(org_dtype) + + def real_quant_weight_static(self, weight, args): + raise NotImplementedError( + 'HiFloat4 quantizer is simulation-only (fake quant). ' + 'real_quant_weight is not supported for hif4.' + ) + + def real_quant_weight_dynamic(self, weight, args={}): + raise NotImplementedError( + 'HiFloat4 quantizer is simulation-only (fake quant). ' + 'real_quant_weight is not supported for hif4.' + ) + + def __repr__(self): + return ( + f'HiFloat4Quantizer(quant_type=hif4, q_dim={self.q_dim}, ' + f'force_py={self.force_py}, force_fp32={self.force_fp32})' + ) + + class Weight48IntegerQuantizer(BaseQuantizer): # flake8: noqa def __init__(self, bit, bit4, bit8, **kwargs): diff --git a/llmc/eval/eval_video_generate.py b/llmc/eval/eval_video_generate.py index 0f99ff6c9..726187c0b 100755 --- a/llmc/eval/eval_video_generate.py +++ b/llmc/eval/eval_video_generate.py @@ -23,6 +23,7 @@ def __init__(self, model, config): self.target_width = self.eval_cfg.get('target_width', 832) self.num_frames = self.eval_cfg.get('num_frames', 81) self.guidance_scale = self.eval_cfg.get('guidance_scale', 5.0) + self.guidance_scale_2 = self.eval_cfg.get('guidance_scale_2', None) self.fps = self.eval_cfg.get('fps', 15) @torch.no_grad() @@ -56,14 +57,17 @@ def t2v_eval(self, model, testenc, bs, eval_pos): assert bs == 1, 'Only support eval bs=1' for i, data in enumerate(testenc): - output = model.Pipeline( - prompt=data['prompt'], - negative_prompt=data['negative_prompt'], - height=self.target_height, - width=self.target_width, - num_frames=self.num_frames, - guidance_scale=self.guidance_scale, - ).frames[0] + pipe_kw = { + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': self.target_height, + 'width': self.target_width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if self.guidance_scale_2 is not None: + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + output = model.Pipeline(**pipe_kw).frames[0] export_to_video( output, os.path.join(self.output_video_path, f'{eval_pos}_output_{i}.mp4'), @@ -77,15 +81,18 @@ def i2v_eval(self, model, testenc, bs, eval_pos): for i, data in enumerate(testenc): image, width, height = self.pre_process(model, data['image']) - output = model.Pipeline( - image=image, - prompt=data['prompt'], - negative_prompt=data['negative_prompt'], - height=height, - width=width, - num_frames=self.num_frames, - guidance_scale=self.guidance_scale, - ).frames[0] + pipe_kw = { + 'image': image, + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': height, + 'width': width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if self.guidance_scale_2 is not None: + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + output = model.Pipeline(**pipe_kw).frames[0] export_to_video( output, @@ -98,9 +105,9 @@ def i2v_eval(self, model, testenc, bs, eval_pos): @torch.no_grad() def eval_func(self, model, testenc, bs, eval_pos): assert bs == 1, 'Evaluation only supports batch size = 1.' - assert self.model_type in ['WanT2V', 'WanI2V'], ( + assert self.model_type in ['WanT2V', 'WanI2V', 'Wan2T2V'], ( f"Unsupported model type '{self.model_type}'.\n" - 'Only Wan2.1 video generation models (WanT2V, WanI2V) are supported.' + 'Only Wan video generation models (WanT2V, WanI2V, Wan2T2V) are supported.' ) if self.eval_dataset_name == 't2v': return self.t2v_eval(model, testenc, bs, eval_pos) diff --git a/llmc/models/__init__.py b/llmc/models/__init__.py index 83d746254..7351995df 100755 --- a/llmc/models/__init__.py +++ b/llmc/models/__init__.py @@ -37,3 +37,4 @@ from .vit import Vit from .wan_i2v import WanI2V from .wan_t2v import WanT2V +from .wan2_2_t2v import Wan2T2V diff --git a/llmc/models/base_model.py b/llmc/models/base_model.py index 4d7dda2ae..25393a871 100755 --- a/llmc/models/base_model.py +++ b/llmc/models/base_model.py @@ -119,7 +119,7 @@ def has_bias(self): pass def build_tokenizer(self): - if self.model_type not in ['Vit', 'WanT2V', 'WanI2V']: + if self.model_type not in ['Vit', 'WanT2V', 'WanI2V', 'Wan2T2V']: assert self.tokenizer_mode in ['fast', 'slow'] self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, use_fast=self.tokenizer_mode, trust_remote_code=True @@ -129,7 +129,7 @@ def build_tokenizer(self): if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token else: - self.tokenizer = None + self.tokenizer = None def get_tokenizer(self): return self.tokenizer diff --git a/llmc/models/wan2_2_t2v.py b/llmc/models/wan2_2_t2v.py new file mode 100755 index 000000000..bf603c536 --- /dev/null +++ b/llmc/models/wan2_2_t2v.py @@ -0,0 +1,193 @@ +import inspect +from collections import defaultdict + +import torch +import torch.nn as nn +from diffusers import AutoencoderKLWan, WanPipeline +from loguru import logger + +from llmc.compression.quantization.module_utils import LlmcWanTransformerBlock +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .base_model import BaseModel + + +@MODEL_REGISTRY +class Wan2T2V(BaseModel): + """Wan2.2-T2V with MoE: two experts (high-noise + low-noise), same block structure as Wan2.1.""" + + def __init__(self, config, device_map=None, use_cache=False): + super().__init__(config, device_map, use_cache) + if 'calib' in config: + self.calib_bs = config.calib.bs + self.sample_steps = config.calib.sample_steps + self.target_height = config.calib.get('target_height', 480) + self.target_width = config.calib.get('target_width', 832) + self.num_frames = config.calib.get('num_frames', 81) + self.guidance_scale = config.calib.get('guidance_scale', 5.0) + self.guidance_scale_2 = config.calib.get('guidance_scale_2', 3.0) + else: + self.sample_steps = None + + def build_model(self): + vae = AutoencoderKLWan.from_pretrained( + self.model_path, + subfolder='vae', + torch_dtype=torch.float32, + use_safetensors=True, + ) + # Wan2.2: one pipeline, two transformer experts (transformer + transformer_2). + # Pipeline switches by SNR; both use WanTransformer3DModel with same block layout as Wan2.1. + self.Pipeline = WanPipeline.from_pretrained( + self.model_path, + vae=vae, + torch_dtype=torch.bfloat16, + use_safetensors=True, + ) + self.find_llmc_model() + # Wrap both experts with LlmcWanTransformerBlock (same as Wan2.1 per-block layout). + for block_idx, block in enumerate(self.Pipeline.transformer.blocks): + new_block = LlmcWanTransformerBlock.new(block) + self.Pipeline.transformer.blocks[block_idx] = new_block + if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: + for block_idx, block in enumerate(self.Pipeline.transformer_2.blocks): + new_block = LlmcWanTransformerBlock.new(block) + self.Pipeline.transformer_2.blocks[block_idx] = new_block + self.blocks = list(self.Pipeline.transformer.blocks) + list( + self.Pipeline.transformer_2.blocks + ) + logger.info( + 'Wan2.2 MoE: both experts wrapped (high-noise + low-noise, 80 blocks total).' + ) + else: + self.blocks = list(self.Pipeline.transformer.blocks) + logger.info('Wan2.2: single transformer wrapped (40 blocks).') + logger.info('Model: %s', self.model) + + def find_llmc_model(self): + self.model = self.Pipeline.transformer + + def find_blocks(self): + self.blocks = self.model.blocks + + def get_catcher(self, first_block_input): + sample_steps = self.sample_steps + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + self.signature = inspect.signature(module.forward) + self.step = 0 + + def forward(self, *args, **kwargs): + params = list(self.signature.parameters.keys()) + for i, arg in enumerate(args): + if i > 0: + kwargs[params[i]] = arg + first_block_input['data'].append(args[0]) + first_block_input['kwargs'].append(kwargs) + self.step += 1 + if self.step == sample_steps: + raise ValueError + else: + return self.module(*args) + + return Catcher + + @torch.no_grad() + def collect_first_block_input(self, calib_data, padding_mask=None): + first_block_input = defaultdict(list) + Catcher = self.get_catcher(first_block_input) + # Install Catcher on the pipeline's first block so forward passes go through it. + first_block = self.Pipeline.transformer.blocks[0] + self.Pipeline.transformer.blocks[0] = Catcher(first_block) + self.Pipeline.to('cuda') + for data in calib_data: + self.Pipeline.transformer.blocks[0].step = 0 + try: + pipe_kw = { + 'prompt': data['prompt'], + 'negative_prompt': data['negative_prompt'], + 'height': self.target_height, + 'width': self.target_width, + 'num_frames': self.num_frames, + 'guidance_scale': self.guidance_scale, + } + if hasattr(self, 'guidance_scale_2'): + pipe_kw['guidance_scale_2'] = self.guidance_scale_2 + self.Pipeline(**pipe_kw) + except ValueError: + pass + + self.first_block_input = first_block_input + assert len(self.first_block_input['data']) > 0, 'Catch input data failed.' + self.n_samples = len(self.first_block_input['data']) + logger.info('Retrieved %s calibration samples for Wan2.2 T2V.', self.n_samples) + self.Pipeline.transformer.blocks[0] = self.Pipeline.transformer.blocks[0].module + self.Pipeline.to('cpu') + + def get_padding_mask(self): + return None + + def has_bias(self): + return True + + def __str__(self): + return '\nWan2.2 MoE Model:\n%s\nTotal params: ~27B (14B active per step)' % ( + str(self.model), + ) + + def get_layernorms_in_block(self, block): + return { + 'affine_norm1': block.affine_norm1, + 'norm2': block.norm2, + 'affine_norm3': block.affine_norm3, + } + + def get_subsets_in_block(self, block): + return [ + { + 'layers': { + 'attn1.to_q': block.attn1.to_q, + 'attn1.to_k': block.attn1.to_k, + 'attn1.to_v': block.attn1.to_v, + }, + 'prev_op': [block.affine_norm1], + 'input': ['attn1.to_q'], + 'inspect': block.attn1, + 'has_kwargs': True, + 'sub_keys': {'rotary_emb': 'rotary_emb'}, + }, + { + 'layers': { + 'attn2.to_q': block.attn2.to_q, + }, + 'prev_op': [block.norm2], + 'input': ['attn2.to_q'], + 'inspect': block.attn2, + 'has_kwargs': True, + 'sub_keys': {'encoder_hidden_states': 'encoder_hidden_states'}, + }, + { + 'layers': { + 'ffn.net.0.proj': block.ffn.net[0].proj, + }, + 'prev_op': [block.affine_norm3], + 'input': ['ffn.net.0.proj'], + 'inspect': block.ffn, + 'has_kwargs': True, + }, + ] + + def find_embed_layers(self): + pass + + def get_embed_layers(self): + pass + + def get_layers_except_blocks(self): + pass + + def skip_layer_name(self): + pass diff --git a/llmc/models/wan_t2v.py b/llmc/models/wan_t2v.py index 885bccda3..ec1f0650c 100755 --- a/llmc/models/wan_t2v.py +++ b/llmc/models/wan_t2v.py @@ -31,10 +31,13 @@ def __init__(self, config, device_map=None, use_cache=False): def build_model(self): vae = AutoencoderKLWan.from_pretrained( - self.model_path, subfolder='vae', torch_dtype=torch.float32 + self.model_path, subfolder='vae', torch_dtype=torch.float32, use_safetensors=True ) + # self.Pipeline = WanPipeline.from_pretrained( + # self.model_path, vae=vae, torch_dtype=torch.bfloat16 + # ) self.Pipeline = WanPipeline.from_pretrained( - self.model_path, vae=vae, torch_dtype=torch.bfloat16 + self.model_path, vae=vae, torch_dtype=torch.bfloat16, use_safetensors=True ) self.find_llmc_model() self.find_blocks() diff --git a/scripts/run_llmc.sh b/scripts/run_llmc.sh index d90877f69..efc4141af 100755 --- a/scripts/run_llmc.sh +++ b/scripts/run_llmc.sh @@ -1,17 +1,20 @@ -#!/bin/bash - -# export CUDA_VISIBLE_DEVICES=0,1 - -llmc=/path/to/llmc +export PATH=/mnt/lm_data_afs/wangzining/charles/miniconda3/envs/llmc/bin:$PATH +export PYTHON=/mnt/lm_data_afs/wangzining/charles/miniconda3/envs/llmc/bin/python +export PIP=/mnt/lm_data_afs/wangzining/charles/miniconda3/envs/llmc/bin/pip +export HF_ENDPOINT=https://hf-mirror.com +cd /mnt/lm_data_afs/wangzining/charles/lab/llmc +# model_name=wan_t2v +model_name=wan2_2_t2v +task_name=awq_w_a +# task_name=awq_w_a_s +log_name=${model_name}_${task_name} +rm -rf ../lightx2v/${log_name}/x2v/lightx2v_quant_model +llmc=. export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=awq_w_only -config=${llmc}/configs/quantization/methods/Awq/awq_w_only.yml - +config=${llmc}/configs/quantization/video_gen/${model_name}/${task_name}.yaml nnodes=1 nproc_per_node=1 - find_unused_port() { while true; do port=$(shuf -i 10000-60000 -n 1) @@ -22,25 +25,15 @@ find_unused_port() { done } UNUSED_PORT=$(find_unused_port) - - MASTER_ADDR=127.0.0.1 MASTER_PORT=$UNUSED_PORT task_id=$UNUSED_PORT -nohup \ + torchrun \ --nnodes $nnodes \ --nproc_per_node $nproc_per_node \ --rdzv_id $task_id \ --rdzv_backend c10d \ --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ -${llmc}/llmc/__main__.py --config $config --task_id $task_id \ -> ${task_name}.log 2>&1 & - -sleep 2 -ps aux | grep '__main__.py' | grep $task_id | awk '{print $2}' > ${task_name}.pid - -# You can kill this program by -# xargs kill -9 < xxx.pid -# xxx.pid is ${task_name}.pid file \ No newline at end of file +${llmc}/llmc/__main__.py --config $config --task_id $task_id |tee ${log_name}.log \ No newline at end of file