Skip to content

Commit 5b8ddb1

Browse files
authored
Error Handling: refactor XlaCoordinator to use status types. (#9386)
1 parent 752ddba commit 5b8ddb1

18 files changed

+626
-39
lines changed

.github/scripts/run_tests.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,13 @@ function run_torch_xla_cpp_tests() {
5353
"test_lazy"
5454
"test_replication"
5555
"test_tensor"
56-
"test_runtime"
5756
# disable test_xla_backend_intf since it is flaky on upstream
5857
#"test_xla_backend_intf"
59-
"test_xla_sharding")
58+
"test_xla_sharding"
59+
"test_runtime"
60+
"test_status"
61+
"test_status_dont_show_cpp_error_context"
62+
"test_status_show_cpp_error_context")
6063
for name in "${test_names[@]}"; do
6164
echo "Running $name cpp test..."
6265
/tmp/test/bin/${name}

BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ test_suite(
7878
"//test/cpp:test_tensor",
7979
"//test/cpp:test_xla_sharding",
8080
"//test/cpp:test_runtime",
81+
"//test/cpp:test_status",
82+
"//test/cpp:test_status_dont_show_cpp_error_context",
83+
"//test/cpp:test_status_show_cpp_error_context",
8184
"//torch_xla/csrc/runtime:pjrt_computation_client_test",
8285
# "//torch_xla/csrc/runtime:ifrt_computation_client_test",
8386
],

test/cpp/BUILD

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,32 @@ ptxla_cc_test(
158158
"@com_google_googletest//:gtest_main",
159159
],
160160
)
161+
162+
ptxla_cc_test(
163+
name = "test_status",
164+
srcs = ["test_status.cpp"],
165+
deps = [
166+
"//torch_xla/csrc:status",
167+
"@com_google_googletest//:gtest_main",
168+
],
169+
)
170+
171+
ptxla_cc_test(
172+
name = "test_status_dont_show_cpp_error_context",
173+
srcs = ["test_status_dont_show_cpp_error_context.cpp"],
174+
deps = [
175+
"//torch_xla/csrc:status",
176+
"//torch_xla/csrc/runtime:env_vars",
177+
"@com_google_googletest//:gtest_main",
178+
],
179+
)
180+
181+
ptxla_cc_test(
182+
name = "test_status_show_cpp_error_context",
183+
srcs = ["test_status_show_cpp_error_context.cpp"],
184+
deps = [
185+
"//torch_xla/csrc:status",
186+
"//torch_xla/csrc/runtime:env_vars",
187+
"@com_google_googletest//:gtest_main",
188+
],
189+
)

test/cpp/run_tests.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ if [[ "$RUN_CPP_TESTS" == "cpp_tests" ]]; then
100100
# disable test_xla_backend_intf since it is flaky on upstream
101101
#"test_xla_backend_intf"
102102
"test_xla_sharding"
103-
"test_runtime")
103+
"test_runtime"
104+
"test_status"
105+
"test_status_dont_show_cpp_error_context"
106+
"test_status_show_cpp_error_context")
104107
fi
105108
for name in "${test_names[@]}"; do
106109
echo "Running $name cpp test..."

