3
3
4
4
import torch
5
5
from torch .distributed .tensor import DeviceMesh , Shard , distribute_tensor
6
+ from torch .distributed .tensor .placement_types import Replicate
6
7
7
8
import torch_xla
8
9
import torch_xla .runtime as xr
10
+ from torch_xla .distributed .spmd import XLAShardedTensor
11
+ from torch_xla .distributed .spmd .xla_sharding import wrap_as_sharded_tensor
9
12
10
13
import unittest
11
14
import test_xla_sharding_base
@@ -31,7 +34,6 @@ def test_xla_to_dtensor_spec_conversion(self):
31
34
mesh = DeviceMesh ("xla" , list (range (device_count )))
32
35
33
36
# Test different sharding patterns
34
- from torch .distributed .tensor .placement_types import Replicate
35
37
test_cases = [
36
38
(torch .randn (100 , 50 ), [Shard (0 )]),
37
39
(torch .randn (100 , 50 ), [Shard (1 )]),
@@ -64,30 +66,27 @@ def test_mesh_conversion(self):
64
66
assert converted_spec .mesh .shape == original_mesh .shape
65
67
66
68
def test_spec_caching (self ):
67
- """Test that _spec property caches results for better performance"""
68
- import time
69
+ """Test that _spec property caches results
70
+
71
+ Addresses PR comment: "These sorts of tests that rely on the wall clock often lead to
72
+ annoying flakes in my experience. I think it's sufficient to just test that
73
+ self._cached_spec has a permanent value after the first call."
74
+ """
69
75
device_count = xr .global_runtime_device_count ()
70
76
mesh = DeviceMesh ("xla" , list (range (device_count )))
71
- tensor = torch .randn (1000 ,
72
- 1000 ) # Large tensor to make spec creation noticeable
77
+ tensor = torch .randn (100 , 100 )
73
78
xla_tensor = distribute_tensor (tensor , mesh , [Shard (0 )])
74
79
75
- # first access should create and cache the spec
76
- start_time = time .time ()
80
+ # First access should create and cache the spec
77
81
spec1 = xla_tensor ._spec
78
- first_access_time = time .time () - start_time
79
82
80
- # should be much faster due to caching
81
- start_time = time .time ()
82
- spec2 = xla_tensor ._spec
83
- second_access_time = time .time () - start_time
83
+ # Verify the spec is cached
84
+ assert xla_tensor ._cached_spec is not None
85
+ assert xla_tensor ._cached_spec is spec1
84
86
87
+ # Second access should return the cached spec
88
+ spec2 = xla_tensor ._spec
85
89
assert spec1 is spec2
86
- print (
87
- f"First access: { first_access_time :.6f} s, Second access: { second_access_time :.6f} s"
88
- )
89
- assert second_access_time * 10 < first_access_time , \
90
- f"Cached access should be much faster: { first_access_time :.6f} s vs { second_access_time :.6f} s"
91
90
92
91
def _create_test_tensor_and_mesh (self , tensor_shape , mesh_shape , placements ):
93
92
"""Helper to create tensor and mesh for testing"""
@@ -114,22 +113,8 @@ def test_multi_dim_sharding_spec(self):
114
113
assert len (spec .placements ) == 2
115
114
assert spec .mesh .ndim == 2
116
115
117
- def test_tensor_operations_preserve_spec (self ):
118
- """Test that tensor operations preserve sharding metadata"""
119
- xla_tensor , mesh = self ._create_test_tensor_and_mesh ((100 , 50 ), (- 1 ,),
120
- [Shard (0 )])
121
-
122
- result_add = xla_tensor + 1
123
- result_mul = xla_tensor * 2
124
- result_relu = torch .relu (xla_tensor )
125
-
126
- for result in [result_add , result_mul , result_relu ]:
127
- assert hasattr (result , '_spec' )
128
- assert result ._spec .mesh .device_type == "xla"
129
-
130
116
def test_mixed_placement_spec (self ):
131
117
"""Test _spec for tensors with mixed shard/replicate placements"""
132
- from torch .distributed .tensor .placement_types import Replicate
133
118
device_count = xr .global_runtime_device_count ()
134
119
if device_count < 4 :
135
120
self .skipTest ("Need at least 4 devices for 2D mesh" )
@@ -143,6 +128,97 @@ def test_mixed_placement_spec(self):
143
128
assert isinstance (spec .placements [0 ], Shard )
144
129
assert isinstance (spec .placements [1 ], Replicate )
145
130
131
+ def test_sharding_info_acquisition (self ):
132
+ """Test that non-XLAShardedTensor can acquire sharding information
133
+
134
+ Tests case of 'elem is not an XLAShardedTensor but there exists
135
+ sharding information we want to acquire'
136
+ """
137
+
138
+ device_count = xr .global_runtime_device_count ()
139
+ mesh_shape = (device_count ,)
140
+ partition_spec = (0 , None )
141
+
142
+ regular_tensor = torch .randn (100 , 50 ).to ('xla' )
143
+
144
+ sharded_tensor = wrap_as_sharded_tensor (
145
+ regular_tensor , mesh_shape = mesh_shape , partition_spec = partition_spec )
146
+
147
+ # Verify the tensor acquired the sharding information
148
+ assert isinstance (sharded_tensor , XLAShardedTensor )
149
+ assert sharded_tensor .mesh_shape == mesh_shape
150
+ assert sharded_tensor .partition_spec == partition_spec
151
+
152
+ def test_resharding_logic (self ):
153
+ """
154
+ Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t.
155
+ """
156
+
157
+ device_count = xr .global_runtime_device_count ()
158
+ if device_count < 4 :
159
+ self .skipTest ("Need at least 4 devices for resharding test" )
160
+
161
+ # Initial sharding
162
+ initial_mesh_shape = (device_count ,)
163
+ initial_partition_spec = (0 , None )
164
+ new_mesh_shape = (2 , device_count // 2 )
165
+ new_partition_spec = (0 , 1 )
166
+
167
+ # Create tensor and verify resharding
168
+ tensor = torch .randn (100 , 50 ).to ('xla' )
169
+ sharded_tensor = wrap_as_sharded_tensor (
170
+ tensor ,
171
+ mesh_shape = initial_mesh_shape ,
172
+ partition_spec = initial_partition_spec )
173
+ initial_spec = sharded_tensor ._spec
174
+
175
+ resharded_tensor = wrap_as_sharded_tensor (
176
+ sharded_tensor ,
177
+ mesh_shape = new_mesh_shape ,
178
+ partition_spec = new_partition_spec )
179
+
180
+ # Verify resharding worked and cache was invalidated
181
+ assert resharded_tensor .mesh_shape == new_mesh_shape
182
+ assert resharded_tensor .partition_spec == new_partition_spec
183
+ assert resharded_tensor ._spec is not initial_spec
184
+
185
+ def test_spec_invalidation_on_resharding (self ):
186
+ """Tests cases where the cached spec may become outdated.
187
+ """
188
+
189
+ device_count = xr .global_runtime_device_count ()
190
+ if device_count < 4 :
191
+ self .skipTest ("Need at least 4 devices for resharding test" )
192
+
193
+ tensor = torch .randn (100 , 50 ).to ('xla' )
194
+ initial_mesh_shape = (device_count ,)
195
+ initial_partition_spec = (0 , None )
196
+ new_mesh_shape = (2 , device_count // 2 )
197
+ new_partition_spec = (0 , 1 )
198
+
199
+ sharded_tensor = wrap_as_sharded_tensor (
200
+ tensor ,
201
+ mesh_shape = initial_mesh_shape ,
202
+ partition_spec = initial_partition_spec )
203
+ initial_spec = sharded_tensor ._spec
204
+ assert sharded_tensor ._cached_spec is not None
205
+
206
+ # Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache
207
+ resharded_tensor = wrap_as_sharded_tensor (
208
+ sharded_tensor ,
209
+ mesh_shape = new_mesh_shape ,
210
+ partition_spec = initial_partition_spec )
211
+ assert resharded_tensor ._spec is not initial_spec
212
+ assert resharded_tensor ._spec .mesh .shape == new_mesh_shape
213
+
214
+ initial_spec = resharded_tensor ._spec
215
+ resharded_tensor = wrap_as_sharded_tensor (
216
+ resharded_tensor ,
217
+ mesh_shape = new_mesh_shape ,
218
+ partition_spec = new_partition_spec )
219
+ assert resharded_tensor ._spec is not initial_spec
220
+ assert resharded_tensor ._spec .placements [1 ].dim == 1
221
+
146
222
147
223
if __name__ == '__main__' :
148
224
test = unittest .main ()
0 commit comments