Skip to content

Commit fae9429

Browse files
authored
[Qwen3Next] Add calibration support and NVFP4 Example (#1889)
SUMMARY: - Add calibration support for Qwen3-Next - Add an NVFP4 example - Update moe calibration context to support the model TEST PLAN: For Qwen3-30B-A3B-NVFP4: ``` |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8825|± |0.0089| | | |strict-match | 5|exact_match|↑ |0.8802|± |0.0089| ``` Qwen3-Next - >96% recovery Base ``` |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8514|± |0.0098| | | |strict-match | 5|exact_match|↑ |0.8089|± |0.0108| ``` NVFP4: ``` |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr| |-----|------:|----------------|-----:|-----------|---|-----:|---|-----:| |gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.8400|± |0.0101| | | |strict-match | 5|exact_match|↑ |0.7733|± |0.0115| ```
1 parent 640147b commit fae9429

File tree

4 files changed

+267
-18
lines changed

4 files changed

+267
-18
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from datasets import load_dataset
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
from llmcompressor.utils import dispatch_for_generation
7+
8+
# NOTE: Qwen3-Next-80B-A3B-Instruct support is not in transformers<=4.56.2
9+
# you may need to install transformers from source
10+
11+
MODEL_ID = "Qwen/Qwen3-Next-80B-A3B-Instruct"
12+
13+
# Load model.
14+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
15+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16+
17+
18+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
19+
DATASET_SPLIT = "train_sft"
20+
21+
# Select number of samples
22+
NUM_CALIBRATION_SAMPLES = 20
23+
MAX_SEQUENCE_LENGTH = 2048
24+
25+
# Load dataset and preprocess.
26+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
27+
ds = ds.shuffle(seed=42)
28+
29+
30+
def preprocess(example):
31+
return {
32+
"text": tokenizer.apply_chat_template(
33+
example["messages"],
34+
tokenize=False,
35+
)
36+
}
37+
38+
39+
ds = ds.map(preprocess)
40+
41+
42+
# Tokenize inputs.
43+
def tokenize(sample):
44+
return tokenizer(
45+
sample["text"],
46+
padding=False,
47+
max_length=MAX_SEQUENCE_LENGTH,
48+
truncation=True,
49+
add_special_tokens=False,
50+
)
51+
52+
53+
ds = ds.map(tokenize, remove_columns=ds.column_names)
54+
55+
# Configure the quantization algorithm and scheme.
56+
# In this case, we:
57+
# * quantize the weights to fp4 with per group 16 via ptq
58+
# * calibrate a global_scale for activations, which will be used to
59+
# quantize activations to fp4 on the fly
60+
recipe = QuantizationModifier(
61+
targets="Linear",
62+
scheme="NVFP4",
63+
ignore=[
64+
"lm_head",
65+
"re:.*mlp.gate$",
66+
"re:.*mlp.shared_expert_gate$",
67+
"re:.*linear_attn.*",
68+
],
69+
)
70+
71+
# Apply quantization.
72+
# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock`
73+
# during calibration.
74+
# Feel free to update the definition under
75+
# llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with
76+
# this behaviour and evaluate its impact on quantization performance
77+
oneshot(
78+
model=model,
79+
dataset=ds,
80+
recipe=recipe,
81+
max_seq_length=MAX_SEQUENCE_LENGTH,
82+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
83+
calibrate_moe_context=True,
84+
)
85+
86+
87+
print("\n\n")
88+
print("========== SAMPLE GENERATION ==============")
89+
dispatch_for_generation(model)
90+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
91+
model.device
92+
)
93+
output = model.generate(input_ids, max_new_tokens=100)
94+
print(tokenizer.decode(output[0]))
95+
print("==========================================\n\n")
96+
97+
98+
# Save to disk in compressed-tensors format.
99+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4"
100+
model.save_pretrained(SAVE_DIR, save_compressed=True)
101+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/args/dataset_arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ class DatasetArguments(CustomDatasetArguments):
201201
"_prepare_4d_causal_attention_mask",
202202
"_prepare_fsmt_decoder_inputs",
203203
"_prepare_4d_causal_attention_mask_with_cache_position",
204+
"_update_linear_attn_mask",
204205
],
205206
metadata={
206207
"help": "List of functions to ignore during tracing, either "

src/llmcompressor/modeling/prepare.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
66
from llmcompressor.modeling.llama4 import replace as replace_llama4
77
from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE
8+
9+
try:
10+
from llmcompressor.modeling.qwen3_next_moe import replace as replace_Qwen3NextMoE
11+
except ImportError:
12+
replace_Qwen3NextMoE = None
813
from llmcompressor.utils.helpers import patch_attr
914

1015
__all__ = ["replace_modules_for_calibration"]
@@ -36,28 +41,51 @@ def replace_modules_for_calibration(
3641
# ------------------- module replacements; during calibration --------------------
3742

3843

39-
def update_qwen3_moe(model, stack, calibrate_all_experts):
40-
for module in model.modules():
41-
cls_name = module.__class__.__name__
42-
if cls_name == "Qwen3MoeDecoderLayer":
43-
# Optionally update the model.config to pass in other arguments
44-
stack.enter_context(
45-
patch_attr(
46-
module,
47-
"mlp",
48-
replace_Qwen3MoE(
49-
config=model.config,
50-
module=module.mlp,
51-
calibrate_all_experts=calibrate_all_experts,
52-
),
53-
)
44+
def update_qwen3_moe(model, module, stack, calibrate_all_experts):
45+
cls_name = module.__class__.__name__
46+
if (
47+
cls_name == "Qwen3MoeDecoderLayer"
48+
and module.mlp.__class__.__name__ == "Qwen3MoeSparseMoeBlock"
49+
):
50+
stack.enter_context(
51+
patch_attr(
52+
module,
53+
"mlp",
54+
replace_Qwen3MoE(
55+
config=model.config,
56+
module=module.mlp,
57+
calibrate_all_experts=calibrate_all_experts,
58+
),
59+
)
60+
)
61+
62+
63+
def update_qwen3_next_moe(model, module, stack, calibrate_all_experts):
64+
cls_name = module.__class__.__name__
65+
if (
66+
cls_name == "Qwen3NextDecoderLayer"
67+
and module.mlp.__class__.__name__ == "Qwen3NextSparseMoeBlock"
68+
):
69+
stack.enter_context(
70+
patch_attr(
71+
module,
72+
"mlp",
73+
replace_Qwen3NextMoE(
74+
config=model.config,
75+
module=module.mlp,
76+
calibrate_all_experts=calibrate_all_experts,
77+
),
5478
)
79+
)
5580

5681

5782
moe_context = {
5883
"Qwen3MoeForCausalLM": update_qwen3_moe,
5984
}
6085

86+
if replace_Qwen3NextMoE is not None:
87+
moe_context["Qwen3NextForCausalLM"] = update_qwen3_next_moe
88+
6189

6290
def moe_calibration_context(
6391
model: PreTrainedModel,
@@ -66,6 +94,7 @@ def moe_calibration_context(
6694
):
6795
# Temporarily updates the MoE modules within the context
6896
# Once the context exists, parameter updates persist
69-
cls_name = model.__class__.__name__
70-
if cls_name in moe_context:
71-
moe_context.get(cls_name)(model, stack, calibrate_all_experts)
97+
model_name = model.__class__.__name__
98+
if model_name in moe_context:
99+
for module in model.modules():
100+
moe_context[model_name](model, module, stack, calibrate_all_experts)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# coding=utf-8
2+
# Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team.
3+
# All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import torch
18+
from transformers.models import Qwen3NextConfig
19+
from transformers.models.qwen3_next.modeling_qwen3_next import (
20+
Qwen3NextSparseMoeBlock as OriginalQwen3NextMoeSparseMoeBlock,
21+
)
22+
23+
24+
class Qwen3NextSparseMoeBlock(torch.nn.Module):
25+
def __init__(
26+
self,
27+
config: Qwen3NextConfig,
28+
original: OriginalQwen3NextMoeSparseMoeBlock,
29+
calibrate_all_experts: bool,
30+
):
31+
super().__init__()
32+
self.num_experts = config.num_experts
33+
self.top_k = config.top_k
34+
self.norm_topk_prob = config.norm_topk_prob
35+
36+
# gating
37+
self.calibrate_all_experts = calibrate_all_experts
38+
self.gate = original.gate
39+
self.experts = original.experts
40+
41+
self.shared_expert = original.shared_expert
42+
self.shared_expert_gate = original.shared_expert_gate
43+
44+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
45+
batch_size, sequence_length, hidden_dim = hidden_states.shape
46+
hidden_states = hidden_states.view(-1, hidden_dim)
47+
# router_logits: (batch * sequence_length, n_experts)
48+
router_logits = self.gate(hidden_states)
49+
50+
routing_weights = torch.nn.functional.softmax(
51+
router_logits, dim=1, dtype=torch.float
52+
)
53+
routing_weights, selected_experts = torch.topk(
54+
routing_weights, self.top_k, dim=-1
55+
)
56+
if self.norm_topk_prob:
57+
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
58+
# we cast back to the input dtype
59+
routing_weights = routing_weights.to(hidden_states.dtype)
60+
61+
final_hidden_states = torch.zeros(
62+
(batch_size * sequence_length, hidden_dim),
63+
dtype=hidden_states.dtype,
64+
device=hidden_states.device,
65+
)
66+
67+
# One hot encode the selected experts to create an expert mask
68+
# this will be used to easily index which expert is going to be
69+
# sollicitated
70+
expert_mask = torch.nn.functional.one_hot(
71+
selected_experts, num_classes=self.num_experts
72+
).permute(2, 1, 0)
73+
74+
# Loop over all available experts in the model and perform the
75+
# computation on each expert
76+
# expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
77+
78+
for expert_idx, expert_layer in enumerate(self.experts):
79+
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0))
80+
81+
if self.calibrate_all_experts:
82+
expert_out = expert_layer(hidden_states)[top_x]
83+
else:
84+
expert_out = expert_layer(hidden_states[top_x])
85+
86+
# Index the correct hidden states and compute the expert hidden
87+
# state for the current expert. We need to make sure to multiply
88+
# the output hidden states by `routing_weights` on the
89+
# corresponding tokens (top-1 and top-2)
90+
if len(top_x) > 0:
91+
current_hidden_states = expert_out * routing_weights[top_x, idx, None]
92+
final_hidden_states.index_add_(
93+
0,
94+
top_x,
95+
current_hidden_states.to(hidden_states.dtype),
96+
)
97+
98+
shared_expert_output = self.shared_expert(hidden_states)
99+
shared_expert_output = (
100+
torch.nn.functional.sigmoid(self.shared_expert_gate(hidden_states))
101+
* shared_expert_output
102+
)
103+
104+
final_hidden_states = final_hidden_states + shared_expert_output
105+
final_hidden_states = final_hidden_states.reshape(
106+
batch_size, sequence_length, hidden_dim
107+
)
108+
return final_hidden_states, router_logits
109+
110+
111+
def replace(
112+
config: Qwen3NextConfig,
113+
module: OriginalQwen3NextMoeSparseMoeBlock,
114+
calibrate_all_experts: bool,
115+
):
116+
return Qwen3NextSparseMoeBlock(
117+
config=config, original=module, calibrate_all_experts=calibrate_all_experts
118+
)

0 commit comments

Comments
 (0)