Skip to content

Commit 10ed554

Browse files
authored
ErrorHandling: make GetComputationClient() return StatusOr<T> type. (#9384)
1 parent a36d3e5 commit 10ed554

22 files changed

+260
-179
lines changed

.github/scripts/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ function run_torch_xla_cpp_tests() {
5353
"test_lazy"
5454
"test_replication"
5555
"test_tensor"
56+
"test_runtime"
5657
# disable test_xla_backend_intf since it is flaky on upstream
5758
#"test_xla_backend_intf"
5859
"test_xla_sharding")

BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ test_suite(
7777
"//test/cpp:test_replication",
7878
"//test/cpp:test_tensor",
7979
"//test/cpp:test_xla_sharding",
80+
"//test/cpp:test_runtime",
8081
"//torch_xla/csrc/runtime:pjrt_computation_client_test",
8182
# "//torch_xla/csrc/runtime:ifrt_computation_client_test",
8283
],

test/cpp/BUILD

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,3 +149,12 @@ ptxla_cc_test(
149149
)
150150
for test in glob(["test_aten_xla_tensor*cpp"])
151151
]
152+
153+
ptxla_cc_test(
154+
name = "test_runtime",
155+
srcs = ["test_runtime.cpp"],
156+
deps = [
157+
"//torch_xla/csrc/runtime:runtime",
158+
"@com_google_googletest//:gtest_main",
159+
],
160+
)

