Skip to content

Commit d03452f

Browse files
committed
Add some tests and reviewer suggestions. Will update V2 op sharding logic in a later commit soon.
1 parent 275f369 commit d03452f

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

test/spmd/test_xla_sharding.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import unittest
77
from unittest.mock import patch
88
import sys
9+
import os
910

1011
import torch
1112
from torch import nn
@@ -26,6 +27,11 @@
2627
from torch_xla._internal import tpu
2728

2829

30+
def should_convert_to_shardy():
31+
return os.environ.get("CONVERT_SHLO_TO_SHARDY",
32+
"").lower() in ("1", "true", "yes")
33+
34+
2935
class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):
3036

3137
@classmethod
@@ -238,6 +244,8 @@ def test_custom_tile_assignment(self):
238244
if self.n_devices > 1:
239245
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
240246
[str(i) for i in reversed(range(self.n_devices))]))
247+
if should_convert_to_shardy():
248+
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
241249
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
242250

243251
def test_mark_sharding_2d(self):
@@ -252,6 +260,8 @@ def test_mark_sharding_2d(self):
252260
if self.n_devices > 1:
253261
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
254262
[str(i) for i in range(self.n_devices)]))
263+
if should_convert_to_shardy():
264+
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
255265
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1))
256266

257267
actual = (xt1 + xt2).cpu()
@@ -271,6 +281,9 @@ def test_mark_sharding_4d(self):
271281
annotation = '{devices=[1,1,%d,%d]%s}' % (
272282
z_dim, self.n_devices // z_dim, ','.join(
273283
[str(i) for i in range(self.n_devices)]))
284+
if should_convert_to_shardy():
285+
annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim, self.n_devices //
286+
z_dim, self.n_devices)
274287
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
275288

276289
actual = (xt + xt).cpu()
@@ -403,9 +416,11 @@ def test_tupled_partition_spec(self):
403416
mesh = self._get_mesh((2, self.n_devices // 2))
404417
t = torch.randn(16).to('xla')
405418
xs.mark_sharding(t, mesh, ((0, 1),))
406-
self.assertEqual(
407-
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" %
408-
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
419+
annotation = "{devices=[%d]%s}" % (self.n_devices, ','.join(
420+
str(x) for x in range(self.n_devices)))
421+
if should_convert_to_shardy():
422+
annotation = "{devices=[%d]<=[%d]}" % (self.n_devices, self.n_devices)
423+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
409424

410425
@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
411426
"Multiple devices required for tupled partition spec")
@@ -452,19 +467,25 @@ def test_multiple_tuples_in_spec(self):
452467
('a', 'b', 'c', 'd'))
453468
t = torch.randn(2, 2).to('xla')
454469
xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd')))
455-
self.assertEqual(
456-
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" %
457-
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
470+
annotation = "{devices=[2,%d]%s}" % (self.n_devices // 2, ','.join(
471+
str(x) for x in range(self.n_devices)))
472+
if should_convert_to_shardy():
473+
annotation = "{devices=[2,%d]<=[%d]}" % (self.n_devices // 2,
474+
self.n_devices)
475+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
458476

459477
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
460478
'At least 2 devices needed for 2D mesh')
461479
def test_3d_tensor_2d_mesh(self):
462480
mesh = self._get_mesh((2, self.n_devices // 2))
463481
t = torch.randn(16, 16, 16).to('xla')
464482
xs.mark_sharding(t, mesh, (None, 0, 1))
465-
self.assertEqual(
466-
torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' %
467-
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
483+
expected = '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join(
484+
str(x) for x in range(self.n_devices)))
485+
if should_convert_to_shardy():
486+
expected = '{devices=[1,2,%d]<=[%d] last_tile_dim_replicate}' % (
487+
self.n_devices // 2, self.n_devices)
488+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), expected)
468489

469490
def test_partial_replication_addmm(self):
470491
device = torch_xla.device()
@@ -983,18 +1004,21 @@ def test_op_sharding_cache(self):
9831004

9841005
t = torch.randn(1, self.n_devices).to('xla')
9851006
xs.mark_sharding(t, mesh, (0, 1))
986-
self.assertIn("CreateOpSharding", met.counter_names())
987-
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
1007+
counter_name = "CreateIotaOpSharding" if should_convert_to_shardy(
1008+
) else "CreateOpSharding"
1009+
self.assertIn(counter_name, met.counter_names())
1010+
self.assertEqual(met.counter_value(counter_name), 1)
9881011

9891012
# Sharding with the same partition spec should not result in another call
9901013
u = torch.randn(1, self.n_devices).to('xla')
9911014
xs.mark_sharding(u, mesh, (0, 1))
992-
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
1015+
self.assertEqual(met.counter_value(counter_name), 1)
9931016

994-
# Changing the partition spec will result in another CreateOpSharding
1017+
# Changing the partition spec will result in another
1018+
# CreateOpSharding or CreatingIotaOpSharding call
9951019
v = torch.randn(1, self.n_devices).to('xla')
9961020
xs.mark_sharding(v, mesh, (0, None))
997-
self.assertEqual(met.counter_value("CreateOpSharding"), 2)
1021+
self.assertEqual(met.counter_value(counter_name), 2)
9981022

9991023
def test_from_cpu_shards_replicated(self):
10001024
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards

torch_xla/csrc/xla_sharding_util.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,11 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a,
221221
xla::OpSharding ShardingUtil::CreateIotaOpSharding(
222222
const py::list& dims, const py::list& reshape_dims,
223223
const py::list& transpose_perm) {
224+
TORCH_LAZY_COUNTER("CreateIotaOpSharding", 1);
224225
auto dims_vec = dims.cast<std::vector<int64_t>>();
225226
auto reshape_dims_vec = reshape_dims.cast<std::vector<int64_t>>();
226227
auto transpose_perm_vec = transpose_perm.cast<std::vector<int>>();
228+
CHECK_EQ(reshape_dims_vec.size(), transpose_perm_vec.size());
227229
std::vector<xla::OpSharding::Type> subgroup_types;
228230
if (dims_vec.size() > transpose_perm.size()) {
229231
subgroup_types.push_back(xla::OpSharding::REPLICATED);

0 commit comments

Comments
 (0)