25
25
#include < vector>
26
26
27
27
#include " absl/container/flat_hash_map.h"
28
- #include " absl/status/status .h"
28
+ #include " absl/log/absl_check .h"
29
29
#include " absl/strings/str_cat.h"
30
30
#include " absl/synchronization/blocking_counter.h"
31
31
#include " absl/types/variant.h"
38
38
#include " pybind11/pytypes.h"
39
39
#include " pybind11/stl.h"
40
40
#include " pybind11/stl_bind.h"
41
+ #include " status.h"
41
42
#include " torch_xla/csrc/XLANativeFunctions.h"
42
43
#include " torch_xla/csrc/aten_autograd_ops.h"
43
44
#include " torch_xla/csrc/aten_fallback.h"
@@ -87,6 +88,23 @@ namespace {
87
88
88
89
constexpr int64_t kSeedInfoId = -127389 ;
89
90
91
+ // Traits related to the return type of the lambda function that wraps the
92
+ // actual implementation inside PythonScope.
93
+ template <class T >
94
+ struct RemoveStatus {
95
+ using type = T;
96
+ };
97
+
98
+ template <>
99
+ struct RemoveStatus <absl::Status> {
100
+ using type = void ;
101
+ };
102
+
103
+ template <class T >
104
+ struct RemoveStatus <absl::StatusOr<T>> {
105
+ using type = T;
106
+ };
107
+
90
108
// Wraps a python scope (e.g. py::module) to provide more convenient APIs.
91
109
// It behaves like a Scope object but has enhanced behaviors for the def*()
92
110
// methods. This class has reference semantics, just like the Scope class.
@@ -153,15 +171,29 @@ class PythonScope : public Scope {
153
171
template <typename F>
154
172
static void Bind (Scope& scope, const char * const name, F&& f,
155
173
const Extra&... extra) {
156
- using RetType =
174
+ // `f` return type.
175
+ using FnRetType =
157
176
typename c10::guts::infer_function_traits<F>::type::return_type;
158
- auto lambda = [f = std::move (f)](Args... args) -> RetType {
177
+ // Wrapper lambda return type.
178
+ // This is needed for handling returning status types.
179
+ using LambdaRetType = typename RemoveStatus<FnRetType>::type;
180
+ // FnRetType is a status type iff after unwrapping the status type,
181
+ // the resulting type (i.e. LambdaRetType) is NOT the same as FnRetType.
182
+ constexpr bool returns_status_type =
183
+ !std::is_same<FnRetType, LambdaRetType>::value;
184
+
185
+ auto lambda = [f = std::move (f)](Args... args) -> LambdaRetType {
159
186
// RAII for emitting Python warnings.
160
187
//
161
188
// This turns messages passed to `TORCH_WARN()` in `f` into Python
162
189
// warnings.
163
190
torch::PyWarningHandler handler;
164
- return f (args...);
191
+
192
+ if constexpr (returns_status_type) {
193
+ return GetValueOrThrow (f (args...));
194
+ } else {
195
+ return f (args...);
196
+ }
165
197
};
166
198
167
199
if constexpr (kind == FunctionKind::kInit ) {
@@ -237,13 +269,11 @@ std::string GetTensorsDump(
237
269
const std::vector<at::Tensor>& tensors,
238
270
const std::function<
239
271
std::string (absl::Span<const torch::lazy::Node* const >)>& coverter) {
272
+ auto xtensors = GetValueOrThrow (bridge::GetXlaTensors (tensors));
240
273
std::vector<const torch::lazy::Node*> nodes;
241
- std::vector<torch::lazy::Value> values;
242
- for (auto & tensor : tensors) {
243
- XLATensorPtr xtensor = GetValueOrThrow (bridge::GetXlaTensor (tensor));
244
- values.push_back (xtensor->GetIrValue ());
245
- nodes.push_back (values.back ().node .get ());
246
- }
274
+ std::transform (
275
+ xtensors.begin (), xtensors.end (), std::back_inserter (nodes),
276
+ [](const XLATensorPtr& ptr) { return ptr->GetIrValue ().node .get (); });
247
277
return coverter (nodes);
248
278
}
249
279
@@ -363,7 +393,7 @@ std::vector<std::vector<int>> ExtractXlaDotGeneralDimVectors(
363
393
return dim_vectors;
364
394
}
365
395
366
- at::Tensor XlaDotGeneral (const at::Tensor& lhs , const at::Tensor& rhs ,
396
+ at::Tensor XlaDotGeneral (const XLATensorPtr& xlhs , const XLATensorPtr& xrhs ,
367
397
const std::vector<std::vector<int >>& dim_vectors,
368
398
std::optional<py::object> preferred_element_type) {
369
399
std::optional<at::ScalarType> at_preferred_element_type;
@@ -373,9 +403,7 @@ at::Tensor XlaDotGeneral(const at::Tensor& lhs, const at::Tensor& rhs,
373
403
->scalar_type ;
374
404
}
375
405
return bridge::AtenFromXlaTensor (tensor_methods::xla_dot_general (
376
- GetValueOrThrow (bridge::GetXlaTensor (lhs)),
377
- GetValueOrThrow (bridge::GetXlaTensor (rhs)), dim_vectors,
378
- at_preferred_element_type));
406
+ xlhs, xrhs, dim_vectors, at_preferred_element_type));
379
407
}
380
408
381
409
std::vector<std::pair<int64_t , int64_t >> CreateSourceTargetPairs (
@@ -1841,20 +1869,25 @@ void InitXlaModuleBindings(py::module m) {
1841
1869
})
1842
1870
.def (
1843
1871
" _xla_dot_general" ,
1844
- [](const at::Tensor& lhs, const at::Tensor& rhs,
1872
+ [](const at::Tensor& lhs,
1873
+ const at::Tensor& rhs,
1845
1874
py::tuple dimension_numbers,
1846
1875
std::optional<std::string>& precision_config,
1847
- std::optional<py::object>& preferred_element_type) -> at::Tensor {
1876
+ std::optional<py::object>& preferred_element_type) -> absl::StatusOr< at::Tensor> {
1848
1877
// Python binding for xla::DotGeneral
1849
1878
// https://openxla.org/xla/operation_semantics#dotgeneral
1850
1879
std::vector<std::vector<int >> dim_vectors =
1851
1880
ExtractXlaDotGeneralDimVectors (dimension_numbers);
1852
1881
XLA_CHECK (!precision_config.has_value ())
1853
1882
<< " _xla_dot_general: precision_config is not supported yet, "
1854
1883
" default precision setting will be applied." ;
1855
- at::Tensor result =
1856
- XlaDotGeneral (lhs, rhs, dim_vectors, preferred_element_type);
1857
- return result;
1884
+ XLA_ASSIGN_OR_RETURN (
1885
+ XLATensorPtr xlhs,
1886
+ bridge::GetInputXlaTensor (lhs, /* param= */ " first argument" ));
1887
+ XLA_ASSIGN_OR_RETURN (
1888
+ XLATensorPtr xrhs,
1889
+ bridge::GetInputXlaTensor (rhs, /* param= */ " second argument" ));
1890
+ return XlaDotGeneral (xlhs, xrhs, dim_vectors, preferred_element_type);
1858
1891
},
1859
1892
py::arg (" lhs" ), //
1860
1893
py::arg (" rhs" ), //
@@ -3340,19 +3373,25 @@ void InitXlaModuleBindings(py::module m) {
3340
3373
opt_device ? &opt_device.value () : nullptr );
3341
3374
return check_materialization_helper (xtensors);
3342
3375
})
3343
- .def (
3344
- " _get_graph_hash" ,
3345
- [](const std::vector<at::Tensor>& tensors) {
3346
- std::vector<XLATensorPtr> xtensors;
3347
- xtensors.reserve (tensors.size ());
3348
- for (auto & tensor : tensors) {
3349
- xtensors.push_back (GetValueOrThrow (bridge::GetXlaTensor (tensor)));
3350
- }
3351
- torch::lazy::hash_t hash =
3352
- XLAGraphExecutor::Get ()->GetGraphHash (xtensors);
3353
- std::string bin ((const char *)&hash, sizeof (hash));
3354
- return py::bytes (bin);
3355
- })
3376
+ .def (" _get_graph_hash" ,
3377
+ [](const std::vector<at::Tensor>& tensors) -> py::bytes {
3378
+ absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>>
3379
+ xtensors_status = bridge::GetXlaTensors (tensors);
3380
+ ABSL_CHECK (xtensors_status.ok ())
3381
+ << " _get_graph_hash(): error retrieving the XLA tensors from "
3382
+ << " the given tensor arguments. "
3383
+ << " This is a bug! Please, open an issue in the PyTorch/XLA "
3384
+ << " GitHub repository: https://github.com/pytorch/xla"
3385
+ << std::endl
3386
+ << " Status Error: "
3387
+ << BuildStatusErrorMessage (xtensors_status.status ());
3388
+ std::vector<absl_nonnull XLATensorPtr> xtensors =
3389
+ xtensors_status.value ();
3390
+ torch::lazy::hash_t hash =
3391
+ XLAGraphExecutor::Get ()->GetGraphHash (xtensors);
3392
+ std::string bin ((const char *)&hash, sizeof (hash));
3393
+ return py::bytes (bin);
3394
+ })
3356
3395
.def (" _clear_pending_irs" ,
3357
3396
[](const std::string& device) {
3358
3397
// Use with caution. Those tensor whole ir was cleared
0 commit comments