Skip to content

Commit 0123644

Browse files
authored
[Calibration] Add MoE Calibration Context (#1596)
# Summary: - Introduce an `moe_calibration_context` which during calibration, replaces MoE blocks with custom modules which are needed to properly calibrate all experts requiring data - The context can be optionally enabled through a new `calibrate_moe_context` argument which if set to True, will enable the context - Modules are replaced with new definitions defined in the `prepare` folder (shared with `replace_modules_for_calibration`) - This enables a second pathway for calibrating MoEs and other models that require updates to their modules to be compatible with llm-compressor: 1. Replacing modules during calibration 2. Replacing modules permanently (as done by `replace_modules_for_calibration`, previously called `prepare_for_calibration`). - Similar to `replace_modules_for_calibration`, a dictionary defining the replacement has been added: `moe_context` # Testing - Tested with a `Qwen/Qwen3-30B-A3B` NVFP4 example and added the example to the folder as well # Next Steps: - Definitions for updated the MoE modules are hardcoded atm - we want to expand and add additional parameters to have more control over the MoE forward pass, such as through parameters defined here: #1593 - this is especially important if we find a certain configuration results in optimal calibration - We may find it easier to refactor out calibration args into their own pydantic model and not put everything under datraset args
1 parent 2c70cb0 commit 0123644

File tree

13 files changed

+264
-20
lines changed

13 files changed

+264
-20
lines changed

examples/multimodal_vision/llama4_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from transformers import Llama4ForConditionalGeneration, Llama4Processor
44

55
from llmcompressor import oneshot
6-
from llmcompressor.modeling import prepare_for_calibration
6+
from llmcompressor.modeling import replace_modules_for_calibration
77
from llmcompressor.modifiers.quantization import GPTQModifier
88

99
# Select model and load it.
@@ -14,7 +14,7 @@
1414
# This change allows compatibility with vllm.
1515
# To apply your own custom module for experimentation, consider updating
1616
# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
17-
model = prepare_for_calibration(model)
17+
model = replace_modules_for_calibration(model)
1818

1919
DATASET_ID = "neuralmagic/calibration"
2020
NUM_CALIBRATION_SAMPLES = 512

examples/quantization_w4a4_fp4/llama4_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from transformers import Llama4ForConditionalGeneration, Llama4Processor
44

55
from llmcompressor import oneshot
6-
from llmcompressor.modeling import prepare_for_calibration
6+
from llmcompressor.modeling import replace_modules_for_calibration
77
from llmcompressor.modifiers.quantization import QuantizationModifier
88

99
# Select model and load it.
@@ -14,7 +14,7 @@
1414
# This change allows compatibility with vllm.
1515
# To apply your own custom module for experimentation, consider updating
1616
# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py
17-
model = prepare_for_calibration(model)
17+
model = replace_modules_for_calibration(model)
1818

1919
DATASET_ID = "neuralmagic/calibration"
2020
NUM_CALIBRATION_SAMPLES = 20
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
from llmcompressor.utils import dispatch_for_generation
7+
8+
MODEL_ID = "Qwen/Qwen3-30B-A3B"
9+
10+
# Load model.
11+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
12+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13+
14+
15+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
16+
DATASET_SPLIT = "train_sft"
17+
18+
# Select number of samples
19+
NUM_CALIBRATION_SAMPLES = 200
20+
MAX_SEQUENCE_LENGTH = 2048
21+
22+
# Load dataset and preprocess.
23+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
24+
ds = ds.shuffle(seed=42)
25+
26+
27+
def preprocess(example):
28+
return {
29+
"text": tokenizer.apply_chat_template(
30+
example["messages"],
31+
tokenize=False,
32+
)
33+
}
34+
35+
36+
ds = ds.map(preprocess)
37+
38+
39+
# Tokenize inputs.
40+
def tokenize(sample):
41+
return tokenizer(
42+
sample["text"],
43+
padding=False,
44+
max_length=MAX_SEQUENCE_LENGTH,
45+
truncation=True,
46+
add_special_tokens=False,
47+
)
48+
49+
50+
ds = ds.map(tokenize, remove_columns=ds.column_names)
51+
52+
# Configure the quantization algorithm and scheme.
53+
# In this case, we:
54+
# * quantize the weights to fp4 with per group 16 via ptq
55+
# * calibrate a global_scale for activations, which will be used to
56+
# quantize activations to fp4 on the fly
57+
recipe = QuantizationModifier(
58+
targets="Linear", scheme="NVFP4", ignore=["lm_head", "re:.*mlp.gate$"]
59+
)
60+
61+
# Apply quantization.
62+
# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock`
63+
# during calibration
64+
oneshot(
65+
model=model,
66+
dataset=ds,
67+
recipe=recipe,
68+
max_seq_length=MAX_SEQUENCE_LENGTH,
69+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
70+
calibrate_moe_context=True,
71+
)
72+
73+
74+
print("\n\n")
75+
print("========== SAMPLE GENERATION ==============")
76+
dispatch_for_generation(model)
77+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to("cuda")
78+
output = model.generate(input_ids, max_new_tokens=100)
79+
print(tokenizer.decode(output[0]))
80+
print("==========================================\n\n")
81+
82+
83+
# Save to disk in compressed-tensors format.
84+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4"
85+
model.save_pretrained(SAVE_DIR, save_compressed=True)
86+
tokenizer.save_pretrained(SAVE_DIR)

examples/quantizing_moe/deepseek_r1_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from datasets import load_dataset
22
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
33

4-
from llmcompressor.modeling import prepare_for_calibration
4+
from llmcompressor.modeling import replace_modules_for_calibration
55
from llmcompressor.modifiers.quantization import GPTQModifier
66
from llmcompressor.transformers import oneshot
77

@@ -20,7 +20,7 @@
2020
model_id, torch_dtype="auto", config=config
2121
)
2222
tokenizer = AutoTokenizer.from_pretrained(model_id)
23-
model = prepare_for_calibration(model)
23+
model = replace_modules_for_calibration(model)
2424

