Skip to content

Commit 40d8b22

Browse files
qcdipankartv-karthikeyaochougul
committed
Onboarding Qwen3VlMoe (#590)
The Onboarding of Qwen3VlMoe --------- Signed-off-by: Dipankar Sarkar <quic_dipankar@quicinc.com> Signed-off-by: Dipankar Sarkar <dipankar@qti.qualcomm.com> Signed-off-by: vtirumal <vtirumal@qti.qualcomm.com> Signed-off-by: Onkar Chougule <ochougul@qti.qualcomm.com> Co-authored-by: vtirumal <vtirumal@qti.qualcomm.com> Co-authored-by: Onkar Chougule <168134249+ochougul@users.noreply.github.com> Signed-off-by: Dipankar Sarkar <dipankar@qti.qualcomm.com>
1 parent 5b3ac38 commit 40d8b22

File tree

9 files changed

+279
-68
lines changed

9 files changed

+279
-68
lines changed

QEfficient/generation/embedding_handler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,21 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -
252252

253253
# Process image and text
254254
inputs = self._processor(images=image, text=prompt, return_tensors="pt")
255-
256255
if (
257256
hasattr(self._qeff_model.model.config, "model_type")
258257
and self._qeff_model.model.config.model_type == "qwen2_5_vl"
259258
):
260259
inputs = self._qeff_model.model.prepare_inputs_for_generation(
261260
inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0]
262261
)
262+
263+
if (
264+
hasattr(self._qeff_model.model.config, "model_type")
265+
and self._qeff_model.model.config.model_type == "qwen3_vl_moe"
266+
):
267+
inputs = self._qeff_model.model.prepare_inputs_for_generation(
268+
inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0]
269+
)
263270

264271
# Convert to float32 if needed
265272
if "pixel_values" in inputs:

QEfficient/generation/vlm_generation.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def __init__(
149149
self.is_qwen2_5_vl = (
150150
hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl"
151151
)
152+
self.is_qwen3_vl_moe=(
153+
hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen3_vl_moe"
154+
)
152155
self.qeff_model = qeff_model
153156
self.processor = processor
154157
self.tokenizer = tokenizer
@@ -256,9 +259,10 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len):
256259
outputs, position_ids, generation_len = self.run_prefill(
257260
next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1)
258261
)
259-
260262
if self.is_qwen2_5_vl:
261263
_ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id)
264+
elif self.is_qwen3_vl_moe:
265+
_ = self.update_decode_inputs_qwen3_vl_moe(outputs,position_ids,generation_len,decode_batch_id)
262266
else:
263267
_ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id)
264268

@@ -283,6 +287,27 @@ def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len,
283287
self.generation_len[decode_batch_id or slice(None)] = generation_len
284288
return next_token_id
285289

290+
def update_decode_inputs_qwen3_vl_moe(self, outputs, position_ids, generation_len, decode_batch_id=None):
291+
"""
292+
Updates the decode input with the generated values.
293+
Args:
294+
outputs (dict): The outputs of the model.
295+
position_ids (array): The position IDs.
296+
generation_len (int): The generation length.
297+
decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None.
298+
299+
Returns:
300+
next_token_id (array): The next token ID.
301+
"""
302+
next_token_id = self._fetch_next_token_id(outputs)
303+
304+
# Store the generated values.
305+
self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id
306+
self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1)
307+
self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1)
308+
self.generation_len[decode_batch_id or slice(None)] = generation_len
309+
return next_token_id
310+
286311
def _execute_chunked_prefill(
287312
self,
288313
lang_inputs: Dict[str, np.ndarray],
@@ -583,12 +608,12 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,
583608
self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length)
584609
if self.is_qwen2_5_vl:
585610
self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64)
586-
611+
if self.is_qwen3_vl_moe:
612+
self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64)
587613
# Create prompt queue
588614
prompt_queue = deque(vision_prompts)
589615

590616
start = perf_counter()
591-
592617
# Pre-process ALL vision inputs and cache them
593618
logger.info("Pre-processing all vision inputs...")
594619
for batch_id in range(min(self.full_batch_size, len(vision_prompts))):
@@ -610,7 +635,6 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,
610635

611636
# Reset prompt queue for prefill
612637
prompt_queue = deque(vision_prompts)
613-
614638
self.batch_index = None
615639

616640
# Run prefill for all inputs using cached vision
@@ -696,6 +720,10 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation
696720
self.update_decode_inputs_qwen2_5_vl(
697721
outputs, position_ids_decode, generation_len_final, decode_batch_id
698722
)
723+
if self.is_qwen3_vl_moe:
724+
self.update_decode_inputs_qwen3_vl_moe(
725+
outputs, position_ids_decode, generation_len_final, decode_batch_id
726+
)
699727
else:
700728
self.update_decode_input(outputs, position_ids_decode, generation_len_final, decode_batch_id)
701729
else:

