Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchtitan/models/deepseek_v3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ python scripts/checkpoint_conversion/convert_from_hf.py <hf_checkpoints_dir> <dc
Some limitations:
1. It can't be used to convert HF checkpoint on the fly using GPU DTensor, because of sharding and quantized blocks may not be aligned well and causing silent numerfical incorrectness.
2. It can't be used for weight sync to generate a state dict of bf16 because fake quantization to fp8 is applied.
3. When converting GroupedExperts weights from HF separate expert weights on-the-fly, `torch.split()` will cause huge GPU memory usage. This is because torchtitan GroupedExperts' weight has shape `(num_experts, dim1, dim2)`, and by default shard FSDP on dim-0. When we call `torch.split()` in `to_hf()` function on dim-0, this will incur and all-gather and get replicated expert memory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought more about this. Even if FSDP shards on dim-1, EP will shard on dim-0 anyway. So the problem still exists. Let's discuss next week.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we perform a redistribute() before split() to ensure the expert parameter is sharded on dim-1? This redistributed, dim-1 sharded parameter will be used exclusively by the split().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With EP it's sharded on dim-0 anyway. Performing this redistribute means at least 1 comm in to_hf and at least 1 comm in from_hf.
If both EP and FSDP dim-0 sharding is used, we'll have strided sharding whose redistribute algo today may not be efficient or even correct.

Copy link
Contributor

@fegin fegin Aug 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The redistribution algorithm should be correct, but whether it is going to be efficient, that's debatable. I think it will be more efficient than allgather as less communication is incurred even if it is not the optimal one.

There will should be no extra comm in from_hf as DCP.load will handle the resharding but this resharding can be slow for sure.


## To be added
- Parallelism
Expand Down
3 changes: 3 additions & 0 deletions torchtitan/models/deepseek_v3/model/state_dict_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
new_key = new_abstract_key.format(layer_num, expert_num)
hf_state_dict[new_key] = split_values[expert_num].squeeze()

# Remove the GroupedExperts' weight from the state_dict to free memory
del value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for loading checkpoint synchronously, this sounds fine.
But for saving, after calling to_hf we may still need the original weights for next training steps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that's a valid concern. If a user periodically save a checkpoint in HF format, this would be a issue. I checked checkpoint.py, and it only support last_save_in_hf in _save_last_step, and we are not supporting saving HF in between

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The adpater is independent of checkpoint.py in torchtitan. In RL weight sync, it will be called without checkpointing.


elif "layers" in key:
abstract_key = re.sub(r"(\d+)", "{}", key, count=1)
layer_num = re.search(r"\d+", key).group(0)
Expand Down
Loading