Skip to content

Commit 7322629

Browse files
aws-cphHoomaaan
authored andcommitted
Added test for catching thrown error in spec
1 parent 97f67e2 commit 7322629

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

test/spmd/test_xla_dtensor_spec_conversion.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,23 @@ def test_spec_invalidation_on_resharding(self):
212212
assert resharded_tensor._spec is not initial_spec
213213
assert resharded_tensor._spec.placements[1].dim == 1
214214

215+
def test_auto_wrapped_tensor_spec_failure(self):
216+
"""Test that auto-wrapped tensors fail when accessing _spec property.
217+
218+
Auto-wrapped tensors are created through operations that trigger __torch_dispatch__
219+
but don't yet have access to the sharding propagation done through open xla,
220+
causing ._spec to fail.
221+
"""
222+
device_count = xr.global_runtime_device_count()
223+
mesh = DeviceMesh("xla", torch.arange(device_count))
224+
tensor = torch.randn(4, 4)
225+
sharded_tensor = distribute_tensor(tensor, mesh, [Shard(0)])
226+
227+
auto_wrapped = sharded_tensor + sharded_tensor
228+
229+
with self.assertRaises(ValueError):
230+
_ = auto_wrapped._spec
231+
215232

216233
if __name__ == '__main__':
217234
test = unittest.main()

0 commit comments

Comments
 (0)