2525
# Select calibration dataset.
2626
DATASET_ID = "HuggingFaceH4/ultrachat_200k"

src/llmcompressor/args/dataset_arguments.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,16 @@ class DatasetArguments(CustomDatasetArguments):
117117
default=512,
118118
metadata={"help": "Number of samples to use for one-shot calibration"},
119119
)
120+
calibrate_moe_context: bool = field(
121+
default=False,
122+
metadata={
123+
"help": "If during calibration, the MoE context should be enabled "
124+
"for the given model. This usually involves updating all MoE modules "
125+
"in the model for the duration of calibration. See moe_context under "
126+
"modeling/prepare.py for a list of supported MoEs and their updated "
127+
"module definitions"
128+
},
129+
)
120130
shuffle_calibration_samples: Optional[bool] = field(
121131
default=True,
122132
metadata={

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,11 @@ def apply_recipe_modifiers(
189189
user_pipeline = self.dataset_args.pipeline
190190
modifiers = session.lifecycle.recipe.modifiers
191191
pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline)
192-
pipeline(self.model, calibration_dataloader, self.dataset_args)
192+
pipeline(
193+
self.model,
194+
calibration_dataloader,
195+
self.dataset_args,
196+
)
193197

194198
session.finalize()
195199

@@ -227,6 +231,7 @@ def oneshot(
227231
overwrite_cache: bool = False,
228232
preprocessing_num_workers: Optional[int] = None,
229233
min_tokens_per_module: Optional[float] = None,
234+
calibrate_moe_context: bool = False,
230235
# Miscellaneous arguments
231236
output_dir: Optional[str] = None,
232237
log_dir: Optional[str] = "sparse_logs",

src/llmcompressor/modeling/deepseek_v3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import torch
22
from transformers.models.deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config
3-
from transformers.models.deepseek_v3.modeling_deepseek_v3 import DeepseekV3MoE
4-
5-
__all__ = ["DeepseekV3MoECalibrate"]
3+
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
4+
DeepseekV3MoE as OriginalDeepseekV3MoE,
5+
)
66

77

88
class DeepseekV3MoECalibrate(torch.nn.Module):
99
"""
1010
Patched DeepseekV3MoE which sends all tokens to all experts for calibration
1111
"""
1212

13-
def __init__(self, config: DeepseekV3Config, original: DeepseekV3MoE):
13+
def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE):
1414
super().__init__()
1515
self.config = config
1616
self.experts = original.experts
@@ -49,5 +49,5 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
4949
return hidden_states
5050

5151

52-
def replace(config: DeepseekV3Config, module: DeepseekV3MoE):
52+
def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE):
5353
return DeepseekV3MoECalibrate(config=config, original=module)

src/llmcompressor/modeling/llama4.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111

1212
from llmcompressor.utils.dev import skip_weights_initialize
1313

14-
__all__ = ["SequentialLlama4TextMoe"]
15-
1614

1715
class SequentialLlama4TextMoe(torch.nn.Module):
1816
def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe):

src/llmcompressor/modeling/prepare.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,53 @@
33

