11
11
#include < iterator>
12
12
13
13
#include " absl/log/absl_check.h"
14
+ #include " absl/status/status.h"
14
15
#include " absl/strings/str_cat.h"
15
16
#include " absl/strings/str_split.h"
16
17
#include " torch_xla/csrc/LazyIr.h"
@@ -453,14 +454,14 @@ absl::Status CheckGatherRanksAreEqual(const XLATensorPtr& input,
453
454
return absl::OkStatus ();
454
455
}
455
456
456
- // Checks that all index dimensions are smaller or equal to those of input,
457
- // except on dimension canonical_dim.
458
- absl::Status CheckGatherDimensionsAreCompatible (const XLATensorPtr& input,
459
- const XLATensorPtr& index,
460
- int64_t canonical_dim) {
457
+ // Checks that all index dimension sizes are smaller or equal to those of
458
+ // input, except on dimension canonical_dim.
459
+ absl::Status CheckGatherSizesAreCompatible (const XLATensorPtr& input,
460
+ const XLATensorPtr& index,
461
+ int64_t canonical_dim) {
461
462
// Dimensions that fail the "smaller or equal" condition.
462
463
std::vector<int64_t > bad_dims;
463
- for (int64_t dim = 0 ; dim < input->shape ().get ().dimensions_size (); dim++) {
464
+ for (int64_t dim = 0 ; dim < input->shape ().get ().dimensions (). size (); dim++) {
464
465
if (dim != canonical_dim && input->size (dim) < index->size (dim)) {
465
466
bad_dims.push_back (dim);
466
467
}
@@ -478,6 +479,33 @@ absl::Status CheckGatherDimensionsAreCompatible(const XLATensorPtr& input,
478
479
return absl::OkStatus ();
479
480
}
480
481
482
+ absl::Status CheckMMInputIsMatrix (const XLATensorPtr& mat,
483
+ const std::string_view arg) {
484
+ xla::Shape shape = mat->shape ();
485
+ if (shape.dimensions ().size () != 2 ) {
486
+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (
487
+ absl::StrCat (" mm(): expected the " , arg, " input tensor " ,
488
+ shape.ToString (), " to be a matrix (i.e. a 2D tensor)." )));
489
+ }
490
+ return absl::OkStatus ();
491
+ }
492
+
493
+ absl::Status CheckMMMatrixSizesAreCompatible (const XLATensorPtr& mat1,
494
+ const XLATensorPtr& mat2) {
495
+ xla::Shape shape1 = mat1->shape ();
496
+ xla::Shape shape2 = mat2->shape ();
497
+ if (shape1.dimensions (1 ) != shape2.dimensions (0 )) {
498
+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
499
+ " mm(): cannot matrix-multiply tensors " , shape1.ToString (), " and " ,
500
+ shape2.ToString (),
501
+ " . Expected the size of dimension 1 of the first input tensor (" ,
502
+ shape1.dimensions (1 ),
503
+ " ) to be equal the size of dimension 0 of the second input tensor (" ,
504
+ shape2.dimensions (0 ), " )." )));
505
+ }
506
+ return absl::OkStatus ();
507
+ }
508
+
481
509
} // namespace
482
510
483
511
// ////////////////////////////////////////////////////////////////////////////
@@ -1844,7 +1872,7 @@ absl::StatusOr<absl_nonnull XLATensorPtr> gather(const XLATensorPtr& input,
1844
1872
dim, input->shape ().get ().dimensions_size ());
1845
1873
XLA_RETURN_IF_ERROR (CheckGatherRanksAreEqual (input, index));
1846
1874
XLA_RETURN_IF_ERROR (
1847
- CheckGatherDimensionsAreCompatible (input, index, canonical_dim));
1875
+ CheckGatherSizesAreCompatible (input, index, canonical_dim));
1848
1876
return input->CreateFrom (torch_xla::MakeNode<Gather>(
1849
1877
input->GetIrValue (), canonical_dim, index->GetIrValue ()));
1850
1878
}
@@ -2349,7 +2377,11 @@ XLATensorPtr mish(const XLATensorPtr& input) {
2349
2377
tensor_ops::Softplus (input, 1 , 20 )->GetIrValue ()));
2350
2378
}
2351
2379
2352
- XLATensorPtr mm (const XLATensorPtr& input, const XLATensorPtr& weight) {
2380
+ absl::StatusOr<XLATensorPtr> mm (const XLATensorPtr& input,
2381
+ const XLATensorPtr& weight) {
2382
+ XLA_RETURN_IF_ERROR (CheckMMInputIsMatrix (input, " first" ));
2383
+ XLA_RETURN_IF_ERROR (CheckMMInputIsMatrix (weight, " second" ));
2384
+ XLA_RETURN_IF_ERROR (CheckMMMatrixSizesAreCompatible (input, weight));
2353
2385
return input->CreateFrom (Dot (input->GetIrValue (), weight->GetIrValue ()));
2354
2386
}
2355
2387
0 commit comments