diff --git a/src/llmcompressor/modeling/config.py b/src/llmcompressor/modeling/config.py new file mode 100644 index 000000000..bf9e77b09 --- /dev/null +++ b/src/llmcompressor/modeling/config.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel, model_validator + + +class CalibrationConfig(BaseModel): + moe_calibrate_all_experts: bool + moe_calibrate_gated_acts: bool + + @model_validator(mode="after") + def validate_config(self): + + if not self.moe_calibrate_gated_acts and not self.moe_calibrate_all_experts: + raise NotImplementedError( + "At least one of moe_calibrate_gated_acts or " + "moe_calibrate_all_experts must be set to True." + ) + + return self diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index 287b343bd..42eda362d 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -4,19 +4,31 @@ DeepseekV3MoE as OriginalDeepseekV3MoE, ) +from llmcompressor.modeling.config import CalibrationConfig + class DeepseekV3MoECalibrate(torch.nn.Module): """ Patched DeepseekV3MoE which sends all tokens to all experts for calibration """ - def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE): + def __init__( + self, + config: DeepseekV3Config, + original: OriginalDeepseekV3MoE, + calib_config: CalibrationConfig, + ): super().__init__() self.config = config self.experts = original.experts self.gate = original.gate self.shared_experts = original.shared_experts + self.calib_config = calib_config + + if not calib_config.moe_calibrate_gated_acts: + self.gate.top_k = self.gate.n_routed_experts # ungate experts + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states orig_shape = hidden_states.shape @@ -35,13 +47,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: mask = expert_mask[expert_idx] token_indices, weight_indices = torch.where(mask) - expert_weights = topk_weights[token_indices, weight_indices] - expert_input = hidden_states[token_indices] - expert_output = expert(expert_input) - weighted_output = expert_output * expert_weights.unsqueeze(-1) + has_tokens = token_indices.numel() > 0 + + if self.calib_config.moe_calibrate_all_experts or has_tokens: + # calibrate one expert + expert_weights = topk_weights[token_indices, weight_indices] + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + weighted_output = expert_output * expert_weights.unsqueeze(-1) - if token_indices.numel() > 0: - final_hidden_states.index_add_(0, token_indices, weighted_output) + if has_tokens and self.calib_config.moe_calibrate_gated_acts: + # expert contributes to output activations + final_hidden_states.index_add_(0, token_indices, weighted_output) # End MoE hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape) @@ -49,5 +66,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE): - return DeepseekV3MoECalibrate(config=config, original=module) +def replace( + config: DeepseekV3Config, + module: OriginalDeepseekV3MoE, + calib_config: CalibrationConfig, +): + return DeepseekV3MoECalibrate( + config=config, original=module, calib_config=calib_config + ) diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py index a33833ea1..90bacae67 100644 --- a/src/llmcompressor/modeling/llama4.py +++ b/src/llmcompressor/modeling/llama4.py @@ -11,11 +11,17 @@ Llama4TextMoe, ) +from llmcompressor.modeling.config import CalibrationConfig from llmcompressor.utils.dev import skip_weights_initialize class SequentialLlama4TextMoe(torch.nn.Module): - def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe): + def __init__( + self, + config: Llama4TextConfig, + original: Llama4TextMoe, + calib_config: CalibrationConfig, + ): super().__init__() self.top_k = config.num_experts_per_tok self.hidden_dim = config.hidden_size @@ -24,6 +30,8 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe): self.router = original.router self.shared_expert = original.shared_expert + self.calib_config = calib_config + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]: hidden_states = hidden_states.reshape(-1, self.hidden_dim) router_logits = self.router(hidden_states) @@ -39,7 +47,15 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tens out = self.shared_expert(hidden_states) for i in range(self.num_experts): - out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) + score = router_scores[i] + has_tokens = torch.any(score > 0) + + if self.calib_config.moe_calibrate_all_experts or has_tokens: + expert_output = self.experts[i](hidden_states) + + if has_tokens and self.calib_config.moe_calibrate_gated_acts: + # expert contributes to output activations + out += expert_output * score.unsqueeze(-1) return out, router_scores @@ -64,5 +80,9 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts): self[i].down_proj.weight.data = down.t().clone().contiguous() -def replace(config: Llama4Config, module: Llama4TextMoe): - return SequentialLlama4TextMoe(config=config.get_text_config(), original=module) +def replace( + config: Llama4Config, module: Llama4TextMoe, calib_config: CalibrationConfig +): + return SequentialLlama4TextMoe( + config=config.get_text_config(), original=module, calib_config=calib_config + ) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index cb61f5fad..fb9fca4e9 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,6 +1,7 @@ from compressed_tensors.utils import replace_module from transformers import PreTrainedModel +from llmcompressor.modeling.config import CalibrationConfig from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 from llmcompressor.modeling.llama4 import replace as replace_llama4 from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE @@ -8,6 +9,7 @@ __all__ = ["replace_modules_for_calibration"] + # ---------------------- module replacements; permanent ------------------------- replacements = { "DeepseekV3MoE": replace_deepseekv3, @@ -15,11 +17,23 @@ } -def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: +def replace_modules_for_calibration( + model: PreTrainedModel, + moe_calibrate_all_experts: bool = True, + moe_calibrate_gated_acts: bool = True, +) -> PreTrainedModel: + + calib_config = CalibrationConfig( + moe_calibrate_all_experts=moe_calibrate_all_experts, + moe_calibrate_gated_acts=moe_calibrate_gated_acts, + ) + for name, module in model.named_modules(): cls_name = module.__class__.__name__ if cls_name in replacements: - new_module = replacements[cls_name](config=model.config, module=module) + new_module = replacements[cls_name]( + config=model.config, module=module, calib_config=calib_config + ) replace_module(model, name, new_module) return model @@ -28,7 +42,7 @@ def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: # ------------------- module replacements; during calibration -------------------- -def update_qwen3_moe(model, stack): +def update_qwen3_moe(model, stack, calib_config): for module in model.modules(): cls_name = module.__class__.__name__ if cls_name == "Qwen3MoeDecoderLayer": @@ -37,7 +51,11 @@ def update_qwen3_moe(model, stack): patch_attr( module, "mlp", - replace_Qwen3MoE(config=model.config, module=module.mlp), + replace_Qwen3MoE( + config=model.config, + module=module.mlp, + calib_config=calib_config, + ), ) ) @@ -47,9 +65,19 @@ def update_qwen3_moe(model, stack): } -def moe_calibration_context(model: PreTrainedModel, stack): +def moe_calibration_context( + model: PreTrainedModel, + stack, + moe_calibrate_all_experts: bool = True, + moe_calibrate_gated_acts: bool = True, +): + calib_config = CalibrationConfig( + moe_calibrate_all_experts=moe_calibrate_all_experts, + moe_calibrate_gated_acts=moe_calibrate_gated_acts, + ) + # Temporarily updates the MoE modules within the context # Once the context exists, parameter updates persist cls_name = model.__class__.__name__ if cls_name in moe_context: - moe_context.get(cls_name)(model, stack) + moe_context.get(cls_name)(model, stack, calib_config) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index fcd5d9925..583a0bf6e 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -20,36 +20,61 @@ Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock, ) +from llmcompressor.modeling.config import CalibrationConfig + class Qwen3MoeSparseMoeBlock(torch.nn.Module): def __init__( - self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock + self, + config: Qwen3MoeConfig, + original: OriginalQwen3MoeSparseMoeBlock, + calib_config: CalibrationConfig, ): super().__init__() self.num_experts = config.num_experts self.top_k = config.top_k self.norm_topk_prob = config.norm_topk_prob + self.calib_config = calib_config # gating self.gate = original.gate self.experts = original.experts + if not self.calib_config.moe_calibrate_gated_acts: + self.gate.top_k = self.num_experts # ungate experts + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) router_logits = self.gate(hidden_states) - routing_weights = torch.nn.functional.softmax( - router_logits, dim=1, dtype=torch.float - ) - routing_weights, selected_experts = torch.topk( - routing_weights, self.top_k, dim=-1 - ) - if self.norm_topk_prob: # only diff with mixtral sparse moe block! - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) + if self.calib_config.moe_calibrate_gated_acts: + routing_weights = torch.nn.functional.softmax( + router_logits, dim=1, dtype=torch.float + ) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + # only diff with mixtral sparse moe block! + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + else: + # ungate experts + selected_experts = torch.arange( + self.num_experts, device=hidden_states.device + ) + selected_experts = selected_experts.unsqueeze(0).expand( + hidden_states.shape[0], -1 + ) + routing_weights = ( + torch.ones_like(selected_experts, dtype=hidden_states.dtype) + / self.num_experts + ) + final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, @@ -65,17 +90,25 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: for expert_idx in range(len(self.experts)): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - expert_output = expert_layer(current_state) - current_hidden_states = expert_output * routing_weights[top_x, idx, None] - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_( - 0, top_x, current_hidden_states.to(hidden_states.dtype) - ) + + has_tokens = idx.numel() > 0 + + if self.calib_config.moe_calibrate_all_experts or has_tokens: + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + expert_output = expert_layer(current_state) + current_hidden_states = ( + expert_output * routing_weights[top_x, idx, None] + ) + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + if has_tokens and self.calib_config.moe_calibrate_gated_acts: + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) final_hidden_states = final_hidden_states.reshape( batch_size, sequence_length, hidden_dim @@ -83,5 +116,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits -def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock): - return Qwen3MoeSparseMoeBlock(config=config, original=module) +def replace( + config: Qwen3MoeConfig, + module: OriginalQwen3MoeSparseMoeBlock, + calib_config: CalibrationConfig, +): + return Qwen3MoeSparseMoeBlock( + config=config, original=module, calib_config=calib_config + ) diff --git a/tests/llmcompressor/modeling/test_calib_config.py b/tests/llmcompressor/modeling/test_calib_config.py new file mode 100644 index 000000000..1ed8f8ab3 --- /dev/null +++ b/tests/llmcompressor/modeling/test_calib_config.py @@ -0,0 +1,16 @@ +from unittest.mock import MagicMock + +import pytest + +from llmcompressor.modeling.prepare import replace_modules_for_calibration + + +def test_calib_config(): + model = MagicMock() + with pytest.raises(NotImplementedError) as exc_info: + replace_modules_for_calibration(model, False, False) + + assert str(exc_info.value) == ( + "At least one of moe_calibrate_gated_acts or " + "moe_calibrate_all_experts must be set to True." + ) diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py new file mode 100644 index 000000000..000eab55a --- /dev/null +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -0,0 +1,51 @@ +from functools import partial + +import pytest +import torch +from transformers import AutoModelForCausalLM + +from llmcompressor.modeling.deepseek_v3 import DeepseekV3MoECalibrate +from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.utils.dev import skip_weights_download + + +@pytest.mark.parametrize("model_stub", ["unsloth/DeepSeek-R1-0528-BF16"]) +def test_calib_replace_deepseekv3moe_all_experts(model_stub): + with skip_weights_download(): + model = AutoModelForCausalLM.from_pretrained(model_stub) + + replace_modules_for_calibration( + model, moe_calibrate_gated_acts=False, moe_calibrate_all_experts=True + ) + + # Find a Deepseek MoE layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, DeepseekV3MoECalibrate): + moe_layer = module + break + + assert moe_layer is not None + + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] + + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True + + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) + + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) + + # Assert all experts are used + assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py new file mode 100644 index 000000000..98db84a18 --- /dev/null +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -0,0 +1,54 @@ +from functools import partial + +import pytest +import torch +from transformers import Llama4ForConditionalGeneration + +from llmcompressor.modeling.llama4 import SequentialLlama4TextMoe +from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.utils.dev import skip_weights_download + + +@pytest.mark.skip("not fully tested yet") +@pytest.mark.parametrize("model_stub", ["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) +def test_calib_replace_llama4_moe_all_experts(model_stub): + with skip_weights_download(Llama4ForConditionalGeneration): + model = Llama4ForConditionalGeneration.from_pretrained( + model_stub, torch_dtype="auto" + ) + + replace_modules_for_calibration( + model, moe_calibrate_gated_acts=False, moe_calibrate_all_experts=True + ) + + # Find a Llama4 MoE layer + moe_layer = None + for _, module in model.modules(): + if isinstance(module, SequentialLlama4TextMoe): + moe_layer = module + break + + assert moe_layer is not None + + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] + + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True + + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) + + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) + + # Assert all experts are used + assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" diff --git a/tests/llmcompressor/modeling/test_calib_qwen3.py b/tests/llmcompressor/modeling/test_calib_qwen3.py new file mode 100644 index 000000000..2b97ec675 --- /dev/null +++ b/tests/llmcompressor/modeling/test_calib_qwen3.py @@ -0,0 +1,60 @@ +import contextlib +from functools import partial + +import pytest +import torch +from transformers import AutoModelForCausalLM + +from llmcompressor.modeling.prepare import moe_calibration_context +from llmcompressor.modeling.qwen3_moe import Qwen3MoeSparseMoeBlock +from llmcompressor.utils.dev import skip_weights_download +from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context + + +@pytest.mark.parametrize("model_stub", ["Qwen/Qwen3-30B-A3B"]) +def test_calib_replace_qwen3moe_all_experts(model_stub): + with skip_weights_download(): + model = AutoModelForCausalLM.from_pretrained(model_stub) + + # Qwen3MoE layer replacement is temporary within the context + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(DisableQuantization(model)) + + moe_calibration_context( + model, stack, moe_calibrate_gated_acts=False, moe_calibrate_all_experts=True + ) + + # Find one MoE layer + moe_layer = None + for name, module in model.named_modules(): + if isinstance(module, Qwen3MoeSparseMoeBlock): + moe_layer = module + break + + assert moe_layer is not None + + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] + + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True + + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) + + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) + + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}"