Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,39 @@ def _load_model_weights(
distribution (str): The distribution name, e.g. "meta-llama/Meta-Llama-3-8B-Instruct".
device (torch.device): The device to load the weights onto.
model_config (ModelArgs): The model config.
chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf".
chpt_from (str): The checkpoint format to load the weights from.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as in prev comments, pls switch to 'chkpt' for abbreviating checkpoint. It reads as 'chapter_from' when using 'chpt_from'.


Valid chpt_from values: hf, tc,
or if you want to load from a specific dir, e.g.
tc:meta-llama/Meta-Llama-3-8B-Instruct-int8_wo
"""
if chpt_from == "hf":
str_lst = chpt_from.split(":")
assert len(str_lst) == 1 or len(str_lst) == 2, "Invalid --chpt_from format"
# Get the checkpoint format, e.g. "hf" or "tc"
chpt_format = str_lst[0]
# If user also specified a checkpoint, such as
# `meta-llama/Meta-Llama-3-8B-Instruct-int8_wo`
chpt_distribution = str_lst[1] if len(str_lst) == 2 else distribution

stage_state_dict = stage_module.state_dict()
if chpt_format == "hf":
# This format stands for: index file + multiple binary files
load_weights_from_hf_format(stage_module, distribution, device, model_config)
elif chpt_from == "torchchat":
stage_state_dict = load_weights_from_hf_format(
stage_state_dict, chpt_distribution, device, model_config
)
elif chpt_format == "tc":
# This format stands for:
# single binary file, OR
# multiple binary files without index files.
load_weights_from_torchchat_format(stage_module, distribution, device, model_config)
stage_state_dict = load_weights_from_torchchat_format(
stage_state_dict, chpt_distribution, device, model_config
)
else:
raise ValueError(f"Unknown checkpoint format: {chpt_from}")
raise ValueError(f"Unknown checkpoint format: {chpt_format}")

# Fill state dict into stage module
stage_module.load_state_dict(stage_state_dict, strict=False, assign=True)
logger.info(f"Successfully loaded weights into stage module")


def _encode_strings(
Expand Down Expand Up @@ -589,9 +610,9 @@ def get_example_ins_outs(seqlen: int) -> Tuple[torch.Tensor, torch.Tensor]:
parser.add_argument(
"--chpt-from",
type=str,
default="hf", # TODO: change to torchchat once we support it well
help="Checkpoint format to load from",
choices=["hf", "torchchat"],
default="tc",
help="Checkpoint to load from, e.g. `hf` or `tc`, or "
"`tc:meta-llama/Meta-Llama-3-8B-Instruct-int8_wo`",
)
args = parser.parse_args()

Expand Down
95 changes: 40 additions & 55 deletions torchchat/distributed/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,12 @@ def remap_weight_keys(dictionary):


def load_weights_per_map(
stage_module: Module,
stage_state_dict,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typing is missing?

weight_map: Dict[str, str],
file_location: str,
new_to_old_keymap: Dict[str, str],
device: torch.device,
is_safetensor: bool,
purge_model_prefix: bool = True,
ignore_cache_layers: bool = True,
model_config: Optional[Dict] = None,
) -> Tuple[int, int]:
"""
Expand All @@ -138,18 +136,11 @@ def load_weights_per_map(
new_to_old_keymap (Dict[str, str]): Mapping of new parameter names to old ones.
device (torch.device): The device to load tensors onto.
is_safetensor (bool): Whether the files are safetensors.
purge_model_prefix (bool): Whether to remove 'model.' prefix from keys.
ignore_cache_layers (bool): Whether to ignore cache layers when reporting missing keys.
model_config (Optional[Dict]): Model configuration.

Returns:
Tuple[int, int]: Number of updated weights and number of missing weights.
"""
stage_state_dict = stage_module.state_dict()
if purge_model_prefix:
stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.")
weight_map = purge_fqn_prefix(weight_map, "model.")

needed_files = get_needed_files(stage_state_dict, weight_map)
updated_states: Set[str] = set()

Expand All @@ -175,27 +166,9 @@ def load_weights_per_map(
logger.error(f"Error during checkpoint processing:")
raise e

missing_keys = handle_missing_keys(
stage_state_dict, updated_states, ignore_cache_layers
check_for_missing_keys(
stage_state_dict, updated_states, ignore_cache_layers=True
)
# log_loading_status(missing_keys, updated_states)
if missing_keys:
logger.warning(
f"Partially updated state dict. Missing {len(missing_keys)} keys: {missing_keys}"
)
else:
logger.info("Fully updated state dict.")

logger.info(f"Loading {len(updated_states)} weights into stage dict")
# precount, premap = record_module_dtypes(stage_module)
stage_module.load_state_dict(stage_state_dict, strict=False, assign=True)
# postcount, postmap = record_module_dtypes(stage_module)
# logger.info(f"{precount=}, {postcount=}")
# logger.info(f"{premap=}, {postmap=}")

logger.info(f"Successfully loaded {len(updated_states)} weights into stage module")

return len(updated_states), len(missing_keys)


# TODO: clean this up together with `purge_fqn_prefix` when we switch
Expand Down Expand Up @@ -287,14 +260,15 @@ def update_state_dict(
checkpoint_tensor = checkpoint[old_param]
model_tensor = state_dict[param]

if "wq" in param:
checkpoint_tensor = permute_weight_to_attn_heads(
checkpoint_tensor, num_heads, head_dim, dim
)
elif "wk" in param:
checkpoint_tensor = permute_weight_to_attn_heads(
checkpoint_tensor, num_local_heads, head_dim, dim
)
if new_to_old_keymap is not None:
if "wq" in param:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this was the issue! It seems like a flag as to the type of checkpoint is more robust and then map type to need for permuting? Seems brittle if we have future checkpoints and we permute or not solely based on a fqn (i.e. generically, having 'wq' in a name doesn't actually imply anything inherently about need to permute)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using if new_to_old_keymap is not None as the flag for now.
If we are using original HF checkpoints, then we will need this new_to_old_keymap. Otherwise, if we use TC checkpoints, we don't.

checkpoint_tensor = permute_weight_to_attn_heads(
checkpoint_tensor, num_heads, head_dim, dim
)
elif "wk" in param:
checkpoint_tensor = permute_weight_to_attn_heads(
checkpoint_tensor, num_local_heads, head_dim, dim
)

# Move checkpoint tensor to desired device
checkpoint_tensor = checkpoint_tensor.to(device)
Expand Down Expand Up @@ -324,10 +298,10 @@ def clean_cache_keys(input_set: Set[str]) -> Set[str]:
}


def handle_missing_keys(
def check_for_missing_keys(
state_dict: Dict[str, torch.Tensor],
updated_states: Set[str],
ignore_cache_layers: bool,
ignore_cache_layers: bool = True,
) -> Set[str]:
"""This function handles 'expected' missing keys from the checkpoint update set.
This is used for ignoring cache, rope freqs, and mask layers that are generated, rather than persisted
Expand All @@ -342,7 +316,13 @@ def handle_missing_keys(
logger.info(
f"Ignoring {start_len - after_len} missing cache, freqs, mask layers"
)
return missing_keys

if len(missing_keys) > 0:
from itertools import islice
raise RuntimeError(
f"Missing {len(missing_keys)} weights, for example: "
f"{list(islice(missing_keys, 10))}"
)


def log_loading_status(missing_keys: Set[str], updated_states: Set[str]):
Expand All @@ -355,10 +335,10 @@ def log_loading_status(missing_keys: Set[str], updated_states: Set[str]):
logger.info(f"Successfully loaded {len(updated_states)} weights into stage module")


def load_weights_from_hf_format(stage_module, distribution, device, model_config):
def load_weights_from_hf_format(stage_state_dict, distribution, device, model_config):
"""
Load the weights from Hugging Face format (index file + multiple safetensor
files), and fill into `stage_module`. Model config is needed b/c we permute
files), and fill into `stage_state_dict`. Model config is needed b/c we permute
wq and wk weights based on attn heads.
"""
# Get the weight map for a given HF model id
Expand All @@ -382,21 +362,21 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config
weight_dir = os.path.dirname(index_file)
logger.info(f"Loading weights from: {weight_dir}")

# TODO: clean this up together with `purge_fqn_prefix` when we switch
stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.")
weight_map = purge_fqn_prefix(weight_map, "model.")

# Load the weights into the stage module
num_loaded_weights, num_missing_weights = load_weights_per_map(
stage_module,
load_weights_per_map(
stage_state_dict,
weight_map,
weight_dir,
new_to_old_keymap,
device,
is_safetensor,
model_config=model_config,
)
logger.info(
f"Success - Loaded {num_loaded_weights} weights, {num_missing_weights} missing weights"
)
if num_missing_weights > 0:
raise ValueError(f"Missing {num_missing_weights} weights")
return stage_state_dict


# HACK: assuming single file for torchchat's converted checkpoints. We should
Expand All @@ -406,13 +386,12 @@ def load_weights_from_hf_format(stage_module, distribution, device, model_config
# will tell us if there is a single file or a directory.
TORCHCHCAT_SINGLE_FILE_CHECKPOINT = True

def load_weights_from_torchchat_format(stage_module, distribution, device, model_config):
def load_weights_from_torchchat_format(stage_state_dict, distribution, device, model_config):
"""
Load the weights from torchchat format (single binary file), and fill into
`stage_module`. Model config is needed b/c we permute wq and wk weights
based on attn heads.
"""
stage_state_dict = stage_module.state_dict()
# TODO: clean this up together with `purge_fqn_prefix` when we switch
stage_state_dict = purge_fqn_prefix(stage_state_dict, "model.")

Expand All @@ -437,6 +416,10 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model
"checkpoint_path": checkpoint_path,
}
builder_args = BuilderArgs(**args_dict)
logger.info(
"Loading checkpoint from: "
f"{builder_args.checkpoint_dir or builder_args.checkpoint_path}"
)
# Then, load the checkpoint using torchchat util
checkpoint = _load_checkpoint(builder_args)

Expand All @@ -450,6 +433,8 @@ def load_weights_from_torchchat_format(stage_module, distribution, device, model
updated_states=updated_states,
)

# Fill state dict into stage module
stage_module.load_state_dict(stage_state_dict, strict=False, assign=True)
logger.info(f"Successfully loaded {len(updated_states)} weights into stage module")
check_for_missing_keys(
stage_state_dict, updated_states, ignore_cache_layers=True
)

return stage_state_dict
Loading