@@ -100,11 +100,10 @@ def convert_hf_checkpoint(
100100 bin_files = {model_dir / bin for bin in bin_index ["weight_map" ].values ()}
101101
102102 def permute (w , n_heads ):
103- dim = config .dim
104103 return (
105- w .view (n_heads , 2 , config .head_dim // 2 , dim )
104+ w .view (n_heads , 2 , config .head_dim // 2 , * w . shape [ 1 :] )
106105 .transpose (1 , 2 )
107- .reshape (config . head_dim * n_heads , dim )
106+ .reshape (w . shape )
108107 )
109108
110109 merged_result = {}
@@ -137,6 +136,7 @@ def load_safetensors():
137136 continue
138137 assert state_dict is not None , f"Unable to load tensors from { file } "
139138 merged_result .update (state_dict )
139+
140140 final_result = {}
141141 for key , value in merged_result .items ():
142142 if "layers" in key :
@@ -150,16 +150,20 @@ def load_safetensors():
150150 final_result [new_key ] = value
151151
152152 for key in tuple (final_result .keys ()):
153- if "wq.weight" in key :
153+ if "wq.weight" in key or "wq.bias" in key :
154+ wk_key = key .replace ("wq" , "wk" )
155+ wv_key = key .replace ("wq" , "wv" )
154156 q = final_result [key ]
155- k = final_result [key .replace ("wq" , "wk" )]
156- v = final_result [key .replace ("wq" , "wv" )]
157+ k = final_result [wk_key ]
158+ v = final_result [wv_key ]
159+ print (key )
157160 q = permute (q , config .n_heads )
161+ print (wk_key )
158162 k = permute (k , config .n_local_heads )
159163 final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
160164 del final_result [key ]
161- del final_result [key . replace ( "wq" , "wk" ) ]
162- del final_result [key . replace ( "wq" , "wv" ) ]
165+ del final_result [wk_key ]
166+ del final_result [wv_key ]
163167 print (f"Saving checkpoint to { model_dir / 'model.pth' } . This may take a while." )
164168 torch .save (final_result , model_dir / "model.pth" )
165169 print ("Done." )
0 commit comments