Skip to content

Commit 8d72b9a

Browse files
committed
Fix visualize_tensor_sharding function for V2 shardings
1 parent ecada8b commit 8d72b9a

File tree

4 files changed

+144
-30
lines changed

4 files changed

+144
-30
lines changed

test/spmd/test_spmd_debugging.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch_xla.distributed.spmd as xs
1818
from torch_xla.distributed.spmd import XLAShardedTensor
1919
from torch_xla.distributed.spmd import Mesh
20+
from torch_xla.distributed.spmd.debugging import construct_v1_sharding_str
2021

2122
import test_xla_sharding_base
2223

@@ -822,6 +823,77 @@ def test_multi_host_replicated_cpu(self):
822823
fake_output = fake_capture.get()
823824
assert output == fake_output
824825

826+
827+
class ConvertV2ShardingToV1Test(test_xla_sharding_base.XlaShardingTest):
828+
829+
@classmethod
830+
def setUpClass(cls):
831+
super().setUpClass()
832+
os.environ["CONVERT_SHLO_TO_SHARDY"] = "1"
833+
834+
def run_test(self):
835+
mesh = self._get_mesh(self.device_mesh_shape)
836+
t = torch.randn(self.tensor_shape).to(torch_xla.device())
837+
xs.mark_sharding(t, mesh, self.partition_spec)
838+
actual_str = construct_v1_sharding_str(t)
839+
self.assertEqual(self.expected_str, actual_str)
840+
841+
def test_tiled_sharding(self):
842+
self.device_mesh_shape = (1, self.n_devices)
843+
self.tensor_shape = (1, 128)
844+
self.partition_spec = (0, 1)
845+
self.expected_str = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
846+
[str(i) for i in range(self.n_devices)]))
847+
self.run_test()
848+
849+
@unittest.skipIf(xr.global_runtime_device_count() < 2,
850+
f"Requires at least 2 devices.")
851+
def test_tupled_tiled_sharding(self):
852+
self.device_mesh_shape = (2, self.n_devices // 2)
853+
self.tensor_shape = (16,)
854+
self.partition_spec = ((0, 1),)
855+
self.expected_str = "{devices=[%d]%s}" % (self.n_devices, ','.join(
856+
str(x) for x in range(self.n_devices)))
857+
self.run_test()
858+
859+
def test_replicated_sharding(self):
860+
self.device_mesh_shape = (1, self.n_devices)
861+
self.tensor_shape = (4, 4)
862+
self.partition_spec = (None, None)
863+
self.expected_str = '{replicated}'
864+
self.run_test()
865+
866+
@unittest.skipIf(xr.global_runtime_device_count() < 4,
867+
f"Requires at least 4 devices.")
868+
def test_partial_replication_sharding(self):
869+
self.device_mesh_shape = (2, self.n_devices // 2)
870+
self.tensor_shape = (4, 4)
871+
self.partition_spec = (0, None)
872+
self.expected_str = '{devices=[2,1,%d]%s last_tile_dim_replicate}' % (
873+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
874+
self.run_test()
875+
876+
@unittest.skipIf(xr.global_runtime_device_count() < 4,
877+
f"Requires at least 4 devices.")
878+
def test_tupled_partial_replication_sharding(self):
879+
self.device_mesh_shape = (1, 2, self.n_devices // 2)
880+
self.tensor_shape = (16, 16)
881+
self.partition_spec = ((0, 1), None)
882+
self.expected_str = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
883+
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
884+
self.run_test()
885+
886+
def test_tupled_partial_replication_sharding_with_transpose(self):
887+
self.device_mesh_shape = (1, 2, self.n_devices // 2)
888+
self.tensor_shape = (16, 16)
889+
self.partition_spec = (None, (2, 1))
890+
device_order = self.device_ids.reshape(self.device_mesh_shape).transpose(
891+
(2, 1, 0)).flatten()
892+
self.expected_str = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
893+
str(x) for x in device_order))
894+
self.run_test()
895+
896+
825897
if __name__ == '__main__':
826898
test = unittest.main()
827899
sys.exit(0 if test.result.wasSuccessful() else 1)

test/spmd/test_xla_sharding.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import unittest
77
from unittest.mock import patch
88
import sys
9-
import os
109

1110
import torch
1211
from torch import nn
@@ -27,16 +26,12 @@
2726
from torch_xla._internal import tpu
2827

2928

30-
def should_convert_to_shardy():
31-
return os.environ.get("CONVERT_SHLO_TO_SHARDY",
32-
"").lower() in ("1", "true", "yes")
33-
34-
3529
class BasicXlaShardingTest(test_xla_sharding_base.XlaShardingTest):
3630

3731
@classmethod
3832
def setUpClass(cls):
3933
super().setUpClass()
34+
cls.convert_to_shardy = xu.check_env_flag("CONVERT_SHLO_TO_SHARDY")
4035

4136
def test_xla_sharded_tensor(self):
4237
partition_spec = (0, 1)
@@ -244,7 +239,7 @@ def test_custom_tile_assignment(self):
244239
if self.n_devices > 1:
245240
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
246241
[str(i) for i in reversed(range(self.n_devices))]))
247-
if should_convert_to_shardy():
242+
if self.convert_to_shardy:
248243
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
249244
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
250245

@@ -260,7 +255,7 @@ def test_mark_sharding_2d(self):
260255
if self.n_devices > 1:
261256
annotation = '{devices=[1,%d]%s}' % (self.n_devices, ','.join(
262257
[str(i) for i in range(self.n_devices)]))
263-
if should_convert_to_shardy():
258+
if self.convert_to_shardy:
264259
annotation = '{devices=[1,%d]<=[%d]}' % (self.n_devices, self.n_devices)
265260
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt1))
266261

