Skip to content

Commit e205446

Browse files
authored
Merge branch 'quic:main' into qwen3_vl
2 parents 1725b12 + e8e5c43 commit e205446

File tree

10 files changed

+107
-101
lines changed

10 files changed

+107
-101
lines changed

QEfficient/diffusers/models/transformers/transformer_flux.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# ----------------------------------------------------------------------------
7-
from typing import Any, Dict, Optional, Tuple, Union
7+
from typing import Any, Dict, Optional, Tuple, Type, Union
88

99
import numpy as np
1010
import torch
11+
import torch.nn as nn
1112
from diffusers.models.modeling_outputs import Transformer2DModelOutput
1213
from diffusers.models.transformers.transformer_flux import (
1314
FluxAttention,
@@ -221,6 +222,15 @@ def forward(
221222

222223

223224
class QEffFluxTransformer2DModel(FluxTransformer2DModel):
225+
def get_submodules_for_export(self) -> Type[nn.Module]:
226+
"""
227+
Return the set of class used as the repeated layer across the model for subfunction extraction.
228+
Notes:
229+
This method should return the *class object* (not an instance).
230+
Downstream code can use this to find/build subfunctions for repeated blocks.
231+
"""
232+
return {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock}
233+
224234
def forward(
225235
self,
226236
hidden_states: torch.Tensor,

QEfficient/diffusers/models/transformers/transformer_wan.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@
1313
and combined QKV-blocking.
1414
"""
1515

16-
from typing import Any, Dict, List, Optional, Tuple, Union
16+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
1717

1818
import torch
19+
import torch.nn as nn
1920
from diffusers.loaders.peft import _SET_ADAPTER_SCALE_FN_MAPPING
2021
from diffusers.models.modeling_outputs import Transformer2DModelOutput
2122
from diffusers.models.transformers.transformer_wan import (
2223
WanAttention,
2324
WanAttnProcessor,
2425
WanTransformer3DModel,
26+
WanTransformerBlock,
2527
_get_qkv_projections,
2628
)
2729
from diffusers.utils import set_weights_and_activate_adapters
@@ -289,3 +291,78 @@ def forward(
289291
return (output,)
290292

291293
return Transformer2DModelOutput(sample=output)
294+
295+
296+
class QEffWanUnifiedWrapper(nn.Module):
297+
"""
298+
A wrapper class that combines WAN high and low noise transformers into a single unified transformer.
299+
300+
This wrapper dynamically selects between high and low noise transformers based on the timestep shape
301+
in the ONNX graph during inference. This approach enables efficient deployment of both transformer
302+
variants in a single model.
303+
304+
Attributes:
305+
transformer_high(nn.Module): The high noise transformer component
306+
transformer_low(nn.Module): The low noise transformer component
307+
config: Configuration shared between both transformers (from high noise transformer)
308+
"""
309+
310+
def __init__(self, transformer_high, transformer_low):
311+
super().__init__()
312+
self.transformer_high = transformer_high
313+
self.transformer_low = transformer_low
314+
# Both high and low noise transformers share the same configuration
315+
self.config = transformer_high.config
316+
317+
def get_submodules_for_export(self) -> Type[nn.Module]:
318+
"""
319+
Return the set of class used as the repeated layer across the model for subfunction extraction.
320+
Notes:
321+
This method should return the *class object* (not an instance).
322+
Downstream code can use this to find/build subfunctions for repeated blocks.
323+
"""
324+
return {WanTransformerBlock}
325+
326+
def forward(
327+
self,
328+
hidden_states,
329+
encoder_hidden_states,
330+
rotary_emb,
331+
temb,
332+
timestep_proj,
333+
tsp,
334+
attention_kwargs=None,
335+
return_dict=False,
336+
):
337+
# Condition based on timestep shape
338+
is_high_noise = tsp.shape[0] == torch.tensor(1)
339+
340+
high_hs = hidden_states.detach()
341+
ehs = encoder_hidden_states.detach()
342+
rhs = rotary_emb.detach()
343+
ths = temb.detach()
344+
projhs = timestep_proj.detach()
345+
346+
noise_pred_high = self.transformer_high(
347+
hidden_states=high_hs,
348+
encoder_hidden_states=ehs,
349+
rotary_emb=rhs,
350+
temb=ths,
351+
timestep_proj=projhs,
352+
attention_kwargs=attention_kwargs,
353+
return_dict=return_dict,
354+
)[0]
355+
356+
noise_pred_low = self.transformer_low(
357+
hidden_states=hidden_states,
358+
encoder_hidden_states=encoder_hidden_states,
359+
rotary_emb=rotary_emb,
360+
temb=temb,
361+
timestep_proj=timestep_proj,
362+
attention_kwargs=attention_kwargs,
363+
return_dict=return_dict,
364+
)[0]
365+
366+
# Select based on timestep condition
367+
noise_pred = torch.where(is_high_noise, noise_pred_high, noise_pred_low)
368+
return noise_pred

QEfficient/diffusers/pipelines/pipeline_module.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import torch
1111
import torch.nn as nn
12-
from diffusers.models.transformers.transformer_wan import WanTransformerBlock
1312

1413
from QEfficient.base.modeling_qeff import QEFFBaseModel
1514
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
@@ -18,10 +17,6 @@
1817
CustomOpsTransform,
1918
NormalizationTransform,
2019
)
21-
from QEfficient.diffusers.models.transformers.transformer_flux import (
22-
QEffFluxSingleTransformerBlock,
23-
QEffFluxTransformerBlock,
24-
)
2520
from QEfficient.transformers.models.pytorch_transforms import (
2621
T5ModelTransform,
2722
)
@@ -475,7 +470,6 @@ def export(
475470
output_names: List[str],
476471
dynamic_axes: Dict,
477472
export_dir: str = None,
478-
export_kwargs: Dict = {},
479473
use_onnx_subfunctions: bool = False,
480474
) -> str:
481475
"""
@@ -486,30 +480,22 @@ def export(
486480
output_names (List[str]): Names of model outputs
487481
dynamic_axes (Dict): Specification of dynamic dimensions
488482
export_dir (str, optional): Directory to save ONNX model
489-
export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions)
490483
use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
491484
for better modularity and potential optimization
492485
493486
Returns:
494487
str: Path to the exported ONNX model
495488
"""
496489

497-
if use_onnx_subfunctions:
498-
export_kwargs = {
499-
"export_modules_as_functions": {QEffFluxTransformerBlock, QEffFluxSingleTransformerBlock},
500-
"use_onnx_subfunctions": True,
501-
}
502-
503490
# Sort _use_default_values in config to ensure consistent hash generation during export
504491
self.model.config["_use_default_values"].sort()
505-
506492
return self._export(
507493
example_inputs=inputs,
508494
output_names=output_names,
509495
dynamic_axes=dynamic_axes,
510496
export_dir=export_dir,
497+
use_onnx_subfunctions=use_onnx_subfunctions,
511498
offload_pt_weights=False, # As weights are needed with AdaLN changes
512-
**export_kwargs,
513499
)
514500

515501
def compile(self, specializations: List[Dict], **compiler_options) -> None:
@@ -631,7 +617,6 @@ def export(
631617
output_names: List[str],
632618
dynamic_axes: Dict,
633619
export_dir: str = None,
634-
export_kwargs: Dict = {},
635620
use_onnx_subfunctions: bool = False,
636621
) -> str:
637622
"""Export the Wan transformer model to ONNX format.
@@ -641,22 +626,19 @@ def export(
641626
output_names (List[str]): Names of model outputs
642627
dynamic_axes (Dict): Specification of dynamic dimensions
643628
export_dir (str, optional): Directory to save ONNX model
644-
export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions)
645629
use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
646630
for better modularity and potential optimization
647631
Returns:
648632
str: Path to the exported ONNX model
649633
"""
650-
if use_onnx_subfunctions:
651-
export_kwargs = {"export_modules_as_functions": {WanTransformerBlock}, "use_onnx_subfunctions": True}
652634

653635
return self._export(
654636
example_inputs=inputs,
655637
output_names=output_names,
656638
dynamic_axes=dynamic_axes,
657639
export_dir=export_dir,
658640
offload_pt_weights=True,
659-
**export_kwargs,
641+
use_onnx_subfunctions=use_onnx_subfunctions,
660642
)
661643

662644
def compile(self, specializations, **compiler_options) -> None:

QEfficient/diffusers/pipelines/pipeline_utils.py

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313

1414
import numpy as np
1515
import PIL.Image
16-
import torch
17-
import torch.nn as nn
1816
from tqdm import tqdm
1917

2018
from QEfficient.utils._utils import load_json
@@ -297,69 +295,3 @@ def __repr__(self):
297295
# List of module name that require special handling during export
298296
# when use_onnx_subfunctions is enabled
299297
ONNX_SUBFUNCTION_MODULE = ["transformer"]
300-
301-
302-
class QEffWanUnifiedWrapper(nn.Module):
303-
"""
304-
A wrapper class that combines WAN high and low noise transformers into a single unified transformer.
305-
306-
This wrapper dynamically selects between high and low noise transformers based on the timestep shape
307-
in the ONNX graph during inference. This approach enables efficient deployment of both transformer
308-
variants in a single model.
309-
310-
Attributes:
311-
transformer_high(nn.Module): The high noise transformer component
312-
transformer_low(nn.Module): The low noise transformer component
313-
config: Configuration shared between both transformers (from high noise transformer)
314-
"""
315-
316-
def __init__(self, transformer_high, transformer_low):
317-
super().__init__()
318-
self.transformer_high = transformer_high
319-
self.transformer_low = transformer_low
320-
# Both high and low noise transformers share the same configuration
321-
self.config = transformer_high.config
322-
323-
def forward(
324-
self,
325-
hidden_states,
326-
encoder_hidden_states,
327-
rotary_emb,
328-
temb,
329-
timestep_proj,
330-
tsp,
331-
attention_kwargs=None,
332-
return_dict=False,
333-
):
334-
# Condition based on timestep shape
335-
is_high_noise = tsp.shape[0] == torch.tensor(1)
336-
337-
high_hs = hidden_states.detach()
338-
ehs = encoder_hidden_states.detach()
339-
rhs = rotary_emb.detach()
340-
ths = temb.detach()
341-
projhs = timestep_proj.detach()
342-
343-
noise_pred_high = self.transformer_high(
344-
hidden_states=high_hs,
345-
encoder_hidden_states=ehs,
346-
rotary_emb=rhs,
347-
temb=ths,
348-
timestep_proj=projhs,
349-
attention_kwargs=attention_kwargs,
350-
return_dict=return_dict,
351-
)[0]
352-
353-
noise_pred_low = self.transformer_low(
354-
hidden_states=hidden_states,
355-
encoder_hidden_states=encoder_hidden_states,
356-
rotary_emb=rotary_emb,
357-
temb=temb,
358-
timestep_proj=timestep_proj,
359-
attention_kwargs=attention_kwargs,
360-
return_dict=return_dict,
361-
)[0]
362-
363-
# Select based on timestep condition
364-
noise_pred = torch.where(is_high_noise, noise_pred_high, noise_pred_low)
365-
return noise_pred

QEfficient/diffusers/pipelines/wan/pipeline_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
from diffusers import WanPipeline
2424
from tqdm import tqdm
2525

26+
from QEfficient.diffusers.models.transformers.transformer_wan import QEffWanUnifiedWrapper
2627
from QEfficient.diffusers.pipelines.pipeline_module import QEffVAE, QEffWanUnifiedTransformer
2728
from QEfficient.diffusers.pipelines.pipeline_utils import (
2829
ONNX_SUBFUNCTION_MODULE,
2930
ModulePerf,
3031
QEffPipelineOutput,
31-
QEffWanUnifiedWrapper,
3232
calculate_latent_dimensions_with_frames,
3333
compile_modules_parallel,
3434
compile_modules_sequential,

QEfficient/utils/export_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,6 @@ def _setup_onnx_subfunctions(qeff_model, args, kwargs):
179179
qeff_model._onnx_transforms.append(RenameFunctionOutputsTransform)
180180
qeff_model._onnx_transforms.append(CustomOpTransform)
181181

182-
# TODO: Handle this in the modelling class QEFFTransformersBase,remove from here. Refer diffusers implementation
183182
submodule_classes = qeff_model.model.get_submodules_for_export()
184183
if submodule_classes:
185184
kwargs["export_modules_as_functions"] = submodule_classes

QEfficient/utils/torch_patches.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _track_module_attributes_forward_hook(module, input, output):
4040
onnx_attrs = getattr(module, attr_name)
4141
delattr(module, attr_name)
4242
try:
43+
onnx_attrs = {} # HACK: to reduce export time # TODO: study behaviour across models
4344
_C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
4445
except Exception:
4546
logger.warning("Failed to track ONNX scope attributes, Skipping this step.")

examples/diffusers/wan/wan_lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def load_wan_lora(path: str):
5252
generator=torch.manual_seed(0),
5353
height=480,
5454
width=832,
55-
use_onnx_subfunctions=False,
55+
use_onnx_subfunctions=True,
5656
parallel_compile=True,
5757
)
5858
frames = output.images[0]

tests/diffusers/flux_test_config.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
"height": 256,
44
"width": 256,
55
"num_transformer_layers": 2,
6-
"num_single_layers": 2,
7-
"use_onnx_subfunctions": false
6+
"num_single_layers": 2
87
},
98
"mad_validation": {
109
"tolerances": {
@@ -21,7 +20,8 @@
2120
"max_sequence_length": 256,
2221
"validate_gen_img": true,
2322
"min_image_variance": 1.0,
24-
"custom_config_path": null
23+
"custom_config_path": null,
24+
"use_onnx_subfunctions": true
2525
},
2626
"validation_checks": {
2727
"image_generation": true,

tests/diffusers/test_flux.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def flux_pipeline_call_with_mad_validation(
5656
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
5757
max_sequence_length: int = 512,
5858
custom_config_path: Optional[str] = None,
59+
use_onnx_subfunctions: bool = False,
5960
parallel_compile: bool = False,
6061
mad_tolerances: Dict[str, float] = None,
6162
):
@@ -72,7 +73,13 @@ def flux_pipeline_call_with_mad_validation(
7273
device = "cpu"
7374

7475
# Step 1: Load configuration, compile models
75-
pipeline.compile(compile_config=custom_config_path, parallel=parallel_compile, height=height, width=width)
76+
pipeline.compile(
77+
compile_config=custom_config_path,
78+
parallel=parallel_compile,
79+
use_onnx_subfunctions=use_onnx_subfunctions,
80+
height=height,
81+
width=width,
82+
)
7683

7784
# Validate all inputs
7885
pipeline.model.check_inputs(
@@ -307,10 +314,7 @@ def flux_pipeline():
307314
"""Setup compiled Flux pipeline for testing"""
308315
config = INITIAL_TEST_CONFIG["model_setup"]
309316

310-
pipeline = QEffFluxPipeline.from_pretrained(
311-
"black-forest-labs/FLUX.1-schnell",
312-
use_onnx_subfunctions=config["use_onnx_subfunctions"],
313-
)
317+
pipeline = QEffFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell")
314318

315319
# Reduce to 2 layers for testing
316320
original_blocks = pipeline.transformer.model.transformer_blocks
@@ -382,6 +386,7 @@ def test_flux_pipeline(flux_pipeline):
382386
custom_config_path=CONFIG_PATH,
383387
generator=generator,
384388
mad_tolerances=config["mad_validation"]["tolerances"],
389+
use_onnx_subfunctions=config["pipeline_params"]["use_onnx_subfunctions"],
385390
parallel_compile=True,
386391
return_dict=True,
387392
)

0 commit comments

Comments
 (0)