File tree Expand file tree Collapse file tree 2 files changed +33
-0
lines changed
test/distributed/pipelining
torch/distributed/pipelining Expand file tree Collapse file tree 2 files changed +33
-0
lines changed Original file line number Diff line number Diff line change 5858class MockPipelineStage (_PipelineStageBase ):
5959 def __init__ (self , * args , ** kwargs ):
6060 # Mock the necessary attributes
61+ self .submod = None
6162 self .num_stages = kwargs .get ("num_stages" , 1 )
6263 self .group_size = kwargs .get ("group_size" , 1 )
6364 self .group_rank = kwargs .get ("group_rank" , 0 )
@@ -197,6 +198,28 @@ def test_schedule_with_single_stage(self, ScheduleClass):
197198
198199 torch .distributed .destroy_process_group ()
199200
201+ def test_zero_bubble_schedule_errors_with_compile (self ):
202+ """
203+ Test that zero bubble schedules raise an error when used with torch.compile.
204+ """
205+ store = FakeStore ()
206+ torch .distributed .init_process_group (
207+ backend = "fake" , rank = 0 , world_size = 1 , store = store
208+ )
209+ n_stages = 1
210+ device = torch .device ("cpu" )
211+ model = MultiMLP (8 , n_layers = n_stages )
212+ # full_mod
213+ compiled_model = torch .compile (model )
214+ stage = PipelineStage (
215+ compiled_model ,
216+ 0 ,
217+ n_stages ,
218+ device ,
219+ )
220+ with self .assertRaises (RuntimeError ):
221+ ScheduleInterleavedZeroBubble ([stage ], 2 )
222+
200223
201224instantiate_parametrized_tests (ScheduleTest )
202225
Original file line number Diff line number Diff line change 2424
2525import torch
2626import torch .distributed as dist
27+ from torch ._dynamo import OptimizedModule
2728from torch .distributed .fsdp import FSDPModule , UnshardHandle
2829from torch .profiler import record_function
2930
@@ -2020,6 +2021,15 @@ def __init__(
20202021 kwargs_chunk_spec : Optional [Dict [str , TensorChunkSpec ]] = None ,
20212022 output_merge_spec : Optional [Union [Dict [str , Any ], Tuple [Any ]]] = None ,
20222023 ):
2024+ # TODO: we don't support Zero Bubble with torch.compile so we
2025+ # should disable it for now
2026+ for stage in stages :
2027+ if isinstance (stage .submod , OptimizedModule ):
2028+ raise RuntimeError (
2029+ "The Zero Bubble schedule is not supported with \
2030+ stage modules that have used torch.compile"
2031+ )
2032+
20232033 self .pp_group_size = stages [0 ].group_size
20242034 super ().__init__ (
20252035 stages = stages ,
You can’t perform that action at this time.
0 commit comments