Skip to content

Commit 3ffcc74

Browse files
committed
Support Index architecture
1 parent c91893a commit 3ffcc74

File tree

4 files changed

+39
-2
lines changed

4 files changed

+39
-2
lines changed

exllamav2/architecture.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,27 @@ def __init__(self, arch_string, read_config):
569569
self.rope_style = RopeStyle.NEOX
570570
self.fused_qkv_altpack = True
571571

572+
# Index
573+
574+
if arch_string == "IndexForCausalLM":
575+
arch_recognized = True
576+
self.layer_keys += \
577+
layer_keys_llama_norms + \
578+
layer_keys_llama_attn + \
579+
layer_keys_llama_mlp
580+
self.expect_keys += \
581+
expect_keys_llama
582+
self.norm_eps_key = "rms_norm_eps"
583+
self.mlp_key_gate = ".mlp.gate_proj"
584+
self.mlp_key_up = ".mlp.up_proj"
585+
self.mlp_key_down = ".mlp.down_proj"
586+
self.lm_head_key = "lm_head"
587+
self.norm_key_1 = ".input_layernorm"
588+
self.norm_key_2 = ".post_attention_layernorm"
589+
self.mlp_act_func = "silu"
590+
self.norm = "rmsnorm"
591+
self.rope_style = RopeStyle.NEOX
592+
572593
# Llama (default + fallback)
573594

574595
if arch_string != "LlamaForCausalLM" and not arch_recognized:

exllamav2/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class ExLlamaV2Config:
104104
final_logit_softcapping: float | None
105105
attn_logit_softcapping: float | None
106106
sliding_window: int
107+
norm_head: int | None
107108

108109
checkpoint_fused_mlp: bool
109110

@@ -251,6 +252,10 @@ def prepare(self, no_tensors: bool = False):
251252
self.attn_logit_softcapping = read(read_config, float, "attn_logit_softcapping", None)
252253
self.final_logit_softcapping = read(read_config, float, "final_logit_softcapping", None)
253254

255+
# Normalize weights in head layer
256+
257+
self.norm_head = read(read_config, int, "norm_head", None)
258+
254259
# Positional embeddings
255260

256261
self.rotary_embedding_base = read(read_config, float, ["rope_theta", "attn_config->rope_theta"], 10000.0)

exllamav2/linear.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def __init__(self,
5454
f_beg: int = None,
5555
f_end: int = None,
5656
is_sub_module: bool = True,
57-
altpack_qkv: bool = False):
57+
altpack_qkv: bool = False,
58+
normalize_unq: bool = False):
5859
super().__init__(model, key)
5960

6061
self.is_sub_module = is_sub_module
@@ -89,6 +90,7 @@ def __init__(self,
8990
self.altpack_qkv = altpack_qkv
9091

9192
self.assumed_footprint = in_features * (out_features + self.padding) * 2 + 128
93+
self.normalize_unq = normalize_unq
9294

9395

9496
@torch.inference_mode
@@ -125,6 +127,8 @@ def load(self,
125127

126128
elif isinstance(w, nn.Parameter):
127129
assert not self.has_bias, self.key + " has no bias tensor but bias is expected"
130+
if self.normalize_unq:
131+
w = self.normalize(w)
128132
if self.padding > 0: w = nn.Parameter(F.pad(w.data, (0, 0, 0, self.padding)).contiguous())
129133
if not self.model.config.load_in_q4 or not ".layers." in self.key:
130134
self.linear = nn.Linear(self.in_features, self.out_features, self.has_bias, device = "meta", dtype = torch.float16)
@@ -138,6 +142,8 @@ def load(self,
138142

139143
elif isinstance(w, tuple):
140144
assert self.has_bias, self.key + " has bias tensor but bias is not expected"
145+
if self.normalize_unq:
146+
w = self.normalize(w[0]), w[1]
141147
ww = w[0]
142148
wb = w[1]
143149
if self.padding > 0:
@@ -154,6 +160,10 @@ def load(self,
154160
self.fp16_bias = wb
155161

156162

163+
def normalize(self, w: torch.Tensor):
164+
return nn.functional.normalize(w)
165+
166+
157167
def matrix_shape(self):
158168

159169
return self.in_features, self.out_features

exllamav2/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ def __init__(self, config: ExLlamaV2Config, lazy_load = False):
250250
False,
251251
max_out_len = self.config.max_output_len,
252252
prescale = self.config.logit_scale,
253-
is_sub_module = False)
253+
is_sub_module = False,
254+
normalize_unq = bool(self.config.norm_head))
254255
if self.config.arch.lm_head_key != "lm_head":
255256
head.alt_key = self.config.arch.lm_head_key
256257
self.modules += [head]

0 commit comments

Comments
 (0)