test/cpp/test_status.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#include <gtest/gtest.h>
2+
3+
#include "absl/status/status.h"
4+
#include "absl/status/statusor.h"
5+
#include "torch_xla/csrc/status.h"
6+
7+
namespace torch_xla {
8+
9+
TEST(StatusTest, MaybeThrowWithOkStatus) {
10+
absl::Status ok_status = absl::OkStatus();
11+
EXPECT_NO_THROW(MaybeThrow(ok_status));
12+
}
13+
14+
TEST(StatusTest, MaybeThrowWithErrorStatus) {
15+
absl::Status error_status = absl::InvalidArgumentError("Test error");
16+
EXPECT_THROW(MaybeThrow(error_status), std::runtime_error);
17+
}
18+
19+
TEST(StatusTest, GetValueOrThrowWithOkStatusOr) {
20+
int value = 42;
21+
absl::StatusOr<int> status_or = value;
22+
int result = GetValueOrThrow(std::move(status_or));
23+
EXPECT_EQ(result, value);
24+
}
25+
26+
TEST(StatusTest, GetValueOrThrowWithErrorStatusOr) {
27+
absl::StatusOr<int> status_or = absl::InvalidArgumentError("Test error");
28+
EXPECT_THROW(GetValueOrThrow(std::move(status_or)), std::runtime_error);
29+
}
30+
31+
TEST(StatusTest, MacroReturnIfError) {
32+
int value = 42;
33+
34+
auto test_function = [=]() -> absl::StatusOr<int> {
35+
absl::Status ok_status = absl::OkStatus();
36+
XLA_RETURN_IF_ERROR(ok_status);
37+
return value;
38+
};
39+
40+
absl::StatusOr<int> result = test_function();
41+
ASSERT_TRUE(result.ok());
42+
EXPECT_EQ(result.value(), value);
43+
}
44+
45+
TEST(StatusTest, MacroAssignOrReturn) {
46+
int initial_value = 42;
47+
int expected_value = initial_value * 2;
48+
49+
auto test_function = [=]() -> absl::StatusOr<int> {
50+
absl::StatusOr<int> status_or = initial_value;
51+
XLA_ASSIGN_OR_RETURN(int value, status_or);
52+
return value * 2;
53+
};
54+
55+
absl::StatusOr<int> result = test_function();
56+
ASSERT_TRUE(result.ok());
57+
EXPECT_EQ(result.value(), expected_value);
58+
}
59+
60+
} // namespace torch_xla
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#include <gtest/gtest.h>
2+
3+
#include "absl/status/status.h"
4+
#include "absl/status/statusor.h"
5+
#include "torch_xla/csrc/runtime/env_vars.h"
6+
#include "torch_xla/csrc/status.h"
7+
8+
namespace torch_xla {
9+
10+
TEST(StatusWithoutErrorContextTest, MaybeWithLocationRetunsSameStatus) {
11+
absl::Status error_status = absl::InvalidArgumentError("Test error message");
12+
absl::Status result = MaybeWithLocation(error_status, "test_file.cpp", 42);
13+
EXPECT_EQ(result, error_status);
14+
}
15+
16+
TEST(StatusWithoutErrorContextTest, MaybeWithNewMessageEmptyNewMessage) {
17+
absl::Status error_status = absl::InvalidArgumentError("Original error");
18+
absl::Status result = MaybeWithNewMessage(error_status, "test_file.cpp", 42);
19+
EXPECT_EQ(result, error_status);
20+
}
21+
22+
TEST(StatusWithoutErrorContextTest, MaybeWithNewMessageNonEmptyNewMessage) {
23+
constexpr char new_err_string[] = "New error message";
24+
absl::Status error_status = absl::InvalidArgumentError("Original error");
25+
absl::Status result =
26+
MaybeWithNewMessage(error_status, "test_file.cpp", 42, new_err_string);
27+
28+
ASSERT_FALSE(result.ok());
29+
ASSERT_NE(result, error_status);
30+
EXPECT_EQ(result.code(), error_status.code());
31+
EXPECT_EQ(result.message(), new_err_string);
32+
}
33+
34+
TEST(StatusWithoutErrorContextTest, MacroReturnIfErrorWithError) {
35+
constexpr char err_string[] = "Test error";
36+
37+
auto test_function = [=]() -> absl::Status {
38+
absl::Status error_status = absl::InvalidArgumentError(err_string);
39+
XLA_RETURN_IF_ERROR(error_status);
40+
return absl::OkStatus();
41+
};
42+
43+
absl::Status result = test_function();
44+
ASSERT_FALSE(result.ok());
45+
EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument);
46+
EXPECT_EQ(result.message(), err_string);
47+
}
48+
49+
TEST(StatusWithoutErrorContextTest, MacroAssignOrReturnWithError) {
50+
auto test_function = []() -> absl::StatusOr<int> {
51+
absl::StatusOr<int> status_or = absl::InvalidArgumentError("Test error");
52+
XLA_ASSIGN_OR_RETURN(int value, status_or);
53+
return value * 2;
54+
};
55+
56+
absl::StatusOr<int> result = test_function();
57+
ASSERT_FALSE(result.ok());
58+
EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument);
59+
}
60+
61+
TEST(StatusWithoutErrorContextTest, MacroErrorWithLocation) {
62+
absl::Status error_status = absl::InvalidArgumentError("Test error");
63+
absl::Status result = XLA_ERROR_WITH_LOCATION(error_status);
64+
EXPECT_EQ(result, error_status);
65+
}
66+
67+
void SetUp() {
68+
setenv(runtime::env::kEnvShowCppErrorContext, /* value= */ "false",
69+
/* replace= */ 1);
70+
}
71+
72+
} // namespace torch_xla
73+
74+
int main(int argc, char **argv) {
75+
::torch_xla::SetUp();
76+
::testing::InitGoogleTest(&argc, argv);
77+
return RUN_ALL_TESTS();
78+
}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#include <gtest/gtest.h>
2+
3+
#include <cstdlib>
4+
#include <string_view>
5+
6+
#include "absl/status/status.h"
7+
#include "absl/status/statusor.h"
8+
#include "torch_xla/csrc/runtime/env_vars.h"
9+
#include "torch_xla/csrc/status.h"
10+
11+
namespace torch_xla {
12+
namespace {
13+
14+
constexpr char new_message[] = "New test error message";
15+
constexpr char message[] = "Test error message";
16+
constexpr char test_file[] = "test_file.cpp";
17+
constexpr int32_t line = 42;
18+
19+
TEST(StatusWithErrorContextTest, MaybeWithLocationRetunsSameStatus) {
20+
absl::Status error_status = absl::InvalidArgumentError(message);
21+
absl::Status result = MaybeWithLocation(error_status, test_file, line);
22+
ASSERT_NE(result, error_status);
23+
ASSERT_EQ(result.code(), error_status.code());
24+
EXPECT_EQ(result.message(), "Test error message (at test_file.cpp:42)");
25+
}
26+
27+
TEST(StatusWithErrorContextTest, MaybeWithNewMessageEmptyNewMessage) {
28+
absl::Status error_status = absl::InvalidArgumentError(message);
29+
absl::Status result = MaybeWithNewMessage(error_status, test_file, line);
30+
ASSERT_NE(result, error_status);
31+
ASSERT_EQ(result.code(), error_status.code());
32+
EXPECT_EQ(result.message(), "Test error message (at test_file.cpp:42)");
33+
}
34+
35+
TEST(StatusWithErrorContextTest, MaybeWithNewMessageNonEmptyNewMessage) {
36+
absl::Status error_status = absl::InvalidArgumentError(message);
37+
absl::Status result =
38+
MaybeWithNewMessage(error_status, test_file, line, new_message);
39+
ASSERT_NE(result, error_status);
40+
ASSERT_FALSE(result.ok());
41+
EXPECT_EQ(result.code(), error_status.code());
42+
EXPECT_EQ(result.message(),
43+
"New test error message (at test_file.cpp:42)\n"
44+
"From Error: Test error message");
45+
}
46+
47+
TEST(StatusWithErrorContextTest, MacroReturnIfErrorWithError) {
48+
int32_t err_line = 0;
49+
50+
auto test_function = [=, &err_line]() -> absl::Status {
51+
absl::Status error_status = absl::InvalidArgumentError(message);
52+
err_line = __LINE__ + 1;
53+
XLA_RETURN_IF_ERROR(error_status);
54+
return absl::OkStatus();
55+
};
56+
57+
absl::Status result = test_function();
58+
ASSERT_FALSE(result.ok());
59+
EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument);
60+
EXPECT_EQ(result.message(), absl::StrCat("Test error message (at ", __FILE__,
61+
":", err_line, ")"));
62+
}
63+
64+
TEST(StatusWithErrorContextTest, MacroAssignOrReturnWithError) {
65+
int32_t err_line = 0;
66+
67+
auto test_function = [&err_line]() -> absl::StatusOr<int> {
68+
absl::StatusOr<int> status_or = absl::InvalidArgumentError(message);
69+
err_line = __LINE__ + 1;
70+
XLA_ASSIGN_OR_RETURN(int value, status_or);
71+
return value * 2;
72+
};
73+
74+
absl::StatusOr<int> result = test_function();
75+
ASSERT_FALSE(result.ok());
76+
EXPECT_EQ(result.status().code(), absl::StatusCode::kInvalidArgument);
77+
EXPECT_EQ(
78+
result.status().message(),
79+
absl::StrCat("Test error message (at ", __FILE__, ":", err_line, ")"));
80+
}
81+
82+
TEST(StatusWithErrorContextTest, MacroErrorWithLocation) {
83+
absl::Status error_status = absl::InvalidArgumentError(message);
84+
int32_t err_line = __LINE__ + 1;
85+
absl::Status result = XLA_ERROR_WITH_LOCATION(error_status);
86+
ASSERT_NE(result, error_status);
87+
ASSERT_FALSE(result.ok());
88+
EXPECT_EQ(result.code(), absl::StatusCode::kInvalidArgument);
89+
EXPECT_EQ(result.message(), absl::StrCat("Test error message (at ", __FILE__,
90+
":", err_line, ")"));
91+
}
92+
93+
void SetUp() {
94+
setenv(runtime::env::kEnvShowCppErrorContext, /* value= */ "true",
95+
/* replace= */ 1);
96+
}
97+
98+
} // namespace
99+
} // namespace torch_xla
100+
101+
int main(int argc, char** argv) {
102+
::torch_xla::SetUp();
103+
::testing::InitGoogleTest(&argc, argv);
104+
return RUN_ALL_TESTS();
105+
}

