Skip to content

Commit 96f1d55

Browse files
authored
Change V2 sharding spec algorithm + Fix tensor sharding spec visualization (#7)
This PR adds support for all previously unsupported partition specs and fixes the visualize_tensor_sharding() function to support V2 sharding specs. See pytorch#9541 for the upstream PR discussion and additional context. * Add some tests and reviewer suggestions. Will update V2 op sharding logic in a later commit soon. * New implementation (WIP) * Fix new implementation * Fix visualize_tensor_sharding function for V2 shardings
1 parent 4a8c988 commit 96f1d55

File tree

7 files changed

+264
-100
lines changed

7 files changed

+264
-100
lines changed

test/spmd/test_spmd_debugging.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch_xla.distributed.spmd as xs
1818
from torch_xla.distributed.spmd import XLAShardedTensor
1919
from torch_xla.distributed.spmd import Mesh
20+
from torch_xla.distributed.spmd.debugging import construct_v1_sharding_str
2021

2122
import test_xla_sharding_base
2223

@@ -828,6 +829,77 @@ def test_multi_host_replicated_cpu(self):
828829
fake_output = fake_capture.get()
829830
assert output == fake_output
830831

832+
833+
class ConvertV2ShardingToV1Test(test_xla_sharding_base.XlaShardingTest):
834+
835+
@classmethod
836+
def setUpClass(cls):
837+
super().setUpClass()
838+
os.environ["CONVERT_SHLO_TO_SHARDY"] = "1"
839+
840+
def run_test(self):
841+
mesh = self._get_mesh(self.device_mesh_shape)
842+
t = torch.randn(self.tensor_shape).to(torch_xla.device())
843+
xs.mark_sharding(t, mesh, self.partition_spec)
844+
actual_str = construct_v1_sharding_str(t)
845+
self.assertEqual(self.expected_str, actual_str)
846+
847+
def test_tiled_sharding(self):
848+
self.device_mesh_shape = (1, self.n_devices)
849+
self.tensor_shape = (1, 128)
850+
self.partition_spec = (0, 1)
851+
self.expected_str = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
852+
[str(i) for i in range(self.n_devices)]))
853+
self.run_test()
854+
855+
@unittest.skipIf(xr.global_runtime_device_count() < 2,
856+
f"Requires at least 2 devices.")
857+
def test_tupled_tiled_sharding(self):
858+
self.device_mesh_shape = (2, self.n_devices // 2)
859+
self.tensor_shape = (16,)
860+
self.partition_spec = ((0, 1),)
861+
self.expected_str = "{devices=[%d]%s}" % (self.n_devices, ','.join(
862+
str(x) for x in range(self.n_devices)))
863+
self.run_test()
864+
865+
def test_replicated_sharding(self):
866+
self.device_mesh_shape = (1, self.n_devices)
867+
self.tensor_shape = (4, 4)
868+
self.partition_spec = (None, None)
869+
self.expected_str = '{replicated}'
870+
self.run_test()
871+
872+
@unittest.skipIf(xr.global_runtime_device_count() < 4,
873+
f"Requires at least 4 devices.")
874+
def test_partial_replication_sharding(self):
875+
self.device_mesh_shape = (2, self.n_devices // 2)
876+
self.tensor_shape = (4, 4)
877+
self.partition_spec = (0, None)
878+
self.expected_str = '{devices=[2,1,%d]%s last_tile_dim_replicate}' % (
879+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
880+
self.run_test()
881+
882+
@unittest.skipIf(xr.global_runtime_device_count() < 4,
883+
f"Requires at least 4 devices.")
884+
def test_tupled_partial_replication_sharding(self):
885+
self.device_mesh_shape = (1, 2, self.n_devices // 2)
886+
self.tensor_shape = (16, 16)
887+
self.partition_spec = ((0, 1), None)
888+
self.expected_str = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
889+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
890+
self.run_test()
891+
892+
def test_tupled_partial_replication_sharding_with_transpose(self):
893+
self.device_mesh_shape = (1, 2, self.n_devices // 2)
894+
self.tensor_shape = (16, 16)
895+
self.partition_spec = (None, (2, 1))
896+
device_order = self.device_ids.reshape(self.device_mesh_shape).transpose(
897+
(2, 1, 0)).flatten()
898+
self.expected_str = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
899+
str(x) for x in device_order))
900+
self.run_test()
901+
902+
831903
if __name__ == '__main__':
832904
test = unittest.main()
833905
sys.exit(0 if test.result.wasSuccessful() else 1)

test/spmd/test_xla_sharding.py

Lines changed: 63 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):
3131
@classmethod
3232
def setUpClass(cls):
3333
super().setUpClass()
34+
cls.convert_to_shardy = xu.check_env_flag("CONVERT_SHLO_TO_SHARDY")
3435

3536
def test_xla_sharded_tensor(self):
3637
partition_spec = (0, 1)
@@ -238,6 +239,8 @@ def test_custom_tile_assignment(self):
238239
if self.n_devices > 1:
239240
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
240241
[str(i) for i in reversed(range(self.n_devices))]))
242+
if self.convert_to_shardy:
243+
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
241244
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
242245

