Skip to content

Commit 6050927

Browse files
authored
Improve error message of functions related to GetXlaTensor(). (#9520)
1 parent 30ad68a commit 6050927

File tree

7 files changed

+160
-46
lines changed

7 files changed

+160
-46
lines changed

test/quantized_ops/test_dot_general.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,25 @@ def test_dot_general_int32_dtype(self):
5656
preferred_element_type=torch.int32)
5757
self.assertTrue(torch.allclose(xla_out.cpu(), expected_out))
5858

59+
def test_raises_error_on_non_xla_tensor(self):
60+
lhs = torch.rand(10, 3, 4, dtype=torch.bfloat16)
61+
rhs = torch.rand(10, 4, 5, dtype=torch.bfloat16)
62+
63+
def test(args, non_xla_tensor_arg):
64+
arg_number_to_str = ["first", "second"]
65+
position = arg_number_to_str[non_xla_tensor_arg]
66+
try:
67+
torch_xla._XLAC._xla_dot_general(*args, (([2], [1]), ([0], [0])))
68+
except RuntimeError as err:
69+
error_message = (
70+
f"Expected input tensor ({position} argument) to be an actual XLA tensor. "
71+
f"Got: CPUBFloat16Type. Consider moving it ({position} argument) to XLA."
72+
)
73+
self.assertEqual(str(err), error_message)
74+
75+
test((lhs, rhs.to(device)), non_xla_tensor_arg=0)
76+
test((lhs.to(device), rhs), non_xla_tensor_arg=1)
77+
5978

6079
if __name__ == '__main__':
6180
test = unittest.main()

torch_xla/csrc/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ ptxla_cc_library(
280280
"//torch_xla/csrc/runtime:xla_coordinator",
281281
"//torch_xla/csrc/runtime:xla_util",
282282
"@com_google_absl//absl/container:flat_hash_map",
283+
"@com_google_absl//absl/log:absl_check",
283284
"@com_google_absl//absl/strings",
284285
"@com_google_absl//absl/synchronization",
285286
"@com_google_absl//absl/types:variant",

torch_xla/csrc/aten_xla_bridge.cpp

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include <string>
88
#include <vector>
99

10-
#include "absl/status/status.h"
10+
#include "absl/log/absl_check.h"
1111
#include "absl/strings/str_cat.h"
1212
#include "torch_xla/csrc/device.h"
1313
#include "torch_xla/csrc/runtime/debug_macros.h"
@@ -80,8 +80,12 @@ static absl::StatusOr<XLATensorImpl * absl_nonnull> GetXlaTensorImpl(
8080
XLATensorImpl* impl =
8181
dynamic_cast<XLATensorImpl*>(inner_tensor.unsafeGetTensorImpl());
8282
if (impl == nullptr) {
83-
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(absl::StrCat(
84-
"Input tensor is not an XLA tensor: ", tensor.toString())));
83+
auto error_message =
84+
absl::StrCat("Failed retrieving the inner XLATensorImpl* from ",
85+
tensor.toString(), ". ",
86+
"It's likely that `tensor` is not an actual XLA tensor, "
87+
"i.e. it wasn't created inside PyTorch/XLA.");
88+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(error_message));
8589
}
8690
return impl;
8791
}
@@ -99,41 +103,63 @@ absl::StatusOr<absl_nonnull XLATensorPtr> GetXlaTensor(
99103
// To make sure we have the most updated version of tensor.
100104
at::functionalization::impl::sync(tensor);
101105
}
102-
XLA_ASSIGN_OR_RETURN(XLATensorImpl * impl, GetXlaTensorImpl(tensor));
106+
XLA_ASSIGN_OR_RETURN(
107+
XLATensorImpl * impl, GetXlaTensorImpl(tensor),
108+
absl::StrCat("Expected XLA tensor. Got: ", tensor.toString()));
103109
return impl->tensor();
104110
}
105111

