Skip to content

Commit 61c1efc

Browse files
committed
[MoE] Add conditional expert calibration
Change Purpose: - Improve MoE calibration support by adding configuration-based expert execution Change Details: - Create class `CalibrationConfig` to standalone llmcompressor/modeling/config.py - Add conditional expert execution based on: - `moe_calibrate_all_experts`: If True, all experts run for every token; If False, only routed experts are run - `moe_calibrate_gated_acts`: If True, routed experts contribute final output; If False, expert activations are computed but excluded from the final output - Add unit test to verify all experts are triggered during MoE calibration
1 parent aeb4b79 commit 61c1efc

File tree

8 files changed

+335
-43
lines changed

8 files changed

+335
-43
lines changed

src/llmcompressor/modeling/config.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from pydantic import BaseModel, model_validator
2+
3+
4+
class CalibrationConfig(BaseModel):
5+
moe_calibrate_all_experts: bool
6+
moe_calibrate_gated_acts: bool
7+
8+
@model_validator(mode="after")
9+
def validate_config(self):
10+
11+
if not self.moe_calibrate_gated_acts and not self.moe_calibrate_all_experts:
12+
raise NotImplementedError(
13+
"Using all experts for activations without calibrating all experts is not supported. "
14+
"Please set moe_calibrate_gated_acts=True or moe_calibrate_all_experts=True."
15+
)
16+
17+
return self

src/llmcompressor/modeling/deepseek_v3.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,31 @@
44
DeepseekV3MoE as OriginalDeepseekV3MoE,
55
)
66

7+
from llmcompressor.modeling.config import CalibrationConfig
8+
79

810
class DeepseekV3MoECalibrate(torch.nn.Module):
911
"""
1012
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
1113
"""
1214

13-
def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE):
15+
def __init__(
16+
self,
17+
config: DeepseekV3Config,
18+
original: OriginalDeepseekV3MoE,
19+
calib_config: CalibrationConfig,
20+
):
1421
super().__init__()
1522
self.config = config
1623
self.experts = original.experts
1724
self.gate = original.gate
1825
self.shared_experts = original.shared_experts
1926

27+
self.calib_config = calib_config
28+
29+
if not calib_config.moe_calibrate_gated_acts:
30+
self.gate.top_k = self.gate.n_routed_experts # ungate experts
31+
2032
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
2133
residuals = hidden_states
2234
orig_shape = hidden_states.shape
@@ -35,19 +47,30 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
3547
mask = expert_mask[expert_idx]
3648
token_indices, weight_indices = torch.where(mask)
3749

38-
expert_weights = topk_weights[token_indices, weight_indices]
39-
expert_input = hidden_states[token_indices]
40-
expert_output = expert(expert_input)
41-
weighted_output = expert_output * expert_weights.unsqueeze(-1)
50+
has_tokens = token_indices.numel() > 0
51+
52+
if self.calib_config.moe_calibrate_all_experts or has_tokens:
53+
# calibrate one expert
54+
expert_weights = topk_weights[token_indices, weight_indices]
55+
expert_input = hidden_states[token_indices]
56+
expert_output = expert(expert_input)
57+
weighted_output = expert_output * expert_weights.unsqueeze(-1)
4258

43-
if token_indices.numel() > 0:
44-
final_hidden_states.index_add_(0, token_indices, weighted_output)
59+
if has_tokens:
60+
# expert contributes to output activations
61+
final_hidden_states.index_add_(0, token_indices, weighted_output)
4562
# End MoE
4663

4764
hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape)
4865
hidden_states = hidden_states + self.shared_experts(residuals)
4966
return hidden_states
5067

5168

52-
def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE):
53-
return DeepseekV3MoECalibrate(config=config, original=module)
69+
def replace(
70+
config: DeepseekV3Config,
71+
module: OriginalDeepseekV3MoE,
72+
calib_config: CalibrationConfig,
73+
):
74+
return DeepseekV3MoECalibrate(
75+
config=config, original=module, calib_config=calib_config
76+
)

src/llmcompressor/modeling/llama4.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,17 @@
1111
Llama4TextMoe,
1212
)
1313

14+
from llmcompressor.modeling.config import CalibrationConfig
1415
from llmcompressor.utils.dev import skip_weights_initialize
1516

1617

1718
class SequentialLlama4TextMoe(torch.nn.Module):
18-
def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):
19+
def __init__(
20+
self,
21+
config: Llama4TextConfig,
22+
original: Llama4TextMoe,
23+
calib_config: CalibrationConfig,
24+
):
1925
super().__init__()
2026
self.top_k = config.num_experts_per_tok
2127
self.hidden_dim = config.hidden_size
@@ -24,6 +30,8 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):
2430
self.router = original.router
2531
self.shared_expert = original.shared_expert
2632

