Skip to content

Commit ebfefc4

Browse files
committed
Support Cohere2 architecture
1 parent d815f5f commit ebfefc4

File tree

4 files changed

+49
-9
lines changed

4 files changed

+49
-9
lines changed

exllamav2/architecture.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,28 @@ class Params:
515515
self.lm.parallel_decoder_blocks = True
516516
self.lm.requires_bos = True
517517

518+
# Cohere 2
519+
520+
if arch_string == "Cohere2ForCausalLM":
521+
arch_recognized = True
522+
self.lm.layer_keys += \
523+
layer_keys_cohere_norms + \
524+
layer_keys_llama_attn + \
525+
layer_keys_llama_mlp
526+
self.lm.expect_keys += \
527+
expect_keys_gemma
528+
self.lm.keys.update({
529+
"norm_eps": "layer_norm_eps",
530+
"lm_head": "model.embed_tokens",
531+
"norm_1": ".input_layernorm",
532+
"norm_2": None,
533+
})
534+
self.lm.norm = "layernorm"
535+
self.lm.rope_style = RopeStyle.GPTJ
536+
self.lm.parallel_decoder_blocks = True
537+
self.lm.requires_bos = True
538+
self.lm.alternating_swa = True
539+
518540
# DBRX
519541

520542
if arch_string == "DbrxForCausalLM":

exllamav2/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class ExLlamaV2Config:
115115
final_logit_softcapping: float | None
116116
attn_logit_softcapping: float | None
117117
sliding_window: int
118+
sliding_window_pattern: int
118119
norm_head: int | None
119120
l3_rope_factor: float | None
120121
l3_rope_low_freq_factor: float | None
@@ -347,6 +348,7 @@ def prepare(self, no_tensors: bool = False):
347348
self.original_max_seq_len = self.max_seq_len
348349

349350
self.sliding_window = read(read_config, int, ["sliding_window", "sliding_window_size"], 0, opt_subkey = "text_config")
351+
self.sliding_window_pattern = read(read_config, int, ["sliding_window_pattern"], 1)
350352

351353
rs = read(read_config, dict, "rope_scaling", None)
352354
if rs:

exllamav2/model.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,18 @@ def __init__(
106106
for layer_idx in range(cfg.num_hidden_layers):
107107

108108
layer_key = cfg.arch.lm_prefix + f"model.layers.{layer_idx}"
109+
110+
if cfg.arch.lm.alternating_swa:
111+
swa = cfg.sliding_window if (layer_idx + 1) % cfg.sliding_window_pattern != 0 else 0
112+
elif cfg.arch.lm.swa:
113+
swa = cfg.sliding_window
114+
else:
115+
swa = 0
116+
109117
if cfg.arch.lm.parallel_decoder_blocks:
110-
pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx)
118+
pd = ExLlamaV2ParallelDecoder(self, layer_key, layer_idx, sliding_window = swa)
111119
self.modules += [pd]
112120
else:
113-
if cfg.arch.lm.alternating_swa:
114-
swa = cfg.sliding_window if not bool(layer_idx % 2) else 0
115-
elif cfg.arch.lm.swa:
116-
swa = cfg.sliding_window
117-
else:
118-
swa = 0
119121
attn = ExLlamaV2Attention(self, layer_key, layer_idx, sliding_window = swa)
120122
if cfg.arch.lm.is_moe: mlp = ExLlamaV2MoEMLP(self, layer_key, layer_idx)
121123
else: mlp = ExLlamaV2MLP(self, layer_key, layer_idx)

exllamav2/parallel_decoder.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
model: ExLlamaV2,
3030
key: str,
3131
layer_idx: int,
32+
sliding_window: int = 0,
3233
archparams = None
3334
):
3435
super().__init__(model, key, archparams)
@@ -42,8 +43,21 @@ def __init__(
4243
elif self.archparams.norm == "rmsnorm":
4344
self.input_layernorm = ExLlamaV2RMSNorm(model, key + self.archparams.keys["norm_1"])
4445

45-
self.attn = ExLlamaV2Attention(model, key, layer_idx, has_norm = False, has_residual = False)
46-
self.mlp = ExLlamaV2MLP(model, key, layer_idx, has_norm = False, has_residual = False)
46+
self.attn = ExLlamaV2Attention(
47+
model,
48+
key,
49+
layer_idx,
50+
has_norm = False,
51+
has_residual = False,
52+
sliding_window = sliding_window
53+
)
54+
self.mlp = ExLlamaV2MLP(
55+
model,
56+
key,
57+
layer_idx,
58+
has_norm = False,
59+
has_residual = False
60+
)
4761

4862
self.submodules = self.attn.submodules + self.mlp.submodules
4963

0 commit comments

Comments
 (0)