Skip to content

Commit 1d428f9

Browse files
SehyoScarletBlue
authored andcommitted
feat: add Qwen3.5 MoE calibration module for quantization
Qwen3.5 MoE (Qwen3_5MoeSparseMoeBlock) stores expert weights as fused 3D tensors. Add CalibrationQwen3_5MoeSparseMoeBlock which unfuses these into individual Qwen3_5MoeMLP modules with nn.Linear layers, enabling proper NVFP4 quantization. Uses is_permanent=True so the unfused structure persists through quantization and saving.
1 parent 36c30ee commit 1d428f9

File tree

3 files changed

+250
-1
lines changed

3 files changed

+250
-1
lines changed
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from compressed_tensors.offload import dispatch_model
2+
from datasets import load_dataset
3+
from transformers import AutoModelForCausalLM, AutoTokenizer
4+
5+
from llmcompressor import oneshot
6+
from llmcompressor.modifiers.quantization import QuantizationModifier
7+
8+
MODEL_ID = "Qwen/Qwen3.5-397B-A17B"
9+
10+
# Load model.
11+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype="auto")
12+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13+
14+
15+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
16+
DATASET_SPLIT = "train_sft"
17+
18+
# Select number of samples
19+
NUM_CALIBRATION_SAMPLES = 20
20+
MAX_SEQUENCE_LENGTH = 2048
21+
22+
# Load dataset and preprocess.
23+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
24+
ds = ds.shuffle(seed=42)
25+
26+
27+
def preprocess(example):
28+
return {
29+
"text": tokenizer.apply_chat_template(
30+
example["messages"],
31+
tokenize=False,
32+
)
33+
}
34+
35+
36+
ds = ds.map(preprocess)
37+
38+
39+
# Tokenize inputs.
40+
def tokenize(sample):
41+
return tokenizer(
42+
sample["text"],
43+
padding=False,
44+
max_length=MAX_SEQUENCE_LENGTH,
45+
truncation=True,
46+
add_special_tokens=False,
47+
)
48+
49+
50+
ds = ds.map(tokenize, remove_columns=ds.column_names)
51+
52+
# Configure the quantization algorithm and scheme.
53+
# In this case, we:
54+
# * quantize the weights to fp4 with per group 16 via ptq
55+
# * calibrate a global_scale for activations, which will be used to
56+
# quantize activations to fp4 on the fly
57+
recipe = QuantizationModifier(
58+
targets="Linear",
59+
scheme="NVFP4",
60+
ignore=[
61+
"lm_head",
62+
"re:.*mlp.gate$",
63+
"re:.*mlp.shared_expert_gate$",
64+
"re:.*linear_attn.*",
65+
],
66+
)
67+
68+
# Apply quantization.
69+
# MoE calibration is now handled automatically by the pipeline.
70+
# We set `moe_calibrate_all_experts` to True to ensure all experts receive
71+
# calibration data. This temporarily updates the model definition to use
72+
# `CalibrationQwen3_5MoeSparseMoeBlock` (from `llmcompressor.modeling.qwen3_5_moe`)
73+
# which replaces the original `Qwen3_5MoeSparseMoeBlock` class.
74+
# This unfuses the 3D expert parameters into individual nn.Linear modules
75+
# so they can be targeted by quantization.
76+
# Feel free to update the definition under
77+
# llm-compressor/src/llmcompressor/modeling/qwen3_5_moe.py to play around with
78+
# this behavior and evaluate its impact on quantization performance.
79+
oneshot(
80+
model=model,
81+
dataset=ds,
82+
recipe=recipe,
83+
max_seq_length=MAX_SEQUENCE_LENGTH,
84+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
85+
moe_calibrate_all_experts=True,
86+
)
87+
88+
89+
print("\n\n")
90+
print("========== SAMPLE GENERATION ==============")
91+
dispatch_model(model)
92+
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
93+
model.device
94+
)
95+
output = model.generate(input_ids, max_new_tokens=100)
96+
print(tokenizer.decode(output[0]))
97+
print("==========================================\n\n")
98+
99+
100+
# Save to disk in compressed-tensors format.
101+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-NVFP4"
102+
model.save_pretrained(SAVE_DIR, save_compressed=True)
103+
tokenizer.save_pretrained(SAVE_DIR)

src/llmcompressor/modeling/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from .glm4_moe import CalibrationGlm4MoeMoE # noqa: F401
1515
from .llama4 import SequentialLlama4TextMoe # noqa: F401
1616
from .qwen3_moe import CalibrationQwen3MoeSparseMoeBlock # noqa: F401
17+
from .qwen3_5_moe import CalibrationQwen3_5MoeSparseMoeBlock # noqa: F401
1718
from .qwen3_vl_moe import CalibrateQwen3VLMoeTextSparseMoeBlock # noqa: F401
1819
from .qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock # noqa: F401
19-
# TODO: add granite4, Qwen3Next
20+
# TODO: add granite4
2021

