Skip to content

[FSDP2] enable per-param mesh FSDP2 for MoE#2281

Open
weifengpy wants to merge 2 commits intopytorch:mainfrom
weifengpy:per-param-mesh
Open

[FSDP2] enable per-param mesh FSDP2 for MoE#2281
weifengpy wants to merge 2 commits intopytorch:mainfrom
weifengpy:per-param-mesh

Conversation

@weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Jan 28, 2026

command: NGPU=8 MODULE=deepseek_v3 CONFIG=deepseek_v3_16b ./run_train.sh --training.steps 20 --parallelism.expert-parallel-degree 4

fsdp2 support per-param mesh: pytorch/pytorch#173509

this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

def _shard_placement_fn(param: nn.Parameter) -> ShardPlacementResult:
    if param in expert_params:
        # Expert parameters: use Shard(1) on edp_mesh
        return ShardPlacementResult(
            placement=Shard(1), mesh_info=edp_mesh_info
        )
    else:
        # Non-expert parameters: use Shard(0) on dp_mesh
        return ShardPlacementResult(
            placement=Shard(0), mesh_info=dp_mesh_info
        )

this make it possible for apply torch.compile on each transformer_block. I didn't enable compile per block yet becuase there is still a gap in torch.compile + ac + MoE: #2341

AG order in forward are exactly the same before and after this change
Screenshot 2026-02-06 at 14 57 04

AG order in backward are different but is better
Screenshot 2026-02-06 at 14 59 44

Explicit Backward AllGather Order                                                                                                                                                                                         
  layers.7       @ 118.83ms   (attention/ffn params)                                                                 
  layers.6       @ 121.52ms   (attention/ffn params)                                                                 
  layers.6.moe   @ 122.04ms   (MoE expert params)                                                                    
  layers.7.moe   @ 125.81ms   (MoE expert params)  ← delayed!                                                        
                                                                                                                     
  Per-param Backward AllGather Order                                                                                 
  layers.7       @ 114.30ms   (first FSDP unit)                                                                      
  layers.7       @ 115.14ms   (second FSDP unit, includes MoE)                                                       
  layers.6       @ 117.42ms   (first FSDP unit)                                                                      
  layers.6       @ 117.89ms   (second FSDP unit, includes MoE)   

Numerics remains bitwise equal with/without this change

 Loss Comparison                                                                                                                                                                                                                           
  ┌──────┬───────────────┬───────────────┬───────┐                                                                                                                                                                                          
  │ Step │ Old (0d93c63) │ New (e1c47c8) │ Match │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 1    │ 8.01151657    │ 8.01151657    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 5    │ 3.85572004    │ 3.85572004    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 10   │ 3.15517211    │ 3.15517211    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 15   │ 3.07873583    │ 3.07873583    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 20   │ 2.92206621    │ 2.92206621    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 25   │ 2.89102936    │ 2.89102936    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 30   │ 2.81201696    │ 2.81201696    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 35   │ 2.84123349    │ 2.84123349    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 40   │ 2.76206398    │ 2.76206398    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 45   │ 2.82969308    │ 2.82969308    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 50   │ 2.77560568    │ 2.77560568    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 55   │ 2.75578761    │ 2.75578761    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 60   │ 2.75143075    │ 2.75143075    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 65   │ 2.74203372    │ 2.74203372    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 70   │ 2.71638918    │ 2.71638918    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 75   │ 2.74999237    │ 2.74999237    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 80   │ 2.75584078    │ 2.75584078    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 85   │ 2.74837303    │ 2.74837303    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 90   │ 2.72101045    │ 2.72101045    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 95   │ 2.73645735    │ 2.73645735    │ ✓     │                                                                                                                                                                                          
  ├──────┼───────────────┼───────────────┼───────┤                                                                                                                                                                                          
  │ 100  │ 2.70604038    │ 2.70604038    │ ✓     │                                                                                                                                                                                          
  └──────┴───────────────┴───────────────┴───────┘                                                                                                                                                                                          

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jan 28, 2026
@weifengpy weifengpy marked this pull request as draft January 28, 2026 20:44
@weifengpy weifengpy changed the title [FSDP2] enable per-param mesh FSDP2 for MoE and per-layer compile [WIP][FSDP2] enable per-param mesh FSDP2 for MoE and per-layer compile Jan 28, 2026
@weifengpy weifengpy force-pushed the per-param-mesh branch 6 times, most recently from 3c36e53 to 3bf7e27 Compare February 7, 2026 01:47
@weifengpy weifengpy changed the title [WIP][FSDP2] enable per-param mesh FSDP2 for MoE and per-layer compile [FSDP2] enable per-param mesh FSDP2 for MoE and per-layer compile Feb 7, 2026
@weifengpy weifengpy changed the title [FSDP2] enable per-param mesh FSDP2 for MoE and per-layer compile [FSDP2] enable per-param mesh FSDP2 for MoE Feb 7, 2026
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 9, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
weifengpy added a commit to pytorch/pytorch that referenced this pull request Feb 10, 2026
this PR applies fully_shard on transformer_block, sharding experts on edp_mesh, and other params on dp_mesh. FSDPModule schedule 2 all-gather sequentially: 1st on transformer blocks, 2nd on experts

see torchtitan for AG/RS schedules and numeric experiments: pytorch/torchtitan#2281

existing fsdp2 callsite won't be affected because _shard_placement_fn -> ShardPlacementResult is a new code path

checked backward-compatiblibility
* pytorch: fsdp2_mem_tracker.py is affected, but only if people use it with per-param mesh. I don't think it's a hard blocker
* torchtitan: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                              
* torchao: No usages of _fsdp_param_group (singular). Safe.                                                                                                                                                                                 




cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx kadeng chauhang amjames Lucaskabela jataylo

[ghstack-poisoned]
@weifengpy
Copy link
Contributor Author

those CI errors because we need to land pytorch/pytorch#173509 first

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM, is there a plan for solving #2341?

Also, not sure if it's blocked by the issue, but we should modify apply_compile to remove the previous workaround fine-grained compilation code.

@xmfan
Copy link
Member

xmfan commented Feb 10, 2026

@tianyu-l We're still gonna need the fine-grained workarounds at least for mxfp8, due to #2250 (comment). Until that issue is fixed (dynamo bwd tracing of autograd functions), we will trace wrong bwd graph.

@tianyu-l
Copy link
Contributor

@xmfan oh, I didn't know. I marked it as high priority for now.

@weifengpy
Copy link
Contributor Author

Also, not sure if it's blocked by the issue, but we should modify apply_compile to remove the previous workaround fine-grained compilation code.

I didn't change apply_compile because of #2341. #2250 (comment) is new to me

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

Status: Todo

Development

Successfully merging this pull request may close these issues.

4 participants