Skip to content

Commit 49b0b3b

Browse files
authored
Fix usage of dimensions_size() to check for tuple (#9347)
1 parent edd38ca commit 49b0b3b

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

test/test_mp_reduce_scatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def _mp_fn(index):
1313
shard_size = 2
1414
input_list_size = 5
1515

16-
if xm.xla_device_hw(device) in ['TPU', 'CUDA']:
16+
if xm.xla_device_hw(device) in ['TPU', 'CUDA', 'CPU']:
1717
rand = torch.rand((32, shard_size * world_size, 32))
1818
xrand = rand.to(device)
1919

torch_xla/csrc/cross_replica_reduces.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ ReduceContext GetReduceContext(absl::Span<const xla::XlaOp> operands) {
6262
return redux;
6363
}
6464

65-
xla::XlaComputation GetReduceComutation(AllReduceType reduce_type,
66-
xla::PrimitiveType type) {
65+
xla::XlaComputation GetReduceComputation(AllReduceType reduce_type,
66+
xla::PrimitiveType type) {
6767
switch (reduce_type) {
6868
case AllReduceType::kSum:
6969
return XlaHelpers::CreateAddComputation(type);
@@ -153,14 +153,14 @@ std::vector<xla::XlaOp> BuildAllReduce(
153153
if (pin_layout) {
154154
reduce = xla::AllReduce(
155155
xla::Tuple(operands[0].builder(), type_ctx.second.ops),
156-
GetReduceComutation(reduce_type, type_ctx.first), reduce_groups,
156+
GetReduceComputation(reduce_type, type_ctx.first), reduce_groups,
157157
/*channel_id=*/absl::nullopt,
158158
/*shape_with_layout=*/
159159
MakeReduceShape(type_ctx.second.operand_shapes));
160160
} else {
161161
reduce = xla::AllReduce(
162162
xla::Tuple(operands[0].builder(), type_ctx.second.ops),
163-
GetReduceComutation(reduce_type, type_ctx.first), reduce_groups);
163+
GetReduceComputation(reduce_type, type_ctx.first), reduce_groups);
164164
}
165165
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
166166
size_t op_idx = type_ctx.second.indices[i];
@@ -192,7 +192,7 @@ xla::XlaOp BuildAllReduce(AllReduceType reduce_type, xla::XlaOp input,
192192
channel_handle.set_handle(1);
193193
channel_handle.set_type(xla::ChannelHandle::DEVICE_TO_DEVICE);
194194
auto reduce_result = xla::AllReduce(
195-
input, GetReduceComutation(reduce_type, input_shape.element_type()),
195+
input, GetReduceComputation(reduce_type, input_shape.element_type()),
196196
std::move(reduce_groups), std::move(channel_handle), std::nullopt, true);
197197
if (scale != 1.0) {
198198
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
@@ -426,13 +426,13 @@ ReduceScatterResult BuildReduceScatter(
426426
static_cast<XlaDeviceType>(xla_device.type()));
427427
reduce_result = xla::ReduceScatter(
428428
token_handler.GetInput(input, &input_shape),
429-
GetReduceComutation(reduce_type, input_shape.element_type()),
429+
GetReduceComputation(reduce_type, input_shape.element_type()),
430430
scatter_dim, shard_count, reduce_groups, channel_handle,
431431
/*layout=*/reduce_shape.layout(), use_global_device_ids);
432432
} else {
433433
reduce_result = xla::ReduceScatter(
434434
token_handler.GetInput(input, &input_shape),
435-
GetReduceComutation(reduce_type, input_shape.element_type()),
435+
GetReduceComputation(reduce_type, input_shape.element_type()),
436436
scatter_dim, shard_count, reduce_groups, channel_handle,
437437
/*layout=*/std::nullopt, use_global_device_ids);
438438
}
@@ -459,7 +459,7 @@ xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input,
459459
channel_handle.set_type(xla::ChannelHandle::DEVICE_TO_DEVICE);
460460
xla::XlaOp reduce_result;
461461
reduce_result = xla::ReduceScatter(
462-
input, GetReduceComutation(reduce_type, input_shape.element_type()),
462+
input, GetReduceComputation(reduce_type, input_shape.element_type()),
463463
scatter_dim, shard_count, std::move(reduce_groups),
464464
std::move(channel_handle), std::nullopt, true);
465465
if (scale != 1.0) {
@@ -528,20 +528,20 @@ ReduceScatterResultCoalesced BuildReduceScatterCoalesced(
528528
if (pin_layout) {
529529
reduce_result = xla::ReduceScatter(
530530
xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
531-
GetReduceComutation(reduce_type, type_ctx.first), scatter_dim,
531+
GetReduceComputation(reduce_type, type_ctx.first), scatter_dim,
532532
shard_count, cc_groups, /*channel_id=*/absl::nullopt,
533533
/*layout=*/
534534
MakeReduceShape(type_ctx.second.operand_shapes).layout());
535535
} else {
536536
reduce_result = xla::ReduceScatter(
537537
xla::Tuple(inputs[0].builder(), type_ctx.second.ops),
538-
GetReduceComutation(reduce_type, type_ctx.first), scatter_dim,
538+
GetReduceComputation(reduce_type, type_ctx.first), scatter_dim,
539539
shard_count, cc_groups);
540540
}
541541
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
542542
size_t op_idx = type_ctx.second.indices[i];
543543
xla::XlaOp gte;
544-
if (ShapeHelper::ShapeOfXlaOp(reduce_result).dimensions_size() == 0) {
544+
if (ShapeHelper::ShapeOfXlaOp(reduce_result).IsTuple()) {
545545
gte = xla::GetTupleElement(reduce_result, i);
546546
} else {
547547
gte = reduce_result;

0 commit comments

Comments
 (0)