torch_xla/csrc/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ ptxla_cc_library(
264264
deps = [
265265
":device",
266266
":dtype",
267+
":status",
267268
":tensor",
268269
":version",
269270
"//torch_xla/csrc/runtime",
@@ -365,3 +366,15 @@ ptxla_cc_library(
365366
"@pybind11//:pybind11_embed",
366367
],
367368
)
369+
370+
cc_library(
371+
name = "status",
372+
srcs = ["status.cpp"],
373+
hdrs = ["status.h"],
374+
deps = [
375+
"//torch_xla/csrc/runtime:sys_util",
376+
"//torch_xla/csrc/runtime:env_vars",
377+
"@com_google_absl//absl/log:absl_check",
378+
"@com_google_absl//absl/status:statusor",
379+
],
380+
)

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "pybind11/pytypes.h"
3737
#include "pybind11/stl.h"
3838
#include "pybind11/stl_bind.h"
39+
#include "status.h"
3940
#include "torch_xla/csrc/XLANativeFunctions.h"
4041
#include "torch_xla/csrc/aten_autograd_ops.h"
4142
#include "torch_xla/csrc/aten_fallback.h"
@@ -64,6 +65,7 @@
6465
#include "torch_xla/csrc/runtime/xla_coordinator.h"
6566
#include "torch_xla/csrc/runtime/xla_util.h"
6667
#include "torch_xla/csrc/shape_helper.h"
68+
#include "torch_xla/csrc/status.h"
6769
#include "torch_xla/csrc/tensor_impl.h"
6870
#include "torch_xla/csrc/tensor_methods.h"
6971
#include "torch_xla/csrc/tensor_util.h"
@@ -172,18 +174,6 @@ class PythonScope : public Scope {
172174
};
173175
};
174176

175-
static void ConsumeAndMaybeThrow(absl::Status status) {
176-
if (!status.ok()) {
177-
throw std::runtime_error(std::string(status.message()));
178-
}
179-
}
180-
181-
template <class T>
182-
static T ConsumeAndMaybeThrow(absl::StatusOr<T> status) {
183-
ConsumeAndMaybeThrow(status.status());
184-
return std::move(status.value());
185-
}
186-
187177
struct NoGilSection {
188178
NoGilSection() : state(PyEval_SaveThread()) {}
189179
~NoGilSection() { PyEval_RestoreThread(state); }
@@ -1694,7 +1684,7 @@ void InitXlaModuleBindings(py::module m) {
16941684
})
16951685
.def("_init_computation_client",
16961686
[]() {
1697-
ConsumeAndMaybeThrow(runtime::GetComputationClient());
1687+
GetValueOrThrow(runtime::GetComputationClient());
16981688
})
16991689
.def("_xla_get_device_hw_type",
17001690
[](const at::Tensor& tensor) {

0 commit comments

Comments
 (0)