@@ -582,8 +582,8 @@ void ExtractDimensionSizesAndDynamicDimensionsFromShape(
582
582
583
583
} // namespace
584
584
585
- xla::Shape XlaHelpers::GetPromotedShape (const xla::Shape& shape1,
586
- const xla::Shape& shape2) {
585
+ absl::StatusOr< xla::Shape> XlaHelpers::GetPromotedShape (
586
+ const xla::Shape& shape1, const xla::Shape& shape2) {
587
587
std::vector<int64_t > dimensions;
588
588
std::vector<bool > dynamic_dimensions;
589
589
@@ -606,20 +606,33 @@ xla::Shape XlaHelpers::GetPromotedShape(const xla::Shape& shape1,
606
606
size_t min_size =
607
607
std::min (shape1.dimensions ().size (), shape2.dimensions ().size ());
608
608
for (size_t i = 0 ; i < min_size; i++) {
609
- int64_t dim1 =
610
- shape1.dimensions ()[shape1.dimensions ().size () - min_size + i];
609
+ int64_t dim_index1 = shape1.dimensions ().size () - min_size + i;
610
+ int64_t dim_index2 = shape2.dimensions ().size () - min_size + i;
611
+ int64_t dim1 = shape1.dimensions ()[dim_index1];
612
+ int64_t dim2 = shape2.dimensions ()[dim_index2];
613
+
611
614
int64_t dynamic_dim1 =
612
615
shape1.dynamic_dimensions ()[shape1.dynamic_dimensions ().size () -
613
616
min_size + i];
614
- int64_t dim2 =
615
- shape2.dimensions ()[shape2.dimensions ().size () - min_size + i];
616
617
int64_t dynamic_dim2 =
617
618
shape2.dynamic_dimensions ()[shape2.dynamic_dimensions ().size () -
618
619
min_size + i];
619
620
620
- XLA_CHECK (dim1 == dim2 || dim1 == 1 || dim2 == 1 ||
621
- dim1 == xla::Shape::kUnboundedSize ||
622
- dim2 == xla::Shape::kUnboundedSize );
621
+ if (dim1 != dim2 && dim1 != 1 && dim2 != 1 &&
622
+ dim1 != xla::Shape::kUnboundedSize &&
623
+ dim2 != xla::Shape::kUnboundedSize ) {
624
+ auto shape_str1 = shape1.ToString ();
625
+ auto shape_str2 = shape2.ToString ();
626
+ auto message = absl::StrCat (
627
+ " Shapes are not compatible for broadcasting: " , shape_str1, " vs. " ,
628
+ shape_str2, " . Expected dimension " , dim_index1, " of shape " ,
629
+ shape_str1, " (" , dim1, " ) " , " to match dimension " , dim_index2,
630
+ " of shape " , shape_str2, " (" , dim2, " ). " ,
631
+ " Either that or that any of them is either 1 or unbounded. " ,
632
+ " Try reshaping one of the tensors to match the "
633
+ " other." );
634
+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (message));
635
+ }
623
636
624
637
// TODO: Consider replacing the broadcasting logic below with
625
638
// 'xla::ShapeInference::InferDegenerateDimensionBroadcastShape' resuing the
@@ -684,7 +697,7 @@ std::vector<int64_t> XlaHelpers::getBroadcastDimensions(xla::XlaOp op1,
684
697
xla::Shape XlaHelpers::GetPromotedBinaryOpShape (const xla::Shape& shape1,
685
698
const xla::Shape& shape2) {
686
699
if (!shape1.is_dynamic () && !shape2.is_dynamic ()) {
687
- auto promoted_shape = GetPromotedShape (shape1, shape2);
700
+ auto promoted_shape = GetValueOrThrow ( GetPromotedShape (shape1, shape2) );
688
701
return xla::ShapeUtil::MakeShape (
689
702
PromoteType (shape1.element_type (), shape2.element_type ()),
690
703
promoted_shape.dimensions ());
@@ -763,7 +776,7 @@ std::pair<xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteShapes(xla::XlaOp op1,
763
776
const xla::Shape& shape1 = ShapeHelper::ShapeOfXlaOp (op1);
764
777
const xla::Shape& shape2 = ShapeHelper::ShapeOfXlaOp (op2);
765
778
766
- xla::Shape shape = GetPromotedShape (shape1, shape2);
779
+ xla::Shape shape = GetValueOrThrow ( GetPromotedShape (shape1, shape2) );
767
780
if (shape1.is_unbounded_dynamic () || shape2.is_unbounded_dynamic ()) {
768
781
return ImplicitBroadcastWithUnboundedDynamicShapes (op1, op2, shape);
769
782
}
0 commit comments