|
13 | 13 | #include "absl/log/absl_check.h"
|
14 | 14 | #include "absl/status/status.h"
|
15 | 15 | #include "absl/strings/str_cat.h"
|
| 16 | +#include "absl/strings/str_join.h" |
16 | 17 | #include "absl/strings/str_split.h"
|
17 | 18 | #include "torch_xla/csrc/LazyIr.h"
|
18 | 19 | #include "torch_xla/csrc/aten_xla_bridge.h"
|
@@ -506,6 +507,37 @@ absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1,
|
506 | 507 | return absl::OkStatus();
|
507 | 508 | }
|
508 | 509 |
|
| 510 | +absl::Status CheckRollShiftsRequired(absl::Span<const int64_t> shifts) { |
| 511 | + if (shifts.empty()) { |
| 512 | + return absl::InvalidArgumentError( |
| 513 | + "roll(): expected `shifts` to have at least 1 element."); |
| 514 | + } |
| 515 | + return absl::OkStatus(); |
| 516 | +} |
| 517 | + |
| 518 | +absl::Status CheckRollDimsAndShiftsAreCompatible( |
| 519 | + absl::Span<const int64_t> dims, absl::Span<const int64_t> shifts) { |
| 520 | + if (dims.empty()) { |
| 521 | + // If `dims` is empty, then return an error status if `shifts` is not |
| 522 | + // of size one. Otherwise, `dims` and `shifts` are valid. |
| 523 | + if (shifts.size() != 1) { |
| 524 | + return absl::InvalidArgumentError(absl::StrCat( |
| 525 | + "roll(): expected `shifts` [", absl::StrJoin(shifts, /* sep= */ ", "), |
| 526 | + "] (size=", shifts.size(), |
| 527 | + ") to have exactly 1 element when `dims` is empty.")); |
| 528 | + } |
| 529 | + } else if (dims.size() != shifts.size()) { |
| 530 | + // If `dims` is not empty, then return an error status if its size |
| 531 | + // does not match with `shifts` size. |
| 532 | + return absl::InvalidArgumentError(absl::StrCat( |
| 533 | + "roll(): expected `dims` [", absl::StrJoin(dims, /* sep= */ ", "), |
| 534 | + "] (size=", dims.size(), ") to match the size of `shifts` [", |
| 535 | + absl::StrJoin(shifts, /* sep= */ ", "), "] (size=", shifts.size(), |
| 536 | + ").")); |
| 537 | + } |
| 538 | + return absl::OkStatus(); |
| 539 | +} |
| 540 | + |
509 | 541 | } // namespace
|
510 | 542 |
|
511 | 543 | //////////////////////////////////////////////////////////////////////////////
|
@@ -3052,17 +3084,15 @@ void resize_(XLATensorPtr& input, std::vector<int64_t> size) {
|
3052 | 3084 | }
|
3053 | 3085 | }
|
3054 | 3086 |
|
3055 |
| -XLATensorPtr roll(const XLATensorPtr& input, absl::Span<const int64_t> shifts, |
3056 |
| - absl::Span<const int64_t> dims) { |
3057 |
| - XLA_CHECK_GT(shifts.size(), 0) << "`shifts` required"; |
3058 |
| - if (dims.size() != 0) { |
3059 |
| - XLA_CHECK_EQ(shifts.size(), dims.size()) |
3060 |
| - << "shifts and dimensions must align. shifts: " << shifts.size() |
3061 |
| - << ", dims:" << dims.size(); |
3062 |
| - } |
3063 |
| - auto canonical_dims = torch::lazy::GetCanonicalDimensionIndices( |
3064 |
| - torch::lazy::ToVector<int64_t>(dims), |
3065 |
| - input->shape().get().dimensions_size()); |
| 3087 | +absl::StatusOr<absl_nonnull XLATensorPtr> roll( |
| 3088 | + const absl_nonnull XLATensorPtr& input, absl::Span<const int64_t> shifts, |
| 3089 | + absl::Span<const int64_t> dims) { |
| 3090 | + XLA_RETURN_IF_ERROR(CheckRollShiftsRequired(shifts)); |
| 3091 | + XLA_RETURN_IF_ERROR(CheckRollDimsAndShiftsAreCompatible(dims, shifts)); |
| 3092 | + const std::vector<int64_t> canonical_dims = |
| 3093 | + torch::lazy::GetCanonicalDimensionIndices( |
| 3094 | + torch::lazy::ToVector<int64_t>(dims), |
| 3095 | + input->shape().get().dimensions().size()); |
3066 | 3096 | return input->CreateFrom(torch_xla::MakeNode<Roll>(
|
3067 | 3097 | input->GetIrValue(), torch::lazy::ToVector<int64_t>(shifts),
|
3068 | 3098 | canonical_dims));
|
|
0 commit comments