243246
def test_mark_sharding_2d(self):
@@ -252,6 +255,8 @@ def test_mark_sharding_2d(self):
252255
if self.n_devices > 1:
253256
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
254257
[str(i) for i in range(self.n_devices)]))
258+
if self.convert_to_shardy:
259+
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
255260
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1))
256261

257262
actual = (xt1 + xt2).cpu()
@@ -271,6 +276,9 @@ def test_mark_sharding_4d(self):
271276
annotation = '{devices=[1,1,%d,%d]%s}' % (
272277
z_dim, self.n_devices // z_dim, ','.join(
273278
[str(i) for i in range(self.n_devices)]))
279+
if self.convert_to_shardy:
280+
annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim, self.n_devices //
281+
z_dim, self.n_devices)
274282
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
275283

276284
actual = (xt + xt).cpu()
@@ -403,9 +411,11 @@ def test_tupled_partition_spec(self):
403411
mesh = self._get_mesh((2, self.n_devices // 2))
404412
t = torch.randn(16).to('xla')
405413
xs.mark_sharding(t, mesh, ((0, 1),))
406-
self.assertEqual(
407-
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" %
408-
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
414+
annotation = "{devices=[%d]%s}" % (self.n_devices, ','.join(
415+
str(x) for x in range(self.n_devices)))
416+
if self.convert_to_shardy:
417+
annotation = "{devices=[%d]<=[%d]}" % (self.n_devices, self.n_devices)
418+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
409419

410420
@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
411421
"Multiple devices required for tupled partition spec")
@@ -415,34 +425,43 @@ def test_named_partial_tupled_partition_spec(self):
415425
# Shard the first dimension on `r` and `b`, replicate the second dimension
416426
t = torch.randn(16, 16).to('xla')
417427
xs.mark_sharding(t, mesh, (('r', 'b'), None))
418-
self.assertEqual(
419-
torch_xla._XLAC._get_xla_sharding_spec(t),
420-
"{devices=[2,1,%d]%s last_tile_dim_replicate}" %
421-
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
428+
annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
429+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
430+
if self.convert_to_shardy:
431+
annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % (
432+
self.n_devices // 2, self.n_devices)
433+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
422434

423435
# Replicate the first dimension, shard the second on `b` and `m`
424436
u = torch.randn(16, 16).to('xla')
425437
xs.mark_sharding(u, mesh, (None, ('b', 'm')))
426-
self.assertEqual(
427-
torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" %
428-
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
438+
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
439+
str(x) for x in range(self.n_devices)))
440+
if self.convert_to_shardy:
441+
annotation = "{devices=[1,%d]<=[%d]}" % (self.n_devices, self.n_devices)
442+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(u), annotation)
429443

430444
# Replicate the first dimension, shard the second on `r` and `m`
431445
v = torch.randn(16, 16).to('xla')
432446
xs.mark_sharding(v, mesh, (None, ('r', 'm')))
433447
device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten()
434-
self.assertEqual(
435-
torch_xla._XLAC._get_xla_sharding_spec(v),
436-
"{devices=[1,%d,2]%s last_tile_dim_replicate}" %
437-
(self.n_devices // 2, ','.join(str(x) for x in device_order)))
448+
annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % (
449+
self.n_devices // 2, ','.join(str(x) for x in device_order))
450+
if self.convert_to_shardy:
451+
annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % (
452+
self.n_devices // 2, self.n_devices // 2)
453+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)
438454

439455
# Replicate the first dimension, shard the second on `m` and `b`
440456
v = torch.randn(16, 16).to('xla')
441457
xs.mark_sharding(v, mesh, (None, ('m', 'b')))
442458
device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten()
443-
self.assertEqual(
444-
torch_xla._XLAC._get_xla_sharding_spec(v), "{devices=[1,%d]%s}" %
445-
(self.n_devices, ','.join(str(x) for x in device_order)))
459+
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
460+
str(x) for x in device_order))
461+
if self.convert_to_shardy:
462+
annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self.n_devices,
463+
self.n_devices // 2)
464+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)
446465

447466
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
448467
'Multiple devices required for tupled partition spec')
@@ -452,19 +471,25 @@ def test_multiple_tuples_in_spec(self):
452471
('a', 'b', 'c', 'd'))
453472
t = torch.randn(2, 2).to('xla')
454473
xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd')))
455-
self.assertEqual(
456-
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" %
457-
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
474+
annotation = "{devices=[2,%d]%s}" % (self.n_devices // 2, ','.join(
475+
str(x) for x in range(self.n_devices)))
476+
if self.convert_to_shardy:
477+
annotation = "{devices=[2,%d]<=[%d]}" % (self.n_devices // 2,
478+
self.n_devices)
479+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
458480

