Skip to content

Commit d9a9e44

Browse files
authored
runtime: Use new macros for throwing exceptions. (#9591)
Follow-up: #9588 and #9580 Target: `torch_xla/csrc/runtime` directory In summary, this PR: - Replaces all calls to `OkOrThrow()` and `GetValueOrThrow()` (that throws an exception without source location information of the *"throw-site"*) with the macros `XLA_THROW_IF_ERROR()` and `XLA_ASSIGN_OR_THROW()`. - Corresponds to the fine-grained set of PRs that came from breaking down PR #9580 - Focuses on the `torch_xla/csrc/runtime` directory, replacing every use of those, now deprecated, functions by the newly introduced macros.
1 parent d214faf commit d9a9e44

8 files changed

+65
-47
lines changed

torch_xla/csrc/runtime/computation_client.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,9 @@ class ComputationClient {
116116
: name_(name),
117117
computation_(std::move(computation)),
118118
devices_(std::move(devices)) {
119-
program_shape_ = GetValueOrThrow(computation_.GetProgramShape());
119+
XLA_ASSIGN_OR_THROW(program_shape_, computation_.GetProgramShape());
120120
const xla::HloModuleProto& proto = computation_.proto();
121-
hash_ = GetValueOrThrow(ComputeHash(proto, name));
121+
XLA_ASSIGN_OR_THROW(hash_, ComputeHash(proto, name));
122122
}
123123

124124
Computation(std::string name, xla::XlaComputation computation,
@@ -191,7 +191,8 @@ class ComputationClient {
191191

192192
const std::string to_string() const override {
193193
xla::HloModuleConfig hlo_config(program_shape());
194-
std::unique_ptr<xla::HloModule> module = GetValueOrThrow(
194+
XLA_ASSIGN_OR_THROW(
195+
std::unique_ptr<xla::HloModule> module,
195196
xla::HloModule::CreateFromProto(computation().proto(), hlo_config));
196197
return module->ToString();
197198
}

torch_xla/csrc/runtime/ifrt_computation_client.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ void IfrtComputationClient::InitializeCoordinator(int global_rank,
177177
std::string port) {
178178
XLA_CHECK(coordinator_ == nullptr)
179179
<< "Can only initialize the XlaCoordinator once.";
180-
coordinator_ = GetValueOrThrow(
180+
XLA_ASSIGN_OR_THROW(
181+
coordinator_,
181182
XlaCoordinator::Create(global_rank, world_size, master_addr, port));
182183
}
183184

@@ -395,10 +396,10 @@ tsl::RCReference<xla::ifrt::Array> IfrtComputationClient::ReplicateShardedData(
395396
auto instruction = XlaBuilderFriend::GetInstruction(y);
396397
*instruction->mutable_sharding() = xla::HloSharding::Replicate().ToProto();
397398

398-
xla::XlaComputation computation =
399-
GetValueOrThrow(builder.Build(/*remove_dynamic_dimensions=*/false));
400-
xla::ProgramShape program_shape =
401-
GetValueOrThrow(computation.GetProgramShape());
399+
XLA_ASSIGN_OR_THROW(xla::XlaComputation computation,
400+
builder.Build(/*remove_dynamic_dimensions=*/false));
401+
XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape,
402+
computation.GetProgramShape());
402403

403404
std::string device = GetDefaultDevice();
404405
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
@@ -417,8 +418,9 @@ tsl::RCReference<xla::ifrt::Array> IfrtComputationClient::ReplicateShardedData(
417418
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
418419
execute_options;
419420

420-
auto sharded_results = GetValueOrThrow(ExecuteReplicated(
421-
*computations.front(), {{handle}}, GetLocalDevices(), execute_options));
421+
XLA_ASSIGN_OR_THROW(std::vector<ComputationClient::DataPtr> sharded_results,
422+
ExecuteReplicated(*computations.front(), {{handle}},
423+
GetLocalDevices(), execute_options));
422424
auto replicated_output =
423425
std::dynamic_pointer_cast<IfrtData>(sharded_results[0])
424426
->buffer->FullyReplicatedShard(
@@ -516,14 +518,17 @@ std::vector<ComputationClient::ComputationPtr> IfrtComputationClient::Compile(
516518
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
517519
torch_xla::ConvertHloToStableHlo(instance.computation.mutable_proto(),
518520
&mlir_module);
519-
std::shared_ptr<xla::ifrt::LoadedExecutable> executable =
520-
GetValueOrThrow(client_->GetDefaultCompiler()->CompileAndLoad(
521+
XLA_ASSIGN_OR_THROW(
522+
std::shared_ptr<xla::ifrt::LoadedExecutable> executable,
523+
client_->GetDefaultCompiler()->CompileAndLoad(
521524
std::make_unique<xla::ifrt::HloProgram>(mlir_module),
522525
std::make_unique<xla::ifrt::XlaCompileOptions>(compile_options,
523526
devices_list)));
524527
StableHloCompileCounter()->AddValue(1);
525528

526-
const auto& hlo_modules = GetValueOrThrow(executable->GetHloModules());
529+
XLA_ASSIGN_OR_THROW(
530+
const std::vector<std::shared_ptr<xla::HloModule>>& hlo_modules,
531+
executable->GetHloModules());
527532

528533
std::shared_ptr<IfrtComputation> ifrt_computation =
529534
std::make_shared<IfrtComputation>(

torch_xla/csrc/runtime/ifrt_computation_client_test.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ absl::StatusOr<xla::XlaComputation> MakeComputation() {
3636
TEST(PjRtComputationClientTest, Init) {
3737
// Get a CPU client.
3838
tsl::setenv("PJRT_DEVICE", "CPU", true);
39-
auto client = GetValueOrThrow(IfrtComputationClient::Create());
39+
XLA_ASSIGN_OR_THROW(std::unique_ptr<IfrtComputationClient> client,
40+
IfrtComputationClient::Create());
4041
std::string device = client->GetDefaultDevice();
4142

4243
// Compose a computation.
@@ -64,14 +65,16 @@ TEST(PjRtComputationClientTest, Init) {
6465
std::make_shared<LiteralSource>(std::move(literal_y), device)};
6566

6667
// Execute the graph.
67-
std::vector<ComputationClient::DataPtr> results =
68-
GetValueOrThrow(client->ExecuteReplicated(
68+
XLA_ASSIGN_OR_THROW(
69+
std::vector<ComputationClient::DataPtr> results,
70+
client->ExecuteReplicated(
6971
*computations[0], client->TransferToDevice(absl::MakeConstSpan(args)),
7072
{device}, options));
7173

7274
// Copy the output from device back to host and assert correctness..
7375
ASSERT_EQ(results.size(), 1);
74-
auto result_literals = GetValueOrThrow(client->TransferFromDevice(results));
76+
XLA_ASSIGN_OR_THROW(std::vector<xla::Literal> result_literals,
77+
client->TransferFromDevice(results));
7578
ASSERT_THAT(result_literals, ::testing::SizeIs(1));
7679
EXPECT_TRUE(xla::LiteralTestUtil::Equal(
7780
xla::LiteralUtil::CreateR2<float>({{6.0f, 8.0f}, {10.0f, 12.0f}}),

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ void PjRtComputationClient::InitializeCoordinator(int global_rank,
168168
std::string port) {
169169
XLA_CHECK(coordinator_ == nullptr)
170170
<< "Can only initialize the XlaCoordinator once.";
171-
coordinator_ = GetValueOrThrow(
171+
XLA_ASSIGN_OR_THROW(
172+
coordinator_,
172173
XlaCoordinator::Create(global_rank, world_size, master_addr, port));
173174
}
174175

@@ -367,10 +368,10 @@ PjRtComputationClient::ReplicateShardedData(
367368
auto instruction = XlaBuilderFriend::GetInstruction(y);
368369
*instruction->mutable_sharding() = xla::HloSharding::Replicate().ToProto();
369370

370-
xla::XlaComputation computation =
371-
GetValueOrThrow(builder.Build(/*remove_dynamic_dimensions=*/false));
372-
xla::ProgramShape program_shape =
373-
GetValueOrThrow(computation.GetProgramShape());
371+
XLA_ASSIGN_OR_THROW(xla::XlaComputation computation,
372+
builder.Build(/*remove_dynamic_dimensions=*/false));
373+
XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape,
374+
computation.GetProgramShape());
374375

375376
std::string device = GetDefaultDevice();
376377
std::vector<torch_xla::runtime::ComputationClient::CompileInstance>
@@ -386,8 +387,8 @@ PjRtComputationClient::ReplicateShardedData(
386387

387388
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
388389
execute_options;
389-
auto sharded_results =
390-
GetValueOrThrow(ExecuteReplicated(*computations.front(), {sharded_data},
390+
XLA_ASSIGN_OR_THROW(std::vector<ComputationClient::DataPtr> sharded_results,
391+
ExecuteReplicated(*computations.front(), {sharded_data},
391392
GetLocalDevices(), execute_options));
392393
XLA_CHECK(sharded_results.size() > 0)
393394
<< "empty ExecuteReplicated results returned.";
@@ -433,8 +434,9 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(
433434
XLA_CHECK_NE(sharding.type(), xla::OpSharding::UNKNOWN)
434435
<< "Resharding by UNKNOWN sharding type is not allowed.";
435436

436-
hlo_shardings.push_back(
437-
GetValueOrThrow(xla::HloSharding::FromProto(sharding)));
437+
XLA_ASSIGN_OR_THROW(xla::HloSharding hlo_sharding,
438+
xla::HloSharding::FromProto(sharding));
439+
hlo_shardings.push_back(std::move(hlo_sharding));
438440

439441
xla::OpSharding fallback_sharding;
440442
fallback_sharding.set_type(xla::OpSharding::REPLICATED);
@@ -457,9 +459,9 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(
457459
root = xla::Tuple(&builder, param_ops);
458460
}
459461

460-
xla::XlaComputation xla_computation = GetValueOrThrow(builder.Build(root));
461-
xla::ProgramShape program_shape =
462-
GetValueOrThrow(xla_computation.GetProgramShape());
462+
XLA_ASSIGN_OR_THROW(xla::XlaComputation xla_computation, builder.Build(root));
463+
XLA_ASSIGN_OR_THROW(xla::ProgramShape program_shape,
464+
xla_computation.GetProgramShape());
463465

464466
std::string device = GetDefaultDevice();
465467
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
@@ -474,8 +476,9 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(
474476

475477
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
476478
execute_options;
477-
auto resharded_results = GetValueOrThrow(ExecuteReplicated(
478-
*computation, handles, GetLocalDevices(), execute_options));
479+
XLA_ASSIGN_OR_THROW(std::vector<ComputationClient::DataPtr> resharded_results,
480+
ExecuteReplicated(*computation, handles,
481+
GetLocalDevices(), execute_options));
479482
return resharded_results;
480483
}
481484

@@ -660,7 +663,9 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
660663
TF_VLOG(3) << "memory usage is not availiable";
661664
}
662665

663-
const auto& hlo_modules = GetValueOrThrow(executable->GetHloModules());
666+
XLA_ASSIGN_OR_THROW(
667+
const std::vector<std::shared_ptr<xla::HloModule>>& hlo_modules,
668+
executable->GetHloModules());
664669
xla::HloComputation* hlo_computation = hlo_modules[0]->entry_computation();
665670
std::shared_ptr<PjRtComputation> pjrt_computation =
666671
std::make_shared<PjRtComputation>(
@@ -679,8 +684,9 @@ std::string PjRtComputationClient::SerializeComputation(
679684
const ComputationPtr computation) {
680685
const PjRtComputation& pjrt_computation =
681686
dynamic_cast<const PjRtComputation&>(*computation);
682-
683-
return GetValueOrThrow(pjrt_computation.executable->SerializeExecutable());
687+
XLA_ASSIGN_OR_THROW(std::string serialized_executable,
688+
pjrt_computation.executable->SerializeExecutable());
689+
return serialized_executable;
684690
}
685691

686692
ComputationClient::ComputationPtr PjRtComputationClient::DeserializeComputation(

torch_xla/csrc/runtime/pjrt_computation_client_test.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class PjRtComputationClientTest : public ::testing::Test {
2525
PjRtComputationClientTest() {
2626
// Get a CPU client.
2727
tsl::setenv("PJRT_DEVICE", "CPU", true);
28-
client_ = GetValueOrThrow(PjRtComputationClient::Create());
28+
XLA_ASSIGN_OR_THROW(client_, PjRtComputationClient::Create());
2929
device_ = client_->GetDefaultDevice();
3030
}
3131

@@ -114,15 +114,16 @@ TEST_F(PjRtComputationClientTest, Init) {
114114
std::make_shared<LiteralSource>(std::move(literal_y), device_)};
115115

116116
// Execute the graph.
117-
std::vector<ComputationClient::DataPtr> results =
118-
GetValueOrThrow(client_->ExecuteComputation(
119-
*computations[0],
120-
client_->TransferToDevice(absl::MakeConstSpan(args)), device_,
121-
options));
117+
XLA_ASSIGN_OR_THROW(std::vector<ComputationClient::DataPtr> results,
118+
client_->ExecuteComputation(
119+
*computations[0],
120+
client_->TransferToDevice(absl::MakeConstSpan(args)),
121+
device_, options));
122122

123123
// Copy the output from device back to host and assert correctness.
124124
ASSERT_EQ(results.size(), 1);
125-
auto result_literals = GetValueOrThrow(client_->TransferFromDevice(results));
125+
XLA_ASSIGN_OR_THROW(std::vector<xla::Literal> result_literals,
126+
client_->TransferFromDevice(results));
126127
ASSERT_THAT(result_literals, ::testing::SizeIs(1));
127128
EXPECT_TRUE(xla::LiteralTestUtil::Equal(
128129
xla::LiteralUtil::CreateR2<float>({{6.0f, 8.0f}, {10.0f, 12.0f}}),

torch_xla/csrc/runtime/runtime.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ const absl::StatusOr<ComputationClient * absl_nonnull>& GetComputationClient() {
6161
}
6262

6363
ComputationClient* absl_nonnull GetComputationClientOrDie() {
64-
return GetValueOrThrow(GetComputationClient());
64+
XLA_ASSIGN_OR_THROW(ComputationClient * client, GetComputationClient());
65+
return client;
6566
}
6667

6768
ComputationClient* GetComputationClientIfInitialized() {

torch_xla/csrc/runtime/tensor_source.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class TensorSource {
3131

3232
virtual std::vector<int64_t> byte_strides() const {
3333
std::vector<int64_t> byte_strides(shape().dimensions_size());
34-
OkOrThrow(
34+
XLA_THROW_IF_ERROR(
3535
xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides)));
3636
return byte_strides;
3737
}

torch_xla/csrc/runtime/xla_util_test.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,10 @@ TEST(XlaUtilTest, XlaToHlo) {
121121

122122
TEST(XlaUtilTest, TestDeterministicModuleProtoSerializationEmptyProto) {
123123
xla::HloModuleProto empty_proto;
124-
auto result =
125-
GetValueOrThrow(GetDeterministicSerializedModuleProto(empty_proto));
124+
XLA_ASSIGN_OR_THROW(std::string serialized_result,
125+
GetDeterministicSerializedModuleProto(empty_proto));
126126
// Verify that the result is an empty string
127-
EXPECT_TRUE(result.empty());
127+
EXPECT_TRUE(serialized_result.empty());
128128
}
129129

130130
TEST(XlaUtilTest, TestDeterministicModuleProtoSerialization) {
@@ -250,7 +250,8 @@ TEST(XlaUtilTest, TestDeterministicModuleProtoSerialization) {
250250
}
251251
}
252252
}
253-
std::string serialized_proto = GetValueOrThrow(
253+
XLA_ASSIGN_OR_THROW(
254+
std::string serialized_proto,
254255
GetDeterministicSerializedModuleProto(hlo_module_proto));
255256
return torch::lazy::Hash(serialized_proto);
256257
};

0 commit comments

Comments
 (0)