Skip to content

Commit 49ac22a

Browse files
authored
random_: improve error handling and error messages. (#9567)
This PR refactors the `random_` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::random_` return `Status` - Replace `CheckRangeValues` by `CheckValueWithinTypeRange`, and make it return `Status` - Refactor `XLANativeFunctions::random_` overloads to handle the status values - Improve error messages
1 parent 8243a25 commit 49ac22a

File tree

4 files changed

+70
-22
lines changed

4 files changed

+70
-22
lines changed

test/test_operations.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2554,6 +2554,33 @@ def test_gather_raises_error_on_invalid_index_size(self):
25542554
"However, that's not true on dimensions [0, 2].")
25552555
self.assertEqual(str(e), expected_error)
25562556

2557+
def test_random__raises_error_on_empty_interval(self):
2558+
a = torch.empty(10, device=torch_xla.device())
2559+
from_ = 3
2560+
to_ = 1
2561+
2562+
try:
2563+
a.random_(from_, to_)
2564+
except RuntimeError as e:
2565+
expected_error = (
2566+
f"random_(): expected `from` ({from_}) to be smaller than "
2567+
f"`to` ({to_}).")
2568+
self.assertEqual(str(e), expected_error)
2569+
2570+
def test_random__raises_error_on_value_out_of_type_value_range(self):
2571+
a = torch.empty(10, device=torch_xla.device(), dtype=torch.float16)
2572+
from_ = 3
2573+
to_ = 65504 + 1
2574+
2575+
try:
2576+
a.random_(from_, to_)
2577+
except RuntimeError as e:
2578+
expected_error = (
2579+
f"random_(): expected `to` to be within the range "
2580+
f"[-65504, 65504]. However got value {to_}, which is greater "
2581+
"than the upper bound.")
2582+
self.assertEqual(str(e), expected_error)
2583+
25572584

25582585
class MNISTComparator(nn.Module):
25592586

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <optional>
1919

2020
#include "absl/log/absl_check.h"
21+
#include "status.h"
2122
#include "torch/csrc/lazy/core/helpers.h"
2223
#include "torch/csrc/lazy/core/shape_inference.h"
2324
#include "torch/csrc/lazy/core/tensor_util.h"
@@ -317,18 +318,27 @@ int64_t GetIntegerUpperLimitForType(torch::ScalarType dtype) {
317318
}
318319
}
319320

320-
void CheckRangeValues(torch::ScalarType dtype, int64_t from, int64_t to) {
321-
XlaHelpers::MinMax min_max;
322-
// Bound the min_max by int64_t since types of "from" and "to" are int64.
323-
if (IsTypeWithLargerRangeThanLong(dtype)) {
324-
min_max = XlaHelpers::MinMaxValues(xla::PrimitiveType::S64);
325-
} else {
326-
min_max = XlaHelpers::MinMaxValues(XlaTypeFromTorchType(dtype));
321+
absl::Status CheckValueWithinTypeRange(const std::string_view op,
322+
const std::string_view arg,
323+
torch::ScalarType dtype, int64_t value) {
324+
xla::PrimitiveType type = IsTypeWithLargerRangeThanLong(dtype)
325+
? xla::PrimitiveType::S64
326+
: XlaTypeFromTorchType(dtype);
327+
328+
XlaHelpers::MinMax mm = XlaHelpers::MinMaxValues(type);
329+
int64_t min = mm.min.toLong();
330+
int64_t max = mm.max.toLong();
331+
332+
if (value < min || value > max) {
333+
const std::string_view comparison = value < min ? "lower" : "greater";
334+
const std::string_view bound = value < min ? "lower bound" : "upper bound";
335+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
336+
absl::StrCat(op, "(): expected `", arg, "` to be within the range [",
337+
min, ", ", max, "]. However got value ", value,
338+
", which is ", comparison, " than the ", bound, ".")));
327339
}
328-
XLA_CHECK_GE(from, min_max.min.toLong());
329-
XLA_CHECK_LE(from, min_max.max.toLong());
330-
XLA_CHECK_GE(to, min_max.min.toLong());
331-
XLA_CHECK_LE(to, min_max.max.toLong());
340+
341+
return absl::OkStatus();
332342
}
333343

334344
std::pair<XLATensorPtr, XLATensorPtr> GetBinaryOperands(
@@ -3025,12 +3035,14 @@ at::Tensor& XLANativeFunctions::random_(
30253035
}
30263036
XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self));
30273037
at::ScalarType dtype = self_tensor->dtype();
3038+
30283039
// Prevent "to_val" from overflowing with at::ScalarType::Long.
30293040
int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1;
30303041
int64_t to_val = (to) ? *to : GetIntegerUpperLimitForType(dtype) + inc;
3031-
XLA_CHECK_LE(from, to_val);
3032-
CheckRangeValues(self_tensor->dtype(), from, to_val - 1);
3033-
tensor_methods::random_(self_tensor, from, to_val);
3042+
3043+
OkOrThrow(CheckValueWithinTypeRange("random_", "from", dtype, from));
3044+
OkOrThrow(CheckValueWithinTypeRange("random_", "to", dtype, to_val - 1));
3045+
OkOrThrow(tensor_methods::random_(self_tensor, from, to_val));
30343046
return self;
30353047
}
30363048

@@ -3043,10 +3055,12 @@ at::Tensor& XLANativeFunctions::random_(
30433055
ATEN_OP2(random_, to)>::call(self, to,
30443056
generator);
30453057
}
3058+
30463059
XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self));
3047-
XLA_CHECK_GT(to, 0);
3048-
CheckRangeValues(self_tensor->dtype(), 0, to - 1);
3049-
tensor_methods::random_(self_tensor, 0, to);
3060+
at::ScalarType dtype = self_tensor->dtype();
3061+
3062+
OkOrThrow(CheckValueWithinTypeRange("random_", "to", dtype, to - 1));
3063+
OkOrThrow(tensor_methods::random_(self_tensor, 0, to));
30503064
return self;
30513065
}
30523066

@@ -3060,10 +3074,12 @@ at::Tensor& XLANativeFunctions::random_(
30603074
}
30613075
XLATensorPtr self_tensor = GetValueOrThrow(bridge::GetXlaTensor(self));
30623076
at::ScalarType dtype = self_tensor->dtype();
3077+
30633078
// Prevent "to_val" from overflowing with at::ScalarType::Long.
30643079
int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1;
3065-
tensor_methods::random_(self_tensor, 0,
3066-
GetIntegerUpperLimitForType(dtype) + inc);
3080+
int64_t to_val = GetIntegerUpperLimitForType(dtype) + inc;
3081+
3082+
OkOrThrow(tensor_methods::random_(self_tensor, 0, to_val));
30673083
return self;
30683084
}
30693085

torch_xla/csrc/tensor_methods.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2922,15 +2922,20 @@ XLATensorPtr dynamic_view(const XLATensorPtr& input,
29222922

29232923
//////////////////////////////////////////////////////////////////////////////
29242924

2925-
void random_(XLATensorPtr& input, int64_t from, int64_t to) {
2926-
XLA_CHECK_LE(from, to);
2925+
absl::Status random_(XLATensorPtr& input, int64_t from, int64_t to) {
2926+
if (from >= to) {
2927+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
2928+
absl::StrCat("random_(): expected `from` (", from,
2929+
") to be smaller than `to` (", to, ").")));
2930+
}
29272931
auto input_shape = input->shape();
29282932
input->SetInPlaceIrValue(torch_xla::MakeNode<DiscreteUniform>(
29292933
XLAGraphExecutor::Get()->GetIrValueForScalar(
29302934
from, xla::PrimitiveType::S64, input->GetDevice()),
29312935
XLAGraphExecutor::Get()->GetIrValueForScalar(to, xla::PrimitiveType::S64,
29322936
input->GetDevice()),
29332937
XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape));
2938+
return absl::OkStatus();
29342939
}
29352940

29362941
XLATensorPtr randperm(int64_t n, const torch::lazy::BackendDevice& device,

torch_xla/csrc/tensor_methods.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ void put_(XLATensorPtr& input, const XLATensorPtr& index,
776776

777777
std::tuple<XLATensorPtr, XLATensorPtr> qr(const XLATensorPtr& input, bool some);
778778

779-
void random_(XLATensorPtr& input, int64_t from, int64_t to);
779+
absl::Status random_(XLATensorPtr& input, int64_t from, int64_t to);
780780

781781
XLATensorPtr randperm(int64_t n, const torch::lazy::BackendDevice& device,
782782
at::ScalarType scalar_type);

0 commit comments

Comments
 (0)