@@ -118,14 +118,12 @@ def remap_weight_keys(dictionary):
118118
119119
120120def load_weights_per_map (
121- stage_module : Module ,
121+ stage_state_dict ,
122122 weight_map : Dict [str , str ],
123123 file_location : str ,
124124 new_to_old_keymap : Dict [str , str ],
125125 device : torch .device ,
126126 is_safetensor : bool ,
127- purge_model_prefix : bool = True ,
128- ignore_cache_layers : bool = True ,
129127 model_config : Optional [Dict ] = None ,
130128) -> Tuple [int , int ]:
131129 """
@@ -138,18 +136,11 @@ def load_weights_per_map(
138136 new_to_old_keymap (Dict[str, str]): Mapping of new parameter names to old ones.
139137 device (torch.device): The device to load tensors onto.
140138 is_safetensor (bool): Whether the files are safetensors.
141- purge_model_prefix (bool): Whether to remove 'model.' prefix from keys.
142- ignore_cache_layers (bool): Whether to ignore cache layers when reporting missing keys.
143139 model_config (Optional[Dict]): Model configuration.
144140
145141 Returns:
146142 Tuple[int, int]: Number of updated weights and number of missing weights.
147143 """
148- stage_state_dict = stage_module .state_dict ()
149- if purge_model_prefix :
150- stage_state_dict = purge_fqn_prefix (stage_state_dict , "model." )
151- weight_map = purge_fqn_prefix (weight_map , "model." )
152-
153144 needed_files = get_needed_files (stage_state_dict , weight_map )
154145 updated_states : Set [str ] = set ()
155146
@@ -175,27 +166,9 @@ def load_weights_per_map(
175166 logger .error (f"Error during checkpoint processing:" )
176167 raise e
177168
178- missing_keys = handle_missing_keys (
179- stage_state_dict , updated_states , ignore_cache_layers
169+ check_for_missing_keys (
170+ stage_state_dict , updated_states , ignore_cache_layers = True
180171 )
181- # log_loading_status(missing_keys, updated_states)
182- if missing_keys :
183- logger .warning (
184- f"Partially updated state dict. Missing { len (missing_keys )} keys: { missing_keys } "
185- )
186- else :
187- logger .info ("Fully updated state dict." )
188-
189- logger .info (f"Loading { len (updated_states )} weights into stage dict" )
190- # precount, premap = record_module_dtypes(stage_module)
191- stage_module .load_state_dict (stage_state_dict , strict = False , assign = True )
192- # postcount, postmap = record_module_dtypes(stage_module)
193- # logger.info(f"{precount=}, {postcount=}")
194- # logger.info(f"{premap=}, {postmap=}")
195-
196- logger .info (f"Successfully loaded { len (updated_states )} weights into stage module" )
197-
198- return len (updated_states ), len (missing_keys )
199172
200173
201174# TODO: clean this up together with `purge_fqn_prefix` when we switch
@@ -287,14 +260,15 @@ def update_state_dict(
287260 checkpoint_tensor = checkpoint [old_param ]
288261 model_tensor = state_dict [param ]
289262
290- if "wq" in param :
291- checkpoint_tensor = permute_weight_to_attn_heads (
292- checkpoint_tensor , num_heads , head_dim , dim
293- )
294- elif "wk" in param :
295- checkpoint_tensor = permute_weight_to_attn_heads (
296- checkpoint_tensor , num_local_heads , head_dim , dim
297- )
263+ if new_to_old_keymap is not None :
264+ if "wq" in param :
265+ checkpoint_tensor = permute_weight_to_attn_heads (
266+ checkpoint_tensor , num_heads , head_dim , dim
267+ )
268+ elif "wk" in param :
269+ checkpoint_tensor = permute_weight_to_attn_heads (
270+ checkpoint_tensor , num_local_heads , head_dim , dim
271+ )
298272
299273 # Move checkpoint tensor to desired device
300274 checkpoint_tensor = checkpoint_tensor .to (device )
@@ -324,10 +298,10 @@ def clean_cache_keys(input_set: Set[str]) -> Set[str]:
324298 }
325299
326300
327- def handle_missing_keys (
301+ def check_for_missing_keys (
328302 state_dict : Dict [str , torch .Tensor ],
329303 updated_states : Set [str ],
330- ignore_cache_layers : bool ,
304+ ignore_cache_layers : bool = True ,
331305) -> Set [str ]:
332306 """This function handles 'expected' missing keys from the checkpoint update set.
333307 This is used for ignoring cache, rope freqs, and mask layers that are generated, rather than persisted
@@ -342,7 +316,13 @@ def handle_missing_keys(
342316 logger .info (
343317 f"Ignoring { start_len - after_len } missing cache, freqs, mask layers"
344318 )
345- return missing_keys
319+
320+ if len (missing_keys ) > 0 :
321+ from itertools import islice
322+ raise RuntimeError (
323+ f"Missing { len (missing_keys )} weights, for example: "
324+ f"{ list (islice (missing_keys , 10 ))} "
325+ )
346326
347327
348328def log_loading_status (missing_keys : Set [str ], updated_states : Set [str ]):
@@ -355,10 +335,10 @@ def log_loading_status(missing_keys: Set[str], updated_states: Set[str]):
355335 logger .info (f"Successfully loaded { len (updated_states )} weights into stage module" )
356336
357337
358- def load_weights_from_hf_format (stage_module , distribution , device , model_config ):
338+ def load_weights_from_hf_format (stage_state_dict , distribution , device , model_config ):
359339 """
360340 Load the weights from Hugging Face format (index file + multiple safetensor
361- files), and fill into `stage_module `. Model config is needed b/c we permute
341+ files), and fill into `stage_state_dict `. Model config is needed b/c we permute
362342 wq and wk weights based on attn heads.
363343 """
364344 # Get the weight map for a given HF model id
@@ -382,21 +362,21 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config
382362 weight_dir = os .path .dirname (index_file )
383363 logger .info (f"Loading weights from: { weight_dir } " )
384364
365+ # TODO: clean this up together with `purge_fqn_prefix` when we switch
366+ stage_state_dict = purge_fqn_prefix (stage_state_dict , "model." )
367+ weight_map = purge_fqn_prefix (weight_map , "model." )
368+
385369 # Load the weights into the stage module
386- num_loaded_weights , num_missing_weights = load_weights_per_map (
387- stage_module ,
370+ load_weights_per_map (
371+ stage_state_dict ,
388372 weight_map ,
389373 weight_dir ,
390374 new_to_old_keymap ,
391375 device ,
392376 is_safetensor ,
393377 model_config = model_config ,
394378 )
395- logger .info (
396- f"Success - Loaded { num_loaded_weights } weights, { num_missing_weights } missing weights"
397- )
398- if num_missing_weights > 0 :
399- raise ValueError (f"Missing { num_missing_weights } weights" )
379+ return stage_state_dict
400380
401381
402382# HACK: assuming single file for torchchat's converted checkpoints. We should
@@ -406,13 +386,12 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config
406386# will tell us if there is a single file or a directory.
407387TORCHCHCAT_SINGLE_FILE_CHECKPOINT = True
408388
409- def load_weights_from_torchchat_format (stage_module , distribution , device , model_config ):
389+ def load_weights_from_torchchat_format (stage_state_dict , distribution , device , model_config ):
410390 """
411391 Load the weights from torchchat format (single binary file), and fill into
412392 `stage_module`. Model config is needed b/c we permute wq and wk weights
413393 based on attn heads.
414394 """
415- stage_state_dict = stage_module .state_dict ()
416395 # TODO: clean this up together with `purge_fqn_prefix` when we switch
417396 stage_state_dict = purge_fqn_prefix (stage_state_dict , "model." )
418397
@@ -437,6 +416,10 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model
437416 "checkpoint_path" : checkpoint_path ,
438417 }
439418 builder_args = BuilderArgs (** args_dict )
419+ logger .info (
420+ "Loading checkpoint from: "
421+ f"{ builder_args .checkpoint_dir or builder_args .checkpoint_path } "
422+ )
440423 # Then, load the checkpoint using torchchat util
441424 checkpoint = _load_checkpoint (builder_args )
442425
@@ -450,6 +433,8 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model
450433 updated_states = updated_states ,
451434 )
452435
453- # Fill state dict into stage module
454- stage_module .load_state_dict (stage_state_dict , strict = False , assign = True )
455- logger .info (f"Successfully loaded { len (updated_states )} weights into stage module" )
436+ check_for_missing_keys (
437+ stage_state_dict , updated_states , ignore_cache_layers = True
438+ )
439+
440+ return stage_state_dict
0 commit comments