Skip to content

Commit 8c1449f

Browse files
authored
Use TORCH_CHECK() instead of throwing std::runtime_error in XLA_CHECK*() macros' implementation. (#9542)
1 parent 57cd41c commit 8c1449f

12 files changed

+134
-116
lines changed

test/cpp/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ cc_library(
167167
name = "test_status_common",
168168
hdrs = ["test_status_common.h"],
169169
deps = [
170+
":cpp_test_util",
170171
"//torch_xla/csrc:status",
171172
"//torch_xla/csrc/runtime:env_vars",
172173
"@com_google_absl//absl/status:statusor",
@@ -196,6 +197,7 @@ ptxla_cc_test(
196197
name = "test_debug_macros",
197198
srcs = ["test_debug_macros.cpp"],
198199
deps = [
200+
":cpp_test_util",
199201
"//torch_xla/csrc:status",
200202
"//torch_xla/csrc/runtime:debug_macros",
201203
"//torch_xla/csrc/runtime:env_vars",

test/cpp/test_aten_xla_tensor_1.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2429,7 +2429,7 @@ TEST_F(AtenXlaTensorTest, TestCount_Nonzero_error_case) {
24292429
torch::Tensor xla_a = CopyToDevice(a, device);
24302430

24312431
std::vector<long int> dims = {0, 0};
2432-
EXPECT_THROW(torch::count_nonzero(xla_a, dims), std::runtime_error);
2432+
EXPECT_THROW(torch::count_nonzero(xla_a, dims), c10::Error);
24332433

24342434
dims = {10};
24352435
EXPECT_THROW(torch::count_nonzero(xla_a, dims), c10::Error);

test/cpp/test_aten_xla_tensor_4.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ TEST_F(AtenXlaTensorTest, TestGettingSizeOnDynamicTensor) {
307307
torch::TensorOptions(torch::kFloat));
308308
torch::Tensor xla_b = CopyToDevice(b, device);
309309
torch::Tensor xla_nonzero = torch::nonzero(xla_b);
310-
EXPECT_THROW(xla_nonzero.sizes(), std::runtime_error);
310+
EXPECT_THROW(xla_nonzero.sizes(), c10::Error);
311311
EXPECT_NO_THROW(xla_nonzero.sym_sizes());
312312
});
313313
}

test/cpp/test_debug_macros.cpp

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,42 @@
11
#include <gmock/gmock.h>
22
#include <gtest/gtest.h>
33

4+
#include "test/cpp/cpp_test_util.h"
45
#include "torch_xla/csrc/runtime/debug_macros.h"
56
#include "torch_xla/csrc/runtime/env_vars.h"
67

7-
namespace torch_xla {
8+
namespace torch_xla::cpp_test {
89
namespace {
910

10-
using absl::StrCat;
11+
// Prefix of the C++ stacktrace PyTorch adds to the error message.
12+
constexpr char kTorchCppStacktracePrefix[] =
13+
"Exception raised from operator& at torch_xla/csrc/runtime/tf_logging.cpp:";
1114

1215
TEST(DebugMacrosTest, Check) {
13-
auto line = __LINE__ + 1;
14-
EXPECT_THAT([] { XLA_CHECK(false) << "Error message."; },
15-
testing::ThrowsMessage<std::runtime_error>(testing::StartsWith(
16-
StrCat("Check failed: false: Error message. (at ", __FILE__,
17-
":", line, ")\n*** Begin stack trace ***"))));
16+
int32_t line;
17+
try {
18+
line = __LINE__ + 1;
19+
XLA_CHECK(false) << "Error message.";
20+
} catch (const c10::Error& error) {
21+
EXPECT_THAT(error.what(),
22+
testing::StartsWith(absl::StrCat(
23+
"Check failed: false: Error message. (at ", __FILE__, ":",
24+
line, ")\n\n", kTorchCppStacktracePrefix)));
25+
}
1826
}
1927

20-
#define TEST_XLA_CHECK_OP_(opstr, lhs, rhs, compstr, valstr) \
21-
TEST(DebugMacrosTest, Check##opstr) { \
22-
EXPECT_THAT( \
23-
[] { XLA_CHECK_##opstr(lhs, rhs) << " Error message."; }, \
24-
testing::ThrowsMessage<std::runtime_error>(testing::StartsWith(StrCat( \
25-
"Check failed: " compstr " (" valstr ") Error message. (at ", \
26-
__FILE__, ":", __LINE__, ")\n*** Begin stack trace ***")))); \
28+
#define TEST_XLA_CHECK_OP_(opstr, lhs, rhs, compstr, valstr) \
29+
TEST(DebugMacrosTest, Check##opstr) { \
30+
try { \
31+
XLA_CHECK_##opstr(lhs, rhs) << " Error message."; \
32+
} catch (const c10::Error& error) { \
33+
EXPECT_THAT( \
34+
error.what(), \
35+
::testing::StartsWith(absl::StrCat( \
36+
"Check failed: " compstr " (" valstr ") Error message. (at ", \
37+
__FILE__, ":", __LINE__, ")\n\n", \
38+
::torch_xla::cpp_test::kTorchCppStacktracePrefix))); \
39+
} \
2740
}
2841

2942
#define TEST_XLA_CHECK_OP(opstr, op, lhs, rhs) \
@@ -52,15 +65,15 @@ TEST_XLA_CHECK_OP(LT, <, 5, 1)
5265
TEST_XLA_CHECK_GREATER(GE, <=, 5, 8)
5366
TEST_XLA_CHECK_GREATER(GT, <, 5, 8)
5467

55-
} // namespace
56-
} // namespace torch_xla
57-
5868
static void SetUp() {
5969
setenv("TORCH_SHOW_CPP_STACKTRACES", /* value= */ "1", /* replace= */ 1);
6070
}
6171

72+
} // namespace
73+
} // namespace torch_xla::cpp_test
74+
6275
int main(int argc, char** argv) {
63-
SetUp();
76+
::torch_xla::cpp_test::SetUp();
6477
::testing::InitGoogleTest(&argc, argv);
6578
return RUN_ALL_TESTS();
6679
}

test/cpp/test_ir.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ TEST_F(IrTest, TestSizeDivNodeDynamicByZero) {
358358
std::shared_ptr<torch::lazy::DimensionNode> dim_node_div =
359359
std::dynamic_pointer_cast<torch::lazy::DimensionNode>(node_div);
360360

361-
EXPECT_THROW(dim_node_div->getDynamicValue(), std::runtime_error);
361+
EXPECT_THROW(dim_node_div->getDynamicValue(), c10::Error);
362362
}
363363

364364
} // namespace cpp_test

test/cpp/test_status_common.h

Lines changed: 74 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
#include <gtest/gtest.h>
2424

2525
#include <cstdlib>
26+
#include <sstream>
2627
#include <stdexcept>
2728
#include <utility>
2829

2930
#include "absl/status/status.h"
3031
#include "absl/status/statusor.h"
32+
#include "test/cpp/cpp_test_util.h"
3133
#include "torch_xla/csrc/runtime/env_vars.h"
3234
#include "torch_xla/csrc/status.h"
3335

@@ -75,7 +77,11 @@ class StatusTest : public testing::TestWithParam<CppStacktracesMode> {
7577
[](const ::testing::TestParamInfo<::torch_xla::CppStacktracesMode>& \
7678
info) { return ToString(info.param); })
7779

78-
namespace testing {
80+
namespace cpp_test {
81+
82+
// Prefix of the C++ stacktrace PyTorch adds to the error message.
83+
constexpr inline char kTorchCppStacktracePrefix[] =
84+
"Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:";
7985

8086
constexpr inline char kNewMessage[] = "New test error message";
8187
constexpr inline char kMessage[] = "Test error message";
@@ -84,29 +90,6 @@ constexpr inline char kFunction[] = "foo";
8490
constexpr inline char kEntryPrefix[] = "\n ";
8591
constexpr inline int32_t kLine = 42;
8692

87-
// The PyTorch C++ stacktrace is ALWAYS appended to the error message.
88-
// More specifically, when `what()` function is called.
89-
//
90-
// However, it's only when the raised `c10::Error` gets translated to a
91-
// Python exception that PyTorch checks the value of the
92-
// `TORCH_SHOW_CPP_STACKTRACES` environment variable, which actually
93-
// controls whether the stacktrace will get shown or not by calling
94-
// `what_without_backtraces()`, instead.
95-
//
96-
// Therefore, we need to mimic this behavior.
97-
#define THROW_RUNTIME_ERROR_FROM_C10_ERROR(block) \
98-
try { \
99-
block; \
100-
} catch (const c10::Error& error) { \
101-
throw std::runtime_error(IsShowCppStacktracesMode() \
102-
? error.what() \
103-
: error.what_without_backtrace()); \
104-
}
105-
106-
// Prefix of the C++ stacktrace PyTorch adds to the error message.
107-
constexpr inline char kTorchCppStacktracePrefix[] =
108-
"Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:";
109-
11093
inline std::string GetStatusPropagationTrace(const absl::Status& status) {
11194
if (status.ok()) {
11295
return "";
@@ -123,21 +106,18 @@ TEST_P(StatusTest, MaybeThrowWithOkStatus) {
123106
}
124107

125108
TEST_P(StatusTest, MaybeThrowWithErrorStatus) {
126-
auto throw_exception = [=]() {
127-
THROW_RUNTIME_ERROR_FROM_C10_ERROR({
128-
absl::Status error_status = absl::InvalidArgumentError(kMessage);
129-
MaybeThrow(error_status);
130-
});
131-
};
132-
133-
if (IsShowCppStacktracesMode()) {
134-
std::string expected_prefix =
135-
absl::StrCat(kMessage, "\n\n", kTorchCppStacktracePrefix);
136-
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
137-
::testing::StartsWith(expected_prefix)));
138-
} else {
139-
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
140-
::testing::Eq(kMessage)));
109+
try {
110+
absl::Status error_status = absl::InvalidArgumentError(kMessage);
111+
MaybeThrow(error_status);
112+
} catch (const c10::Error& error) {
113+
if (IsShowCppStacktracesMode()) {
114+
EXPECT_THAT(std::string_view(error.what()),
115+
::testing::StartsWith(absl::StrCat(
116+
kMessage, "\n\n", kTorchCppStacktracePrefix)));
117+
} else {
118+
EXPECT_EQ(std::string_view(error.what_without_backtrace()),
119+
std::string_view(kMessage));
120+
}
141121
}
142122
}
143123

@@ -149,20 +129,18 @@ TEST_P(StatusTest, GetValueOrThrowWithOkStatusOr) {
149129
}
150130

151131
TEST_P(StatusTest, GetValueOrThrowWithErrorStatusOr) {
152-
auto throw_exception = [=]() {
153-
THROW_RUNTIME_ERROR_FROM_C10_ERROR({
154-
absl::StatusOr<int> error_status = absl::InvalidArgumentError(kMessage);
155-
int value = GetValueOrThrow(error_status);
156-
});
157-
};
158-
if (IsShowCppStacktracesMode()) {
159-
std::string expected_prefix =
160-
absl::StrCat(kMessage, "\n\n", kTorchCppStacktracePrefix);
161-
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
162-
::testing::StartsWith(expected_prefix)));
163-
} else {
164-
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
165-
::testing::Eq(kMessage)));
132+
try {
133+
absl::StatusOr<int> error_status = absl::InvalidArgumentError(kMessage);
134+
int value = GetValueOrThrow(error_status);
135+
} catch (const c10::Error& error) {
136+
if (IsShowCppStacktracesMode()) {
137+
EXPECT_THAT(std::string_view(error.what()),
138+
::testing::StartsWith(absl::StrCat(
139+
kMessage, "\n\n", kTorchCppStacktracePrefix)));
140+
} else {
141+
EXPECT_EQ(std::string_view(error.what_without_backtrace()),
142+
std::string_view(kMessage));
143+
}
166144
}
167145
}
168146

