1111import os
1212import json
1313from torch .nn import Module
14- from typing import Dict , Tuple , Set , Optional
14+ from typing import Any , Dict , Tuple , Set , Optional
1515
1616from torch .distributed ._tensor import DTensor
1717from torchchat .distributed .dtensor_utils import convert_to_dtensor
@@ -165,17 +165,18 @@ def load_safetensor_weights(
165165 Returns:
166166 Tuple[int, int]: Number of updated weights and number of missing weights.
167167 """
168- stage_state_dict , weight_map = prepare_state_dict (
169- stage_module , weight_map , purge_model_prefix
170- )
168+ stage_state_dict = stage_module .state_dict ()
169+ stage_state_dict = purge_fqn_prefix (stage_state_dict , "model." )
170+ weight_map = purge_fqn_prefix (weight_map , "model." )
171+
171172 needed_files = get_needed_files (stage_state_dict , weight_map )
172173 updated_states : Set [str ] = set ()
173174
174175 for file in needed_files :
175176 full_path = os .path .join (file_location , file )
176177 # logger.info(f"Loading checkpoint file: {full_path}")
177178 try :
178- checkpoint = load_checkpoint (full_path , "cpu" ) # device)
179+ checkpoint = load_safetensor_file (full_path , "cpu" ) # device)
179180
180181 update_state_dict (
181182 stage_state_dict ,
@@ -215,14 +216,11 @@ def load_safetensor_weights(
215216 return len (updated_states ), len (missing_keys )
216217
217218
218- def prepare_state_dict (
219- module : Module , weight_map : Dict [str , str ], purge_model_prefix : bool
219+ def purge_fqn_prefix (
220+ any_dict : Dict [str , Any ],
221+ prefix : str ,
220222) -> Dict [str , torch .Tensor ]:
221- state_dict = module .state_dict ()
222- if purge_model_prefix :
223- state_dict = {k .removeprefix ("model." ): v for k , v in state_dict .items ()}
224- weight_map = {k .removeprefix ("model." ): v for k , v in weight_map .items ()}
225- return state_dict , weight_map
223+ return {k .removeprefix (prefix ): v for k , v in any_dict .items ()}
226224
227225
228226def get_needed_files (
@@ -242,7 +240,7 @@ def get_needed_files(
242240 return needed_files
243241
244242
245- def load_checkpoint (full_path : str , device : torch .device ) -> Dict [str , torch .Tensor ]:
243+ def load_safetensor_file (full_path : str , device : torch .device ) -> Dict [str , torch .Tensor ]:
246244 tensors = {}
247245 with safe_open (full_path , framework = "pt" , device = device ) as f :
248246 for k in f .keys ():
0 commit comments