33+
self.calib_config = calib_config
34+
2735
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]:
2836
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
2937
router_logits = self.router(hidden_states)
@@ -39,7 +47,15 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tens
3947

4048
out = self.shared_expert(hidden_states)
4149
for i in range(self.num_experts):
42-
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)
50+
score = router_scores[i]
51+
has_tokens = torch.any(score > 0)
52+
53+
if self.calib_config.moe_calibrate_all_experts or has_tokens:
54+
expert_output = self.experts[i](hidden_states)
55+
56+
if has_tokens:
57+
# expert contributes to output activations
58+
out += expert_output * score.unsqueeze(-1)
4359

4460
return out, router_scores
4561

@@ -64,5 +80,9 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts):
6480
self[i].down_proj.weight.data = down.t().clone().contiguous()
6581

6682

67-
def replace(config: Llama4Config, module: Llama4TextMoe):
68-
return SequentialLlama4TextMoe(config=config.get_text_config(), original=module)
83+
def replace(
84+
config: Llama4Config, module: Llama4TextMoe, calib_config: CalibrationConfig
85+
):
86+
return SequentialLlama4TextMoe(
87+
config=config.get_text_config(), original=module, calib_config=calib_config
88+
)

src/llmcompressor/modeling/prepare.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,39 @@
11
from compressed_tensors.utils import replace_module
22
from transformers import PreTrainedModel
33

4+
from llmcompressor.modeling.config import CalibrationConfig
45
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
56
from llmcompressor.modeling.llama4 import replace as replace_llama4
67
from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE
78
from llmcompressor.utils.helpers import patch_attr
89

910
__all__ = ["replace_modules_for_calibration"]
1011

12+
1113
# ---------------------- module replacements; permanent -------------------------
1214
replacements = {
1315
"DeepseekV3MoE": replace_deepseekv3,
1416
"Llama4TextMoe": replace_llama4,
1517
}
1618

1719

18-
def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
20+
def replace_modules_for_calibration(
21+
model: PreTrainedModel,
22+
moe_calibrate_all_experts: bool = True,
23+
moe_calibrate_gated_acts: bool = True,
24+
) -> PreTrainedModel:
25+
26+
calib_config = CalibrationConfig(
27+
moe_calibrate_all_experts=moe_calibrate_all_experts,
28+
moe_calibrate_gated_acts=moe_calibrate_gated_acts,
29+
)
30+
1931
for name, module in model.named_modules():
2032
cls_name = module.__class__.__name__
2133
if cls_name in replacements:
22-
new_module = replacements[cls_name](config=model.config, module=module)
34+
new_module = replacements[cls_name](
35+
config=model.config, module=module, calib_config=calib_config
36+
)
2337
replace_module(model, name, new_module)
2438

2539
return model
@@ -28,7 +42,7 @@ def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
2842
# ------------------- module replacements; during calibration --------------------
2943

3044

31-
def update_qwen3_moe(model, stack):
45+
def update_qwen3_moe(model, stack, calib_config):
3246
for module in model.modules():
3347
cls_name = module.__class__.__name__
3448
if cls_name == "Qwen3MoeDecoderLayer":
@@ -37,7 +51,11 @@ def update_qwen3_moe(model, stack):
3751
patch_attr(
3852
module,
3953
"mlp",
40-
replace_Qwen3MoE(config=model.config, module=module.mlp),
54+
replace_Qwen3MoE(
55+
config=model.config,
56+
module=module.mlp,
57+
calib_config=calib_config,
58+
),
4159
)
4260
)
4361

@@ -47,9 +65,19 @@ def update_qwen3_moe(model, stack):
4765
}
4866

4967

50-
def moe_calibration_context(model: PreTrainedModel, stack):
68+
def moe_calibration_context(
69+
model: PreTrainedModel,
70+
stack,
71+
moe_calibrate_all_experts: bool = True,
72+
moe_calibrate_gated_acts: bool = True,
73+
):
74+
calib_config = CalibrationConfig(
75+
moe_calibrate_all_experts=moe_calibrate_all_experts,
76+
moe_calibrate_gated_acts=moe_calibrate_gated_acts,
77+
)
78+
5179
# Temporarily updates the MoE modules within the context
5280
# Once the context exists, parameter updates persist
5381
cls_name = model.__class__.__name__
5482
if cls_name in moe_context:
55-
moe_context.get(cls_name)(model, stack)
83+
moe_context.get(cls_name)(model, stack, calib_config)

src/llmcompressor/modeling/qwen3_moe.py

Lines changed: 63 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,61 @@
2020
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock,
2121
)
2222

