-
Notifications
You must be signed in to change notification settings - Fork 566
Open
Description
The following test corresponding to shard_as is failing. Disabling to unblock pin update for release
@unittest.skipIf(
xr.device_type() == 'CPU',
"sharding will be the same for both tensors on single device")
def test_shard_as(self):
mesh = self._get_mesh((self.n_devices,))
partition_spec = (0,)
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, device='xla')
x = xs.mark_sharding_with_gradients(x, mesh, partition_spec)
y = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], dtype=torch.float, device='xla')
x, y = xs.shard_as(x, y)
torch_xla.sync()
sharding_spec = '{devices=[%d]' % self.n_devices
x_sharding = torch_xla._XLAC._get_xla_sharding_spec(x)
y_sharding = torch_xla._XLAC._get_xla_sharding_spec(y)
self.assertIn(sharding_spec, x_sharding)
self.assertEqual(x_sharding, y_sharding)
Metadata
Metadata
Assignees
Labels
No labels