|
8 | 8 |
|
9 | 9 | #include <algorithm>
|
10 | 10 | #include <functional>
|
| 11 | +#include <iterator> |
11 | 12 |
|
12 | 13 | #include "absl/log/absl_check.h"
|
13 | 14 | #include "absl/strings/str_cat.h"
|
@@ -345,6 +346,69 @@ XLATensorPtr DispatchComparisonOp(c10::Symbol kind, const XLATensorPtr& input,
|
345 | 346 | return XLATensor::Create(node, input->GetDevice(), at::ScalarType::Bool);
|
346 | 347 | }
|
347 | 348 |
|
| 349 | +// Checks that the canonical dimensions out of the given dimensions are unique |
| 350 | +// for the `flip` operation. |
| 351 | +// |
| 352 | +// This function fails if any canonical dimension appears more than once. |
| 353 | +// Notice that its error message is specialized for the `flip` operation. |
| 354 | +// |
| 355 | +// @param rank Input rank |
| 356 | +// @param dims (Error Message) `flip` operation original `dims` argument |
| 357 | +// @param canonical_dims (Error Message) Canonical dimensions extracted from |
| 358 | +// the `dims` argument |
| 359 | +absl::Status CheckFlipDimensionsAreUnique( |
| 360 | + int64_t rank, absl::Span<const int64_t> dims, |
| 361 | + absl::Span<const int64_t> canonical_dims) { |
| 362 | + // Counter that maps each given dimension to the number of times it has |
| 363 | + // appeared. |
| 364 | + std::vector<int64_t> count(rank, 0); |
| 365 | + |
| 366 | + // Count the number of times each dimension appears. |
| 367 | + for (auto dim : canonical_dims) { |
| 368 | + count[dim] += 1; |
| 369 | + } |
| 370 | + |
| 371 | + bool any_dimension_appears_more_than_once = std::any_of( |
| 372 | + count.begin(), count.end(), [](const auto n) { return n > 1; }); |
| 373 | + |
| 374 | + if (any_dimension_appears_more_than_once) { |
| 375 | + // Suggestion for the value of dims that wouldn't raise an error. |
| 376 | + std::vector<int64_t> dims_suggestion; |
| 377 | + // Each "bad" dimension is represented as a string of the form: |
| 378 | + // |
| 379 | + // <dimension> (<count> times) |
| 380 | + // |
| 381 | + // To be later joined with commas. |
| 382 | + std::vector<std::string> bad_count_str; |
| 383 | + |
| 384 | + // Iterates each dimension, populating both `dims_suggestion` and |
| 385 | + // `bad_count_str`. |
| 386 | + for (int64_t i : c10::irange(rank)) { |
| 387 | + // Dimension does not appear. Do nothing. |
| 388 | + if (count[i] == 0) { |
| 389 | + continue; |
| 390 | + } |
| 391 | + |
| 392 | + // Dimension appears in `dims`. Add it to the suggestion list. |
| 393 | + dims_suggestion.push_back(i); |
| 394 | + |
| 395 | + // Dimension appears more than once. Add it to the "bad" list. |
| 396 | + if (count[i] > 1) { |
| 397 | + bad_count_str.push_back(absl::StrCat(i, " (", count[i], " times)")); |
| 398 | + } |
| 399 | + } |
| 400 | + |
| 401 | + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat( |
| 402 | + "flip(): expected each dimension to appear at most once. Found " |
| 403 | + "dimensions: ", |
| 404 | + absl::StrJoin(bad_count_str, /* sep= */ ", "), |
| 405 | + ". Consider changing dims from [", absl::StrJoin(dims, /* sep= */ ", "), |
| 406 | + "] to [", absl::StrJoin(dims_suggestion, /* sep= */ ", "), "]."))); |
| 407 | + } |
| 408 | + |
| 409 | + return absl::OkStatus(); |
| 410 | +} |
| 411 | + |
348 | 412 | } // namespace
|
349 | 413 |
|
350 | 414 | //////////////////////////////////////////////////////////////////////////////
|
@@ -1680,12 +1744,11 @@ void fill_(XLATensorPtr& input, const at::Scalar& value) {
|
1680 | 1744 | input->SetInPlaceIrValue(std::move(constant));
|
1681 | 1745 | }
|
1682 | 1746 |
|
1683 |
| -XLATensorPtr flip(const XLATensorPtr& input, absl::Span<const int64_t> dims) { |
1684 |
| - auto dimensions = torch::lazy::GetCanonicalDimensionIndices( |
1685 |
| - torch_xla::runtime::util::ToVector<int64_t>(dims), |
1686 |
| - input->shape().get().dimensions_size()); |
1687 |
| - std::set<int64_t> unique_dims(dimensions.begin(), dimensions.end()); |
1688 |
| - XLA_CHECK_EQ(unique_dims.size(), dimensions.size()); |
| 1747 | +absl::StatusOr<absl_nonnull XLATensorPtr> flip(const XLATensorPtr& input, |
| 1748 | + absl::Span<const int64_t> dims) { |
| 1749 | + auto rank = input->shape().get().dimensions_size(); |
| 1750 | + auto dimensions = torch::lazy::GetCanonicalDimensionIndices(dims, rank); |
| 1751 | + XLA_RETURN_IF_ERROR(CheckFlipDimensionsAreUnique(rank, dims, dimensions)); |
1689 | 1752 | return input->CreateFrom(
|
1690 | 1753 | torch_xla::MakeNode<Flip>(input->GetIrValue(), dimensions));
|
1691 | 1754 | }
|
|
0 commit comments