Skip to content

Commit 6340248

Browse files
aws-cphHoomaaan
authored andcommitted
Removed auto wrapping sharding propagation, added cached spec invalidation
1 parent e395d44 commit 6340248

File tree

3 files changed

+146
-69
lines changed

3 files changed

+146
-69
lines changed

test/spmd/test_xla_dtensor_spec_conversion.py

Lines changed: 107 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33

44
import torch
55
from torch.distributed.tensor import DeviceMesh, Shard, distribute_tensor
6+
from torch.distributed.tensor.placement_types import Replicate
67

78
import torch_xla
89
import torch_xla.runtime as xr
10+
from torch_xla.distributed.spmd import XLAShardedTensor
11+
from torch_xla.distributed.spmd.xla_sharding import wrap_as_sharded_tensor
912

1013
import unittest
1114
import test_xla_sharding_base
@@ -31,7 +34,6 @@ def test_xla_to_dtensor_spec_conversion(self):
3134
mesh = DeviceMesh("xla", list(range(device_count)))
3235

3336
# Test different sharding patterns
34-
from torch.distributed.tensor.placement_types import Replicate
3537
test_cases = [
3638
(torch.randn(100, 50), [Shard(0)]),
3739
(torch.randn(100, 50), [Shard(1)]),
@@ -64,30 +66,27 @@ def test_mesh_conversion(self):
6466
assert converted_spec.mesh.shape == original_mesh.shape
6567

6668
def test_spec_caching(self):
67-
"""Test that _spec property caches results for better performance"""
68-
import time
69+
"""Test that _spec property caches results
70+
71+
Addresses PR comment: "These sorts of tests that rely on the wall clock often lead to
72+
annoying flakes in my experience. I think it's sufficient to just test that
73+
self._cached_spec has a permanent value after the first call."
74+
"""
6975
device_count = xr.global_runtime_device_count()
7076
mesh = DeviceMesh("xla", list(range(device_count)))
71-
tensor = torch.randn(1000,
72-
1000) # Large tensor to make spec creation noticeable
77+
tensor = torch.randn(100, 100)
7378
xla_tensor = distribute_tensor(tensor, mesh, [Shard(0)])
7479

75-
# first access should create and cache the spec
76-
start_time = time.time()
80+
# First access should create and cache the spec
7781
spec1 = xla_tensor._spec
78-
first_access_time = time.time() - start_time
7982

80-
# should be much faster due to caching
81-
start_time = time.time()
82-
spec2 = xla_tensor._spec
83-
second_access_time = time.time() - start_time
83+
# Verify the spec is cached
84+
assert xla_tensor._cached_spec is not None
85+
assert xla_tensor._cached_spec is spec1
8486

87+
# Second access should return the cached spec
88+
spec2 = xla_tensor._spec
8589
assert spec1 is spec2
86-
print(
87-
f"First access: {first_access_time:.6f}s, Second access: {second_access_time:.6f}s"
88-
)
89-
assert second_access_time * 10 < first_access_time, \
90-
f"Cached access should be much faster: {first_access_time:.6f}s vs {second_access_time:.6f}s"
9190

9291
def _create_test_tensor_and_mesh(self, tensor_shape, mesh_shape, placements):
9392
"""Helper to create tensor and mesh for testing"""
@@ -114,22 +113,8 @@ def test_multi_dim_sharding_spec(self):
114113
assert len(spec.placements) == 2
115114
assert spec.mesh.ndim == 2
116115

117-
def test_tensor_operations_preserve_spec(self):
118-
"""Test that tensor operations preserve sharding metadata"""
119-
xla_tensor, mesh = self._create_test_tensor_and_mesh((100, 50), (-1,),
120-
[Shard(0)])
121-
122-
result_add = xla_tensor + 1
123-
result_mul = xla_tensor * 2
124-
result_relu = torch.relu(xla_tensor)
125-
126-
for result in [result_add, result_mul, result_relu]:
127-
assert hasattr(result, '_spec')
128-
assert result._spec.mesh.device_type == "xla"
129-
130116
def test_mixed_placement_spec(self):
131117
"""Test _spec for tensors with mixed shard/replicate placements"""
132-
from torch.distributed.tensor.placement_types import Replicate
133118
device_count = xr.global_runtime_device_count()
134119
if device_count < 4:
135120
self.skipTest("Need at least 4 devices for 2D mesh")
@@ -143,6 +128,97 @@ def test_mixed_placement_spec(self):
143128
assert isinstance(spec.placements[0], Shard)
144129
assert isinstance(spec.placements[1], Replicate)
145130

131+
def test_sharding_info_acquisition(self):
132+
"""Test that non-XLAShardedTensor can acquire sharding information
133+
134+
Tests case of 'elem is not an XLAShardedTensor but there exists
135+
sharding information we want to acquire'
136+
"""
137+
138+
device_count = xr.global_runtime_device_count()
139+
mesh_shape = (device_count,)
140+
partition_spec = (0, None)
141+
142+
regular_tensor = torch.randn(100, 50).to('xla')
143+
144+
sharded_tensor = wrap_as_sharded_tensor(
145+
regular_tensor, mesh_shape=mesh_shape, partition_spec=partition_spec)
146+
147+
# Verify the tensor acquired the sharding information
148+
assert isinstance(sharded_tensor, XLAShardedTensor)
149+
assert sharded_tensor.mesh_shape == mesh_shape
150+
assert sharded_tensor.partition_spec == partition_spec
151+
152+
def test_resharding_logic(self):
153+
"""
154+
Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t.
155+
"""
156+
157+
device_count = xr.global_runtime_device_count()
158+
if device_count < 4:
159+
self.skipTest("Need at least 4 devices for resharding test")
160+
161+
# Initial sharding
162+
initial_mesh_shape = (device_count,)
163+
initial_partition_spec = (0, None)
164+
new_mesh_shape = (2, device_count // 2)
165+
new_partition_spec = (0, 1)
166+
167+
# Create tensor and verify resharding
168+
tensor = torch.randn(100, 50).to('xla')
169+
sharded_tensor = wrap_as_sharded_tensor(
170+
tensor,
171+
mesh_shape=initial_mesh_shape,
172+
partition_spec=initial_partition_spec)
173+
initial_spec = sharded_tensor._spec
174+
175+
resharded_tensor = wrap_as_sharded_tensor(
176+
sharded_tensor,
177+
mesh_shape=new_mesh_shape,
178+
partition_spec=new_partition_spec)
179+
180+
# Verify resharding worked and cache was invalidated
181+
assert resharded_tensor.mesh_shape == new_mesh_shape
182+
assert resharded_tensor.partition_spec == new_partition_spec
183+
assert resharded_tensor._spec is not initial_spec
184+
185+
def test_spec_invalidation_on_resharding(self):
186+
"""Tests cases where the cached spec may become outdated.
187+
"""
188+
189+
device_count = xr.global_runtime_device_count()
190+
if device_count < 4:
191+
self.skipTest("Need at least 4 devices for resharding test")
192+
193+
tensor = torch.randn(100, 50).to('xla')
194+
initial_mesh_shape = (device_count,)
195+
initial_partition_spec = (0, None)
196+
new_mesh_shape = (2, device_count // 2)
197+
new_partition_spec = (0, 1)
198+
199+
sharded_tensor = wrap_as_sharded_tensor(
200+
tensor,
201+
mesh_shape=initial_mesh_shape,
202+
partition_spec=initial_partition_spec)
203+
initial_spec = sharded_tensor._spec
204+
assert sharded_tensor._cached_spec is not None
205+
206+
# Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache
207+
resharded_tensor = wrap_as_sharded_tensor(
208+
sharded_tensor,
209+
mesh_shape=new_mesh_shape,
210+
partition_spec=initial_partition_spec)
211+
assert resharded_tensor._spec is not initial_spec
212+
assert resharded_tensor._spec.mesh.shape == new_mesh_shape
213+
214+
initial_spec = resharded_tensor._spec
215+
resharded_tensor = wrap_as_sharded_tensor(
216+
resharded_tensor,
217+
mesh_shape=new_mesh_shape,
218+
partition_spec=new_partition_spec)
219+
assert resharded_tensor._spec is not initial_spec
220+
assert resharded_tensor._spec.placements[1].dim == 1
221+
146222

147223
if __name__ == '__main__':
148224
test = unittest.main()

torch_xla/distributed/spmd/xla_sharded_tensor.py

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
1111
from torch.distributed.device_mesh import DeviceMesh
1212
from torch.distributed.tensor.placement_types import Shard, Replicate
13+
from torch.utils._pytree import tree_map_only
1314

1415

1516
@dataclass
@@ -115,11 +116,13 @@ def __new__(cls,
115116
device=elem.device,
116117
requires_grad=kwargs.get("requires_grad", False))
117118
r.global_tensor = elem.detach() if r.requires_grad else elem
118-
# Store mesh and partition information for DTensor compatibility
119-
if mesh_shape is not None:
120-
r.mesh_shape = mesh_shape
121-
if partition_spec is not None:
122-
r.partition_spec = partition_spec
119+
120+
# Initialize mesh, partition, and spec information
121+
r.mesh_shape = mesh_shape or (elem.mesh_shape if isinstance(
122+
elem, XLAShardedTensor) else None)
123+
r.partition_spec = partition_spec or (elem.partition_spec if isinstance(
124+
elem, XLAShardedTensor) else None)
125+
r._cached_spec = None
123126
return r
124127

125128
# Shards on the devices are materialized/available after the lazy
@@ -144,6 +147,9 @@ def load_local_shards_(self, shards: List[XLAShard]):
144147
devices = [s.shard_device for s in shards]
145148
torch_xla._XLAC._load_local_shards(self.global_tensor, data, devices)
146149

150+
# Invalidate cached spec since the global_tensor data has changed
151+
self.invalidate_spec_cache()
152+
147153
@property
148154
def sharding_spec(self):
149155
return torch_xla._XLAC._get_xla_sharding_spec(self.global_tensor)
@@ -173,27 +179,7 @@ def unwrap(elem):
173179
return elem.global_tensor if isinstance(elem, XLAShardedTensor) else elem
174180

175181
def wrap(elem):
176-
if isinstance(elem,
177-
torch.Tensor) and not isinstance(elem, XLAShardedTensor):
178-
# Try to get mesh/partition info from any XLAShardedTensor in args
179-
mesh_shape = None
180-
partition_spec = None
181-
182-
def find_sharded_info(x):
183-
nonlocal mesh_shape, partition_spec
184-
if isinstance(x, XLAShardedTensor):
185-
if hasattr(x, 'mesh_shape') and x.mesh_shape:
186-
mesh_shape = x.mesh_shape
187-
if hasattr(x, 'partition_spec') and x.partition_spec:
188-
partition_spec = x.partition_spec
189-
190-
tree_map(find_sharded_info, args)
191-
if kwargs:
192-
tree_map(find_sharded_info, kwargs)
193-
194-
return XLAShardedTensor(
195-
elem, mesh_shape=mesh_shape, partition_spec=partition_spec)
196-
return elem
182+
return XLAShardedTensor(elem) if isinstance(elem, torch.Tensor) else elem
197183

198184
# no_dispatch is only needed if you use enable_python_mode.
199185
# It prevents infinite recursion.
@@ -209,11 +195,11 @@ def _spec(self):
209195
Convert XLA sharding information to DTensorSpec for DTensor interface compatibility.
210196
"""
211197
# Return cached spec if available
212-
if hasattr(self, '_cached_spec'):
198+
if self._cached_spec is not None:
213199
return self._cached_spec
214200

215201
# use existing mesh_shape
216-
if hasattr(self, 'mesh_shape') and self.mesh_shape:
202+
if self.mesh_shape is not None:
217203
import torch_xla.runtime as xr
218204
device_count = xr.global_runtime_device_count()
219205
device_list = list(range(device_count))
@@ -223,11 +209,9 @@ def _spec(self):
223209
raise ValueError("mesh_shape must be specified to create DTensorSpec")
224210

225211
# use existing partition_spec
226-
if hasattr(self, 'partition_spec') and self.partition_spec:
212+
if self.partition_spec is not None:
227213
placements = []
228-
for mesh_dim in range(
229-
len(self.mesh_shape
230-
) if hasattr(self, 'mesh_shape') and self.mesh_shape else 1):
214+
for mesh_dim in range(len(self.mesh_shape)):
231215
# find tensor dimension sharded on this mesh dimension
232216
tensor_dim = None
233217
for t_dim, m_dim in enumerate(self.partition_spec):
@@ -250,6 +234,10 @@ def _spec(self):
250234
mesh=mesh, placements=tuple(placements), tensor_meta=tensor_meta)
251235
return self._cached_spec
252236

237+
def invalidate_spec_cache(self):
238+
"""Invalidate the cached DTensorSpec."""
239+
self._cached_spec = None
240+
253241
@classmethod
254242
def __torch_function__(cls, func, types, args=(), kwargs=None):
255243
return super().__torch_function__(func, types, args, kwargs)

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -765,14 +765,27 @@ def wrap_as_sharded_tensor(t: Union[torch.Tensor, XLAShardedTensor],
765765
partition_spec=None) -> XLAShardedTensor:
766766
# pass along mesh and partition spec information
767767
if not isinstance(t, XLAShardedTensor):
768+
# Create a new XLAShardedTensor
768769
return XLAShardedTensor(
769770
t, mesh_shape=mesh_shape, partition_spec=partition_spec)
770-
else:
771-
if mesh_shape is not None:
772-
t.mesh_shape = mesh_shape
773-
if partition_spec is not None:
774-
t.partition_spec = partition_spec
775-
return t
771+
772+
# Update existing XLAShardedTensor if needed
773+
needs_invalidate = False
774+
775+
# Always set mesh_shape and partition_spec if provided
776+
if mesh_shape is not None:
777+
t.mesh_shape = mesh_shape
778+
needs_invalidate = True
779+
780+
if partition_spec is not None:
781+
t.partition_spec = partition_spec
782+
needs_invalidate = True
783+
784+
# Invalidate cached spec if resharding occurred
785+
if needs_invalidate:
786+
t.invalidate_spec_cache()
787+
788+
return t
776789

777790

778791
def unwrap_sharded_tensor(

0 commit comments

Comments
 (0)