Skip to content

Commit 815309e

Browse files
authored
fix(cloud.infer): reduce Qwen3-MoE export OOM risk (#821)
Summary - Keep `use_onnx_subfunctions` disabled by default in `QEfficient.cloud.infer` - Provide explicit opt-in via `--use-onnx-subfunctions` only - Remove `--no-use-onnx-subfunctions` - Update infer unit tests for explicit-enable and default-disabled behavior - Update quick-start and text-generation docs to reflect explicit opt-in behavior Why - Align infer behavior with reviewer feedback to keep defaults unchanged and avoid model-specific auto-enable behavior. Fixes - Fixes #702 Validation - `python -m py_compile QEfficient/cloud/infer.py tests/cloud/test_infer.py` - `ruff check QEfficient/cloud/infer.py tests/cloud/test_infer.py` - `pytest -q tests/cloud/test_infer.py -m "not on_qaic"` (2 passed, 5 deselected) --------- Signed-off-by: jd316 <jd316biswas@gmail.com>
1 parent 3d0d663 commit 815309e

File tree

4 files changed

+72
-8
lines changed

4 files changed

+72
-8
lines changed

QEfficient/cloud/infer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def main(
139139
qnn_config: Optional[str] = None,
140140
trust_remote_code: Optional[bool] = False,
141141
ccl_enabled: Optional[bool] = False,
142+
use_onnx_subfunctions: bool = False,
142143
**kwargs,
143144
) -> None:
144145
"""
@@ -205,6 +206,8 @@ def main(
205206
Path of the QNN Config parameters file. Default is None.
206207
trust_remote_code : bool, optional
207208
If True, trusts remote code when loading models from HuggingFace. Default is False.
209+
use_onnx_subfunctions : bool, optional
210+
Enables ONNX subfunctions during export and compile. Default is False.
208211
**kwargs :
209212
Additional compiler options passed directly to `qaic-compile`. Any flag supported by
210213
`qaic-compile` can be passed. Parameters are converted to flags as follows:
@@ -231,12 +234,10 @@ def main(
231234
"""
232235
cache_dir = check_and_assign_cache_dir(local_model_dir, cache_dir)
233236

234-
if "--mxfp6" in sys.argv:
235-
if args.mxfp6:
236-
logger.warning("mxfp6 is going to be deprecated in a future release, use -mxfp6_matmul instead.")
237-
if "--mxint8" in sys.argv:
238-
if args.mxint8:
239-
logger.warning("mxint8 is going to be deprecated in a future release, use -mxint8_kv_cache instead.")
237+
if "--mxfp6" in sys.argv and mxfp6:
238+
logger.warning("mxfp6 is going to be deprecated in a future release, use -mxfp6_matmul instead.")
239+
if "--mxint8" in sys.argv and mxint8:
240+
logger.warning("mxint8 is going to be deprecated in a future release, use -mxint8_kv_cache instead.")
240241

241242
qaic_config = {"ccl_enabled": True} if ccl_enabled else None
242243

@@ -280,6 +281,7 @@ def main(
280281
allow_mxint8_mdp_io=allow_mxint8_mdp_io,
281282
enable_qnn=enable_qnn,
282283
qnn_config=qnn_config,
284+
use_onnx_subfunctions=use_onnx_subfunctions,
283285
**kwargs,
284286
)
285287

@@ -382,6 +384,14 @@ def main(
382384
action="store_true",
383385
help="Compress Present/Past KV to MXINT8 using CustomIO config, default is False",
384386
)
387+
parser.add_argument(
388+
"--use-onnx-subfunctions",
389+
"--use_onnx_subfunctions",
390+
dest="use_onnx_subfunctions",
391+
action="store_true",
392+
default=False,
393+
help="Enable ONNX subfunctions during export/compile.",
394+
)
385395
parser.add_argument(
386396
"--num_cores", "--num-cores", type=int, required=True, help="Number of cores to compile on Cloud AI 100"
387397
)

docs/source/quick_start.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,19 @@ This is the single e2e CLI API, which takes `model_card` name as input along wit
111111

112112
* HuggingFace model files Download → Optimize for Cloud AI 100 → Export to `ONNX` → Compile on Cloud AI 100 → [Execute](#execute_api)
113113
* It skips the export/compile stage based if `ONNX` or `qpc` files are found. If you use infer second time with different compilation arguments, it will automatically skip `ONNX` model creation and directly jump to compile stage.
114+
* ONNX subfunctions can be enabled explicitly using `--use-onnx-subfunctions`.
114115

115116

116117
```bash
117118
# Check out the options using the help
118119
python -m QEfficient.cloud.infer --help
119120
python -m QEfficient.cloud.infer --model_name gpt2 --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 16 --device_group [0] --prompt "My name is" --mos 1 --aic_enable_depth_first
120121
```
122+
123+
```bash
124+
# Optional: explicitly control ONNX subfunction usage
125+
python -m QEfficient.cloud.infer --model_name Qwen/Qwen3-30B-A3B-Instruct-2507 --batch_size 1 --prompt_len 32 --ctx_len 128 --num_cores 16 --device_group [0] --prompt "My name is" --use-onnx-subfunctions
126+
```
121127
If executing for batch size>1,
122128
You can pass input prompts in single string but separate with pipe (|) symbol". Example below
123129

examples/text_generation/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ This example:
115115
- Demonstrates MoE model inference
116116
- Uses sparse expert activation for efficiency
117117
- Works with Qwen, Mixtral, and other MoE models
118+
- Supports explicit ONNX subfunction enablement with `--use-onnx-subfunctions`
118119

119120

120121
## CLI Workflow
@@ -216,6 +217,7 @@ This uses the pre-compiled QPC for fast inference. You can run this multiple tim
216217
| `--device_group` | Device IDs to use | `[0]` | `[0]` or `[0,1,2,3]` |
217218
| `--mxfp6` | Enable MXFP6 quantization | False | Add flag to enable |
218219
| `--mxint8_kv_cache` | Enable MXINT8 KV cache | False | Add flag to enable |
220+
| `--use-onnx-subfunctions` | Enable ONNX subfunctions for export/compile | False | Add flag to enable |
219221
| `--mos` | Memory optimization strategy | 1 | `1` or `2` |
220222
| `--aic_enable_depth_first` | Enable depth-first execution | False | Add flag to enable |
221223

@@ -312,4 +314,3 @@ This script demonstrates:
312314
By default, exported models and QPC files are stored in `~/.cache/qeff_cache`. Customize this with:
313315
- `QEFF_HOME`: Primary cache directory
314316
- `XDG_CACHE_HOME`: Alternative cache location
315-

tests/cloud/test_infer.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,22 @@
55
#
66
# -----------------------------------------------------------------------------
77

8+
from types import SimpleNamespace
9+
810
import pytest
911

1012
import QEfficient
1113
from QEfficient.cloud.infer import main as infer
1214

1315

1416
def check_infer(
15-
mocker, model_name, prompt="My name is", full_batch_size=None, enable_qnn=False, image_url=None, generation_len=20
17+
mocker,
18+
model_name,
19+
prompt="My name is",
20+
full_batch_size=None,
21+
enable_qnn=False,
22+
image_url=None,
23+
generation_len=20,
1624
):
1725
check_and_assign_cache_dir_spy = mocker.spy(QEfficient.cloud.infer, "check_and_assign_cache_dir")
1826
qeff_model_load_spy = mocker.spy(QEfficient.cloud.infer.QEFFCommonLoader, "from_pretrained")
@@ -99,3 +107,42 @@ def test_infer_vlm(mocker):
99107
prompt="Describe the image.",
100108
image_url="https://i.etsystatic.com/8155076/r/il/0825c2/1594869823/il_fullxfull.1594869823_5x0w.jpg",
101109
)
110+
111+
112+
class _DummyQEFFModel:
113+
def __init__(self, architecture):
114+
self.model = SimpleNamespace(config=SimpleNamespace(architectures=[architecture]))
115+
self.compile_kwargs = None
116+
117+
def compile(self, **kwargs):
118+
self.compile_kwargs = kwargs
119+
return "/tmp/qpc"
120+
121+
def generate(self, *args, **kwargs):
122+
return {}
123+
124+
125+
def _run_infer_with_dummy_model(mocker, architecture, **infer_kwargs):
126+
dummy_model = _DummyQEFFModel(architecture=architecture)
127+
mocker.patch.object(QEfficient.cloud.infer, "check_and_assign_cache_dir", return_value="/tmp/cache")
128+
mocker.patch.object(QEfficient.cloud.infer.QEFFCommonLoader, "from_pretrained", return_value=dummy_model)
129+
mocker.patch.object(QEfficient.cloud.infer, "load_hf_tokenizer", return_value=object())
130+
131+
infer(
132+
model_name="dummy/model",
133+
num_cores=16,
134+
prompt=["hello"],
135+
generation_len=1,
136+
**infer_kwargs,
137+
)
138+
return dummy_model
139+
140+
141+
def test_infer_enables_onnx_subfunctions_when_explicitly_set(mocker):
142+
dummy_model = _run_infer_with_dummy_model(mocker, architecture="Qwen3MoeForCausalLM", use_onnx_subfunctions=True)
143+
assert dummy_model.compile_kwargs["use_onnx_subfunctions"] is True
144+
145+
146+
def test_infer_keeps_onnx_subfunctions_disabled_by_default(mocker):
147+
dummy_model = _run_infer_with_dummy_model(mocker, architecture="LlamaForCausalLM")
148+
assert dummy_model.compile_kwargs["use_onnx_subfunctions"] is False

0 commit comments

Comments
 (0)