Skip to content

Commit de19cbc

Browse files
committed
Add GLM4 architecture
1 parent b148bb4 commit de19cbc

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

examples/chat_prompts.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,50 @@ def print_extra_newline(self):
674674
return True
675675

676676

677+
class PromptFormat_glm(PromptFormat):
678+
description = "GLM4"
679+
680+
def __init__(self):
681+
super().__init__()
682+
pass
683+
684+
def default_system_prompt(self):
685+
return \
686+
f"""You are a helpful AI assistant."""
687+
688+
def first_prompt(self, sysprompt):
689+
r = """[gMASK]<sop>"""
690+
if sysprompt:
691+
r += \
692+
"""<|system|>\n""" + \
693+
"""<|system_prompt|>"""
694+
r += \
695+
"""<|user|>\n""" + \
696+
"""<|user_prompt|>""" + \
697+
"""<|assistant|>\n"""
698+
return r
699+
700+
def subs_prompt(self):
701+
return \
702+
"""<|user|>\n""" + \
703+
"""<|user_prompt|>""" + \
704+
"""<|assistant|>\n"""
705+
706+
def stop_conditions(self, tokenizer):
707+
return \
708+
[tokenizer.eos_token_id,
709+
tokenizer.single_id("<|user|>"),
710+
"""<|user|>""",
711+
]
712+
713+
def encoding_options(self):
714+
return True, False, True
715+
716+
def print_extra_newline(self):
717+
return True
718+
719+
720+
677721
prompt_formats = \
678722
{
679723
"raw": PromptFormat_raw,
@@ -693,4 +737,5 @@ def print_extra_newline(self):
693737
"phi3": PromptFormat_phi3,
694738
"granite": PromptFormat_granite,
695739
"granite3": PromptFormat_granite3,
740+
"glm": PromptFormat_glm
696741
}

exllamav2/architecture.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
["post_feedforward_layernorm"]]
1717
layer_keys_internlm2_norms = [["attention_norm"],
1818
["ffn_norm"]]
19+
layer_keys_glm4_norms = [["input_layernorm"],
20+
["post_self_attn_layernorm"],
21+
["post_attention_layernorm"],
22+
["post_mlp_layernorm"]]
1923
layer_keys_llama_attn = [["self_attn.q_proj"],
2024
["self_attn.k_proj"],
2125
["self_attn.v_proj"],
@@ -808,6 +812,28 @@ class Params:
808812
self.lm.expect_keys += \
809813
expect_keys_llama
810814

815+
# GLM4
816+
817+
if arch_string == "Glm4ForCausalLM":
818+
arch_recognized = True
819+
self.lm.layer_keys += \
820+
layer_keys_glm4_norms + \
821+
layer_keys_llama_attn + \
822+
layer_keys_phi3_mlp
823+
self.lm.expect_keys += \
824+
expect_keys_llama
825+
self.lm.supports_tp = True
826+
self.lm.rope_style = RopeStyle.GPTJ
827+
self.lm.keys.update({
828+
"fused_mlp_12": "gate_up_proj",
829+
"lm_head": "model.embed_tokens",
830+
"norm_1": ".input_layernorm",
831+
"norm_1_post": ".post_self_attn_layernorm",
832+
"norm_2": ".post_attention_layernorm",
833+
"norm_2_post": ".post_mlp_layernorm",
834+
})
835+
self.lm.attention_bias_qkv = True
836+
811837
# Llama (default + fallback)
812838

813839
if arch_string != "LlamaForCausalLM" and not arch_recognized:
@@ -825,7 +851,7 @@ class Params:
825851

826852
# Arch overrides
827853

828-
if read_config.get("attention_bias", False):
854+
if read_config.get("attention_bias", False) and not (self.lm.attention_bias_qkv or self.lm.attention_bias_o):
829855
self.lm.attention_bias_qkv = True
830856
self.lm.attention_bias_o = True
831857

0 commit comments

Comments
 (0)