Skip to content

Commit b8c830c

Browse files
committed
BlockSparseMLP: Add DS3/Dots routing
1 parent 13fa9ca commit b8c830c

File tree

1 file changed

+111
-6
lines changed

1 file changed

+111
-6
lines changed

exllamav3/modules/block_sparse_mlp.py

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ class RoutingCFG:
2323
router_logits_bsz1: torch.Tensor
2424
routing_weights_bsz1: torch.Tensor
2525
selected_experts_bsz1: torch.Tensor
26+
e_score_correction_bias: torch.Tensor | None
27+
routed_scaling_factor: float | None
28+
n_group: int | None
29+
topk_group: int | None
30+
2631

2732
def routing(bsz, cfg, y, params):
2833
activate_all_experts = params.get("activate_all_experts")
@@ -50,6 +55,75 @@ def routing(bsz, cfg, y, params):
5055
return selected_experts, routing_weights
5156

5257

58+
# TODO: Optimize (for DS3)
59+
def routing_ds3(bsz, cfg, y, params):
60+
activate_all_experts = params.get("activate_all_experts")
61+
router_logits = torch.matmul(y, cfg.gate_tensor)
62+
63+
scores = router_logits.sigmoid()
64+
scores_for_choice = scores.view(-1, cfg.num_experts) + cfg.e_score_correction_bias.unsqueeze(0)
65+
group_scores = (
66+
scores_for_choice.view(-1, cfg.n_group, cfg.num_experts // cfg.n_group)
67+
.topk(2, dim = -1)[0]
68+
.sum(dim = -1)
69+
)
70+
group_idx = torch.topk(group_scores, k = cfg.topk_group, dim = -1, sorted = False)[1]
71+
group_mask = torch.zeros_like(group_scores)
72+
group_mask.scatter_(1, group_idx, 1)
73+
score_mask = (
74+
group_mask.unsqueeze(-1)
75+
.expand(-1, cfg.n_group, cfg.num_experts // cfg.n_group)
76+
.reshape(-1, cfg.num_experts)
77+
)
78+
scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
79+
80+
topk_indices = torch.topk(
81+
scores_for_choice,
82+
k = cfg.num_experts if activate_all_experts else cfg.num_experts_per_tok,
83+
dim = -1,
84+
sorted = False
85+
)[1]
86+
topk_weights = scores.gather(1, topk_indices)
87+
denominator = topk_weights.sum(dim = -1, keepdim = True) + 1e-20
88+
topk_weights /= denominator
89+
topk_weights = topk_weights * cfg.routed_scaling_factor
90+
return topk_indices, topk_weights
91+
92+
93+
def routing_dots(bsz, cfg, y, params):
94+
activate_all_experts = params.get("activate_all_experts")
95+
96+
if bsz == 1 and not activate_all_experts:
97+
torch.matmul(y, cfg.gate_tensor, out = cfg.router_logits_bsz1)
98+
cfg.router_logits_bsz1 += cfg.e_score_correction_bias
99+
torch.topk(
100+
cfg.router_logits_bsz1,
101+
cfg.num_experts_per_tok,
102+
dim = -1,
103+
out = (cfg.routing_weights_bsz1, cfg.selected_experts_bsz1),
104+
sorted = False
105+
)
106+
# TODO: Custom kernel for sigmoid normalization
107+
cfg.routing_weights_bsz1.sigmoid_()
108+
factor = cfg.routed_scaling_factor / (cfg.routing_weights_bsz1.sum(dim = -1, keepdim = True) + 1e-20)
109+
cfg.routing_weights_bsz1 *= factor
110+
return cfg.selected_experts_bsz1, cfg.routing_weights_bsz1
111+
112+
else:
113+
router_logits = torch.matmul(y, cfg.gate_tensor)
114+
router_logits += cfg.e_score_correction_bias
115+
routing_weights, selected_experts = torch.topk(
116+
router_logits,
117+
cfg.num_experts if activate_all_experts else cfg.num_experts_per_tok,
118+
dim = -1
119+
)
120+
# TODO: Custom kernel for sigmoid normalization
121+
routing_weights.sigmoid_()
122+
factor = cfg.routed_scaling_factor / (routing_weights.sum(dim = -1, keepdim = True) + 1e-20)
123+
routing_weights *= factor
124+
return selected_experts, routing_weights
125+
126+
53127
@dataclass
54128
class ExpertsCFG:
55129
yh: torch.Tensor
@@ -77,6 +151,10 @@ def __init__(
77151
out_dtype: torch.dtype = None,
78152
activation_fn: str = "silu",
79153
interm_dtype: torch.dtype = None,
154+
deepseekv3_routing: bool = False,
155+
routed_scaling_factor: float | None = None,
156+
n_group: int | None = None,
157+
topk_group: int | None = None,
80158
shared_experts: MLP | GatedMLP | None = None
81159
):
82160
super().__init__(config, key, None)
@@ -89,6 +167,11 @@ def __init__(
89167
self.num_experts_per_tok = num_experts_per_tok
90168
self.hidden_size = hidden_size
91169

170+
self.deepseekv3_routing = deepseekv3_routing
171+
self.routed_scaling_factor = routed_scaling_factor
172+
self.n_group = n_group
173+
self.topk_group = topk_group
174+
92175
self.routing_gate = Linear(
93176
config = config,
94177
key = f"{key}.{key_routing_gate}",
@@ -152,6 +235,8 @@ def __init__(
152235
self.routing_cfg = None
153236
self.experts_cfg = None
154237

238+
self.e_score_correction_bias = None
239+
155240
self.shared_experts = shared_experts
156241
if shared_experts is not None:
157242
self.register_submodule(shared_experts)
@@ -161,6 +246,9 @@ def __init__(
161246
def load(self, device: torch.Device, **kwargs):
162247
super().load(device, **kwargs)
163248

249+
self.e_score_correction_bias = \
250+
self.config.stc.get_tensor(self.key + ".gate.e_score_correction_bias", self.device, optional = True)
251+
164252
# Test if experts can be fused
165253
num_exl3_tensors = 0
166254
num_nonexl3_tensors = 0
@@ -189,7 +277,11 @@ def load(self, device: torch.Device, **kwargs):
189277
num_experts_per_tok = self.num_experts_per_tok,
190278
router_logits_bsz1 = router_logits_bsz1,
191279
routing_weights_bsz1 = routing_weights_bsz1,
192-
selected_experts_bsz1 = selected_experts_bsz1
280+
selected_experts_bsz1 = selected_experts_bsz1,
281+
e_score_correction_bias = self.e_score_correction_bias,
282+
routed_scaling_factor = self.routed_scaling_factor,
283+
n_group = self.n_group,
284+
topk_group = self.topk_group,
193285
)
194286

195287
yh = torch.empty(
@@ -231,6 +323,7 @@ def unload(self):
231323
self.multi_down = None
232324
self.routing_cfg = None
233325
self.experts_cfg = None
326+
self.e_score_correction_bias = None
234327
super().unload()
235328

236329

@@ -245,12 +338,18 @@ def forward(
245338
y = x.view(-1, self.hidden_size)
246339
bsz = y.shape[0]
247340

248-
# selected_experts, routing_weights = routing(bsz, self.routing_cfg, y, params)
249-
selected_experts, routing_weights = ext.blocksparse_mlp_routing(bsz, self.routing_cfg, y, params)
341+
if self.deepseekv3_routing:
342+
if self.n_group == 1 and self.topk_group == 1:
343+
selected_experts, routing_weights = routing_dots(bsz, self.routing_cfg, y, params)
344+
# else:
345+
# selected_experts, routing_weights = routing_ds3(bsz, self.routing_cfg, y, params)
346+
else:
347+
# selected_experts, routing_weights = routing(bsz, self.routing_cfg, y, params)
348+
selected_experts, routing_weights = ext.blocksparse_mlp_routing(bsz, self.routing_cfg, y, params, False)
250349

251350
# Torch path
252351
if bsz > 1 or not self.is_quantized:
253-
final_hidden_states = torch.zeros_like(y)
352+
final_hidden_states = torch.zeros_like(y, dtype = self.out_dtype)
254353

255354
expert_mask = torch.nn.functional.one_hot(
256355
selected_experts,
@@ -338,8 +437,14 @@ def mlp(exp_i, xc):
338437
)
339438

340439
final_hidden_states = cfg.out_d[:1, ...]
341-
return final_hidden_states.view(x.shape)
342440
final_hidden_states = final_hidden_states.view(x.shape)
343441
if self.shared_experts:
344442
final_hidden_states += self.shared_experts.forward(x, params)
345-
return final_hidden_states
443+
return final_hidden_states
444+
445+
@override
446+
def get_tensors(self):
447+
t = super().get_tensors()
448+
if self.e_score_correction_bias is not None:
449+
t[f"{self.key}.gate.e_score_correction_bias"] = self.e_score_correction_bias.contiguous()
450+
return t

0 commit comments

Comments
 (0)