Skip to content

Commit 9032320

Browse files
committed
Patch for Qwen3
1 parent a51448e commit 9032320

File tree

3 files changed

+108
-0
lines changed

3 files changed

+108
-0
lines changed

onnx_diagnostic/helpers/config_helper.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,10 @@ def default_num_hidden_layers():
119119
It is lower when the unit tests are running
120120
when ``UNITTEST_GOING=1``.
121121
"""
122+
import torch
123+
124+
if torch.cuda.is_available():
125+
capa = torch.cuda.get_device_capability(0)
126+
if capa[0] < 9:
127+
return 2
122128
return 2 if os.environ.get("UNITTEST_GOING", "0") == "1" else 4

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,3 +1482,94 @@ def forward(
14821482
attn_output = attn_output.reshape(seq_length, -1)
14831483
attn_output = self.proj(attn_output)
14841484
return attn_output
1485+
1486+
1487+
class patched_Qwen3MoeSparseMoeBlock(torch.nn.Module):
1488+
_PATCHES_ = ["forward", "_forward_expert_loop"]
1489+
_PATCHED_CLASS_ = transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
1490+
1491+
def _forward_expert_loop(
1492+
self,
1493+
final_hidden_states,
1494+
expert_mask_idx,
1495+
hidden_states,
1496+
routing_weights,
1497+
expert_idx: int,
1498+
):
1499+
# idx, top_x = torch.where(expert_mask_idx.squeeze(0))
1500+
idx, top_x = torch.nonzero(expert_mask_idx, as_tuple=True)
1501+
hidden_dim = hidden_states.shape[-1]
1502+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1503+
expert_current_state = self.experts[expert_idx](current_state)
1504+
current_hidden_states = expert_current_state * routing_weights[top_x, idx, None]
1505+
return final_hidden_states.index_add(
1506+
0, top_x, current_hidden_states.to(hidden_states.dtype)
1507+
)
1508+
1509+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1510+
""" """
1511+
batch_size, sequence_length, hidden_dim = hidden_states.shape
1512+
hidden_states = hidden_states.view(-1, hidden_dim)
1513+
# router_logits: (batch * sequence_length, n_experts)
1514+
router_logits = self.gate(hidden_states)
1515+
1516+
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
1517+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
1518+
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
1519+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
1520+
# we cast back to the input dtype
1521+
routing_weights = routing_weights.to(hidden_states.dtype)
1522+
1523+
final_hidden_states = torch.zeros(
1524+
(batch_size * sequence_length, hidden_dim),
1525+
dtype=hidden_states.dtype,
1526+
device=hidden_states.device,
1527+
)
1528+
1529+
# One hot encode the selected experts to create an expert mask
1530+
# this will be used to easily index which expert is going to be sollicitated
1531+
expert_mask = torch.nn.functional.one_hot(
1532+
selected_experts, num_classes=self.num_experts
1533+
).permute(2, 1, 0)
1534+
1535+
# Loop over all available experts in the model
1536+
# and perform the computation on each expert
1537+
expert_sum = expert_mask.sum(dim=(-1, -2))
1538+
# expert_hit = torch.greater(expert_sum, 0).nonzero()
1539+
# for expert_idx in expert_hit:
1540+
for expert_idx in range(self.num_experts):
1541+
expert_mask_idx = expert_mask[expert_idx].squeeze(0)
1542+
final_hidden_states = torch.cond(
1543+
(expert_sum[expert_idx] > 0).item(),
1544+
lambda final_hidden_states, expert_mask, hidden_states, routing_weights, _i=expert_idx: self._forward_expert_loop( # noqa: E501
1545+
final_hidden_states,
1546+
expert_mask,
1547+
hidden_states,
1548+
routing_weights,
1549+
expert_idx=_i,
1550+
),
1551+
lambda final_hidden_states, *args: final_hidden_states.clone(),
1552+
[final_hidden_states, expert_mask_idx, hidden_states, routing_weights],
1553+
)
1554+
1555+
# if expert_sum[expert_idx] > 0:
1556+
# idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
1557+
1558+
# Index the correct hidden states and compute the expert hidden state for
1559+
# the current expert. We need to make sure to multiply the output hidden
1560+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
1561+
# current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1562+
# current_hidden_states = (
1563+
# expert_layer(current_state) * routing_weights[top_x, idx, None]
1564+
# )
1565+
1566+
# However `index_add_` only support torch tensors for indexing so we'll use
1567+
# the `top_x` tensor here.
1568+
# final_hidden_states.index_add_(
1569+
# 0, top_x, current_hidden_states.to(hidden_states.dtype)
1570+
# )
1571+
1572+
final_hidden_states = final_hidden_states.reshape(
1573+
batch_size, sequence_length, hidden_dim
1574+
)
1575+
return final_hidden_states, router_logits

onnx_diagnostic/torch_models/validate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
22
import inspect
33
import os
4+
import pprint
45
import sys
56
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
67
import time
@@ -467,6 +468,16 @@ def validate_model(
467468
f"inputs2 is True but second set is missing in data for "
468469
f"model id {model_id!r}: {sorted(data)}"
469470
)
471+
if dump_folder:
472+
with open(os.path.join(dump_folder, "model_config.txt"), "w") as f:
473+
f.write(f"model_id: {model_id}\n------\n")
474+
f.write(
475+
pprint.pformat(
476+
data["configuration"]
477+
if type(data["configuration"]) is dict
478+
else data["configuration"].to_dict()
479+
)
480+
)
470481

471482
if exporter == "modelbuilder":
472483
# Models used with ModelBuilder do not like batch size > 1.

0 commit comments

Comments
 (0)