459481
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
460482
'At least 2 devices needed for 2D mesh')
461483
def test_3d_tensor_2d_mesh(self):
462484
mesh = self._get_mesh((2, self.n_devices // 2))
463485
t = torch.randn(16, 16, 16).to('xla')
464486
xs.mark_sharding(t, mesh, (None, 0, 1))
465-
self.assertEqual(
466-
torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' %
467-
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
487+
annotation = '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join(
488+
str(x) for x in range(self.n_devices)))
489+
if self.convert_to_shardy:
490+
annotation = '{devices=[1,2,%d]<=[%d]}' % (self.n_devices // 2,
491+
self.n_devices)
492+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
468493

469494
def test_partial_replication_addmm(self):
470495
device = torch_xla.device()
@@ -983,18 +1008,20 @@ def test_op_sharding_cache(self):
9831008

9841009
t = torch.randn(1, self.n_devices).to('xla')
9851010
xs.mark_sharding(t, mesh, (0, 1))
986-
self.assertIn("CreateOpSharding", met.counter_names())
987-
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
1011+
counter_name = "CreateIotaOpSharding" if self.convert_to_shardy else "CreateOpSharding"
1012+
self.assertIn(counter_name, met.counter_names())
1013+
self.assertEqual(met.counter_value(counter_name), 1)
9881014

9891015
# Sharding with the same partition spec should not result in another call
9901016
u = torch.randn(1, self.n_devices).to('xla')
9911017
xs.mark_sharding(u, mesh, (0, 1))
992-
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
1018+
self.assertEqual(met.counter_value(counter_name), 1)
9931019

994-
# Changing the partition spec will result in another CreateOpSharding
1020+
# Changing the partition spec will result in another
1021+
# CreateOpSharding or CreatingIotaOpSharding call
9951022
v = torch.randn(1, self.n_devices).to('xla')
9961023
xs.mark_sharding(v, mesh, (0, None))
997-
self.assertEqual(met.counter_value("CreateOpSharding"), 2)
1024+
self.assertEqual(met.counter_value(counter_name), 2)
9981025

9991026
def test_from_cpu_shards_replicated(self):
10001027
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
@@ -1401,10 +1428,10 @@ def test_data_loader_with_sharding(self):
14011428
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
14021429
data, _ = iter(train_device_loader).__next__()
14031430
self.assertEqual(data.size(), torch.Size([8, 3, 64, 64]))
1404-
self.assertEqual(
1405-
torch_xla._XLAC._get_xla_sharding_spec(data),
1406-
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1407-
)
1431+
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1432+
if self.convert_to_shardy:
1433+
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
1434+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)
14081435

14091436
@unittest.skipUnless(
14101437
xr.global_runtime_device_count() > 1,
@@ -1424,10 +1451,10 @@ def test_data_loader_with_non_batch_size(self):
14241451
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
14251452
data, _ = iter(train_device_loader).__next__()
14261453
self.assertEqual(data.size(), torch.Size([mesh.size() - 1, 3, 64, 64]))
1427-
self.assertEqual(
1428-
torch_xla._XLAC._get_xla_sharding_spec(data),
1429-
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1430-
)
1454+
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1455+
if self.convert_to_shardy:
1456+
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
1457+
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)
14311458

14321459
@unittest.skipUnless(
14331460
xr.global_runtime_device_count() > 1,

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,16 @@ std::string GetTensorsHloGraph(const std::vector<at::Tensor>& tensors,
703703
return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode);
704704
}
705705

