-
Notifications
You must be signed in to change notification settings - Fork 248
[Distributed] Fix correctness issue in TC load path #1276
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1276
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 3808576 with merge base 766bee9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
| def load_weights_per_map( | ||
| stage_module: Module, | ||
| stage_state_dict, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typing is missing?
| checkpoint_tensor, num_local_heads, head_dim, dim | ||
| ) | ||
| if new_to_old_keymap is not None: | ||
| if "wq" in param: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so this was the issue! It seems like a flag as to the type of checkpoint is more robust and then map type to need for permuting? Seems brittle if we have future checkpoints and we permute or not solely based on a fqn (i.e. generically, having 'wq' in a name doesn't actually imply anything inherently about need to permute)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am using if new_to_old_keymap is not None as the flag for now.
If we are using original HF checkpoints, then we will need this new_to_old_keymap. Otherwise, if we use TC checkpoints, we don't.
| device (torch.device): The device to load the weights onto. | ||
| model_config (ModelArgs): The model config. | ||
| chpt_from (str): The checkpoint format to load the weights from, e.g. "torchchat" or "hf". | ||
| chpt_from (str): The checkpoint format to load the weights from. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as in prev comments, pls switch to 'chkpt' for abbreviating checkpoint. It reads as 'chapter_from' when using 'chpt_from'.
lessw2020
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for fixing the permute issue.
looks good - left a comment about robustness of relying solely on 'wq' in a param name as the determining factor to permute or not...seems better to map it to checkpoint type directly, for future robustness.
torchchat checkpoints are already permuted before saving, thus we don't need to permute it again when loading them.
Also added support for:
--chpt-from tc:<model_dir>Example:
tc:meta-llama/Meta-Llama-3-8B-Instruct-int8_wo(
int8_wostands for a quantized checkpoint)This is a dev feature for experimenting composability between distributed and quantized tensors, not user-facing, yet.