@@ -281,7 +276,7 @@ def test_mark_sharding_4d(self):
281276
annotation = '{devices=[1,1,%d,%d]%s}' % (
282277
z_dim, self.n_devices // z_dim, ','.join(
283278
[str(i) for i in range(self.n_devices)]))
284-
if should_convert_to_shardy():
279+
if self.convert_to_shardy:
285280
annotation = '{devices=[1,1,%d,%d]<=[%d]}' % (z_dim, self.n_devices //
286281
z_dim, self.n_devices)
287282
self.assertEqual(annotation, torch_xla._XLAC._get_xla_sharding_spec(xt))
@@ -418,7 +413,7 @@ def test_tupled_partition_spec(self):
418413
xs.mark_sharding(t, mesh, ((0, 1),))
419414
annotation = "{devices=[%d]%s}" % (self.n_devices, ','.join(
420415
str(x) for x in range(self.n_devices)))
421-
if should_convert_to_shardy():
416+
if self.convert_to_shardy:
422417
annotation = "{devices=[%d]<=[%d]}" % (self.n_devices, self.n_devices)
423418
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
424419

@@ -432,7 +427,7 @@ def test_named_partial_tupled_partition_spec(self):
432427
xs.mark_sharding(t, mesh, (('r', 'b'), None))
433428
annotation = "{devices=[2,1,%d]%s last_tile_dim_replicate}" % (
434429
self.n_devices // 2, ','.join(str(x) for x in range(self.n_devices)))
435-
if should_convert_to_shardy():
430+
if self.convert_to_shardy:
436431
annotation = "{devices=[2,1,%d]<=[%d] last_tile_dim_replicate}" % (
437432
self.n_devices // 2, self.n_devices)
438433
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
@@ -442,7 +437,7 @@ def test_named_partial_tupled_partition_spec(self):
442437
xs.mark_sharding(u, mesh, (None, ('b', 'm')))
443438
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
444439
str(x) for x in range(self.n_devices)))
445-
if should_convert_to_shardy():
440+
if self.convert_to_shardy:
446441
annotation = "{devices=[1,%d]<=[%d]}" % (self.n_devices, self.n_devices)
447442
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(u), annotation)
448443

