Skip to content

Commit 60005ac

Browse files
vbaddiquic-rishinr
andauthored
nit(transforms): Fix ONNX export reuse and expand causal LM subfunction compile coverage (quic#873)
This PR tightens the ONNX export path in modeling_qeff and updates quickcheck coverage. **Changes:** - Validate cached ONNX before reuse and re-export if the cached file is invalid. - Use instance-level transform names for hashing/metadata so export cache reflects the active transform set. - Only pass onnx_base_dir when ONNX transforms actually need external tensor data (FP16ClipTransform / SplitTensorsTransform), avoiding unnecessary tensor materialization in the default path. - Keep the lightweight onnx_transforms guard so external data is only loaded when a base dir is provided. **Tests**: - Removed the use_dynamo quickcheck case. - Added .compile(..., use_onnx_subfunctions=True) mocked compile-path coverage for all causal-LM quickcheck models. - Full quickcheck validation passed: - python -m pytest -q tests/test_model_quickcheck.py -n auto --------- Signed-off-by: vbaddi <vbaddi@qti.qualcomm.com> Signed-off-by: Rishin Raj <rishinr@qti.qualcomm.com> Co-authored-by: vbaddi <vbaddi@qti.qualcomm.com> Co-authored-by: Rishin Raj <rishinr@qti.qualcomm.com>
1 parent 95524d9 commit 60005ac

File tree

7 files changed

+105
-487
lines changed

7 files changed

+105
-487
lines changed

.github/workflows/quickcheck.yml

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
name: Quickcheck
2+
3+
on:
4+
pull_request:
5+
workflow_dispatch:
6+
7+
concurrency:
8+
group: quickcheck-${{ github.event.pull_request.number || github.ref }}
9+
cancel-in-progress: true
10+
11+
jobs:
12+
quickcheck:
13+
runs-on: ubuntu-latest
14+
timeout-minutes: 90
15+
steps:
16+
- name: Checkout Repo
17+
uses: actions/checkout@v4
18+
19+
- name: Setup Python
20+
uses: actions/setup-python@v5
21+
with:
22+
python-version: "3.10"
23+
cache: "pip"
24+
25+
- name: Install Dependencies
26+
run: |
27+
python -m pip install --upgrade pip
28+
python -m pip install -e .[test]
29+
python -m pip install pytest-xdist
30+
31+
- name: Run Quickcheck
32+
run: python -m pytest -q tests/unit_test/models/test_model_quickcheck.py -n auto

QEfficient/base/modeling_qeff.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
import onnx
1919
import torch
2020

21-
from QEfficient.base.onnx_transforms import BaseOnnxTransform, OnnxTransformPipeline
21+
from QEfficient.base.onnx_transforms import (
22+
BaseOnnxTransform,
23+
FP16ClipTransform,
24+
OnnxTransformPipeline,
25+
SplitTensorsTransform,
26+
)
2227
from QEfficient.base.pytorch_transforms import PytorchTransform
2328
from QEfficient.compile.qnn_compiler import compile as qnn_compile
2429
from QEfficient.generation.cloud_infer import QAICInferenceSession
@@ -49,9 +54,8 @@ class QEFFBaseModel(ABC):
4954
_pytorch_transforms: List[PytorchTransform]
5055
_onnx_transforms = [BaseOnnxTransform]
5156

52-
@classmethod
53-
def _transform_names(cls) -> List[str]:
54-
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]
57+
def _transform_names(self) -> List[str]:
58+
return [x.__name__ for x in self._pytorch_transforms + self._onnx_transforms]
5559

