Skip to content

Commit 9631d1a

Browse files
H-Huangpytorchmergebot
authored andcommitted
[pipelining] throw error with ZB and compile (pytorch#143599)
Zero bubble wil SIGSEGV when operating on a `torch.compile`'d model so raising this error while I am still investigating the cause / design for a fix. Pull Request resolved: pytorch#143599 Approved by: https://github.com/wconstab
1 parent 3797143 commit 9631d1a

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

test/distributed/pipelining/test_schedule.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
class MockPipelineStage(_PipelineStageBase):
5959
def __init__(self, *args, **kwargs):
6060
# Mock the necessary attributes
61+
self.submod = None
6162
self.num_stages = kwargs.get("num_stages", 1)
6263
self.group_size = kwargs.get("group_size", 1)
6364
self.group_rank = kwargs.get("group_rank", 0)
@@ -197,6 +198,28 @@ def test_schedule_with_single_stage(self, ScheduleClass):
197198

198199
torch.distributed.destroy_process_group()
199200

201+
def test_zero_bubble_schedule_errors_with_compile(self):
202+
"""
203+
Test that zero bubble schedules raise an error when used with torch.compile.
204+
"""
205+
store = FakeStore()
206+
torch.distributed.init_process_group(
207+
backend="fake", rank=0, world_size=1, store=store
208+
)
209+
n_stages = 1
210+
device = torch.device("cpu")
211+
model = MultiMLP(8, n_layers=n_stages)
212+
# full_mod
213+
compiled_model = torch.compile(model)
214+
stage = PipelineStage(
215+
compiled_model,
216+
0,
217+
n_stages,
218+
device,
219+
)
220+
with self.assertRaises(RuntimeError):
221+
ScheduleInterleavedZeroBubble([stage], 2)
222+
200223

201224
instantiate_parametrized_tests(ScheduleTest)
202225

torch/distributed/pipelining/schedules.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import torch
2626
import torch.distributed as dist
27+
from torch._dynamo import OptimizedModule
2728
from torch.distributed.fsdp import FSDPModule, UnshardHandle
2829
from torch.profiler import record_function
2930

@@ -2020,6 +2021,15 @@ def __init__(
20202021
kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None,
20212022
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
20222023
):
2024+
# TODO: we don't support Zero Bubble with torch.compile so we
2025+
# should disable it for now
2026+
for stage in stages:
2027+
if isinstance(stage.submod, OptimizedModule):
2028+
raise RuntimeError(
2029+
"The Zero Bubble schedule is not supported with \
2030+
stage modules that have used torch.compile"
2031+
)
2032+
20232033
self.pp_group_size = stages[0].group_size
20242034
super().__init__(
20252035
stages=stages,

0 commit comments

Comments
 (0)