-
Notifications
You must be signed in to change notification settings - Fork 248
[Distributed] Fix correctness issue in TC load path #1276
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -118,14 +118,12 @@ def remap_weight_keys(dictionary): | |
|
|
||
|
|
||
| def load_weights_per_map( | ||
| stage_module: Module, | ||
| stage_state_dict, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typing is missing? |
||
| weight_map: Dict[str, str], | ||
| file_location: str, | ||
| new_to_old_keymap: Dict[str, str], | ||
| device: torch.device, | ||
| is_safetensor: bool, | ||
| purge_model_prefix: bool = True, | ||
| ignore_cache_layers: bool = True, | ||
| model_config: Optional[Dict] = None, | ||
| ) -> Tuple[int, int]: | ||
| """ | ||
|
|
@@ -138,18 +136,11 @@ def load_weights_per_map( | |
| new_to_old_keymap (Dict[str, str]): Mapping of new parameter names to old ones. | ||
| device (torch.device): The device to load tensors onto. | ||
| is_safetensor (bool): Whether the files are safetensors. | ||
| purge_model_prefix (bool): Whether to remove 'model.' prefix from keys. | ||
| ignore_cache_layers (bool): Whether to ignore cache layers when reporting missing keys. | ||
| model_config (Optional[Dict]): Model configuration. | ||
|
|
||
| Returns: | ||
| Tuple[int, int]: Number of updated weights and number of missing weights. | ||
| """ | ||
| stage_state_dict = stage_module.state_dict() | ||
| if purge_model_prefix: | ||
| stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.") | ||
| weight_map = purge_fqn_prefix(weight_map, "model.") | ||
|
|
||
| needed_files = get_needed_files(stage_state_dict, weight_map) | ||
| updated_states: Set[str] = set() | ||
|
|
||
|
|
@@ -175,27 +166,9 @@ def load_weights_per_map( | |
| logger.error(f"Error during checkpoint processing:") | ||
| raise e | ||
|
|
||
| missing_keys = handle_missing_keys( | ||
| stage_state_dict, updated_states, ignore_cache_layers | ||
| check_for_missing_keys( | ||
| stage_state_dict, updated_states, ignore_cache_layers=True | ||
| ) | ||
| # log_loading_status(missing_keys, updated_states) | ||
| if missing_keys: | ||
| logger.warning( | ||
| f"Partially updated state dict. Missing {len(missing_keys)} keys: {missing_keys}" | ||
| ) | ||
| else: | ||
| logger.info("Fully updated state dict.") | ||
|
|
||
| logger.info(f"Loading {len(updated_states)} weights into stage dict") | ||
| # precount, premap = record_module_dtypes(stage_module) | ||
| stage_module.load_state_dict(stage_state_dict, strict=False, assign=True) | ||
| # postcount, postmap = record_module_dtypes(stage_module) | ||
| # logger.info(f"{precount=}, {postcount=}") | ||
| # logger.info(f"{premap=}, {postmap=}") | ||
|
|
||
| logger.info(f"Successfully loaded {len(updated_states)} weights into stage module") | ||
|
|
||
| return len(updated_states), len(missing_keys) | ||
|
|
||
|
|
||
| # TODO: clean this up together with `purge_fqn_prefix` when we switch | ||
|
|
@@ -287,14 +260,15 @@ def update_state_dict( | |
| checkpoint_tensor = checkpoint[old_param] | ||
| model_tensor = state_dict[param] | ||
|
|
||
| if "wq" in param: | ||
| checkpoint_tensor = permute_weight_to_attn_heads( | ||
| checkpoint_tensor, num_heads, head_dim, dim | ||
| ) | ||
| elif "wk" in param: | ||
| checkpoint_tensor = permute_weight_to_attn_heads( | ||
| checkpoint_tensor, num_local_heads, head_dim, dim | ||
| ) | ||
| if new_to_old_keymap is not None: | ||
| if "wq" in param: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so this was the issue! It seems like a flag as to the type of checkpoint is more robust and then map type to need for permuting? Seems brittle if we have future checkpoints and we permute or not solely based on a fqn (i.e. generically, having 'wq' in a name doesn't actually imply anything inherently about need to permute)?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am using |
||
| checkpoint_tensor = permute_weight_to_attn_heads( | ||
| checkpoint_tensor, num_heads, head_dim, dim | ||
| ) | ||
| elif "wk" in param: | ||
| checkpoint_tensor = permute_weight_to_attn_heads( | ||
| checkpoint_tensor, num_local_heads, head_dim, dim | ||
| ) | ||
|
|
||
| # Move checkpoint tensor to desired device | ||
| checkpoint_tensor = checkpoint_tensor.to(device) | ||
|
|
@@ -324,10 +298,10 @@ def clean_cache_keys(input_set: Set[str]) -> Set[str]: | |
| } | ||
|
|
||
|
|
||
| def handle_missing_keys( | ||
| def check_for_missing_keys( | ||
| state_dict: Dict[str, torch.Tensor], | ||
| updated_states: Set[str], | ||
| ignore_cache_layers: bool, | ||
| ignore_cache_layers: bool = True, | ||
| ) -> Set[str]: | ||
| """This function handles 'expected' missing keys from the checkpoint update set. | ||
| This is used for ignoring cache, rope freqs, and mask layers that are generated, rather than persisted | ||
|
|
@@ -342,7 +316,13 @@ def handle_missing_keys( | |
| logger.info( | ||
| f"Ignoring {start_len - after_len} missing cache, freqs, mask layers" | ||
| ) | ||
| return missing_keys | ||
|
|
||
| if len(missing_keys) > 0: | ||
| from itertools import islice | ||
| raise RuntimeError( | ||
| f"Missing {len(missing_keys)} weights, for example: " | ||
| f"{list(islice(missing_keys, 10))}" | ||
| ) | ||
|
|
||
|
|
||
| def log_loading_status(missing_keys: Set[str], updated_states: Set[str]): | ||
|
|
@@ -355,10 +335,10 @@ def log_loading_status(missing_keys: Set[str], updated_states: Set[str]): | |
| logger.info(f"Successfully loaded {len(updated_states)} weights into stage module") | ||
|
|
||
|
|
||
| def load_weights_from_hf_format(stage_module, distribution, device, model_config): | ||
| def load_weights_from_hf_format(stage_state_dict, distribution, device, model_config): | ||
| """ | ||
| Load the weights from Hugging Face format (index file + multiple safetensor | ||
| files), and fill into `stage_module`. Model config is needed b/c we permute | ||
| files), and fill into `stage_state_dict`. Model config is needed b/c we permute | ||
| wq and wk weights based on attn heads. | ||
| """ | ||
| # Get the weight map for a given HF model id | ||
|
|
@@ -382,21 +362,21 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config | |
| weight_dir = os.path.dirname(index_file) | ||
| logger.info(f"Loading weights from: {weight_dir}") | ||
|
|
||
| # TODO: clean this up together with `purge_fqn_prefix` when we switch | ||
| stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.") | ||
| weight_map = purge_fqn_prefix(weight_map, "model.") | ||
|
|
||
| # Load the weights into the stage module | ||
| num_loaded_weights, num_missing_weights = load_weights_per_map( | ||
| stage_module, | ||
| load_weights_per_map( | ||
| stage_state_dict, | ||
| weight_map, | ||
| weight_dir, | ||
| new_to_old_keymap, | ||
| device, | ||
| is_safetensor, | ||
| model_config=model_config, | ||
| ) | ||
| logger.info( | ||
| f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights" | ||
| ) | ||
| if num_missing_weights > 0: | ||
| raise ValueError(f"Missing {num_missing_weights} weights") | ||
| return stage_state_dict | ||
|
|
||
|
|
||
| # HACK: assuming single file for torchchat's converted checkpoints. We should | ||
|
|
@@ -406,13 +386,12 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config | |
| # will tell us if there is a single file or a directory. | ||
| TORCHCHCAT_SINGLE_FILE_CHECKPOINT = True | ||
|
|
||
| def load_weights_from_torchchat_format(stage_module, distribution, device, model_config): | ||
| def load_weights_from_torchchat_format(stage_state_dict, distribution, device, model_config): | ||
| """ | ||
| Load the weights from torchchat format (single binary file), and fill into | ||
| `stage_module`. Model config is needed b/c we permute wq and wk weights | ||
| based on attn heads. | ||
| """ | ||
| stage_state_dict = stage_module.state_dict() | ||
| # TODO: clean this up together with `purge_fqn_prefix` when we switch | ||
| stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.") | ||
|
|
||
|
|
@@ -437,6 +416,10 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model | |
| "checkpoint_path": checkpoint_path, | ||
| } | ||
| builder_args = BuilderArgs(**args_dict) | ||
| logger.info( | ||
| "Loading checkpoint from: " | ||
| f"{builder_args.checkpoint_dir or builder_args.checkpoint_path}" | ||
| ) | ||
| # Then, load the checkpoint using torchchat util | ||
| checkpoint = _load_checkpoint(builder_args) | ||
|
|
||
|
|
@@ -450,6 +433,8 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model | |
| updated_states=updated_states, | ||
| ) | ||
|
|
||
| # Fill state dict into stage module | ||
| stage_module.load_state_dict(stage_state_dict, strict=False, assign=True) | ||
| logger.info(f"Successfully loaded {len(updated_states)} weights into stage module") | ||
| check_for_missing_keys( | ||
| stage_state_dict, updated_states, ignore_cache_layers=True | ||
| ) | ||
|
|
||
| return stage_state_dict | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as in prev comments, pls switch to 'chkpt' for abbreviating checkpoint. It reads as 'chapter_from' when using 'chpt_from'.