@@ -62,8 +62,8 @@ ReduceContext GetReduceContext(absl::Span<const xla::XlaOp> operands) {
6262 return redux;
6363}
6464
65- xla::XlaComputation GetReduceComutation (AllReduceType reduce_type,
66- xla::PrimitiveType type) {
65+ xla::XlaComputation GetReduceComputation (AllReduceType reduce_type,
66+ xla::PrimitiveType type) {
6767 switch (reduce_type) {
6868 case AllReduceType::kSum :
6969 return XlaHelpers::CreateAddComputation (type);
@@ -153,14 +153,14 @@ std::vector<xla::XlaOp> BuildAllReduce(
153153 if (pin_layout) {
154154 reduce = xla::AllReduce (
155155 xla::Tuple (operands[0 ].builder (), type_ctx.second .ops ),
156- GetReduceComutation (reduce_type, type_ctx.first ), reduce_groups,
156+ GetReduceComputation (reduce_type, type_ctx.first ), reduce_groups,
157157 /* channel_id=*/ absl::nullopt ,
158158 /* shape_with_layout=*/
159159 MakeReduceShape (type_ctx.second .operand_shapes ));
160160 } else {
161161 reduce = xla::AllReduce (
162162 xla::Tuple (operands[0 ].builder (), type_ctx.second .ops ),
163- GetReduceComutation (reduce_type, type_ctx.first ), reduce_groups);
163+ GetReduceComputation (reduce_type, type_ctx.first ), reduce_groups);
164164 }
165165 for (size_t i = 0 ; i < type_ctx.second .indices .size (); ++i) {
166166 size_t op_idx = type_ctx.second .indices [i];
@@ -192,7 +192,7 @@ xla::XlaOp BuildAllReduce(AllReduceType reduce_type, xla::XlaOp input,
192192 channel_handle.set_handle (1 );
193193 channel_handle.set_type (xla::ChannelHandle::DEVICE_TO_DEVICE);
194194 auto reduce_result = xla::AllReduce (
195- input, GetReduceComutation (reduce_type, input_shape.element_type ()),
195+ input, GetReduceComputation (reduce_type, input_shape.element_type ()),
196196 std::move (reduce_groups), std::move (channel_handle), std::nullopt , true );
197197 if (scale != 1.0 ) {
198198 xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float >(
@@ -426,13 +426,13 @@ ReduceScatterResult BuildReduceScatter(
426426 static_cast <XlaDeviceType>(xla_device.type ()));
427427 reduce_result = xla::ReduceScatter (
428428 token_handler.GetInput (input, &input_shape),
429- GetReduceComutation (reduce_type, input_shape.element_type ()),
429+ GetReduceComputation (reduce_type, input_shape.element_type ()),
430430 scatter_dim, shard_count, reduce_groups, channel_handle,
431431 /* layout=*/ reduce_shape.layout (), use_global_device_ids);
432432 } else {
433433 reduce_result = xla::ReduceScatter (
434434 token_handler.GetInput (input, &input_shape),
435- GetReduceComutation (reduce_type, input_shape.element_type ()),
435+ GetReduceComputation (reduce_type, input_shape.element_type ()),
436436 scatter_dim, shard_count, reduce_groups, channel_handle,
437437 /* layout=*/ std::nullopt , use_global_device_ids);
438438 }
@@ -459,7 +459,7 @@ xla::XlaOp BuildReduceScatter(AllReduceType reduce_type, xla::XlaOp input,
459459 channel_handle.set_type (xla::ChannelHandle::DEVICE_TO_DEVICE);
460460 xla::XlaOp reduce_result;
461461 reduce_result = xla::ReduceScatter (
462- input, GetReduceComutation (reduce_type, input_shape.element_type ()),
462+ input, GetReduceComputation (reduce_type, input_shape.element_type ()),
463463 scatter_dim, shard_count, std::move (reduce_groups),
464464 std::move (channel_handle), std::nullopt , true );
465465 if (scale != 1.0 ) {
@@ -528,20 +528,20 @@ ReduceScatterResultCoalesced BuildReduceScatterCoalesced(
528528 if (pin_layout) {
529529 reduce_result = xla::ReduceScatter (
530530 xla::Tuple (inputs[0 ].builder (), type_ctx.second .ops ),
531- GetReduceComutation (reduce_type, type_ctx.first ), scatter_dim,
531+ GetReduceComputation (reduce_type, type_ctx.first ), scatter_dim,
532532 shard_count, cc_groups, /* channel_id=*/ absl::nullopt ,
533533 /* layout=*/
534534 MakeReduceShape (type_ctx.second .operand_shapes ).layout ());
535535 } else {
536536 reduce_result = xla::ReduceScatter (
537537 xla::Tuple (inputs[0 ].builder (), type_ctx.second .ops ),
538- GetReduceComutation (reduce_type, type_ctx.first ), scatter_dim,
538+ GetReduceComputation (reduce_type, type_ctx.first ), scatter_dim,
539539 shard_count, cc_groups);
540540 }
541541 for (size_t i = 0 ; i < type_ctx.second .indices .size (); ++i) {
542542 size_t op_idx = type_ctx.second .indices [i];
543543 xla::XlaOp gte;
544- if (ShapeHelper::ShapeOfXlaOp (reduce_result).dimensions_size () == 0 ) {
544+ if (ShapeHelper::ShapeOfXlaOp (reduce_result).IsTuple () ) {
545545 gte = xla::GetTupleElement (reduce_result, i);
546546 } else {
547547 gte = reduce_result;
0 commit comments