diff --git a/torchchat/distributed/checkpoint_utils.py b/torchchat/distributed/checkpoint_utils.py index cf3206e4e..a20373fcb 100644 --- a/torchchat/distributed/checkpoint_utils.py +++ b/torchchat/distributed/checkpoint_utils.py @@ -19,7 +19,8 @@ from torchchat.cli.builder import BuilderArgs, _load_checkpoint -_DEFAULT_SAFETENSOR_FILE_NAME = "model.safetensors.index.json" +_DEFAULT_SAFETENSOR_INDEX = "model.safetensors.index.json" +_DEFAULT_BIN_INDEX = "pytorch_model.bin.index.json" _CONFIG_NAME = "config.json" @@ -81,31 +82,6 @@ def get_hf_path_from_model_id(model_id: str) -> str: return file_location -def get_hf_weight_map_and_path( - model_id: str, -) -> Tuple[Dict[str, str], str, Dict[str, str]]: - """Get the weight map for a given HF model id and also the cache path for loading the weights""" - index_file = cached_file(model_id, _DEFAULT_SAFETENSOR_FILE_NAME) - if not os.path.exists(index_file): - raise FileNotFoundError( - f"Weight index file for {model_id} does not exist in HF cache." - ) - logger.info( - f"Loading weight map from: {index_file}" - ) - weight_map = read_weights_from_json(index_file) - if weight_map is None: - raise ValueError(f"Weight map not found in config file {index_file}") - weight_map, new_to_old_keymap = remap_weight_keys(weight_map) - weight_path = os.path.dirname(index_file) - if not os.path.exists(weight_path): - raise FileNotFoundError(f"Weight path {weight_path} does not exist") - logger.info( - f"Loading weights from: {weight_path}" - ) - return weight_map, weight_path, new_to_old_keymap - - def remap_weight_keys(dictionary): """Remap the keys of a dictionary to match the expected format of the tune model.""" # hf_key : dist_model_key @@ -141,12 +117,13 @@ def remap_weight_keys(dictionary): return new_dict, key_mapping -def load_safetensor_weights( +def load_weights_per_map( stage_module: Module, weight_map: Dict[str, str], file_location: str, new_to_old_keymap: Dict[str, str], - device: torch.device = "cuda", + device: torch.device, + is_safetensor: bool, purge_model_prefix: bool = True, ignore_cache_layers: bool = True, model_config: Optional[Dict] = None, @@ -160,6 +137,7 @@ def load_safetensor_weights( file_location (str): Directory containing the weight files. new_to_old_keymap (Dict[str, str]): Mapping of new parameter names to old ones. device (torch.device): The device to load tensors onto. + is_safetensor (bool): Whether the files are safetensors. purge_model_prefix (bool): Whether to remove 'model.' prefix from keys. ignore_cache_layers (bool): Whether to ignore cache layers when reporting missing keys. model_config (Optional[Dict]): Model configuration. @@ -178,9 +156,13 @@ def load_safetensor_weights( for file in needed_files: full_path = os.path.join(file_location, file) # logger.info(f"Loading checkpoint file: {full_path}") - try: - checkpoint = load_safetensor_file(full_path, "cpu") # device) + # TODO: directly load to device + if is_safetensor: + checkpoint = load_safetensor_file(full_path) + else: + checkpoint = torch.load(full_path, mmap=True, weights_only=True) + try: update_state_dict( stage_state_dict, checkpoint, @@ -189,10 +171,9 @@ def load_safetensor_weights( new_to_old_keymap=new_to_old_keymap, updated_states=updated_states, ) - except FileNotFoundError: - logger.error(f"File not found: {full_path}") except Exception as e: - logger.error(f"Error during checkpoint processing of {full_path}: {str(e)}") + logger.error(f"Error during checkpoint processing:") + raise e missing_keys = handle_missing_keys( stage_state_dict, updated_states, ignore_cache_layers @@ -244,12 +225,14 @@ def get_needed_files( return needed_files -def load_safetensor_file(full_path: str, device: torch.device) -> Dict[str, torch.Tensor]: +def load_safetensor_file( + full_path: str, + device: str = "cpu", +) -> Dict[str, torch.Tensor]: tensors = {} with safe_open(full_path, framework="pt", device=device) as f: for k in f.keys(): tensors[k] = f.get_tensor(k) - logger.info(f"Loaded {len(tensors)} tensors from {full_path}") return tensors @@ -378,15 +361,35 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config files), and fill into `stage_module`. Model config is needed b/c we permute wq and wk weights based on attn heads. """ + # Get the weight map for a given HF model id + try: + index_file = cached_file(distribution, _DEFAULT_SAFETENSOR_INDEX) + is_safetensor = True + except: + index_file = cached_file(distribution, _DEFAULT_BIN_INDEX) + is_safetensor = False + logger.info(f"Loading weight map from: {index_file}") + + # Read the weight map from the index file + weight_map = read_weights_from_json(index_file) + if weight_map is None: + raise ValueError(f"Weight map not found in config file {index_file}") + + # Remap the FQNs to the FQNs in HF checkpoints + weight_map, new_to_old_keymap = remap_weight_keys(weight_map) - weight_map, weight_path, key_map = get_hf_weight_map_and_path(distribution) + # Get the dir containing the weight files + weight_dir = os.path.dirname(index_file) + logger.info(f"Loading weights from: {weight_dir}") - num_loaded_weights, num_missing_weights = load_safetensor_weights( + # Load the weights into the stage module + num_loaded_weights, num_missing_weights = load_weights_per_map( stage_module, weight_map, - weight_path, - key_map, + weight_dir, + new_to_old_keymap, device, + is_safetensor, model_config=model_config, ) logger.info(