Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions test/spmd/test_spmd_debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch_xla.distributed.spmd as xs
from torch_xla.distributed.spmd import XLAShardedTensor
from torch_xla.distributed.spmd import Mesh
from torch_xla.distributed.spmd.debugging import construct_v1_sharding_str

import test_xla_sharding_base

Expand Down Expand Up @@ -828,6 +829,77 @@ def test_multi_host_replicated_cpu(self):
fake_output = fake_capture.get()
assert output == fake_output


class ConvertV2ShardingToV1Test(test_xla_sharding_base.XlaShardingTest):

@classmethod
def setUpClass(cls):
super().setUpClass()
os.environ["CONVERT_SHLO_TO_SHARDY"] = "1"

def run_test(self):
mesh = self._get_mesh(self.device_mesh_shape)
t = torch.randn(self.tensor_shape).to(torch_xla.device())
xs.mark_sharding(t, mesh, self.partition_spec)
actual_str = construct_v1_sharding_str(t)
self.assertEqual(self.expected_str, actual_str)

def test_tiled_sharding(self):
self.device_mesh_shape = (1, self.n_devices)
self.tensor_shape = (1, 128)
self.partition_spec = (0, 1)
self.expected_str = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
[str(i) for i in range(self.n_devices)]))
self.run_test()

@unittest.skipIf(xr.global_runtime_device_count() < 2,
f"Requires at least 2 devices.")
def test_tupled_tiled_sharding(self):
self.device_mesh_shape = (2, self.n_devices // 2)
self.tensor_shape = (16,)
self.partition_spec = ((0, 1),)
self.expected_str = "{devices=[%d]%s}" % (self.n_devices, ','.join(
str(x) for x in range(self.n_devices)))
self.run_test()

def test_replicated_sharding(self):
self.device_mesh_shape = (1, self.n_devices)
self.tensor_shape = (4, 4)
self.partition_spec = (None, None)
self.expected_str = '{replicated}'
self.run_test()

@unittest.skipIf(xr.global_runtime_device_count() < 4,
f"Requires at least 4 devices.")
def test_partial_replication_sharding(self):
self.device_mesh_shape = (2, self.n_devices // 2)
self.tensor_shape = (4, 4)
self.partition_spec = (0, None)
self.expected_str = '{devices=[2,1,%d]%s last_tile_dim_replicate}' % (
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
self.run_test()

@unittest.skipIf(xr.global_runtime_device_count() < 4,
f"Requires at least 4 devices.")
def test_tupled_partial_replication_sharding(self):
self.device_mesh_shape = (1, 2, self.n_devices // 2)
self.tensor_shape = (16, 16)
self.partition_spec = ((0, 1), None)
self.expected_str = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
self.run_test()

def test_tupled_partial_replication_sharding_with_transpose(self):
self.device_mesh_shape = (1, 2, self.n_devices // 2)
self.tensor_shape = (16, 16)
self.partition_spec = (None, (2, 1))
device_order = self.device_ids.reshape(self.device_mesh_shape).transpose(
(2, 1, 0)).flatten()
self.expected_str = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
str(x) for x in device_order))
self.run_test()


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
99 changes: 63 additions & 36 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.convert_to_shardy = xu.check_env_flag("CONVERT_SHLO_TO_SHARDY")

def test_xla_sharded_tensor(self):
partition_spec = (0, 1)
Expand Down Expand Up @@ -238,6 +239,8 @@ def test_custom_tile_assignment(self):
if self.n_devices > 1:
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
[str(i) for i in reversed(range(self.n_devices))]))
if self.convert_to_shardy:
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))

def test_mark_sharding_2d(self):
Expand All @@ -252,6 +255,8 @@ def test_mark_sharding_2d(self):
if self.n_devices > 1:
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
[str(i) for i in range(self.n_devices)]))
if self.convert_to_shardy:
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1))

actual = (xt1 + xt2).cpu()
Expand All @@ -271,6 +276,9 @@ def test_mark_sharding_4d(self):
annotation = '{devices=[1,1,%d,%d]%s}' % (
z_dim, self.n_devices // z_dim, ','.join(
[str(i) for i in range(self.n_devices)]))
if self.convert_to_shardy:
annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim, self.n_devices //
z_dim, self.n_devices)
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))

