3
3
4
4
import torch
5
5
from torch .distributed .tensor import DeviceMesh , Shard , distribute_tensor
6
- from torch .distributed .tensor .placement_types import Replicate
7
6
8
7
import torch_xla
9
8
import torch_xla .runtime as xr
10
- from torch_xla .distributed .spmd import XLAShardedTensor
11
- from torch_xla .distributed .spmd .xla_sharding import wrap_as_sharded_tensor
12
9
13
10
import unittest
14
11
import test_xla_sharding_base
@@ -34,6 +31,7 @@ def test_xla_to_dtensor_spec_conversion(self):
34
31
mesh = DeviceMesh ("xla" , list (range (device_count )))
35
32
36
33
# Test different sharding patterns
34
+ from torch .distributed .tensor .placement_types import Replicate
37
35
test_cases = [
38
36
(torch .randn (100 , 50 ), [Shard (0 )]),
39
37
(torch .randn (100 , 50 ), [Shard (1 )]),
@@ -66,20 +64,30 @@ def test_mesh_conversion(self):
66
64
assert converted_spec .mesh .shape == original_mesh .shape
67
65
68
66
def test_spec_caching (self ):
69
- """Test that _spec property caches results
70
- """
67
+ """Test that _spec property caches results for better performance"""
68
+ import time
71
69
device_count = xr .global_runtime_device_count ()
72
70
mesh = DeviceMesh ("xla" , list (range (device_count )))
73
- tensor = torch .randn (100 , 100 )
71
+ tensor = torch .randn (1000 ,
72
+ 1000 ) # Large tensor to make spec creation noticeable
74
73
xla_tensor = distribute_tensor (tensor , mesh , [Shard (0 )])
75
74
75
+ # first access should create and cache the spec
76
+ start_time = time .time ()
76
77
spec1 = xla_tensor ._spec
78
+ first_access_time = time .time () - start_time
77
79
78
- assert xla_tensor ._cached_spec is not None
79
- assert xla_tensor ._cached_spec is spec1
80
-
80
+ # should be much faster due to caching
81
+ start_time = time .time ()
81
82
spec2 = xla_tensor ._spec
83
+ second_access_time = time .time () - start_time
84
+
82
85
assert spec1 is spec2
86
+ print (
87
+ f"First access: { first_access_time :.6f} s, Second access: { second_access_time :.6f} s"
88
+ )
89
+ assert second_access_time * 10 < first_access_time , \
90
+ f"Cached access should be much faster: { first_access_time :.6f} s vs { second_access_time :.6f} s"
83
91
84
92
def _create_test_tensor_and_mesh (self , tensor_shape , mesh_shape , placements ):
85
93
"""Helper to create tensor and mesh for testing"""
@@ -106,8 +114,22 @@ def test_multi_dim_sharding_spec(self):
106
114
assert len (spec .placements ) == 2
107
115
assert spec .mesh .ndim == 2
108
116
117
+ def test_tensor_operations_preserve_spec (self ):
118
+ """Test that tensor operations preserve sharding metadata"""
119
+ xla_tensor , mesh = self ._create_test_tensor_and_mesh ((100 , 50 ), (- 1 ,),
120
+ [Shard (0 )])
121
+
122
+ result_add = xla_tensor + 1
123
+ result_mul = xla_tensor * 2
124
+ result_relu = torch .relu (xla_tensor )
125
+
126
+ for result in [result_add , result_mul , result_relu ]:
127
+ assert hasattr (result , '_spec' )
128
+ assert result ._spec .mesh .device_type == "xla"
129
+
109
130
def test_mixed_placement_spec (self ):
110
131
"""Test _spec for tensors with mixed shard/replicate placements"""
132
+ from torch .distributed .tensor .placement_types import Replicate
111
133
device_count = xr .global_runtime_device_count ()
112
134
if device_count < 4 :
113
135
self .skipTest ("Need at least 4 devices for 2D mesh" )
@@ -121,114 +143,6 @@ def test_mixed_placement_spec(self):
121
143
assert isinstance (spec .placements [0 ], Shard )
122
144
assert isinstance (spec .placements [1 ], Replicate )
123
145
124
- def test_sharding_info_acquisition (self ):
125
- """Test that non-XLAShardedTensor can acquire sharding information
126
-
127
- Tests case of 'elem is not an XLAShardedTensor but there exists
128
- sharding information we want to acquire'
129
- """
130
-
131
- device_count = xr .global_runtime_device_count ()
132
- mesh_shape = (device_count ,)
133
- partition_spec = (0 , None )
134
-
135
- regular_tensor = torch .randn (100 , 50 ).to ('xla' )
136
-
137
- sharded_tensor = wrap_as_sharded_tensor (
138
- regular_tensor , mesh_shape = mesh_shape , partition_spec = partition_spec )
139
-
140
- # Verify the tensor acquired the sharding information
141
- assert isinstance (sharded_tensor , XLAShardedTensor )
142
- assert sharded_tensor .mesh_shape == mesh_shape
143
- assert sharded_tensor .partition_spec == partition_spec
144
-
145
- def test_resharding_logic (self ):
146
- """
147
- Tests wrap_as_sharded_tensor resharding before returning XLAShardedTensor t.
148
- """
149
-
150
- device_count = xr .global_runtime_device_count ()
151
- if device_count < 4 :
152
- self .skipTest ("Need at least 4 devices for resharding test" )
153
-
154
- # Initial sharding
155
- initial_mesh_shape = (device_count ,)
156
- initial_partition_spec = (0 , None )
157
- new_mesh_shape = (2 , device_count // 2 )
158
- new_partition_spec = (0 , 1 )
159
-
160
- # Create tensor and verify resharding
161
- tensor = torch .randn (100 , 50 ).to ('xla' )
162
- sharded_tensor = wrap_as_sharded_tensor (
163
- tensor ,
164
- mesh_shape = initial_mesh_shape ,
165
- partition_spec = initial_partition_spec )
166
- initial_spec = sharded_tensor ._spec
167
-
168
- resharded_tensor = wrap_as_sharded_tensor (
169
- sharded_tensor ,
170
- mesh_shape = new_mesh_shape ,
171
- partition_spec = new_partition_spec )
172
-
173
- # Verify resharding worked and cache was invalidated
174
- assert resharded_tensor .mesh_shape == new_mesh_shape
175
- assert resharded_tensor .partition_spec == new_partition_spec
176
- assert resharded_tensor ._spec is not initial_spec
177
-
178
- def test_spec_invalidation_on_resharding (self ):
179
- """Tests cases where the cached spec may become outdated.
180
- """
181
-
182
- device_count = xr .global_runtime_device_count ()
183
- if device_count < 4 :
184
- self .skipTest ("Need at least 4 devices for resharding test" )
185
-
186
- tensor = torch .randn (100 , 50 ).to ('xla' )
187
- initial_mesh_shape = (device_count ,)
188
- initial_partition_spec = (0 , None )
189
- new_mesh_shape = (2 , device_count // 2 )
190
- new_partition_spec = (0 , 1 )
191
-
192
- sharded_tensor = wrap_as_sharded_tensor (
193
- tensor ,
194
- mesh_shape = initial_mesh_shape ,
195
- partition_spec = initial_partition_spec )
196
- initial_spec = sharded_tensor ._spec
197
- assert sharded_tensor ._cached_spec is not None
198
-
199
- # Changing mesh_shape / partition_spec through wrap_as_sharded_tensor invalidates cache
200
- resharded_tensor = wrap_as_sharded_tensor (
201
- sharded_tensor ,
202
- mesh_shape = new_mesh_shape ,
203
- partition_spec = initial_partition_spec )
204
- assert resharded_tensor ._spec is not initial_spec
205
- assert resharded_tensor ._spec .mesh .shape == new_mesh_shape
206
-
207
- initial_spec = resharded_tensor ._spec
208
- resharded_tensor = wrap_as_sharded_tensor (
209
- resharded_tensor ,
210
- mesh_shape = new_mesh_shape ,
211
- partition_spec = new_partition_spec )
212
- assert resharded_tensor ._spec is not initial_spec
213
- assert resharded_tensor ._spec .placements [1 ].dim == 1
214
-
215
- def test_auto_wrapped_tensor_spec_failure (self ):
216
- """Test that auto-wrapped tensors fail when accessing _spec property.
217
-
218
- Auto-wrapped tensors are created through operations that trigger __torch_dispatch__
219
- but don't yet have access to the sharding propagation done through open xla,
220
- causing ._spec to fail.
221
- """
222
- device_count = xr .global_runtime_device_count ()
223
- mesh = DeviceMesh ("xla" , torch .arange (device_count ))
224
- tensor = torch .randn (4 , 4 )
225
- sharded_tensor = distribute_tensor (tensor , mesh , [Shard (0 )])
226
-
227
- auto_wrapped = sharded_tensor + sharded_tensor
228
-
229
- with self .assertRaises (ValueError ):
230
- _ = auto_wrapped ._spec
231
-
232
146
233
147
if __name__ == '__main__' :
234
148
test = unittest .main ()
0 commit comments