Skip to content

Commit 2c34318

Browse files
authored
Generalize crash message for non-ok status. (#9552)
1 parent d5b9a6d commit 2c34318

File tree

4 files changed

+55
-12
lines changed

4 files changed

+55
-12
lines changed

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3377,17 +3377,9 @@ void InitXlaModuleBindings(py::module m) {
33773377
[](const std::vector<at::Tensor>& tensors) -> py::bytes {
33783378
absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>>
33793379
xtensors_status = bridge::GetXlaTensors(tensors);
3380-
ABSL_CHECK(xtensors_status.ok())
3381-
<< "\n\n"
3382-
<< "Internal Error:\n"
3383-
<< " _get_graph_hash(): error retrieving the XLA tensors "
3384-
"from the given tensor arguments. "
3385-
<< "This is a bug! Please, open an issue in the PyTorch/XLA "
3386-
<< "GitHub repository: https://github.com/pytorch/xla"
3387-
<< "\n\n"
3388-
<< "Status Error:\n"
3389-
<< " " << BuildStatusErrorMessage(xtensors_status.status())
3390-
<< "\n";
3380+
XLA_CHECK_OK(xtensors_status,
3381+
"_get_graph_hash(): error retrieving the XLA tensors "
3382+
"from the given tensor arguments.");
33913383
std::vector<absl_nonnull XLATensorPtr> xtensors =
33923384
xtensors_status.value();
33933385
torch::lazy::hash_t hash =

torch_xla/csrc/runtime/debug_macros.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
// unnecessary or undesirable.
2929
#define XLA_ERROR() TF_ERROR_STREAM()
3030
#define XLA_CHECK(c) TF_CHECK(c)
31-
#define XLA_CHECK_OK(c) TF_CHECK_OK(c)
3231
#define XLA_CHECK_EQ(a, b) TF_CHECK_EQ(a, b)
3332
#define XLA_CHECK_NE(a, b) TF_CHECK_NE(a, b)
3433
#define XLA_CHECK_LE(a, b) TF_CHECK_LE(a, b)

torch_xla/csrc/status.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,4 +124,30 @@ void MaybeThrow(const absl::Status& status) {
124124

125125
void GetValueOrThrow(const absl::Status& status) { MaybeThrow(status); }
126126

127+
void OkOrDie(const absl::Status& status, const char* file, const int32_t line,
128+
const char* function, std::string_view message) {
129+
if (status.ok()) {
130+
return;
131+
}
132+
133+
std::ostringstream oss;
134+
oss << "\n\n"
135+
<< "Internal Error:\n";
136+
137+
if (!message.empty()) {
138+
oss << " " << message << "\n";
139+
}
140+
141+
oss << " This is a bug! Please, open an issue in the PyTorch/XLA "
142+
<< "GitHub repository: https://github.com/pytorch/xla"
143+
<< "\n\n"
144+
<< "Status Error:\n"
145+
<< " "
146+
<< BuildStatusErrorMessage(
147+
status_internal::MaybeWithNewMessage(status, file, line, function))
148+
<< "\n";
149+
150+
ABSL_CHECK(status.ok()) << oss.str();
151+
}
152+
127153
} // namespace torch_xla

torch_xla/csrc/status.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#ifndef XLA_TORCH_XLA_CSRC_STATUS_H_
1111
#define XLA_TORCH_XLA_CSRC_STATUS_H_
1212

13+
#include <sstream>
14+
1315
#include "absl/status/statusor.h"
1416

1517
namespace torch_xla {
@@ -125,6 +127,22 @@ constexpr char kStatusPropagationTraceKey[] =
125127
lhs = std::move(XLA_STATUS_VAR_).value(), \
126128
##__VA_ARGS__)
127129

130+
// Crashes if `status` is not an ok status.
131+
//
132+
// Example:
133+
//
134+
// XLA_CHECK_OK(
135+
// FnThatReturnStatus(),
136+
// "New error message"
137+
// );
138+
//
139+
// If `FnThatReturnStatus()` returns a non-ok status, this macro will
140+
// call `ABSL_CHECK()`, which will crash.
141+
//
142+
#define XLA_CHECK_OK(status, ...) \
143+
::torch_xla::OkOrDie(::torch_xla::status_internal::GetStatus(status), \
144+
__FILE__, __LINE__, __FUNCTION__, ##__VA_ARGS__)
145+
128146
namespace status_internal {
129147

130148
// Adds source location information to the status propagation trace if
@@ -211,6 +229,14 @@ T GetValueOrThrow(absl::StatusOr<T>&& status) {
211229
// `GetValueOrThrow` overload for `Status`.
212230
void GetValueOrThrow(const absl::Status& status);
213231

232+
// Checks that `status` is an ok status.
233+
//
234+
// Otherwise, it will create a new status instance with the given source
235+
// location information, and incorporate its message (alongside the
236+
// status propagation trace) to the crash report.
237+
void OkOrDie(const absl::Status& status, const char* file, const int32_t line,
238+
const char* function, std::string_view message = "");
239+
214240
} // namespace torch_xla
215241

216242
#endif // XLA_TORCH_XLA_CSRC_STATUS_H_

0 commit comments

Comments
 (0)