actual = (xt + xt).cpu()
Expand Down Expand Up @@ -403,9 +411,11 @@ def test_tupled_partition_spec(self):
mesh = self._get_mesh((2, self.n_devices // 2))
t = torch.randn(16).to('xla')
xs.mark_sharding(t, mesh, ((0, 1),))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[%d]%s}" %
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
annotation = "{devices=[%d]%s}" % (self.n_devices, ','.join(
str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = "{devices=[%d]<=[%d]}" % (self.n_devices, self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)

@unittest.skipUnless(xr.global_runtime_device_count() >= 4,
"Multiple devices required for tupled partition spec")
Expand All @@ -415,34 +425,43 @@ def test_named_partial_tupled_partition_spec(self):
# Shard the first dimension on `r` and `b`, replicate the second dimension
t = torch.randn(16, 16).to('xla')
xs.mark_sharding(t, mesh, (('r', 'b'), None))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t),
"{devices=[2,1,%d]%s last_tile_dim_replicate}" %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % (
self.n_devices // 2, self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)

# Replicate the first dimension, shard the second on `b` and `m`
u = torch.randn(16, 16).to('xla')
xs.mark_sharding(u, mesh, (None, ('b', 'm')))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(u), "{devices=[1,%d]%s}" %
(self.n_devices, ','.join(str(x) for x in range(self.n_devices))))
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = "{devices=[1,%d]<=[%d]}" % (self.n_devices, self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(u), annotation)

# Replicate the first dimension, shard the second on `r` and `m`
v = torch.randn(16, 16).to('xla')
xs.mark_sharding(v, mesh, (None, ('r', 'm')))
device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten()
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v),
"{devices=[1,%d,2]%s last_tile_dim_replicate}" %
(self.n_devices // 2, ','.join(str(x) for x in device_order)))
annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % (
self.n_devices // 2, ','.join(str(x) for x in device_order))
if self.convert_to_shardy:
annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % (
self.n_devices // 2, self.n_devices // 2)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)

# Replicate the first dimension, shard the second on `m` and `b`
v = torch.randn(16, 16).to('xla')
xs.mark_sharding(v, mesh, (None, ('m', 'b')))
device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten()
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(v), "{devices=[1,%d]%s}" %
(self.n_devices, ','.join(str(x) for x in device_order)))
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
str(x) for x in device_order))
if self.convert_to_shardy:
annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self.n_devices,
self.n_devices // 2)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
'Multiple devices required for tupled partition spec')
Expand All @@ -452,19 +471,25 @@ def test_multiple_tuples_in_spec(self):
('a', 'b', 'c', 'd'))
t = torch.randn(2, 2).to('xla')
xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd')))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), "{devices=[2,%d]%s}" %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
annotation = "{devices=[2,%d]%s}" % (self.n_devices // 2, ','.join(
str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = "{devices=[2,%d]<=[%d]}" % (self.n_devices // 2,
self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)

@unittest.skipUnless(xr.global_runtime_device_count() > 1,
'At least 2 devices needed for 2D mesh')
def test_3d_tensor_2d_mesh(self):
mesh = self._get_mesh((2, self.n_devices // 2))
t = torch.randn(16, 16, 16).to('xla')
xs.mark_sharding(t, mesh, (None, 0, 1))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(t), '{devices=[1,2,%d]%s}' %
(self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices))))
annotation = '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join(
str(x) for x in range(self.n_devices)))
if self.convert_to_shardy:
annotation = '{devices=[1,2,%d]<=[%d]}' % (self.n_devices // 2,
self.n_devices)
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)

def test_partial_replication_addmm(self):
device = torch_xla.device()
Expand Down Expand Up @@ -983,18 +1008,20 @@ def test_op_sharding_cache(self):

t = torch.randn(1, self.n_devices).to('xla')
xs.mark_sharding(t, mesh, (0, 1))
self.assertIn("CreateOpSharding", met.counter_names())
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
counter_name = "CreateIotaOpSharding" if self.convert_to_shardy else "CreateOpSharding"
self.assertIn(counter_name, met.counter_names())
self.assertEqual(met.counter_value(counter_name), 1)

# Sharding with the same partition spec should not result in another call
u = torch.randn(1, self.n_devices).to('xla')
xs.mark_sharding(u, mesh, (0, 1))
self.assertEqual(met.counter_value("CreateOpSharding"), 1)
self.assertEqual(met.counter_value(counter_name), 1)

# Changing the partition spec will result in another CreateOpSharding
# Changing the partition spec will result in another
# CreateOpSharding or CreatingIotaOpSharding call
v = torch.randn(1, self.n_devices).to('xla')
xs.mark_sharding(v, mesh, (0, None))
self.assertEqual(met.counter_value("CreateOpSharding"), 2)
self.assertEqual(met.counter_value(counter_name), 2)

def test_from_cpu_shards_replicated(self):
from_cpu_shards = torch_xla._XLAC._global_tensor_from_cpu_shards
Expand Down Expand Up @@ -1401,10 +1428,10 @@ def test_data_loader_with_sharding(self):
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
data, _ = iter(train_device_loader).__next__()
self.assertEqual(data.size(), torch.Size([8, 3, 64, 64]))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(data),
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
)
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
if self.convert_to_shardy:
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)

@unittest.skipUnless(
xr.global_runtime_device_count() > 1,
Expand All @@ -1424,10 +1451,10 @@ def test_data_loader_with_non_batch_size(self):
input_sharding=xs.ShardingSpec(mesh, ('data', None, None, None)))
data, _ = iter(train_device_loader).__next__()
self.assertEqual(data.size(), torch.Size([mesh.size() - 1, 3, 64, 64]))
self.assertEqual(
torch_xla._XLAC._get_xla_sharding_spec(data),
f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
)
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
if self.convert_to_shardy:
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)

