Skip to content

Commit cd3bd91

Browse files
authored
Error Handling: refactor GetXlaTensor and related functions to use status types. (#9510)
1 parent b0ffc49 commit cd3bd91

14 files changed

+848
-652
lines changed

test/cpp/cpp_test_util.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,17 +246,17 @@ void WithAllDevices(
246246
}
247247

248248
std::string GetTensorTextGraph(at::Tensor tensor) {
249-
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
249+
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor));
250250
return DumpUtil::ToText({xtensor->GetIrValue().node.get()});
251251
}
252252

253253
std::string GetTensorDotGraph(at::Tensor tensor) {
254-
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
254+
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor));
255255
return DumpUtil::ToDot({xtensor->GetIrValue().node.get()});
256256
}
257257

258258
std::string GetTensorHloGraph(at::Tensor tensor) {
259-
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
259+
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor));
260260
return DumpUtil::ToHlo({xtensor->GetIrValue()}, xtensor->GetDevice());
261261
}
262262

test/cpp/test_aten_xla_tensor_1.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ TEST_F(AtenXlaTensorTest, TestStorage) {
2727
torch::Tensor a = torch::tensor({0.0});
2828
ForEachDevice([&](const torch::Device& device) {
2929
torch::Tensor xla_a = CopyToDevice(a, device);
30-
XLATensorPtr xla_tensor_a = bridge::GetXlaTensor(xla_a);
30+
XLATensorPtr xla_tensor_a = GetValueOrThrow(bridge::GetXlaTensor(xla_a));
3131
EXPECT_EQ(xla_a.device(), xla_tensor_a->Storage().device());
3232
AllClose(a, xla_a);
3333
});

torch_xla/csrc/aten_autograd_ops.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "torch_xla/csrc/aten_fallback.h"
99
#include "torch_xla/csrc/aten_xla_bridge.h"
1010
#include "torch_xla/csrc/helpers.h"
11+
#include "torch_xla/csrc/status.h"
1112
#include "torch_xla/csrc/tensor_methods.h"
1213
#include "torch_xla/csrc/torch_util.h"
1314

@@ -33,7 +34,8 @@ torch::Tensor EinsumAutogradFunction::forward(
3334
}
3435
ctx->save_for_backward(vars);
3536

36-
std::vector<XLATensorPtr> xla_tensors = bridge::GetXlaTensors(tensors);
37+
std::vector<XLATensorPtr> xla_tensors =
38+
GetValueOrThrow(bridge::GetXlaTensors(tensors));
3739
XLATensorPtr output = tensor_methods::einsum(eq_str, xla_tensors);
3840
return bridge::AtenFromXlaTensor(output);
3941
}
@@ -43,11 +45,13 @@ torch::autograd::variable_list EinsumAutogradFunction::backward(
4345
torch::autograd::variable_list grad_output) {
4446
std::string equation = ctx->saved_data["equation"].toString()->string();
4547
torch::autograd::variable_list tensors = ctx->get_saved_variables();
46-
std::vector<XLATensorPtr> xla_tensors = bridge::GetXlaTensors(tensors);
48+
std::vector<XLATensorPtr> xla_tensors =
49+
GetValueOrThrow(bridge::GetXlaTensors(tensors));
4750

4851
std::tuple<XLATensorPtr, XLATensorPtr> outputs =
49-
tensor_methods::einsum_backward(bridge::GetXlaTensor(grad_output[0]),
50-
xla_tensors, equation);
52+
tensor_methods::einsum_backward(
53+
GetValueOrThrow(bridge::GetXlaTensor(grad_output[0])), xla_tensors,
54+
equation);
5155

5256
// For both einsum and max pool, we use "undef" as a placeholder for the
5357
// non-tensor grad inputs, in this case the equation string.
@@ -190,7 +194,7 @@ torch::Tensor MaxPool3dAutogradFunction::forward(
190194
}
191195
ctx->save_for_backward({self});
192196
auto outputs = tensor_methods::max_pool_nd(
193-
bridge::GetXlaTensor(self), /*spatial_dim_count=*/3,
197+
GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/3,
194198
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
195199
XlaHelpers::I64List(padding), ceil_mode);
196200
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
@@ -218,7 +222,8 @@ torch::autograd::variable_list MaxPool3dAutogradFunction::backward(
218222
ceil_mode, indices);
219223
}
220224
grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
221-
bridge::GetXlaTensor(grad_output[0]), bridge::GetXlaTensor(self),
225+
GetValueOrThrow(bridge::GetXlaTensor(grad_output[0])),
226+
GetValueOrThrow(bridge::GetXlaTensor(self)),
222227
/*spatial_dim_count=*/3, XlaHelpers::I64List(kernel_size),
223228
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode));
224229