706+
std::optional<xla::OpSharding> GetXLAOpSharding(const at::Tensor& input) {
707+
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
708+
XLATensor::ShardingSpecPtr sharding_spec =
709+
xtensor ? xtensor->sharding_spec() : nullptr;
710+
if (sharding_spec != nullptr) {
711+
return sharding_spec->sharding;
712+
}
713+
return std::nullopt;
714+
}
715+
706716
std::string GetXLAShardingSpec(const XLATensorPtr xtensor) {
707717
auto sharding_spec = xtensor->sharding_spec();
708718
if (sharding_spec != nullptr) {
@@ -1460,6 +1470,10 @@ at::Tensor tensor_fromDLPack(PyObject* data) {
14601470
void InitXlaModuleBindings(py::module m) {
14611471
PythonScope<py::module> module(m);
14621472

1473+
using TileAssignmentDims = std::vector<int64_t>;
1474+
using ReshapeDims = std::vector<int64_t>;
1475+
using TransposePerm = std::vector<int>;
1476+
14631477
// Define the _XLAC.XlaShardingSpec class.
14641478
PythonScope<py::class_<XLATensor::ShardingSpec, XLATensor::ShardingSpecPtr>>(
14651479
m, "XlaShardingSpec")
@@ -1477,12 +1491,12 @@ void InitXlaModuleBindings(py::module m) {
14771491
})
14781492
.def_init([](at::Tensor tensor, const py::list& dims,
14791493
const py::list& reshape_dims, const py::list& transpose_perm,
1480-
bool minibatch) {
1494+
const py::list& types, bool minibatch) {
14811495
xla::Shape global_shape =
14821496
ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch);
14831497
return std::make_shared<XLATensor::ShardingSpec>(
14841498
ShardingUtil::CreateIotaOpSharding(dims, reshape_dims,
1485-
transpose_perm),
1499+
transpose_perm, types),
14861500
global_shape, minibatch);
14871501
});
14881502

@@ -1512,9 +1526,9 @@ void InitXlaModuleBindings(py::module m) {
15121526
})
15131527
// Constructor for V2 shardings.
15141528
.def_init([](const py::list& dims, const py::list& reshape_dims,
1515-
const py::list& transpose_perm) {
1529+
const py::list& transpose_perm, const py::list& types) {
15161530
return ShardingUtil::CreateIotaOpSharding(dims, reshape_dims,
1517-
transpose_perm);
1531+
transpose_perm, types);
15181532
});
15191533

15201534
// Define the _XLAC.PjRtPlugin class.
@@ -1742,7 +1756,8 @@ void InitXlaModuleBindings(py::module m) {
17421756
}
17431757
})
17441758
.def("_xla_get_runtime_devices",
1745-
[]() { return runtime::GetComputationClientOrDie()->GetLocalDevices(); })
1759+
[]() {
1760+
return runtime::GetComputationClientOrDie()->GetLocalDevices(); })
17461761
.def("_xla_num_runtime_devices",
17471762
[]() -> int64_t {
17481763
return runtime::GetComputationClientOrDie()->GetNumLocalDevices();
@@ -2154,9 +2169,11 @@ void InitXlaModuleBindings(py::module m) {
21542169
return device.ordinal();
21552170
})
21562171
.def("_xla_get_process_index",
2157-
[]() { return runtime::GetComputationClientOrDie()->GetProcessIndex(); })
2172+
[]() {
2173+
return runtime::GetComputationClientOrDie()->GetProcessIndex(); })
21582174
.def("_xla_get_num_processes",
2159-
[]() { return runtime::GetComputationClientOrDie()->GetNumProcesses(); })
2175+
[]() {
2176+
return runtime::GetComputationClientOrDie()->GetNumProcesses(); })
21602177
.def("_xla_get_num_cached_compilation_graph",
21612178
[]() -> int64_t {
21622179
return XLAGraphExecutor::Get()->GetNumGraphHash();
@@ -2579,13 +2596,26 @@ void InitXlaModuleBindings(py::module m) {
25792596
})
25802597
.def("_get_xla_op_sharding",
25812598
[](const at::Tensor& input) -> std::optional<xla::OpSharding> {
2582-
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
2583-
XLATensor::ShardingSpecPtr sharding_spec =
2584-
xtensor ? xtensor->sharding_spec() : nullptr;
2585-
if (sharding_spec != nullptr) {
2586-
return sharding_spec->sharding;
2599+
return GetXLAOpSharding(input);
2600+
})
2601+
.def("_get_xla_op_sharding_v2_params",
2602+
[](const at::Tensor& input) -> std::optional<std::tuple<TileAssignmentDims, ReshapeDims, TransposePerm, bool>> {
2603+
std::optional<xla::OpSharding> maybe_sharding =
2604+
GetXLAOpSharding(input);
2605+
if (!maybe_sharding) {
2606+
return std::nullopt;
25872607
}
2588-
return std::nullopt;
2608+
const xla::OpSharding& sharding = maybe_sharding.value();
2609+
TileAssignmentDims tile_assignment_dims(
2610+
sharding.tile_assignment_dimensions().begin(),
2611+
sharding.tile_assignment_dimensions().end());
2612+
ReshapeDims reshape_dims(sharding.iota_reshape_dims().begin(),
2613+
sharding.iota_reshape_dims().end());
2614+
TransposePerm transpose_perm(sharding.iota_transpose_perm().begin(),
2615+
sharding.iota_transpose_perm().end());
2616+
return std::make_tuple(tile_assignment_dims, reshape_dims,
2617+
transpose_perm,
2618+
sharding.replicate_on_last_tile_dim());
25892619
})
25902620
.def("_get_xla_sharding_specs",
25912621
[](const std::vector<at::Tensor>& tensors)

0 commit comments

Comments
 (0)