Skip to content

Commit 43589c0

Browse files
authored
Implement XLAShardedTensor.redistribute and test (#9529)
1 parent 2889f69 commit 43589c0

File tree

5 files changed

+307
-1
lines changed

5 files changed

+307
-1
lines changed

test/neuron/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ function run_xla_op_tests3 {
256256
#run_test "$_TEST_DIR/spmd/test_dtensor_integration2.py"
257257
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
258258
run_test_multi_device "$_TEST_DIR/spmd/test_xla_dtensor_spec_conv.py"
259+
run_test_multi_device "$_TEST_DIR/spmd/test_dtensor_redistribute.py"
259260
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
260261
#run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
261262
run_test "$_TEST_DIR/spmd/test_train_spmd_linear_model.py"

test/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ function run_xla_op_tests3 {
256256
run_test_multi_devices_without_func "$_TEST_DIR/spmd/test_dtensor_integration3.py"
257257
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
258258
run_test_multi_devices "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
259+
run_test_multi_devices "$_TEST_DIR/spmd/test_dtensor_redistribute.py"
259260
run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
260261
run_test "$_TEST_DIR/spmd/test_spmd_parameter_wrapping.py"
261262
run_test "$_TEST_DIR/spmd/test_mp_input_sharding.py"
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
import sys
2+
import unittest
3+
import torch
4+
from torch.distributed.tensor.placement_types import Shard, Replicate, Partial
5+
import torch_xla.runtime as xr
6+
import torch_xla.distributed.spmd as xs
7+
import torch_xla
8+
import numpy as np
9+
import test_xla_sharding_base
10+
from absl.testing import parameterized
11+
12+
13+
class DTensorRedistributeTest(test_xla_sharding_base.XlaShardingTest,
14+
parameterized.TestCase):
15+
16+
@classmethod
17+
def setUpClass(cls):
18+
super().setUpClass()
19+
xr.use_spmd()
20+
21+
def _verify_sharding_spec(self, tensor, expected_devices=None):
22+
"""Verify tensor sharding spec after mark_step"""
23+
torch_xla.sync()
24+
sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(tensor)
25+
if expected_devices:
26+
self.assertIn(expected_devices, sharding_spec)
27+
return sharding_spec
28+
29+
# Test tensor shapes: 0D, 1D, 2D, 3D
30+
@parameterized.parameters(
31+
((), ()), # 0D scalar
32+
((8,), (0,)), # 1D
33+
((8, 16), (0, None)), # 2D
34+
((4, 8, 16), (0, None, None)) # 3D
35+
)
36+
@unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices")
37+
def test_tensor_shapes(self, shape, partition_spec):
38+
device_count = xr.global_runtime_device_count()
39+
mesh = xs.Mesh(np.arange(device_count), (device_count,))
40+
41+
if len(shape) == 0:
42+
tensor = torch.tensor(1.0).to('xla')
43+
placements = [Replicate()]
44+
expected_spec = ()
45+
else:
46+
tensor = torch.randn(shape).to('xla')
47+
sharded_tensor = xs.mark_sharding(tensor, mesh, partition_spec)
48+
placements = [Shard(0)]
49+
expected_spec = partition_spec
50+
51+
redistributed = sharded_tensor.redistribute(mesh, placements)
52+
self.assertEqual(redistributed.partition_spec, expected_spec)
53+
54+
# Convert partition spec to expected devices pattern
55+
devices_pattern = [
56+
str(device_count) if spec == 0 else '1' for spec in expected_spec
57+
]
58+
expected_devices = f"devices=[{','.join(devices_pattern)}]"
59+
60+
# Skip HLO verification for 4D tensors due to XLA optimization issues
61+
if len(shape) < 4:
62+
self._verify_sharding_spec(redistributed.global_tensor,
63+
expected_devices)
64+
65+
# Test tensor dtypes: bf16, f32, int32
66+
@parameterized.parameters(torch.bfloat16, torch.float32, torch.int32)
67+
@unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices")
68+
def test_tensor_dtypes(self, dtype):
69+
device_count = xr.global_runtime_device_count()
70+
mesh = xs.Mesh(np.arange(device_count), (device_count,))
71+
72+
if dtype == torch.int32:
73+
tensor = torch.randint(0, 100, (8, 16), dtype=dtype).to('xla')
74+
else:
75+
tensor = torch.randn(8, 16, dtype=dtype).to('xla')
76+
77+
sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None))
78+
placements = [Shard(0)]
79+
80+
redistributed = sharded_tensor.redistribute(mesh, placements)
81+
self.assertEqual(redistributed.partition_spec, (0, None))
82+
self.assertEqual(redistributed.global_tensor.dtype, dtype)
83+
84+
# Verify HLO sharding
85+
expected_devices = f"devices=[{device_count},1]"
86+
self._verify_sharding_spec(redistributed.global_tensor, expected_devices)
87+
88+
# Test device mesh dimensions: 1D, 2D
89+
@unittest.skipIf(xr.global_runtime_device_count() < 4, "Need ≥4 devices")
90+
def test_device_mesh_dimensions(self):
91+
device_count = xr.global_runtime_device_count()
92+
93+
# 1D mesh
94+
mesh_1d = xs.Mesh(np.arange(device_count), (device_count,))
95+
tensor = torch.randn(8, 16).to('xla')
96+
sharded_tensor = xs.mark_sharding(tensor, mesh_1d, (0, None))
97+
98+
redistributed = sharded_tensor.redistribute(mesh_1d, [Shard(1)])
99+
self.assertEqual(redistributed.partition_spec, (None, 0))
100+
101+
# Verify HLO sharding for 1D mesh
102+
expected_devices = f"devices=[1,{device_count}]"
103+
self._verify_sharding_spec(redistributed.global_tensor, expected_devices)
104+
105+
# 2D mesh
106+
if device_count >= 4 and device_count % 2 == 0:
107+
mesh_2d = xs.Mesh(np.arange(device_count), (2, device_count // 2))
108+
tensor_2d = torch.randn(8, 16).to('xla')
109+
sharded_tensor = xs.mark_sharding(tensor_2d, mesh_2d, (0, None))
110+
111+
redistributed = sharded_tensor.redistribute(
112+
mesh_2d, [Replicate(), Shard(1)])
113+
self.assertEqual(redistributed.partition_spec, (None, 1))
114+
115+
# Verify HLO sharding for 2D mesh
116+
expected_devices = f"devices=[1,{device_count // 2},{device_count // 2}]"
117+
self._verify_sharding_spec(redistributed.global_tensor, expected_devices)
118+
119+
# Test placement types: Replicate, Shard
120+
@unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices")
121+
def test_placement_types(self):
122+
device_count = xr.global_runtime_device_count()
123+
mesh = xs.Mesh(np.arange(device_count), (device_count,))
124+
tensor = torch.randn(8, 16).to('xla')
125+
126+
# Test Replicate
127+
sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None))
128+
redistributed = sharded_tensor.redistribute(mesh, [Replicate()])
129+
self.assertEqual(redistributed.partition_spec, (None, None))
130+
131+
# Verify HLO sharding for replicated
132+
self._verify_sharding_spec(redistributed.global_tensor, "replicated")
133+
134+
# Test Shard on different dimensions
135+
for dim in [0, 1]:
136+
with self.subTest(shard_dim=dim):
137+
redistributed = sharded_tensor.redistribute(mesh, [Shard(dim)])
138+
expected_spec = [None, None]
139+
expected_spec[dim] = 0
140+
self.assertEqual(redistributed.partition_spec, tuple(expected_spec))
141+
142+
# Verify HLO sharding
143+
devices_pattern = [
144+
str(device_count) if i == dim else '1' for i in range(2)
145+
]
146+
expected_devices = f"devices=[{','.join(devices_pattern)}]"
147+
self._verify_sharding_spec(redistributed.global_tensor,
148+
expected_devices)
149+
150+
# Test error cases with invalid inputs
151+
@unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices")
152+
def test_invalid_inputs(self):
153+
device_count = xr.global_runtime_device_count()
154+
mesh = xs.Mesh(np.arange(device_count), (device_count,))
155+
tensor = torch.randn(8, 16).to('xla')
156+
sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None))
157+
158+
# Test invalid shard dimension (tensor only has dims 0,1 but asking for dim 2)
159+
with self.assertRaises(IndexError):
160+
sharded_tensor.redistribute(mesh, [Shard(2)])
161+
162+
# Test mismatched placements length (1D mesh expects 1 placement, not 2)
163+
with self.assertRaises(ValueError):
164+
sharded_tensor.redistribute(mesh, [Shard(0), Shard(1)])
165+
166+
# Test Partial placement - should raise error about not being implemented
167+
with self.assertRaises(NotImplementedError):
168+
sharded_tensor.redistribute(mesh, [Partial()])
169+
170+
# Test sharding propagation through operations
171+
@unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices")
172+
def test_sharding_propagation(self):
173+
device_count = xr.global_runtime_device_count()
174+
mesh = xs.Mesh(np.arange(device_count), (device_count,))
175+
176+
# Unary ops
177+
tensor = torch.randn(8, 16).to('xla')
178+
sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None))
179+
redistributed = sharded_tensor.redistribute(mesh, [Shard(0)])
180+
181+
relu_result = torch.relu(redistributed.global_tensor)
182+
self.assertEqual(relu_result.shape, (8, 16))
183+
self.assertTrue(torch.all(relu_result >= 0))
184+
185+
# Binary ops
186+
tensor2 = torch.randn(8, 16).to('xla')
187+
sharded_tensor2 = xs.mark_sharding(tensor2, mesh, (0, None))
188+
redistributed2 = sharded_tensor2.redistribute(mesh, [Shard(0)])
189+
190+
add_result = redistributed.global_tensor + redistributed2.global_tensor
191+
mul_result = redistributed.global_tensor * redistributed2.global_tensor
192+
193+
# Verify operation results
194+
self.assertEqual(add_result.shape, (8, 16))
195+
self.assertEqual(mul_result.shape, (8, 16))
196+
197+
# Verify operations work correctly
198+
self.assertTrue(
199+
torch.allclose(
200+
add_result,
201+
redistributed.global_tensor + redistributed2.global_tensor))
202+
self.assertTrue(
203+
torch.allclose(
204+
mul_result,
205+
redistributed.global_tensor * redistributed2.global_tensor))
206+
207+
# Test comprehensive redistribute scenarios
208+
@unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices")
209+
def test_comprehensive_redistribute(self):
210+
device_count = xr.global_runtime_device_count()
211+
mesh = xs.Mesh(np.arange(device_count), (device_count,))
212+
213+
tensor = torch.randn(8, 16).to('xla')
214+
sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None))
215+
216+
# Test all placement combinations for 1D mesh
217+
placement_types = [Replicate(), Shard(0), Shard(1)]
218+
219+
for placement in placement_types:
220+
with self.subTest(placement=placement):
221+
placements = [placement]
222+
223+
if isinstance(placement, Shard):
224+
expected_spec = [None] * 2
225+
expected_spec[placement.dim] = 0
226+
expected_spec = tuple(expected_spec)
227+
else:
228+
expected_spec = (None, None)
229+
230+
redistributed = sharded_tensor.redistribute(mesh, placements)
231+
self.assertEqual(redistributed.partition_spec, expected_spec)
232+
233+
# Verify HLO sharding
234+
if isinstance(placement, Shard):
235+
devices_pattern = [
236+
str(device_count) if i == placement.dim else '1' for i in range(2)
237+
]
238+
expected_devices = f"devices=[{','.join(devices_pattern)}]"
239+
else:
240+
expected_devices = "replicated"
241+
self._verify_sharding_spec(redistributed.global_tensor,
242+
expected_devices)
243+
244+
# Test async redistribute
245+
@unittest.skipIf(xr.global_runtime_device_count() < 4, "Need ≥4 devices")
246+
def test_async_redistribute(self):
247+
device_count = xr.global_runtime_device_count()
248+
mesh_shape = (2, device_count // 2)
249+
mesh = xs.Mesh(np.arange(device_count), mesh_shape)
250+
251+
tensor = torch.randn(8, 16).to('xla')
252+
sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None))
253+
254+
# Test async redistribute
255+
placements = [Replicate(), Shard(0)]
256+
redistributed = sharded_tensor.redistribute(mesh, placements, async_op=True)
257+
self.assertEqual(redistributed.partition_spec, (1, None))
258+
259+
# Verify async operation creates different tensor object
260+
self.assertIsNot(redistributed.global_tensor, sharded_tensor.global_tensor)
261+
262+
# Verify HLO sharding for async redistribute (XLA generates more complex pattern)
263+
expected_devices = f"devices=[2,1,{device_count // 2}]"
264+
self._verify_sharding_spec(redistributed.global_tensor, expected_devices)
265+
266+
267+
if __name__ == '__main__':
268+
test = unittest.main()
269+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ run_test "$_TEST_DIR/spmd/test_xla_auto_sharding.py"
6262
run_test "$_TEST_DIR/spmd/test_fsdp_v2.py"
6363
run_test "$_TEST_DIR/spmd/test_dtensor_convert_mesh.py"
6464
run_test "$_TEST_DIR/spmd/test_xla_dtensor_spec_conversion.py"
65+
run_test "$_TEST_DIR/spmd/test_dtensor_redistribute.py"
6566
run_test "$_TEST_DIR/test_gradient_accumulation.py"
6667
XLA_EXPERIMENTAL=nonzero:masked_select:nms run_test "$_TEST_DIR/ds/test_dynamic_shape_models.py" -v
6768
run_test "$_TEST_DIR/test_autocast.py"

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch_xla.runtime as xr
1010
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
1111
from torch.distributed.device_mesh import DeviceMesh
12-
from torch.distributed.tensor.placement_types import Placement, Shard, Replicate
12+
from torch.distributed.tensor.placement_types import Placement, Shard, Replicate, Partial
1313
from torch.utils._pytree import tree_map_only
1414

