Skip to content

Commit d0f5df8

Browse files
xmfanpytorchmergebot
authored andcommitted
[ca] add test_dtensor_compile.py to compiled autograd tests (pytorch#144107)
more than half the tests use autograd, pass rate 19/26 Pull Request resolved: pytorch#144107 Approved by: https://github.com/zou3519, https://github.com/bdhirsh, https://github.com/jansel
1 parent fcf9dc3 commit d0f5df8

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

test/distributed/_tensor/test_dtensor_compile.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,18 @@ def extract_graph(fx_g, _, graph_cell):
8686

8787
class TestDTensorCompile(torch._dynamo.test_case.TestCase):
8888
def setUp(self):
89-
super().setUp()
89+
super(
90+
type(self), self
91+
).setUp() # use explicit params for compiled autograd test wrapping
9092
fake_store = FakeStore()
9193
dist.init_process_group(
9294
"fake", store=fake_store, rank=0, world_size=self.world_size
9395
)
9496

9597
def tearDown(self):
96-
super().tearDown()
98+
super(
99+
type(self), self
100+
).tearDown() # use explicit params for compiled autograd test wrapping
97101
dist.destroy_process_group()
98102

99103
@property
@@ -104,6 +108,19 @@ def device_type(self) -> str:
104108
def world_size(self) -> int:
105109
return 2
106110

111+
def test_dtensor_basic(self):
112+
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
113+
114+
@torch.compile(backend="aot_eager", fullgraph=True)
115+
def fn(x):
116+
return x * x + 2
117+
118+
param = torch.randn(4, 4, requires_grad=True)
119+
x = DTensor.from_local(param, mesh, [Shard(0)], run_check=False)
120+
121+
res = fn(x)
122+
res.to_local().sum().backward()
123+
107124
def test_placement_compile(self):
108125
def fn(x):
109126
a = 0

test/inductor/test_compiled_autograd.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3594,6 +3594,14 @@ def wrap_test_class(orig_cls):
35943594
"test_invalid_gradients", # can't give autograd error due to inaccurate output metadata of lifted backward
35953595
"test_autograd_node_isinstance", # backward ctx is a fake cls and not directly a Node instance
35963596
"test_backward_hook_relative_ordering", # compiled autograd collects breadth first, and module backward hook not supported
3597+
# Category: Subclasses
3598+
"test_dtensor_basic",
3599+
"test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent",
3600+
"test_dtensor_different_gradient_placement",
3601+
"test_dtensor_noncontiguous_output",
3602+
"test_dtensor_partial_placement_graph_output",
3603+
"test_tp_compile_comm_reordering",
3604+
"test_unwrap_async_collective_tensor_tangent",
35973605
# Uncategorized
35983606
}
35993607

@@ -3606,6 +3614,11 @@ def wrap_test_class(orig_cls):
36063614

36073615
TestAutogradWithCompiledAutograd = wrap_test_class(test_autograd.TestAutograd)
36083616
TestCustomOpWithCompiledAutograd = wrap_test_class(test_custom_ops.TestCustomOp)
3617+
if torch.distributed.is_available() and HAS_CUDA:
3618+
test_dtensor = load_test_module("distributed/_tensor/test_dtensor_compile")
3619+
TestDTensorCompileWithCompiledAutograd = wrap_test_class(
3620+
test_dtensor.TestDTensorCompile
3621+
)
36093622

36103623
if __name__ == "__main__":
36113624
if HAS_CPU:

0 commit comments

Comments
 (0)