|
| 1 | +import os |
| 2 | +import sys |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor |
| 6 | + |
| 7 | +import torch_xla |
| 8 | +import torch_xla.runtime as xr |
| 9 | + |
| 10 | +import unittest |
| 11 | +import test_xla_sharding_base |
| 12 | + |
| 13 | + |
| 14 | +class XLADTensorSpecConversionTest(test_xla_sharding_base.XlaShardingTest): |
| 15 | + |
| 16 | + @classmethod |
| 17 | + def setUpClass(cls): |
| 18 | + super().setUpClass() |
| 19 | + |
| 20 | + def test_sample_test_case(self): |
| 21 | + world_size = xr.global_runtime_device_count() |
| 22 | + mesh = DeviceMesh("xla", torch.arange(world_size)) |
| 23 | + big_tensor = torch.randn(100000, 88) |
| 24 | + my_dtensor = distribute_tensor(big_tensor, mesh, [Shard(0)]) |
| 25 | + |
| 26 | + assert my_dtensor._spec.mesh.device_type == mesh.device_type |
| 27 | + assert my_dtensor._spec.placements == (Shard(0),) |
| 28 | + |
| 29 | + def test_xla_to_dtensor_spec_conversion(self): |
| 30 | + device_count = xr.global_runtime_device_count() |
| 31 | + mesh = DeviceMesh("xla", list(range(device_count))) |
| 32 | + |
| 33 | + # Test different sharding patterns |
| 34 | + from torch.distributed.tensor.placement_types import Replicate |
| 35 | + test_cases = [ |
| 36 | + (torch.randn(100, 50), [Shard(0)]), |
| 37 | + (torch.randn(100, 50), [Shard(1)]), |
| 38 | + (torch.randn(100, 50, 25), [Shard(0)]), |
| 39 | + (torch.randn(100, 50), [Replicate()]), |
| 40 | + ] |
| 41 | + |
| 42 | + for tensor, placements in test_cases: |
| 43 | + xla_tensor = distribute_tensor(tensor, mesh, placements) |
| 44 | + spec = xla_tensor._spec |
| 45 | + |
| 46 | + assert spec is not None |
| 47 | + assert spec.mesh.device_type == "xla" |
| 48 | + assert spec.tensor_meta.shape == tensor.shape |
| 49 | + assert spec.tensor_meta.dtype == tensor.dtype |
| 50 | + assert len(spec.placements) >= 1 |
| 51 | + assert spec.placements == tuple(placements) |
| 52 | + |
| 53 | + def test_mesh_conversion(self): |
| 54 | + device_count = xr.global_runtime_device_count() |
| 55 | + original_mesh = DeviceMesh("xla", list(range(device_count))) |
| 56 | + tensor = torch.randn(50, 50) |
| 57 | + xla_tensor = distribute_tensor(tensor, original_mesh, [Shard(0)]) |
| 58 | + |
| 59 | + converted_spec = xla_tensor._spec |
| 60 | + |
| 61 | + assert converted_spec.mesh.device_type == "xla" |
| 62 | + assert converted_spec.mesh.size() == device_count |
| 63 | + # assert on mesh dimensions |
| 64 | + assert converted_spec.mesh.shape == original_mesh.shape |
| 65 | + |
| 66 | + def test_spec_caching(self): |
| 67 | + """Test that _spec property caches results for better performance""" |
| 68 | + import time |
| 69 | + device_count = xr.global_runtime_device_count() |
| 70 | + mesh = DeviceMesh("xla", list(range(device_count))) |
| 71 | + tensor = torch.randn(1000, |
| 72 | + 1000) # Large tensor to make spec creation noticeable |
| 73 | + xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)]) |
| 74 | + |
| 75 | + # first access should create and cache the spec |
| 76 | + start_time = time.time() |
| 77 | + spec1 = xla_tensor._spec |
| 78 | + first_access_time = time.time() - start_time |
| 79 | + |
| 80 | + # should be much faster due to caching |
| 81 | + start_time = time.time() |
| 82 | + spec2 = xla_tensor._spec |
| 83 | + second_access_time = time.time() - start_time |
| 84 | + |
| 85 | + assert spec1 is spec2 |
| 86 | + print( |
| 87 | + f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s" |
| 88 | + ) |
| 89 | + assert second_access_time * 10 < first_access_time, \ |
| 90 | + f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s" |
| 91 | + |
| 92 | + def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements): |
| 93 | + """Helper to create tensor and mesh for testing""" |
| 94 | + device_count = xr.global_runtime_device_count() |
| 95 | + if device_count < max(mesh_shape): |
| 96 | + self.skipTest( |
| 97 | + f"Need at least {max(mesh_shape)} devices, got {device_count}") |
| 98 | + |
| 99 | + mesh = DeviceMesh("xla", torch.arange(device_count).reshape(mesh_shape)) |
| 100 | + tensor = torch.randn(*tensor_shape) |
| 101 | + return distribute_tensor(tensor, mesh, placements), mesh |
| 102 | + |
| 103 | + def test_multi_dim_sharding_spec(self): |
| 104 | + """Test _spec for multi-dimensional sharding""" |
| 105 | + device_count = xr.global_runtime_device_count() |
| 106 | + if device_count < 4: |
| 107 | + self.skipTest("Need at least 4 devices for 2D mesh") |
| 108 | + |
| 109 | + mesh_shape = (2, device_count // 2) |
| 110 | + xla_tensor, mesh = self._create_test_tensor_and_mesh( |
| 111 | + (100, 50), mesh_shape, [Shard(0), Shard(1)]) |
| 112 | + spec = xla_tensor._spec |
| 113 | + |
| 114 | + assert len(spec.placements) == 2 |
| 115 | + assert spec.mesh.ndim == 2 |
| 116 | + |
| 117 | + def test_tensor_operations_preserve_spec(self): |
| 118 | + """Test that tensor operations preserve sharding metadata""" |
| 119 | + xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,), |
| 120 | + [Shard(0)]) |
| 121 | + |
| 122 | + result_add = xla_tensor + 1 |
| 123 | + result_mul = xla_tensor * 2 |
| 124 | + result_relu = torch.relu(xla_tensor) |
| 125 | + |
| 126 | + for result in [result_add, result_mul, result_relu]: |
| 127 | + assert hasattr(result, '_spec') |
| 128 | + assert result._spec.mesh.device_type == "xla" |
| 129 | + |
| 130 | + def test_mixed_placement_spec(self): |
| 131 | + """Test _spec for tensors with mixed shard/replicate placements""" |
| 132 | + from torch.distributed.tensor.placement_types import Replicate |
| 133 | + device_count = xr.global_runtime_device_count() |
| 134 | + if device_count < 4: |
| 135 | + self.skipTest("Need at least 4 devices for 2D mesh") |
| 136 | + |
| 137 | + mesh_shape = (2, device_count // 2) |
| 138 | + xla_tensor, mesh = self._create_test_tensor_and_mesh( |
| 139 | + (100, 50), mesh_shape, [Shard(0), Replicate()]) |
| 140 | + spec = xla_tensor._spec |
| 141 | + |
| 142 | + assert len(spec.placements) == 2 |
| 143 | + assert isinstance(spec.placements[0], Shard) |
| 144 | + assert isinstance(spec.placements[1], Replicate) |
| 145 | + |
| 146 | + |
| 147 | +if __name__ == '__main__': |
| 148 | + test = unittest.main() |
| 149 | + sys.exit(0 if test.result.wasSuccessful() else 1) |
0 commit comments