Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/llmcompressor/modeling/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3
from llmcompressor.modeling.llama4 import replace as replace_llama4
from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE
from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3VLMoE
from llmcompressor.utils.helpers import patch_attr

__all__ = ["replace_modules_for_calibration"]
Expand All @@ -13,6 +13,7 @@
replacements = {
"DeepseekV3MoE": replace_deepseekv3,
"Llama4TextMoe": replace_llama4,
"Qwen3VLMoeTextSparseMoeBlock": replace_Qwen3VLMoE,
}


Expand Down
38 changes: 38 additions & 0 deletions src/llmcompressor/modeling/qwen3_vl_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextMLP,
)
from llmcompressor.utils.dev import skip_weights_initialize

class LinearQwen3VLMoeTextSparseMoeBlock(torch.nn.Module):
def __init__(self, config, original):
super().__init__()
self.hidden_size = config.hidden_size
self.num_experts = config.num_experts
self.gate = original.gate
self.experts = SequentialQwen3VLMoeTextExperts(config, original.experts)

class SequentialQwen3VLMoeTextExperts(torch.nn.ModuleList):
def __init__(self, config, original):
self.num_experts = original.gate_up_proj.shape[0]
with skip_weights_initialize():
super().__init__([Qwen3VLMoeTextMLP(config) for _ in range(self.num_experts)])

intermediate_size = original.down_proj.shape[1]

for i in range(self.num_experts):
gate_up = original.gate_up_proj[i]
down = original.down_proj[i]

gate_proj = gate_up[:, :intermediate_size]
up_proj = gate_up[:, intermediate_size:]

self[i].gate_proj.weight.data = gate_proj.t().clone().contiguous()
self[i].up_proj.weight.data = up_proj.t().clone().contiguous()
self[i].down_proj.weight.data = down.t().clone().contiguous()

def replace(config, module):
return LinearQwen3VLMoeTextSparseMoeBlock(
config=config.get_text_config(),
original=module,
)
Loading