Skip to content

Commit ea23498

Browse files
Wan I2V support (quic#788)
Support for Wan Image to video model Model card: "Wan-AI/Wan2.2-I2V-A14B-Diffusers" --------- Signed-off-by: vtirumal <vtirumal@qti.qualcomm.com>
1 parent f668b40 commit ea23498

21 files changed

+2771
-21
lines changed

QEfficient/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from QEfficient.compile.compile_helper import compile
3232
from QEfficient.diffusers.pipelines.flux.pipeline_flux import QEffFluxPipeline
3333
from QEfficient.diffusers.pipelines.wan.pipeline_wan import QEffWanPipeline
34+
from QEfficient.diffusers.pipelines.wan.pipeline_wan_i2v import QEffWanImageToVideoPipeline
3435
from QEfficient.exporter.export_hf_to_cloud_ai_100 import qualcomm_efficient_converter
3536
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
3637
from QEfficient.peft import QEffAutoPeftModelForCausalLM
@@ -59,6 +60,7 @@
5960
"QEFFCommonLoader",
6061
"QEffFluxPipeline",
6162
"QEffWanPipeline",
63+
"QEffWanImageToVideoPipeline",
6264
]
6365

6466

QEfficient/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
from typing import Optional
9+
810
import torch
911
from diffusers.models.autoencoders.autoencoder_kl_wan import (
12+
AutoencoderKLWan,
1013
WanDecoder3d,
1114
WanEncoder3d,
1215
WanResample,
@@ -16,8 +19,6 @@
1619

1720
CACHE_T = 2
1821

19-
modes = []
20-
2122
# Used max(0, x.shape[2] - CACHE_T) instead of CACHE_T because x.shape[2] is either 1 or 4,
2223
# and CACHE_T = 2. This ensures the value never goes negative
2324

@@ -58,7 +59,6 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
5859
x = x.reshape(b, c, t * 2, h, w)
5960
t = x.shape[2]
6061
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
61-
modes.append(self.mode)
6262
x = self.resample(x)
6363
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
6464

@@ -198,3 +198,48 @@ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
198198
else:
199199
x = self.conv_out(x)
200200
return x
201+
202+
203+
class QEffAutoencoderKLWan(AutoencoderKLWan):
204+
def encode(self, x: torch.Tensor) -> torch.Tensor:
205+
r"""
206+
Encode a batch of images into latents.
207+
208+
Args:
209+
x (`torch.Tensor`): Input batch of images.
210+
"""
211+
if self.use_slicing and x.shape[0] > 1:
212+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
213+
h = torch.cat(encoded_slices)
214+
else:
215+
h = self._encode(x)
216+
return h
217+
218+
def forward(
219+
self,
220+
image: Optional[torch.Tensor] = None,
221+
latent_sample: Optional[torch.Tensor] = None,
222+
return_dict: bool = True,
223+
) -> torch.Tensor:
224+
r"""
225+
Forward pass through the VAE autoencoder with dual-mode functionality.
226+
This method automatically determines whether to perform encoding or decoding based on the provided inputs:
227+
- If `image` is provided, performs encoding (image → latent space)
228+
- If `latent_sample` is provided, performs decoding (latent space → image)
229+
230+
Args:
231+
image (`torch.Tensor`, *optional*): Input image tensor to encode into latent space.
232+
latent_sample (`torch.Tensor`, *optional*): input latent tensor to decode back to image space.
233+
If provided, `image` should be None.
234+
return_dict (`bool`, *optional*, defaults to `True`):
235+
Whether to return a dictionary with structured output or a raw tensor.
236+
Only applies to decoding operations.
237+
Returns:
238+
`torch.Tensor`:
239+
- If encoding: Latent representation of the input image
240+
- If decoding: Reconstructed image/video from latent representation
241+
"""
242+
if image is not None:
243+
return self.encode(image)
244+
else:
245+
return self.decode(latent_sample, return_dict)

QEfficient/diffusers/models/pytorch_transforms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# -----------------------------------------------------------------------------
77

88
from diffusers.models.autoencoders.autoencoder_kl_wan import (
9+
AutoencoderKLWan,
910
WanDecoder3d,
1011
WanEncoder3d,
1112
WanResample,
@@ -25,6 +26,7 @@
2526
from QEfficient.base.pytorch_transforms import ModuleMappingTransform
2627
from QEfficient.customop.rms_norm import CustomRMSNormAIC
2728
from QEfficient.diffusers.models.autoencoders.autoencoder_kl_wan import (
29+
QEffAutoencoderKLWan,
2830
QEffWanDecoder3d,
2931
QEffWanEncoder3d,
3032
QEffWanResample,
@@ -66,6 +68,7 @@ class AttentionTransform(ModuleMappingTransform):
6668
WanAttnProcessor: QEffWanAttnProcessor,
6769
WanAttention: QEffWanAttention,
6870
WanTransformer3DModel: QEffWanTransformer3DModel,
71+
AutoencoderKLWan: QEffAutoencoderKLWan,
6972
WanDecoder3d: QEffWanDecoder3d,
7073
WanEncoder3d: QEffWanEncoder3d,
7174
WanResidualBlock: QEffWanResidualBlock,
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# ----------------------------------------------------------------------------

QEfficient/diffusers/pipelines/configs/npi_wan_i2v_vae_encoder.yaml

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
{
2+
"description": "Default configuration for Wan image-to-video pipeline with unified transformer (model_type: 1 for high noise; model_type:2 for low noise)",
3+
"modules": {
4+
"vae_encoder":
5+
{
6+
"specializations":
7+
{
8+
"batch_size": 1,
9+
"num_channels": 16
10+
},
11+
"compilation":
12+
{
13+
"onnx_path": null,
14+
"compile_dir": null,
15+
"mdp_ts_num_devices": 8,
16+
"mxfp6_matmul": false,
17+
"convert_to_fp16": true,
18+
"aic_num_cores": 16,
19+
"aic-enable-depth-first": true,
20+
"compile_only":true,
21+
"mos": 1,
22+
"mdts_mos": 1,
23+
"node_precision_info" : "QEfficient/diffusers/pipelines/configs/npi_wan_i2v_vae_encoder.yaml"
24+
},
25+
"execute":
26+
{
27+
"device_ids": null,
28+
"qpc_path" : null
29+
}
30+
},
31+
"transformer": {
32+
"specializations": [
33+
{
34+
"batch_size": "1",
35+
"num_channels": "36",
36+
"steps": "1",
37+
"sequence_length": "512",
38+
"model_type": 1
39+
},
40+
{
41+
"batch_size": "1",
42+
"num_channels": "36",
43+
"steps": "1",
44+
"sequence_length": "512",
45+
"model_type": 2
46+
}
47+
],
48+
"compilation": {
49+
"onnx_path": null,
50+
"compile_dir": null,
51+
"mdp_ts_num_devices": 16,
52+
"mxfp6_matmul": true,
53+
"convert_to_fp16": true,
54+
"compile_only":true,
55+
"aic_num_cores": 16,
56+
"mos": 1,
57+
"mdts_mos": 1
58+
},
59+
"execute": {
60+
"device_ids": null,
61+
"qpc_path" : null
62+
}
63+
},
64+
"vae_decoder":
65+
{
66+
"specializations":
67+
{
68+
"batch_size": 1,
69+
"num_channels": 16
70+
},
71+
"compilation":
72+
{
73+
"onnx_path": null,
74+
"compile_dir": null,
75+
"mdp_ts_num_devices": 8,
76+
"mxfp6_matmul": false,
77+
"convert_to_fp16": true,
78+
"aic_num_cores": 16,
79+
"aic-enable-depth-first": true,
80+
"compile_only":true,
81+
"mos": 1,
82+
"mdts_mos": 1
83+
},
84+
"execute":
85+
{
86+
"device_ids": null,
87+
"qpc_path" : null
88+
}
89+
}
90+
91+
}
92+
}

QEfficient/diffusers/pipelines/pipeline_module.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,8 @@ def __init__(self, model: nn.Module, type: str) -> None:
247247
"""
248248
super().__init__(model)
249249
self.model = model
250-
251-
# To have different hashing for encoder/decoder
252-
self.model.config["type"] = type
250+
self.type = type
251+
# TODO: add vae type in hash file
253252

254253
def get_onnx_params(self, latent_height: int = 32, latent_width: int = 32) -> Tuple[Dict, Dict, List[str]]:
255254
"""
@@ -282,6 +281,43 @@ def get_onnx_params(self, latent_height: int = 32, latent_width: int = 32) -> Tu
282281

283282
return example_inputs, dynamic_axes, output_names
284283

284+
def get_img_encoder_onnx_params(self) -> Tuple[Dict, Dict, List[str]]:
285+
"""
286+
Generate ONNX export configuration for the VAE Encoder.
287+
288+
Returns:
289+
Tuple containing:
290+
- example_inputs (Dict): Sample inputs for ONNX export
291+
- dynamic_axes (Dict): Specification of dynamic dimensions
292+
- output_names (List[str]): Names of model outputs
293+
"""
294+
bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
295+
num_frames = constants.WAN_ONNX_EXPORT_FRAMES
296+
height = constants.WAN_ONNX_EXPORT_HEIGHT_45P
297+
width = constants.WAN_ONNX_EXPORT_WIDTH_45P
298+
example_inputs = {
299+
"image": torch.randn(
300+
bs,
301+
3, # channels
302+
num_frames,
303+
height,
304+
width,
305+
),
306+
}
307+
output_names = ["latents"]
308+
# All dimensions except channels can be dynamic
309+
dynamic_axes = {
310+
"image": {
311+
0: "batch_size",
312+
# 1: "num_channels",
313+
2: "num_frames",
314+
3: "height",
315+
4: "width",
316+
},
317+
}
318+
319+
return example_inputs, dynamic_axes, output_names
320+
285321
def get_video_onnx_params(self) -> Tuple[Dict, Dict, List[str]]:
286322
"""
287323
Generate ONNX export configuration for the VAE decoder.
@@ -298,8 +334,8 @@ def get_video_onnx_params(self) -> Tuple[Dict, Dict, List[str]]:
298334
"""
299335
bs = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
300336
latent_frames = constants.WAN_ONNX_EXPORT_LATENT_FRAMES
301-
latent_height = constants.WAN_ONNX_EXPORT_LATENT_HEIGHT_180P
302-
latent_width = constants.WAN_ONNX_EXPORT_LATENT_WIDTH_180P
337+
latent_height = constants.WAN_ONNX_EXPORT_LATENT_HEIGHT_45P
338+
latent_width = constants.WAN_ONNX_EXPORT_LATENT_WIDTH_45P
303339

304340
# VAE decoder takes latent representation as input
305341
example_inputs = {
@@ -568,8 +604,8 @@ def get_onnx_params(self):
568604
batch_size,
569605
self.model.config.in_channels,
570606
constants.WAN_ONNX_EXPORT_LATENT_FRAMES,
571-
constants.WAN_ONNX_EXPORT_LATENT_HEIGHT_180P,
572-
constants.WAN_ONNX_EXPORT_LATENT_WIDTH_180P,
607+
constants.WAN_ONNX_EXPORT_LATENT_HEIGHT_45P,
608+
constants.WAN_ONNX_EXPORT_LATENT_WIDTH_45P,
573609
dtype=torch.float32,
574610
),
575611
# encoder_hidden_states = [BS, seq len , text dim]
@@ -578,7 +614,7 @@ def get_onnx_params(self):
578614
),
579615
# Rotary position embeddings: [2, context_length, 1, rotary_dim]; 2 is from tuple of cos, sin freqs
580616
"rotary_emb": torch.randn(
581-
2, constants.WAN_ONNX_EXPORT_CL_180P, 1, constants.WAN_ONNX_EXPORT_ROTARY_DIM, dtype=torch.float32
617+
2, constants.WAN_ONNX_EXPORT_CL_45P, 1, constants.WAN_ONNX_EXPORT_ROTARY_DIM, dtype=torch.float32
582618
),
583619
# Timestep embeddings: [batch_size=1, embedding_dim]
584620
"temb": torch.randn(batch_size, constants.WAN_TEXT_EMBED_DIM, dtype=torch.float32),

QEfficient/diffusers/pipelines/pipeline_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,16 @@ def set_execute_params(cls):
131131
)
132132

133133

134+
def update_npi_path(cls, npi_full_path, module_name):
135+
"""To Set NPI for path in compilation config"""
136+
if module_name in cls.custom_config["modules"]:
137+
# Check if the NPI file exists
138+
if not os.path.exists(npi_full_path):
139+
raise FileNotFoundError(f"Node precision info file not found: {npi_full_path}")
140+
141+
cls.custom_config["modules"][module_name]["compilation"]["node_precision_info"] = npi_full_path
142+
143+
134144
def compile_modules_parallel(
135145
modules: Dict[str, Any],
136146
config: Dict[str, Any],

QEfficient/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,8 @@ def compile(
260260
self,
261261
compile_config: Optional[str] = None,
262262
parallel: bool = False,
263-
height: int = constants.WAN_ONNX_EXPORT_HEIGHT_180P,
264-
width: int = constants.WAN_ONNX_EXPORT_WIDTH_180P,
263+
height: int = constants.WAN_ONNX_EXPORT_HEIGHT_45P,
264+
width: int = constants.WAN_ONNX_EXPORT_WIDTH_45P,
265265
num_frames: int = constants.WAN_ONNX_EXPORT_FRAMES,
266266
use_onnx_subfunctions: bool = False,
267267
) -> str:

0 commit comments

Comments
 (0)