Skip to content

Commit 2d1ce89

Browse files
committed
Add Dots1ForCausalLM architecture
1 parent b8c830c commit 2d1ce89

File tree

2 files changed

+220
-0
lines changed

2 files changed

+220
-0
lines changed

exllamav3/models/architectures.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .cohere import CohereModel
22
from .cohere2 import Cohere2Model
33
from .decilm import DeciLMModel
4+
from .dots1 import Dots1Model
45
from .gemma2 import Gemma2Model
56
from .gemma3 import Gemma3Model, Gemma3TextModel
67
from .glm4 import Glm4Model
@@ -23,6 +24,7 @@
2324
CohereModel,
2425
Cohere2Model,
2526
DeciLMModel,
27+
Dots1Model,
2628
Gemma2Model,
2729
Gemma3Model,
2830
Gemma3TextModel,

exllamav3/models/dots1.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
from __future__ import annotations
2+
from typing_extensions import override
3+
import torch
4+
from .config import Config, no_default
5+
from .model import Model
6+
from ..util.rope import RopeSettings, RopeStyle
7+
from ..modules import RMSNorm, Embedding, TransformerBlock, Attention, BlockSparseMLP, GatedMLP, Linear
8+
from ..modules.attn import prepare_for_attn
9+
10+
class Dots1Config(Config):
11+
arch_string = "Dots1ForCausalLM"
12+
13+
def __init__(
14+
self,
15+
directory: str,
16+
**kwargs,
17+
):
18+
super().__init__(
19+
directory,
20+
{"text": Dots1Model},
21+
**kwargs
22+
)
23+
24+
# Attention params
25+
self.head_dim = self.read_cfg(int, "head_dim", None)
26+
self.hidden_size = self.read_cfg(int, "hidden_size", no_default)
27+
self.num_q_heads = self.read_cfg(int, "num_attention_heads", no_default)
28+
self.num_kv_heads = self.read_cfg(int, "num_key_value_heads", self.num_q_heads)
29+
30+
if not self.head_dim:
31+
self.head_dim = self.hidden_size // self.num_q_heads
32+
33+
# MLP params
34+
self.assert_cfg(str, "hidden_act", "silu", True)
35+
self.assert_cfg(str, "scoring_func", "noaux_tc", True)
36+
self.assert_cfg(bool, "norm_topk_prob", True, True)
37+
self.intermediate_size = self.read_cfg(int, "intermediate_size", no_default)
38+
self.moe_intermediate_size = self.read_cfg(int, "moe_intermediate_size", no_default)
39+
self.num_shared_experts = self.read_cfg(int, "n_shared_experts", 1)
40+
self.num_experts = self.read_cfg(int, "n_routed_experts", 128)
41+
self.num_experts_per_tok = self.read_cfg(int, "num_experts_per_tok", 8)
42+
self.first_k_dense_replace = self.read_cfg(int, "first_k_dense_replace", 3)
43+
self.routed_scaling_factor = self.read_cfg(float, "routed_scaling_factor", 2.5)
44+
45+
# Norms
46+
self.rms_norm_eps = self.read_cfg(float, "rms_norm_eps", no_default)
47+
48+
# Layers
49+
self.num_hidden_layers = self.read_cfg(int, "num_hidden_layers", no_default)
50+
self.tie_word_embeddings = self.read_cfg(bool, "tie_word_embeddings", False)
51+
52+
# RoPE
53+
self.rope_settings = self.read_rope_settings_default(RopeStyle.NEOX)
54+
55+
56+
class Dots1Model(Model):
57+
config_class = Dots1Config
58+
59+
def __init__(
60+
self,
61+
config: Dots1Config,
62+
**kwargs
63+
):
64+
super().__init__(config, **kwargs)
65+
66+
self.modules += [
67+
Embedding(
68+
config = config,
69+
key = "model.embed_tokens",
70+
vocab_size = config.vocab_size,
71+
hidden_size = config.hidden_size,
72+
)
73+
]
74+
75+
self.first_block_idx = len(self.modules)
76+
77+
self.modules += [
78+
TransformerBlock(
79+
config = config,
80+
key = f"model.layers.{idx}",
81+
attn_norm = RMSNorm(
82+
config = config,
83+
key = f"model.layers.{idx}.input_layernorm",
84+
rms_norm_eps = config.rms_norm_eps,
85+
),
86+
attn = Attention(
87+
config = config,
88+
key = f"model.layers.{idx}.self_attn",
89+
layer_idx = idx,
90+
hidden_size = config.hidden_size,
91+
head_dim = config.head_dim,
92+
num_q_heads = config.num_q_heads,
93+
num_kv_heads = config.num_kv_heads,
94+
rope_settings = config.rope_settings,
95+
sm_scale = None,
96+
key_q = "q_proj",
97+
key_k = "k_proj",
98+
key_v = "v_proj",
99+
key_o = "o_proj",
100+
qmap = "block.attn",
101+
q_norm = RMSNorm(
102+
config = config,
103+
key = f"model.layers.{idx}.self_attn.q_norm",
104+
rms_norm_eps = config.rms_norm_eps,
105+
),
106+
k_norm = RMSNorm(
107+
config = config,
108+
key = f"model.layers.{idx}.self_attn.k_norm",
109+
rms_norm_eps = config.rms_norm_eps,
110+
),
111+
out_dtype = torch.float
112+
),
113+
mlp_norm = RMSNorm(
114+
config = config,
115+
key = f"model.layers.{idx}.post_attention_layernorm",
116+
rms_norm_eps = config.rms_norm_eps,
117+
),
118+
mlp = (
119+
GatedMLP(
120+
config = config,
121+
key = f"model.layers.{idx}.mlp",
122+
hidden_size = config.hidden_size,
123+
intermediate_size = config.intermediate_size,
124+
key_up = "up_proj",
125+
key_gate = "gate_proj",
126+
key_down = "down_proj",
127+
qmap = "block.mlp",
128+
interm_dtype = torch.half,
129+
out_dtype = torch.float,
130+
)
131+
if idx < config.first_k_dense_replace else
132+
BlockSparseMLP(
133+
config = config,
134+
key = f"model.layers.{idx}.mlp",
135+
hidden_size = config.hidden_size,
136+
intermediate_size = config.moe_intermediate_size,
137+
num_experts = config.num_experts,
138+
num_experts_per_tok = config.num_experts_per_tok,
139+
key_up = "experts.{expert_idx}.up_proj",
140+
key_gate = "experts.{expert_idx}.gate_proj",
141+
key_down = "experts.{expert_idx}.down_proj",
142+
key_routing_gate = "gate",
143+
qmap = "block.mlp",
144+
interm_dtype = torch.half,
145+
out_dtype = torch.float,
146+
deepseekv3_routing = True,
147+
routed_scaling_factor = config.routed_scaling_factor,
148+
n_group = 1,
149+
topk_group = 1,
150+
shared_experts = GatedMLP(
151+
config = config,
152+
key = f"model.layers.{idx}.mlp.shared_experts",
153+
hidden_size = config.hidden_size,
154+
intermediate_size = config.moe_intermediate_size * config.num_shared_experts,
155+
key_up = "up_proj",
156+
key_gate = "gate_proj",
157+
key_down = "down_proj",
158+
qmap = "block.mlp",
159+
interm_dtype = torch.half,
160+
out_dtype = torch.float,
161+
),
162+
)
163+
)
164+
)
165+
for idx in range(config.num_hidden_layers)
166+
]
167+
168+
# TODO: The first attn.o_proj is irregular and breaks quantization. For now, skip quantizing it
169+
self.modules[self.first_block_idx].attn.o_proj.qmap = None
170+
171+
self.last_kv_module_idx = len(self.modules) - 1
172+
173+
head_alt_key = None
174+
if config.tie_word_embeddings and not self.config.stc.has_tensor("lm_head"):
175+
head_alt_key = "model.embed_tokens"
176+
177+
self.modules += [
178+
RMSNorm(
179+
config = config,
180+
key = "model.norm",
181+
rms_norm_eps = config.rms_norm_eps,
182+
out_dtype = torch.half,
183+
),
184+
Linear(
185+
config = config,
186+
key = "lm_head",
187+
qbits_key = "head_bits",
188+
alt_key = head_alt_key,
189+
in_features = config.hidden_size,
190+
out_features = config.vocab_size,
191+
qmap = "block",
192+
caps = {"logits_output": True}
193+
)
194+
]
195+
196+
self.logit_layer_idx = len(self.modules) - 1
197+
198+
# Activate all experts during H capture pass in quantization
199+
self.calibration_all_experts = True
200+
201+
202+
@override
203+
def prepare_inputs(self, input_ids: torch.Tensor, params: dict) -> torch.Tensor:
204+
params["input_ids"] = input_ids
205+
input_ids = prepare_for_attn(input_ids, params)
206+
return input_ids
207+
208+
209+
@override
210+
def default_chat_prompt(self, prompt: str, system_prompt: str = None) -> str:
211+
p = ""
212+
if system_prompt:
213+
p += f"<|im_start|>system\n"
214+
p += f"{system_prompt}<|im_end|>\n"
215+
p += f"<|im_start|>user\n"
216+
p += f"{prompt}<|im_end|>\n"
217+
p += f"<|im_start|>assistant\n"
218+
return p

0 commit comments

Comments
 (0)