Skip to content

Commit a66cfc3

Browse files
authored
roll: improve error handling and error messages. (#9628)
This PR refactors the `roll` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::roll` return `StatusOr<XLATensorPtr>` - Improve error messages and error handling - Create `CheckRollShiftsRequired` and `CheckRollDimsAndShiftsAreCompatible` functions
1 parent efe20ab commit a66cfc3

File tree

4 files changed

+94
-17
lines changed

4 files changed

+94
-17
lines changed

test/test_ops_error_message.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,48 @@ def test():
180180
expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8)."""
181181
)
182182

183+
def test_roll_raises_error_on_empty_shifts(self):
184+
device = torch_xla.device()
185+
a = torch.rand(2, 2, 2, device=device)
186+
shifts = []
187+
188+
def test():
189+
return torch.roll(a, shifts)
190+
191+
self.assertExpectedRaisesInline(
192+
exc_type=RuntimeError,
193+
callable=test,
194+
expect="""roll(): expected `shifts` to have at least 1 element.""")
195+
196+
def test_roll_raises_error_on_shifts_with_empty_dims(self):
197+
device = torch_xla.device()
198+
a = torch.rand(2, 2, 2, device=device)
199+
shifts = [2, 2]
200+
201+
def test():
202+
return torch.roll(a, shifts)
203+
204+
self.assertExpectedRaisesInline(
205+
exc_type=RuntimeError,
206+
callable=test,
207+
expect="""roll(): expected `shifts` [2, 2] (size=2) to have exactly 1 element when `dims` is empty."""
208+
)
209+
210+
def test_roll_raises_error_on_mismatched_dims_and_shifts(self):
211+
device = torch_xla.device()
212+
a = torch.rand(2, 2, 2, device=device)
213+
shifts = [2, 2]
214+
dims = [0]
215+
216+
def test():
217+
return torch.roll(a, shifts, dims)
218+
219+
self.assertExpectedRaisesInline(
220+
exc_type=RuntimeError,
221+
callable=test,
222+
expect="""roll(): expected `dims` [0] (size=1) to match the size of `shifts` [2, 2] (size=2)."""
223+
)
224+
183225

184226
if __name__ == "__main__":
185227
unittest.main()

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
#include <mutex>
1818
#include <optional>
1919

20+
#include "absl/base/nullability.h"
2021
#include "absl/log/absl_check.h"
21-
#include "status.h"
2222
#include "torch/csrc/lazy/core/helpers.h"
2323
#include "torch/csrc/lazy/core/shape_inference.h"
2424
#include "torch/csrc/lazy/core/tensor_util.h"
@@ -3317,9 +3317,13 @@ at::Tensor XLANativeFunctions::roll(const at::Tensor& self,
33173317
at::IntArrayRef shifts,
33183318
at::IntArrayRef dims) {
33193319
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
3320-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
3321-
return bridge::AtenFromXlaTensor(tensor_methods::roll(
3322-
xla_self, XlaHelpers::I64List(shifts), XlaHelpers::I64List(dims)));
3320+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
3321+
bridge::GetXlaTensor(self));
3322+
XLA_ASSIGN_OR_THROW(
3323+
absl_nonnull XLATensorPtr output,
3324+
tensor_methods::roll(xla_self, XlaHelpers::I64List(shifts),
3325+
XlaHelpers::I64List(dims)));
3326+
return bridge::AtenFromXlaTensor(std::move(output));
33233327
}
33243328

33253329
at::Tensor XLANativeFunctions::rrelu_with_noise(

torch_xla/csrc/tensor_methods.cpp

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "absl/log/absl_check.h"
1414
#include "absl/status/status.h"
1515
#include "absl/strings/str_cat.h"
16+
#include "absl/strings/str_join.h"
1617
#include "absl/strings/str_split.h"
1718
#include "torch_xla/csrc/LazyIr.h"
1819
#include "torch_xla/csrc/aten_xla_bridge.h"
@@ -506,6 +507,37 @@ absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1,
506507
return absl::OkStatus();
507508
}
508509

510+
absl::Status CheckRollShiftsRequired(absl::Span<const int64_t> shifts) {
511+
if (shifts.empty()) {
512+
return absl::InvalidArgumentError(
513+
"roll(): expected `shifts` to have at least 1 element.");
514+
}
515+
return absl::OkStatus();
516+
}
517+
518+
absl::Status CheckRollDimsAndShiftsAreCompatible(
519+
absl::Span<const int64_t> dims, absl::Span<const int64_t> shifts) {
520+
if (dims.empty()) {
521+
// If `dims` is empty, then return an error status if `shifts` is not
522+
// of size one. Otherwise, `dims` and `shifts` are valid.
523+
if (shifts.size() != 1) {
524+
return absl::InvalidArgumentError(absl::StrCat(
525+
"roll(): expected `shifts` [", absl::StrJoin(shifts, /* sep= */ ", "),
526+
"] (size=", shifts.size(),
527+
") to have exactly 1 element when `dims` is empty."));
528+
}
529+
} else if (dims.size() != shifts.size()) {
530+
// If `dims` is not empty, then return an error status if its size
531+
// does not match with `shifts` size.
532+
return absl::InvalidArgumentError(absl::StrCat(
533+
"roll(): expected `dims` [", absl::StrJoin(dims, /* sep= */ ", "),
534+
"] (size=", dims.size(), ") to match the size of `shifts` [",
535+
absl::StrJoin(shifts, /* sep= */ ", "), "] (size=", shifts.size(),
536+
")."));
537+
}
538+
return absl::OkStatus();
539+
}
540+
509541
} // namespace
510542

511543
//////////////////////////////////////////////////////////////////////////////
@@ -3052,17 +3084,15 @@ void resize_(XLATensorPtr& input, std::vector<int64_t> size) {
30523084
}
30533085
}
30543086

3055-
XLATensorPtr roll(const XLATensorPtr& input, absl::Span<const int64_t> shifts,
3056-
absl::Span<const int64_t> dims) {
3057-
XLA_CHECK_GT(shifts.size(), 0) << "`shifts` required";
3058-
if (dims.size() != 0) {
3059-
XLA_CHECK_EQ(shifts.size(), dims.size())
3060-
<< "shifts and dimensions must align. shifts: " << shifts.size()
3061-
<< ", dims:" << dims.size();
3062-
}
3063-
auto canonical_dims = torch::lazy::GetCanonicalDimensionIndices(
3064-
torch::lazy::ToVector<int64_t>(dims),
3065-
input->shape().get().dimensions_size());
3087+
absl::StatusOr<absl_nonnull XLATensorPtr> roll(
3088+
const absl_nonnull XLATensorPtr& input, absl::Span<const int64_t> shifts,
3089+
absl::Span<const int64_t> dims) {
3090+
XLA_RETURN_IF_ERROR(CheckRollShiftsRequired(shifts));
3091+
XLA_RETURN_IF_ERROR(CheckRollDimsAndShiftsAreCompatible(dims, shifts));
3092+
const std::vector<int64_t> canonical_dims =
3093+
torch::lazy::GetCanonicalDimensionIndices(
3094+
torch::lazy::ToVector<int64_t>(dims),
3095+
input->shape().get().dimensions().size());
30663096
return input->CreateFrom(torch_xla::MakeNode<Roll>(
30673097
input->GetIrValue(), torch::lazy::ToVector<int64_t>(shifts),
30683098
canonical_dims));

torch_xla/csrc/tensor_methods.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -821,8 +821,9 @@ XLATensorPtr replication_pad3d_backward(const XLATensorPtr& grad_output,
821821

822822
void resize_(XLATensorPtr& input, std::vector<int64_t> size);
823823

824-
XLATensorPtr roll(const XLATensorPtr& input, absl::Span<const int64_t> shifts,
825-
absl::Span<const int64_t> dims);
824+
absl::StatusOr<absl_nonnull XLATensorPtr> roll(
825+
const absl_nonnull XLATensorPtr& input, absl::Span<const int64_t> shifts,
826+
absl::Span<const int64_t> dims);
826827

827828
XLATensorPtr rrelu_with_noise(const XLATensorPtr& input, XLATensorPtr& noise,
828829
const at::Scalar& lower, const at::Scalar& upper,

0 commit comments

Comments
 (0)