Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions configs/quantization/video_gen/wan2_2_t2v/awq_w_a.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
base:
seed: &seed 42
model:
type: Wan2T2V
path: /mnt/lm_data_afs/wangzining/charles/lab/llmc/models/Wan2.2-T2V-A14B-Diffusers
torch_dtype: auto
calib:
name: t2v
download: False
path: ./assets/wan_t2v/calib/
sample_steps: 20
bs: 1
target_height: 480
target_width: 832
num_frames: 81
guidance_scale: 5.0
seed: *seed
eval:
eval_pos: [transformed, fake_quant]
type: video_gen
name: t2v
download: False
path: ./assets/wan_t2v/calib/
bs: 1
target_height: 480
target_width: 832
num_frames: 81
guidance_scale: 5.0
output_video_path: ./output_videos_awq/
quant:
video_gen:
method: Awq
weight:
bit: 4
symmetric: True
granularity: per_channel
group_size: -1
act:
bit: 4
symmetric: True
granularity: per_token
special:
trans: True
trans_version: v2
weight_clip: True
clip_sym: True
save:
save_lightx2v: True
save_path: ../lightx2v/wan2_2_t2v_awq_w_a/x2v/
45 changes: 26 additions & 19 deletions llmc/eval/eval_video_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, model, config):
self.target_width = self.eval_cfg.get('target_width', 832)
self.num_frames = self.eval_cfg.get('num_frames', 81)
self.guidance_scale = self.eval_cfg.get('guidance_scale', 5.0)
self.guidance_scale_2 = self.eval_cfg.get('guidance_scale_2', None)
self.fps = self.eval_cfg.get('fps', 15)

@torch.no_grad()
Expand Down Expand Up @@ -56,14 +57,17 @@ def t2v_eval(self, model, testenc, bs, eval_pos):
assert bs == 1, 'Only support eval bs=1'

