Skip to content

Commit 5dfbb4d

Browse files
hshahTTsshonTT
authored andcommitted
Change V2 sharding spec algorithm + Fix tensor sharding spec visualization (#7)
This PR adds support for all previously unsupported partition specs and fixes the visualize_tensor_sharding() function to support V2 sharding specs. See pytorch#9541 for the upstream PR discussion and additional context. * Add some tests and reviewer suggestions. Will update V2 op sharding logic in a later commit soon. * New implementation (WIP) * Fix new implementation * Fix visualize_tensor_sharding function for V2 shardings
1 parent 686cb76 commit 5dfbb4d

File tree

8 files changed

+262
-96
lines changed

8 files changed

+262
-96
lines changed

test/spmd/test_spmd_debugging.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch_xla.distributed.spmd as xs
1818
from torch_xla.distributed.spmd import XLAShardedTensor
1919
from torch_xla.distributed.spmd import Mesh
20+
from torch_xla.distributed.spmd.debugging import construct_v1_sharding_str
2021

2122
import test_xla_sharding_base
2223

@@ -822,6 +823,77 @@ def test_multi_host_replicated_cpu(self):
822823
fake_output = fake_capture.get()
823824
assert output == fake_output
824825

826+
827+
class ConvertV2ShardingToV1Test(test_xla_sharding_base.XlaShardingTest):
828+
829+
@classmethod
830+
def setUpClass(cls):
831+
super().setUpClass()
832+
os.environ["CONVERT_SHLO_TO_SHARDY"] = "1"
833+
834+
def run_test(self):
835+
mesh = self._get_mesh(self.device_mesh_shape)
836+
t = torch.randn(self.tensor_shape).to(torch_xla.device())
837+
xs.mark_sharding(t, mesh, self.partition_spec)
838+
actual_str = construct_v1_sharding_str(t)
839+
self.assertEqual(self.expected_str, actual_str)
840+
841+
def test_tiled_sharding(self):
842+
self.device_mesh_shape = (1, self.n_devices)
843+
self.tensor_shape = (1, 128)
844+
self.partition_spec = (0, 1)
845+
self.expected_str = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
846+
[str(i) for i in range(self.n_devices)]))
847+
self.run_test()
848+
849+
@unittest.skipIf(xr.global_runtime_device_count() < 2,
850+
f"Requires at least 2 devices.")
851+
def test_tupled_tiled_sharding(self):
852+
self.device_mesh_shape = (2, self.n_devices // 2)
853+
self.tensor_shape = (16,)
854+
self.partition_spec = ((0, 1),)
855+
self.expected_str = "{devices=[%d]%s}" % (self.n_devices, ','.join(
856+
str(x) for x in range(self.n_devices)))
857+
self.run_test()
858+
859+
def test_replicated_sharding(self):
860+
self.device_mesh_shape = (1, self.n_devices)
861+
self.tensor_shape = (4, 4)
862+
self.partition_spec = (None, None)
863+
self.expected_str = '{replicated}'
864+
self.run_test()
865+
866+
@unittest.skipIf(xr.global_runtime_device_count() < 4,
867+
f"Requires at least 4 devices.")
868+
def test_partial_replication_sharding(self):
869+
self.device_mesh_shape = (2, self.n_devices // 2)
870+
self.tensor_shape = (4, 4)
871+
self.partition_spec = (0, None)
872+
self.expected_str = '{devices=[2,1,%d]%s last_tile_dim_replicate}' % (
873+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
874+
self.run_test()
875+
876+
@unittest.skipIf(xr.global_runtime_device_count() < 4,
877+
f"Requires at least 4 devices.")
878+
def test_tupled_partial_replication_sharding(self):
879+
self.device_mesh_shape = (1, 2, self.n_devices // 2)
880+
self.tensor_shape = (16, 16)
881+
self.partition_spec = ((0, 1), None)
882+
self.expected_str = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
883+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
884+
self.run_test()
885+
886+
def test_tupled_partial_replication_sharding_with_transpose(self):
887+
self.device_mesh_shape = (1, 2, self.n_devices // 2)
888+
self.tensor_shape = (16, 16)
889+
self.partition_spec = (None, (2, 1))
890+
device_order = self.device_ids.reshape(self.device_mesh_shape).transpose(
891+
(2, 1, 0)).flatten()
892+
self.expected_str = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
893+
str(x) for x in device_order))
894+
self.run_test()
895+
896+
825897
if __name__ == '__main__':
826898
test = unittest.main()
827899
sys.exit(0 if test.result.wasSuccessful() else 1)

test/spmd/test_xla_sharding.py

Lines changed: 63 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):
3131
@classmethod
3232
def setUpClass(cls):
3333
super().setUpClass()
34+
cls.convert_to_shardy = xu.check_env_flag("CONVERT_SHLO_TO_SHARDY")
3435

3536
def test_xla_sharded_tensor(self):
3637
partition_spec = (0, 1)
@@ -238,6 +239,8 @@ def test_custom_tile_assignment(self):
238239
if self.n_devices > 1:
239240
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
240241
[str(i) for i in reversed(range(self.n_devices))]))
242+
if self.convert_to_shardy:
243+
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
241244
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
242245

