Skip to content

Commit 23158fd

Browse files
authored
div: improve error handling and error messages. (#9549)
1 parent 38e0f03 commit 23158fd

File tree

5 files changed

+32
-17
lines changed

5 files changed

+32
-17
lines changed

test/test_operations.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2486,6 +2486,17 @@ def test_cat_raises_error_on_incompatible_shapes(self):
24862486
"or that either of them was a 1D empty tensor of size (0,).")
24872487
self.assertEqual(str(e), expected_error)
24882488

2489+
def test_div_raises_error_on_invalid_rounding_mode(self):
2490+
a = torch.rand(2, 2, device=torch_xla.device())
2491+
2492+
try:
2493+
torch.div(a, 2, rounding_mode="bad")
2494+
except RuntimeError as e:
2495+
expected_error = (
2496+
"div(): invalid rounding mode `bad`. Expected it to be either "
2497+
"'trunc', 'floor', or be left unspecified.")
2498+
self.assertEqual(str(e), expected_error)
2499+
24892500

24902501
class MNISTComparator(nn.Module):
24912502

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1536,8 +1536,9 @@ at::Tensor XLANativeFunctions::div(
15361536
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
15371537
at::ScalarType dtype = at::result_type(self, other);
15381538
auto operands = GetBinaryOperands(self, UnwrapNumber(other, dtype));
1539-
return bridge::AtenFromXlaTensor(tensor_methods::div(
1539+
auto output = GetValueOrThrow(tensor_methods::div(
15401540
operands.first, operands.second, rounding_mode, dtype));
1541+
return bridge::AtenFromXlaTensor(std::move(output));
15411542
}
15421543

15431544
at::Tensor XLANativeFunctions::div(const at::Tensor& self,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,9 +1421,10 @@ XLATensorPtr diagonal(const XLATensorPtr& input, int64_t offset, int64_t dim1,
14211421
input->GetIrValue(), offset, canonical_dim1, canonical_dim2));
14221422
}
14231423

1424-
XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other,
1425-
const std::optional<std::string_view>& rounding_mode,
1426-
std::optional<at::ScalarType> logical_element_type) {
1424+
absl::StatusOr<absl_nonnull XLATensorPtr> div(
1425+
const XLATensorPtr& input, const XLATensorPtr& other,
1426+
const std::optional<std::string_view>& rounding_mode,
1427+
std::optional<at::ScalarType> logical_element_type) {
14271428
at::ScalarType scalar_type =
14281429
at::typeMetaToScalarType(c10::get_default_dtype());
14291430
xla::PrimitiveType input_type = input->shape().get().element_type();
@@ -1446,8 +1447,10 @@ XLATensorPtr div(const XLATensorPtr& input, const XLATensorPtr& other,
14461447
} else if (*rounding_mode == "floor") {
14471448
res = torch_xla::MakeNode<Floor>(res);
14481449
} else {
1449-
XLA_CHECK(false)
1450-
<< "rounding_mode must be one of None, 'trunc', or 'floor'";
1450+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
1451+
absl::StrCat("div(): invalid rounding mode `", *rounding_mode,
1452+
"`. Expected it to be either 'trunc', 'floor', or be "
1453+
"left unspecified.")));
14511454
}
14521455
}
14531456

torch_xla/csrc/tensor_methods.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ XLATensorPtr diag(const XLATensorPtr& input, int64_t offset);
389389
XLATensorPtr diagonal(const XLATensorPtr& input, int64_t offset, int64_t dim1,
390390
int64_t dim2);
391391

392-
XLATensorPtr div(
392+
absl::StatusOr<absl_nonnull XLATensorPtr> div(
393393
const XLATensorPtr& input, const XLATensorPtr& other,
394394
const std::optional<std::string_view>& rounding_mode = std::nullopt,
395395
std::optional<at::ScalarType> logical_element_type = std::nullopt);

torch_xla/csrc/tensor_ops.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ XLATensorPtr SmoothL1LossBackward(const XLATensorPtr& grad_output,
148148
XLATensorPtr grad_scale = tensor_methods::get_dimensions_size(
149149
broadcasted_input,
150150
XlaHelpers::GetAllDimensions(broadcasted_input->shape()));
151-
return tensor_methods::mul(
152-
tensor_methods::div(elementwise_loss_backward, grad_scale),
153-
grad_output);
151+
XLATensorPtr div_result = GetValueOrThrow(
152+
tensor_methods::div(elementwise_loss_backward, grad_scale));
153+
return tensor_methods::mul(div_result, grad_output);
154154
}
155155
default:
156156
XLA_ERROR() << "Invalid reduction type: "
@@ -174,12 +174,12 @@ XLATensorPtr SoftplusBackward(const XLATensorPtr& grad_output,
174174
XLATensorPtr z = tensor_methods::exp(scaled_input);
175175
XLATensorPtr one_vec =
176176
tensor_methods::full_like(z, 1, z->GetDevice(), z->dtype());
177+
XLATensorPtr div = GetValueOrThrow(
178+
tensor_methods::div(z, tensor_methods::add(z, one_vec, 1)));
177179

178-
return tensor_methods::where(
179-
tensor_methods::gt(scaled_input, threshold), grad_output,
180-
tensor_methods::mul(
181-
grad_output,
182-
tensor_methods::div(z, tensor_methods::add(z, one_vec, 1))));
180+
return tensor_methods::where(tensor_methods::gt(scaled_input, threshold),
181+
grad_output,
182+
tensor_methods::mul(grad_output, div));
183183
}
184184

185185
XLATensorPtr Select(const XLATensorPtr& input, int64_t dim, int64_t index) {
@@ -223,8 +223,8 @@ XLATensorPtr EmbeddingDenseBackward(const XLATensorPtr& grad_output,
223223
XLATensorPtr grad_weights_scale =
224224
tensor_methods::index(counts, {indices_rank1}, 0);
225225
// Scale the value of the gradient by the histogram.
226-
grad = tensor_methods::div(
227-
grad, tensor_methods::unsqueeze(grad_weights_scale, 1));
226+
grad = GetValueOrThrow(tensor_methods::div(
227+
grad, tensor_methods::unsqueeze(grad_weights_scale, 1)));
228228
}
229229
// Don't accumulate gradients for indices which are equal with the given
230230
// padding_idx.

0 commit comments

Comments
 (0)