Skip to content

Commit 92dcabc

Browse files
authored
mm: improve error handling and error messages. (#9621)
This PR refactors the `mm` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - Make `tensor_methods::mm` return `Status` - Refactor `XLANativeFunctions::mm` overloads to handle the status values - Improve error messages and error handling
1 parent 8274f94 commit 92dcabc

File tree

4 files changed

+73
-10
lines changed

4 files changed

+73
-10
lines changed

test/test_operations.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2503,6 +2503,34 @@ def test_random__raises_error_on_value_out_of_type_value_range(self):
25032503
"than the upper bound.")
25042504
self.assertEqual(str(e), expected_error)
25052505

2506+
def test_mm_raises_error_on_non_matrix_input(self):
2507+
device = torch_xla.device()
2508+
a = torch.rand(2, 2, 2, device=device)
2509+
b = torch.rand(2, 2, device=device)
2510+
2511+
try:
2512+
torch.mm(a, b)
2513+
except RuntimeError as e:
2514+
expected_error = (
2515+
"mm(): expected the first input tensor f32[2,2,2] to be a "
2516+
"matrix (i.e. a 2D tensor).")
2517+
self.assertEqual(str(e), expected_error)
2518+
2519+
def test_mm_raises_error_on_incompatible_shapes(self):
2520+
device = torch_xla.device()
2521+
a = torch.rand(2, 5, device=device)
2522+
b = torch.rand(8, 2, device=device)
2523+
2524+
try:
2525+
torch.mm(a, b)
2526+
except RuntimeError as e:
2527+
expected_error = (
2528+
"mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. "
2529+
"Expected the size of dimension 1 of the first input tensor (5) "
2530+
"to be equal the size of dimension 0 of the second input "
2531+
"tensor (8).")
2532+
self.assertEqual(str(e), expected_error)
2533+
25062534

25072535
class MNISTComparator(nn.Module):
25082536

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2495,7 +2495,9 @@ at::Tensor XLANativeFunctions::mm(const at::Tensor& self,
24952495
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
24962496
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
24972497
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_mat2, bridge::GetXlaTensor(mat2));
2498-
return bridge::AtenFromXlaTensor(tensor_methods::mm(xla_self, xla_mat2));
2498+
XLA_ASSIGN_OR_THROW(XLATensorPtr output,
2499+
tensor_methods::mm(xla_self, xla_mat2));
2500+
return bridge::AtenFromXlaTensor(std::move(output));
24992501
}
25002502

25012503
at::Tensor XLANativeFunctions::mse_loss(const at::Tensor& self,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <iterator>
1212

1313
#include "absl/log/absl_check.h"
14+
#include "absl/status/status.h"
1415
#include "absl/strings/str_cat.h"
1516
#include "absl/strings/str_split.h"
1617
#include "torch_xla/csrc/LazyIr.h"
@@ -453,14 +454,14 @@ absl::Status CheckGatherRanksAreEqual(const XLATensorPtr& input,
453454
return absl::OkStatus();
454455
}
455456

456-
// Checks that all index dimensions are smaller or equal to those of input,
457-
// except on dimension canonical_dim.
458-
absl::Status CheckGatherDimensionsAreCompatible(const XLATensorPtr& input,
459-
const XLATensorPtr& index,
460-
int64_t canonical_dim) {
457+
// Checks that all index dimension sizes are smaller or equal to those of
458+
// input, except on dimension canonical_dim.
459+
absl::Status CheckGatherSizesAreCompatible(const XLATensorPtr& input,
460+
const XLATensorPtr& index,
461+
int64_t canonical_dim) {
461462
// Dimensions that fail the "smaller or equal" condition.
462463
std::vector<int64_t> bad_dims;
463-
for (int64_t dim = 0; dim < input->shape().get().dimensions_size(); dim++) {
464+
for (int64_t dim = 0; dim < input->shape().get().dimensions().size(); dim++) {
464465
if (dim != canonical_dim && input->size(dim) < index->size(dim)) {
465466
bad_dims.push_back(dim);
466467
}
@@ -478,6 +479,33 @@ absl::Status CheckGatherDimensionsAreCompatible(const XLATensorPtr& input,
478479
return absl::OkStatus();
479480
}
480481

482+
absl::Status CheckMMInputIsMatrix(const XLATensorPtr& mat,
483+
const std::string_view arg) {
484+
xla::Shape shape = mat->shape();
485+
if (shape.dimensions().size() != 2) {
486+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
487+
absl::StrCat("mm(): expected the ", arg, " input tensor ",
488+
shape.ToString(), " to be a matrix (i.e. a 2D tensor).")));
489+
}
490+
return absl::OkStatus();
491+
}
492+
493+
absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1,
494+
const XLATensorPtr& mat2) {
495+
xla::Shape shape1 = mat1->shape();
496+
xla::Shape shape2 = mat2->shape();
497+
if (shape1.dimensions(1) != shape2.dimensions(0)) {
498+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
499+
"mm(): cannot matrix-multiply tensors ", shape1.ToString(), " and ",
500+
shape2.ToString(),
501+
". Expected the size of dimension 1 of the first input tensor (",
502+
shape1.dimensions(1),
503+
") to be equal the size of dimension 0 of the second input tensor (",
504+
shape2.dimensions(0), ").")));
505+
}
506+
return absl::OkStatus();
507+
}
508+
481509
} // namespace
482510

483511
//////////////////////////////////////////////////////////////////////////////
@@ -1844,7 +1872,7 @@ absl::StatusOr<absl_nonnull XLATensorPtr> gather(const XLATensorPtr& input,
18441872
dim, input->shape().get().dimensions_size());
18451873
XLA_RETURN_IF_ERROR(CheckGatherRanksAreEqual(input, index));
18461874
XLA_RETURN_IF_ERROR(
1847-
CheckGatherDimensionsAreCompatible(input, index, canonical_dim));
1875+
CheckGatherSizesAreCompatible(input, index, canonical_dim));
18481876
return input->CreateFrom(torch_xla::MakeNode<Gather>(
18491877
input->GetIrValue(), canonical_dim, index->GetIrValue()));
18501878
}
@@ -2349,7 +2377,11 @@ XLATensorPtr mish(const XLATensorPtr& input) {
23492377
tensor_ops::Softplus(input, 1, 20)->GetIrValue()));
23502378
}
23512379

2352-
XLATensorPtr mm(const XLATensorPtr& input, const XLATensorPtr& weight) {
2380+
absl::StatusOr<XLATensorPtr> mm(const XLATensorPtr& input,
2381+
const XLATensorPtr& weight) {
2382+
XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(input, "first"));
2383+
XLA_RETURN_IF_ERROR(CheckMMInputIsMatrix(weight, "second"));
2384+
XLA_RETURN_IF_ERROR(CheckMMMatrixSizesAreCompatible(input, weight));
23532385
return input->CreateFrom(Dot(input->GetIrValue(), weight->GetIrValue()));
23542386
}
23552387

torch_xla/csrc/tensor_methods.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,8 @@ void min_out(XLATensorPtr& min, XLATensorPtr& min_indices,
646646

647647
XLATensorPtr mish(const XLATensorPtr& input);
648648

649-
XLATensorPtr mm(const XLATensorPtr& input, const XLATensorPtr& weight);
649+
absl::StatusOr<XLATensorPtr> mm(const XLATensorPtr& input,
650+
const XLATensorPtr& weight);
650651

651652
XLATensorPtr mse_loss(const XLATensorPtr& input, const XLATensorPtr& target,
652653
int64_t reduction);

0 commit comments

Comments
 (0)