@@ -41,20 +41,14 @@ def convert_hf_checkpoint(
4141 config = TransformerArgs .from_params (config_args )
4242 print (f"Model config { config .__dict__ } " )
4343
44- # Load the json file containing weight mapping
44+ # Find all candidate weight mapping index files
4545 model_map_json_matches = [Path (m ) for m in glob .glob (str (model_dir / "*.index.json" ))]
46- if "mistral" not in model_name :
47- assert len (model_map_json_matches ) <= 1 , "Found multiple weight mapping files"
48- if len (model_map_json_matches ):
49- model_map_json = model_map_json_matches [0 ]
50- else :
51- model_map_json = model_dir / "pytorch_model.bin.index.json"
5246
5347 # If there is no weight mapping, check for a consolidated model and
5448 # tokenizer we can move. Llama 2 and Mistral have weight mappings, while
5549 # Llama 3 has a consolidated model and tokenizer.
5650 # Otherwise raise an error.
57- if not model_map_json . is_file () :
51+ if not model_map_json_matches :
5852 consolidated_pth = model_dir / "original" / "consolidated.00.pth"
5953 tokenizer_pth = model_dir / "original" / "tokenizer.model"
6054 if consolidated_pth .is_file () and tokenizer_pth .is_file ():
@@ -70,11 +64,29 @@ def convert_hf_checkpoint(
7064 return
7165 else :
7266 raise RuntimeError (
73- f"Could not find { model_map_json } or { consolidated_pth } plus { tokenizer_pth } "
67+ f"Could not find a valid model weight map or { consolidated_pth } plus { tokenizer_pth } "
7468 )
7569
76- with open (model_map_json ) as json_map :
77- bin_index = json .load (json_map )
70+ # Load the json file(s) containing weight mapping
71+ #
72+ # NOTE: If there are multiple index files, there are two possibilities:
73+ # 1. The files could be mapped to different weight format files (e.g. .bin
74+ # vs .safetensors)
75+ # 2. The files could be split subsets of the mappings that need to be
76+ # merged
77+ #
78+ # In either case, we can simply keep the mappings where the target file is
79+ # valid in the model dir.
80+ bin_files = {}
81+ for weight_map_file in model_map_json_matches :
82+ with open (weight_map_file , "r" ) as handle :
83+ weight_map = json .load (handle )
84+ valid_mappings = {
85+ k : model_dir / v
86+ for (k , v ) in weight_map .get ("weight_map" , {}).items ()
87+ if (model_dir / v ).is_file ()
88+ }
89+ bin_files .update (valid_mappings )
7890
7991 weight_map = {
8092 "model.embed_tokens.weight" : "tok_embeddings.weight" ,
@@ -98,7 +110,6 @@ def convert_hf_checkpoint(
98110 "model.norm.weight" : "norm.weight" ,
99111 "lm_head.weight" : "output.weight" ,
100112 }
101- bin_files = {model_dir / bin for bin in bin_index ["weight_map" ].values ()}
102113
103114 def permute (w , n_heads ):
104115 return (
0 commit comments