test/cpp/cpp_test_util.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,14 @@ void WithAllDevices(
225225
std::vector<torch::lazy::BackendDevice> devices;
226226
std::vector<torch::lazy::BackendDevice> all_devices;
227227
for (const auto& device_str :
228-
torch_xla::runtime::GetComputationClient()->GetLocalDevices()) {
228+
torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices()) {
229229
torch::lazy::BackendDevice device = ParseDeviceString(device_str);
230230
if (device.type() == device_type.type) {
231231
devices.push_back(device);
232232
}
233233
}
234234
for (const auto& device_str :
235-
torch_xla::runtime::GetComputationClient()->GetAllDevices()) {
235+
torch_xla::runtime::GetComputationClientOrDie()->GetAllDevices()) {
236236
torch::lazy::BackendDevice device = ParseDeviceString(device_str);
237237
if (device.type() == device_type.type) {
238238
all_devices.push_back(device);
@@ -283,17 +283,17 @@ std::vector<torch_xla::runtime::ComputationClient::DataPtr> Execute(
283283
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
284284
instances.push_back(
285285
{std::move(computation), device.toString(),
286-
torch_xla::runtime::GetComputationClient()->GetCompilationDevices(
286+
torch_xla::runtime::GetComputationClientOrDie()->GetCompilationDevices(
287287
device.toString(), {}),
288288
&shape});
289289

290290
std::vector<
291291
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
292-
computations = torch_xla::runtime::GetComputationClient()->Compile(
292+
computations = torch_xla::runtime::GetComputationClientOrDie()->Compile(
293293
std::move(instances));
294294

295295
torch_xla::runtime::ComputationClient::ExecuteComputationOptions options;
296-
return torch_xla::runtime::GetComputationClient()->ExecuteComputation(
296+
return torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation(
297297
*computations.front(), UnwrapXlaData(lowering_ctx.GetParametersData()),
298298
device.toString(), options);
299299
}
@@ -302,7 +302,7 @@ std::vector<at::Tensor> Fetch(
302302
absl::Span<const torch_xla::runtime::ComputationClient::DataPtr>
303303
device_data) {
304304
std::vector<xla::Literal> literals =
305-
torch_xla::runtime::GetComputationClient()->TransferFromDevice(
305+
torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice(
306306
device_data);
307307
std::vector<at::Tensor> tensors;
308308
for (auto& literal : literals) {

test/cpp/run_tests.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ if [[ "$RUN_CPP_TESTS" == "cpp_tests" ]]; then
9999
"test_tensor"
100100
# disable test_xla_backend_intf since it is flaky on upstream
101101
#"test_xla_backend_intf"
102-
"test_xla_sharding")
102+
"test_xla_sharding"
103+
"test_runtime")
103104
fi
104105
for name in "${test_names[@]}"; do
105106
echo "Running $name cpp test..."

test/cpp/test_replication.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ void TestSingleReplication(
4848
}
4949
std::vector<torch_xla::runtime::ComputationClient::ComputationPtr>
5050
compiled_computations =
51-
torch_xla::runtime::GetComputationClient()->Compile(
51+
torch_xla::runtime::GetComputationClientOrDie()->Compile(
5252
std::move(instances));
5353

5454
std::vector<at::Tensor> tensors;
@@ -65,7 +65,7 @@ void TestSingleReplication(
6565
for (size_t i = 0; i < device_strings.size(); ++i) {
6666
auto executor = [&, i]() {
6767
results[i] =
68-
torch_xla::runtime::GetComputationClient()->ExecuteComputation(
68+
torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation(
6969
*compiled_computations[i],
7070
{std::dynamic_pointer_cast<
7171
torch_xla::runtime::ComputationClient::Data>(
@@ -79,7 +79,7 @@ void TestSingleReplication(
7979

8080
for (size_t i = 0; i < results.size(); ++i) {
8181
std::vector<xla::Literal> literals =
82-
torch_xla::runtime::GetComputationClient()->TransferFromDevice(
82+
torch_xla::runtime::GetComputationClientOrDie()->TransferFromDevice(
8383
results[i]);
8484
ASSERT_EQ(literals.size(), 1);
8585

test/cpp/test_runtime.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include <gtest/gtest.h>
2+
3+
#include "torch_xla/csrc/runtime/runtime.h"
4+
5+
namespace torch_xla::runtime {
6+
7+
TEST(RuntimeTest, ComputationClientInitialization) {
8+
ComputationClient* client;
9+
10+
client = GetComputationClientIfInitialized();
11+
EXPECT_EQ(client, nullptr);
12+
13+
// Initialize the ComputationClient.
14+
// Check all the APIs return the same valid ComputationClient.
15+
16+
client = GetComputationClientOrDie();
17+
ASSERT_NE(client, nullptr);
18+
19+
auto status = GetComputationClient();
20+
ASSERT_TRUE(status.ok());
21+
22+
EXPECT_EQ(status.value(), client);
23+
EXPECT_EQ(GetComputationClientIfInitialized(), client);
24+
}
25+
26+
} // namespace torch_xla::runtime

test/cpp/test_xla_sharding.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -332,14 +332,15 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
332332
CreateTensorsData(tensors, shardings, devices);
333333

334334
int64_t n_devices =
335-
torch_xla::runtime::GetComputationClient()->GetLocalDevices().size();
335+
torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices().size();
336336
if (n_devices > 1) {
337337
// null sharding is treated as replicated.
338338
auto xla_data =
339339
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
340340
tensors_data[0]);
341341
std::vector<torch_xla::runtime::ComputationClient::DataPtr> shards =
342-
torch_xla::runtime::GetComputationClient()->GetDataShards(xla_data);
342+
torch_xla::runtime::GetComputationClientOrDie()->GetDataShards(
343+
xla_data);
343344
EXPECT_EQ(shards.size(), n_devices);
344345
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(xla_data->shape(),
345346
shards[0]->shape()));
@@ -349,7 +350,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
349350
auto sharded_xla_data =
350351
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
351352
tensors_data[1]);
352-
shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
353+
shards = torch_xla::runtime::GetComputationClientOrDie()->GetDataShards(
353354
sharded_xla_data);
354355
EXPECT_EQ(shards.size(), n_devices);
355356
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
@@ -360,7 +361,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
360361
sharded_xla_data =
361362
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
362363
tensors_data[2]);
363-
shards = torch_xla::runtime::GetComputationClient()->GetDataShards(
364+
shards = torch_xla::runtime::GetComputationClientOrDie()->GetDataShards(
364365
sharded_xla_data);
365366
EXPECT_EQ(shards.size(), n_devices);
366367
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
@@ -372,7 +373,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
372373
TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
373374
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {4, 4});
374375
int64_t n_devices =
375-
torch_xla::runtime::GetComputationClient()->GetLocalDevices().size();
376+
torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices().size();
376377
xla::Array<int64_t> tile_assignment({1, n_devices});
377378
tile_assignment.FillIota(0);
378379
xla::OpSharding tiled = xla::HloSharding::Tile(tile_assignment).ToProto();
@@ -395,15 +396,15 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
395396

396397
std::vector<
397398
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
398-
computations = torch_xla::runtime::GetComputationClient()->Compile(
399+
computations = torch_xla::runtime::GetComputationClientOrDie()->Compile(
399400
std::move(instances));
400401
torch_xla::runtime::ComputationClient::ComputationPtr computation =
401402
std::make_shared<torch_xla::runtime::ComputationClient::Computation>(
402403
"add", std::move(computations[0]->move_computation()));
403404

404405
// Prepare output sharding propagation, expect a sharded output placeholder.
405406
std::vector<XLATensorPtr> tensors{XLATensor::Create(
406-
torch_xla::runtime::GetComputationClient()->CreateDataPlaceholder(
407+
torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder(
407408
bridge::GetDefaultDevice()->toString(), std::move(shape)))};
408409
std::vector<torch::lazy::BackendDataPtr> data_placeholders;
409410
std::vector<XLATensor::ShardingSpecPtr> sharding_specs;

torch_xla/csrc/aten_fallback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ bool UseOpenXLAFallbackOnCUDA(const c10::OperatorHandle& op) {
7777
// support running OpenXLA fallback operations on CUDA if the current
7878
// PyTorch/XLA DeviceType is not CUDA.
7979
bool device_is_cuda =
80-
runtime::GetComputationClient()->GetDeviceType().getType() ==
80+
runtime::GetComputationClientOrDie()->GetDeviceType().getType() ==
8181
XlaDeviceType::CUDA;
8282

8383
// 3. PyTorch must have been compiled with CUDA support. Otherwise, our

torch_xla/csrc/aten_xla_bridge.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class AtenXlaDeviceMapper {
5656
devices_ordinals_[devices_.back()] = 0;
5757
} else {
5858
for (auto& device_str :
59-
torch_xla::runtime::GetComputationClient()->GetLocalDevices()) {
59+
torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices()) {
6060
devices_.emplace_back(ParseDeviceString(device_str));
6161
devices_ordinals_[devices_.back()] = devices_.size() - 1;
6262
}
@@ -366,8 +366,9 @@ std::string ToXlaString(const c10::Device& device) {
366366

367367
const torch::lazy::BackendDevice* GetDefaultDevice() {
368368
static std::string default_device_spec =
369-
UseVirtualDevice() ? "SPMD:0"
370-
: runtime::GetComputationClient()->GetDefaultDevice();
369+
UseVirtualDevice()
370+
? "SPMD:0"
371+
: runtime::GetComputationClientOrDie()->GetDefaultDevice();
371372
XLA_CHECK(!default_device_spec.empty());
372373
static const torch::lazy::BackendDevice default_device =
373374
ParseDeviceString(default_device_spec);

0 commit comments

Comments
 (0)