diff --git a/test/spmd/test_spmd_debugging.py b/test/spmd/test_spmd_debugging.py index 34221d375e9..5ffa74f9886 100644 --- a/test/spmd/test_spmd_debugging.py +++ b/test/spmd/test_spmd_debugging.py @@ -17,6 +17,7 @@ import torch_xla.distributed.spmd as xs from torch_xla.distributed.spmd import XLAShardedTensor from torch_xla.distributed.spmd import Mesh +from torch_xla.distributed.spmd.debugging import construct_v1_sharding_str import test_xla_sharding_base @@ -828,6 +829,77 @@ def test_multi_host_replicated_cpu(self): fake_output = fake_capture.get() assert output == fake_output + +class ConvertV2ShardingToV1Test(test_xla_sharding_base.XlaShardingTest): + + @classmethod + def setUpClass(cls): + super().setUpClass() + os.environ["CONVERT_SHLO_TO_SHARDY"] = "1" + + def run_test(self): + mesh = self._get_mesh(self.device_mesh_shape) + t = torch.randn(self.tensor_shape).to(torch_xla.device()) + xs.mark_sharding(t, mesh, self.partition_spec) + actual_str = construct_v1_sharding_str(t) + self.assertEqual(self.expected_str, actual_str) + + def test_tiled_sharding(self): + self.device_mesh_shape = (1, self.n_devices) + self.tensor_shape = (1, 128) + self.partition_spec = (0, 1) + self.expected_str = '{devices=[1,%d]%s}' % (self.n_devices, ','.join( + [str(i) for i in range(self.n_devices)])) + self.run_test() + + @unittest.skipIf(xr.global_runtime_device_count() < 2, + f"Requires at least 2 devices.") + def test_tupled_tiled_sharding(self): + self.device_mesh_shape = (2, self.n_devices // 2) + self.tensor_shape = (16,) + self.partition_spec = ((0, 1),) + self.expected_str = "{devices=[%d]%s}" % (self.n_devices, ','.join( + str(x) for x in range(self.n_devices))) + self.run_test() + + def test_replicated_sharding(self): + self.device_mesh_shape = (1, self.n_devices) + self.tensor_shape = (4, 4) + self.partition_spec = (None, None) + self.expected_str = '{replicated}' + self.run_test() + + @unittest.skipIf(xr.global_runtime_device_count() < 4, + f"Requires at least 4 devices.") + def test_partial_replication_sharding(self): + self.device_mesh_shape = (2, self.n_devices // 2) + self.tensor_shape = (4, 4) + self.partition_spec = (0, None) + self.expected_str = '{devices=[2,1,%d]%s last_tile_dim_replicate}' % ( + self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))) + self.run_test() + + @unittest.skipIf(xr.global_runtime_device_count() < 4, + f"Requires at least 4 devices.") + def test_tupled_partial_replication_sharding(self): + self.device_mesh_shape = (1, 2, self.n_devices // 2) + self.tensor_shape = (16, 16) + self.partition_spec = ((0, 1), None) + self.expected_str = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % ( + self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))) + self.run_test() + + def test_tupled_partial_replication_sharding_with_transpose(self): + self.device_mesh_shape = (1, 2, self.n_devices // 2) + self.tensor_shape = (16, 16) + self.partition_spec = (None, (2, 1)) + device_order = self.device_ids.reshape(self.device_mesh_shape).transpose( + (2, 1, 0)).flatten() + self.expected_str = "{devices=[1,%d]%s}" % (self.n_devices, ','.join( + str(x) for x in device_order)) + self.run_test() + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/test/spmd/test_xla_sharding.py b/test/spmd/test_xla_sharding.py index 7b1be7574a1..6fe42be517e 100644 --- a/test/spmd/test_xla_sharding.py +++ b/test/spmd/test_xla_sharding.py @@ -31,6 +31,7 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest): @classmethod def setUpClass(cls): super().setUpClass() + cls.convert_to_shardy = xu.check_env_flag("CONVERT_SHLO_TO_SHARDY") def test_xla_sharded_tensor(self): partition_spec = (0, 1) @@ -238,6 +239,8 @@ def test_custom_tile_assignment(self): if self.n_devices > 1: annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join( [str(i) for i in reversed(range(self.n_devices))])) + if self.convert_to_shardy: + annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices) self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt)) def test_mark_sharding_2d(self): @@ -252,6 +255,8 @@ def test_mark_sharding_2d(self): if self.n_devices > 1: annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join( [str(i) for i in range(self.n_devices)])) + if self.convert_to_shardy: + annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices) self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1)) actual = (xt1 + xt2).cpu() @@ -271,6 +276,9 @@ def test_mark_sharding_4d(self): annotation = '{devices=[1,1,%d,%d]%s}' % ( z_dim, self.n_devices // z_dim, ','.join( [str(i) for i in range(self.n_devices)])) + if self.convert_to_shardy: + annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim, self.n_devices // + z_dim, self.n_devices) self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt)) actual = (xt + xt).cpu() @@ -403,9 +411,11 @@ def test_tupled_partition_spec(self): mesh = self._get_mesh((2, self.n_devices // 2)) t = torch.randn(16).to('xla') xs.mark_sharding(t, mesh, ((0, 1),)) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" % - (self.n_devices, ','.join(str(x) for x in range(self.n_devices)))) + annotation = "{devices=[%d]%s}" % (self.n_devices, ','.join( + str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = "{devices=[%d]<=[%d]}" % (self.n_devices, self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation) @unittest.skipUnless(xr.global_runtime_device_count() >= 4, "Multiple devices required for tupled partition spec") @@ -415,34 +425,43 @@ def test_named_partial_tupled_partition_spec(self): # Shard the first dimension on `r` and `b`, replicate the second dimension t = torch.randn(16, 16).to('xla') xs.mark_sharding(t, mesh, (('r', 'b'), None)) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(t), - "{devices=[2,1,%d]%s last_tile_dim_replicate}" % - (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) + annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % ( + self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % ( + self.n_devices // 2, self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation) # Replicate the first dimension, shard the second on `b` and `m` u = torch.randn(16, 16).to('xla') xs.mark_sharding(u, mesh, (None, ('b', 'm'))) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" % - (self.n_devices, ','.join(str(x) for x in range(self.n_devices)))) + annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join( + str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = "{devices=[1,%d]<=[%d]}" % (self.n_devices, self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(u), annotation) # Replicate the first dimension, shard the second on `r` and `m` v = torch.randn(16, 16).to('xla') xs.mark_sharding(v, mesh, (None, ('r', 'm'))) device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten() - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(v), - "{devices=[1,%d,2]%s last_tile_dim_replicate}" % - (self.n_devices // 2, ','.join(str(x) for x in device_order))) + annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % ( + self.n_devices // 2, ','.join(str(x) for x in device_order)) + if self.convert_to_shardy: + annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % ( + self.n_devices // 2, self.n_devices // 2) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation) # Replicate the first dimension, shard the second on `m` and `b` v = torch.randn(16, 16).to('xla') xs.mark_sharding(v, mesh, (None, ('m', 'b'))) device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten() - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(v), "{devices=[1,%d]%s}" % - (self.n_devices, ','.join(str(x) for x in device_order))) + annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join( + str(x) for x in device_order)) + if self.convert_to_shardy: + annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self.n_devices, + self.n_devices // 2) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation) @unittest.skipUnless(xr.global_runtime_device_count() > 1, 'Multiple devices required for tupled partition spec') @@ -452,9 +471,12 @@ def test_multiple_tuples_in_spec(self): ('a', 'b', 'c', 'd')) t = torch.randn(2, 2).to('xla') xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd'))) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" % - (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) + annotation = "{devices=[2,%d]%s}" % (self.n_devices // 2, ','.join( + str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = "{devices=[2,%d]<=[%d]}" % (self.n_devices // 2, + self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation) @unittest.skipUnless(xr.global_runtime_device_count() > 1, 'At least 2 devices needed for 2D mesh') @@ -462,9 +484,12 @@ def test_3d_tensor_2d_mesh(self): mesh = self._get_mesh((2, self.n_devices // 2)) t = torch.randn(16, 16, 16).to('xla') xs.mark_sharding(t, mesh, (None, 0, 1)) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' % - (self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))) + annotation = '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join( + str(x) for x in range(self.n_devices))) + if self.convert_to_shardy: + annotation = '{devices=[1,2,%d]<=[%d]}' % (self.n_devices // 2, + self.n_devices) + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation) def test_partial_replication_addmm(self): device = torch_xla.device() @@ -983,18 +1008,20 @@ def test_op_sharding_cache(self): t = torch.randn(1, self.n_devices).to('xla') xs.mark_sharding(t, mesh, (0, 1)) - self.assertIn("CreateOpSharding", met.counter_names()) - self.assertEqual(met.counter_value("CreateOpSharding"), 1) + counter_name = "CreateIotaOpSharding" if self.convert_to_shardy else "CreateOpSharding" + self.assertIn(counter_name, met.counter_names()) + self.assertEqual(met.counter_value(counter_name), 1) # Sharding with the same partition spec should not result in another call u = torch.randn(1, self.n_devices).to('xla') xs.mark_sharding(u, mesh, (0, 1)) - self.assertEqual(met.counter_value("CreateOpSharding"), 1) + self.assertEqual(met.counter_value(counter_name), 1) - # Changing the partition spec will result in another CreateOpSharding + # Changing the partition spec will result in another + # CreateOpSharding or CreatingIotaOpSharding call v = torch.randn(1, self.n_devices).to('xla') xs.mark_sharding(v, mesh, (0, None)) - self.assertEqual(met.counter_value("CreateOpSharding"), 2) + self.assertEqual(met.counter_value(counter_name), 2) def test_from_cpu_shards_replicated(self): from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards @@ -1401,10 +1428,10 @@ def test_data_loader_with_sharding(self): input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None))) data, _ = iter(train_device_loader).__next__() self.assertEqual(data.size(), torch.Size([8, 3, 64, 64])) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(data), - f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" - ) + annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" + if self.convert_to_shardy: + annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}" + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation) @unittest.skipUnless( xr.global_runtime_device_count() > 1, @@ -1424,10 +1451,10 @@ def test_data_loader_with_non_batch_size(self): input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None))) data, _ = iter(train_device_loader).__next__() self.assertEqual(data.size(), torch.Size([mesh.size() - 1, 3, 64, 64])) - self.assertEqual( - torch_xla._XLAC._get_xla_sharding_spec(data), - f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" - ) + annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}" + if self.convert_to_shardy: + annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}" + self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation) @unittest.skipUnless( xr.global_runtime_device_count() > 1, diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index f8a300205d1..b2b915533bf 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -703,6 +703,16 @@ std::string GetTensorsHloGraph(const std::vector& tensors, return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode); } +std::optional GetXLAOpSharding(const at::Tensor& input) { + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + XLATensor::ShardingSpecPtr sharding_spec = + xtensor ? xtensor->sharding_spec() : nullptr; + if (sharding_spec != nullptr) { + return sharding_spec->sharding; + } + return std::nullopt; +} + std::string GetXLAShardingSpec(const XLATensorPtr xtensor) { auto sharding_spec = xtensor->sharding_spec(); if (sharding_spec != nullptr) { @@ -1460,6 +1470,10 @@ at::Tensor tensor_fromDLPack(PyObject* data) { void InitXlaModuleBindings(py::module m) { PythonScope module(m); + using TileAssignmentDims = std::vector; + using ReshapeDims = std::vector; + using TransposePerm = std::vector; + // Define the _XLAC.XlaShardingSpec class. PythonScope>( m, "XlaShardingSpec") @@ -1477,12 +1491,12 @@ void InitXlaModuleBindings(py::module m) { }) .def_init([](at::Tensor tensor, const py::list& dims, const py::list& reshape_dims, const py::list& transpose_perm, - bool minibatch) { + const py::list& types, bool minibatch) { xla::Shape global_shape = ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch); return std::make_shared( ShardingUtil::CreateIotaOpSharding(dims, reshape_dims, - transpose_perm), + transpose_perm, types), global_shape, minibatch); }); @@ -1512,9 +1526,9 @@ void InitXlaModuleBindings(py::module m) { }) // Constructor for V2 shardings. .def_init([](const py::list& dims, const py::list& reshape_dims, - const py::list& transpose_perm) { + const py::list& transpose_perm, const py::list& types) { return ShardingUtil::CreateIotaOpSharding(dims, reshape_dims, - transpose_perm); + transpose_perm, types); }); // Define the _XLAC.PjRtPlugin class. @@ -1742,7 +1756,8 @@ void InitXlaModuleBindings(py::module m) { } }) .def("_xla_get_runtime_devices", - []() { return runtime::GetComputationClientOrDie()->GetLocalDevices(); }) + []() { + return runtime::GetComputationClientOrDie()->GetLocalDevices(); }) .def("_xla_num_runtime_devices", []() -> int64_t { return runtime::GetComputationClientOrDie()->GetNumLocalDevices(); @@ -2154,9 +2169,11 @@ void InitXlaModuleBindings(py::module m) { return device.ordinal(); }) .def("_xla_get_process_index", - []() { return runtime::GetComputationClientOrDie()->GetProcessIndex(); }) + []() { + return runtime::GetComputationClientOrDie()->GetProcessIndex(); }) .def("_xla_get_num_processes", - []() { return runtime::GetComputationClientOrDie()->GetNumProcesses(); }) + []() { + return runtime::GetComputationClientOrDie()->GetNumProcesses(); }) .def("_xla_get_num_cached_compilation_graph", []() -> int64_t { return XLAGraphExecutor::Get()->GetNumGraphHash(); @@ -2579,13 +2596,26 @@ void InitXlaModuleBindings(py::module m) { }) .def("_get_xla_op_sharding", [](const at::Tensor& input) -> std::optional { - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - XLATensor::ShardingSpecPtr sharding_spec = - xtensor ? xtensor->sharding_spec() : nullptr; - if (sharding_spec != nullptr) { - return sharding_spec->sharding; + return GetXLAOpSharding(input); + }) + .def("_get_xla_op_sharding_v2_params", + [](const at::Tensor& input) -> std::optional> { + std::optional maybe_sharding = + GetXLAOpSharding(input); + if (!maybe_sharding) { + return std::nullopt; } - return std::nullopt; + const xla::OpSharding& sharding = maybe_sharding.value(); + TileAssignmentDims tile_assignment_dims( + sharding.tile_assignment_dimensions().begin(), + sharding.tile_assignment_dimensions().end()); + ReshapeDims reshape_dims(sharding.iota_reshape_dims().begin(), + sharding.iota_reshape_dims().end()); + TransposePerm transpose_perm(sharding.iota_transpose_perm().begin(), + sharding.iota_transpose_perm().end()); + return std::make_tuple(tile_assignment_dims, reshape_dims, + transpose_perm, + sharding.replicate_on_last_tile_dim()); }) .def("_get_xla_sharding_specs", [](const std::vector& tensors) diff --git a/torch_xla/csrc/xla_sharding_util.cpp b/torch_xla/csrc/xla_sharding_util.cpp index 700901e3f5d..b2768699061 100644 --- a/torch_xla/csrc/xla_sharding_util.cpp +++ b/torch_xla/csrc/xla_sharding_util.cpp @@ -220,18 +220,21 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a, xla::OpSharding ShardingUtil::CreateIotaOpSharding( const py::list& dims, const py::list& reshape_dims, - const py::list& transpose_perm) { + const py::list& transpose_perm, const py::list& types) { + TORCH_LAZY_COUNTER("CreateIotaOpSharding", 1); auto dims_vec = dims.cast>(); auto reshape_dims_vec = reshape_dims.cast>(); auto transpose_perm_vec = transpose_perm.cast>(); - std::vector subgroup_types; - if (dims_vec.size() > transpose_perm.size()) { - subgroup_types.push_back(xla::OpSharding::REPLICATED); + std::vector subgroup_types_vec; + for (auto type : types) { + subgroup_types_vec.push_back( + static_cast(type.cast())); } + CHECK_EQ(reshape_dims_vec.size(), transpose_perm_vec.size()); return xla::HloSharding::Subgroup( xla::TileAssignment(dims_vec, reshape_dims_vec, transpose_perm_vec), - subgroup_types) + subgroup_types_vec) .ToProto(); } diff --git a/torch_xla/csrc/xla_sharding_util.h b/torch_xla/csrc/xla_sharding_util.h index 2cae399e293..a925c470748 100644 --- a/torch_xla/csrc/xla_sharding_util.h +++ b/torch_xla/csrc/xla_sharding_util.h @@ -55,7 +55,8 @@ class ShardingUtil { // HloShardingV2 system. static xla::OpSharding CreateIotaOpSharding(const py::list& dims, const py::list& reshape_dims, - const py::list& transpose_perm); + const py::list& transpose_perm, + const py::list& types); // Returns the shape of the resulting shards of `tensor` after applying // `sharding`. This assumes the shards will be padded to ensure they all diff --git a/torch_xla/distributed/spmd/debugging.py b/torch_xla/distributed/spmd/debugging.py index e5f53d04aea..2cb9368aff0 100644 --- a/torch_xla/distributed/spmd/debugging.py +++ b/torch_xla/distributed/spmd/debugging.py @@ -157,6 +157,27 @@ def visualize_sharding(sharding: str, return table +def construct_v1_sharding_str(t: torch.Tensor) -> str: + """ + Returns the corresponding HLO V1 sharding string from the tensor + """ + sharding = torch_xla._XLAC._get_xla_sharding_spec(t) + if "<=" not in sharding: + # This is already in the V1 format + return sharding + sharding_params = torch_xla._XLAC._get_xla_op_sharding_v2_params(t) + assert sharding_params is not None + tile_assignment_dims, reshape_dims, transpose_perm, replicate_on_last_dim = sharding_params + num_devices = np.prod(reshape_dims) + device_list = np.arange(num_devices).reshape(reshape_dims).transpose( + transpose_perm).reshape(num_devices) + + tile_assignment_str = ",".join(str(dim) for dim in tile_assignment_dims) + device_list_str = ",".join(str(i) for i in device_list) + replicate_str = " last_tile_dim_replicate" if replicate_on_last_dim else "" + return f"{{devices=[{tile_assignment_str}]{device_list_str}{replicate_str}}}" + + def visualize_tensor_sharding(t, **kwargs): """Visualizes an array's sharding.""" @@ -164,5 +185,7 @@ def visualize_tensor_sharding(t, **kwargs): def maybe_unwrap(t: torch.Tensor) -> torch.Tensor: return t.global_tensor if isinstance(t, XLAShardedTensor) else t - sharding = torch_xla._XLAC._get_xla_sharding_spec(maybe_unwrap(t)) + t = maybe_unwrap(t) + sharding = construct_v1_sharding_str(t) + return visualize_sharding(sharding, **kwargs) diff --git a/torch_xla/distributed/spmd/xla_sharding.py b/torch_xla/distributed/spmd/xla_sharding.py index 042405c5a50..6f8193ba5f4 100644 --- a/torch_xla/distributed/spmd/xla_sharding.py +++ b/torch_xla/distributed/spmd/xla_sharding.py @@ -131,12 +131,6 @@ def _validate_translated_partition_spec(self, partition_spec: tuple): def _get_op_sharding_args(self, partition_spec: PartitionSpec): partition_spec = _translate_named_partition_spec(self, partition_spec) self._validate_translated_partition_spec(partition_spec) - flat_specs = np.hstack([d for d in partition_spec]) - specs = [d for d in flat_specs if d is not None] - assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \ - f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape." - assert len(specs) == len(np.unique(specs)), \ - f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}." tile_assignment = _get_tile_assignment(self, partition_spec) if len(tile_assignment.shape) > len(partition_spec): @@ -154,44 +148,58 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec): @functools.lru_cache(maxsize=None) def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec): + """ + This function returns all the sharding parameters needed for TILED or PARTIAL sharding. + (All other sharding types are handled separately by the V1 OpSharding function) + """ partition_spec = _translate_named_partition_spec(self, partition_spec) self._validate_translated_partition_spec(partition_spec) - # 1. Calculate the initial part of dims based on the partition_spec. - dims = [] - used_axes = OrderedDict() - for axis in partition_spec: - if isinstance(axis, tuple): - dim_size = 1 - for i in axis: - assert i is not None, "None not allowed within tuple" - dim_size *= self.mesh_shape[i] - used_axes[i] = True - dims.append(dim_size) - elif axis is not None: - assert isinstance(axis, int), "Axis must be an int or a tuple of ints" - dims.append(self.mesh_shape[axis]) - used_axes[axis] = True - else: - dims.append(1) - - # 2. If the product of dims is less than the total number of devices, - # append the sizes of the unused mesh axes. - if math.prod(dims) < math.prod(self.mesh_shape): - for i in range(len(self.mesh_shape)): - if i not in used_axes: - dims.append(self.mesh_shape[i]) + # This algorithm is adapted from + # https://github.com/openxla/xla/blob/256b633e0adaee80588a8c3a5e4b2eaa005b5414/xla/service/spmd/shardy/stablehlo_round_trip/export_shardings.cc#L288 + tile_assignment_dims = [1] * len(partition_spec) + axisRefToShardedPos = {} + subgroup_types = [] + shardedPos = 0 - # 3. Calculate transpose_perm (sharded axes first, then unused axes). - transpose_perm = list(used_axes.keys()) - for i in range(len(self.mesh_shape)): - if i not in used_axes: - transpose_perm.append(i) + for idx, axes in enumerate(partition_spec): + if axes is None: + # Tensor dim is being replicated + continue + elif isinstance(axes, tuple): + # Tensor dim is being sharded over multiple axes + for axis in axes: + tile_assignment_dims[idx] *= self.mesh_shape[axis] + axisRefToShardedPos[axis] = shardedPos + shardedPos += 1 + else: + # Tensor dim is being sharded over just 1 axis + tile_assignment_dims[idx] *= self.mesh_shape[axes] + axisRefToShardedPos[axes] = shardedPos + shardedPos += 1 + + all_axes_ordered = [i for i in range(len(self.mesh_shape))] + reshape_dims = [0] * len(all_axes_ordered) + transpose_perm = [0] * len(all_axes_ordered) + + totalReplicatedSize = 1 + replicatedPos = shardedPos + for idx, axis in enumerate(all_axes_ordered): + reshape_dims[idx] = self.mesh_shape[axis] + if axis in axisRefToShardedPos: + # Axis is sharded + transpose_perm[axisRefToShardedPos[axis]] = idx + else: + # Axis is replicated + transpose_perm[replicatedPos] = idx + replicatedPos += 1 + totalReplicatedSize *= self.mesh_shape[axis] - # 4. reshape_dims is always the physical mesh shape. - reshape_dims = list(self.mesh_shape) + if totalReplicatedSize > 1: + tile_assignment_dims.append(totalReplicatedSize) + subgroup_types.append(ShardingType.REPLICATED) - return dims, reshape_dims, transpose_perm + return tile_assignment_dims, reshape_dims, transpose_perm, subgroup_types @functools.lru_cache(maxsize=None) def get_op_sharding_v2( @@ -203,11 +211,11 @@ def get_op_sharding_v2( return torch_xla._XLAC.OpSharding([], [], [], ShardingType.REPLICATED) sharding_type = _get_sharding_type(partition_spec, self.size()) if sharding_type not in (ShardingType.TILED, ShardingType.PARTIAL): - return torch_xla._XLAC.OpSharding([], [], [0], sharding_type) + return torch_xla._XLAC.OpSharding([], [], [], sharding_type) - dims, reshape_dims, transpose_perm = self._get_op_sharding_args_v2( + dims, reshape_dims, transpose_perm, types = self._get_op_sharding_args_v2( partition_spec) - return torch_xla._XLAC.OpSharding(dims, reshape_dims, transpose_perm) + return torch_xla._XLAC.OpSharding(dims, reshape_dims, transpose_perm, types) @functools.lru_cache(maxsize=None) def get_op_sharding( @@ -881,7 +889,7 @@ def __post_init__(self): self._sharding_type, tile_assignment, len(partition_spec), replicate_dims) if _use_shlo_to_shardy(): - self.dims, self.reshape_dims, self.transpose_dims = mesh._get_op_sharding_args_v2( + self.dims, self.reshape_dims, self.transpose_perm, self.subgroup_types = mesh._get_op_sharding_args_v2( partition_spec) def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: @@ -893,9 +901,9 @@ def xla_spec(self, t: torch.Tensor) -> Union['XlaShardingSpec', None]: return None if _use_shlo_to_shardy(): - # Convert to Shardy spec if the environment variable is set. return torch_xla._XLAC.XlaShardingSpec(t, self.dims, self.reshape_dims, - self.transpose_dims, + self.transpose_perm, + self.subgroup_types, self.minibatch) return torch_xla._XLAC.XlaShardingSpec(t, self._tile_assignment,