Skip to content

Commit 6d755ee

Browse files
authored
stack: improve error handling and error messages. (#9629)
This PR refactors the `stack` operation implementation by improving its error message, and returning a status type value. **Key Changes:** - `tensor_methods::stack` returns `StatusOr<XLATensorPtr>` - Improve error messages and error handling - Create `CheckStackAtLeastOneTensor` function
1 parent a66cfc3 commit 6d755ee

File tree

5 files changed

+40
-19
lines changed

5 files changed

+40
-19
lines changed

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <torch/csrc/lazy/core/tensor_util.h>
1515
#include <torch/csrc/lazy/core/util.h>
1616

17+
#include <iterator>
1718
#include <mutex>
1819
#include <optional>
1920

@@ -3696,12 +3697,16 @@ at::Tensor XLANativeFunctions::squeeze_copy(const at::Tensor& self,
36963697
at::Tensor XLANativeFunctions::stack(at::TensorList tensors, int64_t dim) {
36973698
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
36983699
at::ScalarType result_type = at::native::result_type(tensors);
3699-
std::vector<at::Tensor> c_tensors(tensors.size());
3700-
std::transform(tensors.begin(), tensors.end(), c_tensors.begin(),
3701-
[=](const at::Tensor& t) { return t.to(result_type); });
3702-
XLA_ASSIGN_OR_THROW(std::vector<absl_nonnull XLATensorPtr> xla_c_tensors,
3703-
bridge::GetXlaTensors(c_tensors));
3704-
return bridge::AtenFromXlaTensor(tensor_methods::stack(xla_c_tensors, dim));
3700+
std::vector<absl_nonnull XLATensorPtr> xla_tensors;
3701+
std::transform(tensors.begin(), tensors.end(),
3702+
std::back_inserter(xla_tensors), [=](const at::Tensor& t) {
3703+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_t,
3704+
bridge::GetXlaTensor(t.to(result_type)));
3705+
return xla_t;
3706+
});
3707+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
3708+
tensor_methods::stack(xla_tensors, dim));
3709+
return bridge::AtenFromXlaTensor(std::move(output));
37053710
}
37063711

37073712
at::Tensor XLANativeFunctions::std(const at::Tensor& self, bool unbiased) {

torch_xla/csrc/ops/index_ops.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ XLATensorPtr IndexByTensors(const XLATensorPtr& base,
339339
canonical_indices.front()->shape().get().dimensions_size();
340340
// Stack the indices to allow the whole multi-indexing to be dispatched with a
341341
// single gather.
342-
XLATensorPtr indices_nd =
343-
tensor_methods::stack(canonical_indices, indices_rank);
342+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr indices_nd,
343+
tensor_methods::stack(canonical_indices, indices_rank));
344344
return XLATensor::Create(
345345
torch_xla::MakeNode<IndexGet>(base->GetIrValue(),
346346
indices_nd->GetIrValue(), start_dim),
@@ -356,11 +356,11 @@ torch::lazy::Value IndexPutByTensors(
356356
}
357357
auto canonical_indices = WrapIndicesOnce(base, indices, start_dim);
358358
int64_t indices_rank =
359-
canonical_indices.front()->shape().get().dimensions_size();
359+
canonical_indices.front()->shape().get().dimensions().size();
360360
// Stack the indices to allow the whole multi-indexing to be dispatched with a
361361
// single scatter.
362-
XLATensorPtr indices_nd =
363-
tensor_methods::stack(canonical_indices, indices_rank);
362+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr indices_nd,
363+
tensor_methods::stack(canonical_indices, indices_rank));
364364
return torch_xla::MakeNode<Permute>(
365365
torch_xla::MakeNode<IndexPut>(base->GetIrValue(),
366366
indices_nd->GetIrValue(), start_dim,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,15 @@ absl::Status CheckRollDimsAndShiftsAreCompatible(
538538
return absl::OkStatus();
539539
}
540540

541+
absl::Status CheckStackAtLeastOneTensor(
542+
absl::Span<const absl_nonnull XLATensorPtr> tensors) {
543+
if (tensors.size() == 0) {
544+
return XLA_ERROR_WITH_LOCATION(
545+
absl::InvalidArgumentError("stack(): expected at least one tensor."));
546+
}
547+
return absl::OkStatus();
548+
}
549+
541550
} // namespace
542551

543552
//////////////////////////////////////////////////////////////////////////////
@@ -3422,14 +3431,18 @@ XLATensorPtr squeeze(const XLATensorPtr& input, std::vector<int64_t> dims) {
34223431
return view(input, output_dimensions);
34233432
}
34243433

3425-
XLATensorPtr stack(absl::Span<const XLATensorPtr> tensors, int64_t dim) {
3426-
XLA_CHECK_GT(tensors.size(), 0);
3434+
absl::StatusOr<absl_nonnull XLATensorPtr> stack(
3435+
absl::Span<const absl_nonnull XLATensorPtr> tensors, int64_t dim) {
3436+
XLA_RETURN_IF_ERROR(CheckStackAtLeastOneTensor(tensors));
3437+
34273438
std::vector<torch::lazy::Value> values;
3428-
for (auto& tensor : tensors) {
3429-
values.push_back(tensor->GetIrValue());
3430-
}
3439+
std::transform(
3440+
tensors.begin(), tensors.end(), std::back_inserter(values),
3441+
[](const absl_nonnull XLATensorPtr t) { return t->GetIrValue(); });
3442+
34313443
int64_t canonical_dim = torch::lazy::GetCanonicalDimensionIndex(
3432-
dim, tensors.front()->shape().get().dimensions_size() + 1);
3444+
dim, tensors.front()->shape().get().dimensions().size() + 1);
3445+
34333446
return tensors[0]->CreateFrom(
34343447
torch_xla::MakeNode<Stack>(values, canonical_dim));
34353448
}

torch_xla/csrc/tensor_methods.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,8 @@ XLATensorPtr squeeze(const XLATensorPtr& input, std::vector<int64_t> dims);
926926
void squeeze_(XLATensorPtr& input);
927927
void squeeze_(XLATensorPtr& input, int64_t dim);
928928

929-
XLATensorPtr stack(absl::Span<const XLATensorPtr> tensors, int64_t dim);
929+
absl::StatusOr<absl_nonnull XLATensorPtr> stack(
930+
absl::Span<const absl_nonnull XLATensorPtr> tensors, int64_t dim);
930931

931932
XLATensorPtr std(const XLATensorPtr& input, std::vector<int64_t> dimensions,
932933
bool keep_reduced_dimensions, double correction);

torch_xla/csrc/tensor_ops.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ XLATensorPtr Cross(const XLATensorPtr& input, const XLATensorPtr& other,
6262
XLATensorPtr s3 = tensor_methods::sub(tensor_methods::mul(u1, v2),
6363
tensor_methods::mul(u2, v1), one);
6464
// Stack the terms into one result tensor.
65-
return tensor_methods::stack({s1, s2, s3}, canonical_dim);
65+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
66+
tensor_methods::stack({s1, s2, s3}, canonical_dim));
67+
return output;
6668
}
6769

6870
XLATensorPtr MakeMatrixWithDiagonal(const XLATensorPtr& input,

0 commit comments

Comments
 (0)