Skip to content

Commit 9af66d8

Browse files
aws-cphHoomaaan
authored andcommitted
Implement XLAShardedTensor._spec and test
1 parent ca47198 commit 9af66d8

File tree

3 files changed

+72
-163
lines changed

3 files changed

+72
-163
lines changed

test/spmd/test_xla_dtensor_spec_conversion.py

Lines changed: 31 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,9 @@
33

44
import torch
55
from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor
6-
from torch.distributed.tensor.placement_types import Replicate
76

87
import torch_xla
98
import torch_xla.runtime as xr
10-
from torch_xla.distributed.spmd import XLAShardedTensor
11-
from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor
129

1310
import unittest
1411
import test_xla_sharding_base
@@ -34,6 +31,7 @@ def test_xla_to_dtensor_spec_conversion(self):
3431
mesh = DeviceMesh("xla", list(range(device_count)))
3532

3633
# Test different sharding patterns
34+
from torch.distributed.tensor.placement_types import Replicate
3735
test_cases = [
3836
(torch.randn(100, 50), [Shard(0)]),
3937
(torch.randn(100, 50), [Shard(1)]),
@@ -66,20 +64,30 @@ def test_mesh_conversion(self):
6664
assert converted_spec.mesh.shape == original_mesh.shape
6765

6866
def test_spec_caching(self):
69-
"""Test that _spec property caches results
70-
"""
67+
"""Test that _spec property caches results for better performance"""
68+
import time
7169
device_count = xr.global_runtime_device_count()
7270
mesh = DeviceMesh("xla", list(range(device_count)))
73-
tensor = torch.randn(100, 100)
71+
tensor = torch.randn(1000,
72+
1000) # Large tensor to make spec creation noticeable
7473
xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)])
7574

75+
# first access should create and cache the spec
76+
start_time = time.time()
7677
spec1 = xla_tensor._spec
78+
first_access_time = time.time() - start_time
7779

78-
assert xla_tensor._cached_spec is not None
79-
assert xla_tensor._cached_spec is spec1
80-
80+
# should be much faster due to caching
81+
start_time = time.time()
8182
spec2 = xla_tensor._spec
83+
second_access_time = time.time() - start_time
84+
8285
assert spec1 is spec2
86+
print(
87+
f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s"
88+
)
89+
assert second_access_time * 10 < first_access_time, \
90+
f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s"
8391

8492
def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements):
8593
"""Helper to create tensor and mesh for testing"""
@@ -106,8 +114,22 @@ def test_multi_dim_sharding_spec(self):
106114
assert len(spec.placements) == 2
107115
assert spec.mesh.ndim == 2
108116

117+
def test_tensor_operations_preserve_spec(self):
118+
"""Test that tensor operations preserve sharding metadata"""
119+
xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,),
120+
[Shard(0)])
121+
122+
result_add = xla_tensor + 1
123+
result_mul = xla_tensor * 2
124+
result_relu = torch.relu(xla_tensor)
125+
126+
for result in [result_add, result_mul, result_relu]:
127+
assert hasattr(result, '_spec')
128+
assert result._spec.mesh.device_type == "xla"
129+
109130
def test_mixed_placement_spec(self):
110131
"""Test _spec for tensors with mixed shard/replicate placements"""
132+
from torch.distributed.tensor.placement_types import Replicate
111133
device_count = xr.global_runtime_device_count()
112134
if device_count < 4:
113135
self.skipTest("Need at least 4 devices for 2D mesh")
@@ -121,114 +143,6 @@ def test_mixed_placement_spec(self):
121143
assert isinstance(spec.placements[0], Shard)
122144
assert isinstance(spec.placements[1], Replicate)
123145

