77from ..modules import RMSNorm , Embedding , TransformerBlock , Attention , GatedMLP , Linear , BlockSparseMLP
88from ..modules .attn import prepare_for_attn
99
10+ from typing import TYPE_CHECKING
11+ if TYPE_CHECKING :
12+ from .glm4v_moe import Glm4VMoeConfig
13+
1014class Glm4MoeConfig (Config ):
1115 arch_string = "Glm4MoeForCausalLM"
1216
@@ -58,15 +62,16 @@ class Glm4MoeModel(Model):
5862
5963 def __init__ (
6064 self ,
61- config : Glm4MoeConfig ,
65+ config : Glm4MoeConfig | Glm4VMoeConfig ,
66+ key_prefix : str = "model" ,
6267 ** kwargs
6368 ):
6469 super ().__init__ (config , ** kwargs )
6570
6671 self .modules += [
6772 Embedding (
6873 config = config ,
69- key = "model .embed_tokens" ,
74+ key = f" { key_prefix } .embed_tokens" ,
7075 vocab_size = config .vocab_size ,
7176 hidden_size = config .hidden_size ,
7277 )
@@ -77,15 +82,15 @@ def __init__(
7782 self .modules += [
7883 TransformerBlock (
7984 config = config ,
80- key = f"model .layers.{ idx } " ,
85+ key = f"{ key_prefix } .layers.{ idx } " ,
8186 attn_norm = RMSNorm (
8287 config = config ,
83- key = f"model .layers.{ idx } .input_layernorm" ,
88+ key = f"{ key_prefix } .layers.{ idx } .input_layernorm" ,
8489 rms_norm_eps = config .rms_norm_eps ,
8590 ),
8691 attn = Attention (
8792 config = config ,
88- key = f"model .layers.{ idx } .self_attn" ,
93+ key = f"{ key_prefix } .layers.{ idx } .self_attn" ,
8994 layer_idx = idx ,
9095 hidden_size = config .hidden_size ,
9196 head_dim = config .head_dim ,
@@ -100,25 +105,25 @@ def __init__(
100105 qmap = "block.attn" ,
101106 q_norm = RMSNorm (
102107 config = config ,
103- key = f"model .layers.{ idx } .self_attn.q_norm" ,
108+ key = f"{ key_prefix } .layers.{ idx } .self_attn.q_norm" ,
104109 rms_norm_eps = config .rms_norm_eps ,
105110 ) if config .use_qk_norm else None ,
106111 k_norm = RMSNorm (
107112 config = config ,
108- key = f"model .layers.{ idx } .self_attn.k_norm" ,
113+ key = f"{ key_prefix } .layers.{ idx } .self_attn.k_norm" ,
109114 rms_norm_eps = config .rms_norm_eps ,
110115 ) if config .use_qk_norm else None ,
111116 out_dtype = torch .float
112117 ),
113118 mlp_norm = RMSNorm (
114119 config = config ,
115- key = f"model .layers.{ idx } .post_attention_layernorm" ,
120+ key = f"{ key_prefix } .layers.{ idx } .post_attention_layernorm" ,
116121 rms_norm_eps = config .rms_norm_eps ,
117122 ),
118123 mlp = (
119124 GatedMLP (
120125 config = config ,
121- key = f"model .layers.{ idx } .mlp" ,
126+ key = f"{ key_prefix } .layers.{ idx } .mlp" ,
122127 hidden_size = config .hidden_size ,
123128 intermediate_size = config .intermediate_size ,
124129 key_up = "up_proj" ,
@@ -131,7 +136,7 @@ def __init__(
131136 if idx < config .first_k_dense_replace else
132137 BlockSparseMLP (
133138 config = config ,
134- key = f"model .layers.{ idx } .mlp" ,
139+ key = f"{ key_prefix } .layers.{ idx } .mlp" ,
135140 hidden_size = config .hidden_size ,
136141 intermediate_size = config .moe_intermediate_size ,
137142 num_experts = config .num_experts ,
@@ -150,7 +155,7 @@ def __init__(
150155 topk_group = 1 ,
151156 shared_experts = GatedMLP (
152157 config = config ,
153- key = f"model .layers.{ idx } .mlp.shared_experts" ,
158+ key = f"{ key_prefix } .layers.{ idx } .mlp.shared_experts" ,
154159 hidden_size = config .hidden_size ,
155160 intermediate_size = config .moe_intermediate_size * config .num_shared_experts ,
156161 key_up = "up_proj" ,
@@ -170,12 +175,12 @@ def __init__(
170175
171176 head_alt_key = None
172177 if config .tie_word_embeddings and not self .config .stc .has_tensor ("lm_head" ):
173- head_alt_key = "model .embed_tokens"
178+ head_alt_key = f" { key_prefix } .embed_tokens"
174179
175180 self .modules += [
176181 RMSNorm (
177182 config = config ,
178- key = "model .norm" ,
183+ key = f" { key_prefix } .norm" ,
179184 rms_norm_eps = config .rms_norm_eps ,
180185 out_dtype = torch .half ,
181186 ),
0 commit comments