diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index 42858fc84f6..e07224a666d 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -235,6 +235,21 @@ def test(): expect="""trace(): expected the input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor).""" ) + def test_uniform__raises_error_on_invalid_range(self): + device = torch_xla.device() + a = torch.empty(5, 5, device=device) + from_ = 5. + to_ = 2. + + def test(): + return a.uniform_(from_, to_) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""uniform_(): expected `from` (5) to be smaller or equal `to` (2).""" + ) + if __name__ == "__main__": unittest.main() diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index b6a8484a250..0e35a9a8be8 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3919,8 +3919,9 @@ at::Tensor& XLANativeFunctions::uniform_( return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(uniform_)>::call( self, from, to, generator); } - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - tensor_methods::uniform_(xla_self, from, to); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_THROW_IF_ERROR(tensor_methods::uniform_(xla_self, from, to)); return self; } diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index a52c955ae55..5992c76be92 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -581,6 +581,15 @@ absl::Status CheckStackAtLeastOneTensor( return absl::OkStatus(); } +absl::Status CheckUniformRangeIsValid(double from, double to) { + if (from > to) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("uniform_(): expected `from` (", from, + ") to be smaller or equal `to` (", to, ")."))); + } + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -3728,15 +3737,16 @@ std::vector unbind(const XLATensorPtr& input, int64_t dim) { return slices; } -void uniform_(XLATensorPtr& input, double from, double to) { - XLA_CHECK_LE(from, to); - auto input_shape = input->shape(); +absl::Status uniform_(XLATensorPtr& input, double from, double to) { + XLA_RETURN_IF_ERROR(CheckUniformRangeIsValid(from, to)); + xla::Shape input_shape = input->shape(); input->SetInPlaceIrValue(torch_xla::MakeNode( XLAGraphExecutor::Get()->GetIrValueForScalar( - from, input_shape.get().element_type(), input->GetDevice()), + from, input_shape.element_type(), input->GetDevice()), XLAGraphExecutor::Get()->GetIrValueForScalar( - to, input_shape.get().element_type(), input->GetDevice()), + to, input_shape.element_type(), input->GetDevice()), XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape)); + return absl::OkStatus(); } XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index fb0e39cc861..ea9cf006d0e 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -989,7 +989,7 @@ std::tuple triangular_solve( // removed. std::vector unbind(const XLATensorPtr& input, int64_t dim); -void uniform_(XLATensorPtr& input, double from, double to); +absl::Status uniform_(XLATensorPtr& input, double from, double to); // Insert a dimension of size one at the specified position. XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim);