Skip to content

Commit d5b9a6d

Browse files
authored
flip: improve error handling and error messages. (#9550)
1 parent 40f58a6 commit d5b9a6d

File tree

4 files changed

+89
-9
lines changed

4 files changed

+89
-9
lines changed

test/test_operations.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2497,6 +2497,20 @@ def test_div_raises_error_on_invalid_rounding_mode(self):
24972497
"'trunc', 'floor', or be left unspecified.")
24982498
self.assertEqual(str(e), expected_error)
24992499

2500+
def test_flip_raises_error_on_duplicated_dims(self):
2501+
a = torch.rand(2, 2, 2, 2, device=torch_xla.device())
2502+
dims = [0, 0, 0, 1, 2, 3, -1]
2503+
dims_suggestion = [0, 1, 2, 3]
2504+
2505+
try:
2506+
torch.flip(a, dims=dims)
2507+
except RuntimeError as e:
2508+
expected_error = (
2509+
"flip(): expected each dimension to appear at most once. Found "
2510+
"dimensions: 0 (3 times), 3 (2 times). Consider changing dims "
2511+
f"from {dims} to {dims_suggestion}.")
2512+
self.assertEqual(str(e), expected_error)
2513+
25002514

25012515
class MNISTComparator(nn.Module):
25022516

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,8 +1804,10 @@ at::Tensor& XLANativeFunctions::fill_(at::Tensor& self,
18041804
at::Tensor XLANativeFunctions::flip(const at::Tensor& self,
18051805
at::IntArrayRef dims) {
18061806
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1807-
return bridge::AtenFromXlaTensor(tensor_methods::flip(
1808-
GetValueOrThrow(bridge::GetXlaTensor(self)), XlaHelpers::I64List(dims)));
1807+
auto xself = GetValueOrThrow(bridge::GetXlaTensor(self));
1808+
auto output =
1809+
GetValueOrThrow(tensor_methods::flip(xself, XlaHelpers::I64List(dims)));
1810+
return bridge::AtenFromXlaTensor(std::move(output));
18091811
}
18101812

18111813
at::Tensor XLANativeFunctions::floor_divide(const at::Tensor& self,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <algorithm>
1010
#include <functional>
11+
#include <iterator>
1112

1213
#include "absl/log/absl_check.h"
1314
#include "absl/strings/str_cat.h"
@@ -345,6 +346,69 @@ XLATensorPtr DispatchComparisonOp(c10::Symbol kind, const XLATensorPtr& input,
345346
return XLATensor::Create(node, input->GetDevice(), at::ScalarType::Bool);
346347
}
347348

349+
// Checks that the canonical dimensions out of the given dimensions are unique
350+
// for the `flip` operation.
351+
//
352+
// This function fails if any canonical dimension appears more than once.
353+
// Notice that its error message is specialized for the `flip` operation.
354+
//
355+
// @param rank Input rank
356+
// @param dims (Error Message) `flip` operation original `dims` argument
357+
// @param canonical_dims (Error Message) Canonical dimensions extracted from
358+
// the `dims` argument
359+
absl::Status CheckFlipDimensionsAreUnique(
360+
int64_t rank, absl::Span<const int64_t> dims,
361+
absl::Span<const int64_t> canonical_dims) {
362+
// Counter that maps each given dimension to the number of times it has
363+
// appeared.
364+
std::vector<int64_t> count(rank, 0);
365+
366+
// Count the number of times each dimension appears.
367+
for (auto dim : canonical_dims) {
368+
count[dim] += 1;
369+
}
370+
371+
bool any_dimension_appears_more_than_once = std::any_of(
372+
count.begin(), count.end(), [](const auto n) { return n > 1; });
373+
374+
if (any_dimension_appears_more_than_once) {
375+
// Suggestion for the value of dims that wouldn't raise an error.
376+
std::vector<int64_t> dims_suggestion;
377+
// Each "bad" dimension is represented as a string of the form:
378+
//
379+
// <dimension> (<count> times)
380+
//
381+
// To be later joined with commas.
382+
std::vector<std::string> bad_count_str;
383+
384+
// Iterates each dimension, populating both `dims_suggestion` and
385+
// `bad_count_str`.
386+
for (int64_t i : c10::irange(rank)) {
387+
// Dimension does not appear. Do nothing.
388+
if (count[i] == 0) {
389+
continue;
390+
}
391+
392+
// Dimension appears in `dims`. Add it to the suggestion list.
393+
dims_suggestion.push_back(i);
394+
395+
// Dimension appears more than once. Add it to the "bad" list.
396+
if (count[i] > 1) {
397+
bad_count_str.push_back(absl::StrCat(i, " (", count[i], " times)"));
398+
}
399+
}
400+
401+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
402+
"flip(): expected each dimension to appear at most once. Found "
403+
"dimensions: ",
404+
absl::StrJoin(bad_count_str, /* sep= */ ", "),
405+
". Consider changing dims from [", absl::StrJoin(dims, /* sep= */ ", "),
406+
"] to [", absl::StrJoin(dims_suggestion, /* sep= */ ", "), "].")));
407+
}
408+
409+
return absl::OkStatus();
410+
}
411+
348412
} // namespace
349413

350414
//////////////////////////////////////////////////////////////////////////////
@@ -1680,12 +1744,11 @@ void fill_(XLATensorPtr& input, const at::Scalar& value) {
16801744
input->SetInPlaceIrValue(std::move(constant));
16811745
}
16821746

1683-
XLATensorPtr flip(const XLATensorPtr& input, absl::Span<const int64_t> dims) {
1684-
auto dimensions = torch::lazy::GetCanonicalDimensionIndices(
1685-
torch_xla::runtime::util::ToVector<int64_t>(dims),
1686-
input->shape().get().dimensions_size());
1687-
std::set<int64_t> unique_dims(dimensions.begin(), dimensions.end());
1688-
XLA_CHECK_EQ(unique_dims.size(), dimensions.size());
1747+
absl::StatusOr<absl_nonnull XLATensorPtr> flip(const XLATensorPtr& input,
1748+
absl::Span<const int64_t> dims) {
1749+
auto rank = input->shape().get().dimensions_size();
1750+
auto dimensions = torch::lazy::GetCanonicalDimensionIndices(dims, rank);
1751+
XLA_RETURN_IF_ERROR(CheckFlipDimensionsAreUnique(rank, dims, dimensions));
16891752
return input->CreateFrom(
16901753
torch_xla::MakeNode<Flip>(input->GetIrValue(), dimensions));
16911754
}

torch_xla/csrc/tensor_methods.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,8 @@ void eye_out(XLATensorPtr& out, int64_t lines, int64_t cols);
450450
void fill_(XLATensorPtr& input, const at::Scalar& value);
451451

452452
// Flips (reverses) the values in the dimensions of the input tensor.
453-
XLATensorPtr flip(const XLATensorPtr& input, absl::Span<const int64_t> dims);
453+
absl::StatusOr<absl_nonnull XLATensorPtr> flip(const XLATensorPtr& input,
454+
absl::Span<const int64_t> dims);
454455

455456
XLATensorPtr fmod(
456457
const XLATensorPtr& input, const XLATensorPtr& other,

0 commit comments

Comments
 (0)