@@ -272,14 +250,14 @@ TEST_P(StatusTest, MacroReturnIfErrorWithNestedError) {
272250
EXPECT_EQ(result.message(), std::string_view(kMessage));
273251

274252
if (IsShowCppStacktracesMode()) {
275-
auto frame0 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__,
276-
":", errline0, " (error: ", kMessage, ")");
277-
auto frame1 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__,
278-
":", errline1);
279-
auto frame2 = absl::StrCat(kEntryPrefix, "From: operator() at ", __FILE__,
280-
":", errline2);
281-
EXPECT_EQ(GetStatusPropagationTrace(result),
282-
absl::StrCat(frame0, frame1, frame2));
253+
std::ostringstream oss;
254+
oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":" << errline0
255+
<< " (error: " << kMessage << ")";
256+
oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":"
257+
<< errline1;
258+
oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":"
259+
<< errline2;
260+
EXPECT_EQ(GetStatusPropagationTrace(result), oss.str());
283261
}
284262
}
285263

@@ -383,39 +361,44 @@ TEST_P(StatusTest, MaybeThrowWithErrorPropagationWithNewMessage) {
383361
return absl::OkStatus();
384362
};
385363

386-
auto throw_exception = [&]() {
387-
THROW_RUNTIME_ERROR_FROM_C10_ERROR(MaybeThrow(outerfn()));
388-
};
389-
390-
if (IsShowCppStacktracesMode()) {
391-
// Expected Error Message Prefix
392-
// =============================
393-
//
394-
// New test error kMessage
395-
//
396-
// Status Propagation Stacktrace:
397-
// From: ./test/cpp/test_status_common.h:329 (error: Test error
398-
// kMessage) From: ./test/cpp/test_status_common.h:335 (error: New test
399-
// error kMessage) From: ./test/cpp/test_status_common.h:342
400-
//
401-
// C++ Stacktrace:
402-
//
403-
std::string expected_prefix = absl::StrCat(
404-
kNewMessage, "\n\nStatus Propagation Trace:", kEntryPrefix,
405-
"From: operator() at ", __FILE__, ":", errline0, " (error: ", kMessage,
406-
")", kEntryPrefix, "From: operator() at ", __FILE__, ":", errline1,
407-
" (error: ", kNewMessage, ")", kEntryPrefix, "From: operator() at ",
408-
__FILE__, ":", errline2, "\n\n", kTorchCppStacktracePrefix);
409-
410-
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
411-
::testing::StartsWith(expected_prefix)));
412-
} else {
413-
EXPECT_THAT(throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
414-
::testing::Eq(kNewMessage)));
364+
try {
365+
MaybeThrow(outerfn());
366+
} catch (const c10::Error& error) {
367+
if (IsShowCppStacktracesMode()) {
368+
// Expected Error Message Prefix
369+
// =============================
370+
//
371+
// New test error kMessage
372+
//
373+
// Status Propagation Stacktrace:
374+
// From: ./test/cpp/test_status_common.h:329 (error: Test error
375+
// kMessage) From: ./test/cpp/test_status_common.h:335 (error: New
376+
// test error kMessage) From: ./test/cpp/test_status_common.h:342
377+
//
378+
// C++ Stacktrace:
379+
//
380+
std::ostringstream oss;
381+
oss << kNewMessage;
382+
oss << "\n\n";
383+
oss << "Status Propagation Trace:";
384+
oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":"
385+
<< errline0 << " (error: " << kMessage << ")";
386+
oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":"
387+
<< errline1 << " (error: " << kNewMessage << ")";
388+
oss << kEntryPrefix << "From: operator() at " << __FILE__ << ":"
389+
<< errline2;
390+
oss << "\n\n";
391+
oss << kTorchCppStacktracePrefix;
392+
EXPECT_THAT(std::string_view(error.what()),
393+
::testing::StartsWith(oss.str()));
394+
} else {
395+
EXPECT_EQ(std::string_view(error.what_without_backtrace()),
396+
std::string_view(kNewMessage));
397+
}
415398
}
416399
}
417400

