Skip to content

Commit f4969c8

Browse files
xmfanpytorchmergebot
authored andcommitted
fix torch.compile + ddp + non-reentrant AC pack hook firing count (pytorch#144271)
FIXES pytorch#144035 In order to preserve hook firing semantics, we disabled pack/unpack hooks for torch.compile: pytorch#123196. In DDP under torch.compile, there's this other callsite that we need to disable hooks for Pull Request resolved: pytorch#144271 Approved by: https://github.com/bdhirsh, https://github.com/soulitzer
1 parent 861b65f commit f4969c8

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

test/dynamo/test_repros.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,18 @@
3535
import torch._dynamo.testing
3636
import torch._dynamo.utils
3737
import torch._functorch.config
38+
import torch.distributed as dist
3839
import torch.library
3940
import torch.utils._pytree as pytree
4041
from torch import nn
4142
from torch._dynamo.debug_utils import same_two_models
4243
from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312
4344
from torch._inductor.utils import fresh_inductor_cache
4445
from torch.nn import functional as F
45-
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
46+
from torch.testing._internal.common_cuda import (
47+
PLATFORM_SUPPORTS_FLASH_ATTENTION,
48+
TEST_CUDA,
49+
)
4650
from torch.testing._internal.common_utils import (
4751
disable_translation_validation_if_dynamic_shapes,
4852
instantiate_parametrized_tests,
@@ -6506,6 +6510,80 @@ def fn(obj):
65066510
opt_fn = torch.compile(fn, backend="eager")
65076511
self.assertEqual(fn(typing.Any), opt_fn(typing.Any))
65086512

6513+
@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
6514+
@unittest.skipIf(not dist.is_available(), "test requires distributed")
6515+
def test_ddp_checkpoint(self):
6516+
# https://github.com/pytorch/pytorch/issues/144035
6517+
DIM = 256
6518+
SEQ_LEN = 32
6519+
6520+
@torch.compile(backend="eager", fullgraph=True)
6521+
def mlp_forward(x, w1, w2, b1, b2):
6522+
y = F.linear(x, w1, b1)
6523+
y = F.relu(y)
6524+
y = F.linear(y, w2, b2)
6525+
return y
6526+
6527+
class MLP(nn.Module):
6528+
def __init__(
6529+
self,
6530+
in_features: int,
6531+
hidden_features: int,
6532+
out_features: int,
6533+
):
6534+
super().__init__()
6535+
self.w_in = nn.Parameter(torch.randn(hidden_features, in_features))
6536+
self.w_out = nn.Parameter(torch.randn(out_features, hidden_features))
6537+
self.b_in = nn.Parameter(torch.randn(hidden_features))
6538+
self.b_out = nn.Parameter(torch.randn(out_features))
6539+
6540+
def forward(self, x):
6541+
result = torch.utils.checkpoint.checkpoint(
6542+
mlp_forward,
6543+
x,
6544+
self.w_in,
6545+
self.w_out,
6546+
self.b_in,
6547+
self.b_out,
6548+
use_reentrant=False,
6549+
)
6550+
assert isinstance(result, torch.Tensor)
6551+
return result
6552+
6553+
x = torch.randn(100, SEQ_LEN, DIM)
6554+
y = torch.zeros(100)
6555+
dataset = torch.utils.data.TensorDataset(x, y)
6556+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10)
6557+
model = MLP(DIM, 4 * DIM, DIM)
6558+
6559+
try:
6560+
# required for DDP wrapper initialization
6561+
prior_master_addr = os.environ.get("MASTER_ADDR", None)
6562+
prior_master_port = os.environ.get("MASTER_PORT", None)
6563+
os.environ["MASTER_ADDR"] = "localhost"
6564+
os.environ["MASTER_PORT"] = "12355"
6565+
dist.init_process_group(backend="nccl", world_size=1, rank=0)
6566+
model = model.to("cuda")
6567+
model = nn.parallel.DistributedDataParallel(model)
6568+
6569+
for batch in dataloader:
6570+
x, y = batch
6571+
x = x.to("cuda")
6572+
output = model(x)
6573+
loss = output.sum()
6574+
loss.backward()
6575+
finally:
6576+
dist.destroy_process_group()
6577+
if prior_master_addr:
6578+
os.environ["MASTER_ADDR"] = prior_master_addr
6579+
else:
6580+
del os.environ["MASTER_ADDR"]
6581+
6582+
if prior_master_port:
6583+
os.environ["MASTER_PORT"] = prior_master_port
6584+
else:
6585+
del os.environ["MASTER_PORT"]
6586+
65096587

65106588
instantiate_parametrized_tests(ReproTests)
65116589

torch/_dynamo/backends/distributed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,8 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]):
525525
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
526526

527527
submod_compiler = SubmodCompiler(split_gm, self.backend_compile_fn, fake_mode)
528-
submod_compiler.run(*example_inputs)
528+
with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
529+
submod_compiler.run(*example_inputs)
529530
split_gm.recompile()
530531

531532
ddp_graph_log.debug(

0 commit comments

Comments
 (0)