Skip to content

Commit 1e0ea41

Browse files
aws-cphHoomaaan
authored andcommitted
Removed auto wrapping sharding propagation, added cached spec invalidation
1 parent e1e1742 commit 1e0ea41

File tree

3 files changed

+130
-37
lines changed

3 files changed

+130
-37
lines changed

test/spmd/test_xla_dtensor_spec_conversion.py

Lines changed: 107 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33

44
import torch
55
from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor
6+
from torch.distributed.tensor.placement_types import Replicate
67

78
import torch_xla
89
import torch_xla.runtime as xr
10+
from torch_xla.distributed.spmd import XLAShardedTensor
11+
from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor
912

1013
import unittest
1114
import test_xla_sharding_base
@@ -31,7 +34,6 @@ def test_xla_to_dtensor_spec_conversion(self):
3134
mesh = DeviceMesh("xla", list(range(device_count)))
3235

3336
# Test different sharding patterns
34-
from torch.distributed.tensor.placement_types import Replicate
3537
test_cases = [
3638
(torch.randn(100, 50), [Shard(0)]),
3739
(torch.randn(100, 50), [Shard(1)]),
@@ -64,30 +66,27 @@ def test_mesh_conversion(self):
6466
assert converted_spec.mesh.shape == original_mesh.shape
6567

6668
def test_spec_caching(self):
67-
"""Test that _spec property caches results for better performance"""
68-
import time
69+
"""Test that _spec property caches results
70+
71+
Addresses PR comment: "These sorts of tests that rely on the wall clock often lead to
72+
annoying flakes in my experience. I think it's sufficient to just test that
73+
self._cached_spec has a permanent value after the first call."
74+
"""
6975
device_count = xr.global_runtime_device_count()
7076
mesh = DeviceMesh("xla", list(range(device_count)))
71-
tensor = torch.randn(1000,
72-
1000) # Large tensor to make spec creation noticeable
77+
tensor = torch.randn(100, 100)
7378
xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)])
7479

75-
# first access should create and cache the spec
76-
start_time = time.time()
80+
# First access should create and cache the spec
7781
spec1 = xla_tensor._spec
78-
first_access_time = time.time() - start_time
7982

80-
# should be much faster due to caching
81-
start_time = time.time()
82-
spec2 = xla_tensor._spec
83-
second_access_time = time.time() - start_time
83+
# Verify the spec is cached
84+
assert xla_tensor._cached_spec is not None
85+
assert xla_tensor._cached_spec is spec1
8486

87+
# Second access should return the cached spec
88+
spec2 = xla_tensor._spec
8589
assert spec1 is spec2
86-
print(
87-
f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s"
88-
)
89-
assert second_access_time * 10 < first_access_time, \
90-
f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s"
9190

9291
def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements):
9392
"""Helper to create tensor and mesh for testing"""
@@ -114,22 +113,8 @@ def test_multi_dim_sharding_spec(self):
114113
assert len(spec.placements) == 2
115114
assert spec.mesh.ndim == 2
116115

117-
def test_tensor_operations_preserve_spec(self):
118-
"""Test that tensor operations preserve sharding metadata"""
119-
xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,),
120-
[Shard(0)])
121-
122-
result_add = xla_tensor + 1
123-
result_mul = xla_tensor * 2
124-
result_relu = torch.relu(xla_tensor)
125-
126-
for result in [result_add, result_mul, result_relu]:
127-
assert hasattr(result, '_spec')
128-
assert result._spec.mesh.device_type == "xla"
129-
130116
def test_mixed_placement_spec(self):
131117
"""Test _spec for tensors with mixed shard/replicate placements"""
132-
from torch.distributed.tensor.placement_types import Replicate
133118
device_count = xr.global_runtime_device_count()
134119
if device_count < 4:
135120
self.skipTest("Need at least 4 devices for 2D mesh")
@@ -143,6 +128,97 @@ def test_mixed_placement_spec(self):
143128
assert isinstance(spec.placements[0], Shard)
144129
assert isinstance(spec.placements[1], Replicate)
145130

