Skip to content

Commit 64eed68

Browse files
authored
General disagg fix for prefill-only model (#693)
Signed-off-by: Onkar Chougule <ochougul@qti.qualcomm.com>
1 parent e5a3497 commit 64eed68

File tree

5 files changed

+63
-57
lines changed

5 files changed

+63
-57
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
6060
super().__init__()
6161
self.model = model
6262
self.hash_params = create_model_params(self, **kwargs)
63-
self.prefill_onnx_path: Optional[str] = None
6463
self.onnx_path: Optional[str] = None
6564
self.qpc_path: Optional[str] = None
6665
self.qpc_session: Optional[QAICInferenceSession] = None
@@ -240,10 +239,7 @@ def _export(
240239

241240
# Return early if ONNX already exists
242241
if onnx_path.is_file():
243-
if prefill_only:
244-
self.prefill_onnx_path = onnx_path
245-
else:
246-
self.onnx_path = onnx_path
242+
self.onnx_path = onnx_path
247243
return onnx_path
248244

249245
# check if the model is in meta state or weights are offloaded
@@ -322,10 +318,7 @@ def _export(
322318
finally:
323319
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)
324320

325-
if prefill_only:
326-
self.prefill_onnx_path = onnx_path
327-
else:
328-
self.onnx_path = onnx_path
321+
self.onnx_path = onnx_path
329322
return onnx_path
330323

331324
def get_onnx_path(
@@ -342,21 +335,18 @@ def get_onnx_path(
342335
"use_onnx_subfunctions": use_onnx_subfunctions,
343336
"retain_full_kv": retain_full_kv,
344337
}
338+
345339
if prefill_only:
346-
if self.prefill_onnx_path is None:
347-
kwargs.update(
348-
{
349-
"prefill_only": prefill_only,
350-
"prefill_seq_len": specializations[0].get("seq_len"),
351-
"enable_chunking": enable_chunking,
352-
}
353-
)
354-
self.export(**kwargs)
355-
return self.prefill_onnx_path
356-
else:
357-
if self.onnx_path is None:
358-
self.export(**kwargs)
359-
return self.onnx_path
340+
kwargs.update(
341+
{
342+
"prefill_only": prefill_only,
343+
"prefill_seq_len": specializations[0].get("seq_len"),
344+
"enable_chunking": enable_chunking,
345+
}
346+
)
347+
348+
self.export(**kwargs)
349+
return self.onnx_path
360350

361351
@dump_qconfig
362352
def _compile(
@@ -404,6 +394,8 @@ def _compile(
404394
onnx_path = Path(
405395
onnx_path
406396
if onnx_path
397+
else self.onnx_path
398+
if self.onnx_path
407399
else self.get_onnx_path(
408400
prefill_only,
409401
enable_chunking,

QEfficient/transformers/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@
189189
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}
190190

191191
# This is for supporting different modelling classes specially written for prefill-only model
192-
SPECIALIZED_PREFILL_ONLY_MODEL_ARCH = {"gpt_oss"}
192+
SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss"}
193193

194194
# Define a transformers layers to QEff layers dictionary
195195
# While onboarding new models make sure to add the new layer maps to this dictionary.

QEfficient/transformers/models/modeling_auto.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from QEfficient.generation.vlm_generation import VisionLanguageGeneration
4141
from QEfficient.transformers.modeling_utils import (
4242
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH,
43-
SPECIALIZED_PREFILL_ONLY_MODEL_ARCH,
43+
SPECIALIZED_DISAGG_SERVING_MODEL_ARCH,
4444
)
4545
from QEfficient.transformers.models.pytorch_transforms import (
4646
BlockedKVAttentionTransform,
@@ -2522,15 +2522,18 @@ def get_seq_len_and_handle_specialized_prefill_model(
25222522

25232523
num_q_blocks = os.environ.get("NUM_Q_BLOCKS", None)
25242524
if num_q_blocks is None:
2525-
block_size = 256
2526-
if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128:
2525+
if (
2526+
prefill_seq_len is None
2527+
or prefill_seq_len % constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE != 0
2528+
or prefill_seq_len < constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE
2529+
):
25272530
raise ValueError(
2528-
f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. "
2531+
f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE}. "
25292532
f"Or set `NUM_Q_BLOCKS` ENV variable"
25302533
f"Received: prefill_seq_len={prefill_seq_len}"
25312534
)
25322535

2533-
num_q_blocks = prefill_seq_len // block_size
2536+
num_q_blocks = prefill_seq_len // constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE
25342537
logger.warning(
25352538
f"Setting NUM_Q_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_Q_BLOCKS` to override"
25362539
)
@@ -2588,31 +2591,28 @@ def export(
25882591
self.model.config, fbs if self.continuous_batching else bs, seq_len
25892592
)
25902593
enable_chunking = kwargs.get("enable_chunking", False)
2591-
if prefill_only:
2592-
if not enable_chunking and self.continuous_batching:
2593-
raise NotImplementedError(
2594-
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
2595-
)
2596-
self.prefill(enable=True, enable_chunking=enable_chunking)
2597-
self.hash_params.pop("retain_full_kv", None)
2598-
seq_len = (
2599-
self.get_seq_len_and_handle_specialized_prefill_model(
2594+
2595+
# TODO: move this to a DA Serving utility class
2596+
if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH:
2597+
if prefill_only:
2598+
if self.continuous_batching and not enable_chunking:
2599+
raise NotImplementedError("Can't enable prefix-caching without chunking")
2600+
self.prefill(enable=True, enable_chunking=enable_chunking)
2601+
self.hash_params.pop("retain_full_kv", None)
2602+
seq_len = self.get_seq_len_and_handle_specialized_prefill_model(
26002603
prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking
26012604
)
2602-
if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH
2603-
else seq_len
2604-
)
2605-
kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len
2606-
else:
2607-
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
2608-
self.hash_params.pop("prefill_only", None)
2609-
self.hash_params.pop("NUM_Q_BLOCKS", None)
2610-
self.hash_params.pop("NUM_FFN_BLOCKS", None)
2611-
self.hash_params.pop("ENABLE_OPT_SWA", None)
2612-
self.hash_params.pop("chunking", None)
2613-
if kwargs.get("retain_full_kv", False):
2614-
kv_cache_shape[2] = seq_len + self.model.config.sliding_window
2615-
self.hash_params["retain_full_kv"] = True
2605+
kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len
2606+
else:
2607+
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
2608+
self.hash_params.pop("prefill_only", None)
2609+
self.hash_params.pop("NUM_Q_BLOCKS", None)
2610+
self.hash_params.pop("NUM_FFN_BLOCKS", None)
2611+
self.hash_params.pop("ENABLE_OPT_SWA", None)
2612+
self.hash_params.pop("chunking", None)
2613+
if kwargs.get("retain_full_kv", False):
2614+
kv_cache_shape[2] = seq_len + self.model.config.sliding_window
2615+
self.hash_params["retain_full_kv"] = True
26162616

26172617
example_inputs = {
26182618
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
@@ -2942,7 +2942,6 @@ def compile(
29422942
if prefill_only is None or not prefill_only:
29432943
if self.continuous_batching and full_batch_size is None:
29442944
raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")
2945-
29462945
else:
29472946
if self.continuous_batching and kv_cache_batch_size is None and full_batch_size is None:
29482947
raise ValueError(

QEfficient/utils/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def get_models_dir():
178178
CCL_MAX_ELEMENTS_LISTS = 5
179179
CCL_START_CTX_LEN = 4096
180180

181+
# used for gpt-oss prefill-only model Q-blocking
182+
GPT_OSS_PREFILL_Q_BLOCK_SIZE = 256
183+
181184

182185
class Constants:
183186
# Export Constants.

tests/transformers/test_causal_lm.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,17 @@ def test_causal_lm_export_and_hash(config, cb, tmp_path):
158158

159159

160160
@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
161-
@pytest.mark.parametrize("subfunc", [False, True], ids=["False", "True"])
161+
@pytest.mark.parametrize("subfunc", [False, True], ids=["non-subfunc", "subfunc"])
162+
@pytest.mark.parametrize("prefill_only", [False, True], ids=["pref+decode", "prefill-only"])
162163
@pytest.mark.parametrize("config", configs, ids=config_ids)
163-
def test_causal_lm_hash_creation(config, cb, subfunc, tmp_path):
164+
def test_causal_lm_hash_creation(config, cb, subfunc, prefill_only, tmp_path):
165+
if config.model_type == "gpt_oss" and prefill_only:
166+
pytest.skip(
167+
"gpt_oss prefill_only mode has different logic to create hash as we have two different ONNX for prefill/decode for this model for disagg serving"
168+
)
164169
model = AutoModelForCausalLM.from_config(config, **model_kwargs)
165170
qeff_model = QEFFAutoModelForCausalLM(model, cb)
166-
qeff_model.export(tmp_path, use_onnx_subfunctions=subfunc)
171+
qeff_model.export(tmp_path, use_onnx_subfunctions=subfunc, prefill_only=prefill_only)
167172
hash_params = {}
168173
hash_params["config"] = qeff_model.model.config.to_diff_dict()
169174
hash_params["peft_config"] = None
@@ -251,12 +256,19 @@ def tmp_cache(tmp_path, monkeypatch):
251256
yield tmp_path
252257

253258

259+
@pytest.mark.parametrize("prefill_only", [False, True], ids=["pref+decode", "prefill_only"])
254260
@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"])
255261
@pytest.mark.parametrize("config", configs, ids=config_ids)
256-
def test_causal_lm_compile(config, cb, tmp_cache):
262+
def test_causal_lm_compile(config, cb, prefill_only, tmp_cache):
263+
if config.model_type == "gpt_oss":
264+
pytest.skip(
265+
"gpt_oss prefill_only mode has different logic to create hash as we have two different ONNX for prefill/decode for this model for disagg serving"
266+
)
257267
model = AutoModelForCausalLM.from_config(config, **model_kwargs)
258268
qeff_model = QEFFAutoModelForCausalLM(model, cb)
259269
compile_params = {"prefill_seq_len": 8, "ctx_len": 16}
270+
if prefill_only:
271+
compile_params["prefill_only"] = True
260272
if cb:
261273
compile_params["full_batch_size"] = 32
262274
compile_params["batch_size"] = 8

0 commit comments

Comments
 (0)