418-
} // namespace testing
401+
} // namespace cpp_test
419402
} // namespace torch_xla
420403

421404
#endif // XLA_TEST_CPP_TEST_STATUS_COMMON_H_

test/cpp/test_status_dont_show_cpp_stacktraces.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
using torch_xla::StatusTest;
44

5+
namespace torch_xla::cpp_test {
6+
namespace {
7+
58
// This file instantiates the parameterized tests defined in
69
// `test_status_common.h`. It specifically configures the test environment by
710
// explicitly setting the `TORCH_SHOW_CPP_STACKTRACES` environment variable to
@@ -11,3 +14,6 @@ using torch_xla::StatusTest;
1114
// automatically be run in this mode (without C++ error context).
1215
//
1316
INSTANTIATE_WITH_CPP_STACKTRACES_MODE(StatusTest, StatusTest, kHide);
17+
18+
} // namespace
19+
} // namespace torch_xla::cpp_test

test/cpp/test_status_show_cpp_stacktraces.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
using torch_xla::StatusTest;
44

5+
namespace torch_xla::cpp_test {
6+
namespace {
7+
58
// This file instantiates the parameterized tests defined in
69
// `test_status_common.h`. It specifically configures the test environment by
710
// explicitly setting the `TORCH_SHOW_CPP_STACKTRACES` environment variable to
@@ -11,3 +14,6 @@ using torch_xla::StatusTest;
1114
// automatically be run in this mode (with C++ error context).
1215
INSTANTIATE_WITH_CPP_STACKTRACES_MODE(StatusWithCppErrorContextTest, StatusTest,
1316
kShow);
17+
18+
} // namespace
19+
} // namespace torch_xla::cpp_test

torch_xla/csrc/runtime/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,8 @@ cc_library(
395395
hdrs = ["tf_logging.h"],
396396
deps = [
397397
"//torch_xla/csrc:status",
398+
"@torch//:headers",
399+
"@torch//:runtime_headers",
398400
"@tsl//tsl/platform:stacktrace",
399401
"@tsl//tsl/platform:statusor",
400402
"@xla//xla/service:platform_util",

torch_xla/csrc/runtime/debug_macros.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@
66
#include "tsl/platform/stacktrace.h"
77
#include "tsl/platform/statusor.h"
88

9+
// DEPRECATED
10+
// ==========
11+
// These macros are deprecated in favor of error handling by propagating abseil
12+
// status types (e.g. `absl::Status` and `absl::StatusOr<T>`).
13+
//
14+
// Description
15+
// ===========
916
// TORCH_SHOW_CPP_STACKTRACES environment variable changes the behavior of the
1017
// macros below, such as XLA_CHECK(), XLA_CHECK_EQ(), etc. (except for
1118
// XLA_CHECK_OK) in the following way:

0 commit comments

Comments
 (0)