Skip to content

Commit cf149b8

Browse files
kylesayrsdichndsikka
authored
[MoE] MoE Calibration with calibrate_all_experts (#1760)
Coauthored with @dichn! ## Purpose ## * Add support for `calibrate_all_experts` option, which sends all tokens to all experts, but still produces the same outputs as if tokens had been gated ## Changes ## * Modify model definitions such that, in the case of `calibrate_all_experts=True` token gating occurs after passing tokens to experts, rather than before ```python3 # `calibrate_all_experts=True` by default model = replace_modules_for_calibration(model, calibrate_all_experts=True) ``` ## Testing ## * Added correctness tests for new model definitions which checks that outputs are exactly the same * Added hook tests to make sure all experts are being sent tokens --------- Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: Di Chen <[email protected]> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 9765e2b commit cf149b8

File tree

10 files changed

+449
-93
lines changed

10 files changed

+449
-93
lines changed

src/llmcompressor/modeling/deepseek_v3.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,18 @@ class DeepseekV3MoECalibrate(torch.nn.Module):
1010
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
1111
"""
1212

13-
def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE):
13+
def __init__(
14+
self,
15+
config: DeepseekV3Config,
16+
original: OriginalDeepseekV3MoE,
17+
calibrate_all_experts: bool,
18+
):
1419
super().__init__()
1520
self.config = config
1621
self.experts = original.experts
1722
self.gate = original.gate
1823
self.shared_experts = original.shared_experts
24+
self.calibrate_all_experts = calibrate_all_experts
1925

2026
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
2127
residuals = hidden_states
@@ -30,24 +36,40 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
3036
)
3137
expert_mask = expert_mask.permute(2, 0, 1)
3238

33-
for expert_idx in range(len(self.experts)):
34-
expert = self.experts[expert_idx]
35-
mask = expert_mask[expert_idx]
36-
token_indices, weight_indices = torch.where(mask)
39+
for expert_idx, expert in enumerate(self.experts):
40+
token_indices, weight_indices = torch.where(expert_mask[expert_idx])
41+
has_tokens = token_indices.numel() > 0
3742

38-
expert_weights = topk_weights[token_indices, weight_indices]
39-
expert_input = hidden_states[token_indices]
40-
expert_output = expert(expert_input)
41-
weighted_output = expert_output * expert_weights.unsqueeze(-1)
43+
if self.calibrate_all_experts:
44+
expert_input = hidden_states
45+
expert_output = expert(expert_input)
4246

43-
if token_indices.numel() > 0:
44-
final_hidden_states.index_add_(0, token_indices, weighted_output)
47+
if has_tokens:
48+
expert_weights = topk_weights[token_indices, weight_indices]
49+
routed_output = expert_output[
50+
token_indices
51+
] * expert_weights.unsqueeze(-1)
52+
final_hidden_states.index_add_(0, token_indices, routed_output)
53+
else:
54+
# Normal MoE: only process tokens routed to this expert
55+
if has_tokens:
56+
expert_input = hidden_states[token_indices]
57+
expert_output = expert(expert_input)
58+
expert_weights = topk_weights[token_indices, weight_indices]
59+
routed_output = expert_output * expert_weights.unsqueeze(-1)
60+
final_hidden_states.index_add_(0, token_indices, routed_output)
4561
# End MoE
4662

4763
hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape)
4864
hidden_states = hidden_states + self.shared_experts(residuals)
4965
return hidden_states
5066

5167

52-
def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE):
53-
return DeepseekV3MoECalibrate(config=config, original=module)
68+
def replace(
69+
config: DeepseekV3Config,
70+
module: OriginalDeepseekV3MoE,
71+
calibrate_all_experts: bool,
72+
):
73+
return DeepseekV3MoECalibrate(
74+
config=config, original=module, calibrate_all_experts=calibrate_all_experts
75+
)
Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from typing import Tuple
22

33
import torch
4-
import transformers
5-
from packaging import version
64
from transformers.models.llama4.configuration_llama4 import (
75
Llama4Config,
86
Llama4TextConfig,
@@ -17,39 +15,57 @@
1715

1816

1917
class SequentialLlama4TextMoe(torch.nn.Module):
20-
def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):
18+
def __init__(
19+
self,
20+
config: Llama4TextConfig,
21+
original: Llama4TextMoe,
22+
calibrate_all_experts: bool,
23+
):
2124
super().__init__()
2225
self.top_k = config.num_experts_per_tok
2326
self.hidden_dim = config.hidden_size
2427
self.num_experts = config.num_local_experts
28+
2529
self.experts = SequentialLlama4TextExperts(config, original.experts)
2630
self.router = original.router
2731
self.shared_expert = original.shared_expert
32+
self.calibrate_all_experts = calibrate_all_experts
2833

2934
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]:
3035
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
31-
router_logits = self.router(hidden_states)
36+
router_outputs = self.router(hidden_states)
37+
3238
# support transformers 4.53 and greater
33-
if isinstance(router_logits, tuple):
34-
router_logits = router_logits[-1]
39+
if isinstance(router_outputs, tuple):
40+
router_scores, router_logits = router_outputs
41+
else:
42+
router_top_value, router_indices = torch.topk(
43+
router_logits, self.top_k, dim=1
44+
)
45+
router_logits = router_outputs
46+
router_scores = (
47+
torch.full_like(router_logits, float("-inf"))
48+
.scatter_(1, router_indices, router_top_value)
49+
.transpose(0, 1)
50+
)
51+
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
3552

36-
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
53+
out = self.shared_expert(hidden_states)
54+
for expert_index in range(self.num_experts):
55+
top_token_mask = router_scores[:, expert_index] > 0
3756

38-
router_scores = (
39-
torch.full_like(router_logits, float("-inf"))
40-
.scatter_(1, router_indices, router_top_value)
41-
.transpose(0, 1)
42-
)
43-
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
57+
if self.calibrate_all_experts:
58+
# Run all tokens for calibration
59+
expert_out = self.experts[expert_index](hidden_states)[top_token_mask]
60+
else:
61+
expert_out = self.experts[expert_index](hidden_states[top_token_mask])
4462

45-
out = self.shared_expert(hidden_states)
46-
for i in range(self.num_experts):
47-
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)
63+
# Only top-k tokens contribute to final output
64+
if top_token_mask.any():
65+
expert_score = router_scores[top_token_mask, expert_index].unsqueeze(-1)
66+
out[top_token_mask] += expert_out * expert_score
4867

49-
if version.parse(transformers.__version__) >= version.parse("4.54.0"):
50-
return out, router_logits
51-
else:
52-
return out, router_scores
68+
return out, router_scores
5369

5470

5571
class SequentialLlama4TextExperts(torch.nn.ModuleList):
@@ -72,5 +88,9 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts):
7288
self[i].down_proj.weight.data = down.t().clone().contiguous()
7389

7490

75-
def replace(config: Llama4Config, module: Llama4TextMoe):
76-
return SequentialLlama4TextMoe(config=config.get_text_config(), original=module)
91+
def replace(config: Llama4Config, module: Llama4TextMoe, calibrate_all_experts: bool):
92+
return SequentialLlama4TextMoe(
93+
config=config.get_text_config(),
94+
original=module,
95+
calibrate_all_experts=calibrate_all_experts,
96+
)

src/llmcompressor/modeling/prepare.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import tqdm
12
from compressed_tensors.utils import replace_module
23
from transformers import PreTrainedModel
34

@@ -15,11 +16,18 @@
1516
}
1617

1718

18-
def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
19-
for name, module in model.named_modules():
19+
def replace_modules_for_calibration(
20+
model: PreTrainedModel,
21+
calibrate_all_experts: bool = True,
22+
) -> PreTrainedModel:
23+
for name, module in tqdm.tqdm(list(model.named_modules())):
2024
cls_name = module.__class__.__name__
2125
if cls_name in replacements:
22-
new_module = replacements[cls_name](config=model.config, module=module)
26+
new_module = replacements[cls_name](
27+
config=model.config,
28+
module=module,
29+
calibrate_all_experts=calibrate_all_experts,
30+
)
2331
replace_module(model, name, new_module)
2432

2533
return model
@@ -28,7 +36,7 @@ def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
2836
# ------------------- module replacements; during calibration --------------------
2937

3038

31-
def update_qwen3_moe(model, stack):
39+
def update_qwen3_moe(model, stack, calibrate_all_experts):
3240
for module in model.modules():
3341
cls_name = module.__class__.__name__
3442
if cls_name == "Qwen3MoeDecoderLayer":
@@ -37,7 +45,11 @@ def update_qwen3_moe(model, stack):
3745
patch_attr(
3846
module,
3947
"mlp",
40-
replace_Qwen3MoE(config=model.config, module=module.mlp),
48+
replace_Qwen3MoE(
49+
config=model.config,
50+
module=module.mlp,
51+
calibrate_all_experts=calibrate_all_experts,
52+
),
4153
)
4254
)
4355

@@ -47,9 +59,13 @@ def update_qwen3_moe(model, stack):
4759
}
4860

4961

50-
def moe_calibration_context(model: PreTrainedModel, stack):
62+
def moe_calibration_context(
63+
model: PreTrainedModel,
64+
stack,
65+
calibrate_all_experts: bool = False,
66+
):
5167
# Temporarily updates the MoE modules within the context
5268
# Once the context exists, parameter updates persist
5369
cls_name = model.__class__.__name__
5470
if cls_name in moe_context:
55-
moe_context.get(cls_name)(model, stack)
71+
moe_context.get(cls_name)(model, stack, calibrate_all_experts)

src/llmcompressor/modeling/qwen3_moe.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,17 @@
2323

2424
class Qwen3MoeSparseMoeBlock(torch.nn.Module):
2525
def __init__(
26-
self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock
26+
self,
27+
config: Qwen3MoeConfig,
28+
original: OriginalQwen3MoeSparseMoeBlock,
29+
calibrate_all_experts: bool,
2730
):
2831
super().__init__()
2932
self.num_experts = config.num_experts
30-
self.top_k = config.top_k
33+
self.top_k = config.num_experts_per_tok
3134
self.norm_topk_prob = config.norm_topk_prob
3235

33-
# gating
36+
self.calibrate_all_experts = calibrate_all_experts
3437
self.gate = original.gate
3538
self.experts = original.experts
3639

@@ -50,6 +53,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
5053
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
5154
# we cast back to the input dtype
5255
routing_weights = routing_weights.to(hidden_states.dtype)
56+
5357
final_hidden_states = torch.zeros(
5458
(batch_size * sequence_length, hidden_dim),
5559
dtype=hidden_states.dtype,
@@ -62,26 +66,32 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6266
selected_experts, num_classes=self.num_experts
6367
).permute(2, 1, 0)
6468

65-
for expert_idx in range(len(self.experts)):
66-
expert_layer = self.experts[expert_idx]
69+
for expert_idx, expert_layer in enumerate(self.experts):
6770
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
68-
# Index the correct hidden states and compute the expert hidden state for
69-
# the current expert. We need to make sure to multiply the output hidden
70-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
71-
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
72-
expert_output = expert_layer(current_state)
73-
current_hidden_states = expert_output * routing_weights[top_x, idx, None]
74-
# However `index_add_` only support torch tensors for indexing so we'll use
75-
# the `top_x` tensor here.
76-
final_hidden_states.index_add_(
77-
0, top_x, current_hidden_states.to(hidden_states.dtype)
78-
)
71+
72+
if self.calibrate_all_experts:
73+
expert_out = expert_layer(hidden_states)[top_x]
74+
else:
75+
expert_out = expert_layer(hidden_states[top_x])
76+
77+
# TODO: double check
78+
if len(top_x) > 0:
79+
current_hidden_states = expert_out * routing_weights[top_x, idx, None]
80+
final_hidden_states.index_add_(
81+
0, top_x, current_hidden_states.to(hidden_states.dtype)
82+
)
7983

8084
final_hidden_states = final_hidden_states.reshape(
8185
batch_size, sequence_length, hidden_dim
8286
)
8387
return final_hidden_states, router_logits
8488

8589

86-
def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock):
87-
return Qwen3MoeSparseMoeBlock(config=config, original=module)
90+
def replace(
91+
config: Qwen3MoeConfig,
92+
module: OriginalQwen3MoeSparseMoeBlock,
93+
calibrate_all_experts: bool,
94+
):
95+
return Qwen3MoeSparseMoeBlock(
96+
config=config, original=module, calibrate_all_experts=calibrate_all_experts
97+
)

0 commit comments

Comments
 (0)