Skip to content

Commit e0cedaf

Browse files
committed
[Experimental] Add initial implementation of GSPMD->Shardy pass within PyTorch/XLA (#1)
Adds an environment variable CONVERT_SHLO_TO_SHARDY that does 2 things: - Uses V2 sharding annotations when generating the GSPMD SHLO module (i.e., in V1 a mesh annotation string like: devices=[2,1,4]0,1,2,3,4,5,6,7 becomes this in V2: devices=[2,1,4]<=[8]). - Converts the new GSPMD module with the V2 annotations into a Shardy module.
1 parent 763e5b7 commit e0cedaf

File tree

8 files changed

+112
-1
lines changed

8 files changed

+112
-1
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,12 +1561,19 @@ void InitXlaModuleBindings(py::module m) {
15611561

15621562
// Define the _XLAC.OpSharding class.
15631563
PythonScope<py::class_<xla::OpSharding>>(m, "OpSharding")
1564+
// Constructor for V1 shardings
15641565
.def_init([](const py::list& tile_assignment,
15651566
const py::list& group_assignment,
15661567
const py::list& replication_groups, int sharding_type) {
15671568
return ShardingUtil::CreateOpSharding(
15681569
tile_assignment, group_assignment, replication_groups,
15691570
ShardingUtil::ShardingType(sharding_type));
1571+
})
1572+
// Constructor for V2 shardings.
1573+
.def_init([](const py::list& dims, const py::list& reshape_dims,
1574+
const py::list& transpose_perm) {
1575+
return ShardingUtil::CreateIotaOpSharding(dims, reshape_dims,
1576+
transpose_perm);
15701577
});
15711578

15721579
// Define the _XLAC.PjRtPlugin class.

torch_xla/csrc/runtime/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ cc_library(
366366
"@xla//xla/mlir_hlo:all_passes",
367367
"@xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo",
368368
"@xla//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo",
369+
"@xla//xla/service/spmd/shardy/stablehlo_round_trip:stablehlo_import",
369370
],
370371
)
371372

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "torch_xla/csrc/runtime/env_vars.h"
1515
#include "torch_xla/csrc/runtime/pjrt_registry.h"
1616
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
17+
#include "torch_xla/csrc/runtime/sys_util.h"
1718
#include "torch_xla/csrc/runtime/tensor_source.h"
1819
#include "torch_xla/csrc/runtime/tf_logging.h"
1920
#include "torch_xla/csrc/runtime/util.h"
@@ -638,6 +639,9 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
638639
mlir::ModuleOp mlir_module =
639640
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
640641
ConvertHloToStableHlo(instance.computation.mutable_proto(), &mlir_module);
642+
if (runtime::sys_util::GetEnvBool("CONVERT_SHLO_TO_SHARDY", false)) {
643+
ConvertStableHloToSdy(&mlir_module);
644+
}
641645
executable = util::RaisePythonValueErrorOnFailure([&] {
642646
return fake_xla_compile_
643647
? fake_xla_compile_()

torch_xla/csrc/runtime/stablehlo_helper.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h"
1919
#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h"
2020
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
21+
#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h"
2122

2223
namespace torch_xla {
2324

@@ -89,6 +90,7 @@ static absl::Status mhloToStablehloHelper(mlir::ModuleOp* mlir_module,
8990
torch_xla::runtime::CreateRemoveXlaMarkTensorOpsPass());
9091
pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
9192
pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
93+
9294
if (!mlir::succeeded(pm.run(*mlir_module))) {
9395
return absl::Status(
9496
absl::StatusCode::kInternal,
@@ -111,6 +113,14 @@ void ConvertHloToStableHlo(const xla::HloModuleProto* proto,
111113
<< getHloModuleStr(proto);
112114
}
113115

116+
void ConvertStableHloToSdy(mlir::ModuleOp* mlir_module) {
117+
mlir::PassManager pm(mlir_module->getContext());
118+
xla::sdy::addStablehloImportPipeline(pm, false, false);
119+
if (!mlir::succeeded(pm.run(*mlir_module))) {
120+
XLA_ERROR() << "StableHLO -> SDY conversion failed.\n";
121+
}
122+
}
123+
114124
std::string hloToStablehlo(const xla::HloModuleProto* proto,
115125
bool emit_bytecode) {
116126
mlir::MLIRContext context;

torch_xla/csrc/runtime/stablehlo_helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ namespace torch_xla {
1313
std::string hloToStablehlo(const xla::HloModuleProto* proto,
1414
bool emit_bytecode);
1515

16+
void ConvertStableHloToSdy(mlir::ModuleOp* mlir_module);
17+
1618
void ConvertHloToStableHlo(const xla::HloModuleProto* proto,
1719
mlir::ModuleOp* mlir_module);
1820

torch_xla/csrc/xla_sharding_util.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,23 @@ bool ShardingUtil::EqualOpShardings(const xla::OpSharding& a,
218218
return xla::protobuf_util::HaveSameSerialization(a, b);
219219
}
220220

221+
xla::OpSharding ShardingUtil::CreateIotaOpSharding(
222+
const py::list& dims, const py::list& reshape_dims,
223+
const py::list& transpose_perm) {
224+
auto dims_vec = dims.cast<std::vector<int64_t>>();
225+
auto reshape_dims_vec = reshape_dims.cast<std::vector<int64_t>>();
226+
auto transpose_perm_vec = transpose_perm.cast<std::vector<int>>();
227+
std::vector<xla::OpSharding::Type> subgroup_types;
228+
if (dims_vec.size() > transpose_perm.size()) {
229+
subgroup_types.push_back(xla::OpSharding::REPLICATED);
230+
}
231+
return xla::HloSharding::Subgroup(
232+
xla::TileAssignment(dims_vec, reshape_dims_vec,
233+
transpose_perm_vec),
234+
subgroup_types)
235+
.ToProto();
236+
}
237+
221238
xla::OpSharding ShardingUtil::CreateOpSharding(
222239
const py::list& tile_assignment, const py::list& group_assignment,
223240
const py::list& replication_groups, ShardingType sharding_type) {

torch_xla/csrc/xla_sharding_util.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ class ShardingUtil {
5151
const py::list& group_assignment,
5252
const py::list& replication_groups,
5353
ShardingType sharding_type);
54+
// Creates an xla::OpSharding for TILED and PARTIAL types using the
55+
// HloShardingV2 system.
56+
static xla::OpSharding CreateIotaOpSharding(const py::list& dims,
57+
const py::list& reshape_dims,
58+
const py::list& transpose_perm);
5459

5560
// Returns the shape of the resulting shards of `tensor` after applying
5661
// `sharding`. This assumes the shards will be padded to ensure they all

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections
22
from collections.abc import Generator, MutableMapping
33
import math
4+
import os
45
from collections import OrderedDict, defaultdict
56
from dataclasses import dataclass, field
67
import torch
@@ -118,9 +119,18 @@ def get_axis_name_idx(self, name: str) -> int:
118119
return None
119120
return self.axis_names.index(name)
120121

122+
def _validate_translated_partition_spec(self, partition_spec: tuple):
123+
flat_specs = np.hstack([d for d in partition_spec])
124+
specs = [d for d in flat_specs if d is not None]
125+
assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \
126+
f"partition_spec ({partition_spec}) contains out of bound index into mesh_shape."
127+
assert len(specs) == len(np.unique(specs)), \
128+
f"Each device mesh dimension should appear at most once in partition_spec {partition_spec}."
129+
121130
@functools.lru_cache(maxsize=None)
122131
def _get_op_sharding_args(self, partition_spec: PartitionSpec):
123132
partition_spec = _translate_named_partition_spec(self, partition_spec)
133+
self._validate_translated_partition_spec(partition_spec)
124134
flat_specs = np.hstack([d for d in partition_spec])
125135
specs = [d for d in flat_specs if d is not None]
126136
assert all(d >= 0 and d < len(self.mesh_shape) for d in specs), \
@@ -142,6 +152,57 @@ def _get_op_sharding_args(self, partition_spec: PartitionSpec):
142152
sharding_type = int(sharding_type)
143153
return tile_assignment, group_assignment, replication_groups, sharding_type
144154

155+
@functools.lru_cache(maxsize=None)
156+
def _get_op_sharding_args_v2(self, partition_spec: PartitionSpec):
157+
"""
158+
Returns the appropriate dims, reshape_dims, and transpose_perm for the given partition spec.
159+
"""
160+
partition_spec = _translate_named_partition_spec(self, partition_spec)
161+
self._validate_translated_partition_spec(partition_spec)
162+
163+
dims = []
164+
used_axes = OrderedDict()
165+
for axis in partition_spec:
166+
if isinstance(axis, tuple):
167+
dim_size = 1
168+
for i in axis:
169+
assert i is not None, "None not allowed within tuple"
170+
dim_size *= self.mesh_shape[i]
171+
used_axes[i] = True
172+
dims.append(dim_size)
173+
elif axis is not None:
174+
assert isinstance(axis, int), "Axis must be an int or a tuple of ints"
175+
dims.append(self.mesh_shape[axis])
176+
used_axes[axis] = True
177+
else:
178+
# Replicated mesh axis
179+
dims.append(1)
180+
181+
transpose_perm = [k for k in used_axes.keys()]
182+
for i in range(len(self.mesh_shape)):
183+
if i not in used_axes:
184+
dims.append(self.mesh_shape[i])
185+
transpose_perm.append(i)
186+
reshape_dims = list(self.mesh_shape)
187+
188+
return dims, reshape_dims, transpose_perm
189+
190+
@functools.lru_cache(maxsize=None)
191+
def get_op_sharding_v2(
192+
self, partition_spec: PartitionSpec) -> torch_xla._XLAC.OpSharding:
193+
"""
194+
Return the OpSharding for the given partition spec using V2 annotations.
195+
"""
196+
if len(partition_spec) == 0:
197+
return torch_xla._XLAC.OpSharding([], [], [], ShardingType.REPLICATED)
198+
sharding_type = _get_sharding_type(partition_spec, self.size())
199+
if sharding_type not in (ShardingType.TILED, ShardingType.PARTIAL):
200+
return torch_xla._XLAC.OpSharding([], [], [0], sharding_type)
201+
202+
dims, reshape_dims, transpose_perm = self._get_op_sharding_args_v2(
203+
partition_spec)
204+
return torch_xla._XLAC.OpSharding(dims, reshape_dims, transpose_perm)
205+
145206
@functools.lru_cache(maxsize=None)
146207
def get_op_sharding(
147208
self, partition_spec: PartitionSpec) -> torch_xla._XLAC.OpSharding:
@@ -157,6 +218,7 @@ def get_op_sharding(
157218

158219
tile_assignment, group_assignment, replication_groups, sharding_type = self._get_op_sharding_args(
159220
partition_spec)
221+
160222
return torch_xla._XLAC.OpSharding(tile_assignment, group_assignment,
161223
replication_groups, sharding_type)
162224

@@ -653,7 +715,10 @@ def mark_sharding(t: Union[torch.Tensor, XLAShardedTensor], mesh: Mesh,
653715
t.shard_(NamedSharding(jmesh, P(*partition_spec)))
654716
return t
655717

656-
op_sharding = mesh.get_op_sharding(partition_spec)
718+
if os.environ.get('CONVERT_SHLO_TO_SHARDY', False):
719+
op_sharding = mesh.get_op_sharding_v2(partition_spec)
720+
else:
721+
op_sharding = mesh.get_op_sharding(partition_spec)
657722
annotate_func = torch_xla._XLAC._xla_mark_sharding
658723
annotate_func(unwrap_sharded_tensor(t), op_sharding)
659724
# Pass mesh and partition spec information for DTensor compatibility

0 commit comments

Comments
 (0)