Skip to content

Commit 8243a25

Browse files
authored
gather: improve error handling and error messages. (#9566)
This PR refactors the `tensor_methods::gather` implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::gather` return `StatusOr<absl_nonnull XLATensorPtr>` - Improve error message on incompatible tensor shapes
1 parent 147d2c2 commit 8243a25

File tree

4 files changed

+82
-15
lines changed

4 files changed

+82
-15
lines changed

test/test_operations.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,6 +2521,39 @@ def test_full_raises_error_on_negative_size(self):
25212521
f"positive values. However found negative ones: {shape}.")
25222522
self.assertEqual(str(e), expected_error)
25232523

2524+
def test_gather_raises_error_on_rank_mismatch(self):
2525+
S = 2
2526+
2527+
input = torch.arange(4, device=torch_xla.device()).view(S, S)
2528+
index = torch.randint(0, S, (S, S, S), device=torch_xla.device())
2529+
dim = 1
2530+
2531+
try:
2532+
torch.gather(input, dim, index)
2533+
except RuntimeError as e:
2534+
expected_error = (
2535+
"gather(): expected rank of input (2) and index (3) tensors "
2536+
"to be the same.")
2537+
self.assertEqual(str(e), expected_error)
2538+
2539+
def test_gather_raises_error_on_invalid_index_size(self):
2540+
S = 2
2541+
X = S + 2
2542+
2543+
input = torch.arange(16, device=torch_xla.device()).view(S, S, S, S)
2544+
index = torch.randint(0, S, (X, S, X, S), device=torch_xla.device())
2545+
dim = 1
2546+
2547+
try:
2548+
torch.gather(input, dim, index)
2549+
except RuntimeError as e:
2550+
expected_error = (
2551+
f"gather(): expected sizes of index [{X}, {S}, {X}, {S}] to be "
2552+
f"smaller or equal those of input [{S}, {S}, {S}, {S}] on all "
2553+
f"dimensions, except on dimension {dim}. "
2554+
"However, that's not true on dimensions [0, 2].")
2555+
self.assertEqual(str(e), expected_error)
2556+
25242557

25252558
class MNISTComparator(nn.Module):
25262559

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,9 +1865,9 @@ at::Tensor XLANativeFunctions::gather(const at::Tensor& self, int64_t dim,
18651865
const at::Tensor& index,
18661866
bool /* sparse_grad */) {
18671867
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1868-
return bridge::AtenFromXlaTensor(
1868+
return bridge::AtenFromXlaTensor(GetValueOrThrow(
18691869
tensor_methods::gather(GetValueOrThrow(bridge::GetXlaTensor(self)), dim,
1870-
GetValueOrThrow(bridge::GetXlaTensor(index))));
1870+
GetValueOrThrow(bridge::GetXlaTensor(index)))));
18711871
}
18721872

18731873
at::Tensor XLANativeFunctions::gelu(const at::Tensor& self,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,43 @@ absl::Status CheckFullConcreteSizesArePositive(at::SymIntArrayRef sym_sizes) {
442442
});
443443
}
444444

445+
absl::Status CheckGatherRanksAreEqual(const XLATensorPtr& input,
446+
const XLATensorPtr& index) {
447+
int64_t input_rank = input->shape().get().dimensions_size();
448+
int64_t index_rank = index->shape().get().dimensions_size();
449+
if (input_rank != index_rank) {
450+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
451+
"gather(): expected rank of input (", input_rank, ") and index (",
452+
index_rank, ") tensors to be the same.")));
453+
}
454+
return absl::OkStatus();
455+
}
456+
457+
// Checks that all index dimensions are smaller or equal to those of input,
458+
// except on dimension canonical_dim.
459+
absl::Status CheckGatherDimensionsAreCompatible(const XLATensorPtr& input,
460+
const XLATensorPtr& index,
461+
int64_t canonical_dim) {
462+
// Dimensions that fail the "smaller or equal" condition.
463+
std::vector<int64_t> bad_dims;
464+
for (int64_t dim = 0; dim < input->shape().get().dimensions_size(); dim++) {
465+
if (dim != canonical_dim && input->size(dim) < index->size(dim)) {
466+
bad_dims.push_back(dim);
467+
}
468+
}
469+
if (!bad_dims.empty()) {
470+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
471+
"gather(): expected sizes of index [",
472+
absl::StrJoin(index->shape().get().dimensions(), /* sep= */ ", "),
473+
"] to be smaller or equal those of input [",
474+
absl::StrJoin(input->shape().get().dimensions(), /* sep= */ ", "),
475+
"] on all dimensions, except on dimension ", canonical_dim,
476+
". However, that's not true on dimensions [",
477+
absl::StrJoin(bad_dims, /* sep= */ ", "), "].")));
478+
}
479+
return absl::OkStatus();
480+
}
481+
445482
} // namespace
446483

447484
//////////////////////////////////////////////////////////////////////////////
@@ -1838,18 +1875,14 @@ absl::StatusOr<absl_nonnull XLATensorPtr> full_symint(
18381875
device, scalar_type);
18391876
}
18401877

1841-
XLATensorPtr gather(const XLATensorPtr& input, int64_t dim,
1842-
const XLATensorPtr& index) {
1843-
xla::Shape input_shape = input->shape();
1844-
xla::Shape index_shape = index->shape();
1845-
XLA_CHECK_EQ(input_shape.dimensions_size(), index_shape.dimensions_size());
1878+
absl::StatusOr<absl_nonnull XLATensorPtr> gather(const XLATensorPtr& input,
1879+
int64_t dim,
1880+
const XLATensorPtr& index) {
18461881
int64_t canonical_dim = torch::lazy::GetCanonicalDimensionIndex(
1847-
dim, input_shape.dimensions_size());
1848-
for (size_t dim = 0; dim < input_shape.dimensions_size(); dim++) {
1849-
if (dim != canonical_dim) {
1850-
XLA_CHECK_LE(index->size(dim), input->size(dim));
1851-
}
1852-
}
1882+
dim, input->shape().get().dimensions_size());
1883+
XLA_RETURN_IF_ERROR(CheckGatherRanksAreEqual(input, index));
1884+
XLA_RETURN_IF_ERROR(
1885+
CheckGatherDimensionsAreCompatible(input, index, canonical_dim));
18531886
return input->CreateFrom(torch_xla::MakeNode<Gather>(
18541887
input->GetIrValue(), canonical_dim, index->GetIrValue()));
18551888
}

torch_xla/csrc/tensor_methods.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,9 @@ absl::StatusOr<absl_nonnull XLATensorPtr> full_symint(
470470
at::SymIntArrayRef sym_size, const at::Scalar& fill_value,
471471
const torch::lazy::BackendDevice& device, at::ScalarType scalar_type);
472472

473-
XLATensorPtr gather(const XLATensorPtr& input, int64_t dim,
474-
const XLATensorPtr& index);
473+
absl::StatusOr<absl_nonnull XLATensorPtr> gather(const XLATensorPtr& input,
474+
int64_t dim,
475+
const XLATensorPtr& index);
475476

476477
XLATensorPtr ge(const XLATensorPtr& input, const at::Scalar& other);
477478

0 commit comments

Comments
 (0)