Skip to content

Commit 2c556e7

Browse files
authored
Improve error message for shape promotion on lowering. (pytorch#9486)
1 parent 8b95c5d commit 2c556e7

File tree

6 files changed

+49
-18
lines changed

6 files changed

+49
-18
lines changed

test/test_operations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,6 +2445,19 @@ def test_isneginf_no_fallback(self):
24452445
t = t.to(torch.float16)
24462446
self._test_no_fallback(torch.isneginf, (t,))
24472447

2448+
def test_add_broadcast_error(self):
2449+
a = torch.rand(2, 2, 4, 4, device="xla")
2450+
b = torch.rand(2, 2, device="xla")
2451+
2452+
expected_regex = (
2453+
r"Shapes are not compatible for broadcasting: f32\[2,2,4,4\] vs. f32\[2,2\]. "
2454+
r"Expected dimension 2 of shape f32\[2,2,4,4\] \(4\) to match dimension "
2455+
r"0 of shape f32\[2,2\] \(2\). .*")
2456+
2457+
with self.assertRaisesRegex(RuntimeError, expected_regex):
2458+
torch.add(a, b)
2459+
torch_xla.sync()
2460+
24482461

24492462
class MNISTComparator(nn.Module):
24502463

torch_xla/csrc/data_ops.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "torch_xla/csrc/runtime/sys_util.h"
1818
#include "torch_xla/csrc/runtime/util.h"
1919
#include "torch_xla/csrc/shape_helper.h"
20+
#include "torch_xla/csrc/status.h"
2021
#include "torch_xla/csrc/tensor_util.h"
2122
#include "xla/hlo/builder/lib/constants.h"
2223
#include "xla/hlo/builder/lib/slicing.h"
@@ -196,7 +197,8 @@ xla::XlaOp BuildMaskedFillScalar(xla::XlaOp input, xla::XlaOp mask,
196197
const xla::Shape& mask_shape = ShapeHelper::ShapeOfXlaOp(mask);
197198

198199
if (!xla::ShapeUtil::Compatible(input_shape, mask_shape)) {
199-
xla::Shape shape = XlaHelpers::GetPromotedShape(input_shape, mask_shape);
200+
xla::Shape shape =
201+
GetValueOrThrow(XlaHelpers::GetPromotedShape(input_shape, mask_shape));
200202
input = BuildExpand(input, shape.dimensions());
201203
mask = BuildExpand(mask, shape.dimensions());
202204
}

torch_xla/csrc/helpers.cpp

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -582,8 +582,8 @@ void ExtractDimensionSizesAndDynamicDimensionsFromShape(
582582

583583
} // namespace
584584

585-
xla::Shape XlaHelpers::GetPromotedShape(const xla::Shape& shape1,
586-
const xla::Shape& shape2) {
585+
absl::StatusOr<xla::Shape> XlaHelpers::GetPromotedShape(
586+
const xla::Shape& shape1, const xla::Shape& shape2) {
587587
std::vector<int64_t> dimensions;
588588
std::vector<bool> dynamic_dimensions;
589589

@@ -606,20 +606,33 @@ xla::Shape XlaHelpers::GetPromotedShape(const xla::Shape& shape1,
606606
size_t min_size =
607607
std::min(shape1.dimensions().size(), shape2.dimensions().size());
608608
for (size_t i = 0; i < min_size; i++) {
609-
int64_t dim1 =
610-
shape1.dimensions()[shape1.dimensions().size() - min_size + i];
609+
int64_t dim_index1 = shape1.dimensions().size() - min_size + i;
610+
int64_t dim_index2 = shape2.dimensions().size() - min_size + i;
611+
int64_t dim1 = shape1.dimensions()[dim_index1];
612+
int64_t dim2 = shape2.dimensions()[dim_index2];
613+
611614
int64_t dynamic_dim1 =
612615
shape1.dynamic_dimensions()[shape1.dynamic_dimensions().size() -
613616
min_size + i];
614-
int64_t dim2 =
615-
shape2.dimensions()[shape2.dimensions().size() - min_size + i];
616617
int64_t dynamic_dim2 =
617618
shape2.dynamic_dimensions()[shape2.dynamic_dimensions().size() -
618619
min_size + i];
619620

620-
XLA_CHECK(dim1 == dim2 || dim1 == 1 || dim2 == 1 ||
621-
dim1 == xla::Shape::kUnboundedSize ||
622-
dim2 == xla::Shape::kUnboundedSize);
621+
if (dim1 != dim2 && dim1 != 1 && dim2 != 1 &&
622+
dim1 != xla::Shape::kUnboundedSize &&
623+
dim2 != xla::Shape::kUnboundedSize) {
624+
auto shape_str1 = shape1.ToString();
625+
auto shape_str2 = shape2.ToString();
626+
auto message = absl::StrCat(
627+
"Shapes are not compatible for broadcasting: ", shape_str1, " vs. ",
628+
shape_str2, ". Expected dimension ", dim_index1, " of shape ",
629+
shape_str1, " (", dim1, ") ", "to match dimension ", dim_index2,
630+
" of shape ", shape_str2, " (", dim2, "). ",
631+
"Either that or that any of them is either 1 or unbounded. ",
632+
"Try reshaping one of the tensors to match the "
633+
"other.");
634+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(message));
635+
}
623636

624637
// TODO: Consider replacing the broadcasting logic below with
625638
// 'xla::ShapeInference::InferDegenerateDimensionBroadcastShape' resuing the
@@ -684,7 +697,7 @@ std::vector<int64_t> XlaHelpers::getBroadcastDimensions(xla::XlaOp op1,
684697
xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1,
685698
const xla::Shape& shape2) {
686699
if (!shape1.is_dynamic() && !shape2.is_dynamic()) {
687-
auto promoted_shape = GetPromotedShape(shape1, shape2);
700+
auto promoted_shape = GetValueOrThrow(GetPromotedShape(shape1, shape2));
688701
return xla::ShapeUtil::MakeShape(
689702
PromoteType(shape1.element_type(), shape2.element_type()),
690703
promoted_shape.dimensions());
@@ -763,7 +776,7 @@ std::pair<xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteShapes(xla::XlaOp op1,
763776
const xla::Shape& shape1 = ShapeHelper::ShapeOfXlaOp(op1);
764777
const xla::Shape& shape2 = ShapeHelper::ShapeOfXlaOp(op2);
765778

766-
xla::Shape shape = GetPromotedShape(shape1, shape2);
779+
xla::Shape shape = GetValueOrThrow(GetPromotedShape(shape1, shape2));
767780
if (shape1.is_unbounded_dynamic() || shape2.is_unbounded_dynamic()) {
768781
return ImplicitBroadcastWithUnboundedDynamicShapes(op1, op2, shape);
769782
}

torch_xla/csrc/helpers.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <vector>
1212

1313
#include "absl/status/status.h"
14+
#include "absl/status/statusor.h"
1415
#include "absl/types/optional.h"
1516
#include "absl/types/span.h"
1617
#include "torch_xla/csrc/runtime/debug_macros.h"
@@ -302,8 +303,8 @@ class XlaHelpers {
302303
xla::XlaOp op2);
303304

304305
// Given the two shape 'shape1' and 'shape2', infers the broadcasted shape.
305-
static xla::Shape GetPromotedShape(const xla::Shape& shape1,
306-
const xla::Shape& shape2);
306+
static absl::StatusOr<xla::Shape> GetPromotedShape(const xla::Shape& shape1,
307+
const xla::Shape& shape2);
307308

308309
static xla::Shape GetPromotedDynamicShape(const xla::Shape& shape1,
309310
const xla::Shape& shape2);

torch_xla/csrc/ops/triangular_solve.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "torch_xla/csrc/helpers.h"
44
#include "torch_xla/csrc/lowering_context.h"
55
#include "torch_xla/csrc/shape_helper.h"
6+
#include "torch_xla/csrc/status.h"
67
#include "xla/hlo/builder/xla_builder.h"
78
#include "xla/layout_util.h"
89

@@ -32,8 +33,8 @@ std::pair<xla::Shape, xla::Shape> InferTriangularSolveShape(
3233
return std::pair<xla::Shape, xla::Shape>(rhs_batch_shape, lhs_batch_shape);
3334
}
3435
// Obtain the promoted shapes and add back the trailing dimension.
35-
xla::Shape rhs_batch_promoted_shape =
36-
XlaHelpers::GetPromotedShape(rhs_batch_shape, lhs_batch_shape);
36+
xla::Shape rhs_batch_promoted_shape = GetValueOrThrow(
37+
XlaHelpers::GetPromotedShape(rhs_batch_shape, lhs_batch_shape));
3738
xla::Shape lhs_batch_promoted_shape(rhs_batch_promoted_shape);
3839
rhs_batch_promoted_shape.add_dimensions(nrhs);
3940
lhs_batch_promoted_shape.add_dimensions(n);

torch_xla/csrc/xla_lower_util.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ ConditionMaskData CreateConditionMaskData(xla::XlaOp condition) {
6363
xla::XlaOp GetPromotedMask(xla::XlaOp mask, const xla::Shape& input_shape) {
6464
const xla::Shape& mask_shape = ShapeHelper::ShapeOfXlaOp(mask);
6565
xla::Shape promoted_mask_shape =
66-
XlaHelpers::GetPromotedShape(mask_shape, input_shape);
66+
GetValueOrThrow(XlaHelpers::GetPromotedShape(mask_shape, input_shape));
6767
return XlaHelpers::ImplicitBroadcast(mask, mask_shape, promoted_mask_shape);
6868
}
6969

@@ -543,7 +543,8 @@ std::vector<xla::XlaOp> CreateBroadcastTensors(
543543
for (const xla::XlaOp operand : operands) {
544544
const xla::Shape& operand_shape = ShapeHelper::ShapeOfXlaOp(operand);
545545
operand_shapes.push_back(operand_shape);
546-
result_shape = XlaHelpers::GetPromotedShape(result_shape, operand_shape);
546+
result_shape = GetValueOrThrow(
547+
XlaHelpers::GetPromotedShape(result_shape, operand_shape));
547548
}
548549
std::vector<xla::XlaOp> result;
549550
for (size_t i = 0; i < operands.size(); ++i) {

0 commit comments

Comments
 (0)