Skip to content

[MoE] Add conditional expert calibration #1701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/llmcompressor/modeling/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pydantic import BaseModel, model_validator


class CalibrationConfig(BaseModel):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: What do you think about renaming to MoECalibrationConfig?

moe_calibrate_all_experts: bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add more information in this config class around what these flags do for future readers, so it's clear which flag should be set for which mode?

I was thinking something like:

  | all_experts | gated_acts | Behavior                                                               |
  |-------------|------------|------------------------------------------------------------------------|
  | True        | True       | All experts run, routed experts contribute to output (current default) |
  | True        | False      | All experts run for calibration, but outputs ignored                   |
  | False       | True       | Only routed experts run and contribute (standard inference)            |
  | False       | False      | Invalid configuration (raises error)                                   |

moe_calibrate_gated_acts: bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Consider renaming to something like use_gated_outputs since The name suggests it's about "calibrating
gated activations" but it actually controls whether expert outputs contribute to the final result.


@model_validator(mode="after")
def validate_config(self):

if not self.moe_calibrate_gated_acts and not self.moe_calibrate_all_experts:
raise NotImplementedError(
"At least one of moe_calibrate_gated_acts or "
"moe_calibrate_all_experts must be set to True."
)

return self
41 changes: 32 additions & 9 deletions src/llmcompressor/modeling/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,31 @@
DeepseekV3MoE as OriginalDeepseekV3MoE,
)

from llmcompressor.modeling.config import CalibrationConfig


class DeepseekV3MoECalibrate(torch.nn.Module):
"""
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
"""

def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE):
def __init__(
self,
config: DeepseekV3Config,
original: OriginalDeepseekV3MoE,
calib_config: CalibrationConfig,
):
super().__init__()
self.config = config
self.experts = original.experts
self.gate = original.gate
self.shared_experts = original.shared_experts

self.calib_config = calib_config

if not calib_config.moe_calibrate_gated_acts:
self.gate.top_k = self.gate.n_routed_experts # ungate experts

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residuals = hidden_states
orig_shape = hidden_states.shape
Expand All @@ -35,19 +47,30 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
mask = expert_mask[expert_idx]
token_indices, weight_indices = torch.where(mask)

expert_weights = topk_weights[token_indices, weight_indices]
expert_input = hidden_states[token_indices]
expert_output = expert(expert_input)
weighted_output = expert_output * expert_weights.unsqueeze(-1)
has_tokens = token_indices.numel() > 0

if self.calib_config.moe_calibrate_all_experts or has_tokens:
# calibrate one expert
expert_weights = topk_weights[token_indices, weight_indices]
expert_input = hidden_states[token_indices]
expert_output = expert(expert_input)
weighted_output = expert_output * expert_weights.unsqueeze(-1)

if token_indices.numel() > 0:
final_hidden_states.index_add_(0, token_indices, weighted_output)
if has_tokens and self.calib_config.moe_calibrate_gated_acts:
# expert contributes to output activations
final_hidden_states.index_add_(0, token_indices, weighted_output)
# End MoE

hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape)
hidden_states = hidden_states + self.shared_experts(residuals)
return hidden_states


def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE):
return DeepseekV3MoECalibrate(config=config, original=module)
def replace(
config: DeepseekV3Config,
module: OriginalDeepseekV3MoE,
calib_config: CalibrationConfig,
):
return DeepseekV3MoECalibrate(
config=config, original=module, calib_config=calib_config
)
28 changes: 24 additions & 4 deletions src/llmcompressor/modeling/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@
Llama4TextMoe,
)

from llmcompressor.modeling.config import CalibrationConfig
from llmcompressor.utils.dev import skip_weights_initialize


class SequentialLlama4TextMoe(torch.nn.Module):
def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):
def __init__(
self,
config: Llama4TextConfig,
original: Llama4TextMoe,
calib_config: CalibrationConfig,
):
super().__init__()
self.top_k = config.num_experts_per_tok
self.hidden_dim = config.hidden_size
Expand All @@ -24,6 +30,8 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):
self.router = original.router
self.shared_expert = original.shared_expert

self.calib_config = calib_config

def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]:
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = self.router(hidden_states)
Expand All @@ -39,7 +47,15 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tens

out = self.shared_expert(hidden_states)
for i in range(self.num_experts):
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)
score = router_scores[i]
has_tokens = torch.any(score > 0)

if self.calib_config.moe_calibrate_all_experts or has_tokens:
expert_output = self.experts[i](hidden_states)

if has_tokens and self.calib_config.moe_calibrate_gated_acts:
# expert contributes to output activations
out += expert_output * score.unsqueeze(-1)

return out, router_scores

Expand All @@ -64,5 +80,9 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts):
self[i].down_proj.weight.data = down.t().clone().contiguous()


def replace(config: Llama4Config, module: Llama4TextMoe):
return SequentialLlama4TextMoe(config=config.get_text_config(), original=module)
def replace(
config: Llama4Config, module: Llama4TextMoe, calib_config: CalibrationConfig
):
return SequentialLlama4TextMoe(
config=config.get_text_config(), original=module, calib_config=calib_config
)
40 changes: 34 additions & 6 deletions src/llmcompressor/modeling/prepare.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,39 @@
from compressed_tensors.utils import replace_module
from transformers import PreTrainedModel

