From f8fa1d5bbf4d8486a3b64df3934df4a3658ae19b Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Wed, 10 Sep 2025 16:14:28 -0300 Subject: [PATCH 1/2] Improve error handling and error messages for uniform_. --- test/test_operations.py | 13 +++++++++++++ torch_xla/csrc/aten_xla_type.cpp | 5 +++-- torch_xla/csrc/tensor_methods.cpp | 20 +++++++++++++++----- torch_xla/csrc/tensor_methods.h | 2 +- 4 files changed, 32 insertions(+), 8 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index a02bb746666..a8f61950136 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2367,6 +2367,19 @@ def test_isneginf_no_fallback(self): t = t.to(torch.float16) self._test_no_fallback(torch.isneginf, (t,)) + def test_uniform__raises_error_on_invalid_range(self): + device = torch_xla.device() + a = torch.empty(5, 5, device=device) + from_ = 5. + to_ = 2. + + try: + a.uniform_(from_, to_) + except RuntimeError as e: + expected_error = ( + "uniform_(): expected `from` (5) to be smaller or equal `to` (2).") + self.assertEqual(str(e), expected_error) + class MNISTComparator(nn.Module): diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index b6a8484a250..0e35a9a8be8 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -3919,8 +3919,9 @@ at::Tensor& XLANativeFunctions::uniform_( return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(uniform_)>::call( self, from, to, generator); } - XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self)); - tensor_methods::uniform_(xla_self, from, to); + XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self, + bridge::GetXlaTensor(self)); + XLA_THROW_IF_ERROR(tensor_methods::uniform_(xla_self, from, to)); return self; } diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index a52c955ae55..5992c76be92 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -581,6 +581,15 @@ absl::Status CheckStackAtLeastOneTensor( return absl::OkStatus(); } +absl::Status CheckUniformRangeIsValid(double from, double to) { + if (from > to) { + return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError( + absl::StrCat("uniform_(): expected `from` (", from, + ") to be smaller or equal `to` (", to, ")."))); + } + return absl::OkStatus(); +} + } // namespace ////////////////////////////////////////////////////////////////////////////// @@ -3728,15 +3737,16 @@ std::vector unbind(const XLATensorPtr& input, int64_t dim) { return slices; } -void uniform_(XLATensorPtr& input, double from, double to) { - XLA_CHECK_LE(from, to); - auto input_shape = input->shape(); +absl::Status uniform_(XLATensorPtr& input, double from, double to) { + XLA_RETURN_IF_ERROR(CheckUniformRangeIsValid(from, to)); + xla::Shape input_shape = input->shape(); input->SetInPlaceIrValue(torch_xla::MakeNode( XLAGraphExecutor::Get()->GetIrValueForScalar( - from, input_shape.get().element_type(), input->GetDevice()), + from, input_shape.element_type(), input->GetDevice()), XLAGraphExecutor::Get()->GetIrValueForScalar( - to, input_shape.get().element_type(), input->GetDevice()), + to, input_shape.element_type(), input->GetDevice()), XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape)); + return absl::OkStatus(); } XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim) { diff --git a/torch_xla/csrc/tensor_methods.h b/torch_xla/csrc/tensor_methods.h index fb0e39cc861..ea9cf006d0e 100644 --- a/torch_xla/csrc/tensor_methods.h +++ b/torch_xla/csrc/tensor_methods.h @@ -989,7 +989,7 @@ std::tuple triangular_solve( // removed. std::vector unbind(const XLATensorPtr& input, int64_t dim); -void uniform_(XLATensorPtr& input, double from, double to); +absl::Status uniform_(XLATensorPtr& input, double from, double to); // Insert a dimension of size one at the specified position. XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim); From 33b71b3d82d72e64ce3c3c5550b64d3162e0ff85 Mon Sep 17 00:00:00 2001 From: Yukio Siraichi Date: Fri, 12 Sep 2025 16:24:31 -0300 Subject: [PATCH 2/2] Move test. --- test/test_operations.py | 13 ------------- test/test_ops_error_message.py | 15 +++++++++++++++ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index a8f61950136..a02bb746666 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -2367,19 +2367,6 @@ def test_isneginf_no_fallback(self): t = t.to(torch.float16) self._test_no_fallback(torch.isneginf, (t,)) - def test_uniform__raises_error_on_invalid_range(self): - device = torch_xla.device() - a = torch.empty(5, 5, device=device) - from_ = 5. - to_ = 2. - - try: - a.uniform_(from_, to_) - except RuntimeError as e: - expected_error = ( - "uniform_(): expected `from` (5) to be smaller or equal `to` (2).") - self.assertEqual(str(e), expected_error) - class MNISTComparator(nn.Module): diff --git a/test/test_ops_error_message.py b/test/test_ops_error_message.py index 42858fc84f6..e07224a666d 100644 --- a/test/test_ops_error_message.py +++ b/test/test_ops_error_message.py @@ -235,6 +235,21 @@ def test(): expect="""trace(): expected the input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor).""" ) + def test_uniform__raises_error_on_invalid_range(self): + device = torch_xla.device() + a = torch.empty(5, 5, device=device) + from_ = 5. + to_ = 2. + + def test(): + return a.uniform_(from_, to_) + + self.assertExpectedRaisesInline( + exc_type=RuntimeError, + callable=test, + expect="""uniform_(): expected `from` (5) to be smaller or equal `to` (2).""" + ) + if __name__ == "__main__": unittest.main()