2122
from .fuse import *
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
import torch
6+
import torch.nn.functional as F
7+
8+
from llmcompressor.modeling.moe_context import MoECalibrationModule
9+
from llmcompressor.utils.dev import skip_weights_initialize
10+
11+
if TYPE_CHECKING:
12+
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
13+
Qwen3_5MoeSparseMoeBlock,
14+
)
15+
16+
17+
@MoECalibrationModule.register("Qwen3_5MoeSparseMoeBlock")
18+
class CalibrationQwen3_5MoeSparseMoeBlock(MoECalibrationModule):
19+
"""
20+
Calibration version of Qwen3_5MoeSparseMoeBlock that unfuses 3D expert
21+
parameters into individual MLP modules (nn.Linear) so they can be
22+
individually quantized. Sends all tokens to all experts during calibration.
23+
24+
is_permanent = True because the unfused structure must persist for
25+
quantization to target the individual nn.Linear expert weights.
26+
"""
27+
28+
is_permanent = True
29+
30+
def __init__(
31+
self,
32+
original: Qwen3_5MoeSparseMoeBlock,
33+
config,
34+
calibrate_all_experts: bool = True,
35+
):
36+
super().__init__()
37+
text_config = getattr(config, "text_config", config)
38+
39+
self.num_experts = text_config.num_experts
40+
self.top_k = text_config.num_experts_per_tok
41+
self.hidden_size = text_config.hidden_size
42+
43+
self.calibrate_all_experts = calibrate_all_experts
44+
self.gate = original.gate
45+
self.shared_expert = original.shared_expert
46+
self.shared_expert_gate = original.shared_expert_gate
47+
self.experts = SequentialQwen3_5MoeExperts(text_config, original.experts)
48+
49+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
50+
batch_size, sequence_length, hidden_dim = hidden_states.shape
51+
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
52+
53+
# router: returns (router_logits, router_scores, router_indices)
54+
_, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
55+
56+
# expert mask: (num_experts, top_k, num_tokens)
57+
expert_mask = F.one_hot(
58+
selected_experts, num_classes=self.num_experts
59+
).permute(2, 1, 0)
60+
61+
final_hidden_states = torch.zeros(
62+
(batch_size * sequence_length, hidden_dim),
63+
dtype=hidden_states.dtype,
64+
device=hidden_states.device,
65+
)
66+
67+
for expert_idx, expert_layer in enumerate(self.experts):
68+
idx, token_idx = torch.where(expert_mask[expert_idx].squeeze(0))
69+
70+
if self.calibrate_all_experts:
71+
expert_out = expert_layer(hidden_states_reshaped)[token_idx]
72+
else:
73+
expert_out = expert_layer(hidden_states_reshaped[token_idx])
74+
75+
if len(token_idx) > 0:
76+
current_hidden_states = (
77+
expert_out * routing_weights[token_idx, idx, None]
78+
)
79+
final_hidden_states.index_add_(
80+
0,
81+
token_idx,
82+
current_hidden_states.to(hidden_states.dtype),
83+
)
84+
85+
# shared expert
86+
shared_expert_output = self.shared_expert(hidden_states_reshaped)
87+
shared_expert_output = (
88+
F.sigmoid(self.shared_expert_gate(hidden_states_reshaped))
89+
* shared_expert_output
90+
)
91+
final_hidden_states = final_hidden_states + shared_expert_output
92+
93+
final_hidden_states = final_hidden_states.reshape(
94+
batch_size, sequence_length, hidden_dim
95+
)
96+
return final_hidden_states
97+
98+
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
99+
return self
100+
101+
102+
class SequentialQwen3_5MoeExperts(torch.nn.ModuleList):
103+
"""
104+
Unfuses 3D expert parameter tensors into individual Qwen3_5MoeMLP modules
105+
so that each expert's weights are nn.Linear and can be targeted by
106+
quantization with targets="Linear".
107+
"""
108+
109+
def __init__(self, config, original):
110+
from compressed_tensors.offload import disable_onloading
111+
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
112+
Qwen3_5MoeMLP,
113+
)
114+
115+
self.num_experts = config.num_experts
116+
intermediate_size = config.moe_intermediate_size
117+
118+
with skip_weights_initialize():
119+
super().__init__(
120+
[
121+
Qwen3_5MoeMLP(config, intermediate_size=intermediate_size)
122+
for _ in range(self.num_experts)
123+
]
124+
)
125+
126+
# Access expert weights on CPU to avoid GPU OOM.
127+
# disable_onloading() makes OffloadCache return the offloaded (CPU)
128+
# values directly instead of onloading to GPU.
129+
with disable_onloading():
130+
gate_up_data = original.gate_up_proj.data # [num_experts, 2*inter, hidden]
131+
down_data = original.down_proj.data # [num_experts, hidden, inter]
132+
133+
for i in range(self.num_experts):
134+
gate_up = gate_up_data[i] # [2*intermediate, hidden]
135+
down = down_data[i] # [hidden, intermediate]
136+
137+
# gate_up_proj stores [gate; up] stacked along dim 0
138+
# nn.Linear weight is [out_features, in_features]
139+
self[i].gate_proj.weight.data = (
140+
gate_up[:intermediate_size, :].clone().contiguous()
141+
)
142+
self[i].up_proj.weight.data = (
143+
gate_up[intermediate_size:, :].clone().contiguous()
144+
)
145+
self[i].down_proj.weight.data = down.clone().contiguous()

0 commit comments

Comments
 (0)