Skip to content

Commit 1023895

Browse files
committed
[MoE] Add conditional expert calibration
Change Purpose: - Add calibrate_all_experts option to improve MoE calibration Change Details: - Add `calibrate_all_experts` flag to MoE layers - Update `replace_modules_for_calibration` and `moe_calibration_context` to propagate the flag into modules - Modify expert forward passes: * Normal mode (default): compute output only for tokens routed to top-k experts, and combine their weighted results in the final output * Calibration mode (`calibrate_all_experts=True`): compute output for all tokens on every expert, but still apply the top-k gating to decide which token outputs contribute to the final result. Testing: - Add unit test to verify all experts are triggered during MoE calibration
1 parent cbf36fe commit 1023895

File tree

7 files changed

+278
-37
lines changed

7 files changed

+278
-37
lines changed

src/llmcompressor/modeling/deepseek_v3.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,18 @@ class DeepseekV3MoECalibrate(torch.nn.Module):
1010
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
1111
"""
1212

13-
def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE):
13+
def __init__(
14+
self,
15+
config: DeepseekV3Config,
16+
original: OriginalDeepseekV3MoE,
17+
calibrate_all_experts: bool,
18+
):
1419
super().__init__()
1520
self.config = config
1621
self.experts = original.experts
1722
self.gate = original.gate
1823
self.shared_experts = original.shared_experts
24+
self.calibrate_all_experts = calibrate_all_experts
1925

2026
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
2127
residuals = hidden_states
@@ -30,24 +36,40 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
3036
)
3137
expert_mask = expert_mask.permute(2, 0, 1)
3238

33-
for expert_idx in range(len(self.experts)):
34-
expert = self.experts[expert_idx]
35-
mask = expert_mask[expert_idx]
36-
token_indices, weight_indices = torch.where(mask)
39+
for expert_idx, expert in enumerate(self.experts):
40+
token_indices, weight_indices = torch.where(expert_mask[expert_idx])
41+
has_tokens = token_indices.numel() > 0
3742

38-
expert_weights = topk_weights[token_indices, weight_indices]
39-
expert_input = hidden_states[token_indices]
40-
expert_output = expert(expert_input)
41-
weighted_output = expert_output * expert_weights.unsqueeze(-1)
43+
if self.calibrate_all_experts:
44+
expert_input = hidden_states
45+
expert_output = expert(expert_input)
4246

43-
if token_indices.numel() > 0:
44-
final_hidden_states.index_add_(0, token_indices, weighted_output)
47+
if has_tokens:
48+
expert_weights = topk_weights[token_indices, weight_indices]
49+
routed_output = expert_output[
50+
token_indices
51+
] * expert_weights.unsqueeze(-1)
52+
final_hidden_states.index_add_(0, token_indices, routed_output)
53+
else:
54+
# Normal MoE: only process tokens routed to this expert
55+
if has_tokens:
56+
expert_input = hidden_states[token_indices]
57+
expert_output = expert(expert_input)
58+
expert_weights = topk_weights[token_indices, weight_indices]
59+
routed_output = expert_output * expert_weights.unsqueeze(-1)
60+
final_hidden_states.index_add_(0, token_indices, routed_output)
4561
# End MoE
4662

4763
hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape)
4864
hidden_states = hidden_states + self.shared_experts(residuals)
4965
return hidden_states
5066

5167

52-
def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE):
53-
return DeepseekV3MoECalibrate(config=config, original=module)
68+
def replace(
69+
config: DeepseekV3Config,
70+
module: OriginalDeepseekV3MoE,
71+
calibrate_all_experts: bool,
72+
):
73+
return DeepseekV3MoECalibrate(
74+
config=config, original=module, calibrate_all_experts=calibrate_all_experts
75+
)

src/llmcompressor/modeling/llama4.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,20 @@
1717

1818

1919
class SequentialLlama4TextMoe(torch.nn.Module):
20-
def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):
20+
def __init__(
21+
self,
22+
config: Llama4TextConfig,
23+
original: Llama4TextMoe,
24+
calibrate_all_experts: bool,
25+
):
2126
super().__init__()
2227
self.top_k = config.num_experts_per_tok
2328
self.hidden_dim = config.hidden_size
2429
self.num_experts = config.num_local_experts
2530
self.experts = SequentialLlama4TextExperts(config, original.experts)
2631
self.router = original.router
2732
self.shared_expert = original.shared_expert
33+
self.calibrate_all_experts = calibrate_all_experts
2834

2935
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]:
3036
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
@@ -44,7 +50,21 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tens
4450

4551
out = self.shared_expert(hidden_states)
4652
for i in range(self.num_experts):
47-
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)
53+
expert_output = None
54+
if self.calibrate_all_experts:
55+
# Run all tokens for calibration
56+
expert_output = self.experts[i](hidden_states)
57+
58+
# Only top-k tokens contribute to final output
59+
top_token_mask = router_scores[i] > 0
60+
if top_token_mask.any():
61+
if expert_output is None:
62+
expert_output = self.experts[i](hidden_states[top_token_mask])
63+
else:
64+
expert_output = expert_output[top_token_mask]
65+
out[top_token_mask] += expert_output * router_scores[
66+
i, top_token_mask
67+
].unsqueeze(-1)
4868

4969
if version.parse(transformers.__version__) >= version.parse("4.54.0"):
5070
return out, router_logits
@@ -72,5 +92,9 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts):
7292
self[i].down_proj.weight.data = down.t().clone().contiguous()
7393

7494

75-
def replace(config: Llama4Config, module: Llama4TextMoe):
76-
return SequentialLlama4TextMoe(config=config.get_text_config(), original=module)
95+
def replace(config: Llama4Config, module: Llama4TextMoe, calibrate_all_experts: bool):
96+
return SequentialLlama4TextMoe(
97+
config=config.get_text_config(),
98+
original=module,
99+
calibrate_all_experts=calibrate_all_experts,
100+
)

src/llmcompressor/modeling/prepare.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,19 @@
1515
}
1616

1717

18-
def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
18+
def replace_modules_for_calibration(
19+
model: PreTrainedModel,
20+
calibrate_all_experts: bool = False,
21+
) -> PreTrainedModel:
22+
1923
for name, module in model.named_modules():
2024
cls_name = module.__class__.__name__
2125
if cls_name in replacements:
22-
new_module = replacements[cls_name](config=model.config, module=module)
26+
new_module = replacements[cls_name](
27+
config=model.config,
28+
module=module,
29+
calibrate_all_experts=calibrate_all_experts,
30+
)
2331
replace_module(model, name, new_module)
2432

2533
return model
@@ -28,7 +36,7 @@ def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
2836
# ------------------- module replacements; during calibration --------------------
2937

3038

31-
def update_qwen3_moe(model, stack):
39+
def update_qwen3_moe(model, stack, calibrate_all_experts):
3240
for module in model.modules():
3341
cls_name = module.__class__.__name__
3442
if cls_name == "Qwen3MoeDecoderLayer":
@@ -37,7 +45,11 @@ def update_qwen3_moe(model, stack):
3745
patch_attr(
3846
module,
3947
"mlp",
40-
replace_Qwen3MoE(config=model.config, module=module.mlp),
48+
replace_Qwen3MoE(
49+
config=model.config,
50+
module=module.mlp,
51+
calibrate_all_experts=calibrate_all_experts,
52+
),
4153
)
4254
)
4355

@@ -47,9 +59,13 @@ def update_qwen3_moe(model, stack):
4759
}
4860

4961

50-
def moe_calibration_context(model: PreTrainedModel, stack):
62+
def moe_calibration_context(
63+
model: PreTrainedModel,
64+
stack,
65+
calibrate_all_experts: bool = False,
66+
):
5167
# Temporarily updates the MoE modules within the context
5268
# Once the context exists, parameter updates persist
5369
cls_name = model.__class__.__name__
5470
if cls_name in moe_context:
55-
moe_context.get(cls_name)(model, stack)
71+
moe_context.get(cls_name)(model, stack, calibrate_all_experts)

src/llmcompressor/modeling/qwen3_moe.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,16 @@
2323

2424
class Qwen3MoeSparseMoeBlock(torch.nn.Module):
2525
def __init__(
26-
self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock
26+
self,
27+
config: Qwen3MoeConfig,
28+
original: OriginalQwen3MoeSparseMoeBlock,
29+
calibrate_all_experts: bool,
2730
):
2831
super().__init__()
2932
self.num_experts = config.num_experts
3033
self.top_k = config.top_k
3134
self.norm_topk_prob = config.norm_topk_prob
35+
self.calibrate_all_experts = calibrate_all_experts
3236

3337
# gating
3438
self.gate = original.gate
@@ -64,24 +68,41 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6468

6569
for expert_idx in range(len(self.experts)):
6670
expert_layer = self.experts[expert_idx]
71+
cached_output = None
72+
73+
if self.calibrate_all_experts:
74+
cached_output = expert_layer(hidden_states)
75+
6776
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
68-
# Index the correct hidden states and compute the expert hidden state for
69-
# the current expert. We need to make sure to multiply the output hidden
70-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
71-
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
72-
expert_output = expert_layer(current_state)
73-
current_hidden_states = expert_output * routing_weights[top_x, idx, None]
74-
# However `index_add_` only support torch tensors for indexing so we'll use
75-
# the `top_x` tensor here.
76-
final_hidden_states.index_add_(
77-
0, top_x, current_hidden_states.to(hidden_states.dtype)
78-
)
77+
if top_x.numel() > 0:
78+
if cached_output is not None:
79+
expert_output = cached_output[top_x]
80+
else:
81+
# Index the correct hidden states and compute the expert hidden state for
82+
# the current expert. We need to make sure to multiply the output hidden
83+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
84+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
85+
expert_output = expert_layer(current_state)
86+
current_hidden_states = (
87+
expert_output * routing_weights[top_x, idx, None]
88+
)
89+
# However `index_add_` only support torch tensors for indexing so we'll use
90+
# the `top_x` tensor here.
91+
final_hidden_states.index_add_(
92+
0, top_x, current_hidden_states.to(hidden_states.dtype)
93+
)
7994

8095
final_hidden_states = final_hidden_states.reshape(
8196
batch_size, sequence_length, hidden_dim
8297
)
8398
return final_hidden_states, router_logits
8499

85100

86-
def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock):
87-
return Qwen3MoeSparseMoeBlock(config=config, original=module)
101+
def replace(
102+
config: Qwen3MoeConfig,
103+
module: OriginalQwen3MoeSparseMoeBlock,
104+
calibrate_all_experts: bool,
105+
):
106+
return Qwen3MoeSparseMoeBlock(
107+
config=config, original=module, calibrate_all_experts=calibrate_all_experts
108+
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from functools import partial
2+
3+
import pytest
4+
import torch
5+
from transformers import AutoModelForCausalLM
6+
7+
from llmcompressor.modeling.deepseek_v3 import DeepseekV3MoECalibrate
8+
from llmcompressor.modeling.prepare import replace_modules_for_calibration
9+
from llmcompressor.utils.dev import skip_weights_download
10+
11+
12+
@pytest.mark.parametrize("model_stub", ["unsloth/DeepSeek-R1-0528-BF16"])
13+
def test_calib_replace_deepseekv3moe_all_experts(model_stub):
14+
with skip_weights_download():
15+
model = AutoModelForCausalLM.from_pretrained(model_stub)
16+
17+
replace_modules_for_calibration(model, calibrate_all_experts=True)
18+
19+
# Find a Deepseek MoE layer
20+
moe_layer = None
21+
for _, module in model.named_modules():
22+
if isinstance(module, DeepseekV3MoECalibrate):
23+
moe_layer = module
24+
break
25+
26+
assert moe_layer is not None
27+
28+
num_experts = len(moe_layer.experts)
29+
expert_triggered = [False for _ in range(num_experts)]
30+
31+
# Define the hook function
32+
def hook_fn(i, module, input, output):
33+
expert_triggered[i] = True
34+
35+
# Attach hooks using functools.partial to bind each index
36+
for i, expert in enumerate(moe_layer.experts):
37+
expert.register_forward_hook(partial(hook_fn, i))
38+
39+
# Create dummy input tensor that simulates hidden_states
40+
hidden_dim = model.config.hidden_size
41+
batch, seq_len = 4, 32
42+
sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32)
43+
44+
# Forward through the MoE layer directly
45+
with torch.no_grad():
46+
_ = moe_layer(sample)
47+
48+
# Assert all experts are used
49+
assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}"
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from functools import partial
2+
3+
import pytest
4+
import torch
5+
from transformers import Llama4ForConditionalGeneration
6+
7+
from llmcompressor.modeling.llama4 import SequentialLlama4TextMoe
8+
from llmcompressor.modeling.prepare import replace_modules_for_calibration
9+
from llmcompressor.utils.dev import skip_weights_download
10+
11+
12+
@pytest.mark.parametrize("model_stub", ["meta-llama/Llama-4-Scout-17B-16E-Instruct"])
13+
def test_calib_replace_llama4_moe_all_experts(model_stub):
14+
with skip_weights_download(Llama4ForConditionalGeneration):
15+
model = Llama4ForConditionalGeneration.from_pretrained(
16+
model_stub, torch_dtype="auto"
17+
)
18+
19+
replace_modules_for_calibration(model, calibrate_all_experts=True)
20+
21+
# Find a Llama4 MoE layer
22+
moe_layer = None
23+
for _, module in model.named_modules():
24+
if isinstance(module, SequentialLlama4TextMoe):
25+
moe_layer = module
26+
break
27+
28+
assert moe_layer is not None
29+
30+
num_experts = len(moe_layer.experts)
31+
expert_triggered = [False for _ in range(num_experts)]
32+
33+
# Define the hook function
34+
def hook_fn(i, module, input, output):
35+
expert_triggered[i] = True
36+
37+
# Attach hooks using functools.partial to bind each index
38+
for i, expert in enumerate(moe_layer.experts):
39+
expert.register_forward_hook(partial(hook_fn, i))
40+
41+
# Create dummy input tensor that simulates hidden_states
42+
hidden_dim = model.config.hidden_size
43+
batch, seq_len = 4, 32
44+
sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32)
45+
46+
# Forward through the MoE layer directly
47+
with torch.no_grad():
48+
_ = moe_layer(sample)
49+
50+
# Assert all experts are used
51+
assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}"

0 commit comments

Comments
 (0)