@@ -1482,3 +1482,94 @@ def forward(
14821482 attn_output = attn_output .reshape (seq_length , - 1 )
14831483 attn_output = self .proj (attn_output )
14841484 return attn_output
1485+
1486+
1487+ class patched_Qwen3MoeSparseMoeBlock (torch .nn .Module ):
1488+ _PATCHES_ = ["forward" , "_forward_expert_loop" ]
1489+ _PATCHED_CLASS_ = transformers .models .qwen3_moe .modeling_qwen3_moe .Qwen3MoeSparseMoeBlock
1490+
1491+ def _forward_expert_loop (
1492+ self ,
1493+ final_hidden_states ,
1494+ expert_mask_idx ,
1495+ hidden_states ,
1496+ routing_weights ,
1497+ expert_idx : int ,
1498+ ):
1499+ # idx, top_x = torch.where(expert_mask_idx.squeeze(0))
1500+ idx , top_x = torch .nonzero (expert_mask_idx , as_tuple = True )
1501+ hidden_dim = hidden_states .shape [- 1 ]
1502+ current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_dim )
1503+ expert_current_state = self .experts [expert_idx ](current_state )
1504+ current_hidden_states = expert_current_state * routing_weights [top_x , idx , None ]
1505+ return final_hidden_states .index_add (
1506+ 0 , top_x , current_hidden_states .to (hidden_states .dtype )
1507+ )
1508+
1509+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
1510+ """ """
1511+ batch_size , sequence_length , hidden_dim = hidden_states .shape
1512+ hidden_states = hidden_states .view (- 1 , hidden_dim )
1513+ # router_logits: (batch * sequence_length, n_experts)
1514+ router_logits = self .gate (hidden_states )
1515+
1516+ routing_weights = torch .nn .functional .softmax (router_logits , dim = 1 , dtype = torch .float )
1517+ routing_weights , selected_experts = torch .topk (routing_weights , self .top_k , dim = - 1 )
1518+ if self .norm_topk_prob : # only diff with mixtral sparse moe block!
1519+ routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
1520+ # we cast back to the input dtype
1521+ routing_weights = routing_weights .to (hidden_states .dtype )
1522+
1523+ final_hidden_states = torch .zeros (
1524+ (batch_size * sequence_length , hidden_dim ),
1525+ dtype = hidden_states .dtype ,
1526+ device = hidden_states .device ,
1527+ )
1528+
1529+ # One hot encode the selected experts to create an expert mask
1530+ # this will be used to easily index which expert is going to be sollicitated
1531+ expert_mask = torch .nn .functional .one_hot (
1532+ selected_experts , num_classes = self .num_experts
1533+ ).permute (2 , 1 , 0 )
1534+
1535+ # Loop over all available experts in the model
1536+ # and perform the computation on each expert
1537+ expert_sum = expert_mask .sum (dim = (- 1 , - 2 ))
1538+ # expert_hit = torch.greater(expert_sum, 0).nonzero()
1539+ # for expert_idx in expert_hit:
1540+ for expert_idx in range (self .num_experts ):
1541+ expert_mask_idx = expert_mask [expert_idx ].squeeze (0 )
1542+ final_hidden_states = torch .cond (
1543+ (expert_sum [expert_idx ] > 0 ).item (),
1544+ lambda final_hidden_states , expert_mask , hidden_states , routing_weights , _i = expert_idx : self ._forward_expert_loop ( # noqa: E501
1545+ final_hidden_states ,
1546+ expert_mask ,
1547+ hidden_states ,
1548+ routing_weights ,
1549+ expert_idx = _i ,
1550+ ),
1551+ lambda final_hidden_states , * args : final_hidden_states .clone (),
1552+ [final_hidden_states , expert_mask_idx , hidden_states , routing_weights ],
1553+ )
1554+
1555+ # if expert_sum[expert_idx] > 0:
1556+ # idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
1557+
1558+ # Index the correct hidden states and compute the expert hidden state for
1559+ # the current expert. We need to make sure to multiply the output hidden
1560+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
1561+ # current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1562+ # current_hidden_states = (
1563+ # expert_layer(current_state) * routing_weights[top_x, idx, None]
1564+ # )
1565+
1566+ # However `index_add_` only support torch tensors for indexing so we'll use
1567+ # the `top_x` tensor here.
1568+ # final_hidden_states.index_add_(
1569+ # 0, top_x, current_hidden_states.to(hidden_states.dtype)
1570+ # )
1571+
1572+ final_hidden_states = final_hidden_states .reshape (
1573+ batch_size , sequence_length , hidden_dim
1574+ )
1575+ return final_hidden_states , router_logits
0 commit comments