1+ from __future__ import annotations
2+ from typing_extensions import override
3+ import torch
4+ from .config import Config , no_default
5+ from .model import Model
6+ from ..util .rope import RopeSettings , RopeStyle
7+ from ..modules import RMSNorm , Embedding , TransformerBlock , Attention , BlockSparseMLP , Linear
8+ from ..modules .attn import prepare_for_attn
9+
10+ class Qwen3MoeConfig (Config ):
11+ arch_string = "Qwen3MoeForCausalLM"
12+
13+ def __init__ (
14+ self ,
15+ directory : str ,
16+ ** kwargs ,
17+ ):
18+ super ().__init__ (
19+ directory ,
20+ Qwen3MoeModel ,
21+ ** kwargs
22+ )
23+
24+ # Attention params
25+ self .head_dim = self .read_cfg (int , "head_dim" , None )
26+ self .hidden_size = self .read_cfg (int , "hidden_size" , no_default )
27+ self .num_q_heads = self .read_cfg (int , "num_attention_heads" , no_default )
28+ self .num_kv_heads = self .read_cfg (int , "num_key_value_heads" , self .num_q_heads )
29+
30+ if not self .head_dim :
31+ self .head_dim = self .hidden_size // self .num_q_heads
32+
33+ # MLP params
34+ self .assert_cfg (str , "hidden_act" , "silu" , True )
35+ self .assert_cfg (bool , "norm_topk_prob" , True , True )
36+ self .moe_intermediate_size = self .read_cfg (int , "moe_intermediate_size" , no_default )
37+ self .num_experts = self .read_cfg (int , "num_experts" , no_default )
38+ self .num_experts_per_tok = self .read_cfg (int , "num_experts_per_tok" , no_default )
39+
40+ # Norms
41+ self .rms_norm_eps = self .read_cfg (float , "rms_norm_eps" , no_default )
42+
43+ # Layers
44+ self .num_hidden_layers = self .read_cfg (int , "num_hidden_layers" , no_default )
45+ self .tie_word_embeddings = self .read_cfg (bool , "tie_word_embeddings" , False )
46+
47+ # RoPE
48+ self .rope_settings = self .read_rope_settings_default (RopeStyle .NEOX )
49+
50+
51+ class Qwen3MoeModel (Model ):
52+ config_class = Qwen3MoeConfig
53+
54+ def __init__ (
55+ self ,
56+ config : Qwen3MoeConfig ,
57+ ** kwargs
58+ ):
59+ super ().__init__ (config , ** kwargs )
60+
61+ self .modules += [
62+ Embedding (
63+ config = config ,
64+ key = "model.embed_tokens" ,
65+ vocab_size = config .vocab_size ,
66+ hidden_size = config .hidden_size ,
67+ )
68+ ]
69+
70+ self .first_block_idx = len (self .modules )
71+
72+ self .modules += [
73+ TransformerBlock (
74+ config = config ,
75+ key = f"model.layers.{ idx } " ,
76+ attn_norm = RMSNorm (
77+ config = config ,
78+ key = f"model.layers.{ idx } .input_layernorm" ,
79+ rms_norm_eps = config .rms_norm_eps ,
80+ ),
81+ attn = Attention (
82+ config = config ,
83+ key = f"model.layers.{ idx } .self_attn" ,
84+ layer_idx = idx ,
85+ hidden_size = config .hidden_size ,
86+ head_dim = config .head_dim ,
87+ num_q_heads = config .num_q_heads ,
88+ num_kv_heads = config .num_kv_heads ,
89+ rope_settings = config .rope_settings ,
90+ sm_scale = None ,
91+ key_q = "q_proj" ,
92+ key_k = "k_proj" ,
93+ key_v = "v_proj" ,
94+ key_o = "o_proj" ,
95+ qmap = "block.attn" ,
96+ q_norm = RMSNorm (
97+ config = config ,
98+ key = f"model.layers.{ idx } .self_attn.q_norm" ,
99+ rms_norm_eps = config .rms_norm_eps ,
100+ ),
101+ k_norm = RMSNorm (
102+ config = config ,
103+ key = f"model.layers.{ idx } .self_attn.k_norm" ,
104+ rms_norm_eps = config .rms_norm_eps ,
105+ ),
106+ ),
107+ mlp_norm = RMSNorm (
108+ config = config ,
109+ key = f"model.layers.{ idx } .post_attention_layernorm" ,
110+ rms_norm_eps = config .rms_norm_eps ,
111+ ),
112+ mlp = BlockSparseMLP (
113+ config = config ,
114+ key = f"model.layers.{ idx } .mlp" ,
115+ hidden_size = config .hidden_size ,
116+ intermediate_size = config .moe_intermediate_size ,
117+ num_experts = self .config .num_experts ,
118+ num_experts_per_tok = self .config .num_experts_per_tok ,
119+ key_up = "experts.{expert_idx}.up_proj" ,
120+ key_gate = "experts.{expert_idx}.gate_proj" ,
121+ key_down = "experts.{expert_idx}.down_proj" ,
122+ key_routing_gate = "gate" ,
123+ qmap = "block.mlp" ,
124+ interm_dtype = torch .half ,
125+ out_dtype = torch .float ,
126+ ),
127+ )
128+ for idx in range (config .num_hidden_layers )
129+ ]
130+
131+ self .last_kv_module_idx = len (self .modules ) - 1
132+
133+ head_alt_key = None
134+ if config .tie_word_embeddings and not self .config .stc .has_tensor ("lm_head" ):
135+ head_alt_key = "model.embed_tokens"
136+
137+ self .modules += [
138+ RMSNorm (
139+ config = config ,
140+ key = "model.norm" ,
141+ rms_norm_eps = config .rms_norm_eps ,
142+ out_dtype = torch .half ,
143+ ),
144+ Linear (
145+ config = config ,
146+ key = "lm_head" ,
147+ qbits_key = "head_bits" ,
148+ alt_key = head_alt_key ,
149+ in_features = config .hidden_size ,
150+ out_features = config .vocab_size ,
151+ qmap = "block" ,
152+ caps = {"logits_output" : True }
153+ )
154+ ]
155+
156+ self .logit_layer_idx = len (self .modules ) - 1
157+
158+ # Activate all experts during H capture pass in quantization
159+ self .calibration_all_experts = True
160+
161+
162+ @override
163+ def prepare_inputs (self , input_ids : torch .Tensor , params : dict ) -> torch .Tensor :
164+ params ["input_ids" ] = input_ids
165+ input_ids = prepare_for_attn (input_ids , params )
166+ return input_ids
0 commit comments