|
35 | 35 | import torch._dynamo.testing |
36 | 36 | import torch._dynamo.utils |
37 | 37 | import torch._functorch.config |
| 38 | +import torch.distributed as dist |
38 | 39 | import torch.library |
39 | 40 | import torch.utils._pytree as pytree |
40 | 41 | from torch import nn |
41 | 42 | from torch._dynamo.debug_utils import same_two_models |
42 | 43 | from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312 |
43 | 44 | from torch._inductor.utils import fresh_inductor_cache |
44 | 45 | from torch.nn import functional as F |
45 | | -from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION |
| 46 | +from torch.testing._internal.common_cuda import ( |
| 47 | + PLATFORM_SUPPORTS_FLASH_ATTENTION, |
| 48 | + TEST_CUDA, |
| 49 | +) |
46 | 50 | from torch.testing._internal.common_utils import ( |
47 | 51 | disable_translation_validation_if_dynamic_shapes, |
48 | 52 | instantiate_parametrized_tests, |
@@ -6506,6 +6510,80 @@ def fn(obj): |
6506 | 6510 | opt_fn = torch.compile(fn, backend="eager") |
6507 | 6511 | self.assertEqual(fn(typing.Any), opt_fn(typing.Any)) |
6508 | 6512 |
|
| 6513 | + @unittest.skipIf(not TEST_CUDA, "test requires CUDA") |
| 6514 | + @unittest.skipIf(not dist.is_available(), "test requires distributed") |
| 6515 | + def test_ddp_checkpoint(self): |
| 6516 | + # https://github.com/pytorch/pytorch/issues/144035 |
| 6517 | + DIM = 256 |
| 6518 | + SEQ_LEN = 32 |
| 6519 | + |
| 6520 | + @torch.compile(backend="eager", fullgraph=True) |
| 6521 | + def mlp_forward(x, w1, w2, b1, b2): |
| 6522 | + y = F.linear(x, w1, b1) |
| 6523 | + y = F.relu(y) |
| 6524 | + y = F.linear(y, w2, b2) |
| 6525 | + return y |
| 6526 | + |
| 6527 | + class MLP(nn.Module): |
| 6528 | + def __init__( |
| 6529 | + self, |
| 6530 | + in_features: int, |
| 6531 | + hidden_features: int, |
| 6532 | + out_features: int, |
| 6533 | + ): |
| 6534 | + super().__init__() |
| 6535 | + self.w_in = nn.Parameter(torch.randn(hidden_features, in_features)) |
| 6536 | + self.w_out = nn.Parameter(torch.randn(out_features, hidden_features)) |
| 6537 | + self.b_in = nn.Parameter(torch.randn(hidden_features)) |
| 6538 | + self.b_out = nn.Parameter(torch.randn(out_features)) |
| 6539 | + |
| 6540 | + def forward(self, x): |
| 6541 | + result = torch.utils.checkpoint.checkpoint( |
| 6542 | + mlp_forward, |
| 6543 | + x, |
| 6544 | + self.w_in, |
| 6545 | + self.w_out, |
| 6546 | + self.b_in, |
| 6547 | + self.b_out, |
| 6548 | + use_reentrant=False, |
| 6549 | + ) |
| 6550 | + assert isinstance(result, torch.Tensor) |
| 6551 | + return result |
| 6552 | + |
| 6553 | + x = torch.randn(100, SEQ_LEN, DIM) |
| 6554 | + y = torch.zeros(100) |
| 6555 | + dataset = torch.utils.data.TensorDataset(x, y) |
| 6556 | + dataloader = torch.utils.data.DataLoader(dataset, batch_size=10) |
| 6557 | + model = MLP(DIM, 4 * DIM, DIM) |
| 6558 | + |
| 6559 | + try: |
| 6560 | + # required for DDP wrapper initialization |
| 6561 | + prior_master_addr = os.environ.get("MASTER_ADDR", None) |
| 6562 | + prior_master_port = os.environ.get("MASTER_PORT", None) |
| 6563 | + os.environ["MASTER_ADDR"] = "localhost" |
| 6564 | + os.environ["MASTER_PORT"] = "12355" |
| 6565 | + dist.init_process_group(backend="nccl", world_size=1, rank=0) |
| 6566 | + model = model.to("cuda") |
| 6567 | + model = nn.parallel.DistributedDataParallel(model) |
| 6568 | + |
| 6569 | + for batch in dataloader: |
| 6570 | + x, y = batch |
| 6571 | + x = x.to("cuda") |
| 6572 | + output = model(x) |
| 6573 | + loss = output.sum() |
| 6574 | + loss.backward() |
| 6575 | + finally: |
| 6576 | + dist.destroy_process_group() |
| 6577 | + if prior_master_addr: |
| 6578 | + os.environ["MASTER_ADDR"] = prior_master_addr |
| 6579 | + else: |
| 6580 | + del os.environ["MASTER_ADDR"] |
| 6581 | + |
| 6582 | + if prior_master_port: |
| 6583 | + os.environ["MASTER_PORT"] = prior_master_port |
| 6584 | + else: |
| 6585 | + del os.environ["MASTER_PORT"] |
| 6586 | + |
6509 | 6587 |
|
6510 | 6588 | instantiate_parametrized_tests(ReproTests) |
6511 | 6589 |
|
|
0 commit comments