Skip to content

Commit 6434697

Browse files
committed
Adding dissagg mode support to Qwen3Moe (#682)
**Adding disagg support to Qwen3Moe** > Config used PL =128 CL=128*3 <img width="726" height="1077" alt="image" src="https://github.com/user-attachments/assets/7b9afa00-8505-4df5-9a91-68b55e89b416" /> --------- Signed-off-by: Dipankar Sarkar <dipankar@qti.qualcomm.com>
1 parent 27847e1 commit 6434697

File tree

5 files changed

+192
-46
lines changed

5 files changed

+192
-46
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2991,18 +2991,22 @@ def export(
29912991
self.model.config, fbs if self.continuous_batching else bs, seq_len
29922992
)
29932993
enable_chunking = kwargs.get("enable_chunking", False)
2994-
2995-
# TODO: move this to a DA Serving utility class
29962994
if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH:
29972995
if prefill_only:
2998-
if self.continuous_batching and not enable_chunking:
2999-
raise NotImplementedError("Can't enable prefix-caching without chunking")
2996+
if not enable_chunking and self.continuous_batching:
2997+
raise NotImplementedError(
2998+
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
2999+
)
30003000
self.prefill(enable=True, enable_chunking=enable_chunking)
30013001
self.hash_params.pop("retain_full_kv", None)
30023002
seq_len = self.get_seq_len_and_handle_specialized_prefill_model(
30033003
prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking
30043004
)
3005-
kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len
3005+
kv_cache_shape[2] = (
3006+
seq_len + (self.model.config.sliding_window if self.model.config.sliding_window is not None else 0)
3007+
if enable_chunking
3008+
else seq_len
3009+
)
30063010
else:
30073011
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
30083012
self.hash_params.pop("prefill_only", None)
@@ -3011,7 +3015,9 @@ def export(
30113015
self.hash_params.pop("ENABLE_OPT_SWA", None)
30123016
self.hash_params.pop("chunking", None)
30133017
if kwargs.get("retain_full_kv", False):
3014-
kv_cache_shape[2] = seq_len + self.model.config.sliding_window
3018+
kv_cache_shape[2] = seq_len + (
3019+
self.model.config.sliding_window if self.model.config.sliding_window is not None else 0
3020+
)
30153021
self.hash_params["retain_full_kv"] = True
30163022

30173023
example_inputs = {

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@
446446
QEffQwen3Model,
447447
)
448448
from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import (
449+
QEffPrefillChunkedQwen3MoeSparseMoeBlock,
449450
QEffQwen3MoeAttention,
450451
QEffQwen3MoeDecoderLayer,
451452
QEffQwen3MoeForCausalLM,
@@ -728,20 +729,26 @@ class PrefillOnlyTransform(ModuleMappingTransform):
728729

729730
class PrefillOnlyChunkedTransform(ModuleMappingTransform):
730731
_module_mapping = {
732+
# GPT_OSS
731733
QEffGptOssModel: QEffPrefillOnlyGptOssModel,
732734
QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
733735
QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP,
734736
QEffQwen3VLMoeTextSparseMoeBlock: QEffPrefillChunkedQwen3VLMoeTextSparseMoeBlock,
737+
# Qwen3Moe
738+
QEffQwen3MoeSparseMoeBlock: QEffPrefillChunkedQwen3MoeSparseMoeBlock,
735739
}
736740

737741

738742
class RevertPrefillKeepAttentionTransform(ModuleMappingTransform):
739743
_module_mapping = {
744+
# GPT_OSS
740745
QEffGptOssModel: QEffPrefillOnlyGptOssModel,
741746
QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
742747
QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
743748
QEffPrefillOnlyGptOssMLP: QEffGptOssMLP,
744749
QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP,
750+
# Qwen3Moe
751+
QEffPrefillChunkedQwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock,
745752
}
746753

747754

QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def eager_attention_forward(
104104
key_states = repeat_kv(key, module.num_key_value_groups)
105105

106106
value_states = repeat_kv(value, module.num_key_value_groups)
107-
108107
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
109108
if attention_mask is not None:
110109
attn_weights = torch.where(
@@ -118,53 +117,50 @@ def eager_attention_forward(
118117
return attn_output, attn_weights
119118

120119

121-
class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
122-
def __qeff_init__(self):
123-
self.gate_proj_w = []
124-
self.up_proj_w = []
125-
self.down_proj_w = []
126-
with torch.no_grad():
127-
for e in range(self.num_experts):
128-
self.gate_proj_w.append(self.experts[e].gate_proj.weight.T)
129-
self.up_proj_w.append(self.experts[e].up_proj.weight.T)
130-
self.down_proj_w.append(self.experts[e].down_proj.weight.T)
131-
self.gate_proj_w = torch.stack(self.gate_proj_w)
132-
self.up_proj_w = torch.stack(self.up_proj_w)
133-
self.down_proj_w = torch.stack(self.down_proj_w)
134-
135-
def alt_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
120+
class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
121+
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
136122
B, S, H = hidden_states.shape
137123
T = B * S
138124
x = hidden_states.view(T, H)
139-
140125
router_logits = self.gate(x) # [T, E]
141126
prob = F.softmax(router_logits, -1, dtype=torch.float)
142127
top_w, top_i = torch.topk(prob, self.top_k, -1)
143128
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
144129
top_w /= top_w.sum(-1, keepdim=True)
145-
top_w = top_w.to(x.dtype)
130+
top_w = top_w.to(hidden_states.dtype)
146131
masked_logits = torch.zeros_like(router_logits)
147132
masked_logits.scatter_(1, top_i, top_w)
148-
149133
# Routing weights for each expert [T, E]
150134
routing_weights = masked_logits
151-
152135
# ────────────────── allocate the output tensor ─────
153136
expert_out = x.new_zeros((T, H)) # accumulation buffer
154-
155137
# ───────────────────────── Expert computation loop ─────────────────────────────
156138
for e in range(self.num_experts):
157139
routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1]
158-
W_g, W_u = self.experts[e].gate_proj, self.experts[e].up_proj # [H, I], [H, I]
159-
W_d = self.experts[e].down_proj # [I, H]
160-
gate = W_g(x) # [T, I]
161-
up = W_u(x) # [T, I]
162-
down = W_d(up * self.experts[e].act_fn(gate)) # [T, H]
163-
164-
masked_down = torch.where(routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out))
140+
W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T # [H, I], [H, I]
141+
W_d = self.experts[e].down_proj.weight.T # [I, H]
142+
gate = x @ W_g # [T, I]
143+
up = x @ W_u # [T, I]
144+
down = (up * self.experts[e].act_fn(gate)) @ W_d # [T, H]
145+
masked_down = down * routing_weight
165146
expert_out += masked_down
166147
return expert_out.view(B, S, H), router_logits
167148

149+
150+
class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock):
151+
def __qeff_init__(self):
152+
self.gate_proj_w = []
153+
self.up_proj_w = []
154+
self.down_proj_w = []
155+
with torch.no_grad():
156+
for e in range(self.num_experts):
157+
self.gate_proj_w.append(self.experts[e].gate_proj.weight.T)
158+
self.up_proj_w.append(self.experts[e].up_proj.weight.T)
159+
self.down_proj_w.append(self.experts[e].down_proj.weight.T)
160+
self.gate_proj_w = torch.stack(self.gate_proj_w)
161+
self.up_proj_w = torch.stack(self.up_proj_w)
162+
self.down_proj_w = torch.stack(self.down_proj_w)
163+
168164
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
169165
B, S, H = hidden_states.shape
170166
T = B * S
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import time
9+
10+
import numpy as np
11+
import torch
12+
from transformers import AutoConfig, AutoTokenizer
13+
14+
from QEfficient import QEFFAutoModelForCausalLM
15+
from QEfficient.generation.cloud_infer import QAICInferenceSession
16+
17+
model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32
18+
prompt = """
19+
Explain quantum computing in simple terms.
20+
"""
21+
config = AutoConfig.from_pretrained(model_id)
22+
tokenizer = AutoTokenizer.from_pretrained(model_id)
23+
PREFILL_SEQ_LEN = 128
24+
CTX_LEN = 128 * 3
25+
26+
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id)
27+
decode_qpc_path = qeff_model.compile(
28+
prefill_seq_len=1,
29+
ctx_len=CTX_LEN,
30+
num_cores=16,
31+
mxfp6_matmul=True,
32+
mxint8_kv_cache=True,
33+
num_devices=1,
34+
mos=1,
35+
aic_enable_depth_first=True,
36+
num_speculative_tokens=None,
37+
offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step
38+
retain_full_kv=True,
39+
)
40+
41+
# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68
42+
43+
# prefill_qpc_path = ""
44+
45+
prefill_qpc_path = qeff_model.compile(
46+
prefill_seq_len=PREFILL_SEQ_LEN,
47+
ctx_len=CTX_LEN,
48+
num_cores=16,
49+
mxfp6_matmul=True,
50+
mxint8_kv_cache=True,
51+
num_devices=2,
52+
split_retained_state_io=True,
53+
mos=1,
54+
aic_enable_depth_first=True,
55+
num_speculative_tokens=None,
56+
prefill_only=True,
57+
enable_chunking=True,
58+
# use_onnx_subfunctions=True,
59+
)
60+
61+
62+
inputs = tokenizer(prompt, return_tensors="np", padding=True)
63+
position_ids = inputs["attention_mask"].sum(1, keepdims=True)
64+
generation_len = CTX_LEN - position_ids.max()
65+
padded_len = inputs["input_ids"].shape[1]
66+
num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
67+
padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
68+
inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
69+
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
70+
inputs.pop("token_type_ids", None)
71+
inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
72+
inputs.pop("past_key_values", None)
73+
inputs = {k: v.detach().numpy() for k, v in inputs.items()}
74+
75+
76+
prefill_session = QAICInferenceSession(prefill_qpc_path)
77+
decode_session = QAICInferenceSession(decode_qpc_path)
78+
79+
all_outputs = []
80+
for i in range(num_chunks):
81+
chunk_inputs = inputs.copy()
82+
chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
83+
chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
84+
ins = time.time()
85+
qpc_out = prefill_session.run(chunk_inputs)
86+
print(f"time for this run={time.time() - ins}")
87+
for i in range(config.num_hidden_layers):
88+
inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
89+
inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]
90+
91+
all_outputs.append(np.argmax(qpc_out["logits"]))
92+
93+
decode_inputs = {
94+
"input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1),
95+
"position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1,
96+
}
97+
for i in range(config.num_hidden_layers):
98+
decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
99+
decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]
100+
101+
st = time.time()
102+
decode_out = decode_session.run(decode_inputs)
103+
print(f"time for first run of decode with KV as input = {time.time() - st} sec\n")
104+
all_outputs.append(np.argmax(decode_out["logits"]))
105+
pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1
106+
loop_decode_inputs = {
107+
"input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
108+
"position_ids": pos_id,
109+
}
110+
111+
for i in range(config.num_hidden_layers):
112+
loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"]
113+
loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"]
114+
115+
st = time.time()
116+
for i in range(generation_len - 2):
117+
decode_out = decode_session.run(loop_decode_inputs)
118+
all_outputs.append(np.argmax(decode_out["logits"]))
119+
pos_id += 1
120+
for i in range(config.num_hidden_layers):
121+
loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"]
122+
loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"]
123+
124+
loop_decode_inputs.update(
125+
{
126+
"input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
127+
"position_ids": pos_id,
128+
}
129+
)
130+
ft = time.time()
131+
132+
print(f"decode tok/sec={(generation_len - 2) / (ft - st)}")
133+
print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}")

