@@ -166,8 +166,9 @@ def load_safetensor_weights(
166166 Tuple[int, int]: Number of updated weights and number of missing weights.
167167 """
168168 stage_state_dict = stage_module .state_dict ()
169- stage_state_dict = purge_fqn_prefix (stage_state_dict , "model." )
170- weight_map = purge_fqn_prefix (weight_map , "model." )
169+ if purge_model_prefix :
170+ stage_state_dict = purge_fqn_prefix (stage_state_dict , "model." )
171+ weight_map = purge_fqn_prefix (weight_map , "model." )
171172
172173 needed_files = get_needed_files (stage_state_dict , weight_map )
173174 updated_states : Set [str ] = set ()
@@ -181,9 +182,7 @@ def load_safetensor_weights(
181182 update_state_dict (
182183 stage_state_dict ,
183184 checkpoint ,
184- weight_map ,
185185 new_to_old_keymap ,
186- file ,
187186 updated_states ,
188187 device ,
189188 model_config ,
@@ -216,10 +215,13 @@ def load_safetensor_weights(
216215 return len (updated_states ), len (missing_keys )
217216
218217
218+ # TODO: clean this up together with `purge_fqn_prefix` when we switch
219+ # from creating Transformer to creating model
219220def purge_fqn_prefix (
220221 any_dict : Dict [str , Any ],
221222 prefix : str ,
222- ) -> Dict [str , torch .Tensor ]:
223+ ) -> Dict [str , Any ]:
224+ """Remove a prefix from all keys in a dictionary."""
223225 return {k .removeprefix (prefix ): v for k , v in any_dict .items ()}
224226
225227
@@ -262,64 +264,52 @@ def permute_weight_to_attn_heads(w, n_heads, head_dim, model_dim):
262264def update_state_dict (
263265 state_dict : Dict [str , torch .Tensor ],
264266 checkpoint : Dict [str , torch .Tensor ],
265- weight_map : Dict [str , str ],
266267 new_to_old_keymap : Dict [str , str ],
267- file : str ,
268268 updated_states : Set [str ],
269269 device : torch .device ,
270270 model_config : Optional [Dict ] = None ,
271271):
272- count_dtensors_loaded = 0
273272 # for handling attn head permuting
274273 num_heads = model_config .n_heads
275274 dim = model_config .dim
276275 num_local_heads = model_config .n_local_heads
277276 head_dim = model_config .head_dim
278277
279- for param , file_with_param in weight_map .items ():
280- if file_with_param == file and param in state_dict :
281- model_param = (
282- "output.weight" if param == "output.weight" else f"model.{ param } "
278+ for param in state_dict .keys ():
279+ # TODO: clean this up together with `purge_fqn_prefix` when we switch
280+ # from creating Transformer to creating model
281+ model_param = (
282+ "output.weight" if param == "output.weight" else f"model.{ param } "
283+ )
284+ old_param = new_to_old_keymap .get (model_param )
285+
286+ if old_param not in checkpoint :
287+ # Maybe this param is in other files
288+ continue
289+
290+ checkpoint_tensor = checkpoint [old_param ]
291+ model_tensor = state_dict [param ]
292+
293+ if "wq" in param :
294+ checkpoint_tensor = permute_weight_to_attn_heads (
295+ checkpoint_tensor , num_heads , head_dim , dim
296+ )
297+ elif "wk" in param :
298+ checkpoint_tensor = permute_weight_to_attn_heads (
299+ checkpoint_tensor , num_local_heads , head_dim , dim
283300 )
284- old_param = new_to_old_keymap .get (model_param )
285-
286- if old_param not in checkpoint :
287- logger .warning (f"Missing { old_param } in checkpoint" )
288- continue
289-
290- checkpoint_tensor = checkpoint [old_param ]
291- model_tensor = state_dict [param ]
292-
293- if "wq" in param :
294- checkpoint_tensor = permute_weight_to_attn_heads (
295- checkpoint_tensor , num_heads , head_dim , dim
296- )
297- elif "wk" in param :
298- checkpoint_tensor = permute_weight_to_attn_heads (
299- checkpoint_tensor , num_local_heads , head_dim , dim
300- )
301-
302- # Move checkpoint tensor to desired device
303- checkpoint_tensor = checkpoint_tensor .to (device )
304-
305- # here we need to check if the tensor is a DTensor and if so, adjust the
306- # shape and placement to match the model DTensor.
307- if isinstance (model_tensor , DTensor ):
308- state_dict [param ] = convert_to_dtensor (checkpoint_tensor , model_tensor )
309- count_dtensors_loaded += 1
310- else :
311- # regular tensor, just update directly
312- state_dict [param ] = checkpoint_tensor
313-
314- # ensure matching dtypes
315- state_dict [param ] = state_dict [param ].to (checkpoint_tensor .dtype )
316-
317- assert state_dict [param ].dtype == checkpoint_tensor .dtype
318-
319- # log_tensor_info(param, state_dict[param])
320- # logger.info(f"Loaded {param} from {file}")
321- updated_states .add (param )
322- # logger.info(f"Count of loaded DTensors: {count_dtensors_loaded}")
301+
302+ # Move checkpoint tensor to desired device
303+ checkpoint_tensor = checkpoint_tensor .to (device )
304+
305+ # here we need to check if the tensor is a DTensor and if so, adjust the
306+ # shape and placement to match the model DTensor.
307+ if isinstance (model_tensor , DTensor ):
308+ checkpoint_tensor = convert_to_dtensor (checkpoint_tensor , model_tensor )
309+
310+ # Update model state dict with checkpoint tensor
311+ state_dict [param ] = checkpoint_tensor
312+ updated_states .add (param )
323313
324314
325315def format_tensor_info (tensor : torch .Tensor ) -> str :
0 commit comments