Skip to content

Commit b098be8

Browse files
authored
refactor: DTensor inheritance for XLAShardedTensor (#9576)
Changing XLAShardedTensor to inherit from DTensor and not torch.tensor in regards to #9418.
1 parent f8b44e2 commit b098be8

File tree

5 files changed

+44
-2
lines changed

5 files changed

+44
-2
lines changed

test/neuron/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ function run_xla_op_tests3 {
257257
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
258258
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py"
259259
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_redistribute.py"
260+
run_test_multi_device "$_TEST_DIR/spmd/test_xla_sharded_tensor.py"
260261
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
261262
#run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
262263
run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py"

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ function run_xla_op_tests3 {
238238
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
239239
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
240240
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_redistribute.py"
241+
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_sharded_tensor.py"
241242
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
242243
run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
243244
run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py"
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import sys
2+
import unittest
3+
import test_xla_sharding_base
4+
from torch.distributed.tensor import DTensor
5+
from torch_xla.distributed.spmd import XLAShardedTensor
6+
7+
import torch
8+
9+
10+
class XlaShardedTensorTest(test_xla_sharding_base.XlaShardingTest):
11+
12+
@classmethod
13+
def setUpClass(cls):
14+
super().setUpClass()
15+
16+
def test_xlashardedtensor_is_dtensor(self):
17+
"""Test that XLAShardedTensor is a subclass of DTensor."""
18+
xt = torch.randn(128, 128).to('xla')
19+
xla_tensor = XLAShardedTensor(xt)
20+
self.assertIsInstance(xla_tensor, DTensor)
21+
22+
def test_xlashardedtensor_gradient(self):
23+
"""Test accessing gradients of an XLAShardedTensor (triggers __torch_function__)."""
24+
xt = torch.randn(128, 128).to('xla')
25+
xla_tensor = XLAShardedTensor(xt, requires_grad=True)
26+
result = xla_tensor.sum()
27+
result.backward()
28+
29+
# this should trigger __torch_function__
30+
grad = xla_tensor.grad
31+
32+
self.assertIsNotNone(grad)
33+
self.assertEqual(grad.shape, xla_tensor.shape)
34+
35+
36+
if __name__ == '__main__':
37+
test = unittest.main()
38+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ run_test "$_TEST_DIR/spmd/test_fsdp_v2.py"
6363
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
6464
run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
6565
run_test "$_TEST_DIR/spmd/test_dtensor_redistribute.py"
66+
run_test "$_TEST_DIR/spmd/test_xla_sharded_tensor.py"
6667
run_test "$_TEST_DIR/test_gradient_accumulation.py"
6768
XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v
6869
run_test "$_TEST_DIR/test_autocast.py"

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from torch.distributed.device_mesh import DeviceMesh
1212
from torch.distributed.tensor.placement_types import Placement, Shard, Replicate, Partial
1313
from torch.utils._pytree import tree_map_only
14+
from torch.distributed.tensor import DTensor
1415

1516

1617
@dataclass
@@ -63,7 +64,7 @@ def no_dispatch() -> Iterator[None]:
6364
del guard
6465

6566

66-
class XLAShardedTensor(torch.Tensor):
67+
class XLAShardedTensor(DTensor):
6768
"""
6869
A wrapper around `torch.Tensor` with sharding annotation
6970
for XLA SPMD auto-sharding. The wrapped tensors are unwrapped
@@ -300,4 +301,4 @@ def redistribute(self, device_mesh, placements, *, async_op: bool = False):
300301

301302
@classmethod
302303
def __torch_function__(cls, func, types, args=(), kwargs=None):
303-
return super().__torch_function__(func, types, args, kwargs)
304+
return super(DTensor, cls).__torch_function__(func, types, args, kwargs)

0 commit comments

Comments
 (0)