23
23
#include < gtest/gtest.h>
24
24
25
25
#include < cstdlib>
26
+ #include < sstream>
26
27
#include < stdexcept>
27
28
#include < utility>
28
29
29
30
#include " absl/status/status.h"
30
31
#include " absl/status/statusor.h"
32
+ #include " test/cpp/cpp_test_util.h"
31
33
#include " torch_xla/csrc/runtime/env_vars.h"
32
34
#include " torch_xla/csrc/status.h"
33
35
@@ -75,7 +77,11 @@ class StatusTest : public testing::TestWithParam<CppStacktracesMode> {
75
77
[](const ::testing::TestParamInfo<::torch_xla::CppStacktracesMode>& \
76
78
info) { return ToString (info.param ); })
77
79
78
- namespace testing {
80
+ namespace cpp_test {
81
+
82
+ // Prefix of the C++ stacktrace PyTorch adds to the error message.
83
+ constexpr inline char kTorchCppStacktracePrefix [] =
84
+ " Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:" ;
79
85
80
86
constexpr inline char kNewMessage [] = " New test error message" ;
81
87
constexpr inline char kMessage [] = " Test error message" ;
@@ -84,29 +90,6 @@ constexpr inline char kFunction[] = "foo";
84
90
constexpr inline char kEntryPrefix [] = " \n " ;
85
91
constexpr inline int32_t kLine = 42 ;
86
92
87
- // The PyTorch C++ stacktrace is ALWAYS appended to the error message.
88
- // More specifically, when `what()` function is called.
89
- //
90
- // However, it's only when the raised `c10::Error` gets translated to a
91
- // Python exception that PyTorch checks the value of the
92
- // `TORCH_SHOW_CPP_STACKTRACES` environment variable, which actually
93
- // controls whether the stacktrace will get shown or not by calling
94
- // `what_without_backtraces()`, instead.
95
- //
96
- // Therefore, we need to mimic this behavior.
97
- #define THROW_RUNTIME_ERROR_FROM_C10_ERROR (block ) \
98
- try { \
99
- block; \
100
- } catch (const c10::Error& error) { \
101
- throw std::runtime_error (IsShowCppStacktracesMode () \
102
- ? error.what () \
103
- : error.what_without_backtrace ()); \
104
- }
105
-
106
- // Prefix of the C++ stacktrace PyTorch adds to the error message.
107
- constexpr inline char kTorchCppStacktracePrefix [] =
108
- " Exception raised from MaybeThrow at torch_xla/csrc/status.cpp:" ;
109
-
110
93
inline std::string GetStatusPropagationTrace (const absl::Status& status) {
111
94
if (status.ok ()) {
112
95
return " " ;
@@ -123,21 +106,18 @@ TEST_P(StatusTest, MaybeThrowWithOkStatus) {
123
106
}
124
107
125
108
TEST_P (StatusTest, MaybeThrowWithErrorStatus) {
126
- auto throw_exception = [=]() {
127
- THROW_RUNTIME_ERROR_FROM_C10_ERROR ({
128
- absl::Status error_status = absl::InvalidArgumentError (kMessage );
129
- MaybeThrow (error_status);
130
- });
131
- };
132
-
133
- if (IsShowCppStacktracesMode ()) {
134
- std::string expected_prefix =
135
- absl::StrCat (kMessage , " \n\n " , kTorchCppStacktracePrefix );
136
- EXPECT_THAT (throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
137
- ::testing::StartsWith (expected_prefix)));
138
- } else {
139
- EXPECT_THAT (throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
140
- ::testing::Eq (kMessage )));
109
+ try {
110
+ absl::Status error_status = absl::InvalidArgumentError (kMessage );
111
+ MaybeThrow (error_status);
112
+ } catch (const c10::Error& error) {
113
+ if (IsShowCppStacktracesMode ()) {
114
+ EXPECT_THAT (std::string_view (error.what ()),
115
+ ::testing::StartsWith (absl::StrCat(
116
+ kMessage , " \n\n " , kTorchCppStacktracePrefix )));
117
+ } else {
118
+ EXPECT_EQ (std::string_view (error.what_without_backtrace ()),
119
+ std::string_view (kMessage ));
120
+ }
141
121
}
142
122
}
143
123
@@ -149,20 +129,18 @@ TEST_P(StatusTest, GetValueOrThrowWithOkStatusOr) {
149
129
}
150
130
151
131
TEST_P (StatusTest, GetValueOrThrowWithErrorStatusOr) {
152
- auto throw_exception = [=]() {
153
- THROW_RUNTIME_ERROR_FROM_C10_ERROR ({
154
- absl::StatusOr<int > error_status = absl::InvalidArgumentError (kMessage );
155
- int value = GetValueOrThrow (error_status);
156
- });
157
- };
158
- if (IsShowCppStacktracesMode ()) {
159
- std::string expected_prefix =
160
- absl::StrCat (kMessage , " \n\n " , kTorchCppStacktracePrefix );
161
- EXPECT_THAT (throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
162
- ::testing::StartsWith (expected_prefix)));
163
- } else {
164
- EXPECT_THAT (throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
165
- ::testing::Eq (kMessage )));
132
+ try {
133
+ absl::StatusOr<int > error_status = absl::InvalidArgumentError (kMessage );
134
+ int value = GetValueOrThrow (error_status);
135
+ } catch (const c10::Error& error) {
136
+ if (IsShowCppStacktracesMode ()) {
137
+ EXPECT_THAT (std::string_view (error.what ()),
138
+ ::testing::StartsWith (absl::StrCat(
139
+ kMessage , " \n\n " , kTorchCppStacktracePrefix )));
140
+ } else {
141
+ EXPECT_EQ (std::string_view (error.what_without_backtrace ()),
142
+ std::string_view (kMessage ));
143
+ }
166
144
}
167
145
}
168
146
@@ -272,14 +250,14 @@ TEST_P(StatusTest, MacroReturnIfErrorWithNestedError) {
272
250
EXPECT_EQ (result.message (), std::string_view (kMessage ));
273
251
274
252
if (IsShowCppStacktracesMode ()) {
275
- auto frame0 = absl::StrCat ( kEntryPrefix , " From: operator() at " , __FILE__,
276
- " :" , errline0, " (error: " , kMessage , " ) " );
277
- auto frame1 = absl::StrCat ( kEntryPrefix , " From: operator() at " , __FILE__,
278
- " :" , errline1);
279
- auto frame2 = absl::StrCat ( kEntryPrefix , " From: operator() at " , __FILE__,
280
- " :" , errline2);
281
- EXPECT_EQ ( GetStatusPropagationTrace (result),
282
- absl::StrCat (frame0, frame1, frame2 ));
253
+ std::ostringstream oss;
254
+ oss << kEntryPrefix << " From: operator() at " << __FILE__ << " :" << errline0
255
+ << " (error: " << kMessage << " ) " ;
256
+ oss << kEntryPrefix << " From: operator() at " << __FILE__ << " :"
257
+ << errline1;
258
+ oss << kEntryPrefix << " From: operator() at " << __FILE__ << " :"
259
+ << errline2;
260
+ EXPECT_EQ ( GetStatusPropagationTrace (result), oss. str ( ));
283
261
}
284
262
}
285
263
@@ -383,39 +361,44 @@ TEST_P(StatusTest, MaybeThrowWithErrorPropagationWithNewMessage) {
383
361
return absl::OkStatus ();
384
362
};
385
363
386
- auto throw_exception = [&]() {
387
- THROW_RUNTIME_ERROR_FROM_C10_ERROR (MaybeThrow (outerfn ()));
388
- };
389
-
390
- if (IsShowCppStacktracesMode ()) {
391
- // Expected Error Message Prefix
392
- // =============================
393
- //
394
- // New test error kMessage
395
- //
396
- // Status Propagation Stacktrace:
397
- // From: ./test/cpp/test_status_common.h:329 (error: Test error
398
- // kMessage) From: ./test/cpp/test_status_common.h:335 (error: New test
399
- // error kMessage) From: ./test/cpp/test_status_common.h:342
400
- //
401
- // C++ Stacktrace:
402
- //
403
- std::string expected_prefix = absl::StrCat (
404
- kNewMessage , " \n\n Status Propagation Trace:" , kEntryPrefix ,
405
- " From: operator() at " , __FILE__, " :" , errline0, " (error: " , kMessage ,
406
- " )" , kEntryPrefix , " From: operator() at " , __FILE__, " :" , errline1,
407
- " (error: " , kNewMessage , " )" , kEntryPrefix , " From: operator() at " ,
408
- __FILE__, " :" , errline2, " \n\n " , kTorchCppStacktracePrefix );
409
-
410
- EXPECT_THAT (throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
411
- ::testing::StartsWith (expected_prefix)));
412
- } else {
413
- EXPECT_THAT (throw_exception, ::testing::ThrowsMessage<std::runtime_error>(
414
- ::testing::Eq (kNewMessage )));
364
+ try {
365
+ MaybeThrow (outerfn ());
366
+ } catch (const c10::Error& error) {
367
+ if (IsShowCppStacktracesMode ()) {
368
+ // Expected Error Message Prefix
369
+ // =============================
370
+ //
371
+ // New test error kMessage
372
+ //
373
+ // Status Propagation Stacktrace:
374
+ // From: ./test/cpp/test_status_common.h:329 (error: Test error
375
+ // kMessage) From: ./test/cpp/test_status_common.h:335 (error: New
376
+ // test error kMessage) From: ./test/cpp/test_status_common.h:342
377
+ //
378
+ // C++ Stacktrace:
379
+ //
380
+ std::ostringstream oss;
381
+ oss << kNewMessage ;
382
+ oss << " \n\n " ;
383
+ oss << " Status Propagation Trace:" ;
384
+ oss << kEntryPrefix << " From: operator() at " << __FILE__ << " :"
385
+ << errline0 << " (error: " << kMessage << " )" ;
386
+ oss << kEntryPrefix << " From: operator() at " << __FILE__ << " :"
387
+ << errline1 << " (error: " << kNewMessage << " )" ;
388
+ oss << kEntryPrefix << " From: operator() at " << __FILE__ << " :"
389
+ << errline2;
390
+ oss << " \n\n " ;
391
+ oss << kTorchCppStacktracePrefix ;
392
+ EXPECT_THAT (std::string_view (error.what ()),
393
+ ::testing::StartsWith (oss.str()));
394
+ } else {
395
+ EXPECT_EQ (std::string_view (error.what_without_backtrace ()),
396
+ std::string_view (kNewMessage ));
397
+ }
415
398
}
416
399
}
417
400
418
- } // namespace testing
401
+ } // namespace cpp_test
419
402
} // namespace torch_xla
420
403
421
404
#endif // XLA_TEST_CPP_TEST_STATUS_COMMON_H_
0 commit comments