106112
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> GetXlaTensors(
107113
const at::ITensorListRef& tensors) {
108114
std::vector<absl_nonnull XLATensorPtr> xla_tensors;
109115
xla_tensors.reserve(tensors.size());
116+
std::size_t index = 0;
110117
for (const auto& tensor : tensors) {
111-
XLA_ASSIGN_OR_RETURN(XLATensorPtr ptr, bridge::GetXlaTensor(tensor));
118+
XLA_ASSIGN_OR_RETURN(
119+
XLATensorPtr ptr, bridge::GetXlaTensor(tensor),
120+
absl::StrCat("Expected all tensors in the given list to be XLA "
121+
"tensors. Element at index ",
122+
index, " is not an XLA tensor. Got: ", tensor.toString()));
112123
xla_tensors.push_back(std::move(ptr));
124+
index += 1;
113125
}
114126
return xla_tensors;
115127
}
116128

129+
absl::StatusOr<absl_nonnull XLATensorPtr> GetInputXlaTensor(
130+
const at::Tensor& tensor, const std::string_view param) {
131+
XLA_ASSIGN_OR_RETURN(
132+
XLATensorPtr ptr, GetXlaTensor(tensor),
133+
absl::StrCat("Expected input tensor (", param,
134+
") to be an actual XLA tensor. Got: ", tensor.toString(),
135+
". Consider moving it (", param, ") to XLA."));
136+
return ptr;
137+
}
138+
117139
bool IsXlaTensor(const at::Tensor& tensor) {
118140
return GetXlaTensorImpl(tensor).ok();
119141
}
120142

121143
absl::Status ReplaceXlaTensor(const at::Tensor& tensor,
122144
XLATensorPtr new_xla_tensor) {
123-
XLA_ASSIGN_OR_RETURN(XLATensorImpl * impl, GetXlaTensorImpl(tensor));
145+
XLA_ASSIGN_OR_RETURN(XLATensorImpl * impl, GetXlaTensorImpl(tensor),
146+
"Failed replacing the XLA tensor in the given tensor.");
124147
impl->set_tensor(std::move(new_xla_tensor));
125148
return absl::OkStatus();
126149
}
127150

128151
absl::Status ReplaceXlaTensor(const std::vector<at::Tensor>& tensors,
129152
const std::vector<XLATensorPtr> new_xla_tensors) {
130-
if (tensors.size() != new_xla_tensors.size()) {
131-
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(
132-
absl::StrCat("The size of tensors and new_xla_tensors are not equal: ",
133-
tensors.size(), " vs. ", new_xla_tensors.size())));
134-
}
153+
ABSL_CHECK(tensors.size() == new_xla_tensors.size())
154+
<< "Expected the size of the list of tensors (" << tensors.size()
155+
<< ") to match the size of the list of XLATensorPtr ("
156+
<< new_xla_tensors.size() << ")";
135157
for (size_t i = 0; i < tensors.size(); ++i) {
136-
XLA_RETURN_IF_ERROR(ReplaceXlaTensor(tensors[i], new_xla_tensors[i]));
158+
XLA_RETURN_IF_ERROR(
159+
ReplaceXlaTensor(tensors[i], new_xla_tensors[i]),
160+
absl::StrCat(
161+
"Failed replacing the XLA tensor at index ", i,
162+
". The reason being that that tensor is not an XLA tensor."));
137163
}
138164
return absl::OkStatus();
139165
}

