Skip to content

Commit a511691

Browse files
authored
trace: improve error handling and error messages. (#9630)
This PR refactors the `trace` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::trace` return `StatusOr<XLATensorPtr>` - Improve error messages and error handling - Renamed `CheckMMInputIsMatrix` to `CheckInputIsMatrix` - Added a new parameter for specifying the operation name, so as to build a better error message
1 parent 302c3f1 commit a511691

File tree

4 files changed

+35
-18
lines changed

4 files changed

+35
-18
lines changed

test/test_ops_error_message.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,19 @@ def test():
222222
expect="""roll(): expected `dims` [0] (size=1) to match the size of `shifts` [2, 2] (size=2)."""
223223
)
224224

225+
def test_trace_raises_error_on_non_matrix_input(self):
226+
device = torch_xla.device()
227+
a = torch.rand(2, 2, 2, device=device)
228+
229+
def test():
230+
torch.trace(a)
231+
232+
self.assertExpectedRaisesInline(
233+
exc_type=RuntimeError,
234+
callable=test,
235+
expect="""trace(): expected the input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor)."""
236+
)
237+
225238

226239
if __name__ == "__main__":
227240
unittest.main()

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3875,8 +3875,11 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::topk(
38753875

38763876
at::Tensor XLANativeFunctions::trace(const at::Tensor& self) {
38773877
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
3878-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
3879-
return bridge::AtenFromXlaTensor(tensor_methods::trace(xla_self));
3878+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
3879+
bridge::GetXlaTensor(self));
3880+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
3881+
tensor_methods::trace(xla_self));
3882+
return bridge::AtenFromXlaTensor(std::move(output));
38803883
}
38813884

38823885
at::Tensor XLANativeFunctions::transpose_copy(const at::Tensor& self,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -511,13 +511,16 @@ absl::Status CheckGatherSizesAreCompatible(const XLATensorPtr& input,
511511
return absl::OkStatus();
512512
}
513513

514-
absl::Status CheckMMInputIsMatrix(const XLATensorPtr& mat,
515-
const std::string_view arg) {
516-
xla::Shape shape = mat->shape();
514+
absl::Status CheckInputIsMatrix(const XLATensorPtr& tensor,
515+
const std::string_view op,
516+
const std::string_view arg = "") {
517+
xla::Shape shape = tensor->shape();
517518
if (shape.dimensions().size() != 2) {
518-
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
519-
absl::StrCat("mm(): expected the ", arg, " input tensor ",
520-
shape.ToString(), " to be a matrix (i.e. a 2D tensor).")));
519+
const std::string arg_with_trailing_space =
520+
arg.empty() ? std::string("") : absl::StrCat(arg, " ");
521+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
522+
op, "(): expected the ", arg_with_trailing_space, "input tensor ",
523+
shape.ToString(), " to be a matrix (i.e. a 2D tensor).")));
521524
}
522525
return absl::OkStatus();
523526
}
@@ -2452,8 +2455,8 @@ XLATensorPtr mish(const XLATensorPtr& input) {
24522455

24532456
absl::StatusOr<XLATensorPtr> mm(const XLATensorPtr& input,
24542457
const XLATensorPtr& weight) {
2455-
XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(input, "first"));
2456-
XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(weight, "second"));
2458+
XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "mm", "first"));
2459+
XLA_RETURN_IF_ERROR(CheckInputIsMatrix(weight, "mm", "second"));
24572460
XLA_RETURN_IF_ERROR(CheckMMMatrixSizesAreCompatible(input, weight));
24582461
return input->CreateFrom(Dot(input->GetIrValue(), weight->GetIrValue()));
24592462
}
@@ -3648,13 +3651,11 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
36483651
return std::make_tuple(t1, t2);
36493652
}
36503653

3651-
XLATensorPtr trace(const XLATensorPtr& input) {
3652-
auto input_shape_ref = input->shape();
3653-
XLA_CHECK_EQ((*input_shape_ref).dimensions_size(), 2)
3654-
<< "invalid argument for trace: expected a matrix";
3655-
torch::lazy::NodePtr eye = Identity((*input_shape_ref).dimensions(0),
3656-
(*input_shape_ref).dimensions(1),
3657-
(*input_shape_ref).element_type());
3654+
absl::StatusOr<absl_nonnull XLATensorPtr> trace(const XLATensorPtr& input) {
3655+
XLA_RETURN_IF_ERROR(CheckInputIsMatrix(input, "trace"));
3656+
xla::Shape shape = input->shape();
3657+
torch::lazy::NodePtr eye =
3658+
Identity(shape.dimensions(0), shape.dimensions(1), shape.element_type());
36583659
return sum(input->CreateFrom(eye * input->GetIrValue()), {0, 1}, false,
36593660
input->dtype());
36603661
}

torch_xla/csrc/tensor_methods.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,7 @@ std::tuple<XLATensorPtr, XLATensorPtr> topk(const XLATensorPtr& input,
973973
bool stable);
974974

975975
// Returns the sum of the elements of the diagonal of the input 2-D matrix.
976-
XLATensorPtr trace(const XLATensorPtr& input);
976+
absl::StatusOr<absl_nonnull XLATensorPtr> trace(const XLATensorPtr& input);
977977

978978
// Swap given dimensions of the input.
979979
XLATensorPtr transpose(const XLATensorPtr& input, int64_t dim0, int64_t dim1);

0 commit comments

Comments
 (0)