Skip to content

Commit 3347901

Browse files
committed
update
1 parent 41ccbe1 commit 3347901

File tree

3 files changed

+138
-11
lines changed

3 files changed

+138
-11
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from transformers import AutoProcessor, Qwen3_5MoeForConditionalGeneration
2+
from datasets import load_dataset
3+
from llmcompressor import oneshot
4+
from llmcompressor.modifiers.quantization import QuantizationModifier
5+
6+
MODEL_ID = "/raid/engine/dsikka/models--Qwen--Qwen3.5-397B-A17B/snapshots/7cad2bae11cb49ca79f7d6a0954de2e2756f4e27"
7+
8+
# Load model.
9+
model = Qwen3_5MoeForConditionalGeneration.from_pretrained(MODEL_ID, dtype="auto")
10+
processor = AutoProcessor.from_pretrained(MODEL_ID)
11+
12+
13+
recipe = QuantizationModifier(
14+
targets="Linear",
15+
scheme="NVFP4",
16+
ignore=[
17+
"re:.*lm_head",
18+
"re:visual.*",
19+
"re:model.visual.*",
20+
"re:.*mlp.gate$",
21+
"re:.*embed_tokens$",
22+
"re:.*shared_expert_gate$",
23+
"re:.*mlp\\.shared_expert$",
24+
"re:.*linear_attn.*",
25+
],
26+
)
27+
28+
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
29+
DATASET_SPLIT = "train_sft"
30+
31+
# Select number of samples
32+
NUM_CALIBRATION_SAMPLES = 20
33+
MAX_SEQUENCE_LENGTH = 2048
34+
35+
# Load dataset and preprocess.
36+
ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]")
37+
ds = ds.shuffle(seed=42)
38+
39+
40+
def preprocess(example):
41+
return {
42+
"text": processor.apply_chat_template(
43+
example["messages"],
44+
tokenize=False,
45+
)
46+
}
47+
48+
49+
ds = ds.map(preprocess)
50+
51+
52+
# Tokenize inputs.
53+
def tokenize(sample):
54+
return processor(
55+
sample["text"],
56+
padding=False,
57+
max_length=MAX_SEQUENCE_LENGTH,
58+
truncation=True,
59+
add_special_tokens=False,
60+
)
61+
62+
63+
ds = ds.map(tokenize, remove_columns=ds.column_names)
64+
65+
66+
# Apply quantization.
67+
oneshot(model=model,
68+
recipe=recipe,
69+
dataset=ds,
70+
max_seq_length=MAX_SEQUENCE_LENGTH,
71+
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
72+
moe_calibrate_all_experts=True)
73+
74+
# Save to disk in compressed-tensors format.
75+
SAVE_DIR = "/raid/engine/dsikka/" + "Qwen3.5-397B-A17B" + "-NVFP4"
76+
model.save_pretrained(SAVE_DIR)
77+
processor.save_pretrained(SAVE_DIR)

examples/quantization_w8a8_fp8/qwen3_5_moe.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,23 @@
1616
# NOTE: only datafree quantization is supported for Qwen3-VL-MoE currently
1717
recipe = QuantizationModifier(
1818
targets="Linear",
19-
scheme="FP8_BLOCK",
19+
scheme="FP8_DYNAMIC",
2020
ignore=[
2121
"re:.*lm_head",
2222
"re:visual.*",
2323
"re:model.visual.*",
2424
"re:.*mlp.gate$",
2525
"re:.*embed_tokens$",
2626
"re:.*shared_expert_gate$",
27-
"re:.*conv1d$",
28-
"re:.*in_proj_a$",
29-
"re:.*in_proj_b$",
30-
"re:.*in_proj_ba$",
31-
"re:.*norm$",
32-
"re:.*pre_fc_norm_hidden$",
33-
"re:.*pre_fc_norm_embedding$",
27+
"re:.*mlp\\.shared_expert$",
28+
"re:.*linear_attn.*",
3429
],
3530
)
3631

3732
# Apply quantization.
3833
oneshot(model=model, recipe=recipe)
3934

