Skip to content

Commit a9f113c

Browse files
vbaddiquic-rishinr
authored andcommitted
refactor(transforms): simplify proxy gating with per-class default onnx transforms
Signed-off-by: vbaddi <vbaddi@qti.qualcomm.com>
1 parent 3bbc765 commit a9f113c

File tree

2 files changed

+12
-17
lines changed

2 files changed

+12
-17
lines changed

QEfficient/transformers/modeling_utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,24 +204,20 @@ def _configure_proxy_for_model(instance: "QEFFBaseModel", enable_proxy: bool) ->
204204
"""
205205
Configure per-instance transform lists based on proxy mode.
206206
207-
By default, clip/split ONNX transforms are disabled for production exports.
208-
They are only enabled when proxy flow is explicitly requested.
207+
Keep class-defined ONNX transforms by default.
208+
Proxy flow appends additional proxy-only transforms.
209209
"""
210210
instance._pytorch_transforms = list(instance._pytorch_transforms)
211211
instance._onnx_transforms = list(instance._onnx_transforms)
212212
instance._enable_proxy = enable_proxy
213-
proxy_only_onnx_transforms = tuple(getattr(instance, "_proxy_only_onnx_transforms", _PROXY_ONLY_ONNX_TRANSFORMS))
214213

215214
if enable_proxy:
216215
if QeffProxyModuleTransform not in instance._pytorch_transforms:
217216
instance._pytorch_transforms.append(QeffProxyModuleTransform)
218-
for transform in proxy_only_onnx_transforms:
217+
for transform in _PROXY_ONLY_ONNX_TRANSFORMS:
219218
if transform not in instance._onnx_transforms:
220219
instance._onnx_transforms.append(transform)
221220
logger.info("Proxy Model Enabled for QEfficient Model")
222-
return
223-
224-
instance._onnx_transforms = [t for t in instance._onnx_transforms if t not in proxy_only_onnx_transforms]
225221

226222

227223
# Define a transformers layers to QEff layers dictionary

QEfficient/transformers/models/modeling_auto.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import QEfficient
3131
from QEfficient.base.modeling_qeff import QEFFBaseModel
32-
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
32+
from QEfficient.base.onnx_transforms import FP16ClipTransform
3333
from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform
3434
from QEfficient.generation.cloud_infer import QAICInferenceSession
3535
from QEfficient.generation.text_generation_inference import (
@@ -229,8 +229,7 @@ class QEFFAutoModel(QEFFTransformersBase):
229229

230230
_hf_auto_class = AutoModel
231231
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
232-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
233-
_proxy_only_onnx_transforms = (SplitTensorsTransform,)
232+
_onnx_transforms = [FP16ClipTransform]
234233

235234
def __init__(self, model: nn.Module, pooling=None, **kwargs):
236235
"""
@@ -618,7 +617,7 @@ class QEFFAutoModelForSequenceClassification(QEFFTransformersBase):
618617

619618
_hf_auto_class = AutoModelForSequenceClassification
620619
_pytorch_transforms = [CustomOpsTransform, TextClassificationTransform]
621-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
620+
_onnx_transforms = []
622621

623622
def __init__(self, model: nn.Module, **kwargs):
624623
"""
@@ -860,7 +859,7 @@ class QEffVisionEncoderForTextImageToTextModel(QEFFBaseModel):
860859
KVCacheTransform,
861860
KVCacheExternalModuleMapperTransform,
862861
]
863-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
862+
_onnx_transforms = []
864863

865864
def __init__(self, model: nn.modules, **kwargs):
866865
"""
@@ -999,7 +998,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
999998
VlmKVOffloadTransform,
1000999
SplitGateUpWeightsTransform,
10011000
]
1002-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
1001+
_onnx_transforms = []
10031002

10041003
def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
10051004
"""
@@ -1875,7 +1874,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal
18751874
VlmNoKVOffloadTransform,
18761875
SplitGateUpWeightsTransform,
18771876
]
1878-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
1877+
_onnx_transforms = []
18791878

18801879
def __init__(
18811880
self,
@@ -2627,7 +2626,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
26272626
KVCacheExternalModuleMapperTransform,
26282627
]
26292628

2630-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
2629+
_onnx_transforms = []
26312630

26322631
def prefill(
26332632
self,
@@ -3576,7 +3575,7 @@ class QEFFAutoModelForSpeechSeq2Seq(QEFFTransformersBase, MultimodalUtilityMixin
35763575

35773576
_hf_auto_class = AutoModelForSpeechSeq2Seq
35783577
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, KVCacheTransform]
3579-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
3578+
_onnx_transforms = []
35803579

35813580
def __init__(self, model: nn.Module, **kwargs):
35823581
"""
@@ -3935,7 +3934,7 @@ class QEFFAutoModelForCTC(QEFFTransformersBase):
39353934

39363935
_hf_auto_class = AutoModelForCTC
39373936
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
3938-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
3937+
_onnx_transforms = []
39393938

39403939
def __init__(self, model: nn.Module, **kwargs):
39413940
super().__init__(model, **kwargs)

0 commit comments

Comments
 (0)