23+
from llmcompressor.modeling.config import CalibrationConfig
24+
2325

2426
class Qwen3MoeSparseMoeBlock(torch.nn.Module):
2527
def __init__(
26-
self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock
28+
self,
29+
config: Qwen3MoeConfig,
30+
original: OriginalQwen3MoeSparseMoeBlock,
31+
calib_config: CalibrationConfig,
2732
):
2833
super().__init__()
2934
self.num_experts = config.num_experts
3035
self.top_k = config.top_k
3136
self.norm_topk_prob = config.norm_topk_prob
37+
self.calib_config = calib_config
3238

3339
# gating
3440
self.gate = original.gate
3541
self.experts = original.experts
3642

43+
if not self.calib_config.moe_calibrate_gated_acts:
44+
self.gate.top_k = self.num_experts # ungate experts
45+
3746
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
3847
batch_size, sequence_length, hidden_dim = hidden_states.shape
3948
hidden_states = hidden_states.view(-1, hidden_dim)
4049
# router_logits: (batch * sequence_length, n_experts)
4150
router_logits = self.gate(hidden_states)
4251

43-
routing_weights = torch.nn.functional.softmax(
44-
router_logits, dim=1, dtype=torch.float
45-
)
46-
routing_weights, selected_experts = torch.topk(
47-
routing_weights, self.top_k, dim=-1
48-
)
49-
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
50-
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
51-
# we cast back to the input dtype
52-
routing_weights = routing_weights.to(hidden_states.dtype)
52+
if self.calib_config.moe_calibrate_gated_acts:
53+
routing_weights = torch.nn.functional.softmax(
54+
router_logits, dim=1, dtype=torch.float
55+
)
56+
routing_weights, selected_experts = torch.topk(
57+
routing_weights, self.top_k, dim=-1
58+
)
59+
# only diff with mixtral sparse moe block!
60+
if self.norm_topk_prob:
61+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
62+
# we cast back to the input dtype
63+
routing_weights = routing_weights.to(hidden_states.dtype)
64+
65+
else:
66+
# ungate experts
67+
selected_experts = torch.arange(
68+
self.num_experts, device=hidden_states.device
69+
)
70+
selected_experts = selected_experts.unsqueeze(0).expand(
71+
hidden_states.shape[0], -1
72+
)
73+
routing_weights = (
74+
torch.ones_like(selected_experts, dtype=hidden_states.dtype)
75+
/ self.num_experts
76+
)
77+
5378
final_hidden_states = torch.zeros(
5479
(batch_size * sequence_length, hidden_dim),
5580
dtype=hidden_states.dtype,
@@ -65,23 +90,37 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
6590
for expert_idx in range(len(self.experts)):
6691
expert_layer = self.experts[expert_idx]
6792
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
68-
# Index the correct hidden states and compute the expert hidden state for
69-
# the current expert. We need to make sure to multiply the output hidden
70-
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
71-
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
72-
expert_output = expert_layer(current_state)
73-
current_hidden_states = expert_output * routing_weights[top_x, idx, None]
74-
# However `index_add_` only support torch tensors for indexing so we'll use
75-
# the `top_x` tensor here.
76-
final_hidden_states.index_add_(
77-
0, top_x, current_hidden_states.to(hidden_states.dtype)
78-
)
93+
94+
has_tokens = idx.numel() > 0
95+
96+
if self.calib_config.moe_calibrate_all_experts or has_tokens:
97+
# Index the correct hidden states and compute the expert hidden state for
98+
# the current expert. We need to make sure to multiply the output hidden
99+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
100+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
101+
expert_output = expert_layer(current_state)
102+
current_hidden_states = (
103+
expert_output * routing_weights[top_x, idx, None]
104+
)
105+
106+
# However `index_add_` only support torch tensors for indexing so we'll use
107+
# the `top_x` tensor here.
108+
if has_tokens:
109+
final_hidden_states.index_add_(
110+
0, top_x, current_hidden_states.to(hidden_states.dtype)
111+
)
79112

80113
final_hidden_states = final_hidden_states.reshape(
81114
batch_size, sequence_length, hidden_dim
82115
)
83116
return final_hidden_states, router_logits
84117

85118

86-
def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock):
87-
return Qwen3MoeSparseMoeBlock(config=config, original=module)
119+
def replace(
120+
config: Qwen3MoeConfig,
121+
module: OriginalQwen3MoeSparseMoeBlock,
122+
calib_config: CalibrationConfig,
123+
):
124+
return Qwen3MoeSparseMoeBlock(
125+
config=config, original=module, calib_config=calib_config
126+
)

0 commit comments

Comments
 (0)