Skip to content

Commit 938250c

Browse files
author
kvshbg-aws
committed
feat: abstraction of xla::OpSharding proto using wrapper class
1 parent 93a5e58 commit 938250c

20 files changed

+588
-182
lines changed

test/cpp/test_xla_sharding.cpp

Lines changed: 65 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,18 @@ TEST_F(XLAShardingTest, GetShardShape) {
5050
{0, 1},
5151
{2, 3},
5252
});
53-
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
53+
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
54+
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
5455
auto sharding_spec =
5556
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
5657

5758
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
5859
// For tiled sharding, each dimension should be halved
5960
EXPECT_EQ(shard_shape, std::vector<int64_t>({4, 4}));
6061

61-
sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
62+
xla_sharding = xla::HloSharding::Replicate().ToProto();
63+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
64+
sharding_spec->sharding = sharding;
6265
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
6366
// For replicated sharding, each dimension should be preserved
6467
EXPECT_EQ(shard_shape, std::vector<int64_t>({8, 7}));
@@ -74,7 +77,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
7477
{0, 1},
7578
{2, 3},
7679
});
77-
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
80+
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
81+
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
7882
auto sharding_spec =
7983
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
8084
auto shard_shape = ShardingUtil::GetShardShape(sharding_spec);
@@ -103,7 +107,8 @@ TEST_F(XLAShardingTest, GetShardIndicesForDevices) {
103107
EXPECT_EQ(slice.step(), 1);
104108
}
105109
}
106-
sharding = xla::HloSharding::Replicate().ToProto();
110+
xla_sharding = xla::HloSharding::Replicate().ToProto();
111+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
107112
sharding_spec->sharding = sharding;
108113
shard_shape = ShardingUtil::GetShardShape(sharding_spec);
109114
replica_and_indices = ShardingUtil::GetShardReplicaAndIndicesForDevices(
@@ -126,11 +131,12 @@ TEST_F(XLAShardingTest, ShardTensor) {
126131
at::Tensor tensor = at::ones({8}, at::TensorOptions(at::kFloat));
127132
xla::Shape tensor_shape =
128133
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
129-
xla::OpSharding sharding =
134+
xla::OpSharding xla_sharding =
130135
xla::HloSharding::Tile1D(
131136
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice()),
132137
devices.size())
133138
.ToProto();
139+
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
134140
auto sharding_spec =
135141
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
136142
auto shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -148,7 +154,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
148154
{0, 1, 2, 3},
149155
{4, 5, 6, 7},
150156
});
151-
sharding = xla::HloSharding::Tile(mesh).ToProto();
157+
xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
158+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
152159
sharding_spec =
153160
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
154161
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -160,15 +167,19 @@ TEST_F(XLAShardingTest, ShardTensor) {
160167
// 3D tiled, the first dim is replicated and the last halved. The last shard
161168
// size should be smaller in dim=1 because it's not evenly divisible.
162169
xla::Array3D<int64_t> cube({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}});
163-
sharding_spec->sharding = xla::HloSharding::Tile(cube).ToProto();
170+
xla_sharding = xla::HloSharding::Tile(cube).ToProto();
171+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
172+
sharding_spec->sharding = sharding;
164173
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
165174
/*padded=*/false);
166175
EXPECT_EQ(shards.size(), 8);
167176
EXPECT_EQ(shards[0].sizes(), c10::ArrayRef<long>({8, 2, 2}));
168177
EXPECT_EQ(shards[7].sizes(), c10::ArrayRef<long>({8, 1, 2}));
169178

