Skip to content

Commit 9328991

Browse files
tv-karthikeyaqcdipankar
authored andcommitted
updating Subfn, Prefill only logic in Disagg mode (#820)
Added Support for Subfn for Qwen 3 VL dense, MOE. Updated prefill only logic for disagg mode --------- Signed-off-by: vtirumal <vtirumal@qti.qualcomm.com> Signed-off-by: Dipankar Sarkar <dipankar@qti.qualcomm.com>
1 parent f1a35f0 commit 9328991

File tree

7 files changed

+142
-131
lines changed

7 files changed

+142
-131
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,16 +1090,16 @@ def export(
10901090
Path to the generated ONNX graph file for the language decoder.
10911091
"""
10921092
if prefill_only:
1093-
if prefill_seq_len > 1:
1094-
if not enable_chunking and self.continuous_batching:
1095-
raise NotImplementedError(
1096-
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
1097-
)
1098-
self.hash_params["prefill_only"] = True
1099-
self.prefill(enable=True, enable_chunking=enable_chunking)
1100-
else:
1101-
self.hash_params["prefill_only"] = False
1102-
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
1093+
assert prefill_seq_len > 1
1094+
if not enable_chunking and self.continuous_batching:
1095+
raise NotImplementedError(
1096+
"Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
1097+
)
1098+
self.hash_params["prefill_only"] = True
1099+
self.prefill(enable=True, enable_chunking=enable_chunking)
1100+
else:
1101+
self.hash_params["prefill_only"] = False
1102+
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
11031103

11041104
return self._export(
11051105
inputs,
@@ -1286,24 +1286,6 @@ def onnx_path(self):
12861286
"""
12871287
return [self.vision_model.onnx_path, self.lang_model.onnx_path]
12881288

1289-
@property
1290-
def qpc_path(self):
1291-
"""
1292-
Get the QPC paths for the vision and language model components.
1293-
1294-
Returns
1295-
-------
1296-
Union[List[str], str, None]
1297-
A list containing both QPC paths if both are compiled, or just one if only one is,
1298-
or None if neither is compiled.
1299-
"""
1300-
if self.vision_model.qpc_path and self.lang_model.qpc_path:
1301-
return [self.vision_model.qpc_path, self.lang_model.qpc_path]
1302-
elif self.vision_model.qpc_path:
1303-
return self.vision_model.qpc_path
1304-
else:
1305-
return self.lang_model.qpc_path
1306-
13071289
def export(
13081290
self,
13091291
export_dir: Optional[str] = None,
@@ -1416,7 +1398,7 @@ def compile(
14161398
skip_vision: Optional[bool] = False,
14171399
skip_lang: Optional[bool] = False,
14181400
use_onnx_subfunctions: bool = False,
1419-
prefill_only=False,
1401+
prefill_only=None,
14201402
enable_chunking=False,
14211403
**compiler_options,
14221404
) -> str:
@@ -1536,11 +1518,7 @@ def compile(
15361518
if lang_onnx_path:
15371519
self.lang_model.onnx_path = lang_onnx_path
15381520

1539-
if (
1540-
(self.vision_model.onnx_path is None and vision_onnx_path is None)
1541-
or (self.lang_model.onnx_path is None and lang_onnx_path is None)
1542-
or prefill_only
1543-
):
1521+
if vision_onnx_path is None or lang_onnx_path is None:
15441522
self.export(
15451523
use_onnx_subfunctions=use_onnx_subfunctions,
15461524
skip_vision=skip_vision,
@@ -1554,8 +1532,9 @@ def compile(
15541532
compiler_options.pop("continuous_batching", None)
15551533
compiler_options.pop("kv_cache_batch_size", None)
15561534
compiler_options.pop("full_batch_size", None)
1535+
self.qpc_paths = {}
15571536
if not skip_vision:
1558-
self.vision_model._compile(
1537+
vision_qpc_path = self.vision_model._compile(
15591538
compile_dir=compile_dir,
15601539
compile_only=True,
15611540
specializations=specializations["vision"],
@@ -1568,6 +1547,7 @@ def compile(
15681547
use_onnx_subfunctions=use_onnx_subfunctions,
15691548
**compiler_options,
15701549
)
1550+
self.qpc_paths["vision_qpc_path"] = vision_qpc_path
15711551

15721552
# Custom NPI file options
15731553
if hasattr(self.model, "get_npi_file") and "node_precision_info" not in compiler_options:
@@ -1592,16 +1572,17 @@ def compile(
15921572
if ("vision_embeds" in output_name or "deepstack_features" in output_name)
15931573
else kv_cache_dtype
15941574
)
1595-
15961575
if prefill_only:
1597-
if prefill_seq_len > 1:
1598-
specializations = specializations["lang"][:1] # prefill
1599-
else:
1600-
specializations = specializations["lang"][-1:] # decoder
1576+
specializations = specializations["lang"][:1]
1577+
qpc_key = "lang_prefill_qpc_path"
1578+
elif prefill_seq_len == 1:
1579+
specializations = specializations["lang"][-1:]
1580+
qpc_key = "lang_decode_qpc_path"
16011581
else:
16021582
specializations = specializations["lang"]
1583+
qpc_key = "lang_qpc_path"
16031584

1604-
self.lang_model._compile(
1585+
lang_qpc_path = self.lang_model._compile(
16051586
compile_dir=compile_dir,
16061587
compile_only=True,
16071588
retained_state=True,
@@ -1615,9 +1596,8 @@ def compile(
16151596
use_onnx_subfunctions=use_onnx_subfunctions,
16161597
**compiler_options,
16171598
)
1618-
if skip_vision and prefill_only: # for disagg serving
1619-
return self.lang_model.qpc_path
1620-
return self.qpc_path
1599+
self.qpc_paths.update({qpc_key: lang_qpc_path})
1600+
return self.qpc_paths
16211601

16221602
def generate(
16231603
self,

QEfficient/transformers/models/qwen3_vl/modeling_qwen3_vl.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77
import math
8-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
99

1010
import torch
1111
import torch.nn as nn
@@ -562,6 +562,15 @@ def __init__(self, model):
562562
self.model = model
563563
self.model.vision_model = self.model.visual
564564

565+
def get_submodules_for_export(self) -> Type[nn.Module]:
566+
"""
567+
Return the set of class used as the repeated layer across the model for subfunction extraction.
568+
Notes:
569+
This method should return the *class object* (not an instance).
570+
Downstream code can use this to find/build subfunctions for repeated blocks.
571+
"""
572+
return {self.model.visual.blocks[0].__class__}
573+
565574
def forward(self, pixel_values, image_grid_thw):
566575
image_embeds, deepstack_feature_lists = self.model.visual(pixel_values, grid_thw=image_grid_thw)
567576
bs = image_grid_thw.shape[0]
@@ -580,6 +589,15 @@ def __init__(self, model):
580589
self.model = model
581590
self.language_model = self.model.model
582591

592+
def get_submodules_for_export(self) -> Type[nn.Module]:
593+
"""
594+
Return the set of class used as the repeated layer across the model for subfunction extraction.
595+
Notes:
596+
This method should return the *class object* (not an instance).
597+
Downstream code can use this to find/build subfunctions for repeated blocks.
598+
"""
599+
return {QEffQwen3VLTextDecoderLayer}
600+
583601
def forward(
584602
self,
585603
input_ids,

QEfficient/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#
66
# -----------------------------------------------------------------------------
77
import math
8-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
8+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
99

1010
import torch
1111
import torch.nn as nn
@@ -629,7 +629,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
629629
router_logits = self.gate(x) # [T, E]
630630
prob = F.softmax(router_logits, dim=-1, dtype=hidden_states.dtype)
631631
top_w, top_i = torch.topk(prob, self.top_k, dim=-1) # [T, k], [T, k]
632-
top_w = top_w / top_w.sum(dim=-1, keepdim=True)
632+
top_w = top_w / torch.einsum("bi->b", top_w)[:, None]
633633
top_w = top_w.to(hidden_states.dtype)
634634

635635
# gate_up_proj: [E, H, 2I], down_proj: [E, I, H]
@@ -711,6 +711,15 @@ def __init__(self, model):
711711
self.model = model
712712
self.model.vision_model = self.model.visual
713713

714+
def get_submodules_for_export(self) -> Type[nn.Module]:
715+
"""
716+
Return the set of class used as the repeated layer across the model for subfunction extraction.
717+
Notes:
718+
This method should return the *class object* (not an instance).
719+
Downstream code can use this to find/build subfunctions for repeated blocks.
720+
"""
721+
return {self.model.visual.blocks[0].__class__}
722+
714723
def forward(self, pixel_values, image_grid_thw):
715724
image_embeds, deepstack_feature_lists = self.model.visual(pixel_values, grid_thw=image_grid_thw)
716725
bs = image_grid_thw.shape[0]
@@ -727,7 +736,16 @@ class QEffQwen3VLDecoderWrapper(nn.Module):
727736
def __init__(self, model):
728737
super().__init__()
729738
self.model = model
730-
self.language_model = self.model.model
739+
self.language_model = self.model.model.language_model
740+
741+
def get_submodules_for_export(self) -> Type[nn.Module]:
742+
"""
743+
Return the set of class used as the repeated layer across the model for subfunction extraction.
744+
Notes:
745+
This method should return the *class object* (not an instance).
746+
Downstream code can use this to find/build subfunctions for repeated blocks.
747+
"""
748+
return {QEffQwen3VLMoeTextDecoderLayer}
731749

732750
def forward(
733751
self,
@@ -790,7 +808,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
790808
router_logits = self.gate(x)
791809
prob = F.softmax(router_logits, dim=-1, dtype=torch.float)
792810
top_w, top_i = torch.topk(prob, self.top_k, dim=-1)
793-
top_w = top_w / top_w.sum(dim=1, keepdim=True)
811+
top_w = top_w / torch.einsum("bi->b", top_w)[:, None]
794812
top_w = top_w.to(x.dtype)
795813
idx = top_i.reshape(-1)
796814
w_up = self.experts.gate_up_proj.index_select(0, idx)
@@ -805,9 +823,8 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
805823
intermediate = up * self.experts.act_fn(gate)
806824
experts_out = torch.bmm(intermediate, w_dn)
807825
experts_out = experts_out.view(T, self.top_k, H) * top_w.unsqueeze(-1)
808-
experts_out = experts_out.sum(dim=1).view(B, S, H)
809-
810-
return experts_out, router_logits
826+
experts_out = torch.einsum("bnd->bd", experts_out)
827+
return experts_out.view(B, S, H), router_logits
811828

812829

813830
class QEffQwen3VLMoeForConditionalGeneration(Qwen3VLMoeForConditionalGeneration):

0 commit comments

Comments
 (0)