@@ -1032,7 +1032,8 @@ def patched_modeling_marian_eager_attention_forward(
10321032
10331033
10341034class common_RotaryEmbedding (torch .nn .Module ):
1035- @torch .no_grad ()
1035+ # This may cause some issues.
1036+ # @torch.no_grad()
10361037 @patched_dynamic_rope_update
10371038 def forward (self , x , position_ids ):
10381039 inv_freq_expanded = (
@@ -1482,3 +1483,109 @@ def forward(
14821483 attn_output = attn_output .reshape (seq_length , - 1 )
14831484 attn_output = self .proj (attn_output )
14841485 return attn_output
1486+
1487+
1488+ try :
1489+ import transformers .models .qwen3_moe
1490+
1491+ patch_qwen3 = True
1492+ except ImportError :
1493+ patch_qwen3 = False
1494+
1495+ if patch_qwen3 :
1496+
1497+ class patched_Qwen3MoeSparseMoeBlock (torch .nn .Module ):
1498+ _PATCHES_ = ["forward" , "_forward_expert_loop" ]
1499+ _PATCHED_CLASS_ = (
1500+ transformers .models .qwen3_moe .modeling_qwen3_moe .Qwen3MoeSparseMoeBlock
1501+ )
1502+
1503+ def _forward_expert_loop (
1504+ self ,
1505+ final_hidden_states ,
1506+ expert_mask_idx ,
1507+ hidden_states ,
1508+ routing_weights ,
1509+ expert_idx : int ,
1510+ ):
1511+ # idx, top_x = torch.where(expert_mask_idx.squeeze(0))
1512+ idx , top_x = torch .nonzero (expert_mask_idx , as_tuple = True )
1513+ hidden_dim = hidden_states .shape [- 1 ]
1514+ current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_dim )
1515+ expert_current_state = self .experts [expert_idx ](current_state )
1516+ current_hidden_states = expert_current_state * routing_weights [top_x , idx , None ]
1517+ return final_hidden_states .index_add (
1518+ 0 , top_x , current_hidden_states .to (hidden_states .dtype )
1519+ )
1520+
1521+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
1522+ """ """
1523+ batch_size , sequence_length , hidden_dim = hidden_states .shape
1524+ hidden_states = hidden_states .view (- 1 , hidden_dim )
1525+ # router_logits: (batch * sequence_length, n_experts)
1526+ router_logits = self .gate (hidden_states )
1527+
1528+ routing_weights = torch .nn .functional .softmax (
1529+ router_logits , dim = 1 , dtype = torch .float
1530+ )
1531+ routing_weights , selected_experts = torch .topk (routing_weights , self .top_k , dim = - 1 )
1532+ if self .norm_topk_prob : # only diff with mixtral sparse moe block!
1533+ routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
1534+ # we cast back to the input dtype
1535+ routing_weights = routing_weights .to (hidden_states .dtype )
1536+
1537+ final_hidden_states = torch .zeros (
1538+ (batch_size * sequence_length , hidden_dim ),
1539+ dtype = hidden_states .dtype ,
1540+ device = hidden_states .device ,
1541+ )
1542+
1543+ # One hot encode the selected experts to create an expert mask
1544+ # this will be used to easily index which expert is going to be sollicitated
1545+ expert_mask = torch .nn .functional .one_hot (
1546+ selected_experts , num_classes = self .num_experts
1547+ ).permute (2 , 1 , 0 )
1548+
1549+ # Loop over all available experts in the model
1550+ # and perform the computation on each expert
1551+ expert_sum = expert_mask .sum (dim = (- 1 , - 2 ))
1552+ # expert_hit = torch.greater(expert_sum, 0).nonzero()
1553+ # for expert_idx in expert_hit:
1554+ for expert_idx in range (self .num_experts ):
1555+ # initial code has a squeeze but it is not possible to do that.
1556+ # expert_mask_idx = expert_mask[expert_idx].squeeze(0)
1557+ expert_mask_idx = expert_mask [expert_idx ]
1558+ final_hidden_states = torch .cond (
1559+ (expert_sum [expert_idx ] > 0 ).item (),
1560+ lambda final_hidden_states , expert_mask , hidden_states , routing_weights , _i = expert_idx : self ._forward_expert_loop ( # noqa: E501
1561+ final_hidden_states ,
1562+ expert_mask ,
1563+ hidden_states ,
1564+ routing_weights ,
1565+ expert_idx = _i ,
1566+ ),
1567+ lambda final_hidden_states , * args : final_hidden_states .clone (),
1568+ [final_hidden_states , expert_mask_idx , hidden_states , routing_weights ],
1569+ )
1570+
1571+ # if expert_sum[expert_idx] > 0:
1572+ # idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
1573+
1574+ # Index the correct hidden states and compute the expert hidden state for
1575+ # the current expert. We need to make sure to multiply the output hidden
1576+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
1577+ # current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1578+ # current_hidden_states = (
1579+ # expert_layer(current_state) * routing_weights[top_x, idx, None]
1580+ # )
1581+
1582+ # However `index_add_` only support torch tensors for indexing so we'll use
1583+ # the `top_x` tensor here.
1584+ # final_hidden_states.index_add_(
1585+ # 0, top_x, current_hidden_states.to(hidden_states.dtype)
1586+ # )
1587+
1588+ final_hidden_states = final_hidden_states .reshape (
1589+ batch_size , sequence_length , hidden_dim
1590+ )
1591+ return final_hidden_states , router_logits
0 commit comments