Skip to content

Commit 4c586bd

Browse files
authored
Create new macros for throwing status errors. (#9588)
Follow-up: #9580 This PR introduces 2 new macros for throwing status errors: `XLA_THROW_IF_ERROR()`, and `XLA_ASSIGN_OR_THROW()`. These macros are analogous to the already existing `XLA_RETURN_IF_ERROR()` and `XLA_ASSIGN_OR_RETURN()`, where instead of propagating (i.e. returning) error status, they throw an exception with the given error status. **Key Changes:** - (_status.h_ and _status.cpp_) New function: `ThrowStatusError(...)` - (_status.h_) Refactors the implementation of the existing macros, so as to use its definition for all of the aforementioned macros - `XLA_PROCESS_STATUS_IMPL_(...)` : core implementation of those macros. - `XLA_PROPAGATE_STATUS_IMPL_(var, ...)`: propagates the given status `var`. - `XLA_THROW_STATUS_IMPL_(...)`: calls the newly added `ThrowStatusError()` function, which throws an exception - `XLA_DO_IF_ERROR_IMPL_(...)`: core implementation of `XLA_*_IF_ERROR()` macros - `XLA_RETURN_IF_ERROR(...)`: combines `XLA_DO_IF_ERROR_IMPL_` with `XLA_PROPAGATE_STATUS_IMPL_` - `XLA_THROW_IF_ERROR(...)`: combines `XLA_DO_IF_ERROR_IMPL_` with `XLA_THROW_STATUS_IMPL_` - `XLA_ASSIGN_OR_DO_IMPL_(...)`: core implementation of `XLA_ASSIGN_OR_*()` macros - `XLA_ASSIGN_OR_RETURN(...)`: combines `XLA_ASSIGN_OR_DO_IMPL_` with `XLA_PROPAGATE_STATUS_IMPL_` - `XLA_ASSGIN_OR_THROW(...)`: combines `XLA_ASSIGN_OR_DO_IMPL_` with `XLA_THROW_STATUS_IMPL_` - (_test_status_common.h_) Add one test for each of the 2 new public macros
1 parent 163193e commit 4c586bd

File tree

3 files changed

+266
-39
lines changed

3 files changed

+266
-39
lines changed

test/cpp/test_status_common.h

Lines changed: 131 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,10 @@ class StatusTest : public testing::TestWithParam<CppStacktracesMode> {
8080
namespace cpp_test {
8181

8282
// Prefix of the C++ stacktrace PyTorch adds to the error message.
83-
constexpr inline char kTorchCppStacktracePrefix[] =
83+
constexpr inline char kTorchCppStacktracePrefixDeprecated[] =
8484
"Exception raised from OkOrThrow at torch_xla/csrc/status.cpp:";
85+
constexpr inline char kTorchCppStacktracePrefix[] =
86+
"Exception raised from ThrowStatusError at torch_xla/csrc/status.cpp:";
8587

8688
constexpr inline char kNewMessage[] = "New test error message";
8789
constexpr inline char kMessage[] = "Test error message";
@@ -113,7 +115,7 @@ TEST_P(StatusTest, OkOrThrowWithErrorStatus) {
113115
if (IsShowCppStacktracesMode()) {
114116
EXPECT_THAT(std::string_view(error.what()),
115117
::testing::StartsWith(absl::StrCat(
116-
kMessage, "\n\n", kTorchCppStacktracePrefix)));
118+
kMessage, "\n\n", kTorchCppStacktracePrefixDeprecated)));
117119
} else {
118120
EXPECT_EQ(std::string_view(error.what_without_backtrace()),
119121
std::string_view(kMessage));
@@ -136,7 +138,7 @@ TEST_P(StatusTest, GetValueOrThrowWithErrorStatusOr) {
136138
if (IsShowCppStacktracesMode()) {
137139
EXPECT_THAT(std::string_view(error.what()),
138140
::testing::StartsWith(absl::StrCat(
139-
kMessage, "\n\n", kTorchCppStacktracePrefix)));
141+
kMessage, "\n\n", kTorchCppStacktracePrefixDeprecated)));
140142
} else {
141143
EXPECT_EQ(std::string_view(error.what_without_backtrace()),
142144
std::string_view(kMessage));
@@ -388,6 +390,132 @@ TEST_P(StatusTest, OkOrThrowWithErrorPropagationWithNewMessage) {
388390
oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":"
389391
<< errline2;
390392
oss << "\n\n";
393+
oss << kTorchCppStacktracePrefixDeprecated;
394+
EXPECT_THAT(std::string_view(error.what()),
395+
::testing::StartsWith(oss.str()));
396+
} else {
397+
EXPECT_EQ(std::string_view(error.what_without_backtrace()),
398+
std::string_view(kNewMessage));
399+
}
400+
}
401+
}
402+
403+
TEST_P(StatusTest, MacroThrowIfErrorWithErrorPropagationWithNewMessage) {
404+
int32_t errline0 = __LINE__ + 2;
405+
auto innerfn = [&]() -> absl::Status {
406+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage));
407+
};
408+
409+
int32_t errline1 = __LINE__ + 2;
410+
auto midfn = [&]() -> absl::Status {
411+
XLA_RETURN_IF_ERROR(innerfn(), kNewMessage);
412+
return absl::OkStatus();
413+
};
414+
415+
int32_t errline2 = __LINE__ + 2;
416+
auto outerfn = [&]() -> absl::Status {
417+
XLA_RETURN_IF_ERROR(midfn());
418+
return absl::OkStatus();
419+
};
420+
421+
int32_t errline3 = __LINE__ + 2;
422+
try {
423+
XLA_THROW_IF_ERROR(outerfn());
424+
FAIL() << "Expected `XLA_THROW_IF_ERROR(outerfn())` to throw.";
425+
} catch (const c10::Error& error) {
426+
if (IsShowCppStacktracesMode()) {
427+
// clang-format off
428+
//
429+
// Expected Error Message Prefix
430+
// =============================
431+
//
432+
// New test error kMessage
433+
//
434+
// Status Propagation Stacktrace:
435+
// From: operator() at ./test/cpp/test_status_common.h:334 (error: Test error kMessage)
436+
// From: operator() at ./test/cpp/test_status_common.h:339 (error: New test error kMessage)
437+
// From: operator() at ./test/cpp/test_status_common.h:345
438+
// From: TestBody at ./test/cpp/test_status_common.h:350
439+
//
440+
// C++ Stacktrace:
441+
//
442+
// clang-format on
443+
std::ostringstream oss;
444+
oss << kNewMessage;
445+
oss << "\n\n";
446+
oss << "Status Propagation Trace:";
447+
oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":"
448+
<< errline0 << " (error: " << kMessage << ")";
449+
oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":"
450+
<< errline1 << " (error: " << kNewMessage << ")";
451+
oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":"
452+
<< errline2;
453+
oss << kEntryPrefix << "From: TestBody at " << __FILE__ << ":"
454+
<< errline3;
455+
oss << "\n\n";
456+
oss << kTorchCppStacktracePrefix;
457+
EXPECT_THAT(std::string_view(error.what()),
458+
::testing::StartsWith(oss.str()));
459+
} else {
460+
EXPECT_EQ(std::string_view(error.what_without_backtrace()),
461+
std::string_view(kNewMessage));
462+
}
463+
}
464+
}
465+
466+
TEST_P(StatusTest, MacroAssignOrThrowWithErrorPropagationWithNewMessage) {
467+
int32_t errline0 = __LINE__ + 2;
468+
auto innerfn = [&]() -> absl::Status {
469+
return XLA_ERROR_WITH_LOCATION(absl::InvalidArgumentError(kMessage));
470+
};
471+
472+
int32_t errline1 = __LINE__ + 2;
473+
auto midfn = [&]() -> absl::Status {
474+
XLA_RETURN_IF_ERROR(innerfn(), kNewMessage);
475+
return absl::OkStatus();
476+
};
477+
478+
int32_t errline2 = __LINE__ + 2;
479+
auto outerfn = [&]() -> absl::StatusOr<int> {
480+
XLA_RETURN_IF_ERROR(midfn());
481+
return 42;
482+
};
483+
484+
int32_t errline3 = __LINE__ + 2;
485+
try {
486+
XLA_ASSIGN_OR_THROW(int ret, outerfn());
487+
FAIL() << "Expected `XLA_ASSIGN_OR_THROW(int ret, outerfn())` to throw.";
488+
} catch (const c10::Error& error) {
489+
if (IsShowCppStacktracesMode()) {
490+
// clang-format off
491+
//
492+
// Expected Error Message Prefix
493+
// =============================
494+
//
495+
// New test error kMessage
496+
//
497+
// Status Propagation Stacktrace:
498+
// From: operator() at ./test/cpp/test_status_common.h:393 (error: Test error kMessage)
499+
// From: operator() at ./test/cpp/test_status_common.h:398 (error: New test error kMessage)
500+
// From: operator() at ./test/cpp/test_status_common.h:404
501+
// From: TestBody at ./test/cpp/test_status_common.h:410
502+
//
503+
// C++ Stacktrace:
504+
//
505+
// clang-format on
506+
std::ostringstream oss;
507+
oss << kNewMessage;
508+
oss << "\n\n";
509+
oss << "Status Propagation Trace:";
510+
oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":"
511+
<< errline0 << " (error: " << kMessage << ")";
512+
oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":"
513+
<< errline1 << " (error: " << kNewMessage << ")";
514+
oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":"
515+
<< errline2;
516+
oss << kEntryPrefix << "From: TestBody at " << __FILE__ << ":"
517+
<< errline3;
518+
oss << "\n\n";
391519
oss << kTorchCppStacktracePrefix;
392520
EXPECT_THAT(std::string_view(error.what()),
393521
::testing::StartsWith(oss.str()));

torch_xla/csrc/status.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,17 @@ static std::string LineBreakIfCppStacktracesEnabled() {
118118
return torch::get_cpp_stacktraces_enabled() ? "\n" : "";
119119
}
120120

121+
void status_internal::ThrowStatusError(const absl::Status& status,
122+
const char* file, const int32_t line,
123+
const char* function,
124+
std::string_view message) {
125+
ABSL_CHECK(!status.ok());
126+
absl::Status new_status = status_internal::MaybeWithNewMessage(
127+
status, file, line, function, message);
128+
TORCH_CHECK(false, absl::StrCat(BuildStatusErrorMessage(new_status),
129+
LineBreakIfCppStacktracesEnabled()));
130+
}
131+
121132
void OkOrThrow(const absl::Status& status) {
122133
TORCH_CHECK(status.ok(), absl::StrCat(BuildStatusErrorMessage(status),
123134
LineBreakIfCppStacktracesEnabled()));

torch_xla/csrc/status.h

Lines changed: 124 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -62,30 +62,56 @@ constexpr char kStatusPropagationTraceKey[] =
6262
#define XLA_STATUS_VAR_ XLA_CONCAT_(status_, __LINE__)
6363

6464
// Provides a flexible way to handle error checking with optional message
65-
// modification. It evaluates `expr`, checks if it's OK, and either:
66-
// 1. Returns early with an error status
67-
// 2. Proceeds with the given `then` block if successful
68-
#define XLA_RETURN_IF_ERROR_IMPL_(expr, var, then, ...) \
69-
auto var = (expr); \
70-
if (!var.ok()) { \
71-
return ::torch_xla::status_internal::MaybeWithNewMessage( \
72-
::torch_xla::status_internal::GetStatus(var), __FILE__, __LINE__, \
73-
__FUNCTION__, ##__VA_ARGS__); \
74-
} \
75-
then
76-
77-
// Propagates `rexpr`, in case it's a non-ok status.
65+
// modification. It evaluates `expr`, and:
7866
//
79-
// Example:
67+
// 1. Runs the `on_error` block, if the returned status is an error
68+
// 2. Runs the `on_success` block, otherwise
8069
//
81-
// XLA_RETURN_IF_ERROR(
82-
// FnThatReturnsStatus(),
83-
// "New error message."
84-
// );
70+
#define XLA_PROCESS_STATUS_IMPL_(on_error, on_success, expr, var, ...) \
71+
auto var = (expr); \
72+
if (!var.ok()) { \
73+
on_error(var, ##__VA_ARGS__); \
74+
} \
75+
on_success
76+
77+
// `on_error` implementation for propagating the status `var`.
78+
//
79+
// This macro wraps `var` (error status returned) into a new status, adding
80+
// source location information to the status propagation trace if
81+
// `TORCH_SHOW_CPP_STACKTRACES` is set. And then, returns the newly created
82+
// status.
83+
//
84+
// It should be only used as parameter to `XLA_PROCESS_STATUS_IMPL_` macro
85+
// defined above.
86+
//
87+
#define XLA_PROPAGATE_STATUS_IMPL_(var, ...) \
88+
return ::torch_xla::status_internal::MaybeWithNewMessage( \
89+
::torch_xla::status_internal::GetStatus(var), __FILE__, __LINE__, \
90+
__FUNCTION__, ##__VA_ARGS__)
91+
92+
// `on_error` implementation for throwing an exception with the status `var`.
8593
//
86-
// If the function call results in an ok status, execution continues. Otherwise,
87-
// we early return a non-ok status. Then, if `TORCH_SHOW_CPP_STACKTRACES` is
88-
// set, the error shown will be:
94+
// This macro wraps `var` (error status returned) into a new status, adding
95+
// source location information to the status propagation trace if
96+
// `TORCH_SHOW_CPP_STACKTRACES` is set. And then, throws an exception using the
97+
// `ThrowStatusError()` function.
98+
//
99+
// It should be only used as parameter to `XLA_PROCESS_STATUS_IMPL_` macro
100+
// defined above.
101+
//
102+
#define XLA_THROW_STATUS_IMPL_(var, ...) \
103+
::torch_xla::status_internal::ThrowStatusError( \
104+
::torch_xla::status_internal::GetStatus(var), __FILE__, __LINE__, \
105+
__FUNCTION__, ##__VA_ARGS__)
106+
107+
// Macro implementation for processing an `absl::Status` value. This is the core
108+
// definition of `XLA_*_IF_ERROR()` macros that, given that `rexpr` is an error
109+
// status, either throws or returns (i.e. propagates) a newly created status
110+
// with source location information.
111+
//
112+
// If `rexpr` results in an ok status, execution continues. Otherwise, we run
113+
// `on_error`. Then, if `TORCH_SHOW_CPP_STACKTRACES` is set, the error shown
114+
// will be:
89115
//
90116
// RuntimeError: New error message.
91117
//
@@ -95,18 +121,61 @@ constexpr char kStatusPropagationTraceKey[] =
95121
// ...
96122
// From: <cpp-source-file>:<line> (error: New error message.)
97123
//
98-
#define XLA_RETURN_IF_ERROR(rexpr, ...) \
99-
do { \
100-
XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR_, {}, ##__VA_ARGS__) \
124+
#define XLA_DO_IF_ERROR_IMPL_(on_error, rexpr, ...) \
125+
do { \
126+
XLA_PROCESS_STATUS_IMPL_(on_error, /* on_success= */ {}, rexpr, \
127+
XLA_STATUS_VAR_, ##__VA_ARGS__) \
101128
} while (false)
102129

103-
// Propagates `rexpr`, in case it's a non-ok status. Otherwise, assign
104-
// its result to `lhs`.
130+
// If `rexpr` returns a non-ok status, this macro propagates the returned status
131+
// by early-returning a, possibly, new status with source location information.
132+
// Otherwise, continues execution.
133+
//
134+
// Example:
135+
//
136+
// XLA_RETURN_IF_ERROR(
137+
// FnThatReturnsStatus(),
138+
// "New error message."
139+
// );
140+
//
141+
#define XLA_RETURN_IF_ERROR(rexpr, ...) \
142+
XLA_DO_IF_ERROR_IMPL_(XLA_PROPAGATE_STATUS_IMPL_, rexpr, ##__VA_ARGS__)
143+
144+
// If `rexpr` returns a non-ok status, this macro throws an exception with the
145+
// returned status, possibly, wrapped by a new status with source location
146+
// information. Otherwise, continues execution.
147+
//
148+
// Example:
149+
//
150+
// XLA_THROW_IF_ERROR(
151+
// FnThatReturnsStatus(),
152+
// "New error message."
153+
// );
154+
//
155+
#define XLA_THROW_IF_ERROR(rexpr, ...) \
156+
XLA_DO_IF_ERROR_IMPL_(XLA_THROW_STATUS_IMPL_, rexpr, ##__VA_ARGS__)
157+
158+
// Macro implementation for processing an `absl::Status` value. This is the core
159+
// definition of `XLA_ASSIGN_OR_*()` macros that, given that `rexpr` is an error
160+
// status, either throws or returns (i.e. propagates) a newly created status
161+
// with source location information.
162+
//
163+
// If `rexpr` results in an ok status, we assign the value held by the status
164+
// returned by `rexpr` to `lhs`. Otherwise, we run `on_error`.
105165
//
106166
// Note 1: `lhs` might be a variable declarate, e.g:
107167
//
108168
// Note 2: this macro will be replaced by multiple statements that live on
109-
// the scope it was called (see XLA_RETURN_IF_ERROR_IMPL).
169+
// the scope it was called (see `XLA_PROCESS_STATUS_IMPL_`).
170+
//
171+
#define XLA_ASSIGN_OR_DO_IMPL_(on_error, lhs, rexpr, ...) \
172+
XLA_PROCESS_STATUS_IMPL_( \
173+
on_error, /* on_success= */ lhs = std::move(XLA_STATUS_VAR_).value(), \
174+
rexpr, XLA_STATUS_VAR_, ##__VA_ARGS__)
175+
176+
// If `rexpr` returns a non-ok status, this macro propagates the returned status
177+
// by early-returning a, possibly, new status with source location information.
178+
// Otherwise, assigns `rexpr` to `lhs`.
110179
//
111180
// Example:
112181
//
@@ -116,16 +185,23 @@ constexpr char kStatusPropagationTraceKey[] =
116185
// "New error message."
117186
// );
118187
//
119-
// If the function call results in an ok status, execution continues with
120-
// `result` set to `ret.value()`, where `ret` is the returned value of the
121-
// function. Otherwise, we early return a non-ok status. Then, if
122-
// `TORCH_SHOW_CPP_STACKTRACES` is set, the error shown will be similar to
123-
// the one above.
188+
#define XLA_ASSIGN_OR_RETURN(lhs, rexpr, ...) \
189+
XLA_ASSIGN_OR_DO_IMPL_(XLA_PROPAGATE_STATUS_IMPL_, lhs, rexpr, ##__VA_ARGS__)
190+
191+
// If `rexpr` returns a non-ok status, this macro throws an exception with the
192+
// returned status, possibly, wrapped by a new status with source location
193+
// information. Otherwise, assigns `rexpr` to `lhs`.
194+
//
195+
// Example:
196+
//
197+
// XLA_ASSIGN_OR_THROW(
198+
// int result,
199+
// FnThatReturnsStatus(),
200+
// "New error message."
201+
// );
124202
//
125-
#define XLA_ASSIGN_OR_RETURN(lhs, rexpr, ...) \
126-
XLA_RETURN_IF_ERROR_IMPL_(rexpr, XLA_STATUS_VAR_, \
127-
lhs = std::move(XLA_STATUS_VAR_).value(), \
128-
##__VA_ARGS__)
203+
#define XLA_ASSIGN_OR_THROW(lhs, rexpr, ...) \
204+
XLA_ASSIGN_OR_DO_IMPL_(XLA_THROW_STATUS_IMPL_, lhs, rexpr, ##__VA_ARGS__)
129205

130206
// Crashes if `status` is not an ok status.
131207
//
@@ -191,6 +267,18 @@ absl::Status MaybeWithNewMessage(const absl::Status& status, const char* file,
191267
int32_t line, const char* function,
192268
std::string_view new_message = "");
193269

270+
// Throws an exception from the given `status`
271+
//
272+
// This function wraps `status` within a new status, with the current source
273+
// location information added to its status propagation trace payload.
274+
//
275+
// Then, it throws an exception by using the `TORCH_CHECK(false)` macro, which
276+
// also displays the C++ stacktrace at the end, if `TORCH_SHOW_CPP_STACKTRACES`
277+
// is set.
278+
void ThrowStatusError(const absl::Status& status, const char* file,
279+
const int32_t line, const char* function,
280+
std::string_view message = "");
281+
194282
// Checks that `status` is an ok status.
195283
//
196284
// Otherwise, it will create a new status instance with the given source

0 commit comments

Comments
 (0)