QEfficient/transformers/cache_utils.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,11 +327,9 @@ def __init__(
327327
**kwargs,
328328
):
329329
# Remove layer_classes if present to avoid duplicate argument
330-
# breakpoint()
331330
kwargs.pop("layers", None)
332331
from transformers.cache_utils import Cache # Import here to avoid circular import
333332

334-
# breakpoint()
335333
layers = []
336334
# If a config is passed, use it to infer the layer types and initialize accordingly
337335
if len(layers) == 0:

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@
178178
Qwen3VLMoeTextDecoderLayer,
179179
Qwen3VLMoeTextModel,
180180
Qwen3VLMoeTextRMSNorm,
181+
Qwen3VLMoeTextSparseMoeBlock,
181182
Qwen3VLMoeVisionAttention,
182183
Qwen3VLMoeVisionModel,
183184
)
@@ -407,7 +408,7 @@
407408
QEffQwen3VLMoeTextAttention,
408409
QEffQwen3VLMoeTextDecoderLayer,
409410
QEffQwen3VLMoeTextModel,
410-
# QEffQwen3VLMoeTextSparseMoeBlock,
411+
QEffQwen3VLMoeTextSparseMoeBlock,
411412
QEffQwen3VLMoeVisionAttention,
412413
QEffQwen3VLMoeVisionModel,
413414
)
@@ -612,7 +613,7 @@ class KVCacheTransform(ModuleMappingTransform):
612613
Qwen3VLMoeVisionAttention: QEffQwen3VLMoeVisionAttention,
613614
Qwen3VLMoeVisionModel: QEffQwen3VLMoeVisionModel,
614615
Qwen3VLMoeTextModel: QEffQwen3VLMoeTextModel,
615-
# Qwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock,
616+
Qwen3VLMoeTextSparseMoeBlock: QEffQwen3VLMoeTextSparseMoeBlock,
616617
# Grok1
617618
# Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel,
618619
# Starcoder2

QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def forward(
591591
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
592592
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
593593

594-
kv_seq_len = key_states.shape[-2]
594+
# kv_seq_len = key_states.shape[-2]
595595
kv_seq_len = past_key_value.get_seq_length(self.layer_idx, cache_position)
596596
past_seen_tokens = past_key_value.get_seq_length() if past_key_value is not None else 0
597597

QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py

Lines changed: 24 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -711,23 +711,29 @@ class QEffQwen3VLMoeTextSparseMoeBlock(Qwen3VLMoeTextSparseMoeBlock):
711711
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
712712
B, S, H = hidden_states.shape
713713
T = B * S
714-
hidden_states = hidden_states.view(T, H)
715-
router_logits = self.gate(hidden_states) # [T, E]
716-
prob = F.softmax(router_logits, -1, dtype=torch.float)
717-
top_w, top_i = torch.topk(prob, self.top_k, -1)
718-
top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype)
719-
gate_proj_up_w = self.experts.gate_up_proj.requires_grad_(False)[top_i.flatten()]
720-
down_proj_w = self.experts.down_proj.requires_grad_(False)[top_i.flatten()]
721-
722-
expert_in = hidden_states.unsqueeze(1).expand(-1, self.top_k, -1).contiguous().view(-1, 1, H)
723-
gate_up = torch.bmm(expert_in, gate_proj_up_w)
724-
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
714+
x = hidden_states.view(T, H)
715+
716+
router_logits = self.gate(x)
717+
prob = F.softmax(router_logits, dim=-1, dtype=torch.float)
718+
top_w, top_i = torch.topk(prob, self.top_k, dim=-1)
719+
top_w = top_w / top_w.sum(dim=1, keepdim=True)
720+
top_w = top_w.to(x.dtype)
721+
idx = top_i.reshape(-1)
722+
w_up = self.experts.gate_up_proj.index_select(0, idx)
723+
w_dn = self.experts.down_proj.index_select(0, idx)
724+
725+
xk = x.unsqueeze(1).expand(-1, self.top_k, -1).contiguous()
726+
xk = xk.view(-1, 1, H)
727+
gate_up = torch.bmm(xk, w_up)
728+
I2 = gate_up.size(-1)
729+
half = I2 // 2
730+
gate, up = gate_up[..., :half], gate_up[..., half:]
725731
intermediate = up * self.experts.act_fn(gate)
726-
experts_out = torch.bmm(intermediate, down_proj_w)
727-
experts_out = experts_out.view(B * S, self.top_k, H)
728-
experts_out = experts_out * top_w.unsqueeze(-1)
729-
experts_out = experts_out.sum(dim=1)
730-
return experts_out.view(B, S, H), router_logits
732+
experts_out = torch.bmm(intermediate, w_dn)
733+
experts_out = experts_out.view(T, self.top_k, H) * top_w.unsqueeze(-1)
734+
experts_out = experts_out.sum(dim=1).view(B, S, H)
735+
736+
return experts_out, router_logits
731737

