1+ from __future__ import annotations
2+ from typing_extensions import override
3+ import torch
4+ from .config import Config , no_default
5+ from .model import Model
6+ from ..util .rope import RopeSettings , RopeStyle
7+ from ..modules import RMSNorm , Embedding , TransformerBlock , Attention , BlockSparseMLP , GatedMLP , Linear
8+ from ..modules .attn import prepare_for_attn
9+
10+ class Dots1Config (Config ):
11+ arch_string = "Dots1ForCausalLM"
12+
13+ def __init__ (
14+ self ,
15+ directory : str ,
16+ ** kwargs ,
17+ ):
18+ super ().__init__ (
19+ directory ,
20+ {"text" : Dots1Model },
21+ ** kwargs
22+ )
23+
24+ # Attention params
25+ self .head_dim = self .read_cfg (int , "head_dim" , None )
26+ self .hidden_size = self .read_cfg (int , "hidden_size" , no_default )
27+ self .num_q_heads = self .read_cfg (int , "num_attention_heads" , no_default )
28+ self .num_kv_heads = self .read_cfg (int , "num_key_value_heads" , self .num_q_heads )
29+
30+ if not self .head_dim :
31+ self .head_dim = self .hidden_size // self .num_q_heads
32+
33+ # MLP params
34+ self .assert_cfg (str , "hidden_act" , "silu" , True )
35+ self .assert_cfg (str , "scoring_func" , "noaux_tc" , True )
36+ self .assert_cfg (bool , "norm_topk_prob" , True , True )
37+ self .intermediate_size = self .read_cfg (int , "intermediate_size" , no_default )
38+ self .moe_intermediate_size = self .read_cfg (int , "moe_intermediate_size" , no_default )
39+ self .num_shared_experts = self .read_cfg (int , "n_shared_experts" , 1 )
40+ self .num_experts = self .read_cfg (int , "n_routed_experts" , 128 )
41+ self .num_experts_per_tok = self .read_cfg (int , "num_experts_per_tok" , 8 )
42+ self .first_k_dense_replace = self .read_cfg (int , "first_k_dense_replace" , 3 )
43+ self .routed_scaling_factor = self .read_cfg (float , "routed_scaling_factor" , 2.5 )
44+
45+ # Norms
46+ self .rms_norm_eps = self .read_cfg (float , "rms_norm_eps" , no_default )
47+
48+ # Layers
49+ self .num_hidden_layers = self .read_cfg (int , "num_hidden_layers" , no_default )
50+ self .tie_word_embeddings = self .read_cfg (bool , "tie_word_embeddings" , False )
51+
52+ # RoPE
53+ self .rope_settings = self .read_rope_settings_default (RopeStyle .NEOX )
54+
55+
56+ class Dots1Model (Model ):
57+ config_class = Dots1Config
58+
59+ def __init__ (
60+ self ,
61+ config : Dots1Config ,
62+ ** kwargs
63+ ):
64+ super ().__init__ (config , ** kwargs )
65+
66+ self .modules += [
67+ Embedding (
68+ config = config ,
69+ key = "model.embed_tokens" ,
70+ vocab_size = config .vocab_size ,
71+ hidden_size = config .hidden_size ,
72+ )
73+ ]
74+
75+ self .first_block_idx = len (self .modules )
76+
77+ self .modules += [
78+ TransformerBlock (
79+ config = config ,
80+ key = f"model.layers.{ idx } " ,
81+ attn_norm = RMSNorm (
82+ config = config ,
83+ key = f"model.layers.{ idx } .input_layernorm" ,
84+ rms_norm_eps = config .rms_norm_eps ,
85+ ),
86+ attn = Attention (
87+ config = config ,
88+ key = f"model.layers.{ idx } .self_attn" ,
89+ layer_idx = idx ,
90+ hidden_size = config .hidden_size ,
91+ head_dim = config .head_dim ,
92+ num_q_heads = config .num_q_heads ,
93+ num_kv_heads = config .num_kv_heads ,
94+ rope_settings = config .rope_settings ,
95+ sm_scale = None ,
96+ key_q = "q_proj" ,
97+ key_k = "k_proj" ,
98+ key_v = "v_proj" ,
99+ key_o = "o_proj" ,
100+ qmap = "block.attn" ,
101+ q_norm = RMSNorm (
102+ config = config ,
103+ key = f"model.layers.{ idx } .self_attn.q_norm" ,
104+ rms_norm_eps = config .rms_norm_eps ,
105+ ),
106+ k_norm = RMSNorm (
107+ config = config ,
108+ key = f"model.layers.{ idx } .self_attn.k_norm" ,
109+ rms_norm_eps = config .rms_norm_eps ,
110+ ),
111+ out_dtype = torch .float
112+ ),
113+ mlp_norm = RMSNorm (
114+ config = config ,
115+ key = f"model.layers.{ idx } .post_attention_layernorm" ,
116+ rms_norm_eps = config .rms_norm_eps ,
117+ ),
118+ mlp = (
119+ GatedMLP (
120+ config = config ,
121+ key = f"model.layers.{ idx } .mlp" ,
122+ hidden_size = config .hidden_size ,
123+ intermediate_size = config .intermediate_size ,
124+ key_up = "up_proj" ,
125+ key_gate = "gate_proj" ,
126+ key_down = "down_proj" ,
127+ qmap = "block.mlp" ,
128+ interm_dtype = torch .half ,
129+ out_dtype = torch .float ,
130+ )
131+ if idx < config .first_k_dense_replace else
132+ BlockSparseMLP (
133+ config = config ,
134+ key = f"model.layers.{ idx } .mlp" ,
135+ hidden_size = config .hidden_size ,
136+ intermediate_size = config .moe_intermediate_size ,
137+ num_experts = config .num_experts ,
138+ num_experts_per_tok = config .num_experts_per_tok ,
139+ key_up = "experts.{expert_idx}.up_proj" ,
140+ key_gate = "experts.{expert_idx}.gate_proj" ,
141+ key_down = "experts.{expert_idx}.down_proj" ,
142+ key_routing_gate = "gate" ,
143+ qmap = "block.mlp" ,
144+ interm_dtype = torch .half ,
145+ out_dtype = torch .float ,
146+ deepseekv3_routing = True ,
147+ routed_scaling_factor = config .routed_scaling_factor ,
148+ n_group = 1 ,
149+ topk_group = 1 ,
150+ shared_experts = GatedMLP (
151+ config = config ,
152+ key = f"model.layers.{ idx } .mlp.shared_experts" ,
153+ hidden_size = config .hidden_size ,
154+ intermediate_size = config .moe_intermediate_size * config .num_shared_experts ,
155+ key_up = "up_proj" ,
156+ key_gate = "gate_proj" ,
157+ key_down = "down_proj" ,
158+ qmap = "block.mlp" ,
159+ interm_dtype = torch .half ,
160+ out_dtype = torch .float ,
161+ ),
162+ )
163+ )
164+ )
165+ for idx in range (config .num_hidden_layers )
166+ ]
167+
168+ # TODO: The first attn.o_proj is irregular and breaks quantization. For now, skip quantizing it
169+ self .modules [self .first_block_idx ].attn .o_proj .qmap = None
170+
171+ self .last_kv_module_idx = len (self .modules ) - 1
172+
173+ head_alt_key = None
174+ if config .tie_word_embeddings and not self .config .stc .has_tensor ("lm_head" ):
175+ head_alt_key = "model.embed_tokens"
176+
177+ self .modules += [
178+ RMSNorm (
179+ config = config ,
180+ key = "model.norm" ,
181+ rms_norm_eps = config .rms_norm_eps ,
182+ out_dtype = torch .half ,
183+ ),
184+ Linear (
185+ config = config ,
186+ key = "lm_head" ,
187+ qbits_key = "head_bits" ,
188+ alt_key = head_alt_key ,
189+ in_features = config .hidden_size ,
190+ out_features = config .vocab_size ,
191+ qmap = "block" ,
192+ caps = {"logits_output" : True }
193+ )
194+ ]
195+
196+ self .logit_layer_idx = len (self .modules ) - 1
197+
198+ # Activate all experts during H capture pass in quantization
199+ self .calibration_all_experts = True
200+
201+
202+ @override
203+ def prepare_inputs (self , input_ids : torch .Tensor , params : dict ) -> torch .Tensor :
204+ params ["input_ids" ] = input_ids
205+ input_ids = prepare_for_attn (input_ids , params )
206+ return input_ids
207+
208+
209+ @override
210+ def default_chat_prompt (self , prompt : str , system_prompt : str = None ) -> str :
211+ p = ""
212+ if system_prompt :
213+ p += f"<|im_start|>system\n "
214+ p += f"{ system_prompt } <|im_end|>\n "
215+ p += f"<|im_start|>user\n "
216+ p += f"{ prompt } <|im_end|>\n "
217+ p += f"<|im_start|>assistant\n "
218+ return p
0 commit comments