5660
def __init__(self, model: torch.nn.Module, **kwargs) -> None:
5761
super().__init__()
@@ -242,9 +246,7 @@ def _export(
242246
# check if the model is in meta state or weights are offloaded
243247
self._model_offloaded_check()
244248

245-
# Export directly into export_dir so any external data files are retained.
246249
export_dir.mkdir(parents=True, exist_ok=True)
247-
tmp_onnx_path = onnx_path
248250

249251
# Create input_names from example_inputs
250252
input_names = []
@@ -274,7 +276,7 @@ def _export(
274276
torch.onnx.export(
275277
self.model,
276278
(example_inputs,),
277-
str(tmp_onnx_path),
279+
str(onnx_path),
278280
input_names=input_names,
279281
output_names=output_names,
280282
dynamic_axes=dynamic_axes,
@@ -283,11 +285,13 @@ def _export(
283285
)
284286
logger.info("PyTorch export successful")
285287
_ = self._offload_model_weights(offload_pt_weights)
286-
model = onnx.load(tmp_onnx_path, load_external_data=False)
288+
model = onnx.load(onnx_path, load_external_data=False)
287289

288-
# Clear temporary references
290+
needs_external_tensor_data = any(
291+
transform in self._onnx_transforms for transform in (FP16ClipTransform, SplitTensorsTransform)
292+
)
289293
transform_kwargs = {
290-
"onnx_base_dir": str(export_dir),
294+
"onnx_base_dir": str(export_dir) if needs_external_tensor_data else None,
291295
"model_name": self.model_name,
292296
}
293297
if onnx_transform_kwargs is not None:
@@ -302,7 +306,9 @@ def _export(
302306
)
303307
logger.info("ONNX transforms applied")
304308

305-
onnx.save(model, onnx_path)
309+
onnx_path_tmp = onnx_path.with_suffix(onnx_path.suffix + ".tmp")
310+
onnx.save(model, onnx_path_tmp)
311+
onnx_path_tmp.replace(onnx_path)
306312
del model
307313
gc.collect()
308314
logger.info("Transformed ONNX saved")

QEfficient/base/onnx_transforms.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import logging
99
import os
10-
import warnings
1110
from concurrent.futures import ThreadPoolExecutor, as_completed
1211
from typing import Any, Dict, List, Optional, Tuple, Type
1312

@@ -106,16 +105,27 @@ class CustomOpTransform(BaseOnnxTransform):
106105
@classmethod
107106
def apply(cls, model: ModelProto) -> bool:
108107
op_applied = False
108+
109+
# Register with PyTorch ONNX exporter (for export time)
109110
for op_name, (func_class, _) in cls._custom_ops.items():
110111
if hasattr(func_class, "symbolic"):
111112
torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, ONNX_EXPORT_OPSET)
112113

114+
used_op_types = {node.op_type for node in model.graph.node}
115+
for function_proto in model.functions:
116+
used_op_types.update(node.op_type for node in function_proto.node)
117+
118+
# Add function prototypes to model
113119
existing = {f.name for f in model.functions}
114-
for _, onnxscript_func in cls._custom_ops.values():
120+
121+
for func_name, onnxscript_func in cls._custom_ops.values():
115122
proto = onnxscript_func.to_function_proto()
123+
if proto.name not in used_op_types:
124+
continue
116125
if proto.name not in existing:
117126
model.functions.append(proto)
118127
op_applied = True
128+
119129
return op_applied
120130

121131

@@ -202,8 +212,6 @@ class OnnxTransformPipeline(BaseOnnxTransform):
202212
"""Pipeline to apply multiple ONNX transformations in sequence."""
203213

204214
def __init__(self, transforms: List[Type[BaseOnnxTransform]]):
205-
if not transforms:
206-
warnings.warn("Transform list is empty. No transformations will be applied.")
207215
self.transforms = transforms
208216

209217
def apply(
@@ -228,7 +236,8 @@ def apply(
228236
do_split = SplitTensorsTransform in requested
229237
fp16_min, fp16_max = np.finfo(np.float16).min, np.finfo(np.float16).max
230238
file_num_tracker = {"num": 0, "size": 0}
231-
external_data_helper.load_external_data_for_model(model, onnx_base_dir)
239+
if onnx_base_dir is not None:
240+
external_data_helper.load_external_data_for_model(model, onnx_base_dir)
232241

233242
if do_fp16 or do_split:
234243
for tensor in external_data_helper._get_all_tensors(model):

scripts/Jenkinsfile

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ pipeline {
4141
mkdir -p $PWD/Non_cli_qaic &&
4242
export TOKENIZERS_PARALLELISM=false &&
4343
export QEFF_HOME=$PWD/Non_cli_qaic &&
44-
pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm --ignore tests/transformers/models/image_text_to_text -n 4 --junitxml=tests/tests_log1.xml --durations=10 &&
44+
pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm --ignore tests/transformers/models/image_text_to_text --ignore tests/unit_test -n 4 --junitxml=tests/tests_log1.xml --durations=10 &&
4545
junitparser merge tests/tests_log1.xml tests/tests_log.xml &&
4646
deactivate"
4747
'''
@@ -50,15 +50,15 @@ pipeline {
5050
}
5151
stage('QAIC LLM Tests') {
5252
steps {
53-
timeout(time: 120, unit: 'MINUTES') {
53+
timeout(time: 180, unit: 'MINUTES') {
5454
sh '''
5555
sudo docker exec ${BUILD_TAG} bash -c "
5656
cd /efficient-transformers &&
5757
. preflight_qeff/bin/activate &&
5858
mkdir -p $PWD/Non_qaic_llm &&
5959
export TOKENIZERS_PARALLELISM=false &&
6060
export QEFF_HOME=$PWD/Non_qaic_llm &&
61-
pytest tests -m '(not cli) and (on_qaic) and (llm_model) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --junitxml=tests/tests_log2.xml --durations=10 &&
61+
pytest tests -m '(not cli) and (on_qaic) and (llm_model) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log2.xml --durations=10 &&
6262
junitparser merge tests/tests_log2.xml tests/tests_log.xml &&
6363
deactivate"
6464
'''
@@ -75,7 +75,7 @@ pipeline {
7575
mkdir -p $PWD/Non_qaic_feature &&
7676
export TOKENIZERS_PARALLELISM=false &&
7777
export QEFF_HOME=$PWD/Non_qaic_feature &&
78-
pytest tests -m '(not cli) and (on_qaic) and (feature) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --junitxml=tests/tests_log2_feature.xml --durations=10 &&
78+
pytest tests -m '(not cli) and (on_qaic) and (feature) and (not nightly) and (not multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log2_feature.xml --durations=10 &&
7979
junitparser merge tests/tests_log2_feature.xml tests/tests_log.xml &&
8080
deactivate"
8181
'''
@@ -94,7 +94,7 @@ pipeline {
9494
mkdir -p $PWD/Non_cli_qaic_multimodal &&
9595
export TOKENIZERS_PARALLELISM=false &&
9696
export QEFF_HOME=$PWD/Non_cli_qaic_multimodal &&
97-
pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --junitxml=tests/tests_log6.xml --durations=10 &&
97+
pytest tests -m '(not cli) and (on_qaic) and (multimodal) and (not qnn) and (not finetune) and (not diffusion_models)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log6.xml --durations=10 &&
9898
junitparser merge tests/tests_log6.xml tests/tests_log.xml &&
9999
deactivate"
100100
'''
@@ -112,7 +112,7 @@ pipeline {
112112
export TOKENIZERS_PARALLELISM=false &&
113113
export QEFF_HOME=$PWD/Non_cli_qaic_diffusion &&
114114
export HF_HUB_CACHE=/huggingface_hub &&
115-
pytest tests -m '(not cli) and (on_qaic) and (diffusion_models) and (not wan) and (not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log_diffusion.xml --durations=10 &&
115+
pytest tests -m '(not cli) and (on_qaic) and (diffusion_models) and (not wan) and (not qnn) and (not finetune)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log_diffusion.xml --durations=10 &&
116116
junitparser merge tests/tests_log_diffusion.xml tests/tests_log.xml &&
117117
deactivate"
118118
'''
@@ -131,7 +131,7 @@ pipeline {
131131
mkdir -p $PWD/cli &&
132132
export TOKENIZERS_PARALLELISM=false &&
133133
export QEFF_HOME=$PWD/cli &&
134-
pytest tests -m '(cli and not qnn) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log3.xml --durations=10 &&
134+
pytest tests -m '(cli and not qnn) and (not finetune)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log3.xml --durations=10 &&
135135
junitparser merge tests/tests_log3.xml tests/tests_log.xml &&
136136
deactivate"
137137
'''
@@ -209,7 +209,7 @@ pipeline {
209209
mkdir -p $PWD/cli_qaic_finetuning &&
210210
export TOKENIZERS_PARALLELISM=false &&
211211
export QEFF_HOME=$PWD/cli_qaic_finetuning &&
212-
pytest tests -m '(cli) and (on_qaic) and (not qnn) and (not multimodal) and (finetune)' --ignore tests/vllm --junitxml=tests/tests_log_finetune.xml --durations=10 &&
212+
pytest tests -m '(cli) and (on_qaic) and (not qnn) and (not multimodal) and (finetune)' --ignore tests/vllm --ignore tests/unit_test --junitxml=tests/tests_log_finetune.xml --durations=10 &&
213213
junitparser merge tests/tests_log_finetune.xml tests/tests_log.xml &&
214214
deactivate"
215215
'''

0 commit comments

Comments
 (0)