Skip to content

Commit 32f30c0

Browse files
Fixed torch patch for subfunction with VLMs (quic#750)
Signed-off-by: abhishek-singh591 <sabhis@qti.qualcomm.com>
1 parent 0ffa4ea commit 32f30c0

File tree

5 files changed

+30
-24
lines changed

5 files changed

+30
-24
lines changed

QEfficient/peft/auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,8 @@ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
289289

290290
return self._export(
291291
example_inputs,
292-
output_names,
293-
dynamic_axes,
292+
output_names=output_names,
293+
dynamic_axes=dynamic_axes,
294294
do_constant_folding=False, # To avoid merging adapter weights with base weights
295295
onnx_transform_kwargs={"adapter_name": self.model.active_adapter},
296296
export_dir=export_dir,

QEfficient/peft/lora/auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,8 @@ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
384384

385385
return self._export(
386386
example_inputs,
387-
output_names,
388-
dynamic_axes,
387+
output_names=output_names,
388+
dynamic_axes=dynamic_axes,
389389
export_dir=export_dir,
390390
**kwargs,
391391
)

QEfficient/transformers/models/modeling_auto.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
344344

345345
return self._export(
346346
example_inputs,
347-
output_names,
348-
dynamic_axes,
347+
output_names=output_names,
348+
dynamic_axes=dynamic_axes,
349349
export_dir=export_dir,
350350
use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
351351
)
@@ -623,8 +623,8 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
623623
"""
624624
return self._export(
625625
inputs,
626-
output_names,
627-
dynamic_axes,
626+
output_names=output_names,
627+
dynamic_axes=dynamic_axes,
628628
export_dir=export_dir,
629629
offload_pt_weights=offload_pt_weights,
630630
use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
@@ -768,8 +768,8 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
768768
"""
769769
return self._export(
770770
inputs,
771-
output_names,
772-
dynamic_axes,
771+
output_names=output_names,
772+
dynamic_axes=dynamic_axes,
773773
export_dir=export_dir,
774774
offload_pt_weights=offload_pt_weights,
775775
use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
@@ -1708,8 +1708,8 @@ def export(
17081708
output_names = self.model.get_output_names()
17091709
return self._export(
17101710
inputs,
1711-
output_names,
1712-
dynamic_axes,
1711+
output_names=output_names,
1712+
dynamic_axes=dynamic_axes,
17131713
export_dir=export_dir,
17141714
use_onnx_subfunctions=use_onnx_subfunctions,
17151715
)
@@ -2706,8 +2706,8 @@ def export(
27062706
)
27072707
return self._export(
27082708
example_inputs,
2709-
output_names,
2710-
dynamic_axes,
2709+
output_names=output_names,
2710+
dynamic_axes=dynamic_axes,
27112711
export_dir=export_dir,
27122712
use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
27132713
offload_pt_weights=kwargs.get("offload_pt_weights", True),
@@ -3300,8 +3300,8 @@ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
33003300
output_names = self.model.get_output_names()
33013301
return self._export(
33023302
inputs,
3303-
output_names,
3304-
dynamic_axes,
3303+
output_names=output_names,
3304+
dynamic_axes=dynamic_axes,
33053305
export_dir=export_dir,
33063306
use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
33073307
)
@@ -3676,8 +3676,8 @@ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
36763676

36773677
return self._export(
36783678
example_inputs,
3679-
output_names,
3680-
dynamic_axes,
3679+
output_names=output_names,
3680+
dynamic_axes=dynamic_axes,
36813681
export_dir=export_dir,
36823682
use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False),
36833683
)

QEfficient/utils/export_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,15 +161,18 @@ def _setup_onnx_subfunctions(qeff_model, args, kwargs):
161161
# Apply torch patches for subfunction support
162162
apply_torch_patches()
163163
InvalidIndexProvider.SUBFUNC_ENABLED = True
164+
164165
# Transform output names for subfunction compatibility
165166
if "output_names" in kwargs:
166167
kwargs["output_names"] = [
167168
re.sub("_RetainedState", "_InternalRetainedState", name) for name in kwargs["output_names"]
168169
]
169170
else:
170-
args = list(args)
171-
args[1] = [re.sub("_RetainedState", "_InternalRetainedState", name) for name in args[1]]
172-
args = tuple(args)
171+
warnings.warn(
172+
"ONNX subfunctions are enabled, but no retained-state output names were found to rewrite. "
173+
"Ensure `output_names` includes key/value retained states if subfunction compatibility is required."
174+
)
175+
173176
# Add subfunction-specific ONNX transforms
174177
qeff_model._onnx_transforms.append(RenameFunctionOutputsTransform)
175178
qeff_model._onnx_transforms.append(CustomOpTransform)

QEfficient/utils/torch_patches.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import torch.onnx.utils as onnx_utils
1212
from torch import _C
1313

14+
from QEfficient.utils.logging_utils import logger
15+
1416
# Store original references before patching
1517
_original_setup_trace_module_map = onnx_utils._setup_trace_module_map
1618
_original_get_module_attributes = getattr(onnx_utils, "_get_module_attributes", None)
@@ -37,9 +39,10 @@ def _track_module_attributes_forward_hook(module, input, output):
3739
if hasattr(module, attr_name):
3840
onnx_attrs = getattr(module, attr_name)
3941
delattr(module, attr_name)
40-
# FIX: use empty dict to avoid type mismatch
41-
onnx_attrs = {}
42-
_C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
42+
try:
43+
_C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
44+
except Exception as e:
45+
logger.warning(f"Failed to track ONNX scope attributes: {e}. Skipping this step.")
4346

4447
for m in model.modules():
4548
m.register_forward_hook(_track_module_attributes_forward_hook)

0 commit comments

Comments
 (0)