Skip to content

Commit 8274f94

Browse files
authored
Replace GetComputationClientOrDie() with GetComputationClient() (part 1). (#9617)
This PR replaces calls of the deprecated function `GetComputationClientOrDie()` with calls to the `GetComputationClient()` function. The difference between them is that the former throws an exception on error, while the latter returns an status object. _Note: this is the part 1 out of 2 PRs. Together, they will phase out `GetComputationClientOrDie()` function_
1 parent 6ee7627 commit 8274f94

File tree

8 files changed

+167
-142
lines changed

8 files changed

+167
-142
lines changed

test/cpp/cpp_test_util.cpp

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -222,18 +222,19 @@ void WithAllDevices(
222222
const std::function<void(const std::vector<torch::lazy::BackendDevice>&,
223223
const std::vector<torch::lazy::BackendDevice>&)>&
224224
devfn) {
225+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
226+
runtime::GetComputationClient());
225227
for (auto device_type : device_types) {
226228
std::vector<torch::lazy::BackendDevice> devices;
227229
std::vector<torch::lazy::BackendDevice> all_devices;
228-
for (const auto& device_str :
229-
torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices()) {
230+
231+
for (const auto& device_str : client->GetLocalDevices()) {
230232
torch::lazy::BackendDevice device = ParseDeviceString(device_str);
231233
if (device.type() == device_type.type) {
232234
devices.push_back(device);
233235
}
234236
}
235-
for (const auto& device_str :
236-
torch_xla::runtime::GetComputationClientOrDie()->GetAllDevices()) {
237+
for (const auto& device_str : client->GetAllDevices()) {
237238
torch::lazy::BackendDevice device = ParseDeviceString(device_str);
238239
if (device.type() == device_type.type) {
239240
all_devices.push_back(device);
@@ -279,37 +280,36 @@ std::vector<torch_xla::runtime::ComputationClient::DataPtr> Execute(
279280
XLA_ASSIGN_OR_THROW(xla::XlaComputation computation, lowering_ctx.BuildXla());
280281
XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape,
281282
computation.GetProgramShape());
283+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
284+
runtime::GetComputationClient());
282285
xla::Shape shape = MakeShapeWithDeviceLayout(
283286
program_shape.result(), static_cast<XlaDeviceType>(device.type()));
284287

285288
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
286-
instances.push_back(
287-
{std::move(computation), device.toString(),
288-
torch_xla::runtime::GetComputationClientOrDie()->GetCompilationDevices(
289-
device.toString(), {}),
290-
&shape});
289+
instances.push_back({std::move(computation), device.toString(),
290+
client->GetCompilationDevices(device.toString(), {}),
291+
&shape});
291292

292293
std::vector<
293294
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
294-
computations = torch_xla::runtime::GetComputationClientOrDie()->Compile(
295-
std::move(instances));
295+
computations = client->Compile(std::move(instances));
296296

297297
torch_xla::runtime::ComputationClient::ExecuteComputationOptions options;
298-
XLA_ASSIGN_OR_THROW(
299-
std::vector<runtime::ComputationClient::DataPtr> outputs,
300-
torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation(
301-
*computations.front(),
302-
UnwrapXlaData(lowering_ctx.GetParametersData()), device.toString(),
303-
options));
298+
XLA_ASSIGN_OR_THROW(std::vector<runtime::ComputationClient::DataPtr> outputs,
299+
client->ExecuteComputation(
300+
*computations.front(),
301+
UnwrapXlaData(lowering_ctx.GetParametersData()),
302+
device.toString(), options));
304303
return outputs;
305304
}
306305

307306
std::vector<at::Tensor> Fetch(
308307
absl::Span<const torch_xla::runtime::ComputationClient::DataPtr>
309308
device_data) {
310-
XLA_ASSIGN_OR_THROW(
311-
std::vector<xla::Literal> literals,
312-
runtime::GetComputationClientOrDie()->TransferFromDevice(device_data));
309+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
310+
runtime::GetComputationClient());
311+
XLA_ASSIGN_OR_THROW(std::vector<xla::Literal> literals,
312+
client->TransferFromDevice(device_data));
313313
std::vector<at::Tensor> tensors;
314314
for (auto& literal : literals) {
315315
tensors.push_back(MakeTensorFromXlaLiteral(

test/cpp/test_replication.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ void TestSingleReplication(
4848
instances.emplace_back(CreateCrsComputation(shape), device_str,
4949
all_device_strings, &shape);
5050
}
51+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
52+
runtime::GetComputationClient());
5153
std::vector<torch_xla::runtime::ComputationClient::ComputationPtr>
52-
compiled_computations =
53-
torch_xla::runtime::GetComputationClientOrDie()->Compile(
54-
std::move(instances));
54+
compiled_computations = client->Compile(std::move(instances));
5555

5656
std::vector<at::Tensor> tensors;
5757
for (size_t i = 0; i < device_strings.size(); ++i) {
@@ -66,24 +66,22 @@ void TestSingleReplication(
6666
torch_xla::runtime::ComputationClient::ExecuteComputationOptions exec_options;
6767
for (size_t i = 0; i < device_strings.size(); ++i) {
6868
auto executor = [&, i]() {
69-
XLA_ASSIGN_OR_THROW(
70-
results[i],
71-
torch_xla::runtime::GetComputationClientOrDie()->ExecuteComputation(
72-
*compiled_computations[i],
73-
{std::dynamic_pointer_cast<
74-
torch_xla::runtime::ComputationClient::Data>(
75-
tensors_data[i])},
76-
device_strings[i], exec_options));
69+
XLA_ASSIGN_OR_THROW(results[i],
70+
client->ExecuteComputation(
71+
*compiled_computations[i],
72+
{std::dynamic_pointer_cast<
73+
torch_xla::runtime::ComputationClient::Data>(
74+
tensors_data[i])},
75+
device_strings[i], exec_options));
7776
counter.DecrementCount();
7877
};
7978
torch_xla::thread::Schedule(std::move(executor));
8079
}
8180
counter.Wait();
8281

8382
for (size_t i = 0; i < results.size(); ++i) {
84-
XLA_ASSIGN_OR_THROW(
85-
std::vector<xla::Literal> literals,
86-
runtime::GetComputationClientOrDie()->TransferFromDevice(results[i]));
83+
XLA_ASSIGN_OR_THROW(std::vector<xla::Literal> literals,
84+
client->TransferFromDevice(results[i]));
8785
ASSERT_EQ(literals.size(), 1);
8886

8987
// The result must be the original tensor value, multiplied by the number of

test/cpp/test_runtime.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,10 @@ TEST(RuntimeTest, ComputationClientInitialization) {
1313
// Initialize the ComputationClient.
1414
// Check all the APIs return the same valid ComputationClient.
1515

16-
client = GetComputationClientOrDie();
17-
ASSERT_NE(client, nullptr);
18-
1916
auto status = GetComputationClient();
2017
ASSERT_TRUE(status.ok());
2118

22-
EXPECT_EQ(status.value(), client);
19+
client = status.value();
2320
EXPECT_EQ(GetComputationClientIfInitialized(), client);
2421
}
2522

test/cpp/test_xla_sharding.cpp

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -332,16 +332,16 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
332332
std::vector<torch::lazy::BackendDataPtr> tensors_data =
333333
CreateTensorsData(tensors, shardings, devices);
334334

335-
int64_t n_devices =
336-
torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices().size();
335+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
336+
runtime::GetComputationClient());
337+
int64_t n_devices = client->GetLocalDevices().size();
337338
if (n_devices > 1) {
338339
// null sharding is treated as replicated.
339340
auto xla_data =
340341
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
341342
tensors_data[0]);
342343
std::vector<torch_xla::runtime::ComputationClient::DataPtr> shards =
343-
torch_xla::runtime::GetComputationClientOrDie()->GetDataShards(
344-
xla_data);
344+
client->GetDataShards(xla_data);
345345
EXPECT_EQ(shards.size(), n_devices);
346346
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(xla_data->shape(),
347347
shards[0]->shape()));
@@ -351,8 +351,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
351351
auto sharded_xla_data =
352352
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
353353
tensors_data[1]);
354-
shards = torch_xla::runtime::GetComputationClientOrDie()->GetDataShards(
355-
sharded_xla_data);
354+
shards = client->GetDataShards(sharded_xla_data);
356355
EXPECT_EQ(shards.size(), n_devices);
357356
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
358357
shards[0]->shape()));
@@ -362,8 +361,7 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
362361
sharded_xla_data =
363362
std::dynamic_pointer_cast<torch_xla::runtime::ComputationClient::Data>(
364363
tensors_data[2]);
365-
shards = torch_xla::runtime::GetComputationClientOrDie()->GetDataShards(
366-
sharded_xla_data);
364+
shards = client->GetDataShards(sharded_xla_data);
367365
EXPECT_EQ(shards.size(), n_devices);
368366
EXPECT_TRUE(xla::Shape::Equal().IgnoreLayout()(sharded_xla_data->shape(),
369367
shards[0]->shape()));
@@ -373,8 +371,9 @@ TEST_F(XLAShardingTest, CreateTensorsData) {
373371

374372
TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
375373
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {4, 4});
376-
int64_t n_devices =
377-
torch_xla::runtime::GetComputationClientOrDie()->GetLocalDevices().size();
374+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
375+
runtime::GetComputationClient());
376+
int64_t n_devices = client->GetLocalDevices().size();
378377
xla::Array<int64_t> tile_assignment({1, n_devices});
379378
tile_assignment.FillIota(0);
380379
xla::OpSharding tiled = xla::HloSharding::Tile(tile_assignment).ToProto();
@@ -397,15 +396,14 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
397396

398397
std::vector<
399398
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
400-
computations = torch_xla::runtime::GetComputationClientOrDie()->Compile(
401-
std::move(instances));
399+
computations = client->Compile(std::move(instances));
402400
torch_xla::runtime::ComputationClient::ComputationPtr computation =
403401
std::make_shared<torch_xla::runtime::ComputationClient::Computation>(
404402
"add", std::move(computations[0]->move_computation()));
405403

406404
// Prepare output sharding propagation, expect a sharded output placeholder.
407-
std::vector<XLATensorPtr> tensors{XLATensor::Create(
408-
torch_xla::runtime::GetComputationClientOrDie()->CreateDataPlaceholder(
405+
std::vector<XLATensorPtr> tensors{
406+
XLATensor::Create(client->CreateDataPlaceholder(
409407
bridge::GetDefaultDevice()->toString(), std::move(shape)))};
410408
std::vector<torch::lazy::BackendDataPtr> data_placeholders;
411409
std::vector<XLATensor::ShardingSpecPtr> sharding_specs;

torch_xla/csrc/tensor_util.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -550,11 +550,14 @@ torch::lazy::BackendDataPtr TensorToXlaData(
550550
const at::Tensor& tensor, const xla::Shape& shape,
551551
const torch::lazy::BackendDevice& device) {
552552
TORCH_LAZY_TIMED("TensorToData");
553+
554+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
555+
runtime::GetComputationClient());
556+
553557
if (static_cast<XlaDeviceType>(device.type()) == XlaDeviceType::SPMD) {
554558
// The tensor is bypassing the virtual device, so it should be replicated
555559
// to all devices.
556-
std::vector<std::string> local_devices =
557-
runtime::GetComputationClientOrDie()->GetLocalDevices();
560+
std::vector<std::string> local_devices = client->GetLocalDevices();
558561
auto replicated_data =
559562
std::vector<at::Tensor>(local_devices.size(), tensor);
560563
return ShardingUtil::CreateShardedData(replicated_data, local_devices,
@@ -565,8 +568,7 @@ torch::lazy::BackendDataPtr TensorToXlaData(
565568
source_tensors.push_back(
566569
std::make_shared<runtime::AtenSource>(tensor, shape, device.toString()));
567570

568-
auto handles =
569-
runtime::GetComputationClientOrDie()->TransferToDevice(source_tensors);
571+
auto handles = client->TransferToDevice(source_tensors);
570572
XLA_CHECK_EQ(handles.size(), 1);
571573
return handles.front();
572574
}
@@ -806,15 +808,17 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
806808
return {};
807809
}
808810

811+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
812+
runtime::GetComputationClient());
813+
809814
// CreateTensorsData should be implicitly replicated to all devices.
810815
if (IsVirtualDevice(devices[0])) {
811816
XLA_CHECK(
812817
std::all_of(devices.begin(), devices.end(),
813818
[&](const std::string& s) { return s == devices[0]; }))
814819
<< "can't mix virtual device and real device.";
815820

816-
std::vector<std::string> local_devices =
817-
runtime::GetComputationClientOrDie()->GetLocalDevices();
821+
std::vector<std::string> local_devices = client->GetLocalDevices();
818822
std::vector<runtime::ComputationClient::DataPtr> handles;
819823
for (size_t i = 0; i < tensors.size(); ++i) {
820824
auto device = ParseDeviceString(devices[i]);
@@ -834,8 +838,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
834838
source_tensors.push_back(std::make_shared<runtime::AtenSource>(
835839
tensors[i], std::move(shape), devices[i]));
836840
}
837-
return WrapXlaData(
838-
runtime::GetComputationClientOrDie()->TransferToDevice(source_tensors));
841+
return WrapXlaData(client->TransferToDevice(source_tensors));
839842
}
840843

841844
std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
@@ -846,6 +849,9 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
846849
XLA_CHECK_EQ(tensors.size(), shardings.size());
847850
XLA_CHECK_EQ(tensors.size(), devices.size());
848851

852+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
853+
runtime::GetComputationClient());
854+
849855
std::vector<runtime::ComputationClient::DataPtr> handles;
850856
for (size_t i = 0; i < tensors.size(); ++i) {
851857
torch::lazy::BackendDevice device = ParseDeviceString(devices[i]);
@@ -858,8 +864,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
858864
// GetLocalDevices returns the list of local devices specified by their
859865
// global ordinals (e.g. ["TPU:4", "TPU:5", "TPU:6", "TPU:7"]).
860866

861-
std::vector<std::string> local_devices =
862-
runtime::GetComputationClientOrDie()->GetLocalDevices();
867+
std::vector<std::string> local_devices = client->GetLocalDevices();
863868
// Shards the input tensors with padding, to split evenly.
864869
// The execution requires consistent shard sizes, and the zero-padded
865870
// values should be ignored.
@@ -871,8 +876,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
871876
} else {
872877
source_tensors.push_back(std::make_shared<runtime::AtenSource>(
873878
tensors[i], std::move(shape), devices[i]));
874-
new_handles = runtime::GetComputationClientOrDie()->TransferToDevice(
875-
source_tensors);
879+
new_handles = client->TransferToDevice(source_tensors);
876880
}
877881
handles.insert(handles.end(), new_handles.begin(), new_handles.end());
878882
}
@@ -910,7 +914,7 @@ absl::StatusOr<std::vector<xla::Literal>> ReleaseGilAndTransferData(
910914
save = PyEval_SaveThread();
911915
}
912916

913-
XLA_ASSIGN_OR_RETURN(runtime::ComputationClient * client,
917+
XLA_ASSIGN_OR_RETURN(runtime::ComputationClient * absl_nonnull const client,
914918
runtime::GetComputationClient());
915919
XLA_ASSIGN_OR_RETURN(std::vector<xla::Literal> literals,
916920
client->TransferFromDevice(UnwrapXlaData(xla_data)));

torch_xla/csrc/xla_backend_impl.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
2828
if (!default_device_type_inited_) {
2929
// bridge::GetDefaultDevice will trigger the runtime device init, should
3030
// not do it during class init time.
31-
default_device_type_ = std::make_shared<DeviceType>(
32-
runtime::GetComputationClientOrDie()->GetDeviceType());
31+
XLA_ASSIGN_OR_THROW(
32+
runtime::ComputationClient * absl_nonnull const client,
33+
runtime::GetComputationClient());
34+
default_device_type_ =
35+
std::make_shared<DeviceType>(client->GetDeviceType());
3336
default_device_type_inited_ = true;
3437
}
3538
return true;
@@ -77,8 +80,10 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
7780
const torch::lazy::BackendDevice& device,
7881
const torch::lazy::Shape& shape) const override {
7982
xla::Shape xla_shape = MakeXlaShapeFromLazyShape(shape, device);
80-
return runtime::GetComputationClientOrDie()->CreateDataPlaceholder(
81-
device.toString(), std::move(xla_shape));
83+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
84+
runtime::GetComputationClient());
85+
return client->CreateDataPlaceholder(device.toString(),
86+
std::move(xla_shape));
8287
}
8388

8489
torch::lazy::BackendDataPtr GetComputationDataFromNode(
@@ -121,8 +126,9 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
121126
std::vector<std::string> GetCompilationDevices(
122127
const std::string& device,
123128
c10::ArrayRef<std::string> devices) const override {
124-
return runtime::GetComputationClientOrDie()->GetCompilationDevices(device,
125-
devices);
129+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
130+
runtime::GetComputationClient());
131+
return client->GetCompilationDevices(device, devices);
126132
}
127133

128134
std::vector<torch::lazy::ComputationPtr> Compile(
@@ -155,19 +161,22 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
155161
torch_xla_computation->get_device_string(),
156162
{current_device.toString()}, &output_shapes.back()));
157163
}
164+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
165+
runtime::GetComputationClient());
158166
std::vector<std::shared_ptr<runtime::ComputationClient::Computation>>
159-
client_computations = runtime::GetComputationClientOrDie()->Compile(
160-
std::move(compile_instances));
167+
client_computations = client->Compile(std::move(compile_instances));
161168
return {client_computations.begin(), client_computations.end()};
162169
}
163170

164171
std::vector<torch::lazy::BackendDataPtr> ExecuteComputation(
165172
torch::lazy::ComputationPtr computation,
166173
c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
167174
const torch::lazy::BackendDevice& device) const override {
175+
XLA_ASSIGN_OR_THROW(runtime::ComputationClient * absl_nonnull const client,
176+
runtime::GetComputationClient());
168177
XLA_ASSIGN_OR_THROW(
169178
std::vector<runtime::ComputationClient::DataPtr> results,
170-
runtime::GetComputationClientOrDie()->ExecuteComputation(
179+
client->ExecuteComputation(
171180
*std::dynamic_pointer_cast<runtime::ComputationClient::Computation>(
172181
computation),
173182
UnwrapXlaData(arguments), device.toString()));

0 commit comments

Comments
 (0)