243246
def test_mark_sharding_2d(self):
@@ -252,6 +255,8 @@ def test_mark_sharding_2d(self):
252255
if self.n_devices > 1:
253256
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
254257
[str(i) for i in range(self.n_devices)]))
258+
if self.convert_to_shardy:
259+
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
255260
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1))
256261

257262
actual = (xt1 + xt2).cpu()
@@ -271,6 +276,9 @@ def test_mark_sharding_4d(self):
271276
annotation = '{devices=[1,1,%d,%d]%s}' % (
272277
z_dim, self.n_devices // z_dim, ','.join(
273278
[str(i) for i in range(self.n_devices)]))
279+
if self.convert_to_shardy:
280+
annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim, self.n_devices //
281+
z_dim, self.n_devices)
274282
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
275283

276284
actual = (xt + xt).cpu()
@@ -403,9 +411,11 @@ def test_tupled_partition_spec(self):
403411
mesh = self._get_mesh((2, self.n_devices // 2))
404412
t = torch.randn(16).to('xla')
405413
xs.mark_sharding(t, mesh, ((0, 1),))
406-
self.assertEqual(
407-
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" %
408-
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
414+
annotation = "{devices=[%d]%s}" % (self.n_devices, ','.join(
415+
str(x) for x in range(self.n_devices)))
416+
if self.convert_to_shardy:
417+
annotation = "{devices=[%d]<=[%d]}" % (self.n_devices, self.n_devices)
418+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
409419

410420
@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
411421
"Multiple devices required for tupled partition spec")
@@ -415,34 +425,43 @@ def test_named_partial_tupled_partition_spec(self):
415425
# Shard the first dimension on `r` and `b`, replicate the second dimension
416426
t = torch.randn(16, 16).to('xla')
417427
xs.mark_sharding(t, mesh, (('r', 'b'), None))
418-
self.assertEqual(
419-
torch_xla._XLAC._get_xla_sharding_spec(t),
420-
"{devices=[2,1,%d]%s last_tile_dim_replicate}" %
421-
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
428+
annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
429+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
430+
if self.convert_to_shardy:
431+
annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % (
432+
self.n_devices // 2, self.n_devices)
433+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
422434

423435
# Replicate the first dimension, shard the second on `b` and `m`
424436
u = torch.randn(16, 16).to('xla')
425437
xs.mark_sharding(u, mesh, (None, ('b', 'm')))
426-
self.assertEqual(
427-
torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" %
428-
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
438+
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
439+
str(x) for x in range(self.n_devices)))
440+
if self.convert_to_shardy:
441+
annotation = "{devices=[1,%d]<=[%d]}" % (self.n_devices, self.n_devices)
442+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(u), annotation)
429443

430444
# Replicate the first dimension, shard the second on `r` and `m`
431445
v = torch.randn(16, 16).to('xla')
432446
xs.mark_sharding(v, mesh, (None, ('r', 'm')))
433447
device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten()
434-
self.assertEqual(
435-
torch_xla._XLAC._get_xla_sharding_spec(v),
436-
"{devices=[1,%d,2]%s last_tile_dim_replicate}" %
437-
(self.n_devices // 2, ','.join(str(x) for x in device_order)))
448+
annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % (
449+
self.n_devices // 2, ','.join(str(x) for x in device_order))
450+
if self.convert_to_shardy:
451+
annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % (
452+
self.n_devices // 2, self.n_devices // 2)
453+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)
438454

439455
# Replicate the first dimension, shard the second on `m` and `b`
440456
v = torch.randn(16, 16).to('xla')
441457
xs.mark_sharding(v, mesh, (None, ('m', 'b')))
442458
device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten()
443-
self.assertEqual(
444-
torch_xla._XLAC._get_xla_sharding_spec(v), "{devices=[1,%d]%s}" %
445-
(self.n_devices, ','.join(str(x) for x in device_order)))
459+
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
460+
str(x) for x in device_order))
461+
if self.convert_to_shardy:
462+
annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self.n_devices,
463+
self.n_devices // 2)
464+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)
446465

447466
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
448467
'Multiple devices required for tupled partition spec')
@@ -452,19 +471,25 @@ def test_multiple_tuples_in_spec(self):
452471
('a', 'b', 'c', 'd'))
453472
t = torch.randn(2, 2).to('xla')
454473
xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd')))
455-
self.assertEqual(
456-
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" %
457-
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
474+
annotation = "{devices=[2,%d]%s}" % (self.n_devices // 2, ','.join(
475+
str(x) for x in range(self.n_devices)))
476+
if self.convert_to_shardy:
477+
annotation = "{devices=[2,%d]<=[%d]}" % (self.n_devices // 2,
478+
self.n_devices)
479+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
458480

