Skip to content

Commit f8fa1d5

Browse files
committed
Improve error handling and error messages for uniform_.
1 parent 6ac4a7c commit f8fa1d5

File tree

4 files changed

+32
-8
lines changed

4 files changed

+32
-8
lines changed

test/test_operations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2367,6 +2367,19 @@ def test_isneginf_no_fallback(self):
23672367
t = t.to(torch.float16)
23682368
self._test_no_fallback(torch.isneginf, (t,))
23692369

2370+
def test_uniform__raises_error_on_invalid_range(self):
2371+
device = torch_xla.device()
2372+
a = torch.empty(5, 5, device=device)
2373+
from_ = 5.
2374+
to_ = 2.
2375+
2376+
try:
2377+
a.uniform_(from_, to_)
2378+
except RuntimeError as e:
2379+
expected_error = (
2380+
"uniform_(): expected `from` (5) to be smaller or equal `to` (2).")
2381+
self.assertEqual(str(e), expected_error)
2382+
23702383

23712384
class MNISTComparator(nn.Module):
23722385

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3919,8 +3919,9 @@ at::Tensor& XLANativeFunctions::uniform_(
39193919
return at::native::call_fallback_fn<&xla_fallback, ATEN_OP(uniform_)>::call(
39203920
self, from, to, generator);
39213921
}
3922-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
3923-
tensor_methods::uniform_(xla_self, from, to);
3922+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
3923+
bridge::GetXlaTensor(self));
3924+
XLA_THROW_IF_ERROR(tensor_methods::uniform_(xla_self, from, to));
39243925
return self;
39253926
}
39263927

torch_xla/csrc/tensor_methods.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,15 @@ absl::Status CheckStackAtLeastOneTensor(
581581
return absl::OkStatus();
582582
}
583583

584+
absl::Status CheckUniformRangeIsValid(double from, double to) {
585+
if (from > to) {
586+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
587+
absl::StrCat("uniform_(): expected `from` (", from,
588+
") to be smaller or equal `to` (", to, ").")));
589+
}
590+
return absl::OkStatus();
591+
}
592+
584593
} // namespace
585594

586595
//////////////////////////////////////////////////////////////////////////////
@@ -3728,15 +3737,16 @@ std::vector<XLATensorPtr> unbind(const XLATensorPtr& input, int64_t dim) {
37283737
return slices;
37293738
}
37303739

3731-
void uniform_(XLATensorPtr& input, double from, double to) {
3732-
XLA_CHECK_LE(from, to);
3733-
auto input_shape = input->shape();
3740+
absl::Status uniform_(XLATensorPtr& input, double from, double to) {
3741+
XLA_RETURN_IF_ERROR(CheckUniformRangeIsValid(from, to));
3742+
xla::Shape input_shape = input->shape();
37343743
input->SetInPlaceIrValue(torch_xla::MakeNode<Uniform>(
37353744
XLAGraphExecutor::Get()->GetIrValueForScalar(
3736-
from, input_shape.get().element_type(), input->GetDevice()),
3745+
from, input_shape.element_type(), input->GetDevice()),
37373746
XLAGraphExecutor::Get()->GetIrValueForScalar(
3738-
to, input_shape.get().element_type(), input->GetDevice()),
3747+
to, input_shape.element_type(), input->GetDevice()),
37393748
XLAGraphExecutor::Get()->GetRngSeed(input->GetDevice()), input_shape));
3749+
return absl::OkStatus();
37403750
}
37413751

37423752
XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim) {

torch_xla/csrc/tensor_methods.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ std::tuple<XLATensorPtr, XLATensorPtr> triangular_solve(
989989
// removed.
990990
std::vector<XLATensorPtr> unbind(const XLATensorPtr& input, int64_t dim);
991991

992-
void uniform_(XLATensorPtr& input, double from, double to);
992+
absl::Status uniform_(XLATensorPtr& input, double from, double to);
993993

994994
// Insert a dimension of size one at the specified position.
995995
XLATensorPtr unsqueeze(const XLATensorPtr& input, int64_t dim);

0 commit comments

Comments
 (0)