124-
def test_sharding_info_acquisition(self):
125-
"""Test that non-XLAShardedTensor can acquire sharding information
126-
127-
Tests case of 'elem is not an XLAShardedTensor but there exists
128-
sharding information we want to acquire'
129-
"""
130-
131-
device_count = xr.global_runtime_device_count()
132-
mesh_shape = (device_count,)
133-
partition_spec = (0, None)
134-
135-
regular_tensor = torch.randn(100, 50).to('xla')
136-
137-
sharded_tensor = wrap_as_sharded_tensor(
138-
regular_tensor, mesh_shape=mesh_shape, partition_spec=partition_spec)
139-
140-
# Verify the tensor acquired the sharding information
141-
assert isinstance(sharded_tensor, XLAShardedTensor)
142-
assert sharded_tensor.mesh_shape == mesh_shape
143-
assert sharded_tensor.partition_spec == partition_spec
144-
145-
def test_resharding_logic(self):
146-
"""
147-
Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t.
148-
"""
149-
150-
device_count = xr.global_runtime_device_count()
151-
if device_count < 4:
152-
self.skipTest("Need at least 4 devices for resharding test")
153-
154-
# Initial sharding
155-
initial_mesh_shape = (device_count,)
156-
initial_partition_spec = (0, None)
157-
new_mesh_shape = (2, device_count // 2)
158-
new_partition_spec = (0, 1)
159-
160-
# Create tensor and verify resharding
161-
tensor = torch.randn(100, 50).to('xla')
162-
sharded_tensor = wrap_as_sharded_tensor(
163-
tensor,
164-
mesh_shape=initial_mesh_shape,
165-
partition_spec=initial_partition_spec)
166-
initial_spec = sharded_tensor._spec
167-
168-
resharded_tensor = wrap_as_sharded_tensor(
169-
sharded_tensor,
170-
mesh_shape=new_mesh_shape,
171-
partition_spec=new_partition_spec)
172-
173-
# Verify resharding worked and cache was invalidated
174-
assert resharded_tensor.mesh_shape == new_mesh_shape
175-
assert resharded_tensor.partition_spec == new_partition_spec
176-
assert resharded_tensor._spec is not initial_spec
177-
178-
def test_spec_invalidation_on_resharding(self):
179-
"""Tests cases where the cached spec may become outdated.
180-
"""
181-
182-
device_count = xr.global_runtime_device_count()
183-
if device_count < 4:
184-
self.skipTest("Need at least 4 devices for resharding test")
185-
186-
tensor = torch.randn(100, 50).to('xla')
187-
initial_mesh_shape = (device_count,)
188-
initial_partition_spec = (0, None)
189-
new_mesh_shape = (2, device_count // 2)
190-
new_partition_spec = (0, 1)
191-
192-
sharded_tensor = wrap_as_sharded_tensor(
193-
tensor,
194-
mesh_shape=initial_mesh_shape,
195-
partition_spec=initial_partition_spec)
196-
initial_spec = sharded_tensor._spec
197-
assert sharded_tensor._cached_spec is not None
198-
199-
# Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache
200-
resharded_tensor = wrap_as_sharded_tensor(
201-
sharded_tensor,
202-
mesh_shape=new_mesh_shape,
203-
partition_spec=initial_partition_spec)
204-
assert resharded_tensor._spec is not initial_spec
205-
assert resharded_tensor._spec.mesh.shape == new_mesh_shape
206-
207-
initial_spec = resharded_tensor._spec
208-
resharded_tensor = wrap_as_sharded_tensor(
209-
resharded_tensor,
210-
mesh_shape=new_mesh_shape,
211-
partition_spec=new_partition_spec)
212-
assert resharded_tensor._spec is not initial_spec
213-
assert resharded_tensor._spec.placements[1].dim == 1
214-
215-
def test_auto_wrapped_tensor_spec_failure(self):
216-
"""Test that auto-wrapped tensors fail when accessing _spec property.
217-
218-
Auto-wrapped tensors are created through operations that trigger __torch_dispatch__
219-
but don't yet have access to the sharding propagation done through open xla,
220-
causing ._spec to fail.
221-
"""
222-
device_count = xr.global_runtime_device_count()
223-
mesh = DeviceMesh("xla", torch.arange(device_count))
224-
tensor = torch.randn(4, 4)
225-
sharded_tensor = distribute_tensor(tensor, mesh, [Shard(0)])
226-
227-
auto_wrapped = sharded_tensor + sharded_tensor
228-
229-
with self.assertRaises(ValueError):
230-
_ = auto_wrapped._spec
231-
232146

233147
if __name__ == '__main__':
234148
test = unittest.main()

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
1111
from torch.distributed.device_mesh import DeviceMesh
1212
from torch.distributed.tensor.placement_types import Shard, Replicate
13-
from torch.utils._pytree import tree_map_only
1413

1514

1615
@dataclass
@@ -116,13 +115,11 @@ def __new__(cls,
116115
device=elem.device,
117116
requires_grad=kwargs.get("requires_grad", False))
118117
r.global_tensor = elem.detach() if r.requires_grad else elem
119-
120-
# Initialize mesh, partition, and spec information
121-
r.mesh_shape = mesh_shape or (elem.mesh_shape if isinstance(
122-
elem, XLAShardedTensor) else None)
123-
r.partition_spec = partition_spec or (elem.partition_spec if isinstance(
124-
elem, XLAShardedTensor) else None)
125-
r._cached_spec = None
118+
# Store mesh and partition information for DTensor compatibility
119+
if mesh_shape is not None:
120+
r.mesh_shape = mesh_shape
121+
if partition_spec is not None:
122+
r.partition_spec = partition_spec
126123
return r
127124

128125
# Shards on the devices are materialized/available after the lazy
@@ -179,7 +176,27 @@ def unwrap(elem):
179176
return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem
180177

181178
def wrap(elem):
182-
return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem
179+
if isinstance(elem,
180+
torch.Tensor) and not isinstance(elem, XLAShardedTensor):
181+
# Try to get mesh/partition info from any XLAShardedTensor in args
182+
mesh_shape = None
183+
partition_spec = None
184+
185+
def find_sharded_info(x):
186+
nonlocal mesh_shape, partition_spec
187+
if isinstance(x, XLAShardedTensor):
188+
if hasattr(x, 'mesh_shape') and x.mesh_shape:
189+
mesh_shape = x.mesh_shape
190+
if hasattr(x, 'partition_spec') and x.partition_spec:
191+
partition_spec = x.partition_spec
192+
193+
tree_map(find_sharded_info, args)
194+
if kwargs:
195+
tree_map(find_sharded_info, kwargs)
196+
197+
return XLAShardedTensor(
198+
elem, mesh_shape=mesh_shape, partition_spec=partition_spec)
199+
return elem
183200

184201
# no_dispatch is only needed if you use enable_python_mode.
185202
# It prevents infinite recursion.
@@ -195,26 +212,25 @@ def _spec(self):
195212
Convert XLA sharding information to DTensorSpec for DTensor interface compatibility.
196213
"""
197214
# Return cached spec if available
198-
if self._cached_spec is not None:
215+
if hasattr(self, '_cached_spec'):
199216
return self._cached_spec
200217

201218
# use existing mesh_shape
202-
if self.mesh_shape is not None:
219+
if hasattr(self, 'mesh_shape') and self.mesh_shape:
220+
import torch_xla.runtime as xr
203221
device_count = xr.global_runtime_device_count()
204222
device_list = list(range(device_count))
205223
mesh = DeviceMesh("xla",
206224
torch.tensor(device_list).reshape(self.mesh_shape))
207225
else:
208-
raise ValueError(
209-
"mesh_shape must be specified to create DTensorSpec. "
210-
"If this tensor was created through torch operations, it may be auto-wrapped. "
211-
"Use wrap_as_sharded_tensor() to set mesh_shape before accessing _spec. "
212-
)
226+
raise ValueError("mesh_shape must be specified to create DTensorSpec")
213227

214228
# use existing partition_spec
215-
if self.partition_spec is not None:
229+
if hasattr(self, 'partition_spec') and self.partition_spec:
216230
placements = []
217-
for mesh_dim in range(len(self.mesh_shape)):
231+
for mesh_dim in range(
232+
len(self.mesh_shape
233+
) if hasattr(self, 'mesh_shape') and self.mesh_shape else 1):
218234
# find tensor dimension sharded on this mesh dimension
219235
tensor_dim = None
220236
for t_dim, m_dim in enumerate(self.partition_spec):
@@ -224,11 +240,7 @@ def _spec(self):
224240
placements.append(
225241
Shard(tensor_dim) if tensor_dim is not None else Replicate())
226242
else:
227-
raise ValueError(
228-
"partition_spec must be specified to create DTensorSpec. "
229-
"If this tensor was created through torch operations, it may be auto-wrapped. "
230-
"Use wrap_as_sharded_tensor() to set partition_spec before accessing _spec. "
231-
)
243+
raise ValueError("partition_spec must be specified to create DTensorSpec")
232244

233245
# tensor metadata
234246
tensor_meta = TensorMeta(
@@ -241,10 +253,6 @@ def _spec(self):
241253
mesh=mesh, placements=tuple(placements), tensor_meta=tensor_meta)
242254
return self._cached_spec
243255

244-
def invalidate_spec_cache(self):
245-
"""Invalidate the cached DTensorSpec."""
246-
self._cached_spec = None
247-
248256
@classmethod
249257
def __torch_function__(cls, func, types, args=(), kwargs=None):
250258
return super().__torch_function__(func, types, args, kwargs)

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -765,27 +765,14 @@ def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor],
765765
partition_spec=None) -> XLAShardedTensor:
766766
# pass along mesh and partition spec information
767767
if not isinstance(t, XLAShardedTensor):
768-
# Create a new XLAShardedTensor
769768
return XLAShardedTensor(
770769
t, mesh_shape=mesh_shape, partition_spec=partition_spec)
771-
772-
# Update existing XLAShardedTensor if needed
773-
needs_invalidate = False
774-
775-
# Always set mesh_shape and partition_spec if provided
776-
if mesh_shape is not None:
777-
t.mesh_shape = mesh_shape
778-
needs_invalidate = True
779-
780-
if partition_spec is not None:
781-
t.partition_spec = partition_spec
782-
needs_invalidate = True
783-
784-
# Invalidate cached spec if resharding occurred
785-
if needs_invalidate:
786-
t.invalidate_spec_cache()
787-
788-
return t
770+
else:
771+
if mesh_shape is not None:
772+
t.mesh_shape = mesh_shape
773+
if partition_spec is not None:
774+
t.partition_spec = partition_spec
775+
return t
789776

790777

791778
def unwrap_sharded_tensor(

0 commit comments

Comments
 (0)