Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Oct 6, 2024

torchchat checkpoints are already permuted before saving, thus we don't need to permute it again when loading them.

Also added support for:
--chpt-from tc:<model_dir>
Example:
tc:meta-llama/Meta-Llama-3-8B-Instruct-int8_wo
(int8_wo stands for a quantized checkpoint)
This is a dev feature for experimenting composability between distributed and quantized tensors, not user-facing, yet.

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1276

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 3808576 with merge base 766bee9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 6, 2024

def load_weights_per_map(
stage_module: Module,
stage_state_dict,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typing is missing?

checkpoint_tensor, num_local_heads, head_dim, dim
)
if new_to_old_keymap is not None:
if "wq" in param:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this was the issue! It seems like a flag as to the type of checkpoint is more robust and then map type to need for permuting? Seems brittle if we have future checkpoints and we permute or not solely based on a fqn (i.e. generically, having 'wq' in a name doesn't actually imply anything inherently about need to permute)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using if new_to_old_keymap is not None as the flag for now.
If we are using original HF checkpoints, then we will need this new_to_old_keymap. Otherwise, if we use TC checkpoints, we don't.

device (torch.device): The device to load the weights onto.
model_config (ModelArgs): The model config.
chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf".
chpt_from (str): The checkpoint format to load the weights from.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as in prev comments, pls switch to 'chkpt' for abbreviating checkpoint. It reads as 'chapter_from' when using 'chpt_from'.

Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for fixing the permute issue.
looks good - left a comment about robustness of relying solely on 'wq' in a param name as the determining factor to permute or not...seems better to map it to checkpoint type directly, for future robustness.

@kwen2501 kwen2501 merged commit e950f5c into hf_bin Oct 8, 2024
52 checks passed
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants