Skip to content

Commit d59e11e

Browse files
ochougulmamtsing
andauthored
removed duplication of mdp_json_path in compilation command (#706)
Needed for passing custom config via vllm. --------- Signed-off-by: Onkar Chougule <ochougul@qti.qualcomm.com> Signed-off-by: Mamta Singh <mamtsing@qti.qualcomm.com> Co-authored-by: Mamta Singh <mamtsing@qti.qualcomm.com>
1 parent 64eed68 commit d59e11e

File tree

4 files changed

+369
-24
lines changed

4 files changed

+369
-24
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -438,9 +438,6 @@ def _compile(
438438
+ [f"-m={onnx_path}"]
439439
)
440440

441-
if mdp_ts_json_path := compiler_options.pop("mdp_load_partition_config", None):
442-
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
443-
444441
for key, value in compiler_options.items():
445442
option = "-" + key.replace("_", "-")
446443
if isinstance(value, bool):
@@ -449,20 +446,22 @@ def _compile(
449446
continue
450447
command.append(f"{option}={value}")
451448

449+
if use_onnx_subfunctions:
450+
logger.info("Using ONNX subfunctions for compilation.")
451+
command.append("-sub-functions")
452+
452453
# Create a dummy mdp_ts_json if mdp-load-partition-config not provided and num_devices > 1
453-
if mdp_ts_json_path is not None:
454+
if mdp_ts_json_path := compiler_options.pop("mdp_load_partition_config", None):
454455
mdp_ts_json = load_json(str(mdp_ts_json_path))
455456
elif mdp_ts_num_devices > 1:
456457
mdp_ts_json = generate_mdp_partition_config(
457458
mdp_ts_num_devices, compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES)
458459
)
460+
mdp_ts_json_path = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
461+
create_json(str(mdp_ts_json_path), mdp_ts_json)
459462
else:
460463
mdp_ts_json = None
461464

462-
if use_onnx_subfunctions:
463-
logger.info("Using ONNX subfunctions for compilation.")
464-
command.append("-sub-functions")
465-
466465
compile_hash_params = {
467466
"command": command,
468467
"specializations": specializations,
@@ -485,12 +484,6 @@ def _compile(
485484
# Probably compilation failure last time, delete directory to start over
486485
shutil.rmtree(qpc_path)
487486

488-
# write the MDP partition config file if not provided
489-
if mdp_ts_json is not None:
490-
mdp_ts_json_path = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
491-
create_json(str(mdp_ts_json_path), mdp_ts_json)
492-
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
493-
494487
# Write specializations.json file
495488
if specializations is not None:
496489
specializations_json = compile_dir / "specializations.json"
@@ -500,6 +493,9 @@ def _compile(
500493
create_json(str(specializations_json), specializations_data)
501494
command.append(f"-network-specialization-config={specializations_json}")
502495

496+
if mdp_ts_json_path is not None:
497+
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
498+
503499
# Write custom_io.yaml file
504500
if custom_io is not None:
505501
custom_io_yaml = compile_dir / "custom_io.yaml"

QEfficient/transformers/models/modeling_auto.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2741,10 +2741,12 @@ def build_prefill_specialization(
27412741
Dict[str, Union[int, str]]
27422742
A dictionary defining the prefill specialization.
27432743
"""
2744-
if prefill_seq_len == 1 and self.continuous_batching:
2744+
if not self.continuous_batching:
2745+
exec_batch_size = batch_size
2746+
elif prefill_seq_len == 1:
27452747
exec_batch_size = full_batch_size
27462748
else:
2747-
exec_batch_size = 1 if self.continuous_batching else batch_size
2749+
exec_batch_size = 1
27482750

27492751
if hasattr(self.model, "get_specializations"):
27502752
spec = self.model.get_specializations(
@@ -2755,7 +2757,7 @@ def build_prefill_specialization(
27552757
)[0]
27562758
else:
27572759
spec = {
2758-
"batch_size": 1 if self.continuous_batching else batch_size,
2760+
"batch_size": exec_batch_size,
27592761
"seq_len": prefill_seq_len,
27602762
"ctx_len": ctx_len,
27612763
}
@@ -2766,8 +2768,9 @@ def build_prefill_specialization(
27662768
spec["full_batch_size"] = kv_cache_batch_size
27672769
else:
27682770
spec["batch_size"] = kv_cache_batch_size
2771+
# TODO: remove this; not required
27692772
if full_batch_size:
2770-
spec["full_batch_exec_size"] = full_batch_size
2773+
spec["full_batch_exec_size"] = exec_batch_size
27712774
return {k: v for k, v in spec.items() if v is not None}
27722775

27732776
def build_decode_specialization(
@@ -2805,9 +2808,6 @@ def build_decode_specialization(
28052808
A dictionary defining the decode specialization, or None if it would be a duplicate
28062809
of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching).
28072810
"""
2808-
if prefill_seq_len == 1 and not self.continuous_batching:
2809-
return None # Avoid duplication with prefill
2810-
28112811
if hasattr(self.model, "get_specializations"):
28122812
spec = self.model.get_specializations(
28132813
batch_size=full_batch_size if self.continuous_batching else batch_size,
@@ -3025,7 +3025,7 @@ def compile(
30253025
)
30263026
)
30273027

3028-
if prefill_only is None or not prefill_only:
3028+
if (prefill_only is None or not prefill_only) and prefill_seq_len != 1:
30293029
if self.comp_ctx_lengths_decode is not None:
30303030
# Adding elements from self.comp_ctx_lengths_decode to decode_specialization
30313031
for i in range(0, len(self.comp_ctx_lengths_decode)):
@@ -3054,6 +3054,8 @@ def compile(
30543054
if decode_spec:
30553055
specializations.append(decode_spec)
30563056

3057+
if kw_spec := compiler_options.pop("specializations", None):
3058+
specializations = kw_spec
30573059
# --- Compilation ---
30583060
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
30593061
custom_io = {}

tests/transformers/models/test_causal_lm_models.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
153153
config: Optional[AutoConfig] = None,
154154
pytorch_hf_tokens: Optional[list] = None,
155155
qaic_config: Optional[dict] = None,
156+
retain_full_kv: Optional[bool] = None,
156157
):
157158
"""
158159
Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
@@ -211,6 +212,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
211212
prefill_only=prefill_only,
212213
enable_qnn=enable_qnn,
213214
qnn_config=qnn_config,
215+
retain_full_kv=retain_full_kv,
214216
)
215217
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
216218
cloud_ai_100_tokens = exec_info.generated_ids[0][
@@ -260,17 +262,38 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
260262
if not get_available_device_id():
261263
pytest.skip("No available devices to run model on Cloud AI 100")
262264

265+
compiler_options = {}
266+
if prompt_len == 1:
267+
prefill_spec = {
268+
"batch_size": batch_size,
269+
"seq_len": 1,
270+
"ctx_len": ctx_len,
271+
"full_batch_size": full_batch_size,
272+
"sliding_window": 128,
273+
}
274+
decode_spec = {
275+
"batch_size": full_batch_size,
276+
"seq_len": 1,
277+
"ctx_len": ctx_len,
278+
"full_batch_size": full_batch_size,
279+
"sliding_window": 128,
280+
}
281+
compiler_options = {"specializations": [prefill_spec, decode_spec]}
282+
263283
# TODO: add prefill_only tests
264284
qpc_path = qeff_model.compile(
265285
prefill_seq_len=prompt_len,
266286
ctx_len=ctx_len,
267287
num_cores=14,
268288
mxfp6=False,
269289
aic_enable_depth_first=False,
290+
batch_size=batch_size,
270291
full_batch_size=full_batch_size,
271292
num_speculative_tokens=num_speculative_tokens,
272293
enable_qnn=enable_qnn,
273294
qnn_config=qnn_config,
295+
retain_full_kv=retain_full_kv,
296+
**compiler_options,
274297
)
275298
exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts)
276299

@@ -370,6 +393,24 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
370393
)
371394

372395

396+
@pytest.mark.nightly
397+
@pytest.mark.on_qaic
398+
@pytest.mark.parametrize("retain_full_kv", [True, False])
399+
def test_causal_lm_gpt_oss_pytorch_vs_kv_vs_ort_vs_ai100_pl1(retain_full_kv):
400+
"""
401+
Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
402+
``Mandatory`` Args:
403+
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
404+
"""
405+
model_name = "openai/gpt-oss-20b"
406+
n_layer = get_custom_n_layers(model_name)
407+
prompt_len = 1
408+
409+
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
410+
model_name=model_name, n_layer=n_layer, prompt_len=prompt_len, retain_full_kv=retain_full_kv
411+
)
412+
413+
373414
@pytest.mark.on_qaic
374415
@pytest.mark.regular
375416
@pytest.mark.qnn

0 commit comments

Comments
 (0)