1515

@@ -264,6 +264,40 @@ def invalidate_spec_cache(self):
264264
"""Invalidate the cached DTensorSpec."""
265265
self._cached_spec = None
266266

267+
def redistribute(self, device_mesh, placements, *, async_op: bool = False):
268+
# Validate inputs
269+
if len(placements) != len(device_mesh.mesh_shape):
270+
raise ValueError(
271+
f"Number of placements ({len(placements)}) must match mesh dimensions ({len(device_mesh.mesh_shape)})"
272+
)
273+
274+
# Check for unsupported placement types
275+
for placement in placements:
276+
if isinstance(placement, Partial):
277+
raise NotImplementedError(
278+
"Partial placement is not yet implemented and may have unexpected behavior. "
279+
"Use Shard or Replicate placements instead.")
280+
281+
# Convert placements to partition spec
282+
partition_spec = [None] * len(self.global_tensor.shape)
283+
for mesh_dim, placement in enumerate(placements):
284+
if isinstance(placement, Shard):
285+
if placement.dim >= len(self.global_tensor.shape):
286+
raise IndexError(
287+
f"Shard dimension {placement.dim} is out of bounds for tensor with {len(self.global_tensor.shape)} dimensions"
288+
)
289+
partition_spec[placement.dim] = mesh_dim
290+
291+
result_tensor = self.global_tensor.clone(
292+
) if async_op else self.global_tensor
293+
op_sharding = device_mesh.get_op_sharding(tuple(partition_spec))
294+
torch_xla._XLAC._xla_annotate_custom_sharding(result_tensor, op_sharding)
295+
296+
return XLAShardedTensor(
297+
result_tensor,
298+
mesh_shape=device_mesh.mesh_shape,
299+
partition_spec=tuple(partition_spec))
300+
267301
@classmethod
268302
def __torch_function__(cls, func, types, args=(), kwargs=None):
269303
return super().__torch_function__(func, types, args, kwargs)

0 commit comments

Comments
 (0)