170179
// Replicated, all shards should be identical.
171-
sharding_spec->sharding = xla::HloSharding::Replicate().ToProto();
180+
xla_sharding = xla::HloSharding::Replicate().ToProto();
181+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
182+
sharding_spec->sharding = sharding;
172183
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
173184
/*padded=*/false);
174185
EXPECT_EQ(shards.size(), 8);
@@ -182,7 +193,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
182193
tensor_shape =
183194
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
184195
xla::Array4D<int64_t> tesseract({{{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}});
185-
sharding = xla::HloSharding::Tile(tesseract).ToProto();
196+
xla_sharding = xla::HloSharding::Tile(tesseract).ToProto();
197+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
186198
sharding_spec =
187199
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
188200
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -206,7 +218,8 @@ TEST_F(XLAShardingTest, ShardTensor) {
206218
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
207219
xla::Array<int64_t> hypercube(std::vector<int64_t>{1, 1, 2, 2, 2});
208220
hypercube.FillIota(0);
209-
sharding = xla::HloSharding::Tile(hypercube).ToProto();
221+
xla_sharding = xla::HloSharding::Tile(hypercube).ToProto();
222+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
210223
sharding_spec =
211224
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
212225
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
@@ -234,7 +247,8 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
234247
{4, 5, 0, 1},
235248
{6, 7, 2, 3},
236249
});
237-
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
250+
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
251+
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
238252
auto sharding_spec =
239253
std::make_shared<XLATensor::ShardingSpec>(sharding, tensor_shape);
240254
// For devices at the start of the mesh, all shards should have the same
@@ -251,7 +265,9 @@ TEST_F(XLAShardingTest, ShardTensorMultiHost) {
251265
{0, 1, 4, 5},
252266
{2, 3, 6, 7},
253267
});
254-
sharding_spec->sharding = xla::HloSharding::Tile(mesh).ToProto();
268+
xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
269+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
270+
sharding_spec->sharding = sharding;
255271
shards = ShardingUtil::ShardTensor(tensor, sharding_spec, devices,
256272
/*padded=*/false);
257273
EXPECT_EQ(shards.size(), 4);
@@ -278,7 +294,8 @@ TEST_F(XLAShardingTest, ShardTensorMiniBatch) {
278294
{{7}},
279295
});
280296

281-
auto sharding = xla::HloSharding::Tile(mesh).ToProto();
297+
auto xla_sharding = xla::HloSharding::Tile(mesh).ToProto();
298+
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
282299
auto sharding_spec = std::make_shared<XLATensor::ShardingSpec>(
283300
sharding, global_shape, /*minibatch=*/true);
284301
auto shards = ShardingUtil::ShardTensor(minibatch_tensor, sharding_spec,
@@ -292,17 +309,20 @@ TEST_F(XLAShardingTest, EqualShardingSpecs) {
292309
auto tensor = at::ones({8, 7}, at::TensorOptions(at::kFloat));
293310
xla::Shape tensor_shape =
294311
CreateComputationShapeFromTensor(tensor, bridge::GetDefaultDevice());
295-
XLATensor::ShardingSpec tiled_2d(xla::HloSharding::Tile({
296-
{0, 1, 2, 3},
297-
{4, 5, 6, 7},
298-
})
299-
.ToProto(),
300-
tensor_shape);
301-
XLATensor::ShardingSpec tiled_3d(
302-
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto(),
303-
tensor_shape);
304-
XLATensor::ShardingSpec replicated(xla::HloSharding::Replicate().ToProto(),
305-
tensor_shape);
312+
auto xla_sharding = xla::HloSharding::Tile({
313+
{0, 1, 2, 3},
314+
{4, 5, 6, 7},
315+
})
316+
.ToProto();
317+
torch_xla::OpSharding sharding(xla_sharding, std::nullopt);
318+
XLATensor::ShardingSpec tiled_2d(sharding, tensor_shape);
319+
xla_sharding =
320+
xla::HloSharding::Tile({{{0, 1}, {2, 3}, {4, 5}, {6, 7}}}).ToProto();
321+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
322+
XLATensor::ShardingSpec tiled_3d(sharding, tensor_shape);
323+
xla_sharding = xla::HloSharding::Replicate().ToProto();
324+
sharding = torch_xla::OpSharding(xla_sharding, std::nullopt);
325+
XLATensor::ShardingSpec replicated(sharding, tensor_shape);
306326
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_2d));
307327
EXPECT_FALSE(ShardingUtil::EqualShardingSpecs(tiled_2d, tiled_3d));
308328
EXPECT_TRUE(ShardingUtil::EqualShardingSpecs(replicated, replicated));
@@ -323,12 +343,17 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
323343
std::vector<std::string> devices(3);
324344
std::fill_n(devices.begin(), devices.size(),
325345
bridge::GetDefaultDevice()->toString());
346+
auto replicate_xla_sharding = xla::HloSharding::Replicate().ToProto();
347+
auto unknown_xla_sharding = xla::HloSharding::Unknown().ToProto();
348+
torch_xla::OpSharding replicate_sharding(replicate_xla_sharding,
349+
std::nullopt);
350+
torch_xla::OpSharding unknown_sharding(unknown_xla_sharding, std::nullopt);
326351
std::vector<XLATensor::ShardingSpecPtr> shardings = {
327352
nullptr,
328-
std::make_shared<XLATensor::ShardingSpec>(
329-
xla::HloSharding::Replicate().ToProto(), tensor_shape),
330-
std::make_shared<XLATensor::ShardingSpec>(
331-
xla::HloSharding::Unknown().ToProto(), tensor_shape)};
353+
std::make_shared<XLATensor::ShardingSpec>(replicate_sharding,
354+
tensor_shape),
355+
std::make_shared<XLATensor::ShardingSpec>(unknown_sharding,
356+
tensor_shape)};
332357
std::vector<torch::lazy::BackendDataPtr> tensors_data =
333358
CreateTensorsData(tensors, shardings, devices);
334359

