@@ -39,20 +39,14 @@ def convert_hf_checkpoint(
3939    config  =  TransformerArgs .from_params (config_args )
4040    print (f"Model config { config .__dict__ }  " )
4141
42-     # Load the json file containing  weight mapping 
42+     # Find all candidate  weight mapping index files  
4343    model_map_json_matches  =  [Path (m ) for  m  in  glob .glob (str (model_dir  /  "*.index.json" ))]
44-     if  "mistral"  not  in   model_name :
45-         assert  len (model_map_json_matches ) <=  1 , "Found multiple weight mapping files" 
46-     if  len (model_map_json_matches ):
47-         model_map_json  =  model_map_json_matches [0 ]
48-     else :
49-         model_map_json  =  model_dir  /  "pytorch_model.bin.index.json" 
5044
5145    # If there is no weight mapping, check for a consolidated model and 
5246    # tokenizer we can move. Llama 2 and Mistral have weight mappings, while 
5347    # Llama 3 has a consolidated model and tokenizer. 
5448    # Otherwise raise an error. 
55-     if  not  model_map_json . is_file () :
49+     if  not  model_map_json_matches :
5650        consolidated_pth  =  model_dir  /  "original"  /  "consolidated.00.pth" 
5751        tokenizer_pth  =  model_dir  /  "original"  /  "tokenizer.model" 
5852        if  consolidated_pth .is_file () and  tokenizer_pth .is_file ():
@@ -69,11 +63,29 @@ def convert_hf_checkpoint(
6963            return 
7064        else :
7165            raise  RuntimeError (
72-                 f"Could not find { model_map_json }   or { consolidated_pth }   plus { tokenizer_pth }  " 
66+                 f"Could not find a valid model weight map  or { consolidated_pth }   plus { tokenizer_pth }  " 
7367            )
7468
75-     with  open (model_map_json ) as  json_map :
76-         bin_index  =  json .load (json_map )
69+     # Load the json file(s) containing weight mapping 
70+     # 
71+     # NOTE: If there are multiple index files, there are two possibilities: 
72+     #   1. The files could be mapped to different weight format files (e.g. .bin 
73+     #       vs .safetensors) 
74+     #   2. The files could be split subsets of the mappings that need to be 
75+     #       merged 
76+     # 
77+     # In either case, we can simply keep the mappings where the target file is 
78+     # valid in the model dir. 
79+     bin_files  =  {}
80+     for  weight_map_file  in  model_map_json_matches :
81+         with  open (weight_map_file , "r" ) as  handle :
82+             weight_map  =  json .load (handle )
83+         valid_mappings  =  {
84+             k : model_dir  /  v 
85+             for  (k , v ) in  weight_map .get ("weight_map" , {}).items ()
86+             if  (model_dir  /  v ).is_file ()
87+         }
88+         bin_files .update (valid_mappings )
7789
7890    weight_map  =  {
7991        "model.embed_tokens.weight" : "tok_embeddings.weight" ,
@@ -97,7 +109,6 @@ def convert_hf_checkpoint(
97109        "model.norm.weight" : "norm.weight" ,
98110        "lm_head.weight" : "output.weight" ,
99111    }
100-     bin_files  =  {model_dir  /  bin  for  bin  in  bin_index ["weight_map" ].values ()}
101112
102113    def  permute (w , n_heads ):
103114        return  (
0 commit comments