@@ -1484,92 +1484,105 @@ def forward(
14841484 return attn_output
14851485
14861486
1487- class patched_Qwen3MoeSparseMoeBlock (torch .nn .Module ):
1488- _PATCHES_ = ["forward" , "_forward_expert_loop" ]
1489- _PATCHED_CLASS_ = transformers .models .qwen3_moe .modeling_qwen3_moe .Qwen3MoeSparseMoeBlock
1487+ try :
1488+ import transformers .models .qwen3_moe
14901489
1491- def _forward_expert_loop (
1492- self ,
1493- final_hidden_states ,
1494- expert_mask_idx ,
1495- hidden_states ,
1496- routing_weights ,
1497- expert_idx : int ,
1498- ):
1499- # idx, top_x = torch.where(expert_mask_idx.squeeze(0))
1500- idx , top_x = torch .nonzero (expert_mask_idx , as_tuple = True )
1501- hidden_dim = hidden_states .shape [- 1 ]
1502- current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_dim )
1503- expert_current_state = self .experts [expert_idx ](current_state )
1504- current_hidden_states = expert_current_state * routing_weights [top_x , idx , None ]
1505- return final_hidden_states .index_add (
1506- 0 , top_x , current_hidden_states .to (hidden_states .dtype )
1507- )
1490+ patch_qwen3 = True
1491+ except ImportError :
1492+ patch_qwen3 = False
15081493
1509- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
1510- """ """
1511- batch_size , sequence_length , hidden_dim = hidden_states .shape
1512- hidden_states = hidden_states .view (- 1 , hidden_dim )
1513- # router_logits: (batch * sequence_length, n_experts)
1514- router_logits = self .gate (hidden_states )
1515-
1516- routing_weights = torch .nn .functional .softmax (router_logits , dim = 1 , dtype = torch .float )
1517- routing_weights , selected_experts = torch .topk (routing_weights , self .top_k , dim = - 1 )
1518- if self .norm_topk_prob : # only diff with mixtral sparse moe block!
1519- routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
1520- # we cast back to the input dtype
1521- routing_weights = routing_weights .to (hidden_states .dtype )
1522-
1523- final_hidden_states = torch .zeros (
1524- (batch_size * sequence_length , hidden_dim ),
1525- dtype = hidden_states .dtype ,
1526- device = hidden_states .device ,
1494+ if patch_qwen3 :
1495+
1496+ class patched_Qwen3MoeSparseMoeBlock (torch .nn .Module ):
1497+ _PATCHES_ = ["forward" , "_forward_expert_loop" ]
1498+ _PATCHED_CLASS_ = (
1499+ transformers .models .qwen3_moe .modeling_qwen3_moe .Qwen3MoeSparseMoeBlock
15271500 )
15281501
1529- # One hot encode the selected experts to create an expert mask
1530- # this will be used to easily index which expert is going to be sollicitated
1531- expert_mask = torch .nn .functional .one_hot (
1532- selected_experts , num_classes = self .num_experts
1533- ).permute (2 , 1 , 0 )
1534-
1535- # Loop over all available experts in the model
1536- # and perform the computation on each expert
1537- expert_sum = expert_mask .sum (dim = (- 1 , - 2 ))
1538- # expert_hit = torch.greater(expert_sum, 0).nonzero()
1539- # for expert_idx in expert_hit:
1540- for expert_idx in range (self .num_experts ):
1541- expert_mask_idx = expert_mask [expert_idx ].squeeze (0 )
1542- final_hidden_states = torch .cond (
1543- (expert_sum [expert_idx ] > 0 ).item (),
1544- lambda final_hidden_states , expert_mask , hidden_states , routing_weights , _i = expert_idx : self ._forward_expert_loop ( # noqa: E501
1545- final_hidden_states ,
1546- expert_mask ,
1547- hidden_states ,
1548- routing_weights ,
1549- expert_idx = _i ,
1550- ),
1551- lambda final_hidden_states , * args : final_hidden_states .clone (),
1552- [final_hidden_states , expert_mask_idx , hidden_states , routing_weights ],
1502+ def _forward_expert_loop (
1503+ self ,
1504+ final_hidden_states ,
1505+ expert_mask_idx ,
1506+ hidden_states ,
1507+ routing_weights ,
1508+ expert_idx : int ,
1509+ ):
1510+ # idx, top_x = torch.where(expert_mask_idx.squeeze(0))
1511+ idx , top_x = torch .nonzero (expert_mask_idx , as_tuple = True )
1512+ hidden_dim = hidden_states .shape [- 1 ]
1513+ current_state = hidden_states [None , top_x ].reshape (- 1 , hidden_dim )
1514+ expert_current_state = self .experts [expert_idx ](current_state )
1515+ current_hidden_states = expert_current_state * routing_weights [top_x , idx , None ]
1516+ return final_hidden_states .index_add (
1517+ 0 , top_x , current_hidden_states .to (hidden_states .dtype )
15531518 )
15541519
1555- # if expert_sum[expert_idx] > 0:
1556- # idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
1557-
1558- # Index the correct hidden states and compute the expert hidden state for
1559- # the current expert. We need to make sure to multiply the output hidden
1560- # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
1561- # current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1562- # current_hidden_states = (
1563- # expert_layer(current_state) * routing_weights[top_x, idx, None]
1564- # )
1565-
1566- # However `index_add_` only support torch tensors for indexing so we'll use
1567- # the `top_x` tensor here.
1568- # final_hidden_states.index_add_(
1569- # 0, top_x, current_hidden_states.to(hidden_states.dtype)
1570- # )
1571-
1572- final_hidden_states = final_hidden_states .reshape (
1573- batch_size , sequence_length , hidden_dim
1574- )
1575- return final_hidden_states , router_logits
1520+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
1521+ """ """
1522+ batch_size , sequence_length , hidden_dim = hidden_states .shape
1523+ hidden_states = hidden_states .view (- 1 , hidden_dim )
1524+ # router_logits: (batch * sequence_length, n_experts)
1525+ router_logits = self .gate (hidden_states )
1526+
1527+ routing_weights = torch .nn .functional .softmax (
1528+ router_logits , dim = 1 , dtype = torch .float
1529+ )
1530+ routing_weights , selected_experts = torch .topk (routing_weights , self .top_k , dim = - 1 )
1531+ if self .norm_topk_prob : # only diff with mixtral sparse moe block!
1532+ routing_weights /= routing_weights .sum (dim = - 1 , keepdim = True )
1533+ # we cast back to the input dtype
1534+ routing_weights = routing_weights .to (hidden_states .dtype )
1535+
1536+ final_hidden_states = torch .zeros (
1537+ (batch_size * sequence_length , hidden_dim ),
1538+ dtype = hidden_states .dtype ,
1539+ device = hidden_states .device ,
1540+ )
1541+
1542+ # One hot encode the selected experts to create an expert mask
1543+ # this will be used to easily index which expert is going to be sollicitated
1544+ expert_mask = torch .nn .functional .one_hot (
1545+ selected_experts , num_classes = self .num_experts
1546+ ).permute (2 , 1 , 0 )
1547+
1548+ # Loop over all available experts in the model
1549+ # and perform the computation on each expert
1550+ expert_sum = expert_mask .sum (dim = (- 1 , - 2 ))
1551+ # expert_hit = torch.greater(expert_sum, 0).nonzero()
1552+ # for expert_idx in expert_hit:
1553+ for expert_idx in range (self .num_experts ):
1554+ expert_mask_idx = expert_mask [expert_idx ].squeeze (0 )
1555+ final_hidden_states = torch .cond (
1556+ (expert_sum [expert_idx ] > 0 ).item (),
1557+ lambda final_hidden_states , expert_mask , hidden_states , routing_weights , _i = expert_idx : self ._forward_expert_loop ( # noqa: E501
1558+ final_hidden_states ,
1559+ expert_mask ,
1560+ hidden_states ,
1561+ routing_weights ,
1562+ expert_idx = _i ,
1563+ ),
1564+ lambda final_hidden_states , * args : final_hidden_states .clone (),
1565+ [final_hidden_states , expert_mask_idx , hidden_states , routing_weights ],
1566+ )
1567+
1568+ # if expert_sum[expert_idx] > 0:
1569+ # idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
1570+
1571+ # Index the correct hidden states and compute the expert hidden state for
1572+ # the current expert. We need to make sure to multiply the output hidden
1573+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
1574+ # current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
1575+ # current_hidden_states = (
1576+ # expert_layer(current_state) * routing_weights[top_x, idx, None]
1577+ # )
1578+
1579+ # However `index_add_` only support torch tensors for indexing so we'll use
1580+ # the `top_x` tensor here.
1581+ # final_hidden_states.index_add_(
1582+ # 0, top_x, current_hidden_states.to(hidden_states.dtype)
1583+ # )
1584+
1585+ final_hidden_states = final_hidden_states .reshape (
1586+ batch_size , sequence_length , hidden_dim
1587+ )
1588+ return final_hidden_states , router_logits
0 commit comments