Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions examples/quantization_w8a8_fp8/qwen3_vl_moe_fp8_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration

from llmcompressor import oneshot
from llmcompressor.modeling import replace_modules_for_calibration
from llmcompressor.modifiers.quantization import QuantizationModifier

# NOTE: Qwen3-VL-MoE support is not in transformers<=4.56.2
# you may need to install transformers from source


MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct"

# Load model.
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype="auto")
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = replace_modules_for_calibration(model)

# Configure the quantization algorithm and scheme.
# In this case, we:
# * quantize the weights to fp8 with channel-wise quantization
# * quantize the activations to fp8 with dynamic token activations
# NOTE: only datafree quantization is supported for Qwen3-VL-MoE currently
recipe = QuantizationModifier(
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=[
"re:.*lm_head",
"re:visual.*",
"re:model.visual.*",
"re:.*mlp.gate$",
],
)

# Apply quantization.
oneshot(model=model, recipe=recipe)

# Save to disk in compressed-tensors format.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-DYNAMIC"
model.save_pretrained(SAVE_DIR)
processor.save_pretrained(SAVE_DIR)
12 changes: 4 additions & 8 deletions src/llmcompressor/modeling/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,8 @@
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
from llmcompressor.modeling.llama4 import replace as replace_llama4
from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE

try:
from llmcompressor.modeling.qwen3_next_moe import replace as replace_Qwen3NextMoE
except ImportError:
replace_Qwen3NextMoE = None
from llmcompressor.modeling.qwen3_next_moe import replace as replace_Qwen3NextMoE
from llmcompressor.modeling.qwen3_vl_moe import replace as replace_Qwen3VLMoE
from llmcompressor.utils.helpers import patch_attr

__all__ = ["replace_modules_for_calibration"]
Expand All @@ -18,6 +15,7 @@
replacements = {
"DeepseekV3MoE": replace_deepseekv3,
"Llama4TextMoe": replace_llama4,
"Qwen3VLMoeTextSparseMoeBlock": replace_Qwen3VLMoE,
}


Expand Down Expand Up @@ -81,11 +79,9 @@ def update_qwen3_next_moe(model, module, stack, calibrate_all_experts):

moe_context = {
"Qwen3MoeForCausalLM": update_qwen3_moe,
"Qwen3NextForCausalLM": update_qwen3_next_moe,
}

if replace_Qwen3NextMoE is not None:
moe_context["Qwen3NextForCausalLM"] = update_qwen3_next_moe


def moe_calibration_context(
model: PreTrainedModel,
Expand Down
14 changes: 5 additions & 9 deletions src/llmcompressor/modeling/qwen3_next_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,13 @@
# limitations under the License.

import torch
from transformers.models import Qwen3NextConfig
from transformers.models.qwen3_next.modeling_qwen3_next import (
Qwen3NextSparseMoeBlock as OriginalQwen3NextMoeSparseMoeBlock,
)


class Qwen3NextSparseMoeBlock(torch.nn.Module):
def __init__(
self,
config: Qwen3NextConfig,
original: OriginalQwen3NextMoeSparseMoeBlock,
config,
original,
calibrate_all_experts: bool,
):
super().__init__()
Expand Down Expand Up @@ -109,9 +105,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


def replace(
config: Qwen3NextConfig,
module: OriginalQwen3NextMoeSparseMoeBlock,
calibrate_all_experts: bool,
config,
module,
calibrate_all_experts,
):
return Qwen3NextSparseMoeBlock(
config=config, original=module, calibrate_all_experts=calibrate_all_experts
Expand Down
56 changes: 56 additions & 0 deletions src/llmcompressor/modeling/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch

from llmcompressor.utils.dev import skip_weights_initialize


class LinearQwen3VLMoeTextSparseMoeBlock(torch.nn.Module):
def __init__(self, config, original):
super().__init__()
self.hidden_size = config.hidden_size
self.num_experts = config.num_experts
self.gate = wrap_gate(original.gate)
self.experts = SequentialQwen3VLMoeTextExperts(config, original.experts)


class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList):
def __init__(self, config, original):
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextMLP,
)

self.num_experts = original.gate_up_proj.shape[0]
with skip_weights_initialize():
super().__init__(
[Qwen3VLMoeTextMLP(config) for _ in range(self.num_experts)]
)

intermediate_size = original.down_proj.shape[1]

for i in range(self.num_experts):
gate_up = original.gate_up_proj[i]
down = original.down_proj[i]

gate_proj = gate_up[:, :intermediate_size]
up_proj = gate_up[:, intermediate_size:]

self[i].gate_proj.weight.data = gate_proj.t().clone().contiguous()
self[i].up_proj.weight.data = up_proj.t().clone().contiguous()
self[i].down_proj.weight.data = down.t().clone().contiguous()


def wrap_gate(gate):
# temporary workaround until ct supports ignores of Linear instances
linear_gate = torch.nn.Linear(gate.in_features, gate.out_features)
linear_gate.weight.data.copy_(gate.weight.data)
setattr(linear_gate, "hidden_size", gate.hidden_size)
setattr(linear_gate, "top_k", gate.top_k)
setattr(linear_gate, "forward", gate.forward)
del gate
return linear_gate


def replace(config, module, calibrate_all_experts=False):
return LinearQwen3VLMoeTextSparseMoeBlock(
config=config.get_text_config(),
original=module,
)