Enable non-dim-0 FSDP sharding of MoE experts when ep=1#2668
Open
aws-ritikadm wants to merge 2 commits intopytorch:mainfrom
Open
Enable non-dim-0 FSDP sharding of MoE experts when ep=1#2668aws-ritikadm wants to merge 2 commits intopytorch:mainfrom
aws-ritikadm wants to merge 2 commits intopytorch:mainfrom
Conversation
tianyu-l
approved these changes
Mar 23, 2026
Contributor
tianyu-l
left a comment
There was a problem hiding this comment.
thanks for the fix, please address nit comments
| experts_fsdp_config = fsdp_config.copy() | ||
| experts_fsdp_config["mesh"] = edp_mesh | ||
| assert edp_mesh is not None | ||
| fsdp_size = edp_mesh["efsdp"].size() * ep_degree |
Contributor
There was a problem hiding this comment.
Suggested change
| fsdp_size = edp_mesh["efsdp"].size() * ep_degree | |
| efsdp_ep_size = edp_mesh["efsdp"].size() * ep_degree |
| fsdp_size = edp_mesh["efsdp"].size() * ep_degree | ||
| else: | ||
| experts_fsdp_config = fsdp_config | ||
| fsdp_size = fsdp_config["mesh"].size() |
Contributor
There was a problem hiding this comment.
Suggested change
| fsdp_size = fsdp_config["mesh"].size() | |
| efsdp_ep_size = fsdp_config["mesh"].size() |
| edp_mesh["efsdp"].size() * ep_degree | ||
| > transformer_block.moe.experts.num_experts | ||
| ): | ||
| if fsdp_size > transformer_block.moe.experts.num_experts: |
Contributor
There was a problem hiding this comment.
Suggested change
| if fsdp_size > transformer_block.moe.experts.num_experts: | |
| if efsdp_ep_size > transformer_block.moe.experts.num_experts: |
| # inefficiency due to padding, so we shard on dim-1 (hidden_dim) instead. | ||
| if transformer_block.moe_enabled: | ||
| if ep_degree > 1: | ||
| experts_fsdp_config = fsdp_config.copy() |
Contributor
There was a problem hiding this comment.
can also change this to efsdp_config to be concise and consistent, but up to you.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Previously, routed experts in MoE layers were only separately wrapped with
fully_shardwhenep_degree > 1. Whenep_degree == 1, experts were sharded only as part of the outer TransformerBlock FSDP group, which meant theShard(1)placement optimization (sharding on hidden_dim instead of num_experts) was never applied.This PR extends the separate expert FSDP wrapping to also apply when
ep_degree == 1. When the FSDP degree exceedsnum_experts, experts are sharded on dim 1 (hidden_dim) to avoid padding inefficiency from dim-0 sharding — the same optimization that was already in place forep > 1.Validation
The three tests: