Skip to content

Commit e1e1742

Browse files
aws-cphHoomaaan
authored andcommitted
:qImplement XLAShardedTensor._spec and test
1 parent f400690 commit e1e1742

File tree

3 files changed

+62
-138
lines changed

3 files changed

+62
-138
lines changed

test/spmd/test_xla_dtensor_spec_conversion.py

Lines changed: 31 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,9 @@
33

44
import torch
55
from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor
6-
from torch.distributed.tensor.placement_types import Replicate
76

87
import torch_xla
98
import torch_xla.runtime as xr
10-
from torch_xla.distributed.spmd import XLAShardedTensor
11-
from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor
129

1310
import unittest
1411
import test_xla_sharding_base
@@ -34,6 +31,7 @@ def test_xla_to_dtensor_spec_conversion(self):
3431
mesh = DeviceMesh("xla", list(range(device_count)))
3532

3633
# Test different sharding patterns
34+
from torch.distributed.tensor.placement_types import Replicate
3735
test_cases = [
3836
(torch.randn(100, 50), [Shard(0)]),
3937
(torch.randn(100, 50), [Shard(1)]),
@@ -66,20 +64,30 @@ def test_mesh_conversion(self):
6664
assert converted_spec.mesh.shape == original_mesh.shape
6765

6866
def test_spec_caching(self):
69-
"""Test that _spec property caches results
70-
"""
67+
"""Test that _spec property caches results for better performance"""
68+
import time
7169
device_count = xr.global_runtime_device_count()
7270
mesh = DeviceMesh("xla", list(range(device_count)))
73-
tensor = torch.randn(100, 100)
71+
tensor = torch.randn(1000,
72+
1000) # Large tensor to make spec creation noticeable
7473
xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)])
7574

75+
# first access should create and cache the spec
76+
start_time = time.time()
7677
spec1 = xla_tensor._spec
78+
first_access_time = time.time() - start_time
7779

78-
assert xla_tensor._cached_spec is not None
79-
assert xla_tensor._cached_spec is spec1
80-
80+
# should be much faster due to caching
81+
start_time = time.time()
8182
spec2 = xla_tensor._spec
83+
second_access_time = time.time() - start_time
84+
8285
assert spec1 is spec2
86+
print(
87+
f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s"
88+
)
89+
assert second_access_time * 10 < first_access_time, \
90+
f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s"
8391

8492
def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements):
8593
"""Helper to create tensor and mesh for testing"""
@@ -106,8 +114,22 @@ def test_multi_dim_sharding_spec(self):
106114
assert len(spec.placements) == 2
107115
assert spec.mesh.ndim == 2
108116

117+
def test_tensor_operations_preserve_spec(self):
118+
"""Test that tensor operations preserve sharding metadata"""
119+
xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,),
120+
[Shard(0)])
121+
122+
result_add = xla_tensor + 1
123+
result_mul = xla_tensor * 2
124+
result_relu = torch.relu(xla_tensor)
125+
126+
for result in [result_add, result_mul, result_relu]:
127+
assert hasattr(result, '_spec')
128+
assert result._spec.mesh.device_type == "xla"
129+
109130
def test_mixed_placement_spec(self):
110131
"""Test _spec for tensors with mixed shard/replicate placements"""
132+
from torch.distributed.tensor.placement_types import Replicate
111133
device_count = xr.global_runtime_device_count()
112134
if device_count < 4:
113135
self.skipTest("Need at least 4 devices for 2D mesh")
@@ -121,114 +143,6 @@ def test_mixed_placement_spec(self):
121143
assert isinstance(spec.placements[0], Shard)
122144
assert isinstance(spec.placements[1], Replicate)
123145