@@ -452,7 +447,7 @@ def test_named_partial_tupled_partition_spec(self):
452447
device_order = mesh.get_logical_mesh().transpose((0, 2, 1)).flatten()
453448
annotation = "{devices=[1,%d,2]%s last_tile_dim_replicate}" % (
454449
self.n_devices // 2, ','.join(str(x) for x in device_order))
455-
if should_convert_to_shardy():
450+
if self.convert_to_shardy:
456451
annotation = "{devices=[1,%d,2]<=[2,%d]T(1,0) last_tile_dim_replicate}" % (
457452
self.n_devices // 2, self.n_devices // 2)
458453
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)
@@ -463,7 +458,7 @@ def test_named_partial_tupled_partition_spec(self):
463458
device_order = mesh.get_logical_mesh().transpose((2, 1, 0)).flatten()
464459
annotation = "{devices=[1,%d]%s}" % (self.n_devices, ','.join(
465460
str(x) for x in device_order))
466-
if should_convert_to_shardy():
461+
if self.convert_to_shardy:
467462
annotation = "{devices=[1,%d]<=[2,%d]T(1,0)}" % (self.n_devices,
468463
self.n_devices // 2)
469464
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(v), annotation)
@@ -478,7 +473,7 @@ def test_multiple_tuples_in_spec(self):
478473
xs.mark_sharding(t, mesh, (('a', 'b'), ('c', 'd')))
479474
annotation = "{devices=[2,%d]%s}" % (self.n_devices // 2, ','.join(
480475
str(x) for x in range(self.n_devices)))
481-
if should_convert_to_shardy():
476+
if self.convert_to_shardy:
482477
annotation = "{devices=[2,%d]<=[%d]}" % (self.n_devices // 2,
483478
self.n_devices)
484479
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
@@ -491,7 +486,7 @@ def test_3d_tensor_2d_mesh(self):
491486
xs.mark_sharding(t, mesh, (None, 0, 1))
492487
annotation = '{devices=[1,2,%d]%s}' % (self.n_devices // 2, ','.join(
493488
str(x) for x in range(self.n_devices)))
494-
if should_convert_to_shardy():
489+
if self.convert_to_shardy:
495490
annotation = '{devices=[1,2,%d]<=[%d]}' % (self.n_devices // 2,
496491
self.n_devices)
497492
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(t), annotation)
@@ -1013,8 +1008,7 @@ def test_op_sharding_cache(self):
10131008

10141009
t = torch.randn(1, self.n_devices).to('xla')
10151010
xs.mark_sharding(t, mesh, (0, 1))
1016-
counter_name = "CreateIotaOpSharding" if should_convert_to_shardy(
1017-
) else "CreateOpSharding"
1011+
counter_name = "CreateIotaOpSharding" if self.convert_to_shardy else "CreateOpSharding"
10181012
self.assertIn(counter_name, met.counter_names())
10191013
self.assertEqual(met.counter_value(counter_name), 1)
10201014

@@ -1435,7 +1429,7 @@ def test_data_loader_with_sharding(self):
14351429
data, _ = iter(train_device_loader).__next__()
14361430
self.assertEqual(data.size(), torch.Size([8, 3, 64, 64]))
14371431
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1438-
if should_convert_to_shardy():
1432+
if self.convert_to_shardy:
14391433
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
14401434
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)
14411435

@@ -1458,7 +1452,7 @@ def test_data_loader_with_non_batch_size(self):
14581452
data, _ = iter(train_device_loader).__next__()
14591453
self.assertEqual(data.size(), torch.Size([mesh.size() - 1, 3, 64, 64]))
14601454
annotation = f"{{devices=[{mesh.size()},1,1,1]{','.join([str(i) for i in range(mesh.size())])}}}"
1461-
if should_convert_to_shardy():
1455+
if self.convert_to_shardy:
14621456
annotation = f"{{devices=[{mesh.size()},1,1,1]<=[{mesh.size()}]}}"
14631457
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(data), annotation)
14641458

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@
8383
#include "xla/pjrt/distributed/distributed.h"
8484
#include "xla/python/profiler/internal/traceme_wrapper.h"
8585

