@@ -59,86 +59,106 @@ def get_q_module(job, module):
59
59
@torch .inference_mode ()
60
60
def compile_model (job , save_fn , model ):
61
61
62
+ cfg = model .config
62
63
out_dict = {}
63
64
current_size = 0
64
65
file_index = 1
65
66
index = 0
66
67
shard_bytes = job ["shard_size" ] * 1024 ** 2
67
68
68
- while index < len (model .modules ):
69
+ extra_tensors = []
70
+ if cfg .arch .mmp_prefix :
71
+ extra_tensors += [k for k in cfg .tensor_file_map .keys () if k .startswith (cfg .arch .mmp_prefix )]
72
+ if cfg .arch .vt_prefix :
73
+ extra_tensors += [k for k in cfg .tensor_file_map .keys () if k .startswith (cfg .arch .vt_prefix )]
74
+ extra_tensors_size = 0
69
75
70
- module = model .modules [ index ]
76
+ while index < len ( model .modules ) or len ( extra_tensors ):
71
77
72
- if isinstance ( module , ExLlamaV2Embedding ):
78
+ if index < len ( model . modules ):
73
79
74
- d = get_f_module ( job , module ); out_dict . update ( d ); current_size += _dsize ( d )
80
+ module = model . modules [ index ]
75
81
76
- if isinstance (module , ExLlamaV2PosEmbedding ):
82
+ if isinstance (module , ExLlamaV2Embedding ):
77
83
78
- d = get_f_module (job , module ); out_dict .update (d ); current_size += _dsize (d )
84
+ d = get_f_module (job , module ); out_dict .update (d ); current_size += _dsize (d )
79
85
80
- if isinstance (module , ExLlamaV2Attention ):
86
+ if isinstance (module , ExLlamaV2PosEmbedding ):
81
87
82
- d = get_f_module (job , module .pre_layernorm )
83
- if d : out_dict .update (d ); current_size += _dsize (d )
84
- d = get_f_module (job , module .post_layernorm )
85
- if d : out_dict .update (d ); current_size += _dsize (d )
86
- d = get_q_module (job , module .q_proj ); out_dict .update (d ); current_size += _dsize (d )
87
- d = get_q_module (job , module .k_proj ); out_dict .update (d ); current_size += _dsize (d )
88
- d = get_q_module (job , module .v_proj ); out_dict .update (d ); current_size += _dsize (d )
89
- d = get_q_module (job , module .o_proj ); out_dict .update (d ); current_size += _dsize (d )
88
+ d = get_f_module (job , module ); out_dict .update (d ); current_size += _dsize (d )
90
89
91
- if isinstance (module , ExLlamaV2MLP ):
90
+ if isinstance (module , ExLlamaV2Attention ):
92
91
93
- has_gate = model . config . arch . lm . mlp_gate
94
- d = get_f_module ( job , module . pre_layernorm )
95
- if d : out_dict . update ( d ); current_size += _dsize ( d )
96
- d = get_f_module ( job , module . post_layernorm )
97
- if d : out_dict .update (d ); current_size += _dsize (d )
98
- if has_gate : d = get_q_module (job , module .gate_proj ); out_dict .update (d ); current_size += _dsize (d )
99
- d = get_q_module (job , module .up_proj ); out_dict .update (d ); current_size += _dsize (d )
100
- d = get_q_module (job , module .down_proj ); out_dict .update (d ); current_size += _dsize (d )
92
+ d = get_f_module ( job , module . pre_layernorm )
93
+ if d : out_dict . update ( d ); current_size += _dsize ( d )
94
+ d = get_f_module ( job , module . post_layernorm )
95
+ if d : out_dict . update ( d ); current_size += _dsize ( d )
96
+ d = get_q_module ( job , module . q_proj ); out_dict .update (d ); current_size += _dsize (d )
97
+ d = get_q_module (job , module .k_proj ); out_dict .update (d ); current_size += _dsize (d )
98
+ d = get_q_module (job , module .v_proj ); out_dict .update (d ); current_size += _dsize (d )
99
+ d = get_q_module (job , module .o_proj ); out_dict .update (d ); current_size += _dsize (d )
101
100
102
- if isinstance (module , ExLlamaV2MoEMLP ):
101
+ if isinstance (module , ExLlamaV2MLP ):
103
102
104
- d = get_f_module (job , module .post_attention_layernorm ); out_dict .update (d ); current_size += _dsize (d )
105
- d = get_f_module (job , module .gate ); out_dict .update (d ); current_size += _dsize (d )
106
- for i in range (model .config .num_experts ):
107
- d = get_q_module (job , module .w1 [i ]); out_dict .update (d ); current_size += _dsize (d )
108
- d = get_q_module (job , module .w3 [i ]); out_dict .update (d ); current_size += _dsize (d )
109
- d = get_q_module (job , module .w2 [i ]); out_dict .update (d ); current_size += _dsize (d )
103
+ has_gate = model .config .arch .lm .mlp_gate
104
+ d = get_f_module (job , module .pre_layernorm )
105
+ if d : out_dict .update (d ); current_size += _dsize (d )
106
+ d = get_f_module (job , module .post_layernorm )
107
+ if d : out_dict .update (d ); current_size += _dsize (d )
108
+ if has_gate : d = get_q_module (job , module .gate_proj ); out_dict .update (d ); current_size += _dsize (d )
109
+ d = get_q_module (job , module .up_proj ); out_dict .update (d ); current_size += _dsize (d )
110
+ d = get_q_module (job , module .down_proj ); out_dict .update (d ); current_size += _dsize (d )
110
111
111
- if isinstance (module , ExLlamaV2ParallelDecoder ):
112
+ if isinstance (module , ExLlamaV2MoEMLP ):
112
113
113
- has_gate = model .config .arch .lm .mlp_gate
114
- has_qk_norm = model .config .use_qk_norm
115
- d = get_f_module (job , module .input_layernorm ); out_dict .update (d ); current_size += _dsize (d )
116
- d = get_q_module (job , module .attn .q_proj ); out_dict .update (d ); current_size += _dsize (d )
117
- d = get_q_module (job , module .attn .k_proj ); out_dict .update (d ); current_size += _dsize (d )
118
- d = get_q_module (job , module .attn .v_proj ); out_dict .update (d ); current_size += _dsize (d )
119
- d = get_q_module (job , module .attn .o_proj ); out_dict .update (d ); current_size += _dsize (d )
120
- if has_qk_norm :
121
- d = get_f_module (job , module .attn .q_norm ); out_dict .update (d ); current_size += _dsize (d )
122
- d = get_f_module (job , module .attn .k_norm ); out_dict .update (d ); current_size += _dsize (d )
123
- if has_gate :
124
- d = get_q_module (job , module .mlp .gate_proj ); out_dict .update (d ); current_size += _dsize (d )
125
- d = get_q_module (job , module .mlp .up_proj ); out_dict .update (d ); current_size += _dsize (d )
126
- d = get_q_module (job , module .mlp .down_proj ); out_dict .update (d ); current_size += _dsize (d )
114
+ d = get_f_module (job , module .post_attention_layernorm ); out_dict .update (d ); current_size += _dsize (d )
115
+ d = get_f_module (job , module .gate ); out_dict .update (d ); current_size += _dsize (d )
116
+ for i in range (model .config .num_experts ):
117
+ d = get_q_module (job , module .w1 [i ]); out_dict .update (d ); current_size += _dsize (d )
118
+ d = get_q_module (job , module .w3 [i ]); out_dict .update (d ); current_size += _dsize (d )
119
+ d = get_q_module (job , module .w2 [i ]); out_dict .update (d ); current_size += _dsize (d )
127
120
128
- if isinstance ( module , ExLlamaV2RMSNorm ) or isinstance (module , ExLlamaV2LayerNorm ):
121
+ if isinstance (module , ExLlamaV2ParallelDecoder ):
129
122
130
- d = get_f_module (job , module ); out_dict .update (d ); current_size += _dsize (d )
123
+ has_gate = model .config .arch .lm .mlp_gate
124
+ has_qk_norm = model .config .use_qk_norm
125
+ d = get_f_module (job , module .input_layernorm ); out_dict .update (d ); current_size += _dsize (d )
126
+ d = get_q_module (job , module .attn .q_proj ); out_dict .update (d ); current_size += _dsize (d )
127
+ d = get_q_module (job , module .attn .k_proj ); out_dict .update (d ); current_size += _dsize (d )
128
+ d = get_q_module (job , module .attn .v_proj ); out_dict .update (d ); current_size += _dsize (d )
129
+ d = get_q_module (job , module .attn .o_proj ); out_dict .update (d ); current_size += _dsize (d )
130
+ if has_qk_norm :
131
+ d = get_f_module (job , module .attn .q_norm ); out_dict .update (d ); current_size += _dsize (d )
132
+ d = get_f_module (job , module .attn .k_norm ); out_dict .update (d ); current_size += _dsize (d )
133
+ if has_gate :
134
+ d = get_q_module (job , module .mlp .gate_proj ); out_dict .update (d ); current_size += _dsize (d )
135
+ d = get_q_module (job , module .mlp .up_proj ); out_dict .update (d ); current_size += _dsize (d )
136
+ d = get_q_module (job , module .mlp .down_proj ); out_dict .update (d ); current_size += _dsize (d )
131
137
132
- if isinstance (module , ExLlamaV2Linear ):
138
+ if isinstance (module , ExLlamaV2RMSNorm ) or isinstance ( module , ExLlamaV2LayerNorm ):
133
139
134
- assert module .key == "lm_head"
135
- d = get_q_module (job , module ); out_dict .update (d ); current_size += _dsize (d )
140
+ d = get_f_module (job , module ); out_dict .update (d ); current_size += _dsize (d )
136
141
137
- index += 1
142
+ if isinstance (module , ExLlamaV2Linear ):
143
+
144
+ assert module .key == cfg .arch .lm_prefix + "lm_head"
145
+ d = get_q_module (job , module ); out_dict .update (d ); current_size += _dsize (d )
146
+
147
+ index += 1
148
+
149
+ else :
150
+
151
+ key = extra_tensors [0 ]
152
+ extra_tensors = extra_tensors [1 :]
153
+ file = cfg .tensor_file_map [key ]
154
+ with safe_open (file , framework = "pt" ) as f :
155
+ tensor = f .get_tensor (key )
156
+ out_dict .update ({key : tensor })
157
+ extra_tensors_size += _tsize (tensor )
138
158
139
159
# Save shard
140
160
141
- if current_size > shard_bytes or index == len (model .modules ):
161
+ if current_size > shard_bytes or ( index == len (model .modules ) and len ( extra_tensors ) == 0 ):
142
162
143
163
print_stage (job , "Compiling" , index , len (model .modules ))
144
164
@@ -175,7 +195,7 @@ def compile_model(job, save_fn, model):
175
195
176
196
out_dict = dont_save_dict
177
197
178
- if index == len (model .modules ) and len (out_dict ) > 0 :
198
+ if index == len (model .modules ) and len (extra_tensors ) == 0 and len ( out_dict ) > 0 :
179
199
save_dict = dont_save_dict
180
200
dont_save_dict = {}
181
201
continue
@@ -203,6 +223,9 @@ def compile_model(job, save_fn, model):
203
223
filesize = os .path .getsize (final_filename ) // (1024 ** 2 )
204
224
print (f" -- { final_filename } ({ filesize :,} MB)" )
205
225
226
+ if extra_tensors_size :
227
+ print (f" -- Tensors copied (MM components): { extra_tensors_size // (1024 ** 2 ):,} MB" )
228
+
206
229
# Copy all non-tensor files from the model's directory if compiling a full model
207
230
208
231
if job ["compile_full" ] is not None :
0 commit comments