Skip to content

Commit 5dd06ff

Browse files
dsikkacajeonrh
authored andcommitted
[Qwen3VLMoe] Add linearized definition and FP8 Quantization Example (#1874)
SUMMARY: - Updates the MoE layer to use a linearized definition such that we can quantize and run the model in vLLM - Wraps the gate layer so that it is properly ignored - this is hack for now. We will need to do this properly in ct - Not adding forward pass for now; will add a forward pass as a follow-up but would like it in for the release to enable FP8 quantization - Note - requires latest transformers TEST PLAN: - Produces `/proving-grounds/engine/hub_cache/Qwen3-VL-235B-A22B-Instruct-FP8_DYNAMIC` which generates coherent generations: ```python if __name__ == '__main__': import torch from vllm import LLM, SamplingParams import torch prompts = [ "The Swiss Alps are", "Brad Marchand is", "The Toronto Maple Leafs are" ] # Create a sampling params object for greedy sampling sampling_params = SamplingParams(temperature=0.80, top_p=0.95, max_tokens=40, min_tokens=10) llm = LLM("/proving-grounds/engine/hub_cache/Qwen3-VL-235B-A22B-Instruct-FP8_DYNAMIC", tensor_parallel_size=2, max_model_len=4096, enforce_eager=True) output = llm.generate(prompts, sampling_params) for out in output: print(out.outputs[0].text) ``` Generations: ```bash a true paradise for nature lovers and outdoor enthusiasts. With their snow-capped peaks, lush green valleys, and crystal-clear lakes, the Alps offer a stunning backdrop for a wide range of activities. Whether a prominent figure in the NHL, known for his exceptional performance and leadership. He has won the Art Ross Trophy as the NHL's leading scorer, with 110 points (32 goals and a professional ice hockey team based in Toronto, Ontario, Canada. They are members of the Atlantic Division in the Eastern Conference of the National Hockey League (NHL). The team was established in 1 ``` Signed-off-by: Cassie Jeon <[email protected]>
1 parent 34a4602 commit 5dd06ff

File tree

4 files changed

+105
-17
lines changed

4 files changed

+105
-17
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
2+
3+
from llmcompressor import oneshot
4+
from llmcompressor.modeling import replace_modules_for_calibration
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
7+
# NOTE: Qwen3-VL-MoE support is not in transformers<=4.56.2
8+
# you may need to install transformers from source
9+
10+
11+
MODEL_ID = "Qwen/Qwen3-VL-235B-A22B-Instruct"
12+
13+
# Load model.
14+
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype="auto")
15+
processor = AutoProcessor.from_pretrained(MODEL_ID)
16+
model = replace_modules_for_calibration(model)
17+
18+
# Configure the quantization algorithm and scheme.
19+
# In this case, we:
20+
# * quantize the weights to fp8 with channel-wise quantization
21+
# * quantize the activations to fp8 with dynamic token activations
22+
# NOTE: only datafree quantization is supported for Qwen3-VL-MoE currently
23+
recipe = QuantizationModifier(
24+
targets="Linear",
25+
scheme="FP8_DYNAMIC",
26+
ignore=[
27+
"re:.*lm_head",
28+
"re:visual.*",
29+
"re:model.visual.*",
30+
"re:.*mlp.gate$",
31+
],
32+
)
33+
34+
# Apply quantization.
35+
oneshot(model=model, recipe=recipe)
36+
37+
# Save to disk in compressed-tensors format.
38+
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-DYNAMIC"
39+
model.save_pretrained(SAVE_DIR)
40+
processor.save_pretrained(SAVE_DIR)

src/llmcompressor/modeling/prepare.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@
55
from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
66
from llmcompressor.modeling.llama4 import replace as replace_llama4
77
from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE
8-
9-
try:
10-
from llmcompressor.modeling.qwen3_next_moe import replace as replace_Qwen3NextMoE
11-
except ImportError:
12-
replace_Qwen3NextMoE = None
8+
from llmcompressor.modeling.qwen3_next_moe import replace as replace_Qwen3NextMoE
9+
from llmcompressor.modeling.qwen3_vl_moe import replace as replace_Qwen3VLMoE
1310
from llmcompressor.utils.helpers import patch_attr
1411