torch_xla/csrc/aten_xla_bridge.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,15 @@ absl::StatusOr<absl_nonnull XLATensorPtr> GetXlaTensor(
5959
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> GetXlaTensors(
6060
const at::ITensorListRef& tensors);
6161

62+
// Retrieves the underlying `XLATensorPtr` from `tensor`.
63+
//
64+
// If `tensor` is not an actual XLA tensor, this function will craft a
65+
// specialized error message for PyTorch operations or Python API
66+
// functions, i.e. functions where the parameter name makes sense for
67+
// the end user.
68+
absl::StatusOr<absl_nonnull XLATensorPtr> GetInputXlaTensor(
69+
const at::Tensor& tensor, std::string_view param);
70+
6271
bool IsXlaTensor(const at::Tensor& tensor);
6372

6473
// Replaces the XLA tensor embedded within `tensor`'s XLA TensorImpl with

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
#include <vector>
2626

2727
#include "absl/container/flat_hash_map.h"
28-
#include "absl/status/status.h"
28+
#include "absl/log/absl_check.h"
2929
#include "absl/strings/str_cat.h"
3030
#include "absl/synchronization/blocking_counter.h"
3131
#include "absl/types/variant.h"
@@ -38,6 +38,7 @@
3838
#include "pybind11/pytypes.h"
3939
#include "pybind11/stl.h"
4040
#include "pybind11/stl_bind.h"
41+
#include "status.h"
4142
#include "torch_xla/csrc/XLANativeFunctions.h"
4243
#include "torch_xla/csrc/aten_autograd_ops.h"
4344
#include "torch_xla/csrc/aten_fallback.h"
@@ -87,6 +88,23 @@ namespace {
8788

8889
constexpr int64_t kSeedInfoId = -127389;
8990

91+
// Traits related to the return type of the lambda function that wraps the
92+
// actual implementation inside PythonScope.
93+
template <class T>
94+
struct RemoveStatus {
95+
using type = T;
96+
};
97+
98+
template <>
99+
struct RemoveStatus<absl::Status> {
100+
using type = void;
101+
};
102+
103+
template <class T>
104+
struct RemoveStatus<absl::StatusOr<T>> {
105+
using type = T;
106+
};
107+
90108
// Wraps a python scope (e.g. py::module) to provide more convenient APIs.
91109
// It behaves like a Scope object but has enhanced behaviors for the def*()
92110
// methods. This class has reference semantics, just like the Scope class.
@@ -153,15 +171,29 @@ class PythonScope : public Scope {
153171
template <typename F>
154172
static void Bind(Scope& scope, const char* const name, F&& f,
155173
const Extra&... extra) {
156-
using RetType =
174+
// `f` return type.
175+
using FnRetType =
157176
typename c10::guts::infer_function_traits<F>::type::return_type;
158-
auto lambda = [f = std::move(f)](Args... args) -> RetType {
177+
// Wrapper lambda return type.
178+
// This is needed for handling returning status types.
179+
using LambdaRetType = typename RemoveStatus<FnRetType>::type;
180+
// FnRetType is a status type iff after unwrapping the status type,
181+
// the resulting type (i.e. LambdaRetType) is NOT the same as FnRetType.
182+
constexpr bool returns_status_type =
183+
!std::is_same<FnRetType, LambdaRetType>::value;
184+
185+
auto lambda = [f = std::move(f)](Args... args) -> LambdaRetType {
159186
// RAII for emitting Python warnings.
160187
//
161188
// This turns messages passed to `TORCH_WARN()` in `f` into Python
162189
// warnings.
163190
torch::PyWarningHandler handler;
164-
return f(args...);
191+
192+
if constexpr (returns_status_type) {
193+
return GetValueOrThrow(f(args...));
194+
} else {
195+
return f(args...);
196+
}
165197
};
166198

167199
if constexpr (kind == FunctionKind::kInit) {
@@ -237,13 +269,11 @@ std::string GetTensorsDump(
237269
const std::vector<at::Tensor>& tensors,
238270
const std::function<
239271
std::string(absl::Span<const torch::lazy::Node* const>)>& coverter) {
272+
auto xtensors = GetValueOrThrow(bridge::GetXlaTensors(tensors));
240273
std::vector<const torch::lazy::Node*> nodes;
241-
std::vector<torch::lazy::Value> values;
242-
for (auto& tensor : tensors) {
243-
XLATensorPtr xtensor = GetValueOrThrow(bridge::GetXlaTensor(tensor));
244-
values.push_back(xtensor->GetIrValue());
245-
nodes.push_back(values.back().node.get());
246-
}
274+
std::transform(
275+
xtensors.begin(), xtensors.end(), std::back_inserter(nodes),
276+
[](const XLATensorPtr& ptr) { return ptr->GetIrValue().node.get(); });
247277
return coverter(nodes);
248278
}
249279

@@ -363,7 +393,7 @@ std::vector<std::vector<int>> ExtractXlaDotGeneralDimVectors(
363393
return dim_vectors;
364394
}
365395

366-
at::Tensor XlaDotGeneral(const at::Tensor& lhs, const at::Tensor& rhs,
396+
at::Tensor XlaDotGeneral(const XLATensorPtr& xlhs, const XLATensorPtr& xrhs,
367397
const std::vector<std::vector<int>>& dim_vectors,
368398
std::optional<py::object> preferred_element_type) {
369399
std::optional<at::ScalarType> at_preferred_element_type;
@@ -373,9 +403,7 @@ at::Tensor XlaDotGeneral(const at::Tensor& lhs, const at::Tensor& rhs,
373403
->scalar_type;
374404
}
375405
return bridge::AtenFromXlaTensor(tensor_methods::xla_dot_general(
376-
GetValueOrThrow(bridge::GetXlaTensor(lhs)),
377-
GetValueOrThrow(bridge::GetXlaTensor(rhs)), dim_vectors,
378-
at_preferred_element_type));
406+
xlhs, xrhs, dim_vectors, at_preferred_element_type));
379407
}
380408

381409
std::vector<std::pair<int64_t, int64_t>> CreateSourceTargetPairs(
@@ -1841,20 +1869,25 @@ void InitXlaModuleBindings(py::module m) {
18411869
})
18421870
.def(
18431871
"_xla_dot_general",
1844-
[](const at::Tensor& lhs, const at::Tensor& rhs,
1872+
[](const at::Tensor& lhs,
1873+
const at::Tensor& rhs,
18451874
py::tuple dimension_numbers,
18461875
std::optional<std::string>& precision_config,
1847-
std::optional<py::object>& preferred_element_type) -> at::Tensor {
1876+
std::optional<py::object>& preferred_element_type) -> absl::StatusOr<at::Tensor> {
18481877
// Python binding for xla::DotGeneral
18491878
// https://openxla.org/xla/operation_semantics#dotgeneral
18501879
std::vector<std::vector<int>> dim_vectors =
18511880
ExtractXlaDotGeneralDimVectors(dimension_numbers);
18521881
XLA_CHECK(!precision_config.has_value())
18531882
<< "_xla_dot_general: precision_config is not supported yet, "
18541883
"default precision setting will be applied.";
1855-
at::Tensor result =
1856-
XlaDotGeneral(lhs, rhs, dim_vectors, preferred_element_type);
1857-
return result;
1884+
XLA_ASSIGN_OR_RETURN(
1885+
XLATensorPtr xlhs,
1886+
bridge::GetInputXlaTensor(lhs, /* param= */ "first argument"));
1887+
XLA_ASSIGN_OR_RETURN(
1888+
XLATensorPtr xrhs,
1889+
bridge::GetInputXlaTensor(rhs, /* param= */ "second argument"));
1890+
return XlaDotGeneral(xlhs, xrhs, dim_vectors, preferred_element_type);
18581891
},
18591892
py::arg("lhs"), //
18601893
py::arg("rhs"), //
@@ -3340,19 +3373,25 @@ void InitXlaModuleBindings(py::module m) {
33403373
opt_device ? &opt_device.value() : nullptr);
33413374
return check_materialization_helper(xtensors);
33423375
})
3343-
.def(
3344-
"_get_graph_hash",
3345-
[](const std::vector<at::Tensor>& tensors) {
3346-
std::vector<XLATensorPtr> xtensors;
3347-
xtensors.reserve(tensors.size());
3348-
for (auto& tensor : tensors) {
3349-
xtensors.push_back(GetValueOrThrow(bridge::GetXlaTensor(tensor)));
3350-
}
3351-
torch::lazy::hash_t hash =
3352-
XLAGraphExecutor::Get()->GetGraphHash(xtensors);
3353-
std::string bin((const char*)&hash, sizeof(hash));
3354-
return py::bytes(bin);
3355-
})
3376+
.def("_get_graph_hash",
3377+
[](const std::vector<at::Tensor>& tensors) -> py::bytes {
3378+
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>>
3379+
xtensors_status = bridge::GetXlaTensors(tensors);
3380+
ABSL_CHECK(xtensors_status.ok())
3381+
<< "_get_graph_hash(): error retrieving the XLA tensors from "
3382+
<< "the given tensor arguments. "
3383+
<< "This is a bug! Please, open an issue in the PyTorch/XLA "
3384+
<< "GitHub repository: https://github.com/pytorch/xla"
3385+
<< std::endl
3386+
<< "Status Error: "
3387+
<< BuildStatusErrorMessage(xtensors_status.status());
3388+
std::vector<absl_nonnull XLATensorPtr> xtensors =
3389+
xtensors_status.value();
3390+
torch::lazy::hash_t hash =
3391+
XLAGraphExecutor::Get()->GetGraphHash(xtensors);
3392+
std::string bin((const char*)&hash, sizeof(hash));
3393+
return py::bytes(bin);
3394+
})
33563395
.def("_clear_pending_irs",
33573396
[](const std::string& device) {
33583397
// Use with caution. Those tensor whole ir was cleared

torch_xla/csrc/status.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,15 @@ static std::string MaybeGetMessageWithLineBreak(const absl::Status& status) {
119119
: std::string(status.message());
120120
}
121121

122+
std::string BuildStatusErrorMessage(const absl::Status& status) {
123+
return absl::StrCat(MaybeGetMessageWithLineBreak(status),
124+
GetFormattedStatusPropagationTrace(status));
125+
}
126+
122127
void MaybeThrow(const absl::Status& status) {
123-
TORCH_CHECK(status.ok(), MaybeGetMessageWithLineBreak(status),
124-
GetFormattedStatusPropagationTrace(status));
128+
TORCH_CHECK(status.ok(), BuildStatusErrorMessage(status));
125129
}
126130

131+
void GetValueOrThrow(const absl::Status& status) { MaybeThrow(status); }
132+
127133
} // namespace torch_xla

torch_xla/csrc/status.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,17 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file,
174174

175175
} // namespace status_internal
176176

177+
// Builds the complete error message for the given `status`.
178+
//
179+
// If `TORCH_SHOW_CPP_STACKTRACES` is enabled, returns the concatenation of
180+
// `status.message()` with its inner status propagation trace.
181+
//
182+
// TODO(ysiraichi): this call does not append the C++ stacktrace, which,
183+
// ideally, should. It can be done by not using `TORCH_CHECK()` macro directly
184+
// in `MaybeThrow()`, but using PyTorch `c10::get_lazy_backtrace()`
185+
// (at c10/util/Backtrace.h).
186+
std::string BuildStatusErrorMessage(const absl::Status& status);
187+
177188
// Maybe throws an exception if `status` has a non-ok code.
178189
//
179190
// Ideally, this function should be used only used in the project's
@@ -200,6 +211,9 @@ T GetValueOrThrow(absl::StatusOr<T>&& status) {
200211
return std::move(status).value();
201212
}
202213

214+
// `GetValueOrThrow` overload for `Status`.
215+
void GetValueOrThrow(const absl::Status& status);
216+
203217
} // namespace torch_xla
204218

205219
#endif // XLA_TORCH_XLA_CSRC_STATUS_H_

0 commit comments

Comments
 (0)