44
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
55
from llmcompressor.modeling.llama4 import replace as replace_llama4
6+
from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE
7+
from llmcompressor.utils.helpers import patch_attr
68

7-
__all__ = ["prepare_for_calibration"]
9+
__all__ = ["replace_modules_for_calibration"]
810

11+
# ---------------------- module replacements; permanent -------------------------
912
replacements = {
1013
"DeepseekV3MoE": replace_deepseekv3,
1114
"Llama4TextMoe": replace_llama4,
1215
}
1316

1417

15-
def prepare_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
18+
def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel:
1619
for name, module in model.named_modules():
1720
cls_name = module.__class__.__name__
1821
if cls_name in replacements:
1922
new_module = replacements[cls_name](config=model.config, module=module)
2023
replace_module(model, name, new_module)
2124

2225
return model
26+
27+
28+
# ------------------- module replacements; during calibration --------------------
29+
30+
31+
def update_qwen3_moe(model, stack):
32+
for module in model.modules():
33+
cls_name = module.__class__.__name__
34+
if cls_name == "Qwen3MoeDecoderLayer":
35+
# Optionally update the model.config to pass in other arguments
36+
stack.enter_context(
37+
patch_attr(
38+
module,
39+
"mlp",
40+
replace_Qwen3MoE(config=model.config, module=module.mlp),
41+
)
42+
)
43+
44+
45+
moe_context = {
46+
"Qwen3MoeForCausalLM": update_qwen3_moe,
47+
}
48+
49+
50+
def moe_calibration_context(model: PreTrainedModel, stack):
51+
# Temporarily updates the MoE modules within the context
52+
# Once the context exists, parameter updates persist
53+
cls_name = model.__class__.__name__
54+
if cls_name in moe_context:
55+
moe_context.get(cls_name)(model, stack)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# coding=utf-8
2+
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
3+
# All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import torch
18+
from transformers.models import Qwen3MoeConfig
19+
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
20+
Qwen3MoeSparseMoeBlock as OriginalQwen3MoeSparseMoeBlock,
21+
)
22+
23+
24+
class Qwen3MoeSparseMoeBlock(torch.nn.Module):
25+
def __init__(
26+
self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock
27+
):
28+
super().__init__()
29+
self.num_experts = config.num_experts
30+
self.top_k = config.top_k
31+
self.norm_topk_prob = config.norm_topk_prob
32+
33+
# gating
34+
self.gate = original.gate
35+
self.experts = original.experts
36+
37+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
38+
batch_size, sequence_length, hidden_dim = hidden_states.shape
39+
hidden_states = hidden_states.view(-1, hidden_dim)
40+
# router_logits: (batch * sequence_length, n_experts)
41+
router_logits = self.gate(hidden_states)
42+
43+
routing_weights = torch.nn.functional.softmax(
44+
router_logits, dim=1, dtype=torch.float
45+
)
46+
routing_weights, selected_experts = torch.topk(
47+
routing_weights, self.top_k, dim=-1
48+
)
49+
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
50+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
51+
# we cast back to the input dtype
52+
routing_weights = routing_weights.to(hidden_states.dtype)
53+
final_hidden_states = torch.zeros(
54+
(batch_size * sequence_length, hidden_dim),
55+
dtype=hidden_states.dtype,
56+
device=hidden_states.device,
57+
)
58+
59+
# One hot encode the selected experts to create an expert mask
60+
# this will be used to easily index which expert is going to be sollicitated
61+
expert_mask = torch.nn.functional.one_hot(
62+
selected_experts, num_classes=self.num_experts
63+
).permute(2, 1, 0)
64+
65+
for expert_idx in range(len(self.experts)):
66+
expert_layer = self.experts[expert_idx]
67+
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
68+
# Index the correct hidden states and compute the expert hidden state for
69+
# the current expert. We need to make sure to multiply the output hidden
70+
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
71+
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
72+
expert_output = expert_layer(current_state)
73+
current_hidden_states = expert_output * routing_weights[top_x, idx, None]
74+
# However `index_add_` only support torch tensors for indexing so we'll use
75+
# the `top_x` tensor here.
76+
final_hidden_states.index_add_(
77+
0, top_x, current_hidden_states.to(hidden_states.dtype)
78+
)
79+
80+
final_hidden_states = final_hidden_states.reshape(
81+
batch_size, sequence_length, hidden_dim
82+
)
83+
return final_hidden_states, router_logits
84+
85+
86+
def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock):
87+
return Qwen3MoeSparseMoeBlock(config=config, original=module)

0 commit comments

Comments
 (0)