Skip to content

Commit 2d6fa60

Browse files
vbaddiAnn Kuruvilla
authored andcommitted
fix(embedding): restore default FP16Clip transform for automodel (#881)
This PR restores FP16ClipTransform for embedding models (`QEFFAutoModel`) in the default (non-proxy) path, while preserving existing proxy-gated behavior for other model categories. ### What changed - Added per-model support for always-on ONNX transforms in proxy configuration. - Set embedding models to always keep FP16ClipTransform enabled by default. - Embedding accuracy on HW depends on FP16 clipping, so clip must remain enabled for embedding even when proxy is disabled. ### Tests verified - `python -m pytest -q tests/unit_test/models/test_model_quickcheck.py -k "test_text_embedding_fp16_clip_transform_and_export"` cc: @anujgupt-github @quic-rishinr @quic-hemagnih --------- Signed-off-by: vbaddi <vbaddi@qti.qualcomm.com>
1 parent 40ce1a7 commit 2d6fa60

File tree

4 files changed

+39
-21
lines changed

4 files changed

+39
-21
lines changed

QEfficient/transformers/modeling_utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ def _configure_proxy_for_model(instance: "QEFFBaseModel", enable_proxy: bool) ->
204204
"""
205205
Configure per-instance transform lists based on proxy mode.
206206
207-
By default, clip/split ONNX transforms are disabled for production exports.
208-
They are only enabled when proxy flow is explicitly requested.
207+
Keep class-defined ONNX transforms by default.
208+
Proxy flow appends additional proxy-only transforms.
209209
"""
210210
instance._pytorch_transforms = list(instance._pytorch_transforms)
211211
instance._onnx_transforms = list(instance._onnx_transforms)
@@ -218,9 +218,6 @@ def _configure_proxy_for_model(instance: "QEFFBaseModel", enable_proxy: bool) ->
218218
if transform not in instance._onnx_transforms:
219219
instance._onnx_transforms.append(transform)
220220
logger.info("Proxy Model Enabled for QEfficient Model")
221-
return
222-
223-
instance._onnx_transforms = [t for t in instance._onnx_transforms if t not in _PROXY_ONLY_ONNX_TRANSFORMS]
224221

225222

226223
# Define a transformers layers to QEff layers dictionary

QEfficient/transformers/models/modeling_auto.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
import QEfficient
3131
from QEfficient.base.modeling_qeff import QEFFBaseModel
32-
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
32+
from QEfficient.base.onnx_transforms import FP16ClipTransform
3333
from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform
3434
from QEfficient.generation.cloud_infer import QAICInferenceSession
3535
from QEfficient.generation.text_generation_inference import (
@@ -229,7 +229,7 @@ class QEFFAutoModel(QEFFTransformersBase):
229229

230230
_hf_auto_class = AutoModel
231231
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
232-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
232+
_onnx_transforms = [FP16ClipTransform]
233233

234234
def __init__(self, model: nn.Module, pooling=None, **kwargs):
235235
"""
@@ -617,7 +617,7 @@ class QEFFAutoModelForSequenceClassification(QEFFTransformersBase):
617617

618618
_hf_auto_class = AutoModelForSequenceClassification
619619
_pytorch_transforms = [CustomOpsTransform, TextClassificationTransform]
620-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
620+
_onnx_transforms = []
621621

622622
def __init__(self, model: nn.Module, **kwargs):
623623
"""
@@ -859,7 +859,7 @@ class QEffVisionEncoderForTextImageToTextModel(QEFFBaseModel):
859859
KVCacheTransform,
860860
KVCacheExternalModuleMapperTransform,
861861
]
862-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
862+
_onnx_transforms = []
863863

864864
def __init__(self, model: nn.modules, **kwargs):
865865
"""
@@ -998,7 +998,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
998998
VlmKVOffloadTransform,
999999
SplitGateUpWeightsTransform,
10001000
]
1001-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
1001+
_onnx_transforms = []
10021002

10031003
def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
10041004
"""
@@ -1874,7 +1874,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal
18741874
VlmNoKVOffloadTransform,
18751875
SplitGateUpWeightsTransform,
18761876
]
1877-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
1877+
_onnx_transforms = []
18781878

18791879
def __init__(
18801880
self,
@@ -2626,7 +2626,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
26262626
KVCacheExternalModuleMapperTransform,
26272627
]
26282628

2629-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
2629+
_onnx_transforms = []
26302630

