33
44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
6- import glob
76import json
87import os
98import re
@@ -42,12 +41,7 @@ def convert_hf_checkpoint(
4241 print (f"Model config { config .__dict__ } " )
4342
4443 # Load the json file containing weight mapping
45- model_map_json_matches = [Path (m ) for m in glob .glob (str (model_dir / "*.index.json" ))]
46- assert len (model_map_json_matches ) <= 1 , "Found multiple weight mapping files"
47- if len (model_map_json_matches ):
48- model_map_json = model_map_json_matches [0 ]
49- else :
50- model_map_json = model_dir / "pytorch_model.bin.index.json"
44+ model_map_json = model_dir / "pytorch_model.bin.index.json"
5145
5246 # If there is no weight mapping, check for a consolidated model and
5347 # tokenizer we can move. Llama 2 and Mistral have weight mappings, while
@@ -62,9 +56,10 @@ def convert_hf_checkpoint(
6256 str (consolidated_pth ), map_location = "cpu" , mmap = True , weights_only = True
6357 )
6458 del loaded_result # No longer needed
65- print (f"Moving checkpoint to { model_dir / 'model.pth' } ." )
66- os .rename (consolidated_pth , model_dir / "model.pth" )
67- os .rename (tokenizer_pth , model_dir / "tokenizer.model" )
59+ print (f"Symlinking checkpoint to { model_dir / 'model.pth' } ." )
60+ consolidated_pth = os .path .realpath (consolidated_pth )
61+ os .symlink (consolidated_pth , model_dir / "model.pth" )
62+ os .symlink (tokenizer_pth , model_dir / "tokenizer.model" )
6863 print ("Done." )
6964 return
7065 else :
@@ -81,17 +76,10 @@ def convert_hf_checkpoint(
8176 "model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
8277 "model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
8378 "model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
84- "model.layers.{}.self_attn.q_proj.bias" : "layers.{}.attention.wq.bias" ,
85- "model.layers.{}.self_attn.k_proj.bias" : "layers.{}.attention.wk.bias" ,
86- "model.layers.{}.self_attn.v_proj.bias" : "layers.{}.attention.wv.bias" ,
87- "model.layers.{}.self_attn.o_proj.bias" : "layers.{}.attention.wo.bias" ,
8879 "model.layers.{}.self_attn.rotary_emb.inv_freq" : None ,
8980 "model.layers.{}.mlp.gate_proj.weight" : "layers.{}.feed_forward.w1.weight" ,
9081 "model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
9182 "model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
92- "model.layers.{}.mlp.gate_proj.bias" : "layers.{}.feed_forward.w1.bias" ,
93- "model.layers.{}.mlp.up_proj.bias" : "layers.{}.feed_forward.w3.bias" ,
94- "model.layers.{}.mlp.down_proj.bias" : "layers.{}.feed_forward.w2.bias" ,
9583 "model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
9684 "model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
9785 "model.norm.weight" : "norm.weight" ,
@@ -100,43 +88,19 @@ def convert_hf_checkpoint(
10088 bin_files = {model_dir / bin for bin in bin_index ["weight_map" ].values ()}
10189
10290 def permute (w , n_heads ):
91+ dim = config .dim
10392 return (
104- w .view (n_heads , 2 , config .head_dim // 2 , * w . shape [ 1 :] )
93+ w .view (n_heads , 2 , config .head_dim // 2 , dim )
10594 .transpose (1 , 2 )
106- .reshape (w . shape )
95+ .reshape (config . head_dim * n_heads , dim )
10796 )
10897
10998 merged_result = {}
11099 for file in sorted (bin_files ):
111-
112- # The state_dict can be loaded from either a torch zip file or
113- # safetensors. We take our best guess from the name and try all
114- # possibilities
115- load_pt_mmap = lambda : torch .load (
100+ state_dict = torch .load (
116101 str (file ), map_location = "cpu" , mmap = True , weights_only = True
117102 )
118- load_pt_no_mmap = lambda : torch .load (
119- str (file ), map_location = "cpu" , mmap = False , weights_only = True
120- )
121- def load_safetensors ():
122- import safetensors .torch
123- with open (file , "rb" ) as handle :
124- return safetensors .torch .load (handle .read ())
125- if "safetensors" in str (file ):
126- loaders = [load_safetensors , load_pt_mmap , load_pt_no_mmap ]
127- else :
128- loaders = [load_pt_mmap , load_pt_no_mmap , load_safetensors ]
129-
130- state_dict = None
131- for loader in loaders :
132- try :
133- state_dict = loader ()
134- break
135- except Exception :
136- continue
137- assert state_dict is not None , f"Unable to load tensors from { file } "
138103 merged_result .update (state_dict )
139-
140104 final_result = {}
141105 for key , value in merged_result .items ():
142106 if "layers" in key :
@@ -152,18 +116,16 @@ def load_safetensors():
152116 final_result [new_key ] = value
153117
154118 for key in tuple (final_result .keys ()):
155- if "wq.weight" in key or "wq.bias" in key :
156- wk_key = key .replace ("wq" , "wk" )
157- wv_key = key .replace ("wq" , "wv" )
119+ if "wq" in key :
158120 q = final_result [key ]
159- k = final_result [wk_key ]
160- v = final_result [wv_key ]
121+ k = final_result [key . replace ( "wq" , "wk" ) ]
122+ v = final_result [key . replace ( "wq" , "wv" ) ]
161123 q = permute (q , config .n_heads )
162124 k = permute (k , config .n_local_heads )
163125 final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
164126 del final_result [key ]
165- del final_result [wk_key ]
166- del final_result [wv_key ]
127+ del final_result [key . replace ( "wq" , "wk" ) ]
128+ del final_result [key . replace ( "wq" , "wv" ) ]
167129 print (f"Saving checkpoint to { model_dir / 'model.pth' } . This may take a while." )
168130 torch .save (final_result , model_dir / "model.pth" )
169131 print ("Done." )
@@ -184,10 +146,10 @@ def convert_hf_checkpoint_to_tune(
184146 consolidated_pth = model_dir / "original" / "consolidated.pth"
185147 tokenizer_pth = model_dir / "original" / "tokenizer.model"
186148 if consolidated_pth .is_file () and tokenizer_pth .is_file ():
187- print (f"Moving checkpoint to { model_dir / 'model.pth' } ." )
188- os .rename (consolidated_pth , model_dir / "model.pth" )
189- print (f"Moving tokenizer to { model_dir / 'tokenizer.model' } ." )
190- os .rename (tokenizer_pth , model_dir / "tokenizer.model" )
149+ print (f"Creating symlink from { consolidated_pth } to { model_dir / 'model.pth' } ." )
150+ os .symlink (consolidated_pth , model_dir / "model.pth" )
151+ print (f"Creating symlink from { tokenizer_pth } to { model_dir / 'tokenizer.model' } ." )
152+ os .symlink (tokenizer_pth , model_dir / "tokenizer.model" )
191153 print ("Done." )
192154 else :
193155 raise RuntimeError (f"Could not find { consolidated_pth } " )
0 commit comments