@unittest.skipUnless(
xr.global_runtime_device_count() > 1,
Expand Down
56 changes: 43 additions & 13 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,16 @@ std::string GetTensorsHloGraph(const std::vector<at::Tensor>& tensors,
return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode);
}

std::optional<xla::OpSharding> GetXLAOpSharding(const at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLATensor::ShardingSpecPtr sharding_spec =
xtensor ? xtensor->sharding_spec() : nullptr;
if (sharding_spec != nullptr) {
return sharding_spec->sharding;
}
return std::nullopt;
}

std::string GetXLAShardingSpec(const XLATensorPtr xtensor) {
auto sharding_spec = xtensor->sharding_spec();
if (sharding_spec != nullptr) {
Expand Down Expand Up @@ -1460,6 +1470,10 @@ at::Tensor tensor_fromDLPack(PyObject* data) {
void InitXlaModuleBindings(py::module m) {
PythonScope<py::module> module(m);

using TileAssignmentDims = std::vector<int64_t>;
using ReshapeDims = std::vector<int64_t>;
using TransposePerm = std::vector<int>;

// Define the _XLAC.XlaShardingSpec class.
PythonScope<py::class_<XLATensor::ShardingSpec, XLATensor::ShardingSpecPtr>>(
m, "XlaShardingSpec")
Expand All @@ -1477,12 +1491,12 @@ void InitXlaModuleBindings(py::module m) {
})
.def_init([](at::Tensor tensor, const py::list& dims,
const py::list& reshape_dims, const py::list& transpose_perm,
bool minibatch) {
const py::list& types, bool minibatch) {
xla::Shape global_shape =
ShardingUtil::GetAdjustedGlobalShape(tensor, minibatch);
return std::make_shared<XLATensor::ShardingSpec>(
ShardingUtil::CreateIotaOpSharding(dims, reshape_dims,
transpose_perm),
transpose_perm, types),
global_shape, minibatch);
});

Expand Down Expand Up @@ -1512,9 +1526,9 @@ void InitXlaModuleBindings(py::module m) {
})
// Constructor for V2 shardings.
.def_init([](const py::list& dims, const py::list& reshape_dims,
const py::list& transpose_perm) {
const py::list& transpose_perm, const py::list& types) {
return ShardingUtil::CreateIotaOpSharding(dims, reshape_dims,
transpose_perm);
transpose_perm, types);
});

// Define the _XLAC.PjRtPlugin class.
Expand Down Expand Up @@ -1742,7 +1756,8 @@ void InitXlaModuleBindings(py::module m) {
}
})
.def("_xla_get_runtime_devices",
[]() { return runtime::GetComputationClientOrDie()->GetLocalDevices(); })
[]() {
return runtime::GetComputationClientOrDie()->GetLocalDevices(); })
.def("_xla_num_runtime_devices",
[]() -> int64_t {
return runtime::GetComputationClientOrDie()->GetNumLocalDevices();
Expand Down Expand Up @@ -2154,9 +2169,11 @@ void InitXlaModuleBindings(py::module m) {
return device.ordinal();
})
.def("_xla_get_process_index",
[]() { return runtime::GetComputationClientOrDie()->GetProcessIndex(); })
[]() {
return runtime::GetComputationClientOrDie()->GetProcessIndex(); })
.def("_xla_get_num_processes",
[]() { return runtime::GetComputationClientOrDie()->GetNumProcesses(); })
[]() {
return runtime::GetComputationClientOrDie()->GetNumProcesses(); })
.def("_xla_get_num_cached_compilation_graph",
[]() -> int64_t {
return XLAGraphExecutor::Get()->GetNumGraphHash();
Expand Down Expand Up @@ -2579,13 +2596,26 @@ void InitXlaModuleBindings(py::module m) {
})
.def("_get_xla_op_sharding",
[](const at::Tensor& input) -> std::optional<xla::OpSharding> {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
XLATensor::ShardingSpecPtr sharding_spec =
xtensor ? xtensor->sharding_spec() : nullptr;
if (sharding_spec != nullptr) {
return sharding_spec->sharding;
return GetXLAOpSharding(input);
})
.def("_get_xla_op_sharding_v2_params",
[](const at::Tensor& input) -> std::optional<std::tuple<TileAssignmentDims, ReshapeDims, TransposePerm, bool>> {
std::optional<xla::OpSharding> maybe_sharding =
GetXLAOpSharding(input);
if (!maybe_sharding) {
return std::nullopt;
}
return std::nullopt;
const xla::OpSharding& sharding = maybe_sharding.value();
TileAssignmentDims tile_assignment_dims(
sharding.tile_assignment_dimensions().begin(),
sharding.tile_assignment_dimensions().end());
ReshapeDims reshape_dims(sharding.iota_reshape_dims().begin(),
sharding.iota_reshape_dims().end());
TransposePerm transpose_perm(sharding.iota_transpose_perm().begin(),
sharding.iota_transpose_perm().end());
return std::make_tuple(tile_assignment_dims, reshape_dims,
transpose_perm,
sharding.replicate_on_last_tile_dim());
})
.def("_get_xla_sharding_specs",
[](const std::vector<at::Tensor>& tensors)
Expand Down
Loading