for i, data in enumerate(testenc):
output = model.Pipeline(
prompt=data['prompt'],
negative_prompt=data['negative_prompt'],
height=self.target_height,
width=self.target_width,
num_frames=self.num_frames,
guidance_scale=self.guidance_scale,
).frames[0]
pipe_kw = {
'prompt': data['prompt'],
'negative_prompt': data['negative_prompt'],
'height': self.target_height,
'width': self.target_width,
'num_frames': self.num_frames,
'guidance_scale': self.guidance_scale,
}
if self.guidance_scale_2 is not None:
pipe_kw['guidance_scale_2'] = self.guidance_scale_2
output = model.Pipeline(**pipe_kw).frames[0]
export_to_video(
output,
os.path.join(self.output_video_path, f'{eval_pos}_output_{i}.mp4'),
Expand All @@ -77,15 +81,18 @@ def i2v_eval(self, model, testenc, bs, eval_pos):
for i, data in enumerate(testenc):
image, width, height = self.pre_process(model, data['image'])

output = model.Pipeline(
image=image,
prompt=data['prompt'],
negative_prompt=data['negative_prompt'],
height=height,
width=width,
num_frames=self.num_frames,
guidance_scale=self.guidance_scale,
).frames[0]
pipe_kw = {
'image': image,
'prompt': data['prompt'],
'negative_prompt': data['negative_prompt'],
'height': height,
'width': width,
'num_frames': self.num_frames,
'guidance_scale': self.guidance_scale,
}
if self.guidance_scale_2 is not None:
pipe_kw['guidance_scale_2'] = self.guidance_scale_2
output = model.Pipeline(**pipe_kw).frames[0]

export_to_video(
output,
Expand All @@ -98,9 +105,9 @@ def i2v_eval(self, model, testenc, bs, eval_pos):
@torch.no_grad()
def eval_func(self, model, testenc, bs, eval_pos):
assert bs == 1, 'Evaluation only supports batch size = 1.'
assert self.model_type in ['WanT2V', 'WanI2V'], (
assert self.model_type in ['WanT2V', 'WanI2V', 'Wan2T2V'], (
f"Unsupported model type '{self.model_type}'.\n"
'Only Wan2.1 video generation models (WanT2V, WanI2V) are supported.'
'Only Wan video generation models (WanT2V, WanI2V, Wan2T2V) are supported.'
)
if self.eval_dataset_name == 't2v':
return self.t2v_eval(model, testenc, bs, eval_pos)
Expand Down
1 change: 1 addition & 0 deletions llmc/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@
from .vit import Vit
from .wan_i2v import WanI2V
from .wan_t2v import WanT2V
from .wan2_2_t2v import Wan2T2V
4 changes: 2 additions & 2 deletions llmc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def has_bias(self):
pass

def build_tokenizer(self):
if self.model_type not in ['Vit', 'WanT2V', 'WanI2V']:
if self.model_type not in ['Vit', 'WanT2V', 'WanI2V', 'Wan2T2V']:
assert self.tokenizer_mode in ['fast', 'slow']
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path, use_fast=self.tokenizer_mode, trust_remote_code=True
Expand All @@ -129,7 +129,7 @@ def build_tokenizer(self):
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
self.tokenizer = None
self.tokenizer = None

def get_tokenizer(self):
return self.tokenizer
Expand Down
193 changes: 193 additions & 0 deletions llmc/models/wan2_2_t2v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import inspect
from collections import defaultdict

import torch
import torch.nn as nn
from diffusers import AutoencoderKLWan, WanPipeline
from loguru import logger

from llmc.compression.quantization.module_utils import LlmcWanTransformerBlock
from llmc.utils.registry_factory import MODEL_REGISTRY

from .base_model import BaseModel


@MODEL_REGISTRY
class Wan2T2V(BaseModel):
"""Wan2.2-T2V with MoE: two experts (high-noise + low-noise), same block structure as Wan2.1."""

def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)
if 'calib' in config:
self.calib_bs = config.calib.bs
self.sample_steps = config.calib.sample_steps
self.target_height = config.calib.get('target_height', 480)
self.target_width = config.calib.get('target_width', 832)
self.num_frames = config.calib.get('num_frames', 81)
self.guidance_scale = config.calib.get('guidance_scale', 5.0)
self.guidance_scale_2 = config.calib.get('guidance_scale_2', 3.0)
else:
self.sample_steps = None

def build_model(self):
vae = AutoencoderKLWan.from_pretrained(
self.model_path,
subfolder='vae',
torch_dtype=torch.float32,
use_safetensors=True,
)
# Wan2.2: one pipeline, two transformer experts (transformer + transformer_2).
# Pipeline switches by SNR; both use WanTransformer3DModel with same block layout as Wan2.1.
self.Pipeline = WanPipeline.from_pretrained(
self.model_path,
vae=vae,
torch_dtype=torch.bfloat16,
use_safetensors=True,
)
self.find_llmc_model()
# Wrap both experts with LlmcWanTransformerBlock (same as Wan2.1 per-block layout).
for block_idx, block in enumerate(self.Pipeline.transformer.blocks):
new_block = LlmcWanTransformerBlock.new(block)
self.Pipeline.transformer.blocks[block_idx] = new_block
if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None:
for block_idx, block in enumerate(self.Pipeline.transformer_2.blocks):
new_block = LlmcWanTransformerBlock.new(block)
self.Pipeline.transformer_2.blocks[block_idx] = new_block
self.blocks = list(self.Pipeline.transformer.blocks) + list(
self.Pipeline.transformer_2.blocks
)
logger.info(
'Wan2.2 MoE: both experts wrapped (high-noise + low-noise, 80 blocks total).'
)
else:
self.blocks = list(self.Pipeline.transformer.blocks)
logger.info('Wan2.2: single transformer wrapped (40 blocks).')
logger.info('Model: %s', self.model)

def find_llmc_model(self):
self.model = self.Pipeline.transformer

def find_blocks(self):
self.blocks = self.model.blocks

def get_catcher(self, first_block_input):
sample_steps = self.sample_steps

class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
self.signature = inspect.signature(module.forward)
self.step = 0

def forward(self, *args, **kwargs):
params = list(self.signature.parameters.keys())
for i, arg in enumerate(args):
if i > 0:
kwargs[params[i]] = arg
first_block_input['data'].append(args[0])
first_block_input['kwargs'].append(kwargs)
self.step += 1
if self.step == sample_steps:
raise ValueError
else:
return self.module(*args)

return Catcher

@torch.no_grad()
def collect_first_block_input(self, calib_data, padding_mask=None):
first_block_input = defaultdict(list)
Catcher = self.get_catcher(first_block_input)
# Install Catcher on the pipeline's first block so forward passes go through it.
first_block = self.Pipeline.transformer.blocks[0]
self.Pipeline.transformer.blocks[0] = Catcher(first_block)
self.Pipeline.to('cuda')
for data in calib_data:
self.Pipeline.transformer.blocks[0].step = 0
try:
pipe_kw = {
'prompt': data['prompt'],
'negative_prompt': data['negative_prompt'],
'height': self.target_height,
'width': self.target_width,
'num_frames': self.num_frames,
'guidance_scale': self.guidance_scale,
}
if hasattr(self, 'guidance_scale_2'):
pipe_kw['guidance_scale_2'] = self.guidance_scale_2
self.Pipeline(**pipe_kw)
except ValueError:
pass

self.first_block_input = first_block_input
assert len(self.first_block_input['data']) > 0, 'Catch input data failed.'
self.n_samples = len(self.first_block_input['data'])
logger.info('Retrieved %s calibration samples for Wan2.2 T2V.', self.n_samples)
self.Pipeline.transformer.blocks[0] = self.Pipeline.transformer.blocks[0].module
self.Pipeline.to('cpu')

def get_padding_mask(self):
return None

def has_bias(self):
return True

def __str__(self):
return '\nWan2.2 MoE Model:\n%s\nTotal params: ~27B (14B active per step)' % (
str(self.model),
)

def get_layernorms_in_block(self, block):
return {
'affine_norm1': block.affine_norm1,
'norm2': block.norm2,
'affine_norm3': block.affine_norm3,
}

def get_subsets_in_block(self, block):
return [
{
'layers': {
'attn1.to_q': block.attn1.to_q,
'attn1.to_k': block.attn1.to_k,
'attn1.to_v': block.attn1.to_v,
},
'prev_op': [block.affine_norm1],
'input': ['attn1.to_q'],
'inspect': block.attn1,
'has_kwargs': True,
'sub_keys': {'rotary_emb': 'rotary_emb'},
},
{
'layers': {
'attn2.to_q': block.attn2.to_q,
},
'prev_op': [block.norm2],
'input': ['attn2.to_q'],
'inspect': block.attn2,
'has_kwargs': True,
'sub_keys': {'encoder_hidden_states': 'encoder_hidden_states'},
},
{
'layers': {
'ffn.net.0.proj': block.ffn.net[0].proj,
},
'prev_op': [block.affine_norm3],
'input': ['ffn.net.0.proj'],
'inspect': block.ffn,
'has_kwargs': True,
},
]

def find_embed_layers(self):
pass

def get_embed_layers(self):
pass

def get_layers_except_blocks(self):
pass

def skip_layer_name(self):
pass