|
| 1 | +import sys |
| 2 | +import unittest |
| 3 | +import torch |
| 4 | +from torch.distributed.tensor.placement_types import Shard, Replicate, Partial |
| 5 | +import torch_xla.runtime as xr |
| 6 | +import torch_xla.distributed.spmd as xs |
| 7 | +import torch_xla |
| 8 | +import numpy as np |
| 9 | +import test_xla_sharding_base |
| 10 | +from absl.testing import parameterized |
| 11 | + |
| 12 | + |
| 13 | +class DTensorRedistributeTest(test_xla_sharding_base.XlaShardingTest, |
| 14 | + parameterized.TestCase): |
| 15 | + |
| 16 | + @classmethod |
| 17 | + def setUpClass(cls): |
| 18 | + super().setUpClass() |
| 19 | + xr.use_spmd() |
| 20 | + |
| 21 | + def _verify_sharding_spec(self, tensor, expected_devices=None): |
| 22 | + """Verify tensor sharding spec after mark_step""" |
| 23 | + torch_xla.sync() |
| 24 | + sharding_spec = torch_xla._XLAC._get_xla_sharding_spec(tensor) |
| 25 | + if expected_devices: |
| 26 | + self.assertIn(expected_devices, sharding_spec) |
| 27 | + return sharding_spec |
| 28 | + |
| 29 | + # Test tensor shapes: 0D, 1D, 2D, 3D |
| 30 | + @parameterized.parameters( |
| 31 | + ((), ()), # 0D scalar |
| 32 | + ((8,), (0,)), # 1D |
| 33 | + ((8, 16), (0, None)), # 2D |
| 34 | + ((4, 8, 16), (0, None, None)) # 3D |
| 35 | + ) |
| 36 | + @unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices") |
| 37 | + def test_tensor_shapes(self, shape, partition_spec): |
| 38 | + device_count = xr.global_runtime_device_count() |
| 39 | + mesh = xs.Mesh(np.arange(device_count), (device_count,)) |
| 40 | + |
| 41 | + if len(shape) == 0: |
| 42 | + tensor = torch.tensor(1.0).to('xla') |
| 43 | + placements = [Replicate()] |
| 44 | + expected_spec = () |
| 45 | + else: |
| 46 | + tensor = torch.randn(shape).to('xla') |
| 47 | + sharded_tensor = xs.mark_sharding(tensor, mesh, partition_spec) |
| 48 | + placements = [Shard(0)] |
| 49 | + expected_spec = partition_spec |
| 50 | + |
| 51 | + redistributed = sharded_tensor.redistribute(mesh, placements) |
| 52 | + self.assertEqual(redistributed.partition_spec, expected_spec) |
| 53 | + |
| 54 | + # Convert partition spec to expected devices pattern |
| 55 | + devices_pattern = [ |
| 56 | + str(device_count) if spec == 0 else '1' for spec in expected_spec |
| 57 | + ] |
| 58 | + expected_devices = f"devices=[{','.join(devices_pattern)}]" |
| 59 | + |
| 60 | + # Skip HLO verification for 4D tensors due to XLA optimization issues |
| 61 | + if len(shape) < 4: |
| 62 | + self._verify_sharding_spec(redistributed.global_tensor, |
| 63 | + expected_devices) |
| 64 | + |
| 65 | + # Test tensor dtypes: bf16, f32, int32 |
| 66 | + @parameterized.parameters(torch.bfloat16, torch.float32, torch.int32) |
| 67 | + @unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices") |
| 68 | + def test_tensor_dtypes(self, dtype): |
| 69 | + device_count = xr.global_runtime_device_count() |
| 70 | + mesh = xs.Mesh(np.arange(device_count), (device_count,)) |
| 71 | + |
| 72 | + if dtype == torch.int32: |
| 73 | + tensor = torch.randint(0, 100, (8, 16), dtype=dtype).to('xla') |
| 74 | + else: |
| 75 | + tensor = torch.randn(8, 16, dtype=dtype).to('xla') |
| 76 | + |
| 77 | + sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None)) |
| 78 | + placements = [Shard(0)] |
| 79 | + |
| 80 | + redistributed = sharded_tensor.redistribute(mesh, placements) |
| 81 | + self.assertEqual(redistributed.partition_spec, (0, None)) |
| 82 | + self.assertEqual(redistributed.global_tensor.dtype, dtype) |
| 83 | + |
| 84 | + # Verify HLO sharding |
| 85 | + expected_devices = f"devices=[{device_count},1]" |
| 86 | + self._verify_sharding_spec(redistributed.global_tensor, expected_devices) |
| 87 | + |
| 88 | + # Test device mesh dimensions: 1D, 2D |
| 89 | + @unittest.skipIf(xr.global_runtime_device_count() < 4, "Need ≥4 devices") |
| 90 | + def test_device_mesh_dimensions(self): |
| 91 | + device_count = xr.global_runtime_device_count() |
| 92 | + |
| 93 | + # 1D mesh |
| 94 | + mesh_1d = xs.Mesh(np.arange(device_count), (device_count,)) |
| 95 | + tensor = torch.randn(8, 16).to('xla') |
| 96 | + sharded_tensor = xs.mark_sharding(tensor, mesh_1d, (0, None)) |
| 97 | + |
| 98 | + redistributed = sharded_tensor.redistribute(mesh_1d, [Shard(1)]) |
| 99 | + self.assertEqual(redistributed.partition_spec, (None, 0)) |
| 100 | + |
| 101 | + # Verify HLO sharding for 1D mesh |
| 102 | + expected_devices = f"devices=[1,{device_count}]" |
| 103 | + self._verify_sharding_spec(redistributed.global_tensor, expected_devices) |
| 104 | + |
| 105 | + # 2D mesh |
| 106 | + if device_count >= 4 and device_count % 2 == 0: |
| 107 | + mesh_2d = xs.Mesh(np.arange(device_count), (2, device_count // 2)) |
| 108 | + tensor_2d = torch.randn(8, 16).to('xla') |
| 109 | + sharded_tensor = xs.mark_sharding(tensor_2d, mesh_2d, (0, None)) |
| 110 | + |
| 111 | + redistributed = sharded_tensor.redistribute( |
| 112 | + mesh_2d, [Replicate(), Shard(1)]) |
| 113 | + self.assertEqual(redistributed.partition_spec, (None, 1)) |
| 114 | + |
| 115 | + # Verify HLO sharding for 2D mesh |
| 116 | + expected_devices = f"devices=[1,{device_count // 2},{device_count // 2}]" |
| 117 | + self._verify_sharding_spec(redistributed.global_tensor, expected_devices) |
| 118 | + |
| 119 | + # Test placement types: Replicate, Shard |
| 120 | + @unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices") |
| 121 | + def test_placement_types(self): |
| 122 | + device_count = xr.global_runtime_device_count() |
| 123 | + mesh = xs.Mesh(np.arange(device_count), (device_count,)) |
| 124 | + tensor = torch.randn(8, 16).to('xla') |
| 125 | + |
| 126 | + # Test Replicate |
| 127 | + sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None)) |
| 128 | + redistributed = sharded_tensor.redistribute(mesh, [Replicate()]) |
| 129 | + self.assertEqual(redistributed.partition_spec, (None, None)) |
| 130 | + |
| 131 | + # Verify HLO sharding for replicated |
| 132 | + self._verify_sharding_spec(redistributed.global_tensor, "replicated") |
| 133 | + |
| 134 | + # Test Shard on different dimensions |
| 135 | + for dim in [0, 1]: |
| 136 | + with self.subTest(shard_dim=dim): |
| 137 | + redistributed = sharded_tensor.redistribute(mesh, [Shard(dim)]) |
| 138 | + expected_spec = [None, None] |
| 139 | + expected_spec[dim] = 0 |
| 140 | + self.assertEqual(redistributed.partition_spec, tuple(expected_spec)) |
| 141 | + |
| 142 | + # Verify HLO sharding |
| 143 | + devices_pattern = [ |
| 144 | + str(device_count) if i == dim else '1' for i in range(2) |
| 145 | + ] |
| 146 | + expected_devices = f"devices=[{','.join(devices_pattern)}]" |
| 147 | + self._verify_sharding_spec(redistributed.global_tensor, |
| 148 | + expected_devices) |
| 149 | + |
| 150 | + # Test error cases with invalid inputs |
| 151 | + @unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices") |
| 152 | + def test_invalid_inputs(self): |
| 153 | + device_count = xr.global_runtime_device_count() |
| 154 | + mesh = xs.Mesh(np.arange(device_count), (device_count,)) |
| 155 | + tensor = torch.randn(8, 16).to('xla') |
| 156 | + sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None)) |
| 157 | + |
| 158 | + # Test invalid shard dimension (tensor only has dims 0,1 but asking for dim 2) |
| 159 | + with self.assertRaises(IndexError): |
| 160 | + sharded_tensor.redistribute(mesh, [Shard(2)]) |
| 161 | + |
| 162 | + # Test mismatched placements length (1D mesh expects 1 placement, not 2) |
| 163 | + with self.assertRaises(ValueError): |
| 164 | + sharded_tensor.redistribute(mesh, [Shard(0), Shard(1)]) |
| 165 | + |
| 166 | + # Test Partial placement - should raise error about not being implemented |
| 167 | + with self.assertRaises(NotImplementedError): |
| 168 | + sharded_tensor.redistribute(mesh, [Partial()]) |
| 169 | + |
| 170 | + # Test sharding propagation through operations |
| 171 | + @unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices") |
| 172 | + def test_sharding_propagation(self): |
| 173 | + device_count = xr.global_runtime_device_count() |
| 174 | + mesh = xs.Mesh(np.arange(device_count), (device_count,)) |
| 175 | + |
| 176 | + # Unary ops |
| 177 | + tensor = torch.randn(8, 16).to('xla') |
| 178 | + sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None)) |
| 179 | + redistributed = sharded_tensor.redistribute(mesh, [Shard(0)]) |
| 180 | + |
| 181 | + relu_result = torch.relu(redistributed.global_tensor) |
| 182 | + self.assertEqual(relu_result.shape, (8, 16)) |
| 183 | + self.assertTrue(torch.all(relu_result >= 0)) |
| 184 | + |
| 185 | + # Binary ops |
| 186 | + tensor2 = torch.randn(8, 16).to('xla') |
| 187 | + sharded_tensor2 = xs.mark_sharding(tensor2, mesh, (0, None)) |
| 188 | + redistributed2 = sharded_tensor2.redistribute(mesh, [Shard(0)]) |
| 189 | + |
| 190 | + add_result = redistributed.global_tensor + redistributed2.global_tensor |
| 191 | + mul_result = redistributed.global_tensor * redistributed2.global_tensor |
| 192 | + |
| 193 | + # Verify operation results |
| 194 | + self.assertEqual(add_result.shape, (8, 16)) |
| 195 | + self.assertEqual(mul_result.shape, (8, 16)) |
| 196 | + |
| 197 | + # Verify operations work correctly |
| 198 | + self.assertTrue( |
| 199 | + torch.allclose( |
| 200 | + add_result, |
| 201 | + redistributed.global_tensor + redistributed2.global_tensor)) |
| 202 | + self.assertTrue( |
| 203 | + torch.allclose( |
| 204 | + mul_result, |
| 205 | + redistributed.global_tensor * redistributed2.global_tensor)) |
| 206 | + |
| 207 | + # Test comprehensive redistribute scenarios |
| 208 | + @unittest.skipIf(xr.global_runtime_device_count() < 2, "Need ≥2 devices") |
| 209 | + def test_comprehensive_redistribute(self): |
| 210 | + device_count = xr.global_runtime_device_count() |
| 211 | + mesh = xs.Mesh(np.arange(device_count), (device_count,)) |
| 212 | + |
| 213 | + tensor = torch.randn(8, 16).to('xla') |
| 214 | + sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None)) |
| 215 | + |
| 216 | + # Test all placement combinations for 1D mesh |
| 217 | + placement_types = [Replicate(), Shard(0), Shard(1)] |
| 218 | + |
| 219 | + for placement in placement_types: |
| 220 | + with self.subTest(placement=placement): |
| 221 | + placements = [placement] |
| 222 | + |
| 223 | + if isinstance(placement, Shard): |
| 224 | + expected_spec = [None] * 2 |
| 225 | + expected_spec[placement.dim] = 0 |
| 226 | + expected_spec = tuple(expected_spec) |
| 227 | + else: |
| 228 | + expected_spec = (None, None) |
| 229 | + |
| 230 | + redistributed = sharded_tensor.redistribute(mesh, placements) |
| 231 | + self.assertEqual(redistributed.partition_spec, expected_spec) |
| 232 | + |
| 233 | + # Verify HLO sharding |
| 234 | + if isinstance(placement, Shard): |
| 235 | + devices_pattern = [ |
| 236 | + str(device_count) if i == placement.dim else '1' for i in range(2) |
| 237 | + ] |
| 238 | + expected_devices = f"devices=[{','.join(devices_pattern)}]" |
| 239 | + else: |
| 240 | + expected_devices = "replicated" |
| 241 | + self._verify_sharding_spec(redistributed.global_tensor, |
| 242 | + expected_devices) |
| 243 | + |
| 244 | + # Test async redistribute |
| 245 | + @unittest.skipIf(xr.global_runtime_device_count() < 4, "Need ≥4 devices") |
| 246 | + def test_async_redistribute(self): |
| 247 | + device_count = xr.global_runtime_device_count() |
| 248 | + mesh_shape = (2, device_count // 2) |
| 249 | + mesh = xs.Mesh(np.arange(device_count), mesh_shape) |
| 250 | + |
| 251 | + tensor = torch.randn(8, 16).to('xla') |
| 252 | + sharded_tensor = xs.mark_sharding(tensor, mesh, (0, None)) |
| 253 | + |
| 254 | + # Test async redistribute |
| 255 | + placements = [Replicate(), Shard(0)] |
| 256 | + redistributed = sharded_tensor.redistribute(mesh, placements, async_op=True) |
| 257 | + self.assertEqual(redistributed.partition_spec, (1, None)) |
| 258 | + |
| 259 | + # Verify async operation creates different tensor object |
| 260 | + self.assertIsNot(redistributed.global_tensor, sharded_tensor.global_tensor) |
| 261 | + |
| 262 | + # Verify HLO sharding for async redistribute (XLA generates more complex pattern) |
| 263 | + expected_devices = f"devices=[2,1,{device_count // 2}]" |
| 264 | + self._verify_sharding_spec(redistributed.global_tensor, expected_devices) |
| 265 | + |
| 266 | + |
| 267 | +if __name__ == '__main__': |
| 268 | + test = unittest.main() |
| 269 | + sys.exit(0 if test.result.wasSuccessful() else 1) |
0 commit comments