@@ -387,13 +412,21 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
387412
auto y = xla::Add(x, xla::ConstantR0<float>(&b, 3));
388413
xla::XlaComputation xla_computation =
389414
GetValueOrThrow(b.Build(/*remove_dynamic_dimensions=*/false));
415+
416+
std::vector<torch::lazy::BackendDataPtr> parameters_data;
417+
parameters_data.push_back(
418+
torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder(
419+
bridge::GetDefaultDevice()->toString(), std::move(shape)));
420+
390421
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
391422
instances.push_back({std::move(xla_computation),
392423
bridge::GetDefaultDevice()->toString(),
393424
{bridge::GetDefaultDevice()->toString()},
394425
&shape,
395426
/*should_wrap_parameter=*/false,
396-
/*is_sharded=*/true});
427+
/*is_sharded=*/true,
428+
/*allow_spmd_sharding_propagation_to_output=*/true,
429+
/*parameters_data=*/parameters_data});
397430

398431
std::vector<
399432
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
@@ -417,11 +450,12 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
417450
if (n_devices > 1) {
418451
// Tiled sharding requires multiple devices.
419452
EXPECT_TRUE(xla::protobuf_util::HaveSameSerialization(
420-
tiled, sharding_specs[0]->sharding));
453+
tiled, sharding_specs[0]->sharding.GetXlaOpSharding()));
421454
} else {
422455
// Sincle device execution defaults to replication sharding.
423456
EXPECT_TRUE(xla::protobuf_util::HaveSameSerialization(
424-
xla::HloSharding::Replicate().ToProto(), sharding_specs[0]->sharding));
457+
xla::HloSharding::Replicate().ToProto(),
458+
sharding_specs[0]->sharding.GetXlaOpSharding()));
425459
}
426460

427461
// Check if the placeholder is on a SPMD device (sharded) with no real values.

