File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -212,6 +212,23 @@ def test_spec_invalidation_on_resharding(self):
212
212
assert resharded_tensor ._spec is not initial_spec
213
213
assert resharded_tensor ._spec .placements [1 ].dim == 1
214
214
215
+ def test_auto_wrapped_tensor_spec_failure (self ):
216
+ """Test that auto-wrapped tensors fail when accessing _spec property.
217
+
218
+ Auto-wrapped tensors are created through operations that trigger __torch_dispatch__
219
+ but don't yet have access to the sharding propagation done through open xla,
220
+ causing ._spec to fail.
221
+ """
222
+ device_count = xr .global_runtime_device_count ()
223
+ mesh = DeviceMesh ("xla" , torch .arange (device_count ))
224
+ tensor = torch .randn (4 , 4 )
225
+ sharded_tensor = distribute_tensor (tensor , mesh , [Shard (0 )])
226
+
227
+ auto_wrapped = sharded_tensor + sharded_tensor
228
+
229
+ with self .assertRaises (ValueError ):
230
+ _ = auto_wrapped ._spec
231
+
215
232
216
233
if __name__ == '__main__' :
217
234
test = unittest .main ()
You can’t perform that action at this time.
0 commit comments