@@ -168,7 +168,8 @@ void PjRtComputationClient::InitializeCoordinator(int global_rank,
168
168
std::string port) {
169
169
XLA_CHECK (coordinator_ == nullptr )
170
170
<< " Can only initialize the XlaCoordinator once." ;
171
- coordinator_ = GetValueOrThrow (
171
+ XLA_ASSIGN_OR_THROW (
172
+ coordinator_,
172
173
XlaCoordinator::Create (global_rank, world_size, master_addr, port));
173
174
}
174
175
@@ -367,10 +368,10 @@ PjRtComputationClient::ReplicateShardedData(
367
368
auto instruction = XlaBuilderFriend::GetInstruction (y);
368
369
*instruction->mutable_sharding () = xla::HloSharding::Replicate ().ToProto ();
369
370
370
- xla::XlaComputation computation =
371
- GetValueOrThrow ( builder.Build (/* remove_dynamic_dimensions=*/ false ));
372
- xla::ProgramShape program_shape =
373
- GetValueOrThrow ( computation.GetProgramShape ());
371
+ XLA_ASSIGN_OR_THROW ( xla::XlaComputation computation,
372
+ builder.Build (/* remove_dynamic_dimensions=*/ false ));
373
+ XLA_ASSIGN_OR_THROW ( xla::ProgramShape program_shape,
374
+ computation.GetProgramShape ());
374
375
375
376
std::string device = GetDefaultDevice ();
376
377
std::vector<torch_xla::runtime::ComputationClient::CompileInstance>
@@ -386,8 +387,8 @@ PjRtComputationClient::ReplicateShardedData(
386
387
387
388
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
388
389
execute_options;
389
- auto sharded_results =
390
- GetValueOrThrow ( ExecuteReplicated (*computations.front (), {sharded_data},
390
+ XLA_ASSIGN_OR_THROW (std::vector<ComputationClient::DataPtr> sharded_results,
391
+ ExecuteReplicated (*computations.front (), {sharded_data},
391
392
GetLocalDevices (), execute_options));
392
393
XLA_CHECK (sharded_results.size () > 0 )
393
394
<< " empty ExecuteReplicated results returned." ;
@@ -433,8 +434,9 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(
433
434
XLA_CHECK_NE (sharding.type (), xla::OpSharding::UNKNOWN)
434
435
<< " Resharding by UNKNOWN sharding type is not allowed." ;
435
436
436
- hlo_shardings.push_back (
437
- GetValueOrThrow (xla::HloSharding::FromProto (sharding)));
437
+ XLA_ASSIGN_OR_THROW (xla::HloSharding hlo_sharding,
438
+ xla::HloSharding::FromProto (sharding));
439
+ hlo_shardings.push_back (std::move (hlo_sharding));
438
440
439
441
xla::OpSharding fallback_sharding;
440
442
fallback_sharding.set_type (xla::OpSharding::REPLICATED);
@@ -457,9 +459,9 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(
457
459
root = xla::Tuple (&builder, param_ops);
458
460
}
459
461
460
- xla::XlaComputation xla_computation = GetValueOrThrow ( builder.Build (root));
461
- xla::ProgramShape program_shape =
462
- GetValueOrThrow ( xla_computation.GetProgramShape ());
462
+ XLA_ASSIGN_OR_THROW ( xla::XlaComputation xla_computation, builder.Build (root));
463
+ XLA_ASSIGN_OR_THROW ( xla::ProgramShape program_shape,
464
+ xla_computation.GetProgramShape ());
463
465
464
466
std::string device = GetDefaultDevice ();
465
467
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
@@ -474,8 +476,9 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::ReshardData(
474
476
475
477
torch_xla::runtime::ComputationClient::ExecuteReplicatedOptions
476
478
execute_options;
477
- auto resharded_results = GetValueOrThrow (ExecuteReplicated (
478
- *computation, handles, GetLocalDevices (), execute_options));
479
+ XLA_ASSIGN_OR_THROW (std::vector<ComputationClient::DataPtr> resharded_results,
480
+ ExecuteReplicated (*computation, handles,
481
+ GetLocalDevices (), execute_options));
479
482
return resharded_results;
480
483
}
481
484
@@ -660,7 +663,9 @@ std::vector<ComputationClient::ComputationPtr> PjRtComputationClient::Compile(
660
663
TF_VLOG (3 ) << " memory usage is not availiable" ;
661
664
}
662
665
663
- const auto & hlo_modules = GetValueOrThrow (executable->GetHloModules ());
666
+ XLA_ASSIGN_OR_THROW (
667
+ const std::vector<std::shared_ptr<xla::HloModule>>& hlo_modules,
668
+ executable->GetHloModules ());
664
669
xla::HloComputation* hlo_computation = hlo_modules[0 ]->entry_computation ();
665
670
std::shared_ptr<PjRtComputation> pjrt_computation =
666
671
std::make_shared<PjRtComputation>(
@@ -679,8 +684,9 @@ std::string PjRtComputationClient::SerializeComputation(
679
684
const ComputationPtr computation) {
680
685
const PjRtComputation& pjrt_computation =
681
686
dynamic_cast <const PjRtComputation&>(*computation);
682
-
683
- return GetValueOrThrow (pjrt_computation.executable ->SerializeExecutable ());
687
+ XLA_ASSIGN_OR_THROW (std::string serialized_executable,
688
+ pjrt_computation.executable ->SerializeExecutable ());
689
+ return serialized_executable;
684
690
}
685
691
686
692
ComputationClient::ComputationPtr PjRtComputationClient::DeserializeComputation (
0 commit comments