26312631
def prefill(
26322632
self,
@@ -3575,7 +3575,7 @@ class QEFFAutoModelForSpeechSeq2Seq(QEFFTransformersBase, MultimodalUtilityMixin
35753575

35763576
_hf_auto_class = AutoModelForSpeechSeq2Seq
35773577
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, KVCacheTransform]
3578-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
3578+
_onnx_transforms = []
35793579

35803580
def __init__(self, model: nn.Module, **kwargs):
35813581
"""
@@ -3934,7 +3934,7 @@ class QEFFAutoModelForCTC(QEFFTransformersBase):
39343934

39353935
_hf_auto_class = AutoModelForCTC
39363936
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
3937-
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
3937+
_onnx_transforms = []
39383938

39393939
def __init__(self, model: nn.Module, **kwargs):
39403940
super().__init__(model, **kwargs)

tests/transformers/models/test_embedding_models.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ def check_embed_pytorch_vs_ort_vs_ai100(
101101
assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json"))
102102

103103

104-
@pytest.mark.skip(reason="Known issue: AI100 compiled model produces high MAD; needs investigation")
105104
@pytest.mark.on_qaic
106105
@pytest.mark.llm_model
107106
@pytest.mark.parametrize("model", embed_test_models)
@@ -112,7 +111,6 @@ def test_embed_model_pytorch_vs_onnx_vs_ai100(model):
112111
check_embed_pytorch_vs_ort_vs_ai100(model_name=model["model_name"], seq_len=32, n_layer=1)
113112

114113

115-
@pytest.mark.skip(reason="Known issue: AI100 compiled model produces high MAD; needs investigation")
116114
@pytest.mark.on_qaic
117115
@pytest.mark.llm_model
118116
@pytest.mark.parametrize("model", embed_test_models)
@@ -123,7 +121,6 @@ def test_embed_model_pytorch_vs_onnx_vs_ai100_pooling(model):
123121
check_embed_pytorch_vs_ort_vs_ai100(model_name=model["model_name"], seq_len=32, n_layer=1, pooling=model["pooling"])
124122

125123

126-
@pytest.mark.skip(reason="Known issue: AI100 compiled model produces high MAD; needs investigation")
127124
@pytest.mark.on_qaic
128125
@pytest.mark.llm_model
129126
@pytest.mark.parametrize("model", embed_test_models[:1])

tests/unit_test/models/test_model_quickcheck.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from contextlib import contextmanager, redirect_stderr, redirect_stdout
2626
from io import StringIO
2727
from pathlib import Path
28-
from typing import Dict
28+
from typing import Dict, Optional, Set
2929

3030
import numpy as np
3131
import onnx
@@ -189,13 +189,19 @@ def _run_whisper_export_smoke(qeff_model: QEFFAutoModelForSpeechSeq2Seq, out_dir
189189
return onnx_path
190190

191191

192-
def _assert_proxy_only_onnx_transform_policy(qeff_model, enable_proxy: bool) -> None:
192+
def _assert_proxy_only_onnx_transform_policy(
193+
qeff_model, enable_proxy: bool, always_on_transforms: Optional[Set[str]] = None
194+
) -> None:
193195
transform_names = {transform.__name__ for transform in qeff_model._onnx_transforms}
194196
proxy_only_transforms = {"FP16ClipTransform", "SplitTensorsTransform"}
197+
always_on_transforms = always_on_transforms or set()
198+
conditional_proxy_transforms = proxy_only_transforms - always_on_transforms
199+
195200
if enable_proxy:
196201
assert proxy_only_transforms.issubset(transform_names)
197202
else:
198-
assert proxy_only_transforms.isdisjoint(transform_names)
203+
assert conditional_proxy_transforms.isdisjoint(transform_names)
204+
assert always_on_transforms.issubset(transform_names)
199205

200206

201207
def _skip_on_model_fetch_error(exc: Exception, model_id: str) -> None:
@@ -357,6 +363,22 @@ def test_text_embedding_cpu_parity_and_export(tmp_path):
357363
assert np.allclose(hf_outputs, ort_outputs, atol=1e-5)
358364

359365

366+
@pytest.mark.llm_model
367+
def test_text_embedding_fp16_clip_transform_and_export(tmp_path):
368+
tokenizer = AutoTokenizer.from_pretrained(TINY_TEXT_EMBEDDING_MODEL_ID)
369+
qeff_model = QEFFAutoModel.from_pretrained(TINY_TEXT_EMBEDDING_MODEL_ID)
370+
transform_names = {transform.__name__ for transform in qeff_model._onnx_transforms}
371+
372+
assert "FP16ClipTransform" in transform_names
373+
assert "SplitTensorsTransform" not in transform_names
374+
375+
inputs = tokenizer("hello world", return_tensors="pt")
376+
onnx_path = _exported_onnx_path(qeff_model.export(tmp_path / "embedding-ai100"))
377+
ort_outputs = _run_embedding_ort(onnx_path, inputs)
378+
assert ort_outputs.shape[0] == inputs["input_ids"].shape[0]
379+
assert ort_outputs.shape[1] == inputs["input_ids"].shape[1]
380+
381+
360382
@pytest.mark.llm_model
361383
def test_audio_embedding_ctc_cpu_parity_and_export(tmp_path):
362384
processor = AutoTokenizer.from_pretrained(TINY_AUDIO_CTC_MODEL_ID)
@@ -564,7 +586,9 @@ def test_proxy_toggle_onnx_transform_policy_for_embedding():
564586
except Exception as exc:
565587
_skip_on_model_fetch_error(exc, model_id)
566588

567-
_assert_proxy_only_onnx_transform_policy(qeff_default, enable_proxy=False)
589+
_assert_proxy_only_onnx_transform_policy(
590+
qeff_default, enable_proxy=False, always_on_transforms={"FP16ClipTransform"}
591+
)
568592
_assert_proxy_only_onnx_transform_policy(qeff_proxy, enable_proxy=True)
569593

570594

0 commit comments

Comments
 (0)