Skip to content

Commit 15da5c6

Browse files
committed
updatE
1 parent 27e6f07 commit 15da5c6

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

examples/quantization_w8a8_fp8/qwen3_vl_moe_fp8_block_example.py renamed to examples/quantization_w8a8_fp8/qwen3_vl_moe_fp8_example.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,29 @@
11

2-
from transformers import Qwen3VLMoeForConditionalGeneration
2+
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
33

44
from llmcompressor import oneshot
55
from llmcompressor.modeling import replace_modules_for_calibration
66
from llmcompressor.modifiers.quantization import QuantizationModifier
77

88
# NOTE: Qwen3-VL-MoE support is not in transformers<=4.56.2
9-
# you may need to install transformes from source
9+
# you may need to install transformers from source
1010

11-
MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct"
1211

12+
MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct"
1313

1414
# Load model.
1515
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype="auto")
16+
processor = AutoProcessor.from_pretrained(MODEL_ID)
1617
model = replace_modules_for_calibration(model)
18+
1719
# Configure the quantization algorithm and scheme.
1820
# In this case, we:
19-
# * quantize the weights to fp8 with block size 128 via ptq
20-
# * quantize the activations to fp8 with dynamic group activations
21+
# * quantize the weights to fp8 with channel-wise quantization
22+
# * quantize the activations to fp8 with dynamic token activations
23+
# NOTE: only datafree quantization is supported for Qwen3-VL-MoE currently
2124
recipe = QuantizationModifier(
2225
targets="Linear",
23-
scheme="FP8_BLOCK",
26+
scheme="FP8_DYNAMIC",
2427
ignore=[
2528
"re:.*lm_head",
2629
"re:visual.*",
@@ -33,5 +36,6 @@
3336
oneshot(model=model, recipe=recipe)
3437

3538
# Save to disk in compressed-tensors format.
36-
SAVE_DIR = "/proving-grounds/engine/hub_cache/Qwen3-VL-235B-A22B-Instruct" + "-FP8-BLOCK"
39+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-DYNAMIC"
3740
model.save_pretrained(SAVE_DIR)
41+
processor.save_pretrained(SAVE_DIR)

src/llmcompressor/modeling/qwen3_vl_moe.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def __init__(self, config, original):
1111
super().__init__()
1212
self.hidden_size = config.hidden_size
1313
self.num_experts = config.num_experts
14-
self.gate = original.gate
14+
self.gate = wrap_gate(original.gate)
1515
self.experts = SequentialQwen3VLMoeTextExperts(config, original.experts)
1616

1717
class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList):
@@ -33,7 +33,19 @@ def __init__(self, config, original):
3333
self[i].up_proj.weight.data = up_proj.t().clone().contiguous()
3434
self[i].down_proj.weight.data = down.t().clone().contiguous()
3535

36-
def replace(config, module, calibrate_all_experts):
36+
37+
def wrap_gate(gate):
38+
# temporary workaround until ct supports ignores of Linear instances
39+
linear_gate = torch.nn.Linear(gate.in_features, gate.out_features)
40+
linear_gate.weight.data.copy_(gate.weight.data)
41+
setattr(linear_gate, "hidden_size", gate.hidden_size)
42+
setattr(linear_gate, "top_k", gate.top_k)
43+
setattr(linear_gate, "forward", gate.forward)
44+
del gate
45+
return linear_gate
46+
47+
48+
def replace(config, module, calibrate_all_experts=False):
3749
return LinearQwen3VLMoeTextSparseMoeBlock(
3850
config=config.get_text_config(),
3951
original=module,

0 commit comments

Comments
 (0)