-
Notifications
You must be signed in to change notification settings - Fork 496
[WIP][DSV3] Remove keep a copy of GroupedExperts weight, free memory in StateDictAdapter #1585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -158,6 +158,9 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: | |
new_key = new_abstract_key.format(layer_num, expert_num) | ||
hf_state_dict[new_key] = split_values[expert_num].squeeze() | ||
|
||
# Remove the GroupedExperts' weight from the state_dict to free memory | ||
del value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think for loading checkpoint synchronously, this sounds fine. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, that's a valid concern. If a user periodically save a checkpoint in HF format, this would be a issue. I checked There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The adpater is independent of checkpoint.py in torchtitan. In RL weight sync, it will be called without checkpointing. |
||
|
||
elif "layers" in key: | ||
abstract_key = re.sub(r"(\d+)", "{}", key, count=1) | ||
layer_num = re.search(r"\d+", key).group(0) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought more about this. Even if FSDP shards on dim-1, EP will shard on dim-0 anyway. So the problem still exists. Let's discuss next week.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we perform a redistribute() before split() to ensure the expert parameter is sharded on
dim-1
? This redistributed, dim-1 sharded parameter will be used exclusively by thesplit()
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With EP it's sharded on dim-0 anyway. Performing this redistribute means at least 1 comm in to_hf and at least 1 comm in from_hf.
If both EP and FSDP dim-0 sharding is used, we'll have strided sharding whose redistribute algo today may not be efficient or even correct.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The redistribution algorithm should be correct, but whether it is going to be efficient, that's debatable. I think it will be more efficient than
allgather
as less communication is incurred even if it is not the optimal one.There will should be no extra comm in from_hf as DCP.load will handle the resharding but this resharding can be slow for sure.