7
7
#include < string>
8
8
#include < vector>
9
9
10
+ #include " absl/status/status.h"
10
11
#include " absl/strings/str_cat.h"
11
12
#include " torch_xla/csrc/device.h"
12
13
#include " torch_xla/csrc/runtime/debug_macros.h"
13
14
#include " torch_xla/csrc/runtime/runtime.h"
15
+ #include " torch_xla/csrc/status.h"
14
16
#include " torch_xla/csrc/tensor_impl.h"
15
17
#include " torch_xla/csrc/torch_util.h"
16
18
#include " torch_xla/csrc/xla_graph_executor.h"
@@ -72,72 +74,68 @@ AtenXlaDeviceMapper* AtenXlaDeviceMapper::Get() {
72
74
return device_mapper;
73
75
}
74
76
75
- XLATensorImpl* GetXlaTensorImpl (const at::Tensor& tensor) {
77
+ static absl::StatusOr<XLATensorImpl * absl_nonnull> GetXlaTensorImpl (
78
+ const at::Tensor& tensor) {
76
79
auto inner_tensor = torch::lazy::maybe_unwrap_functional (tensor);
77
- return dynamic_cast <XLATensorImpl*>(inner_tensor.unsafeGetTensorImpl ());
80
+ XLATensorImpl* impl =
81
+ dynamic_cast <XLATensorImpl*>(inner_tensor.unsafeGetTensorImpl ());
82
+ if (impl == nullptr ) {
83
+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (absl::StrCat (
84
+ " Input tensor is not an XLA tensor: " , tensor.toString ())));
85
+ }
86
+ return impl;
78
87
}
79
88
80
89
} // namespace
81
90
82
91
XLATensorPtr TryGetXlaTensor (const at::Tensor& tensor) {
92
+ return GetXlaTensor (tensor).value_or (XLATensorPtr{});
93
+ }
94
+
95
+ absl::StatusOr<absl_nonnull XLATensorPtr> GetXlaTensor (
96
+ const at::Tensor& tensor) {
83
97
if (tensor.defined () &&
84
98
at::functionalization::impl::isFunctionalTensor (tensor)) {
85
99
// To make sure we have the most updated version of tensor.
86
100
at::functionalization::impl::sync (tensor);
87
101
}
88
- XLATensorImpl* impl = GetXlaTensorImpl (tensor);
89
- if (impl == nullptr ) {
90
- return XLATensorPtr ();
91
- }
102
+ XLA_ASSIGN_OR_RETURN (XLATensorImpl * impl, GetXlaTensorImpl (tensor));
92
103
return impl->tensor ();
93
104
}
94
105
95
- std::vector<XLATensorPtr> TryGetXlaTensors (const at::ITensorListRef& tensors) {
96
- std::vector<XLATensorPtr> xla_tensors;
106
+ absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> GetXlaTensors (
107
+ const at::ITensorListRef& tensors) {
108
+ std::vector<absl_nonnull XLATensorPtr> xla_tensors;
97
109
xla_tensors.reserve (tensors.size ());
98
110
for (const auto & tensor : tensors) {
99
- xla_tensors.push_back (bridge::TryGetXlaTensor (tensor));
111
+ XLA_ASSIGN_OR_RETURN (XLATensorPtr ptr, bridge::GetXlaTensor (tensor));
112
+ xla_tensors.push_back (std::move (ptr));
100
113
}
101
114
return xla_tensors;
102
115
}
103
116
104
117
bool IsXlaTensor (const at::Tensor& tensor) {
105
- return GetXlaTensorImpl (tensor) != nullptr ;
106
- }
107
-
108
- XLATensorPtr GetXlaTensor (const at::Tensor& tensor) {
109
- auto xtensor = TryGetXlaTensor (tensor);
110
- XLA_CHECK (xtensor) << " Input tensor is not an XLA tensor: "
111
- << tensor.toString ();
112
- return xtensor;
118
+ return GetXlaTensorImpl (tensor).ok ();
113
119
}
114
120
115
- void ReplaceXlaTensor (const at::Tensor& tensor, XLATensorPtr new_xla_tensor) {
116
- auto inner_tensor = torch::lazy::maybe_unwrap_functional (tensor);
117
- XLATensorImpl* impl =
118
- dynamic_cast <XLATensorImpl*>(inner_tensor.unsafeGetTensorImpl ());
119
- XLA_CHECK (impl != nullptr )
120
- << " Input tensor is not an XLA tensor: " << inner_tensor.toString ();
121
+ absl::Status ReplaceXlaTensor (const at::Tensor& tensor,
122
+ XLATensorPtr new_xla_tensor) {
123
+ XLA_ASSIGN_OR_RETURN (XLATensorImpl * impl, GetXlaTensorImpl (tensor));
121
124
impl->set_tensor (std::move (new_xla_tensor));
125
+ return absl::OkStatus ();
122
126
}
123
127
124
- void ReplaceXlaTensor (const std::vector<at::Tensor>& tensors,
125
- const std::vector<XLATensorPtr> new_xla_tensors) {
126
- XLA_CHECK (tensors.size () == new_xla_tensors.size ())
127
- << " The size of tensors and new_xla_tensors are not equal: "
128
- << tensors.size () << " vs. " << new_xla_tensors.size ();
129
- for (size_t i = 0 ; i < tensors.size (); ++i) {
130
- ReplaceXlaTensor (tensors[i], new_xla_tensors[i]);
128
+ absl::Status ReplaceXlaTensor (const std::vector<at::Tensor>& tensors,
129
+ const std::vector<XLATensorPtr> new_xla_tensors) {
130
+ if (tensors.size () != new_xla_tensors.size ()) {
131
+ return XLA_ERROR_WITH_LOCATION (absl::InvalidArgumentError (
132
+ absl::StrCat (" The size of tensors and new_xla_tensors are not equal: " ,
133
+ tensors.size (), " vs. " , new_xla_tensors.size ())));
131
134
}
132
- }
133
-
134
- std::vector<XLATensorPtr> GetXlaTensors (const at::ITensorListRef& tensors) {
135
- std::vector<XLATensorPtr> xla_tensors;
136
- xla_tensors.reserve (tensors.size ());
137
- for (const auto & tensor : tensors) {
138
- xla_tensors.push_back (bridge::GetXlaTensor (tensor));
135
+ for (size_t i = 0 ; i < tensors.size (); ++i) {
136
+ XLA_RETURN_IF_ERROR (ReplaceXlaTensor (tensors[i], new_xla_tensors[i]));
139
137
}
140
- return xla_tensors ;
138
+ return absl::OkStatus () ;
141
139
}
142
140
143
141
torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber (
@@ -146,7 +144,7 @@ torch_xla::XLATensorPtr GetXlaTensorOrCreateForWrappedNumber(
146
144
(tensor.dim () == 0 && tensor.numel () == 1 )) {
147
145
return torch_xla::bridge::GetOrCreateXlaTensor (tensor, device);
148
146
} else {
149
- return torch_xla::bridge::GetXlaTensor (tensor);
147
+ return GetValueOrThrow ( torch_xla::bridge::GetXlaTensor (tensor) );
150
148
}
151
149
}
152
150
@@ -155,22 +153,23 @@ XLATensorPtr GetOrCreateXlaTensor(const at::Tensor& tensor,
155
153
if (!tensor.defined ()) {
156
154
return XLATensorPtr ();
157
155
}
156
+
158
157
auto inner_tensor = torch::lazy::maybe_unwrap_functional (tensor);
159
158
if (!inner_tensor.defined ()) {
160
159
return XLATensorPtr ();
161
160
}
162
- auto xtensor = TryGetXlaTensor (tensor);
163
- return xtensor ? xtensor : XLATensor::Create (inner_tensor, device);
161
+
162
+ auto xtensor = GetXlaTensor (tensor);
163
+ return xtensor.ok () ? xtensor.value ()
164
+ : XLATensor::Create (inner_tensor, device);
164
165
}
165
166
166
167
XLATensorPtr GetOrCreateXlaTensor (const std::optional<at::Tensor>& tensor,
167
168
const torch::lazy::BackendDevice& device) {
168
- if (!IsDefined ( tensor)) {
169
+ if (!tensor. has_value ( )) {
169
170
return XLATensorPtr ();
170
171
}
171
- auto xtensor = TryGetXlaTensor (*tensor);
172
- auto inner_tensor = torch::lazy::maybe_unwrap_functional (*tensor);
173
- return xtensor ? xtensor : XLATensor::Create (inner_tensor, device);
172
+ return GetOrCreateXlaTensor (*tensor, device);
174
173
}
175
174
176
175
std::vector<XLATensorPtr> GetOrCreateXlaTensors (
@@ -199,10 +198,10 @@ std::vector<at::Tensor> XlaCreateTensorList(const at::ITensorListRef& tensors) {
199
198
continue ;
200
199
}
201
200
202
- auto xtensor = TryGetXlaTensor (tensor);
203
- if (xtensor ) {
201
+ auto xtensor_status = GetXlaTensor (tensor);
202
+ if (xtensor_status. ok () ) {
204
203
to_translate[ix] = true ;
205
- xla_tensors.push_back (xtensor );
204
+ xla_tensors.push_back (xtensor_status. value () );
206
205
} else {
207
206
aten_xla_tensors[ix] = tensor;
208
207
}
@@ -253,13 +252,14 @@ void XlaUpdateTensors(absl::Span<const at::Tensor> dest_xla_tensors,
253
252
for (auto index : indices) {
254
253
at::Tensor dest = dest_xla_tensors.at (index);
255
254
at::Tensor source = source_cpu_tensors.at (index);
256
- XLATensorImpl* dest_impl = GetXlaTensorImpl (dest);
257
- if (dest_impl != nullptr ) {
258
- auto xla_source = TryGetXlaTensor (source);
259
- if (!xla_source) {
260
- dest_impl->tensor ()->UpdateFromTensorOut (source);
255
+ auto dest_impl_status = GetXlaTensorImpl (dest);
256
+ if (dest_impl_status.ok ()) {
257
+ auto dest_impl = std::move (dest_impl_status).value ();
258
+ auto xla_source_status = GetXlaTensor (source);
259
+ if (xla_source_status.ok ()) {
260
+ dest_impl->tensor ()->UpdateFromTensorOut (xla_source_status.value ());
261
261
} else {
262
- dest_impl->tensor ()->UpdateFromTensorOut (xla_source );
262
+ dest_impl->tensor ()->UpdateFromTensorOut (source );
263
263
}
264
264
dest_impl->force_refresh_sizes ();
265
265
} else {
@@ -270,11 +270,11 @@ void XlaUpdateTensors(absl::Span<const at::Tensor> dest_xla_tensors,
270
270
271
271
std::optional<torch::lazy::BackendDevice> GetXlaDevice (
272
272
const at::Tensor& tensor) {
273
- auto xtensor = TryGetXlaTensor (tensor);
274
- if (!xtensor ) {
273
+ auto xtensor_status = GetXlaTensor (tensor);
274
+ if (!xtensor_status. ok () ) {
275
275
return std::nullopt ;
276
276
}
277
- return xtensor ->GetDevice ();
277
+ return xtensor_status. value () ->GetDevice ();
278
278
}
279
279
280
280
std::optional<torch::lazy::BackendDevice> GetXlaDevice (
@@ -469,12 +469,15 @@ std::vector<at::Tensor> CreateXlaTensors(
469
469
}
470
470
471
471
const at::Tensor& GetRootBase (const at::Tensor& tensor) {
472
- auto xla_tensor = TryGetXlaTensor (tensor);
473
- if (xla_tensor && xla_tensor->Base ().defined ()) {
474
- return GetRootBase (xla_tensor->Base ());
475
- } else {
472
+ auto xla_tensor_status = GetXlaTensor (tensor);
473
+ if (!xla_tensor_status.ok ()) {
474
+ return tensor;
475
+ }
476
+ auto xla_tensor = std::move (xla_tensor_status).value ();
477
+ if (!xla_tensor->Base ().defined ()) {
476
478
return tensor;
477
479
}
480
+ return GetRootBase (xla_tensor->Base ());
478
481
}
479
482
480
483
XLATensorPtr SetBaseTensor (XLATensorPtr tensor, const at::Tensor& base) {
0 commit comments