from llmcompressor.modeling.config import CalibrationConfig
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
from llmcompressor.modeling.llama4 import replace as replace_llama4
from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE
from llmcompressor.utils.helpers import patch_attr

__all__ = ["replace_modules_for_calibration"]


# ---------------------- module replacements; permanent -------------------------
replacements = {
"DeepseekV3MoE": replace_deepseekv3,
"Llama4TextMoe": replace_llama4,
}


def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
def replace_modules_for_calibration(
model: PreTrainedModel,
moe_calibrate_all_experts: bool = True,
moe_calibrate_gated_acts: bool = True,
) -> PreTrainedModel:

calib_config = CalibrationConfig(
moe_calibrate_all_experts=moe_calibrate_all_experts,
moe_calibrate_gated_acts=moe_calibrate_gated_acts,
)

for name, module in model.named_modules():
cls_name = module.__class__.__name__
if cls_name in replacements:
new_module = replacements[cls_name](config=model.config, module=module)
new_module = replacements[cls_name](
config=model.config, module=module, calib_config=calib_config
)
replace_module(model, name, new_module)

return model
Expand All @@ -28,7 +42,7 @@ def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
# ------------------- module replacements; during calibration --------------------


def update_qwen3_moe(model, stack):
def update_qwen3_moe(model, stack, calib_config):
for module in model.modules():
cls_name = module.__class__.__name__
if cls_name == "Qwen3MoeDecoderLayer":
Expand All @@ -37,7 +51,11 @@ def update_qwen3_moe(model, stack):
patch_attr(
module,
"mlp",
replace_Qwen3MoE(config=model.config, module=module.mlp),
replace_Qwen3MoE(
config=model.config,
module=module.mlp,
calib_config=calib_config,
),
)
)

Expand All @@ -47,9 +65,19 @@ def update_qwen3_moe(model, stack):
}


def moe_calibration_context(model: PreTrainedModel, stack):
def moe_calibration_context(
model: PreTrainedModel,
stack,
moe_calibrate_all_experts: bool = True,
moe_calibrate_gated_acts: bool = True,
):
calib_config = CalibrationConfig(
moe_calibrate_all_experts=moe_calibrate_all_experts,
moe_calibrate_gated_acts=moe_calibrate_gated_acts,
)

# Temporarily updates the MoE modules within the context
# Once the context exists, parameter updates persist
cls_name = model.__class__.__name__
if cls_name in moe_context:
moe_context.get(cls_name)(model, stack)
moe_context.get(cls_name)(model, stack, calib_config)
87 changes: 63 additions & 24 deletions src/llmcompressor/modeling/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,61 @@
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock,
)

from llmcompressor.modeling.config import CalibrationConfig


class Qwen3MoeSparseMoeBlock(torch.nn.Module):
def __init__(
self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock
self,
config: Qwen3MoeConfig,
original: OriginalQwen3MoeSparseMoeBlock,
calib_config: CalibrationConfig,
):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.top_k
self.norm_topk_prob = config.norm_topk_prob
self.calib_config = calib_config

# gating
self.gate = original.gate
self.experts = original.experts

if not self.calib_config.moe_calibrate_gated_acts:
self.gate.top_k = self.num_experts # ungate experts

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

routing_weights = torch.nn.functional.softmax(
router_logits, dim=1, dtype=torch.float
)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
if self.calib_config.moe_calibrate_gated_acts:
routing_weights = torch.nn.functional.softmax(
router_logits, dim=1, dtype=torch.float
)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
# only diff with mixtral sparse moe block!
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

else:
# ungate experts
selected_experts = torch.arange(
self.num_experts, device=hidden_states.device
)
selected_experts = selected_experts.unsqueeze(0).expand(
hidden_states.shape[0], -1
)
routing_weights = (
torch.ones_like(selected_experts, dtype=hidden_states.dtype)
/ self.num_experts
)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
Expand All @@ -65,23 +90,37 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
for expert_idx in range(len(self.experts)):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
expert_output = expert_layer(current_state)
current_hidden_states = expert_output * routing_weights[top_x, idx, None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)

has_tokens = idx.numel() > 0

if self.calib_config.moe_calibrate_all_experts or has_tokens:
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
expert_output = expert_layer(current_state)
current_hidden_states = (
expert_output * routing_weights[top_x, idx, None]
)

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
if has_tokens and self.calib_config.moe_calibrate_gated_acts:
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)

final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim
)
return final_hidden_states, router_logits


def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock):
return Qwen3MoeSparseMoeBlock(config=config, original=module)
def replace(
config: Qwen3MoeConfig,
module: OriginalQwen3MoeSparseMoeBlock,
calib_config: CalibrationConfig,
):
return Qwen3MoeSparseMoeBlock(
config=config, original=module, calib_config=calib_config
)
16 changes: 16 additions & 0 deletions tests/llmcompressor/modeling/test_calib_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from unittest.mock import MagicMock

import pytest

from llmcompressor.modeling.prepare import replace_modules_for_calibration


def test_calib_config():
model = MagicMock()
with pytest.raises(NotImplementedError) as exc_info:
replace_modules_for_calibration(model, False, False)

assert str(exc_info.value) == (
"At least one of moe_calibrate_gated_acts or "
"moe_calibrate_all_experts must be set to True."
)
Loading