tests/transformers/models/test_disagg_mode.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616
from QEfficient.generation.cloud_infer import QAICInferenceSession
1717
from QEfficient.transformers.quantizers import replace_transformers_quantizers, undo_transformers_quantizers
1818

19-
model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32
20-
19+
# model id based on blocking support and chunking
20+
model_id_blocking = [
21+
"openai/gpt-oss-20b",
22+
]
23+
model_id_chunking = [
24+
"Qwen/Qwen3-30B-A3B-Instruct-2507",
25+
]
2126
prompt2 = """
2227
Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures.
2328
@@ -32,7 +37,7 @@
3237

3338
@pytest.mark.on_qaic
3439
@pytest.mark.llm_model
35-
@pytest.mark.parametrize("model_id", [model_id])
40+
@pytest.mark.parametrize("model_id", model_id_blocking)
3641
@pytest.mark.parametrize("prompt", prompts)
3742
def test_disagg_mode_prefill(model_id, prompt):
3843
# Run prefill
@@ -93,7 +98,7 @@ def test_disagg_mode_prefill(model_id, prompt):
9398
)
9499

95100
prefill_session = QAICInferenceSession(prefill_qpc_path)
96-
logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32)
101+
logits_out_placeholder = np.zeros((1, 1, config.vocab_size), dtype=np.float32)
97102
prefill_session.set_buffers({"logits": logits_out_placeholder})
98103
inputs.pop("past_key_values")
99104
inputs = {k: v.detach().numpy() for k, v in inputs.items()}
@@ -105,10 +110,9 @@ def test_disagg_mode_prefill(model_id, prompt):
105110
assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 5e-2
106111

107112

108-
@pytest.mark.skip(reason="no way of currently testing this without the assert sdk")
109113
@pytest.mark.on_qaic
110114
@pytest.mark.llm_model
111-
@pytest.mark.parametrize("model_id", [model_id])
115+
@pytest.mark.parametrize("model_id", model_id_chunking)
112116
@pytest.mark.parametrize("prompt", prompts)
113117
def test_disagg_mode_prefill_chunked(model_id, prompt):
114118
# Run prefill
@@ -143,7 +147,7 @@ def test_disagg_mode_prefill_chunked(model_id, prompt):
143147
past_key_values = []
144148
for i in range(config.num_hidden_layers):
145149
cache_len = CTX_LEN
146-
pad_shape = (1, 8, cache_len, 64)
150+
pad_shape = (1, config.num_key_value_heads, cache_len, config.head_dim)
147151
past_key = torch.zeros((pad_shape), dtype=torch.float32)
148152
past_value = torch.zeros((pad_shape), dtype=torch.float32)
149153
pkv = (past_key, past_value)
@@ -178,7 +182,7 @@ def test_disagg_mode_prefill_chunked(model_id, prompt):
178182
prefill_session.skip_buffers(
179183
[x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")]
180184
)
181-
logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32)
185+
logits_out_placeholder = np.zeros((1, 1, config.vocab_size), dtype=np.float32)
182186
prefill_session.set_buffers({"logits": logits_out_placeholder})
183187
inputs.pop("past_key_values")
184188
inputs = {k: v.detach().numpy() for k, v in inputs.items()}
@@ -195,7 +199,7 @@ def test_disagg_mode_prefill_chunked(model_id, prompt):
195199

196200

197201
@pytest.mark.on_qaic
198-
@pytest.mark.parametrize("model_id", [model_id])
202+
@pytest.mark.parametrize("model_id", model_id_blocking)
199203
@pytest.mark.parametrize("prompt", [prompt1])
200204
def test_disagg_mode_prefill_only_and_decode_only(model_id, prompt):
201205
# Run prefill for original pytorch model
@@ -300,7 +304,7 @@ def test_disagg_mode_prefill_only_and_decode_only(model_id, prompt):
300304
)
301305

302306
prefill_session = QAICInferenceSession(prefill_qpc_path)
303-
logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32)
307+
logits_out_placeholder = np.zeros((1, 1, config.vocab_size), dtype=np.float32)
304308
prefill_session.set_buffers({"logits": logits_out_placeholder})
305309
inputs.pop("past_key_values")
306310
inputs = {k: v.detach().numpy() for k, v in inputs.items()}
@@ -366,7 +370,7 @@ def test_disagg_mode_prefill_only_and_decode_only(model_id, prompt):
366370

367371

368372
@pytest.mark.on_qaic
369-
@pytest.mark.parametrize("model_id", [model_id])
373+
@pytest.mark.parametrize("model_id", model_id_blocking)
370374
@pytest.mark.parametrize("prompt", [prompt1])
371375
def test_disagg_mode_prefix_caching(model_id, prompt):
372376
PREFILL_SEQ_LEN = 128
@@ -445,7 +449,7 @@ def prefix_caching_inference(model_id, prefill_qpc_path, decode_qpc_path, prompt
445449
inputs["batch_index"] = np.array([[decode_batch_id]], dtype=np.int64)
446450

447451
prefill_session = QAICInferenceSession(prefill_qpc_path)
448-
logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32)
452+
logits_out_placeholder = np.zeros((1, 1, config.vocab_size), dtype=np.float32)
449453
prefill_session.set_buffers({"logits": logits_out_placeholder})
450454
for i in range(num_chunks):
451455
chunk_inputs = inputs.copy()

0 commit comments

Comments
 (0)