Skip to content

Commit 4dfffd0

Browse files
committed
fix patch
1 parent 12b7296 commit 4dfffd0

File tree

1 file changed

+96
-83
lines changed

1 file changed

+96
-83
lines changed

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 96 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1484,92 +1484,105 @@ def forward(
14841484
return attn_output
14851485

14861486

1487-
class patched_Qwen3MoeSparseMoeBlock(torch.nn.Module):
1488-
_PATCHES_ = ["forward", "_forward_expert_loop"]
1489-
_PATCHED_CLASS_ = transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
1487+
try:
1488+
import transformers.models.qwen3_moe
14901489

1491-
def _forward_expert_loop(
1492-
self,
1493-
final_hidden_states,
1494-
expert_mask_idx,
1495-
hidden_states,
1496-
routing_weights,
1497-
expert_idx: int,
1498-
):
1499-
# idx, top_x = torch.where(expert_mask_idx.squeeze(0))
1500-
idx, top_x = torch.nonzero(expert_mask_idx, as_tuple=True)
1501-
hidden_dim = hidden_states.shape[-1]
1502-
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1503-
expert_current_state = self.experts[expert_idx](current_state)
1504-
current_hidden_states = expert_current_state * routing_weights[top_x, idx, None]
1505-
return final_hidden_states.index_add(
1506-
0, top_x, current_hidden_states.to(hidden_states.dtype)
1507-
)
1490+
patch_qwen3 = True
1491+
except ImportError:
1492+
patch_qwen3 = False
15081493

1509-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1510-
""" """
1511-
batch_size, sequence_length, hidden_dim = hidden_states.shape
1512-
hidden_states = hidden_states.view(-1, hidden_dim)
1513-
# router_logits: (batch * sequence_length, n_experts)
1514-
router_logits = self.gate(hidden_states)
1515-
1516-
routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
1517-
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
1518-
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
1519-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
1520-
# we cast back to the input dtype
1521-
routing_weights = routing_weights.to(hidden_states.dtype)
1522-
1523-
final_hidden_states = torch.zeros(
1524-
(batch_size * sequence_length, hidden_dim),
1525-
dtype=hidden_states.dtype,
1526-
device=hidden_states.device,
1494+
if patch_qwen3:
1495+
1496+
class patched_Qwen3MoeSparseMoeBlock(torch.nn.Module):
1497+
_PATCHES_ = ["forward", "_forward_expert_loop"]
1498+
_PATCHED_CLASS_ = (
1499+
transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock
15271500
)
15281501

1529-
# One hot encode the selected experts to create an expert mask
1530-
# this will be used to easily index which expert is going to be sollicitated
1531-
expert_mask = torch.nn.functional.one_hot(
1532-
selected_experts, num_classes=self.num_experts
1533-
).permute(2, 1, 0)
1534-
1535-
# Loop over all available experts in the model
1536-
# and perform the computation on each expert
1537-
expert_sum = expert_mask.sum(dim=(-1, -2))
1538-
# expert_hit = torch.greater(expert_sum, 0).nonzero()
1539-
# for expert_idx in expert_hit:
1540-
for expert_idx in range(self.num_experts):
1541-
expert_mask_idx = expert_mask[expert_idx].squeeze(0)
1542-
final_hidden_states = torch.cond(
1543-
(expert_sum[expert_idx] > 0).item(),
1544-
lambda final_hidden_states, expert_mask, hidden_states, routing_weights, _i=expert_idx: self._forward_expert_loop( # noqa: E501
1545-
final_hidden_states,
1546-
expert_mask,
1547-
hidden_states,
1548-
routing_weights,
1549-
expert_idx=_i,
1550-
),
1551-
lambda final_hidden_states, *args: final_hidden_states.clone(),
1552-
[final_hidden_states, expert_mask_idx, hidden_states, routing_weights],
1502+
def _forward_expert_loop(
1503+
self,
1504+
final_hidden_states,
1505+
expert_mask_idx,
1506+
hidden_states,
1507+
routing_weights,
1508+
expert_idx: int,
1509+
):
1510+
# idx, top_x = torch.where(expert_mask_idx.squeeze(0))
1511+
idx, top_x = torch.nonzero(expert_mask_idx, as_tuple=True)
1512+
hidden_dim = hidden_states.shape[-1]
1513+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1514+
expert_current_state = self.experts[expert_idx](current_state)
1515+
current_hidden_states = expert_current_state * routing_weights[top_x, idx, None]
1516+
return final_hidden_states.index_add(
1517+
0, top_x, current_hidden_states.to(hidden_states.dtype)
15531518
)
15541519

1555-
# if expert_sum[expert_idx] > 0:
1556-
# idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
1557-
1558-
# Index the correct hidden states and compute the expert hidden state for
1559-
# the current expert. We need to make sure to multiply the output hidden
1560-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
1561-
# current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1562-
# current_hidden_states = (
1563-
# expert_layer(current_state) * routing_weights[top_x, idx, None]
1564-
# )
1565-
1566-
# However `index_add_` only support torch tensors for indexing so we'll use
1567-
# the `top_x` tensor here.
1568-
# final_hidden_states.index_add_(
1569-
# 0, top_x, current_hidden_states.to(hidden_states.dtype)
1570-
# )
1571-
1572-
final_hidden_states = final_hidden_states.reshape(
1573-
batch_size, sequence_length, hidden_dim
1574-
)
1575-
return final_hidden_states, router_logits
1520+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1521+
""" """
1522+
batch_size, sequence_length, hidden_dim = hidden_states.shape
1523+
hidden_states = hidden_states.view(-1, hidden_dim)
1524+
# router_logits: (batch * sequence_length, n_experts)
1525+
router_logits = self.gate(hidden_states)
1526+
1527+
routing_weights = torch.nn.functional.softmax(
1528+
router_logits, dim=1, dtype=torch.float
1529+
)
1530+
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
1531+
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
1532+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
1533+
# we cast back to the input dtype
1534+
routing_weights = routing_weights.to(hidden_states.dtype)
1535+
1536+
final_hidden_states = torch.zeros(
1537+
(batch_size * sequence_length, hidden_dim),
1538+
dtype=hidden_states.dtype,
1539+
device=hidden_states.device,
1540+
)
1541+
1542+
# One hot encode the selected experts to create an expert mask
1543+
# this will be used to easily index which expert is going to be sollicitated
1544+
expert_mask = torch.nn.functional.one_hot(
1545+
selected_experts, num_classes=self.num_experts
1546+
).permute(2, 1, 0)
1547+
1548+
# Loop over all available experts in the model
1549+
# and perform the computation on each expert
1550+
expert_sum = expert_mask.sum(dim=(-1, -2))
1551+
# expert_hit = torch.greater(expert_sum, 0).nonzero()
1552+
# for expert_idx in expert_hit:
1553+
for expert_idx in range(self.num_experts):
1554+
expert_mask_idx = expert_mask[expert_idx].squeeze(0)
1555+
final_hidden_states = torch.cond(
1556+
(expert_sum[expert_idx] > 0).item(),
1557+
lambda final_hidden_states, expert_mask, hidden_states, routing_weights, _i=expert_idx: self._forward_expert_loop( # noqa: E501
1558+
final_hidden_states,
1559+
expert_mask,
1560+
hidden_states,
1561+
routing_weights,
1562+
expert_idx=_i,
1563+
),
1564+
lambda final_hidden_states, *args: final_hidden_states.clone(),
1565+
[final_hidden_states, expert_mask_idx, hidden_states, routing_weights],
1566+
)
1567+
1568+
# if expert_sum[expert_idx] > 0:
1569+
# idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
1570+
1571+
# Index the correct hidden states and compute the expert hidden state for
1572+
# the current expert. We need to make sure to multiply the output hidden
1573+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
1574+
# current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1575+
# current_hidden_states = (
1576+
# expert_layer(current_state) * routing_weights[top_x, idx, None]
1577+
# )
1578+
1579+
# However `index_add_` only support torch tensors for indexing so we'll use
1580+
# the `top_x` tensor here.
1581+
# final_hidden_states.index_add_(
1582+
# 0, top_x, current_hidden_states.to(hidden_states.dtype)
1583+
# )
1584+
1585+
final_hidden_states = final_hidden_states.reshape(
1586+
batch_size, sequence_length, hidden_dim
1587+
)
1588+
return final_hidden_states, router_logits

0 commit comments

Comments
 (0)