99
1010import torch
1111import torch .nn as nn
12- from diffusers .models .transformers .transformer_wan import WanTransformerBlock
1312
1413from QEfficient .base .modeling_qeff import QEFFBaseModel
1514from QEfficient .base .onnx_transforms import FP16ClipTransform , SplitTensorsTransform
1817 CustomOpsTransform ,
1918 NormalizationTransform ,
2019)
21- from QEfficient .diffusers .models .transformers .transformer_flux import (
22- QEffFluxSingleTransformerBlock ,
23- QEffFluxTransformerBlock ,
24- )
2520from QEfficient .transformers .models .pytorch_transforms import (
2621 T5ModelTransform ,
2722)
@@ -475,7 +470,6 @@ def export(
475470 output_names : List [str ],
476471 dynamic_axes : Dict ,
477472 export_dir : str = None ,
478- export_kwargs : Dict = {},
479473 use_onnx_subfunctions : bool = False ,
480474 ) -> str :
481475 """
@@ -486,30 +480,22 @@ def export(
486480 output_names (List[str]): Names of model outputs
487481 dynamic_axes (Dict): Specification of dynamic dimensions
488482 export_dir (str, optional): Directory to save ONNX model
489- export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions)
490483 use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
491484 for better modularity and potential optimization
492485
493486 Returns:
494487 str: Path to the exported ONNX model
495488 """
496489
497- if use_onnx_subfunctions :
498- export_kwargs = {
499- "export_modules_as_functions" : {QEffFluxTransformerBlock , QEffFluxSingleTransformerBlock },
500- "use_onnx_subfunctions" : True ,
501- }
502-
503490 # Sort _use_default_values in config to ensure consistent hash generation during export
504491 self .model .config ["_use_default_values" ].sort ()
505-
506492 return self ._export (
507493 example_inputs = inputs ,
508494 output_names = output_names ,
509495 dynamic_axes = dynamic_axes ,
510496 export_dir = export_dir ,
497+ use_onnx_subfunctions = use_onnx_subfunctions ,
511498 offload_pt_weights = False , # As weights are needed with AdaLN changes
512- ** export_kwargs ,
513499 )
514500
515501 def compile (self , specializations : List [Dict ], ** compiler_options ) -> None :
@@ -631,7 +617,6 @@ def export(
631617 output_names : List [str ],
632618 dynamic_axes : Dict ,
633619 export_dir : str = None ,
634- export_kwargs : Dict = {},
635620 use_onnx_subfunctions : bool = False ,
636621 ) -> str :
637622 """Export the Wan transformer model to ONNX format.
@@ -641,22 +626,19 @@ def export(
641626 output_names (List[str]): Names of model outputs
642627 dynamic_axes (Dict): Specification of dynamic dimensions
643628 export_dir (str, optional): Directory to save ONNX model
644- export_kwargs (Dict, optional): Additional export arguments (e.g., export_modules_as_functions)
645629 use_onnx_subfunctions (bool): Whether to export transformer blocks as ONNX functions
646630 for better modularity and potential optimization
647631 Returns:
648632 str: Path to the exported ONNX model
649633 """
650- if use_onnx_subfunctions :
651- export_kwargs = {"export_modules_as_functions" : {WanTransformerBlock }, "use_onnx_subfunctions" : True }
652634
653635 return self ._export (
654636 example_inputs = inputs ,
655637 output_names = output_names ,
656638 dynamic_axes = dynamic_axes ,
657639 export_dir = export_dir ,
658640 offload_pt_weights = True ,
659- ** export_kwargs ,
641+ use_onnx_subfunctions = use_onnx_subfunctions ,
660642 )
661643
662644 def compile (self , specializations , ** compiler_options ) -> None :
0 commit comments