459481
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
460482
'At least 2 devices needed for 2D mesh')
461483
def test_3d_tensor_2d_mesh(self):
462484
mesh = self._get_mesh((2, self.n_devices // 2))
463485
t = torch.randn(16, 16, 16).to('xla')
464486
xs.mark_sharding(t, mesh, (None, 0, 1))
465-
self.assertEqual(
466-
torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' %
467-
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
487+
annotation = '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join(
488+
str(x) for x in range(self.n_devices)))
489+
if self.convert_to_shardy:
490+
annotation = '{devices=[1,2,%d]<=[%d]}' % (self.n_devices // 2,
491+
self.n_devices)
492+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
468493

469494
def test_partial_replication_addmm(self):
470495
device = torch_xla.device()
@@ -983,18 +1008,20 @@ def test_op_sharding_cache(self):
9831008

9841009
t = torch.randn(1, self.n_devices).to('xla')
9851010
xs.mark_sharding(t, mesh, (0, 1))
986-
self.assertIn("CreateOpSharding", met.counter_names())
987-
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
1011+
counter_name = "CreateIotaOpSharding" if self.convert_to_shardy else "CreateOpSharding"
1012+
self.assertIn(counter_name, met.counter_names())
1013+
self.assertEqual(met.counter_value(counter_name), 1)
9881014

9891015
# Sharding with the same partition spec should not result in another call
9901016
u = torch.randn(1, self.n_devices).to('xla')
9911017
xs.mark_sharding(u, mesh, (0, 1))
992-
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
1018+
self.assertEqual(met.counter_value(counter_name), 1)
9931019

994-
# Changing the partition spec will result in another CreateOpSharding
1020+
# Changing the partition spec will result in another
1021+
# CreateOpSharding or CreatingIotaOpSharding call
9951022
v = torch.randn(1, self.n_devices).to('xla')
9961023
xs.mark_sharding(v, mesh, (0, None))
997-
self.assertEqual(met.counter_value("CreateOpSharding"), 2)
1024+
self.assertEqual(met.counter_value(counter_name), 2)
9981025

9991026
def test_from_cpu_shards_replicated(self):
10001027
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
@@ -1397,10 +1424,10 @@ def test_data_loader_with_sharding(self):
13971424
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
13981425
data, _ = iter(train_device_loader).__next__()
13991426
self.assertEqual(data.size(), torch.Size([8, 3, 64, 64]))
1400-
self.assertEqual(
1401-
torch_xla._XLAC._get_xla_sharding_spec(data),
1402-
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1403-
)
1427+
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1428+
if self.convert_to_shardy:
1429+
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
1430+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)
14041431

14051432
@unittest.skipUnless(
14061433
xr.global_runtime_device_count() > 1,
@@ -1420,10 +1447,10 @@ def test_data_loader_with_non_batch_size(self):
14201447
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
14211448
data, _ = iter(train_device_loader).__next__()
14221449
self.assertEqual(data.size(), torch.Size([mesh.size() - 1, 3, 64, 64]))
1423-
self.assertEqual(
1424-
torch_xla._XLAC._get_xla_sharding_spec(data),
1425-
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1426-
)
1450+
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1451+
if self.convert_to_shardy:
1452+
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
1453+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)
14271454

14281455
@unittest.skipUnless(
14291456
xr.global_runtime_device_count() > 1,

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,16 @@ std::string GetTensorsHloGraph(const std::vector<at::Tensor>& tensors,
760760
return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode);
761761
}
762762

