@@ -550,11 +550,14 @@ torch::lazy::BackendDataPtr TensorToXlaData(
550
550
const at::Tensor& tensor, const xla::Shape& shape,
551
551
const torch::lazy::BackendDevice& device) {
552
552
TORCH_LAZY_TIMED (" TensorToData" );
553
+
554
+ XLA_ASSIGN_OR_THROW (runtime::ComputationClient * absl_nonnull const client,
555
+ runtime::GetComputationClient ());
556
+
553
557
if (static_cast <XlaDeviceType>(device.type ()) == XlaDeviceType::SPMD) {
554
558
// The tensor is bypassing the virtual device, so it should be replicated
555
559
// to all devices.
556
- std::vector<std::string> local_devices =
557
- runtime::GetComputationClientOrDie ()->GetLocalDevices ();
560
+ std::vector<std::string> local_devices = client->GetLocalDevices ();
558
561
auto replicated_data =
559
562
std::vector<at::Tensor>(local_devices.size (), tensor);
560
563
return ShardingUtil::CreateShardedData (replicated_data, local_devices,
@@ -565,8 +568,7 @@ torch::lazy::BackendDataPtr TensorToXlaData(
565
568
source_tensors.push_back (
566
569
std::make_shared<runtime::AtenSource>(tensor, shape, device.toString ()));
567
570
568
- auto handles =
569
- runtime::GetComputationClientOrDie ()->TransferToDevice (source_tensors);
571
+ auto handles = client->TransferToDevice (source_tensors);
570
572
XLA_CHECK_EQ (handles.size (), 1 );
571
573
return handles.front ();
572
574
}
@@ -806,15 +808,17 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
806
808
return {};
807
809
}
808
810
811
+ XLA_ASSIGN_OR_THROW (runtime::ComputationClient * absl_nonnull const client,
812
+ runtime::GetComputationClient ());
813
+
809
814
// CreateTensorsData should be implicitly replicated to all devices.
810
815
if (IsVirtualDevice (devices[0 ])) {
811
816
XLA_CHECK (
812
817
std::all_of (devices.begin (), devices.end (),
813
818
[&](const std::string& s) { return s == devices[0 ]; }))
814
819
<< " can't mix virtual device and real device." ;
815
820
816
- std::vector<std::string> local_devices =
817
- runtime::GetComputationClientOrDie ()->GetLocalDevices ();
821
+ std::vector<std::string> local_devices = client->GetLocalDevices ();
818
822
std::vector<runtime::ComputationClient::DataPtr> handles;
819
823
for (size_t i = 0 ; i < tensors.size (); ++i) {
820
824
auto device = ParseDeviceString (devices[i]);
@@ -834,8 +838,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
834
838
source_tensors.push_back (std::make_shared<runtime::AtenSource>(
835
839
tensors[i], std::move (shape), devices[i]));
836
840
}
837
- return WrapXlaData (
838
- runtime::GetComputationClientOrDie ()->TransferToDevice (source_tensors));
841
+ return WrapXlaData (client->TransferToDevice (source_tensors));
839
842
}
840
843
841
844
std::vector<torch::lazy::BackendDataPtr> CreateTensorsData (
@@ -846,6 +849,9 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
846
849
XLA_CHECK_EQ (tensors.size (), shardings.size ());
847
850
XLA_CHECK_EQ (tensors.size (), devices.size ());
848
851
852
+ XLA_ASSIGN_OR_THROW (runtime::ComputationClient * absl_nonnull const client,
853
+ runtime::GetComputationClient ());
854
+
849
855
std::vector<runtime::ComputationClient::DataPtr> handles;
850
856
for (size_t i = 0 ; i < tensors.size (); ++i) {
851
857
torch::lazy::BackendDevice device = ParseDeviceString (devices[i]);
@@ -858,8 +864,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
858
864
// GetLocalDevices returns the list of local devices specified by their
859
865
// global ordinals (e.g. ["TPU:4", "TPU:5", "TPU:6", "TPU:7"]).
860
866
861
- std::vector<std::string> local_devices =
862
- runtime::GetComputationClientOrDie ()->GetLocalDevices ();
867
+ std::vector<std::string> local_devices = client->GetLocalDevices ();
863
868
// Shards the input tensors with padding, to split evenly.
864
869
// The execution requires consistent shard sizes, and the zero-padded
865
870
// values should be ignored.
@@ -871,8 +876,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
871
876
} else {
872
877
source_tensors.push_back (std::make_shared<runtime::AtenSource>(
873
878
tensors[i], std::move (shape), devices[i]));
874
- new_handles = runtime::GetComputationClientOrDie ()->TransferToDevice (
875
- source_tensors);
879
+ new_handles = client->TransferToDevice (source_tensors);
876
880
}
877
881
handles.insert (handles.end (), new_handles.begin (), new_handles.end ());
878
882
}
@@ -910,7 +914,7 @@ absl::StatusOr<std::vector<xla::Literal>> ReleaseGilAndTransferData(
910
914
save = PyEval_SaveThread ();
911
915
}
912
916
913
- XLA_ASSIGN_OR_RETURN (runtime::ComputationClient * client,
917
+ XLA_ASSIGN_OR_RETURN (runtime::ComputationClient * absl_nonnull const client,
914
918
runtime::GetComputationClient ());
915
919
XLA_ASSIGN_OR_RETURN (std::vector<xla::Literal> literals,
916
920
client->TransferFromDevice (UnwrapXlaData (xla_data)));
0 commit comments