@@ -234,7 +239,7 @@ torch::Tensor max_pool2d_forward(torch::Tensor self,
234239
torch::IntArrayRef padding,
235240
torch::IntArrayRef dilation, bool ceil_mode) {
236241
auto outputs = tensor_methods::max_pool_nd(
237-
bridge::GetXlaTensor(self), /*spatial_dim_count=*/2,
242+
GetValueOrThrow(bridge::GetXlaTensor(self)), /*spatial_dim_count=*/2,
238243
XlaHelpers::I64List(kernel_size), XlaHelpers::I64List(stride),
239244
XlaHelpers::I64List(padding), ceil_mode);
240245
return bridge::AtenFromXlaTensor(std::get<0>(outputs));
@@ -245,7 +250,8 @@ torch::Tensor max_pool2d_backward(torch::Tensor grad_output, torch::Tensor self,
245250
torch::IntArrayRef stride,
246251
torch::IntArrayRef padding, bool ceil_mode) {
247252
auto grad = bridge::AtenFromXlaTensor(tensor_methods::max_pool_nd_backward(
248-
bridge::GetXlaTensor(grad_output), bridge::GetXlaTensor(self),
253+
GetValueOrThrow(bridge::GetXlaTensor(grad_output)),
254+
GetValueOrThrow(bridge::GetXlaTensor(self)),
249255
/*spatial_dim_count=*/2, XlaHelpers::I64List(kernel_size),
250256
XlaHelpers::I64List(stride), XlaHelpers::I64List(padding), ceil_mode));
251257
return grad;

torch_xla/csrc/aten_fallback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ static bool validate_tensor_list(const c10::List<at::Tensor>& tensorlist) {
137137

138138
// Retrieve the inner XLATensorPtr, and check it lives inside CUDA.
139139
static XLATensorPtr get_xla_cuda_tensor(const at::Tensor& tensor) {
140-
XLATensorPtr xla_tensor = bridge::GetXlaTensor(tensor);
140+
XLATensorPtr xla_tensor = GetValueOrThrow(bridge::GetXlaTensor(tensor));
141141
const torch::lazy::BackendDevice& device = xla_tensor->GetDevice();
142142
TORCH_CHECK(device.type() == static_cast<int8_t>(XlaDeviceType::CUDA),
143143
"OpenXLA CUDA fallback only supports XLA:CUDA tensors. Found a "

torch_xla/csrc/aten_xla_bridge.cpp

Lines changed: 64 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
#include <string>
88
#include <vector>
99

10+
#include "absl/status/status.h"
1011
#include "absl/strings/str_cat.h"
1112
#include "torch_xla/csrc/device.h"
1213
#include "torch_xla/csrc/runtime/debug_macros.h"
1314
#include "torch_xla/csrc/runtime/runtime.h"
15+
#include "torch_xla/csrc/status.h"
1416
#include "torch_xla/csrc/tensor_impl.h"
1517
#include "torch_xla/csrc/torch_util.h"
1618
#include "torch_xla/csrc/xla_graph_executor.h"
@@ -72,72 +74,68 @@ AtenXlaDeviceMapper* AtenXlaDeviceMapper::Get() {
7274
return device_mapper;
7375
}
7476

75-
XLATensorImpl* GetXlaTensorImpl(const at::Tensor& tensor) {
77+
static absl::StatusOr<XLATensorImpl * absl_nonnull> GetXlaTensorImpl(
78+
const at::Tensor& tensor) {
7679
auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor);
77-
return dynamic_cast<XLATensorImpl*>(inner_tensor.unsafeGetTensorImpl());
80+
XLATensorImpl* impl =
81+
dynamic_cast<XLATensorImpl*>(inner_tensor.unsafeGetTensorImpl());
82+
if (impl == nullptr) {
83+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
84+
"Input tensor is not an XLA tensor: ", tensor.toString())));
85+
}
86+
return impl;
7887
}
7988

8089
} // namespace
8190

8291
XLATensorPtr TryGetXlaTensor(const at::Tensor& tensor) {
92+
return GetXlaTensor(tensor).value_or(XLATensorPtr{});
93+
}
94+
95+
absl::StatusOr<absl_nonnull XLATensorPtr> GetXlaTensor(
96+
const at::Tensor& tensor) {
8397
if (tensor.defined() &&
8498
at::functionalization::impl::isFunctionalTensor(tensor)) {
8599
// To make sure we have the most updated version of tensor.
86100
at::functionalization::impl::sync(tensor);
87101
}
88-
XLATensorImpl* impl = GetXlaTensorImpl(tensor);
89-
if (impl == nullptr) {
90-
return XLATensorPtr();
91-
}
102+
XLA_ASSIGN_OR_RETURN(XLATensorImpl * impl, GetXlaTensorImpl(tensor));
92103
return impl->tensor();
93104
}
94105

95-
std::vector<XLATensorPtr> TryGetXlaTensors(const at::ITensorListRef& tensors) {
96-
std::vector<XLATensorPtr> xla_tensors;
106+
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> GetXlaTensors(
107+
const at::ITensorListRef& tensors) {
108+
std::vector<absl_nonnull XLATensorPtr> xla_tensors;
97109
xla_tensors.reserve(tensors.size());
98110
for (const auto& tensor : tensors) {
99-
xla_tensors.push_back(bridge::TryGetXlaTensor(tensor));
111+
XLA_ASSIGN_OR_RETURN(XLATensorPtr ptr, bridge::GetXlaTensor(tensor));
112+
xla_tensors.push_back(std::move(ptr));
100113
}
101114
return xla_tensors;
102115
}
103116

104117
bool IsXlaTensor(const at::Tensor& tensor) {
105-
return GetXlaTensorImpl(tensor) != nullptr;
106-
}
107-
108-
XLATensorPtr GetXlaTensor(const at::Tensor& tensor) {
109-
auto xtensor = TryGetXlaTensor(tensor);
110-
XLA_CHECK(xtensor) << "Input tensor is not an XLA tensor: "
111-
<< tensor.toString();
112-
return xtensor;
118+
return GetXlaTensorImpl(tensor).ok();
113119
}
114120

115-
void ReplaceXlaTensor(const at::Tensor& tensor, XLATensorPtr new_xla_tensor) {
116-
auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor);
117-
XLATensorImpl* impl =
118-
dynamic_cast<XLATensorImpl*>(inner_tensor.unsafeGetTensorImpl());
119-
XLA_CHECK(impl != nullptr)
120-
<< "Input tensor is not an XLA tensor: " << inner_tensor.toString();
121+
absl::Status ReplaceXlaTensor(const at::Tensor& tensor,
122+
XLATensorPtr new_xla_tensor) {
123+
XLA_ASSIGN_OR_RETURN(XLATensorImpl * impl, GetXlaTensorImpl(tensor));
121124
impl->set_tensor(std::move(new_xla_tensor));
125+
return absl::OkStatus();
122126
}
123127

124-
void ReplaceXlaTensor(const std::vector<at::Tensor>& tensors,
125-
const std::vector<XLATensorPtr> new_xla_tensors) {
126-
XLA_CHECK(tensors.size() == new_xla_tensors.size())
127-
<< "The size of tensors and new_xla_tensors are not equal: "
128-
<< tensors.size() << " vs. " << new_xla_tensors.size();
129-
for (size_t i = 0; i < tensors.size(); ++i) {
130-
ReplaceXlaTensor(tensors[i], new_xla_tensors[i]);
128+
absl::Status ReplaceXlaTensor(const std::vector<at::Tensor>& tensors,
129+
const std::vector<XLATensorPtr> new_xla_tensors) {
130+
if (tensors.size() != new_xla_tensors.size()) {
131+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
132+
absl::StrCat("The size of tensors and new_xla_tensors are not equal: ",
133+
tensors.size(), " vs. ", new_xla_tensors.size())));
131134
}
132-
}
133-
134-
std::vector<XLATensorPtr> GetXlaTensors(const at::ITensorListRef& tensors) {
135-
std::vector<XLATensorPtr> xla_tensors;
136-
xla_tensors.reserve(tensors.size());
137-
for (const auto& tensor : tensors) {
138-
xla_tensors.push_back(bridge::GetXlaTensor(tensor));
135+
for (size_t i = 0; i < tensors.size(); ++i) {
136+
XLA_RETURN_IF_ERROR(ReplaceXlaTensor(tensors[i], new_xla_tensors[i]));
139137
}
140-
return xla_tensors;
138+
return absl::OkStatus();
141139
}
142140

143141
torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber(
@@ -146,7 +144,7 @@ torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber(
146144
(tensor.dim() == 0 && tensor.numel() == 1)) {
147145
return torch_xla::bridge::GetOrCreateXlaTensor(tensor, device);
148146
} else {
149-
return torch_xla::bridge::GetXlaTensor(tensor);
147+
return GetValueOrThrow(torch_xla::bridge::GetXlaTensor(tensor));
150148
}
151149
}
152150

@@ -155,22 +153,23 @@ XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor,
155153
if (!tensor.defined()) {
156154
return XLATensorPtr();
157155
}
156+
158157
auto inner_tensor = torch::lazy::maybe_unwrap_functional(tensor);
159158
if (!inner_tensor.defined()) {
160159
return XLATensorPtr();
161160
}
162-
auto xtensor = TryGetXlaTensor(tensor);
163-
return xtensor ? xtensor : XLATensor::Create(inner_tensor, device);
161+
162+
auto xtensor = GetXlaTensor(tensor);
163+
return xtensor.ok() ? xtensor.value()
164+
: XLATensor::Create(inner_tensor, device);
164165
}
165166

166167
XLATensorPtr GetOrCreateXlaTensor(const std::optional<at::Tensor>& tensor,
167168
const torch::lazy::BackendDevice& device) {
168-
if (!IsDefined(tensor)) {
169+
if (!tensor.has_value()) {
169170
return XLATensorPtr();
170171
}
171-
auto xtensor = TryGetXlaTensor(*tensor);
172-
auto inner_tensor = torch::lazy::maybe_unwrap_functional(*tensor);
173-
return xtensor ? xtensor : XLATensor::Create(inner_tensor, device);
172+
return GetOrCreateXlaTensor(*tensor, device);
174173
}
175174

176175
std::vector<XLATensorPtr> GetOrCreateXlaTensors(
@@ -199,10 +198,10 @@ std::vector<at::Tensor> XlaCreateTensorList(const at::ITensorListRef& tensors) {
199198
continue;
200199
}
201200

202-
auto xtensor = TryGetXlaTensor(tensor);
203-
if (xtensor) {
201+
auto xtensor_status = GetXlaTensor(tensor);
202+
if (xtensor_status.ok()) {
204203
to_translate[ix] = true;
205-
xla_tensors.push_back(xtensor);
204+
xla_tensors.push_back(xtensor_status.value());
206205
} else {
207206
aten_xla_tensors[ix] = tensor;
208207
}
@@ -253,13 +252,14 @@ void XlaUpdateTensors(absl::Span<const at::Tensor> dest_xla_tensors,
253252
for (auto index : indices) {
254253
at::Tensor dest = dest_xla_tensors.at(index);
255254
at::Tensor source = source_cpu_tensors.at(index);
256-
XLATensorImpl* dest_impl = GetXlaTensorImpl(dest);
257-
if (dest_impl != nullptr) {
258-
auto xla_source = TryGetXlaTensor(source);
259-
if (!xla_source) {
260-
dest_impl->tensor()->UpdateFromTensorOut(source);
255+
auto dest_impl_status = GetXlaTensorImpl(dest);
256+
if (dest_impl_status.ok()) {
257+
auto dest_impl = std::move(dest_impl_status).value();
258+
auto xla_source_status = GetXlaTensor(source);
259+
if (xla_source_status.ok()) {
260+
dest_impl->tensor()->UpdateFromTensorOut(xla_source_status.value());
261261
} else {
262-
dest_impl->tensor()->UpdateFromTensorOut(xla_source);
262+
dest_impl->tensor()->UpdateFromTensorOut(source);
263263
}
264264
dest_impl->force_refresh_sizes();
265265
} else {
@@ -270,11 +270,11 @@ void XlaUpdateTensors(absl::Span<const at::Tensor> dest_xla_tensors,
270270

271271
std::optional<torch::lazy::BackendDevice> GetXlaDevice(
272272
const at::Tensor& tensor) {
273-
auto xtensor = TryGetXlaTensor(tensor);
274-
if (!xtensor) {
273+
auto xtensor_status = GetXlaTensor(tensor);
274+
if (!xtensor_status.ok()) {
275275
return std::nullopt;
276276
}
277-
return xtensor->GetDevice();
277+
return xtensor_status.value()->GetDevice();
278278
}
279279

280280
std::optional<torch::lazy::BackendDevice> GetXlaDevice(
@@ -469,12 +469,15 @@ std::vector<at::Tensor> CreateXlaTensors(
469469
}
470470

471471
const at::Tensor& GetRootBase(const at::Tensor& tensor) {
472-
auto xla_tensor = TryGetXlaTensor(tensor);
473-
if (xla_tensor && xla_tensor->Base().defined()) {
474-
return GetRootBase(xla_tensor->Base());
475-
} else {
472+
auto xla_tensor_status = GetXlaTensor(tensor);
473+
if (!xla_tensor_status.ok()) {
474+
return tensor;
475+
}
476+
auto xla_tensor = std::move(xla_tensor_status).value();
477+
if (!xla_tensor->Base().defined()) {
476478
return tensor;
477479
}
480+
return GetRootBase(xla_tensor->Base());
478481
}
479482

480483
XLATensorPtr SetBaseTensor(XLATensorPtr tensor, const at::Tensor& base) {

0 commit comments

Comments
 (0)