732738

733739
class QEffQwen3VLMoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration):
@@ -737,44 +743,6 @@ def get_qeff_vision_encoder(self):
737743
def get_qeff_language_decoder(self):
738744
return QEffQwen3VLDecoderWrapper(self)
739745

740-
# def forward(
741-
# self,
742-
# input_ids,
743-
# position_ids,
744-
# past_key_values,
745-
# pixel_values:Optional[torch.FloatTensor] = None,
746-
# image_idx:Optional[torch.LongTensor] = None,
747-
# comp_ctx_lengths: Optional[List[int]] = None,
748-
# batch_index: Optional[torch.LongTensor] = None,
749-
# image_grid_thw=None,
750-
# ):
751-
# image_embeds = self.model.visual(pixel_values, grid_thw=image_grid_thw)[0]
752-
# bs = image_grid_thw.shape[0]
753-
# split_size = torch.floor_divide(torch.tensor(image_embeds.size(0)), bs)
754-
755-
# inputs_embeds = self.model.get_input_embeddings()(input_ids)
756-
# B, N, C = inputs_embeds.shape
757-
# selected = input_ids == self.model.config.image_token_id
758-
# indices1 = selected.to(torch.int64).cumsum(1) - 1
759-
# indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1)
760-
# indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1)
761-
# image_features_expanded = image_embeds.reshape(-1, C).unsqueeze(0)[indices0, indices1]
762-
# image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds)
763-
# inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds)
764-
# outputs = self.language_model(
765-
# inputs_embeds=inputs_embeds,
766-
# position_ids=position_ids,
767-
# past_key_values=past_key_values,
768-
# comp_ctx_lengths=comp_ctx_lengths,
769-
# batch_index=batch_index,
770-
# use_cache=True,
771-
# )
772-
# logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True)
773-
# hidden_states = outputs.last_hidden_state[torch.arange(position_ids[0].shape[0]).view(-1, 1), logit_index]
774-
# logits = self.lm_head(hidden_states)
775-
# image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0)
776-
# return logits, image_embeds, image_idx, outputs.past_key_values
777-
778746
def get_dummy_inputs(
779747
self,
780748
comp_ctx_lengths: Optional[List[int]] = None,
@@ -1036,7 +1004,7 @@ def get_onnx_dynamic_axes(
10361004
lang_dynamic_axes = {
10371005
"input_ids": {0: "batch_size", 1: "seq_len"},
10381006
"position_ids": {1: "batch_size", 2: "seq_len"},
1039-
"vision_embeds": {0: "batch_size", 1: "vision_size"},
1007+
"vision_embeds": {0: "vision_batch_size", 1: "vision_size"},
10401008
}
10411009

10421010
for i in range(num_layers):
@@ -1102,6 +1070,7 @@ def prepare_inputs_for_generation(self, inputs, prefill_seq_len=128, batch_size=
11021070
inputs["position_ids"] = F.pad(
11031071
inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1
11041072
)
1073+
inputs.pop("image_grid_thw", None)
11051074
return inputs
11061075

11071076
def get_inputs_info(self):

examples/qwen3_vl.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(
1919
model_id, attn_implementation="eager", kv_offload=True, config=config
2020
)
21+
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
2122
processor = AutoProcessor.from_pretrained(model_id)
2223
### use skip_vision=Ture, if want to run only text, else false ###
2324
skip_vision = False
@@ -60,7 +61,7 @@
6061
return_tensors="pt",
6162
)
6263
inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size)
63-
streamer = TextStreamer(processor.tokenizer)
64+
streamer = TextStreamer(tokenizer)
6465
output = qeff_model.generate(inputs=inputs, generation_len=100)
6566
print(output.generated_ids)
6667
print(processor.tokenizer.batch_decode(output.generated_ids))
@@ -77,6 +78,8 @@
7778
num_devices=4,
7879
height=354,
7980
width=536,
81+
# height=1024,
82+
# width=1024,
8083
mxfp6_matmul=True,
8184
mxint8_kv_cache=True,
8285
aic_enable_depth_first=True,
@@ -85,7 +88,9 @@
8588

8689
### IMAGE + TEXT ###
8790
image_url = "https://picsum.photos/id/237/536/354"
88-
# image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
91+
# image_url = (
92+
# "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png"
93+
# )
8994

9095
image = Image.open(requests.get(image_url, stream=True).raw)
9196

@@ -122,7 +127,7 @@
122127
return_tensors="pt",
123128
)
124129
inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size)
125-
streamer = TextStreamer(processor.tokenizer)
130+
streamer = TextStreamer(tokenizer)
126131
output = qeff_model.generate(inputs=inputs, generation_len=100)
127132
print(output.generated_ids)
128133
print(processor.tokenizer.batch_decode(output.generated_ids))

0 commit comments

Comments
 (0)