9
9
#include < algorithm>
10
10
#include < functional>
11
11
12
+ #include " absl/log/absl_check.h"
12
13
#include " absl/strings/str_cat.h"
13
14
#include " absl/strings/str_split.h"
14
15
#include " torch_xla/csrc/LazyIr.h"
@@ -1160,18 +1161,19 @@ std::vector<XLATensorPtr> broadcast_tensors(
1160
1161
return tensors.front ()->MakeOutputTensors (node);
1161
1162
}
1162
1163
1163
- XLATensorPtr cat ( absl::Span< const XLATensorPtr> tensors, int64_t dim,
1164
- at::ScalarType dtype) {
1164
+ absl::StatusOr<absl_nonnull XLATensorPtr> cat (
1165
+ absl::Span< const XLATensorPtr> tensors, int64_t dim, at::ScalarType dtype) {
1165
1166
// Shape checks for cat:
1166
1167
// - If not empty, every tensor shape must be the same.
1167
1168
// - Empty tensor passes but is simply ignore in implementation,
1168
1169
// e.g. ([2, 3, 5], [])
1169
1170
// - If empty dimension, other dimensions must be the same.
1170
1171
// e.g. ([4, 0, 32, 32], [4, 2, 32, 32], dim=1) passes.
1171
1172
// ([4, 0, 32, 32], [4, 2, 31, 32], dim=1) throws.
1172
- XLA_CHECK_GT (tensors.size (), 0 );
1173
+ ABSL_CHECK (tensors.size () > 0 );
1173
1174
std::vector<torch::lazy::Value> values;
1174
1175
std::vector<xla::Shape> shapes;
1176
+ size_t last_tensor_index;
1175
1177
for (size_t i = 0 ; i < tensors.size (); ++i) {
1176
1178
xla::Shape tensor_shape = tensors[i]->shape ();
1177
1179
if (tensor_shape.dimensions_size () == 1 &&
@@ -1181,13 +1183,20 @@ XLATensorPtr cat(absl::Span<const XLATensorPtr> tensors, int64_t dim,
1181
1183
dim = torch::lazy::GetCanonicalDimensionIndex (
1182
1184
dim, tensor_shape.dimensions_size ());
1183
1185
tensor_shape.DeleteDimension (dim);
1184
- if (!shapes.empty ()) {
1185
- XLA_CHECK (xla::ShapeUtil::CompatibleIgnoringElementType (shapes.back (),
1186
- tensor_shape))
1187
- << shapes.back () << " vs. " << tensor_shape;
1186
+ if (!shapes.empty () && !xla::ShapeUtil::CompatibleIgnoringElementType (
1187
+ shapes.back (), tensor_shape)) {
1188
+ auto last_tensor = tensors[last_tensor_index];
1189
+ auto tensor = tensors[i];
1190
+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
1191
+ " cat(): cannot concatenate tensors of shape " ,
1192
+ last_tensor->shape ().get ().ToString (), " with " ,
1193
+ tensor->shape ().get ().ToString (), " at dimension " , dim,
1194
+ " . Expected shapes to be equal (except at dimension " , dim,
1195
+ " ) or that either of them was a 1D empty tensor of size (0,)." )));
1188
1196
}
1189
1197
shapes.push_back (tensor_shape);
1190
1198
values.push_back (tensors[i]->GetIrValue ());
1199
+ last_tensor_index = i;
1191
1200
}
1192
1201
if (values.empty ()) {
1193
1202
return tensors[0 ];
0 commit comments