diff --git a/dist_run.py b/dist_run.py index e077ea9ca..b7d317ca3 100644 --- a/dist_run.py +++ b/dist_run.py @@ -125,18 +125,39 @@ def _load_model_weights( distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct". device (torch.device): The device to load the weights onto. model_config (ModelArgs): The model config. - chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf". + chpt_from (str): The checkpoint format to load the weights from. + + Valid chpt_from values: hf, tc, + or if you want to load from a specific dir, e.g. + tc:meta-llama/Meta-Llama-3-8B-Instruct-int8_wo """ - if chpt_from == "hf": + str_lst = chpt_from.split(":") + assert len(str_lst) == 1 or len(str_lst) == 2, "Invalid --chpt_from format" + # Get the checkpoint format, e.g. "hf" or "tc" + chpt_format = str_lst[0] + # If user also specified a checkpoint, such as + # `meta-llama/Meta-Llama-3-8B-Instruct-int8_wo` + chpt_distribution = str_lst[1] if len(str_lst) == 2 else distribution + + stage_state_dict = stage_module.state_dict() + if chpt_format == "hf": # This format stands for: index file + multiple binary files - load_weights_from_hf_format(stage_module, distribution, device, model_config) - elif chpt_from == "torchchat": + stage_state_dict = load_weights_from_hf_format( + stage_state_dict, chpt_distribution, device, model_config + ) + elif chpt_format == "tc": # This format stands for: # single binary file, OR # multiple binary files without index files. - load_weights_from_torchchat_format(stage_module, distribution, device, model_config) + stage_state_dict = load_weights_from_torchchat_format( + stage_state_dict, chpt_distribution, device, model_config + ) else: - raise ValueError(f"Unknown checkpoint format: {chpt_from}") + raise ValueError(f"Unknown checkpoint format: {chpt_format}") + + # Fill state dict into stage module + stage_module.load_state_dict(stage_state_dict, strict=False, assign=True) + logger.info(f"Successfully loaded weights into stage module") def _encode_strings( @@ -589,9 +610,9 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]: parser.add_argument( "--chpt-from", type=str, - default="hf", # TODO: change to torchchat once we support it well - help="Checkpoint format to load from", - choices=["hf", "torchchat"], + default="tc", + help="Checkpoint to load from, e.g. `hf` or `tc`, or " + "`tc:meta-llama/Meta-Llama-3-8B-Instruct-int8_wo`", ) args = parser.parse_args() diff --git a/torchchat/distributed/checkpoint_utils.py b/torchchat/distributed/checkpoint_utils.py index a20373fcb..8e58649fc 100644 --- a/torchchat/distributed/checkpoint_utils.py +++ b/torchchat/distributed/checkpoint_utils.py @@ -118,14 +118,12 @@ def remap_weight_keys(dictionary): def load_weights_per_map( - stage_module: Module, + stage_state_dict, weight_map: Dict[str, str], file_location: str, new_to_old_keymap: Dict[str, str], device: torch.device, is_safetensor: bool, - purge_model_prefix: bool = True, - ignore_cache_layers: bool = True, model_config: Optional[Dict] = None, ) -> Tuple[int, int]: """ @@ -138,18 +136,11 @@ def load_weights_per_map( new_to_old_keymap (Dict[str, str]): Mapping of new parameter names to old ones. device (torch.device): The device to load tensors onto. is_safetensor (bool): Whether the files are safetensors. - purge_model_prefix (bool): Whether to remove 'model.' prefix from keys. - ignore_cache_layers (bool): Whether to ignore cache layers when reporting missing keys. model_config (Optional[Dict]): Model configuration. Returns: Tuple[int, int]: Number of updated weights and number of missing weights. """ - stage_state_dict = stage_module.state_dict() - if purge_model_prefix: - stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.") - weight_map = purge_fqn_prefix(weight_map, "model.") - needed_files = get_needed_files(stage_state_dict, weight_map) updated_states: Set[str] = set() @@ -175,27 +166,9 @@ def load_weights_per_map( logger.error(f"Error during checkpoint processing:") raise e - missing_keys = handle_missing_keys( - stage_state_dict, updated_states, ignore_cache_layers + check_for_missing_keys( + stage_state_dict, updated_states, ignore_cache_layers=True ) - # log_loading_status(missing_keys, updated_states) - if missing_keys: - logger.warning( - f"Partially updated state dict. Missing {len(missing_keys)} keys: {missing_keys}" - ) - else: - logger.info("Fully updated state dict.") - - logger.info(f"Loading {len(updated_states)} weights into stage dict") - # precount, premap = record_module_dtypes(stage_module) - stage_module.load_state_dict(stage_state_dict, strict=False, assign=True) - # postcount, postmap = record_module_dtypes(stage_module) - # logger.info(f"{precount=}, {postcount=}") - # logger.info(f"{premap=}, {postmap=}") - - logger.info(f"Successfully loaded {len(updated_states)} weights into stage module") - - return len(updated_states), len(missing_keys) # TODO: clean this up together with `purge_fqn_prefix` when we switch @@ -287,14 +260,15 @@ def update_state_dict( checkpoint_tensor = checkpoint[old_param] model_tensor = state_dict[param] - if "wq" in param: - checkpoint_tensor = permute_weight_to_attn_heads( - checkpoint_tensor, num_heads, head_dim, dim - ) - elif "wk" in param: - checkpoint_tensor = permute_weight_to_attn_heads( - checkpoint_tensor, num_local_heads, head_dim, dim - ) + if new_to_old_keymap is not None: + if "wq" in param: + checkpoint_tensor = permute_weight_to_attn_heads( + checkpoint_tensor, num_heads, head_dim, dim + ) + elif "wk" in param: + checkpoint_tensor = permute_weight_to_attn_heads( + checkpoint_tensor, num_local_heads, head_dim, dim + ) # Move checkpoint tensor to desired device checkpoint_tensor = checkpoint_tensor.to(device) @@ -324,10 +298,10 @@ def clean_cache_keys(input_set: Set[str]) -> Set[str]: } -def handle_missing_keys( +def check_for_missing_keys( state_dict: Dict[str, torch.Tensor], updated_states: Set[str], - ignore_cache_layers: bool, + ignore_cache_layers: bool = True, ) -> Set[str]: """This function handles 'expected' missing keys from the checkpoint update set. This is used for ignoring cache, rope freqs, and mask layers that are generated, rather than persisted @@ -342,7 +316,13 @@ def handle_missing_keys( logger.info( f"Ignoring {start_len - after_len} missing cache, freqs, mask layers" ) - return missing_keys + + if len(missing_keys) > 0: + from itertools import islice + raise RuntimeError( + f"Missing {len(missing_keys)} weights, for example: " + f"{list(islice(missing_keys, 10))}" + ) def log_loading_status(missing_keys: Set[str], updated_states: Set[str]): @@ -355,10 +335,10 @@ def log_loading_status(missing_keys: Set[str], updated_states: Set[str]): logger.info(f"Successfully loaded {len(updated_states)} weights into stage module") -def load_weights_from_hf_format(stage_module, distribution, device, model_config): +def load_weights_from_hf_format(stage_state_dict, distribution, device, model_config): """ Load the weights from Hugging Face format (index file + multiple safetensor - files), and fill into `stage_module`. Model config is needed b/c we permute + files), and fill into `stage_state_dict`. Model config is needed b/c we permute wq and wk weights based on attn heads. """ # Get the weight map for a given HF model id @@ -382,9 +362,13 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config weight_dir = os.path.dirname(index_file) logger.info(f"Loading weights from: {weight_dir}") + # TODO: clean this up together with `purge_fqn_prefix` when we switch + stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.") + weight_map = purge_fqn_prefix(weight_map, "model.") + # Load the weights into the stage module - num_loaded_weights, num_missing_weights = load_weights_per_map( - stage_module, + load_weights_per_map( + stage_state_dict, weight_map, weight_dir, new_to_old_keymap, @@ -392,11 +376,7 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config is_safetensor, model_config=model_config, ) - logger.info( - f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights" - ) - if num_missing_weights > 0: - raise ValueError(f"Missing {num_missing_weights} weights") + return stage_state_dict # HACK: assuming single file for torchchat's converted checkpoints. We should @@ -406,13 +386,12 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config # will tell us if there is a single file or a directory. TORCHCHCAT_SINGLE_FILE_CHECKPOINT = True -def load_weights_from_torchchat_format(stage_module, distribution, device, model_config): +def load_weights_from_torchchat_format(stage_state_dict, distribution, device, model_config): """ Load the weights from torchchat format (single binary file), and fill into `stage_module`. Model config is needed b/c we permute wq and wk weights based on attn heads. """ - stage_state_dict = stage_module.state_dict() # TODO: clean this up together with `purge_fqn_prefix` when we switch stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.") @@ -437,6 +416,10 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model "checkpoint_path": checkpoint_path, } builder_args = BuilderArgs(**args_dict) + logger.info( + "Loading checkpoint from: " + f"{builder_args.checkpoint_dir or builder_args.checkpoint_path}" + ) # Then, load the checkpoint using torchchat util checkpoint = _load_checkpoint(builder_args) @@ -450,6 +433,8 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model updated_states=updated_states, ) - # Fill state dict into stage module - stage_module.load_state_dict(stage_state_dict, strict=False, assign=True) - logger.info(f"Successfully loaded {len(updated_states)} weights into stage module") + check_for_missing_keys( + stage_state_dict, updated_states, ignore_cache_layers=True + ) + + return stage_state_dict