Skip to content

Commit 7411221

Browse files
Merge branch 'release/v1.21.0' into pyhton3.12-update
2 parents 5c40fab + d59e11e commit 7411221

File tree

7 files changed

+432
-81
lines changed

7 files changed

+432
-81
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
6060
super().__init__()
6161
self.model = model
6262
self.hash_params = create_model_params(self, **kwargs)
63-
self.prefill_onnx_path: Optional[str] = None
6463
self.onnx_path: Optional[str] = None
6564
self.qpc_path: Optional[str] = None
6665
self.qpc_session: Optional[QAICInferenceSession] = None
@@ -240,10 +239,7 @@ def _export(
240239

241240
# Return early if ONNX already exists
242241
if onnx_path.is_file():
243-
if prefill_only:
244-
self.prefill_onnx_path = onnx_path
245-
else:
246-
self.onnx_path = onnx_path
242+
self.onnx_path = onnx_path
247243
return onnx_path
248244

249245
# check if the model is in meta state or weights are offloaded
@@ -322,10 +318,7 @@ def _export(
322318
finally:
323319
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
324320

325-
if prefill_only:
326-
self.prefill_onnx_path = onnx_path
327-
else:
328-
self.onnx_path = onnx_path
321+
self.onnx_path = onnx_path
329322
return onnx_path
330323

331324
def get_onnx_path(
@@ -342,21 +335,18 @@ def get_onnx_path(
342335
"use_onnx_subfunctions": use_onnx_subfunctions,
343336
"retain_full_kv": retain_full_kv,
344337
}
338+
345339
if prefill_only:
346-
if self.prefill_onnx_path is None:
347-
kwargs.update(
348-
{
349-
"prefill_only": prefill_only,
350-
"prefill_seq_len": specializations[0].get("seq_len"),
351-
"enable_chunking": enable_chunking,
352-
}
353-
)
354-
self.export(**kwargs)
355-
return self.prefill_onnx_path
356-
else:
357-
if self.onnx_path is None:
358-
self.export(**kwargs)
359-
return self.onnx_path
340+
kwargs.update(
341+
{
342+
"prefill_only": prefill_only,
343+
"prefill_seq_len": specializations[0].get("seq_len"),
344+
"enable_chunking": enable_chunking,
345+
}
346+
)
347+
348+
self.export(**kwargs)
349+
return self.onnx_path
360350

361351
@dump_qconfig
362352
def _compile(
@@ -404,6 +394,8 @@ def _compile(
404394
onnx_path = Path(
405395
onnx_path
406396
if onnx_path
397+
else self.onnx_path
398+
if self.onnx_path
407399
else self.get_onnx_path(
408400
prefill_only,
409401
enable_chunking,
@@ -446,9 +438,6 @@ def _compile(
446438
+ [f"-m={onnx_path}"]
447439
)
448440

449-
if mdp_ts_json_path := compiler_options.pop("mdp_load_partition_config", None):
450-
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
451-
452441
for key, value in compiler_options.items():
453442
option = "-" + key.replace("_", "-")
454443
if isinstance(value, bool):
@@ -457,20 +446,22 @@ def _compile(
457446
continue
458447
command.append(f"{option}={value}")
459448

449+
if use_onnx_subfunctions:
450+
logger.info("Using ONNX subfunctions for compilation.")
451+
command.append("-sub-functions")
452+
460453
# Create a dummy mdp_ts_json if mdp-load-partition-config not provided and num_devices > 1
461-
if mdp_ts_json_path is not None:
454+
if mdp_ts_json_path := compiler_options.pop("mdp_load_partition_config", None):
462455
mdp_ts_json = load_json(str(mdp_ts_json_path))
463456
elif mdp_ts_num_devices > 1:
464457
mdp_ts_json = generate_mdp_partition_config(
465458
mdp_ts_num_devices, compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES)
466459
)
460+
mdp_ts_json_path = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
461+
create_json(str(mdp_ts_json_path), mdp_ts_json)
467462
else:
468463
mdp_ts_json = None
469464

470-
if use_onnx_subfunctions:
471-
logger.info("Using ONNX subfunctions for compilation.")
472-
command.append("-sub-functions")
473-
474465
compile_hash_params = {
475466
"command": command,
476467
"specializations": specializations,
@@ -493,12 +484,6 @@ def _compile(
493484
# Probably compilation failure last time, delete directory to start over
494485
shutil.rmtree(qpc_path)
495486

496-
# write the MDP partition config file if not provided
497-
if mdp_ts_json is not None:
498-
mdp_ts_json_path = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
499-
create_json(str(mdp_ts_json_path), mdp_ts_json)
500-
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
501-
502487
# Write specializations.json file
503488
if specializations is not None:
504489
specializations_json = compile_dir / "specializations.json"
@@ -508,6 +493,9 @@ def _compile(
508493
create_json(str(specializations_json), specializations_data)
509494
command.append(f"-network-specialization-config={specializations_json}")
510495

496+
if mdp_ts_json_path is not None:
497+
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
498+
511499
# Write custom_io.yaml file
512500
if custom_io is not None:
513501
custom_io_yaml = compile_dir / "custom_io.yaml"

QEfficient/transformers/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@
189189
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}
190190

191191
# This is for supporting different modelling classes specially written for prefill-only model
192-
SPECIALIZED_PREFILL_ONLY_MODEL_ARCH = {"gpt_oss"}
192+
SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss"}
193193

194194
# Define a transformers layers to QEff layers dictionary
195195
# While onboarding new models make sure to add the new layer maps to this dictionary.

QEfficient/transformers/models/modeling_auto.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from QEfficient.generation.vlm_generation import VisionLanguageGeneration
4141
from QEfficient.transformers.modeling_utils import (
4242
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH,
43-
SPECIALIZED_PREFILL_ONLY_MODEL_ARCH,
43+
SPECIALIZED_DISAGG_SERVING_MODEL_ARCH,
4444
)
4545
from QEfficient.transformers.models.pytorch_transforms import (
4646
BlockedKVAttentionTransform,
@@ -2522,15 +2522,18 @@ def get_seq_len_and_handle_specialized_prefill_model(
25222522

25232523
num_q_blocks = os.environ.get("NUM_Q_BLOCKS", None)
25242524
if num_q_blocks is None:
2525-
block_size = 256
2526-
if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128:
2525+
if (
2526+
prefill_seq_len is None
2527+
or prefill_seq_len % constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE != 0
2528+
or prefill_seq_len < constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE
2529+
):
25272530
raise ValueError(
2528-
f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. "
2531+
f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE}. "
25292532
f"Or set `NUM_Q_BLOCKS` ENV variable"
25302533
f"Received: prefill_seq_len={prefill_seq_len}"
25312534
)
25322535

2533-
num_q_blocks = prefill_seq_len // block_size
2536+
num_q_blocks = prefill_seq_len // constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE
25342537
logger.warning(
25352538
f"Setting NUM_Q_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_Q_BLOCKS` to override"
25362539
)
@@ -2588,31 +2591,28 @@ def export(
25882591
self.model.config, fbs if self.continuous_batching else bs, seq_len
25892592
)
25902593
enable_chunking = kwargs.get("enable_chunking", False)
2591-
if prefill_only:
2592-
if not enable_chunking and self.continuous_batching:
2593-
raise NotImplementedError(
2594-
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
2595-
)
2596-
self.prefill(enable=True, enable_chunking=enable_chunking)
2597-
self.hash_params.pop("retain_full_kv", None)
2598-
seq_len = (
2599-
self.get_seq_len_and_handle_specialized_prefill_model(
2594+
2595+
# TODO: move this to a DA Serving utility class
2596+
if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH:
2597+
if prefill_only:
2598+
if self.continuous_batching and not enable_chunking:
2599+
raise NotImplementedError("Can't enable prefix-caching without chunking")
2600+
self.prefill(enable=True, enable_chunking=enable_chunking)
2601+
self.hash_params.pop("retain_full_kv", None)
2602+
seq_len = self.get_seq_len_and_handle_specialized_prefill_model(
26002603
prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking
26012604
)
2602-
if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH
2603-
else seq_len
2604-
)
2605-
kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len
2606-
else:
2607-
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
2608-
self.hash_params.pop("prefill_only", None)
2609-
self.hash_params.pop("NUM_Q_BLOCKS", None)
2610-
self.hash_params.pop("NUM_FFN_BLOCKS", None)
2611-
self.hash_params.pop("ENABLE_OPT_SWA", None)
2612-
self.hash_params.pop("chunking", None)
2613-
if kwargs.get("retain_full_kv", False):
2614-
kv_cache_shape[2] = seq_len + self.model.config.sliding_window
2615-
self.hash_params["retain_full_kv"] = True
2605+
kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len
2606+
else:
2607+
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
2608+
self.hash_params.pop("prefill_only", None)
2609+
self.hash_params.pop("NUM_Q_BLOCKS", None)
2610+
self.hash_params.pop("NUM_FFN_BLOCKS", None)
2611+
self.hash_params.pop("ENABLE_OPT_SWA", None)
2612+
self.hash_params.pop("chunking", None)
2613+
if kwargs.get("retain_full_kv", False):
2614+
kv_cache_shape[2] = seq_len + self.model.config.sliding_window
2615+
self.hash_params["retain_full_kv"] = True
26162616

26172617
example_inputs = {
26182618
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
@@ -2741,10 +2741,12 @@ def build_prefill_specialization(
27412741
Dict[str, Union[int, str]]
27422742
A dictionary defining the prefill specialization.
27432743
"""
2744-
if prefill_seq_len == 1 and self.continuous_batching:
2744+
if not self.continuous_batching:
2745+
exec_batch_size = batch_size
2746+
elif prefill_seq_len == 1:
27452747
exec_batch_size = full_batch_size
27462748
else:
2747-
exec_batch_size = 1 if self.continuous_batching else batch_size
2749+
exec_batch_size = 1
27482750

27492751
if hasattr(self.model, "get_specializations"):
27502752
spec = self.model.get_specializations(
@@ -2755,7 +2757,7 @@ def build_prefill_specialization(
27552757
)[0]
27562758
else:
27572759
spec = {
2758-
"batch_size": 1 if self.continuous_batching else batch_size,
2760+
"batch_size": exec_batch_size,
27592761
"seq_len": prefill_seq_len,
27602762
"ctx_len": ctx_len,
27612763
}
@@ -2766,8 +2768,9 @@ def build_prefill_specialization(
27662768
spec["full_batch_size"] = kv_cache_batch_size
27672769
else:
27682770
spec["batch_size"] = kv_cache_batch_size
2771+
# TODO: remove this; not required
27692772
if full_batch_size:
2770-
spec["full_batch_exec_size"] = full_batch_size
2773+
spec["full_batch_exec_size"] = exec_batch_size
27712774
return {k: v for k, v in spec.items() if v is not None}
27722775

27732776
def build_decode_specialization(
@@ -2805,9 +2808,6 @@ def build_decode_specialization(
28052808
A dictionary defining the decode specialization, or None if it would be a duplicate
28062809
of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching).
28072810
"""
2808-
if prefill_seq_len == 1 and not self.continuous_batching:
2809-
return None # Avoid duplication with prefill
2810-
28112811
if hasattr(self.model, "get_specializations"):
28122812
spec = self.model.get_specializations(
28132813
batch_size=full_batch_size if self.continuous_batching else batch_size,
@@ -2942,7 +2942,6 @@ def compile(
29422942
if prefill_only is None or not prefill_only:
29432943
if self.continuous_batching and full_batch_size is None:
29442944
raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")
2945-
29462945
else:
29472946
if self.continuous_batching and kv_cache_batch_size is None and full_batch_size is None:
29482947
raise ValueError(
@@ -3026,7 +3025,7 @@ def compile(
30263025
)
30273026
)
30283027

3029-
if prefill_only is None or not prefill_only:
3028+
if (prefill_only is None or not prefill_only) and prefill_seq_len != 1:
30303029
if self.comp_ctx_lengths_decode is not None:
30313030
# Adding elements from self.comp_ctx_lengths_decode to decode_specialization
30323031
for i in range(0, len(self.comp_ctx_lengths_decode)):
@@ -3055,6 +3054,8 @@ def compile(
30553054
if decode_spec:
30563055
specializations.append(decode_spec)
30573056

3057+
if kw_spec := compiler_options.pop("specializations", None):
3058+
specializations = kw_spec
30583059
# --- Compilation ---
30593060
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
30603061
custom_io = {}

QEfficient/utils/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def get_models_dir():
178178
CCL_MAX_ELEMENTS_LISTS = 5
179179
CCL_START_CTX_LEN = 4096
180180

181+
# used for gpt-oss prefill-only model Q-blocking
182+
GPT_OSS_PREFILL_Q_BLOCK_SIZE = 256
183+
181184

182185
class Constants:
183186
# Export Constants.

tests/transformers/models/test_causal_lm_models.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
153153
config: Optional[AutoConfig] = None,
154154
pytorch_hf_tokens: Optional[list] = None,
155155
qaic_config: Optional[dict] = None,
156+
retain_full_kv: Optional[bool] = None,
156157
):
157158
"""
158159
Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
@@ -211,6 +212,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
211212
prefill_only=prefill_only,
212213
enable_qnn=enable_qnn,
213214
qnn_config=qnn_config,
215+
retain_full_kv=retain_full_kv,
214216
)
215217
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
216218
cloud_ai_100_tokens = exec_info.generated_ids[0][
@@ -260,17 +262,38 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
260262
if not get_available_device_id():
261263
pytest.skip("No available devices to run model on Cloud AI 100")
262264

265+
compiler_options = {}
266+
if prompt_len == 1:
267+
prefill_spec = {
268+
"batch_size": batch_size,
269+
"seq_len": 1,
270+
"ctx_len": ctx_len,
271+
"full_batch_size": full_batch_size,
272+
"sliding_window": 128,
273+
}
274+
decode_spec = {
275+
"batch_size": full_batch_size,
276+
"seq_len": 1,
277+
"ctx_len": ctx_len,
278+
"full_batch_size": full_batch_size,
279+
"sliding_window": 128,
280+
}
281+
compiler_options = {"specializations": [prefill_spec, decode_spec]}
282+
263283
# TODO: add prefill_only tests
264284
qpc_path = qeff_model.compile(
265285
prefill_seq_len=prompt_len,
266286
ctx_len=ctx_len,
267287
num_cores=14,
268288
mxfp6=False,
269289
aic_enable_depth_first=False,
290+
batch_size=batch_size,
270291
full_batch_size=full_batch_size,
271292
num_speculative_tokens=num_speculative_tokens,
272293
enable_qnn=enable_qnn,
273294
qnn_config=qnn_config,
295+
retain_full_kv=retain_full_kv,
296+
**compiler_options,
274297
)
275298
exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts)
276299

@@ -370,6 +393,24 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
370393
)
371394

372395

396+
@pytest.mark.nightly
397+
@pytest.mark.on_qaic
398+
@pytest.mark.parametrize("retain_full_kv", [True, False])
399+
def test_causal_lm_gpt_oss_pytorch_vs_kv_vs_ort_vs_ai100_pl1(retain_full_kv):
400+
"""
401+
Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
402+
``Mandatory`` Args:
403+
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
404+
"""
405+
model_name = "openai/gpt-oss-20b"
406+
n_layer = get_custom_n_layers(model_name)
407+
prompt_len = 1
408+
409+
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
410+
model_name=model_name, n_layer=n_layer, prompt_len=prompt_len, retain_full_kv=retain_full_kv
411+
)
412+
413+
373414
@pytest.mark.on_qaic
374415
@pytest.mark.regular
375416
@pytest.mark.qnn

0 commit comments

Comments
 (0)