66
77from llmcompressor .modeling .moe_context import MoECalibrationModule
88from llmcompressor .utils .dev import skip_weights_initialize
9+ import torch .nn .functional as F
910
1011
1112@MoECalibrationModule .register ("Qwen3_5MoeSparseMoeBlock" )
@@ -32,6 +33,55 @@ def __init__(
3233 self .shared_expert_gate = original .shared_expert_gate
3334 self .gate = original .gate
3435 self .experts = SequentialQwen3VLMoeTextExperts (text_config , original .experts )
36+
37+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
38+ batch_size , sequence_length , hidden_dim = hidden_states .shape
39+ hidden_states_reshaped = hidden_states .view (- 1 , hidden_dim )
40+
41+ # router: returns (router_logits, router_scores, router_indices)
42+ _ , routing_weights , selected_experts = self .gate (hidden_states_reshaped )
43+
44+ # expert mask: (num_experts, top_k, num_tokens)
45+ expert_mask = F .one_hot (selected_experts , num_classes = self .num_experts ).permute (
46+ 2 , 1 , 0
47+ )
48+
49+ final_hidden_states = torch .zeros (
50+ (batch_size * sequence_length , hidden_dim ),
51+ dtype = hidden_states .dtype ,
52+ device = hidden_states .device ,
53+ )
54+
55+ for expert_idx , expert_layer in enumerate (self .experts ):
56+ idx , token_idx = torch .where (expert_mask [expert_idx ])
57+
58+ if self .calibrate_all_experts :
59+ expert_out = expert_layer (hidden_states_reshaped )[token_idx ]
60+ else :
61+ expert_out = expert_layer (hidden_states_reshaped [token_idx ])
62+
63+ if len (token_idx ) > 0 :
64+ current_hidden_states = (
65+ expert_out * routing_weights [token_idx , idx , None ]
66+ )
67+ final_hidden_states .index_add_ (
68+ 0 ,
69+ token_idx ,
70+ current_hidden_states .to (hidden_states .dtype ),
71+ )
72+
73+ # shared expert
74+ shared_expert_output = self .shared_expert (hidden_states_reshaped )
75+ shared_expert_output = (
76+ F .sigmoid (self .shared_expert_gate (hidden_states_reshaped ))
77+ * shared_expert_output
78+ )
79+ final_hidden_states = final_hidden_states + shared_expert_output
80+
81+ final_hidden_states = final_hidden_states .reshape (
82+ batch_size , sequence_length , hidden_dim
83+ )
84+ return final_hidden_states
3585
3686 def restore (self , original : torch .nn .Module ) -> torch .nn .Module :
3787 return original
@@ -42,6 +92,7 @@ def __init__(self, config, original):
4292 from transformers .models .qwen3_5_moe .modeling_qwen3_5_moe import (
4393 Qwen3_5MoeMLP ,
4494 )
95+ from compressed_tensors .offload import disable_onloading
4596
4697 self .num_experts = original .gate_up_proj .shape [0 ]
4798 with skip_weights_initialize ():
@@ -56,9 +107,13 @@ def __init__(self, config, original):
56107
57108 intermediate_size = original .down_proj .shape [- 1 ]
58109
110+ with disable_onloading ():
111+ gate_up_data = original .gate_up_proj .data # [num_experts, 2*inter, hidden]
112+ down_data = original .down_proj .data # [num_experts, hidden, inter]
113+
59114 for i in range (self .num_experts ):
60- gate_up = original . gate_up_proj [i ]
61- down = original . down_proj [i ]
115+ gate_up = gate_up_data [i ]
116+ down = down_data [i ]
62117
63118 gate_proj = gate_up [:intermediate_size , :]
64119 up_proj = gate_up [intermediate_size :, :]
0 commit comments