124-
def test_sharding_info_acquisition(self):
125-
"""Test that non-XLAShardedTensor can acquire sharding information
126-
127-
Tests case of 'elem is not an XLAShardedTensor but there exists
128-
sharding information we want to acquire'
129-
"""
130-
131-
device_count = xr.global_runtime_device_count()
132-
mesh_shape = (device_count,)
133-
partition_spec = (0, None)
134-
135-
regular_tensor = torch.randn(100, 50).to('xla')
136-
137-
sharded_tensor = wrap_as_sharded_tensor(
138-
regular_tensor, mesh_shape=mesh_shape, partition_spec=partition_spec)
139-
140-
# Verify the tensor acquired the sharding information
141-
assert isinstance(sharded_tensor, XLAShardedTensor)
142-
assert sharded_tensor.mesh_shape == mesh_shape
143-
assert sharded_tensor.partition_spec == partition_spec
144-
145-
def test_resharding_logic(self):
146-
"""
147-
Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t.
148-
"""
149-
150-
device_count = xr.global_runtime_device_count()
151-
if device_count < 4:
152-
self.skipTest("Need at least 4 devices for resharding test")
153-
154-
# Initial sharding
155-
initial_mesh_shape = (device_count,)
156-
initial_partition_spec = (0, None)
157-
new_mesh_shape = (2, device_count // 2)
158-
new_partition_spec = (0, 1)
159-
160-
# Create tensor and verify resharding
161-
tensor = torch.randn(100, 50).to('xla')
162-
sharded_tensor = wrap_as_sharded_tensor(
163-
tensor,
164-
mesh_shape=initial_mesh_shape,
165-
partition_spec=initial_partition_spec)
166-
initial_spec = sharded_tensor._spec
167-
168-
resharded_tensor = wrap_as_sharded_tensor(
169-
sharded_tensor,
170-
mesh_shape=new_mesh_shape,
171-
partition_spec=new_partition_spec)
172-
173-
# Verify resharding worked and cache was invalidated
174-
assert resharded_tensor.mesh_shape == new_mesh_shape
175-
assert resharded_tensor.partition_spec == new_partition_spec
176-
assert resharded_tensor._spec is not initial_spec
177-
178-
def test_spec_invalidation_on_resharding(self):
179-
"""Tests cases where the cached spec may become outdated.
180-
"""
181-
182-
device_count = xr.global_runtime_device_count()
183-
if device_count < 4:
184-
self.skipTest("Need at least 4 devices for resharding test")
185-
186-
tensor = torch.randn(100, 50).to('xla')
187-
initial_mesh_shape = (device_count,)
188-
initial_partition_spec = (0, None)
189-
new_mesh_shape = (2, device_count // 2)
190-
new_partition_spec = (0, 1)
191-
192-
sharded_tensor = wrap_as_sharded_tensor(
193-
tensor,
194-
mesh_shape=initial_mesh_shape,
195-
partition_spec=initial_partition_spec)
196-
initial_spec = sharded_tensor._spec
197-
assert sharded_tensor._cached_spec is not None
198-
199-
# Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache
200-
resharded_tensor = wrap_as_sharded_tensor(
201-
sharded_tensor,
202-
mesh_shape=new_mesh_shape,
203-
partition_spec=initial_partition_spec)
204-
assert resharded_tensor._spec is not initial_spec
205-
assert resharded_tensor._spec.mesh.shape == new_mesh_shape
206-
207-
initial_spec = resharded_tensor._spec
208-
resharded_tensor = wrap_as_sharded_tensor(
209-
resharded_tensor,
210-
mesh_shape=new_mesh_shape,
211-
partition_spec=new_partition_spec)
212-
assert resharded_tensor._spec is not initial_spec
213-
assert resharded_tensor._spec.placements[1].dim == 1
214-
215-
def test_auto_wrapped_tensor_spec_failure(self):
216-
"""Test that auto-wrapped tensors fail when accessing _spec property.
217-
218-
Auto-wrapped tensors are created through operations that trigger __torch_dispatch__
219-
but don't yet have access to the sharding propagation done through open xla,
220-
causing ._spec to fail.
221-
"""
222-
device_count = xr.global_runtime_device_count()
223-
mesh = DeviceMesh("xla", torch.arange(device_count))
224-
tensor = torch.randn(4, 4)
225-
sharded_tensor = distribute_tensor(tensor, mesh, [Shard(0)])
226-
227-
auto_wrapped = sharded_tensor + sharded_tensor
228-
229-
with self.assertRaises(ValueError):
230-
_ = auto_wrapped._spec
231-
232146

233147
if __name__ == '__main__':
234148
test = unittest.main()

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __new__(cls,
114114
dtype=elem.dtype,
115115
layout=elem.layout,
116116
device=elem.device,
117-
requires_grad=kwargs.get("requires_grad", False))
117+
requires_grad=kwargs.get("requires_grad", elem.requires_grad))
118118
r.global_tensor = elem.detach() if r.requires_grad else elem
119119

120120
# Initialize mesh, partition, and spec information
@@ -150,6 +150,29 @@ def load_local_shards_(self, shards: List[XLAShard]):
150150
# Invalidate cached spec since the global_tensor data has changed
151151
self.invalidate_spec_cache()
152152

153+
def to_local(self):
154+
"""
155+
Returns the local representation of the XLAShardedTensor.
156+
157+
This method returns the global tensor representation, which contains
158+
the combined data across all devices. The returned tensor is on the
159+
same device as the original XLAShardedTensor. The returned tensor
160+
will have the same requires_grad value as the XLAShardedTensor.
161+
If the original tensor has gradients, those will be preserved.
162+
163+
Returns:
164+
torch.Tensor: The global tensor representation with appropriate requires_grad setting.
165+
"""
166+
167+
# Create a new tensor with the same values of global_tensor
168+
result = self.global_tensor.clone()
169+
# Since global tensor is detached, add requires_grad and grad values back to the local tensor
170+
if self.requires_grad:
171+
result.requires_grad = self.requires_grad
172+
result.grad = self.grad
173+
174+
return result
175+
153176
@property
154177
def sharding_spec(self):
155178
return torch_xla._XLAC._get_xla_sharding_spec(self.global_tensor)
@@ -300,4 +323,4 @@ def redistribute(self, device_mesh, placements, *, async_op: bool = False):
300323

301324
@classmethod
302325
def __torch_function__(cls, func, types, args=(), kwargs=None):
303-
return super().__torch_function__(func, types, args, kwargs)
326+
return super().__torch_function__(func, types, args, kwargs)

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -767,27 +767,14 @@ def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor],
767767
partition_spec=None) -> XLAShardedTensor:
768768
# pass along mesh and partition spec information
769769
if not isinstance(t, XLAShardedTensor):
770-
# Create a new XLAShardedTensor
771770
return XLAShardedTensor(
772771
t, mesh_shape=mesh_shape, partition_spec=partition_spec)
773-
774-
# Update existing XLAShardedTensor if needed
775-
needs_invalidate = False
776-
777-
# Always set mesh_shape and partition_spec if provided
778-
if mesh_shape is not None:
779-
t.mesh_shape = mesh_shape
780-
needs_invalidate = True
781-
782-
if partition_spec is not None:
783-
t.partition_spec = partition_spec
784-
needs_invalidate = True
785-
786-
# Invalidate cached spec if resharding occurred
787-
if needs_invalidate:
788-
t.invalidate_spec_cache()
789-
790-
return t
772+
else:
773+
if mesh_shape is not None:
774+
t.mesh_shape = mesh_shape
775+
if partition_spec is not None:
776+
t.partition_spec = partition_spec
777+
return t
791778

792779

793780
def unwrap_sharded_tensor(

0 commit comments

Comments
 (0)