763+
std::optional<xla::OpSharding> GetXLAOpSharding(const at::Tensor& input) {
764+
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
765+
XLATensor::ShardingSpecPtr sharding_spec =
766+
xtensor ? xtensor->sharding_spec() : nullptr;
767+
if (sharding_spec != nullptr) {
768+
return sharding_spec->sharding;
769+
}
770+
return std::nullopt;
771+
}
772+
763773
std::string GetXLAShardingSpec(const XLATensorPtr xtensor) {
764774
auto sharding_spec = xtensor->sharding_spec();
765775
if (sharding_spec != nullptr) {
@@ -1526,6 +1536,10 @@ at::Tensor tensor_fromDLPack(PyObject* data) {
15261536
void InitXlaModuleBindings(py::module m) {
15271537
PythonScope<py::module> module(m);
15281538

1539+
using TileAssignmentDims = std::vector<int64_t>;
1540+
using ReshapeDims = std::vector<int64_t>;
1541+
using TransposePerm = std::vector<int>;
1542+
15291543
// Define the _XLAC.XlaShardingSpec class.
15301544
PythonScope<py::class_<XLATensor::ShardingSpec, XLATensor::ShardingSpecPtr>>(
15311545
m, "XlaShardingSpec")
@@ -1543,12 +1557,12 @@ void InitXlaModuleBindings(py::module m) {
15431557
})
15441558
.def_init([](at::Tensor tensor, const py::list& dims,
15451559
const py::list& reshape_dims, const py::list& transpose_perm,
1546-
bool minibatch) {
1560+
const py::list& types, bool minibatch) {
15471561
xla::Shape global_shape =
15481562
ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch);
15491563
return std::make_shared<XLATensor::ShardingSpec>(
15501564
ShardingUtil::CreateIotaOpSharding(dims, reshape_dims,
1551-
transpose_perm),
1565+
transpose_perm, types),
15521566
global_shape, minibatch);
15531567
});
15541568

@@ -1578,9 +1592,9 @@ void InitXlaModuleBindings(py::module m) {
15781592
})
15791593
// Constructor for V2 shardings.
15801594
.def_init([](const py::list& dims, const py::list& reshape_dims,
1581-
const py::list& transpose_perm) {
1595+
const py::list& transpose_perm, const py::list& types) {
15821596
return ShardingUtil::CreateIotaOpSharding(dims, reshape_dims,
1583-
transpose_perm);
1597+
transpose_perm, types);
15841598
});
15851599

15861600
// Define the _XLAC.PjRtPlugin class.
@@ -2703,7 +2717,26 @@ void InitXlaModuleBindings(py::module m) {
27032717
if (sharding_spec != nullptr) {
27042718
return sharding_spec->sharding;
27052719
}
2706-
return std::nullopt;
2720+
return GetXLAOpSharding(input);
2721+
})
2722+
.def("_get_xla_op_sharding_v2_params",
2723+
[](const at::Tensor& input) -> std::optional<std::tuple<TileAssignmentDims, ReshapeDims, TransposePerm, bool>> {
2724+
std::optional<xla::OpSharding> maybe_sharding =
2725+
GetXLAOpSharding(input);
2726+
if (!maybe_sharding) {
2727+
return std::nullopt;
2728+
}
2729+
const xla::OpSharding& sharding = maybe_sharding.value();
2730+
TileAssignmentDims tile_assignment_dims(
2731+
sharding.tile_assignment_dimensions().begin(),
2732+
sharding.tile_assignment_dimensions().end());
2733+
ReshapeDims reshape_dims(sharding.iota_reshape_dims().begin(),
2734+
sharding.iota_reshape_dims().end());
2735+
TransposePerm transpose_perm(sharding.iota_transpose_perm().begin(),
2736+
sharding.iota_transpose_perm().end());
2737+
return std::make_tuple(tile_assignment_dims, reshape_dims,
2738+
transpose_perm,
2739+
sharding.replicate_on_last_tile_dim());
27072740
})
27082741
.def("_get_xla_sharding_specs",
27092742
[](const std::vector<at::Tensor>& tensors)

torch_xla/csrc/xla_sharding_util.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,18 +220,21 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a,
220220

221221
xla::OpSharding ShardingUtil::CreateIotaOpSharding(
222222
const py::list& dims, const py::list& reshape_dims,
223-
const py::list& transpose_perm) {
223+
const py::list& transpose_perm, const py::list& types) {
224+
TORCH_LAZY_COUNTER("CreateIotaOpSharding", 1);
224225
auto dims_vec = dims.cast<std::vector<int64_t>>();
225226
auto reshape_dims_vec = reshape_dims.cast<std::vector<int64_t>>();
226227
auto transpose_perm_vec = transpose_perm.cast<std::vector<int>>();
227-
std::vector<xla::OpSharding::Type> subgroup_types;
228-
if (dims_vec.size() > transpose_perm.size()) {
229-
subgroup_types.push_back(xla::OpSharding::REPLICATED);
228+
std::vector<xla::OpSharding::Type> subgroup_types_vec;
229+
for (auto type : types) {
230+
subgroup_types_vec.push_back(
231+
static_cast<xla::OpSharding::Type>(type.cast<int>()));
230232
}
233+
CHECK_EQ(reshape_dims_vec.size(), transpose_perm_vec.size());
231234
return xla::HloSharding::Subgroup(
232235
xla::TileAssignment(dims_vec, reshape_dims_vec,
233236
transpose_perm_vec),
234-
subgroup_types)
237+
subgroup_types_vec)
235238
.ToProto();
236239
}
237240

0 commit comments

Comments
 (0)