@@ -442,6 +442,43 @@ absl::Status CheckFullConcreteSizesArePositive(at::SymIntArrayRef sym_sizes) {
442
442
});
443
443
}
444
444
445
+ absl::Status CheckGatherRanksAreEqual (const XLATensorPtr& input,
446
+ const XLATensorPtr& index) {
447
+ int64_t input_rank = input->shape ().get ().dimensions_size ();
448
+ int64_t index_rank = index->shape ().get ().dimensions_size ();
449
+ if (input_rank != index_rank) {
450
+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
451
+ " gather(): expected rank of input (" , input_rank, " ) and index (" ,
452
+ index_rank, " ) tensors to be the same." )));
453
+ }
454
+ return absl::OkStatus ();
455
+ }
456
+
457
+ // Checks that all index dimensions are smaller or equal to those of input,
458
+ // except on dimension canonical_dim.
459
+ absl::Status CheckGatherDimensionsAreCompatible (const XLATensorPtr& input,
460
+ const XLATensorPtr& index,
461
+ int64_t canonical_dim) {
462
+ // Dimensions that fail the "smaller or equal" condition.
463
+ std::vector<int64_t > bad_dims;
464
+ for (int64_t dim = 0 ; dim < input->shape ().get ().dimensions_size (); dim++) {
465
+ if (dim != canonical_dim && input->size (dim) < index->size (dim)) {
466
+ bad_dims.push_back (dim);
467
+ }
468
+ }
469
+ if (!bad_dims.empty ()) {
470
+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
471
+ " gather(): expected sizes of index [" ,
472
+ absl::StrJoin (index->shape ().get ().dimensions (), /* sep= */ " , " ),
473
+ " ] to be smaller or equal those of input [" ,
474
+ absl::StrJoin (input->shape ().get ().dimensions (), /* sep= */ " , " ),
475
+ " ] on all dimensions, except on dimension " , canonical_dim,
476
+ " . However, that's not true on dimensions [" ,
477
+ absl::StrJoin (bad_dims, /* sep= */ " , " ), " ]." )));
478
+ }
479
+ return absl::OkStatus ();
480
+ }
481
+
445
482
} // namespace
446
483
447
484
// ////////////////////////////////////////////////////////////////////////////
@@ -1838,18 +1875,14 @@ absl::StatusOr<absl_nonnull XLATensorPtr> full_symint(
1838
1875
device, scalar_type);
1839
1876
}
1840
1877
1841
- XLATensorPtr gather (const XLATensorPtr& input, int64_t dim,
1842
- const XLATensorPtr& index) {
1843
- xla::Shape input_shape = input->shape ();
1844
- xla::Shape index_shape = index->shape ();
1845
- XLA_CHECK_EQ (input_shape.dimensions_size (), index_shape.dimensions_size ());
1878
+ absl::StatusOr<absl_nonnull XLATensorPtr> gather (const XLATensorPtr& input,
1879
+ int64_t dim,
1880
+ const XLATensorPtr& index) {
1846
1881
int64_t canonical_dim = torch::lazy::GetCanonicalDimensionIndex (
1847
- dim, input_shape.dimensions_size ());
1848
- for (size_t dim = 0 ; dim < input_shape.dimensions_size (); dim++) {
1849
- if (dim != canonical_dim) {
1850
- XLA_CHECK_LE (index->size (dim), input->size (dim));
1851
- }
1852
- }
1882
+ dim, input->shape ().get ().dimensions_size ());
1883
+ XLA_RETURN_IF_ERROR (CheckGatherRanksAreEqual (input, index));
1884
+ XLA_RETURN_IF_ERROR (
1885
+ CheckGatherDimensionsAreCompatible (input, index, canonical_dim));
1853
1886
return input->CreateFrom (torch_xla::MakeNode<Gather>(
1854
1887
input->GetIrValue (), canonical_dim, index->GetIrValue ()));
1855
1888
}
0 commit comments