torch_xla/csrc/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ ptxla_cc_library(
126126
":shape_builder",
127127
":shape_helper",
128128
":status",
129+
":torch_xla_op_sharding",
129130
":version",
130131
"//torch_xla/csrc:hash_util",
131132
"//torch_xla/csrc:thread_pool",
@@ -313,6 +314,7 @@ ptxla_cc_library(
313314
":shape_helper",
314315
":status",
315316
":unwrap_data",
317+
":torch_xla_op_sharding",
316318
"//torch_xla/csrc/runtime:cache",
317319
"//torch_xla/csrc/runtime:computation_client",
318320
"@com_google_absl//absl/log:absl_check",
@@ -382,3 +384,13 @@ cc_library(
382384
"@com_google_absl//absl/status:statusor",
383385
],
384386
)
387+
388+
cc_library(
389+
name = "torch_xla_op_sharding",
390+
srcs = ["torch_xla_op_sharding.cpp"],
391+
hdrs = ["torch_xla_op_sharding.h"],
392+
deps = [
393+
"//torch_xla/csrc/runtime:debug_macros",
394+
"@xla//xla/hlo/builder:xla_builder",
395+
],
396+
)

torch_xla/csrc/debug_util.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "torch_xla/csrc/runtime/sys_util.h"
2222
#include "torch_xla/csrc/runtime/xla_util.h"
2323
#include "torch_xla/csrc/status.h"
24+
#include "torch_xla/csrc/torch_xla_op_sharding.h"
2425
#include "torch_xla/csrc/xla_graph_executor.h"
2526

2627
namespace torch_xla {
@@ -218,7 +219,8 @@ void DebugUtil::SaveOutputShardingInfo(std::vector<XLATensorPtr>* tensors,
218219
auto xtensor = (*tensors)[indices[i]];
219220
ss << xtensor->shape().get().ToString() << " ";
220221
if (xtensor->sharding_spec()) {
221-
ss << xla::HloSharding::FromProto(xtensor->sharding_spec()->sharding)
222+
ss << xla::HloSharding::FromProto(
223+
xtensor->sharding_spec()->sharding.GetXlaOpSharding())
222224
->ToString();
223225
} else {
224226
ss << xla::HloSharding::FromProto(xla::HloSharding::Unknown().ToProto())

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
#include "torch_xla/csrc/tensor_methods.h"
7070
#include "torch_xla/csrc/tensor_util.h"
7171
#include "torch_xla/csrc/torch_util.h"
72+
#include "torch_xla/csrc/torch_xla_op_sharding.h"
7273
#include "torch_xla/csrc/version.h"
7374
#include "torch_xla/csrc/xla_backend_impl.h"
7475
#include "torch_xla/csrc/xla_graph_executor.h"
@@ -706,7 +707,8 @@ std::string GetTensorsHloGraph(const std::vector<at::Tensor>& tensors,
706707
std::string GetXLAShardingSpec(const XLATensorPtr xtensor) {
707708
auto sharding_spec = xtensor->sharding_spec();
708709
if (sharding_spec != nullptr) {
709-
auto hlo_sharding = xla::HloSharding::FromProto(sharding_spec->sharding);
710+
auto hlo_sharding =
711+
xla::HloSharding::FromProto(sharding_spec->sharding.GetXlaOpSharding());
710712
return hlo_sharding->ToString();
711713
}
712714
return std::string();
@@ -1503,7 +1505,7 @@ void InitXlaModuleBindings(py::module m) {
15031505
runtime::ComputationClient::ComputationPtr>(m, "XlaComputation");
15041506

15051507
// Define the _XLAC.OpSharding class.
1506-
PythonScope<py::class_<xla::OpSharding>>(m, "OpSharding")
1508+
PythonScope<py::class_<torch_xla::OpSharding>>(m, "OpSharding")
15071509
.def_init([](const py::list& tile_assignment,
15081510
const py::list& group_assignment,
15091511
const py::list& replication_groups, int sharding_type) {
@@ -2268,6 +2270,7 @@ void InitXlaModuleBindings(py::module m) {
22682270
[](const std::vector<at::Tensor>& tensors, const std::string& device,
22692271
const std::vector<std::string>& devices,
22702272
bool emit_bytecode) -> py::bytes {
2273+
NoGilSection nogil;
22712274
EmitMode mode = emit_bytecode ? EmitMode::kStableHloBytecode
22722275
: EmitMode::kStableHloReadable;
22732276
std::vector<XLATensorPtr> xtensors;
@@ -2504,16 +2507,16 @@ void InitXlaModuleBindings(py::module m) {
25042507
}
25052508
})
25062509
.def("_xla_mark_sharding",
2507-
[](const at::Tensor& input, xla::OpSharding sharding) {
2510+
[](const at::Tensor& input, torch_xla::OpSharding sharding) {
25082511
ShardingUtil::XlaMarkSharding(input, sharding);
25092512
})
25102513
.def("_xla_annotate_custom_sharding",
2511-
[](const at::Tensor& input, xla::OpSharding sharding) {
2514+
[](const at::Tensor& input, torch_xla::OpSharding sharding) {
25122515
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
25132516
ShardingUtil::XlaAnnotateCustomSharding(xtensor, sharding);
25142517
})
25152518
.def("_mark_manual_sharding",
2516-
[](const at::Tensor& input, xla::OpSharding sharding) {
2519+
[](const at::Tensor& input, torch_xla::OpSharding sharding) {
25172520
XLA_CHECK(IsNonDeviceDataIR(input))
25182521
<< "Marking any data tensors as manual is not supported";
25192522
ShardingUtil::XlaMarkSharding(input, sharding);
@@ -2533,13 +2536,14 @@ void InitXlaModuleBindings(py::module m) {
25332536
xtensor->CreateFrom(torch_xla::MakeNode<CustomSharding>(
25342537
xtensor->GetIrValue(), shard_shape,
25352538
CustomSharding::Type::kSPMDFullToShardShape));
2536-
output->SetShardingSpec(XLATensor::ShardingSpec(
2537-
xla::HloSharding::Manual().ToProto(), shard_shape));
2539+
torch_xla::OpSharding sharding(xla::HloSharding::Manual().ToProto(),
2540+
sharding_spec->sharding.GetDenormalizedTileAssignment());
2541+
output->SetShardingSpec(XLATensor::ShardingSpec(sharding, shard_shape));
25382542
return bridge::AtenFromXlaTensor(output);
25392543
})
25402544
.def(
25412545
"_spmd_shard_to_full_shape",
2542-
[](const at::Tensor& input, const xla::OpSharding& sharding,
2546+
[](const at::Tensor& input, const torch_xla::OpSharding& sharding,
25432547
const std::vector<int64_t>& output_shape,
25442548
const py::object& output_dtype) -> at::Tensor {
25452549
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
@@ -2578,7 +2582,7 @@ void InitXlaModuleBindings(py::module m) {
25782582
XLATensor::ShardingSpecPtr sharding_spec =
25792583
xtensor ? xtensor->sharding_spec() : nullptr;
25802584
if (sharding_spec != nullptr) {
2581-
return sharding_spec->sharding;
2585+
return sharding_spec->sharding.GetXlaOpSharding();
25822586
}
25832587
return std::nullopt;
25842588
})
@@ -2613,7 +2617,7 @@ void InitXlaModuleBindings(py::module m) {
26132617
// `torch_xla.runtime.local_runtime_devices()`.
26142618
"_global_tensor_from_cpu_shards",
26152619
[](const std::vector<at::Tensor>& shards,
2616-
const xla::OpSharding& sharding,
2620+
const torch_xla::OpSharding& sharding,
26172621
std::optional<std::vector<int64_t>>& global_shape) -> at::Tensor {
26182622
XLA_CHECK(UseVirtualDevice())
26192623
<< "Please enable SPMD via `torch_xla.runtime.use_spmd()`";

0 commit comments

Comments
 (0)