Skip to content

Commit 38e0f03

Browse files
authored
cat: improve error handling and error messages. (#9548)
1 parent 095faec commit 38e0f03

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

test/test_operations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2473,6 +2473,19 @@ def test_construct_large_tensor_raises_error(self):
24732473
# OOM is raised when we try to bring data from the device.
24742474
b.cpu()
24752475

2476+
def test_cat_raises_error_on_incompatible_shapes(self):
2477+
a = torch.rand(2, 2, device=torch_xla.device())
2478+
b = torch.rand(5, 1, device=torch_xla.device())
2479+
2480+
try:
2481+
torch.cat([a, b])
2482+
except RuntimeError as e:
2483+
expected_error = (
2484+
"cat(): cannot concatenate tensors of shape f32[2,2] with f32[5,1] "
2485+
"at dimension 0. Expected shapes to be equal (except at dimension 0) "
2486+
"or that either of them was a 1D empty tensor of size (0,).")
2487+
self.assertEqual(str(e), expected_error)
2488+
24762489

24772490
class MNISTComparator(nn.Module):
24782491

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,9 +1314,10 @@ at::Tensor XLANativeFunctions::bmm(const at::Tensor& self,
13141314
at::Tensor XLANativeFunctions::cat(const at::ITensorListRef& tensors,
13151315
int64_t dim) {
13161316
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1317-
return bridge::AtenFromXlaTensor(
1318-
tensor_methods::cat(GetValueOrThrow(bridge::GetXlaTensors(tensors)), dim,
1319-
at::native::result_type(tensors)));
1317+
auto xtensors = GetValueOrThrow(bridge::GetXlaTensors(tensors));
1318+
auto output = GetValueOrThrow(
1319+
tensor_methods::cat(xtensors, dim, at::native::result_type(tensors)));
1320+
return bridge::AtenFromXlaTensor(std::move(output));
13201321
}
13211322

13221323
at::Tensor XLANativeFunctions::celu(const at::Tensor& self,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <algorithm>
1010
#include <functional>
1111

12+
#include "absl/log/absl_check.h"
1213
#include "absl/strings/str_cat.h"
1314
#include "absl/strings/str_split.h"
1415
#include "torch_xla/csrc/LazyIr.h"
@@ -1160,18 +1161,19 @@ std::vector<XLATensorPtr> broadcast_tensors(
11601161
return tensors.front()->MakeOutputTensors(node);
11611162
}
11621163

1163-
XLATensorPtr cat(absl::Span<const XLATensorPtr> tensors, int64_t dim,
1164-
at::ScalarType dtype) {
1164+
absl::StatusOr<absl_nonnull XLATensorPtr> cat(
1165+
absl::Span<const XLATensorPtr> tensors, int64_t dim, at::ScalarType dtype) {
11651166
// Shape checks for cat:
11661167
// - If not empty, every tensor shape must be the same.
11671168
// - Empty tensor passes but is simply ignore in implementation,
11681169
// e.g. ([2, 3, 5], [])
11691170
// - If empty dimension, other dimensions must be the same.
11701171
// e.g. ([4, 0, 32, 32], [4, 2, 32, 32], dim=1) passes.
11711172
// ([4, 0, 32, 32], [4, 2, 31, 32], dim=1) throws.
1172-
XLA_CHECK_GT(tensors.size(), 0);
1173+
ABSL_CHECK(tensors.size() > 0);
11731174
std::vector<torch::lazy::Value> values;
11741175
std::vector<xla::Shape> shapes;
1176+
size_t last_tensor_index;
11751177
for (size_t i = 0; i < tensors.size(); ++i) {
11761178
xla::Shape tensor_shape = tensors[i]->shape();
11771179
if (tensor_shape.dimensions_size() == 1 &&
@@ -1181,13 +1183,20 @@ XLATensorPtr cat(absl::Span<const XLATensorPtr> tensors, int64_t dim,
11811183
dim = torch::lazy::GetCanonicalDimensionIndex(
11821184
dim, tensor_shape.dimensions_size());
11831185
tensor_shape.DeleteDimension(dim);
1184-
if (!shapes.empty()) {
1185-
XLA_CHECK(xla::ShapeUtil::CompatibleIgnoringElementType(shapes.back(),
1186-
tensor_shape))
1187-
<< shapes.back() << " vs. " << tensor_shape;
1186+
if (!shapes.empty() && !xla::ShapeUtil::CompatibleIgnoringElementType(
1187+
shapes.back(), tensor_shape)) {
1188+
auto last_tensor = tensors[last_tensor_index];
1189+
auto tensor = tensors[i];
1190+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
1191+
"cat(): cannot concatenate tensors of shape ",
1192+
last_tensor->shape().get().ToString(), " with ",
1193+
tensor->shape().get().ToString(), " at dimension ", dim,
1194+
". Expected shapes to be equal (except at dimension ", dim,
1195+
") or that either of them was a 1D empty tensor of size (0,).")));
11881196
}
11891197
shapes.push_back(tensor_shape);
11901198
values.push_back(tensors[i]->GetIrValue());
1199+
last_tensor_index = i;
11911200
}
11921201
if (values.empty()) {
11931202
return tensors[0];

torch_xla/csrc/tensor_methods.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef XLA_TORCH_XLA_CSRC_TENSOR_METHODS_H_
22
#define XLA_TORCH_XLA_CSRC_TENSOR_METHODS_H_
33

4+
#include "absl/base/nullability.h"
45
#include "torch_xla/csrc/cross_replica_reduces.h"
56
#include "torch_xla/csrc/ops/custom_sharding.h"
67
#include "torch_xla/csrc/runtime/computation_client.h"
@@ -307,8 +308,8 @@ XLATensorPtr bmm(const XLATensorPtr& batch1, const XLATensorPtr& batch2);
307308
std::vector<XLATensorPtr> broadcast_tensors(
308309
absl::Span<const XLATensorPtr> tensors);
309310

310-
XLATensorPtr cat(absl::Span<const XLATensorPtr> tensors, int64_t dim,
311-
at::ScalarType dtype);
311+
absl::StatusOr<absl_nonnull XLATensorPtr> cat(
312+
absl::Span<const XLATensorPtr> tensors, int64_t dim, at::ScalarType dtype);
312313

313314
XLATensorPtr cdist_forward(const XLATensorPtr& x1, const XLATensorPtr& x2,
314315
double p);

0 commit comments

Comments
 (0)