4
4
#include < stdexcept>
5
5
#include < vector>
6
6
7
+ #include " absl/log/absl_check.h"
7
8
#include " absl/strings/ascii.h"
8
9
#include " absl/synchronization/blocking_counter.h"
9
10
#include " absl/types/span.h"
@@ -508,8 +509,8 @@ std::shared_ptr<xla::PjRtBuffer> PjRtComputationClient::GetPjRtBuffer(
508
509
}
509
510
}
510
511
511
- std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice (
512
- absl::Span<const DataPtr> handles) {
512
+ absl::StatusOr< std::vector<xla::Literal>>
513
+ PjRtComputationClient::TransferFromDevice ( absl::Span<const DataPtr> handles) {
513
514
metrics::TimedSection timed (TransferFromDeviceMetric ());
514
515
tsl::profiler::TraceMe activity (" PjRtComputationClient::TransferFromDevice" ,
515
516
tsl::profiler::TraceMeLevel::kInfo );
@@ -522,21 +523,17 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(
522
523
// Use XLA replication to reassemble the sharded data. If input handle
523
524
// is not sharded, then it is a no-op.
524
525
std::shared_ptr<PjRtData> pjrt_data = ReplicateShardedData (handle);
525
- XLA_CHECK (pjrt_data) << " PjRt_data is null in " << __FUNCTION__;
526
- XLA_CHECK (pjrt_data->buffer != nullptr )
526
+ ABSL_CHECK (pjrt_data) << " PjRt_data is null in " << __FUNCTION__;
527
+ ABSL_CHECK (pjrt_data->buffer != nullptr )
527
528
<< " PjRt buffer is null in " << __FUNCTION__;
528
529
529
- xla::Literal& literal =
530
- literals. emplace_back (host_output_shape (pjrt_data->buffer .get ()));
530
+ xla::Literal& literal = literals. emplace_back (
531
+ xla::Literal (host_output_shape (pjrt_data->buffer .get () )));
531
532
futures.push_back (pjrt_data->buffer ->ToLiteral (&literal));
532
533
533
534
total_size += literal.size_bytes ();
534
535
}
535
- for (auto & future : futures) {
536
- absl::Status status = future.Await ();
537
- XLA_CHECK_OK (status) << " Failed to await future from buffer to literal in"
538
- << __FUNCTION__;
539
- }
536
+ XLA_RETURN_IF_ERROR (xla::JoinFutures (futures).Await ());
540
537
InboundDataMetric ()->AddSample (total_size);
541
538
542
539
return literals;
@@ -773,10 +770,8 @@ PjRtComputationClient::ExecuteComputation(
773
770
774
771
std::optional<xla::PjRtFuture<>> returned_future;
775
772
std::vector<std::unique_ptr<xla::PjRtBuffer>> results =
776
- pjrt_computation.executable
777
- ->ExecuteSharded (buffers, pjrt_device, execute_options,
778
- returned_future)
779
- .value ();
773
+ GetValueOrThrow (pjrt_computation.executable ->ExecuteSharded (
774
+ buffers, pjrt_device, execute_options, returned_future));
780
775
781
776
returned_future->OnReady (std::move (
782
777
[timed, op_tracker = std::move (op_tracker)](absl::Status unused) mutable {
@@ -878,10 +873,8 @@ PjRtComputationClient::ExecuteReplicated(
878
873
tsl::profiler::TraceMe activity (
879
874
" PjRtComputationClient::ExecuteReplicated_execute" ,
880
875
tsl::profiler::TraceMeLevel::kInfo );
881
- results = pjrt_computation.executable
882
- ->Execute (std::move (argument_handles), execute_options,
883
- returned_futures)
884
- .value ();
876
+ results = GetValueOrThrow (pjrt_computation.executable ->Execute (
877
+ std::move (argument_handles), execute_options, returned_futures));
885
878
886
879
(*returned_futures)[0 ].OnReady (
887
880
std::move ([timed, op_tracker = std::move (op_tracker)](
0 commit comments