86-
#define PYBIND11_DETAILED_ERROR_MESSAGES
87-
8886
namespace torch_xla {
8987
namespace {
9088

@@ -762,6 +760,16 @@ std::string GetTensorsHloGraph(const std::vector<at::Tensor>& tensors,
762760
return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode);
763761
}
764762

763+
std::optional<xla::OpSharding> GetXLAOpSharding(const at::Tensor& input) {
764+
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
765+
XLATensor::ShardingSpecPtr sharding_spec =
766+
xtensor ? xtensor->sharding_spec() : nullptr;
767+
if (sharding_spec != nullptr) {
768+
return sharding_spec->sharding;
769+
}
770+
return std::nullopt;
771+
}
772+
765773
std::string GetXLAShardingSpec(const XLATensorPtr xtensor) {
766774
auto sharding_spec = xtensor->sharding_spec();
767775
if (sharding_spec != nullptr) {
@@ -1528,6 +1536,10 @@ at::Tensor tensor_fromDLPack(PyObject* data) {
15281536
void InitXlaModuleBindings(py::module m) {
15291537
PythonScope<py::module> module(m);
15301538

1539+
using TileAssignmentDims = std::vector<int64_t>;
1540+
using ReshapeDims = std::vector<int64_t>;
1541+
using TransposePerm = std::vector<int>;
1542+
15311543
// Define the _XLAC.XlaShardingSpec class.
15321544
PythonScope<py::class_<XLATensor::ShardingSpec, XLATensor::ShardingSpecPtr>>(
15331545
m, "XlaShardingSpec")
@@ -2699,13 +2711,26 @@ void InitXlaModuleBindings(py::module m) {
26992711
})
27002712
.def("_get_xla_op_sharding",
27012713
[](const at::Tensor& input) -> std::optional<xla::OpSharding> {
2702-
XLA_ASSIGN_OR_THROW(XLATensorPtr xtensor, bridge::GetXlaTensor(input));
2703-
XLATensor::ShardingSpecPtr sharding_spec =
2704-
xtensor ? xtensor->sharding_spec() : nullptr;
2705-
if (sharding_spec != nullptr) {
2706-
return sharding_spec->sharding;
2714+
return GetXLAOpSharding(input);
2715+
})
2716+
.def("_get_xla_op_sharding_v2_params",
2717+
[](const at::Tensor& input) -> std::optional<std::tuple<TileAssignmentDims, ReshapeDims, TransposePerm, bool>> {
2718+
std::optional<xla::OpSharding> maybe_sharding =
2719+
GetXLAOpSharding(input);
2720+
if (!maybe_sharding) {
2721+
return std::nullopt;
27072722
}
2708-
return std::nullopt;
2723+
const xla::OpSharding& sharding = maybe_sharding.value();
2724+
TileAssignmentDims tile_assignment_dims(
2725+
sharding.tile_assignment_dimensions().begin(),
2726+
sharding.tile_assignment_dimensions().end());
2727+
ReshapeDims reshape_dims(sharding.iota_reshape_dims().begin(),
2728+
sharding.iota_reshape_dims().end());
2729+
TransposePerm transpose_perm(sharding.iota_transpose_perm().begin(),
2730+
sharding.iota_transpose_perm().end());
2731+
return std::make_tuple(tile_assignment_dims, reshape_dims,
2732+
transpose_perm,
2733+
sharding.replicate_on_last_tile_dim());
27092734
})
27102735
.def("_get_xla_sharding_specs",
27112736
[](const std::vector<at::Tensor>& tensors)

torch_xla/distributed/spmd/debugging.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import functools
33
import string
44
import sys
5-
from typing import Any, Callable, Optional, Union
5+
from typing import Any, Callable, Optional, Union, Tuple
66
import weakref
77

88
import numpy as np
@@ -157,12 +157,35 @@ def visualize_sharding(sharding: str,
157157
return table
158158

159159

160+
def construct_v1_sharding_str(t: torch.Tensor) -> str:
161+
"""
162+
Returns the corresponding HLO V1 sharding string from the tensor
163+
"""
164+
sharding = torch_xla._XLAC._get_xla_sharding_spec(t)
165+
if "<=" not in sharding:
166+
# This is already in the V1 format
167+
return sharding
168+
sharding_params = torch_xla._XLAC._get_xla_op_sharding_v2_params(t)
169+
assert sharding_params is not None
170+
tile_assignment_dims, reshape_dims, transpose_perm, replicate_on_last_dim = sharding_params
171+
num_devices = np.prod(reshape_dims)
172+
device_list = np.arange(num_devices).reshape(reshape_dims).transpose(
173+
transpose_perm).reshape(num_devices)
174+
175+
tile_assignment_str = ",".join(str(dim) for dim in tile_assignment_dims)
176+
device_list_str = ",".join(str(i) for i in device_list)
177+
replicate_str = " last_tile_dim_replicate" if replicate_on_last_dim else ""
178+
return f"{{devices=[{tile_assignment_str}]{device_list_str}{replicate_str}}}"
179+
180+
160181
def visualize_tensor_sharding(t, **kwargs):
161182
"""Visualizes an array's sharding."""
162183

163184
# XLAShardedTensor is-a torch.Tensor
164185
def maybe_unwrap(t: torch.Tensor) -> torch.Tensor:
165186
return t.global_tensor if isinstance(t, XLAShardedTensor) else t
166187

167-
sharding = torch_xla._XLAC._get_xla_sharding_spec(maybe_unwrap(t))
188+
t = maybe_unwrap(t)
189+
sharding = construct_v1_sharding_str(t)
190+
168191
return visualize_sharding(sharding, **kwargs)

0 commit comments

Comments
 (0)