131+
def test_sharding_info_acquisition(self):
132+
"""Test that non-XLAShardedTensor can acquire sharding information
133+
134+
Tests case of 'elem is not an XLAShardedTensor but there exists
135+
sharding information we want to acquire'
136+
"""
137+
138+
device_count = xr.global_runtime_device_count()
139+
mesh_shape = (device_count,)
140+
partition_spec = (0, None)
141+
142+
regular_tensor = torch.randn(100, 50).to('xla')
143+
144+
sharded_tensor = wrap_as_sharded_tensor(
145+
regular_tensor, mesh_shape=mesh_shape, partition_spec=partition_spec)
146+
147+
# Verify the tensor acquired the sharding information
148+
assert isinstance(sharded_tensor, XLAShardedTensor)
149+
assert sharded_tensor.mesh_shape == mesh_shape
150+
assert sharded_tensor.partition_spec == partition_spec
151+
152+
def test_resharding_logic(self):
153+
"""
154+
Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t.
155+
"""
156+
157+
device_count = xr.global_runtime_device_count()
158+
if device_count < 4:
159+
self.skipTest("Need at least 4 devices for resharding test")
160+
161+
# Initial sharding
162+
initial_mesh_shape = (device_count,)
163+
initial_partition_spec = (0, None)
164+
new_mesh_shape = (2, device_count // 2)
165+
new_partition_spec = (0, 1)
166+
167+
# Create tensor and verify resharding
168+
tensor = torch.randn(100, 50).to('xla')
169+
sharded_tensor = wrap_as_sharded_tensor(
170+
tensor,
171+
mesh_shape=initial_mesh_shape,
172+
partition_spec=initial_partition_spec)
173+
initial_spec = sharded_tensor._spec
174+
175+
resharded_tensor = wrap_as_sharded_tensor(
176+
sharded_tensor,
177+
mesh_shape=new_mesh_shape,
178+
partition_spec=new_partition_spec)
179+
180+
# Verify resharding worked and cache was invalidated
181+
assert resharded_tensor.mesh_shape == new_mesh_shape
182+
assert resharded_tensor.partition_spec == new_partition_spec
183+
assert resharded_tensor._spec is not initial_spec
184+
185+
def test_spec_invalidation_on_resharding(self):
186+
"""Tests cases where the cached spec may become outdated.
187+
"""
188+
189+
device_count = xr.global_runtime_device_count()
190+
if device_count < 4:
191+
self.skipTest("Need at least 4 devices for resharding test")
192+
193+
tensor = torch.randn(100, 50).to('xla')
194+
initial_mesh_shape = (device_count,)
195+
initial_partition_spec = (0, None)
196+
new_mesh_shape = (2, device_count // 2)
197+
new_partition_spec = (0, 1)
198+
199+
sharded_tensor = wrap_as_sharded_tensor(
200+
tensor,
201+
mesh_shape=initial_mesh_shape,
202+
partition_spec=initial_partition_spec)
203+
initial_spec = sharded_tensor._spec
204+
assert sharded_tensor._cached_spec is not None
205+
206+
# Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache
207+
resharded_tensor = wrap_as_sharded_tensor(
208+
sharded_tensor,
209+
mesh_shape=new_mesh_shape,
210+
partition_spec=initial_partition_spec)
211+
assert resharded_tensor._spec is not initial_spec
212+
assert resharded_tensor._spec.mesh.shape == new_mesh_shape
213+
214+
initial_spec = resharded_tensor._spec
215+
resharded_tensor = wrap_as_sharded_tensor(
216+
resharded_tensor,
217+
mesh_shape=new_mesh_shape,
218+
partition_spec=new_partition_spec)
219+
assert resharded_tensor._spec is not initial_spec
220+
assert resharded_tensor._spec.placements[1].dim == 1
221+
146222

147223
if __name__ == '__main__':
148224
test = unittest.main()

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99
import torch_xla.runtime as xr
1010
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
1111
from torch.distributed.device_mesh import DeviceMesh
12+
<<<<<<< HEAD
1213
from torch.distributed.tensor.placement_types import Placement, Shard, Replicate, Partial
14+
=======
15+
from torch.distributed.tensor.placement_types import Shard, Replicate
16+
>>>>>>> 566959e10 (Removed auto wrapping sharding propagation, added cached spec invalidation)
1317
from torch.utils._pytree import tree_map_only
1418

1519

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -767,14 +767,27 @@ def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor],
767767
partition_spec=None) -> XLAShardedTensor:
768768
# pass along mesh and partition spec information
769769
if not isinstance(t, XLAShardedTensor):
770+
# Create a new XLAShardedTensor
770771
return XLAShardedTensor(
771772
t, mesh_shape=mesh_shape, partition_spec=partition_spec)
772-
else:
773-
if mesh_shape is not None:
774-
t.mesh_shape = mesh_shape
775-
if partition_spec is not None:
776-
t.partition_spec = partition_spec
777-
return t
773+
774+
# Update existing XLAShardedTensor if needed
775+
needs_invalidate = False
776+
777+
# Always set mesh_shape and partition_spec if provided
778+
if mesh_shape is not None:
779+
t.mesh_shape = mesh_shape
780+
needs_invalidate = True
781+
782+
if partition_spec is not None:
783+
t.partition_spec = partition_spec
784+
needs_invalidate = True
785+
786+
# Invalidate cached spec if resharding occurred
787+
if needs_invalidate:
788+
t.invalidate_spec_cache()
789+
790+
return t
778791

779792

780793
def unwrap_sharded_tensor(

0 commit comments

Comments
 (0)