18
18
#include < optional>
19
19
20
20
#include " absl/log/absl_check.h"
21
+ #include " status.h"
21
22
#include " torch/csrc/lazy/core/helpers.h"
22
23
#include " torch/csrc/lazy/core/shape_inference.h"
23
24
#include " torch/csrc/lazy/core/tensor_util.h"
@@ -317,18 +318,27 @@ int64_t GetIntegerUpperLimitForType(torch::ScalarType dtype) {
317
318
}
318
319
}
319
320
320
- void CheckRangeValues (torch::ScalarType dtype, int64_t from, int64_t to) {
321
- XlaHelpers::MinMax min_max;
322
- // Bound the min_max by int64_t since types of "from" and "to" are int64.
323
- if (IsTypeWithLargerRangeThanLong (dtype)) {
324
- min_max = XlaHelpers::MinMaxValues (xla::PrimitiveType::S64);
325
- } else {
326
- min_max = XlaHelpers::MinMaxValues (XlaTypeFromTorchType (dtype));
321
+ absl::Status CheckValueWithinTypeRange (const std::string_view op,
322
+ const std::string_view arg,
323
+ torch::ScalarType dtype, int64_t value) {
324
+ xla::PrimitiveType type = IsTypeWithLargerRangeThanLong (dtype)
325
+ ? xla::PrimitiveType::S64
326
+ : XlaTypeFromTorchType (dtype);
327
+
328
+ XlaHelpers::MinMax mm = XlaHelpers::MinMaxValues (type);
329
+ int64_t min = mm.min .toLong ();
330
+ int64_t max = mm.max .toLong ();
331
+
332
+ if (value < min || value > max) {
333
+ const std::string_view comparison = value < min ? " lower" : " greater" ;
334
+ const std::string_view bound = value < min ? " lower bound" : " upper bound" ;
335
+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (
336
+ absl::StrCat (op, " (): expected `" , arg, " ` to be within the range [" ,
337
+ min, " , " , max, " ]. However got value " , value,
338
+ " , which is " , comparison, " than the " , bound, " ." )));
327
339
}
328
- XLA_CHECK_GE (from, min_max.min .toLong ());
329
- XLA_CHECK_LE (from, min_max.max .toLong ());
330
- XLA_CHECK_GE (to, min_max.min .toLong ());
331
- XLA_CHECK_LE (to, min_max.max .toLong ());
340
+
341
+ return absl::OkStatus ();
332
342
}
333
343
334
344
std::pair<XLATensorPtr, XLATensorPtr> GetBinaryOperands (
@@ -3025,12 +3035,14 @@ at::Tensor& XLANativeFunctions::random_(
3025
3035
}
3026
3036
XLATensorPtr self_tensor = GetValueOrThrow (bridge::GetXlaTensor (self));
3027
3037
at::ScalarType dtype = self_tensor->dtype ();
3038
+
3028
3039
// Prevent "to_val" from overflowing with at::ScalarType::Long.
3029
3040
int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1 ;
3030
3041
int64_t to_val = (to) ? *to : GetIntegerUpperLimitForType (dtype) + inc;
3031
- XLA_CHECK_LE (from, to_val);
3032
- CheckRangeValues (self_tensor->dtype (), from, to_val - 1 );
3033
- tensor_methods::random_ (self_tensor, from, to_val);
3042
+
3043
+ OkOrThrow (CheckValueWithinTypeRange (" random_" , " from" , dtype, from));
3044
+ OkOrThrow (CheckValueWithinTypeRange (" random_" , " to" , dtype, to_val - 1 ));
3045
+ OkOrThrow (tensor_methods::random_ (self_tensor, from, to_val));
3034
3046
return self;
3035
3047
}
3036
3048
@@ -3043,10 +3055,12 @@ at::Tensor& XLANativeFunctions::random_(
3043
3055
ATEN_OP2 (random_, to)>::call (self, to,
3044
3056
generator);
3045
3057
}
3058
+
3046
3059
XLATensorPtr self_tensor = GetValueOrThrow (bridge::GetXlaTensor (self));
3047
- XLA_CHECK_GT (to, 0 );
3048
- CheckRangeValues (self_tensor->dtype (), 0 , to - 1 );
3049
- tensor_methods::random_ (self_tensor, 0 , to);
3060
+ at::ScalarType dtype = self_tensor->dtype ();
3061
+
3062
+ OkOrThrow (CheckValueWithinTypeRange (" random_" , " to" , dtype, to - 1 ));
3063
+ OkOrThrow (tensor_methods::random_ (self_tensor, 0 , to));
3050
3064
return self;
3051
3065
}
3052
3066
@@ -3060,10 +3074,12 @@ at::Tensor& XLANativeFunctions::random_(
3060
3074
}
3061
3075
XLATensorPtr self_tensor = GetValueOrThrow (bridge::GetXlaTensor (self));
3062
3076
at::ScalarType dtype = self_tensor->dtype ();
3077
+
3063
3078
// Prevent "to_val" from overflowing with at::ScalarType::Long.
3064
3079
int64_t inc = (dtype == at::ScalarType::Long) ? 0 : 1 ;
3065
- tensor_methods::random_ (self_tensor, 0 ,
3066
- GetIntegerUpperLimitForType (dtype) + inc);
3080
+ int64_t to_val = GetIntegerUpperLimitForType (dtype) + inc;
3081
+
3082
+ OkOrThrow (tensor_methods::random_ (self_tensor, 0 , to_val));
3067
3083
return self;
3068
3084
}
3069
3085
0 commit comments