@@ -23,6 +23,11 @@ class RoutingCFG:
2323 router_logits_bsz1 : torch .Tensor
2424 routing_weights_bsz1 : torch .Tensor
2525 selected_experts_bsz1 : torch .Tensor
26+ e_score_correction_bias : torch .Tensor | None
27+ routed_scaling_factor : float | None
28+ n_group : int | None
29+ topk_group : int | None
30+
2631
2732def routing (bsz , cfg , y , params ):
2833 activate_all_experts = params .get ("activate_all_experts" )
@@ -50,6 +55,75 @@ def routing(bsz, cfg, y, params):
5055 return selected_experts , routing_weights
5156
5257
58+ # TODO: Optimize (for DS3)
59+ def routing_ds3 (bsz , cfg , y , params ):
60+ activate_all_experts = params .get ("activate_all_experts" )
61+ router_logits = torch .matmul (y , cfg .gate_tensor )
62+
63+ scores = router_logits .sigmoid ()
64+ scores_for_choice = scores .view (- 1 , cfg .num_experts ) + cfg .e_score_correction_bias .unsqueeze (0 )
65+ group_scores = (
66+ scores_for_choice .view (- 1 , cfg .n_group , cfg .num_experts // cfg .n_group )
67+ .topk (2 , dim = - 1 )[0 ]
68+ .sum (dim = - 1 )
69+ )
70+ group_idx = torch .topk (group_scores , k = cfg .topk_group , dim = - 1 , sorted = False )[1 ]
71+ group_mask = torch .zeros_like (group_scores )
72+ group_mask .scatter_ (1 , group_idx , 1 )
73+ score_mask = (
74+ group_mask .unsqueeze (- 1 )
75+ .expand (- 1 , cfg .n_group , cfg .num_experts // cfg .n_group )
76+ .reshape (- 1 , cfg .num_experts )
77+ )
78+ scores_for_choice = scores_for_choice .masked_fill (~ score_mask .bool (), 0.0 )
79+
80+ topk_indices = torch .topk (
81+ scores_for_choice ,
82+ k = cfg .num_experts if activate_all_experts else cfg .num_experts_per_tok ,
83+ dim = - 1 ,
84+ sorted = False
85+ )[1 ]
86+ topk_weights = scores .gather (1 , topk_indices )
87+ denominator = topk_weights .sum (dim = - 1 , keepdim = True ) + 1e-20
88+ topk_weights /= denominator
89+ topk_weights = topk_weights * cfg .routed_scaling_factor
90+ return topk_indices , topk_weights
91+
92+
93+ def routing_dots (bsz , cfg , y , params ):
94+ activate_all_experts = params .get ("activate_all_experts" )
95+
96+ if bsz == 1 and not activate_all_experts :
97+ torch .matmul (y , cfg .gate_tensor , out = cfg .router_logits_bsz1 )
98+ cfg .router_logits_bsz1 += cfg .e_score_correction_bias
99+ torch .topk (
100+ cfg .router_logits_bsz1 ,
101+ cfg .num_experts_per_tok ,
102+ dim = - 1 ,
103+ out = (cfg .routing_weights_bsz1 , cfg .selected_experts_bsz1 ),
104+ sorted = False
105+ )
106+ # TODO: Custom kernel for sigmoid normalization
107+ cfg .routing_weights_bsz1 .sigmoid_ ()
108+ factor = cfg .routed_scaling_factor / (cfg .routing_weights_bsz1 .sum (dim = - 1 , keepdim = True ) + 1e-20 )
109+ cfg .routing_weights_bsz1 *= factor
110+ return cfg .selected_experts_bsz1 , cfg .routing_weights_bsz1
111+
112+ else :
113+ router_logits = torch .matmul (y , cfg .gate_tensor )
114+ router_logits += cfg .e_score_correction_bias
115+ routing_weights , selected_experts = torch .topk (
116+ router_logits ,
117+ cfg .num_experts if activate_all_experts else cfg .num_experts_per_tok ,
118+ dim = - 1
119+ )
120+ # TODO: Custom kernel for sigmoid normalization
121+ routing_weights .sigmoid_ ()
122+ factor = cfg .routed_scaling_factor / (routing_weights .sum (dim = - 1 , keepdim = True ) + 1e-20 )
123+ routing_weights *= factor
124+ return selected_experts , routing_weights
125+
126+
53127@dataclass
54128class ExpertsCFG :
55129 yh : torch .Tensor
@@ -77,6 +151,10 @@ def __init__(
77151 out_dtype : torch .dtype = None ,
78152 activation_fn : str = "silu" ,
79153 interm_dtype : torch .dtype = None ,
154+ deepseekv3_routing : bool = False ,
155+ routed_scaling_factor : float | None = None ,
156+ n_group : int | None = None ,
157+ topk_group : int | None = None ,
80158 shared_experts : MLP | GatedMLP | None = None
81159 ):
82160 super ().__init__ (config , key , None )
@@ -89,6 +167,11 @@ def __init__(
89167 self .num_experts_per_tok = num_experts_per_tok
90168 self .hidden_size = hidden_size
91169
170+ self .deepseekv3_routing = deepseekv3_routing
171+ self .routed_scaling_factor = routed_scaling_factor
172+ self .n_group = n_group
173+ self .topk_group = topk_group
174+
92175 self .routing_gate = Linear (
93176 config = config ,
94177 key = f"{ key } .{ key_routing_gate } " ,
@@ -152,6 +235,8 @@ def __init__(
152235 self .routing_cfg = None
153236 self .experts_cfg = None
154237
238+ self .e_score_correction_bias = None
239+
155240 self .shared_experts = shared_experts
156241 if shared_experts is not None :
157242 self .register_submodule (shared_experts )
@@ -161,6 +246,9 @@ def __init__(
161246 def load (self , device : torch .Device , ** kwargs ):
162247 super ().load (device , ** kwargs )
163248
249+ self .e_score_correction_bias = \
250+ self .config .stc .get_tensor (self .key + ".gate.e_score_correction_bias" , self .device , optional = True )
251+
164252 # Test if experts can be fused
165253 num_exl3_tensors = 0
166254 num_nonexl3_tensors = 0
@@ -189,7 +277,11 @@ def load(self, device: torch.Device, **kwargs):
189277 num_experts_per_tok = self .num_experts_per_tok ,
190278 router_logits_bsz1 = router_logits_bsz1 ,
191279 routing_weights_bsz1 = routing_weights_bsz1 ,
192- selected_experts_bsz1 = selected_experts_bsz1
280+ selected_experts_bsz1 = selected_experts_bsz1 ,
281+ e_score_correction_bias = self .e_score_correction_bias ,
282+ routed_scaling_factor = self .routed_scaling_factor ,
283+ n_group = self .n_group ,
284+ topk_group = self .topk_group ,
193285 )
194286
195287 yh = torch .empty (
@@ -231,6 +323,7 @@ def unload(self):
231323 self .multi_down = None
232324 self .routing_cfg = None
233325 self .experts_cfg = None
326+ self .e_score_correction_bias = None
234327 super ().unload ()
235328
236329
@@ -245,12 +338,18 @@ def forward(
245338 y = x .view (- 1 , self .hidden_size )
246339 bsz = y .shape [0 ]
247340
248- # selected_experts, routing_weights = routing(bsz, self.routing_cfg, y, params)
249- selected_experts , routing_weights = ext .blocksparse_mlp_routing (bsz , self .routing_cfg , y , params )
341+ if self .deepseekv3_routing :
342+ if self .n_group == 1 and self .topk_group == 1 :
343+ selected_experts , routing_weights = routing_dots (bsz , self .routing_cfg , y , params )
344+ # else:
345+ # selected_experts, routing_weights = routing_ds3(bsz, self.routing_cfg, y, params)
346+ else :
347+ # selected_experts, routing_weights = routing(bsz, self.routing_cfg, y, params)
348+ selected_experts , routing_weights = ext .blocksparse_mlp_routing (bsz , self .routing_cfg , y , params , False )
250349
251350 # Torch path
252351 if bsz > 1 or not self .is_quantized :
253- final_hidden_states = torch .zeros_like (y )
352+ final_hidden_states = torch .zeros_like (y , dtype = self . out_dtype )
254353
255354 expert_mask = torch .nn .functional .one_hot (
256355 selected_experts ,
@@ -338,8 +437,14 @@ def mlp(exp_i, xc):
338437 )
339438
340439 final_hidden_states = cfg .out_d [:1 , ...]
341- return final_hidden_states .view (x .shape )
342440 final_hidden_states = final_hidden_states .view (x .shape )
343441 if self .shared_experts :
344442 final_hidden_states += self .shared_experts .forward (x , params )
345- return final_hidden_states
443+ return final_hidden_states
444+
445+ @override
446+ def get_tensors (self ):
447+ t = super ().get_tensors ()
448+ if self .e_score_correction_bias is not None :
449+ t [f"{ self .key } .gate.e_score_correction_bias" ] = self .e_score_correction_bias .contiguous ()
450+ return t
0 commit comments