Skip to content

Commit a5caa42

Browse files
committed
Add Qwen3MoeForCausalLM
1 parent c6aa83f commit a5caa42

File tree

2 files changed

+168
-0
lines changed

2 files changed

+168
-0
lines changed

exllamav3/models/architectures.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .mistral import MistralModel
33
from .qwen2 import Qwen2Model
44
from .qwen3 import Qwen3Model
5+
from .qwen3_moe import Qwen3MoeModel
56
from .phi3 import Phi3Model
67
from .gemma2 import Gemma2Model
78
from .decilm import DeciLMModel
@@ -19,6 +20,7 @@
1920
MistralModel,
2021
Qwen2Model,
2122
Qwen3Model,
23+
Qwen3MoeModel,
2224
Phi3Model,
2325
Gemma2Model,
2426
DeciLMModel,

exllamav3/models/qwen3_moe.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from __future__ import annotations
2+
from typing_extensions import override
3+
import torch
4+
from .config import Config, no_default
5+
from .model import Model
6+
from ..util.rope import RopeSettings, RopeStyle
7+
from ..modules import RMSNorm, Embedding, TransformerBlock, Attention, BlockSparseMLP, Linear
8+
from ..modules.attn import prepare_for_attn
9+
10+
class Qwen3MoeConfig(Config):
11+
arch_string = "Qwen3MoeForCausalLM"
12+
13+
def __init__(
14+
self,
15+
directory: str,
16+
**kwargs,
17+
):
18+
super().__init__(
19+
directory,
20+
Qwen3MoeModel,
21+
**kwargs
22+
)
23+
24+
# Attention params
25+
self.head_dim = self.read_cfg(int, "head_dim", None)
26+
self.hidden_size = self.read_cfg(int, "hidden_size", no_default)
27+
self.num_q_heads = self.read_cfg(int, "num_attention_heads", no_default)
28+
self.num_kv_heads = self.read_cfg(int, "num_key_value_heads", self.num_q_heads)
29+
30+
if not self.head_dim:
31+
self.head_dim = self.hidden_size // self.num_q_heads
32+
33+
# MLP params
34+
self.assert_cfg(str, "hidden_act", "silu", True)
35+
self.assert_cfg(bool, "norm_topk_prob", True, True)
36+
self.moe_intermediate_size = self.read_cfg(int, "moe_intermediate_size", no_default)
37+
self.num_experts = self.read_cfg(int, "num_experts", no_default)
38+
self.num_experts_per_tok = self.read_cfg(int, "num_experts_per_tok", no_default)
39+
40+
# Norms
41+
self.rms_norm_eps = self.read_cfg(float, "rms_norm_eps", no_default)
42+
43+
# Layers
44+
self.num_hidden_layers = self.read_cfg(int, "num_hidden_layers", no_default)
45+
self.tie_word_embeddings = self.read_cfg(bool, "tie_word_embeddings", False)
46+
47+
# RoPE
48+
self.rope_settings = self.read_rope_settings_default(RopeStyle.NEOX)
49+
50+
51+
class Qwen3MoeModel(Model):
52+
config_class = Qwen3MoeConfig
53+
54+
def __init__(
55+
self,
56+
config: Qwen3MoeConfig,
57+
**kwargs
58+
):
59+
super().__init__(config, **kwargs)
60+
61+
self.modules += [
62+
Embedding(
63+
config = config,
64+
key = "model.embed_tokens",
65+
vocab_size = config.vocab_size,
66+
hidden_size = config.hidden_size,
67+
)
68+
]
69+
70+
self.first_block_idx = len(self.modules)
71+
72+
self.modules += [
73+
TransformerBlock(
74+
config = config,
75+
key = f"model.layers.{idx}",
76+
attn_norm = RMSNorm(
77+
config = config,
78+
key = f"model.layers.{idx}.input_layernorm",
79+
rms_norm_eps = config.rms_norm_eps,
80+
),
81+
attn = Attention(
82+
config = config,
83+
key = f"model.layers.{idx}.self_attn",
84+
layer_idx = idx,
85+
hidden_size = config.hidden_size,
86+
head_dim = config.head_dim,
87+
num_q_heads = config.num_q_heads,
88+
num_kv_heads = config.num_kv_heads,
89+
rope_settings = config.rope_settings,
90+
sm_scale = None,
91+
key_q = "q_proj",
92+
key_k = "k_proj",
93+
key_v = "v_proj",
94+
key_o = "o_proj",
95+
qmap = "block.attn",
96+
q_norm = RMSNorm(
97+
config = config,
98+
key = f"model.layers.{idx}.self_attn.q_norm",
99+
rms_norm_eps = config.rms_norm_eps,
100+
),
101+
k_norm = RMSNorm(
102+
config = config,
103+
key = f"model.layers.{idx}.self_attn.k_norm",
104+
rms_norm_eps = config.rms_norm_eps,
105+
),
106+
),
107+
mlp_norm = RMSNorm(
108+
config = config,
109+
key = f"model.layers.{idx}.post_attention_layernorm",
110+
rms_norm_eps = config.rms_norm_eps,
111+
),
112+
mlp = BlockSparseMLP(
113+
config = config,
114+
key = f"model.layers.{idx}.mlp",
115+
hidden_size = config.hidden_size,
116+
intermediate_size = config.moe_intermediate_size,
117+
num_experts = self.config.num_experts,
118+
num_experts_per_tok = self.config.num_experts_per_tok,
119+
key_up = "experts.{expert_idx}.up_proj",
120+
key_gate = "experts.{expert_idx}.gate_proj",
121+
key_down = "experts.{expert_idx}.down_proj",
122+
key_routing_gate = "gate",
123+
qmap = "block.mlp",
124+
interm_dtype = torch.half,
125+
out_dtype = torch.float,
126+
),
127+
)
128+
for idx in range(config.num_hidden_layers)
129+
]
130+
131+
self.last_kv_module_idx = len(self.modules) - 1
132+
133+
head_alt_key = None
134+
if config.tie_word_embeddings and not self.config.stc.has_tensor("lm_head"):
135+
head_alt_key = "model.embed_tokens"
136+
137+
self.modules += [
138+
RMSNorm(
139+
config = config,
140+
key = "model.norm",
141+
rms_norm_eps = config.rms_norm_eps,
142+
out_dtype = torch.half,
143+
),
144+
Linear(
145+
config = config,
146+
key = "lm_head",
147+
qbits_key = "head_bits",
148+
alt_key = head_alt_key,
149+
in_features = config.hidden_size,
150+
out_features = config.vocab_size,
151+
qmap = "block",
152+
caps = {"logits_output": True}
153+
)
154+
]
155+
156+
self.logit_layer_idx = len(self.modules) - 1
157+
158+
# Activate all experts during H capture pass in quantization
159+
self.calibration_all_experts = True
160+
161+
162+
@override
163+
def prepare_inputs(self, input_ids: torch.Tensor, params: dict) -> torch.Tensor:
164+
params["input_ids"] = input_ids
165+
input_ids = prepare_for_attn(input_ids, params)
166+
return input_ids

0 commit comments

Comments
 (0)