4035
# Save to disk in compressed-tensors format.
41-
SAVE_DIR = "/raid/engine/dsikka/" + "Qwen3.5-397B-A17B" + "-FP8-Block"
36+
SAVE_DIR = "/raid/engine/dsikka/" + "Qwen3.5-397B-A17B" + "-FP8-Dynamic-NoLinearAttn"
4237
model.save_pretrained(SAVE_DIR)
4338
processor.save_pretrained(SAVE_DIR)

src/llmcompressor/modeling/qwen3_5_vl_moe.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from llmcompressor.modeling.moe_context import MoECalibrationModule
88
from llmcompressor.utils.dev import skip_weights_initialize
9+
import torch.nn.functional as F
910

1011

1112
@MoECalibrationModule.register("Qwen3_5MoeSparseMoeBlock")
@@ -32,6 +33,55 @@ def __init__(
3233
self.shared_expert_gate = original.shared_expert_gate
3334
self.gate = original.gate
3435
self.experts = SequentialQwen3VLMoeTextExperts(text_config, original.experts)
36+
37+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
38+
batch_size, sequence_length, hidden_dim = hidden_states.shape
39+
hidden_states_reshaped = hidden_states.view(-1, hidden_dim)
40+
41+
# router: returns (router_logits, router_scores, router_indices)
42+
_, routing_weights, selected_experts = self.gate(hidden_states_reshaped)
43+
44+
# expert mask: (num_experts, top_k, num_tokens)
45+
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(
46+
2, 1, 0
47+
)
48+
49+
final_hidden_states = torch.zeros(
50+
(batch_size * sequence_length, hidden_dim),
51+
dtype=hidden_states.dtype,
52+
device=hidden_states.device,
53+
)
54+
55+
for expert_idx, expert_layer in enumerate(self.experts):
56+
idx, token_idx = torch.where(expert_mask[expert_idx])
57+
58+
if self.calibrate_all_experts:
59+
expert_out = expert_layer(hidden_states_reshaped)[token_idx]
60+
else:
61+
expert_out = expert_layer(hidden_states_reshaped[token_idx])
62+
63+
if len(token_idx) > 0:
64+
current_hidden_states = (
65+
expert_out * routing_weights[token_idx, idx, None]
66+
)
67+
final_hidden_states.index_add_(
68+
0,
69+
token_idx,
70+
current_hidden_states.to(hidden_states.dtype),
71+
)
72+
73+
# shared expert
74+
shared_expert_output = self.shared_expert(hidden_states_reshaped)
75+
shared_expert_output = (
76+
F.sigmoid(self.shared_expert_gate(hidden_states_reshaped))
77+
* shared_expert_output
78+
)
79+
final_hidden_states = final_hidden_states + shared_expert_output
80+
81+
final_hidden_states = final_hidden_states.reshape(
82+
batch_size, sequence_length, hidden_dim
83+
)
84+
return final_hidden_states
3585

3686
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
3787
return original
@@ -42,6 +92,7 @@ def __init__(self, config, original):
4292
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
4393
Qwen3_5MoeMLP,
4494
)
95+
from compressed_tensors.offload import disable_onloading
4596

4697
self.num_experts = original.gate_up_proj.shape[0]
4798
with skip_weights_initialize():
@@ -56,9 +107,13 @@ def __init__(self, config, original):
56107

57108
intermediate_size = original.down_proj.shape[-1]
58109

110+
with disable_onloading():
111+
gate_up_data = original.gate_up_proj.data # [num_experts, 2*inter, hidden]
112+
down_data = original.down_proj.data # [num_experts, hidden, inter]
113+
59114
for i in range(self.num_experts):
60-
gate_up = original.gate_up_proj[i]
61-
down = original.down_proj[i]
115+
gate_up = gate_up_data[i]
116+
down = down_data[i]
62117

63118
gate_proj = gate_up[:intermediate_size, :]
64119
up_proj = gate_up[intermediate_size:, :]

0 commit comments

Comments
 (0)