1919from torchchat .cli .builder import BuilderArgs , _load_checkpoint
2020
2121
22- _DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json"
22+ _DEFAULT_SAFETENSOR_INDEX = "model.safetensors.index.json"
23+ _DEFAULT_BIN_INDEX = "pytorch_model.bin.index.json"
2324_CONFIG_NAME = "config.json"
2425
2526
@@ -81,31 +82,6 @@ def get_hf_path_from_model_id(model_id: str) -> str:
8182 return file_location
8283
8384
84- def get_hf_weight_map_and_path (
85- model_id : str ,
86- ) -> Tuple [Dict [str , str ], str , Dict [str , str ]]:
87- """Get the weight map for a given HF model id and also the cache path for loading the weights"""
88- index_file = cached_file (model_id , _DEFAULT_SAFETENSOR_FILE_NAME )
89- if not os .path .exists (index_file ):
90- raise FileNotFoundError (
91- f"Weight index file for { model_id } does not exist in HF cache."
92- )
93- logger .info (
94- f"Loading weight map from: { index_file } "
95- )
96- weight_map = read_weights_from_json (index_file )
97- if weight_map is None :
98- raise ValueError (f"Weight map not found in config file { index_file } " )
99- weight_map , new_to_old_keymap = remap_weight_keys (weight_map )
100- weight_path = os .path .dirname (index_file )
101- if not os .path .exists (weight_path ):
102- raise FileNotFoundError (f"Weight path { weight_path } does not exist" )
103- logger .info (
104- f"Loading weights from: { weight_path } "
105- )
106- return weight_map , weight_path , new_to_old_keymap
107-
108-
10985def remap_weight_keys (dictionary ):
11086 """Remap the keys of a dictionary to match the expected format of the tune model."""
11187 # hf_key : dist_model_key
@@ -141,12 +117,13 @@ def remap_weight_keys(dictionary):
141117 return new_dict , key_mapping
142118
143119
144- def load_safetensor_weights (
120+ def load_weights_per_map (
145121 stage_module : Module ,
146122 weight_map : Dict [str , str ],
147123 file_location : str ,
148124 new_to_old_keymap : Dict [str , str ],
149- device : torch .device = "cuda" ,
125+ device : torch .device ,
126+ is_safetensor : bool ,
150127 purge_model_prefix : bool = True ,
151128 ignore_cache_layers : bool = True ,
152129 model_config : Optional [Dict ] = None ,
@@ -160,6 +137,7 @@ def load_safetensor_weights(
160137 file_location (str): Directory containing the weight files.
161138 new_to_old_keymap (Dict[str, str]): Mapping of new parameter names to old ones.
162139 device (torch.device): The device to load tensors onto.
140+ is_safetensor (bool): Whether the files are safetensors.
163141 purge_model_prefix (bool): Whether to remove 'model.' prefix from keys.
164142 ignore_cache_layers (bool): Whether to ignore cache layers when reporting missing keys.
165143 model_config (Optional[Dict]): Model configuration.
@@ -178,9 +156,13 @@ def load_safetensor_weights(
178156 for file in needed_files :
179157 full_path = os .path .join (file_location , file )
180158 # logger.info(f"Loading checkpoint file: {full_path}")
181- try :
182- checkpoint = load_safetensor_file (full_path , "cpu" ) # device)
159+ # TODO: directly load to device
160+ if is_safetensor :
161+ checkpoint = load_safetensor_file (full_path )
162+ else :
163+ checkpoint = torch .load (full_path , mmap = True , weights_only = True )
183164
165+ try :
184166 update_state_dict (
185167 stage_state_dict ,
186168 checkpoint ,
@@ -189,10 +171,9 @@ def load_safetensor_weights(
189171 new_to_old_keymap = new_to_old_keymap ,
190172 updated_states = updated_states ,
191173 )
192- except FileNotFoundError :
193- logger .error (f"File not found: { full_path } " )
194174 except Exception as e :
195- logger .error (f"Error during checkpoint processing of { full_path } : { str (e )} " )
175+ logger .error (f"Error during checkpoint processing:" )
176+ raise e
196177
197178 missing_keys = handle_missing_keys (
198179 stage_state_dict , updated_states , ignore_cache_layers
@@ -244,12 +225,14 @@ def get_needed_files(
244225 return needed_files
245226
246227
247- def load_safetensor_file (full_path : str , device : torch .device ) -> Dict [str , torch .Tensor ]:
228+ def load_safetensor_file (
229+ full_path : str ,
230+ device : str = "cpu" ,
231+ ) -> Dict [str , torch .Tensor ]:
248232 tensors = {}
249233 with safe_open (full_path , framework = "pt" , device = device ) as f :
250234 for k in f .keys ():
251235 tensors [k ] = f .get_tensor (k )
252- logger .info (f"Loaded { len (tensors )} tensors from { full_path } " )
253236 return tensors
254237
255238
@@ -378,15 +361,35 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config
378361 files), and fill into `stage_module`. Model config is needed b/c we permute
379362 wq and wk weights based on attn heads.
380363 """
364+ # Get the weight map for a given HF model id
365+ try :
366+ index_file = cached_file (distribution , _DEFAULT_SAFETENSOR_INDEX )
367+ is_safetensor = True
368+ except :
369+ index_file = cached_file (distribution , _DEFAULT_BIN_INDEX )
370+ is_safetensor = False
371+ logger .info (f"Loading weight map from: { index_file } " )
372+
373+ # Read the weight map from the index file
374+ weight_map = read_weights_from_json (index_file )
375+ if weight_map is None :
376+ raise ValueError (f"Weight map not found in config file { index_file } " )
377+
378+ # Remap the FQNs to the FQNs in HF checkpoints
379+ weight_map , new_to_old_keymap = remap_weight_keys (weight_map )
381380
382- weight_map , weight_path , key_map = get_hf_weight_map_and_path (distribution )
381+ # Get the dir containing the weight files
382+ weight_dir = os .path .dirname (index_file )
383+ logger .info (f"Loading weights from: { weight_dir } " )
383384
384- num_loaded_weights , num_missing_weights = load_safetensor_weights (
385+ # Load the weights into the stage module
386+ num_loaded_weights , num_missing_weights = load_weights_per_map (
385387 stage_module ,
386388 weight_map ,
387- weight_path ,
388- key_map ,
389+ weight_dir ,
390+ new_to_old_keymap ,
389391 device ,
392+ is_safetensor ,
390393 model_config = model_config ,
391394 )
392395 logger .info (
0 commit comments