1512
__all__ = ["replace_modules_for_calibration"]
@@ -18,6 +15,7 @@
1815
replacements = {
1916
"DeepseekV3MoE": replace_deepseekv3,
2017
"Llama4TextMoe": replace_llama4,
18+
"Qwen3VLMoeTextSparseMoeBlock": replace_Qwen3VLMoE,
2119
}
2220

2321

@@ -81,11 +79,9 @@ def update_qwen3_next_moe(model, module, stack, calibrate_all_experts):
8179

8280
moe_context = {
8381
"Qwen3MoeForCausalLM": update_qwen3_moe,
82+
"Qwen3NextForCausalLM": update_qwen3_next_moe,
8483
}
8584

86-
if replace_Qwen3NextMoE is not None:
87-
moe_context["Qwen3NextForCausalLM"] = update_qwen3_next_moe
88-
8985

9086
def moe_calibration_context(
9187
model: PreTrainedModel,

src/llmcompressor/modeling/qwen3_next_moe.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,13 @@
1515
# limitations under the License.
1616

1717
import torch
18-
from transformers.models import Qwen3NextConfig
19-
from transformers.models.qwen3_next.modeling_qwen3_next import (
20-
Qwen3NextSparseMoeBlock as OriginalQwen3NextMoeSparseMoeBlock,
21-
)
2218

2319

2420
class Qwen3NextSparseMoeBlock(torch.nn.Module):
2521
def __init__(
2622
self,
27-
config: Qwen3NextConfig,
28-
original: OriginalQwen3NextMoeSparseMoeBlock,
23+
config,
24+
original,
2925
calibrate_all_experts: bool,
3026
):
3127
super().__init__()
@@ -109,9 +105,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
109105

110106

111107
def replace(
112-
config: Qwen3NextConfig,
113-
module: OriginalQwen3NextMoeSparseMoeBlock,
114-
calibrate_all_experts: bool,
108+
config,
109+
module,
110+
calibrate_all_experts,
115111
):
116112
return Qwen3NextSparseMoeBlock(
117113
config=config, original=module, calibrate_all_experts=calibrate_all_experts
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
3+
from llmcompressor.utils.dev import skip_weights_initialize
4+
5+
6+
class LinearQwen3VLMoeTextSparseMoeBlock(torch.nn.Module):
7+
def __init__(self, config, original):
8+
super().__init__()
9+
self.hidden_size = config.hidden_size
10+
self.num_experts = config.num_experts
11+
self.gate = wrap_gate(original.gate)
12+
self.experts = SequentialQwen3VLMoeTextExperts(config, original.experts)
13+
14+
15+
class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList):
16+
def __init__(self, config, original):
17+
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
18+
Qwen3VLMoeTextMLP,
19+
)
20+
21+
self.num_experts = original.gate_up_proj.shape[0]
22+
with skip_weights_initialize():
23+
super().__init__(
24+
[Qwen3VLMoeTextMLP(config) for _ in range(self.num_experts)]
25+
)
26+
27+
intermediate_size = original.down_proj.shape[1]
28+
29+
for i in range(self.num_experts):
30+
gate_up = original.gate_up_proj[i]
31+
down = original.down_proj[i]
32+
33+
gate_proj = gate_up[:, :intermediate_size]
34+
up_proj = gate_up[:, intermediate_size:]
35+
36+
self[i].gate_proj.weight.data = gate_proj.t().clone().contiguous()
37+
self[i].up_proj.weight.data = up_proj.t().clone().contiguous()
38+
self[i].down_proj.weight.data = down.t().clone().contiguous()
39+
40+
41+
def wrap_gate(gate):
42+
# temporary workaround until ct supports ignores of Linear instances
43+
linear_gate = torch.nn.Linear(gate.in_features, gate.out_features)
44+
linear_gate.weight.data.copy_(gate.weight.data)
45+
setattr(linear_gate, "hidden_size", gate.hidden_size)
46+
setattr(linear_gate, "top_k", gate.top_k)
47+
setattr(linear_gate, "forward", gate.forward)
48+
del gate
49+
return linear_gate
50+
51+
52+
def replace(config, module, calibrate_all_experts=False):
53+
return LinearQwen3VLMoeTextSparseMoeBlock(
54+
config=config.get_text_config(),
55+
original=module,
56+
)

0 commit comments

Comments
 (0)