Skip to content

Commit 004f19e

Browse files
authored
Remove CUDA specific logic from runtime. (#9598)
This PR removes CUDA specific logic from `torch_xla/csrc/runtime` directory, as well as uses of deleted functions and environment variables from outside. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - Removed environment variable `ZERO_COPY_ENABLED`, which was used to use DLPack for moving without copying tensors from PyTorch CUDA to PyTorch/XLA XLA:CUDA device - Removed Python API function `_get_stream_for_cuda_device`, which was used in `dlpack.py` for DLPack related logic on CUDA capsules - Removed `ComputationClient::GetCudaStreamForDevice()`, which was used by the Python API above - Removed `PjRtComputationClient::RegisterCustomCall()`, since it only worked when `platform == "CUDA"` - Removed `GetGpuAllocatorConfig()` - Removed `from_xla_cuda_to_cuda()` DLPack function - Removed CUDA branch from `InitializePjRt()`
1 parent e7b1159 commit 004f19e

12 files changed

+12
-195
lines changed

test/dynamo/test_dynamo.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,6 @@ def _choose_proper_device(self, initialize_on_cuda):
157157
self.skipTest(
158158
"Skip this test because it requires xr.device_type()=='CUDA' and torch.cuda.is_available()."
159159
)
160-
os.environ.update({
161-
xenv.ZERO_COPY_ENABLED: "1",
162-
})
163160
return "cuda:0"
164161

165162
@skipOnNeuron
@@ -205,9 +202,6 @@ def test_simple_model(self):
205202
"1",
206203
)
207204
def test_simple_model_automoves_tensors(self, zero_copy_enabled):
208-
os.environ.update({
209-
xenv.ZERO_COPY_ENABLED: zero_copy_enabled,
210-
})
211205
x = torch.tensor(100.0, requires_grad=True, device="cuda:0")
212206
y = torch.tensor(200.0, requires_grad=True, device="cuda:0")
213207
original_device = x.device

torch_xla/_dynamo/dynamo_bridge.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,19 +148,10 @@ def _maybe_move_tensors_to_device(tensors: tuple,
148148
if dynamo_debug:
149149
print("Moving Tensor {} to device {}".format(tensor, target_device))
150150

151-
zero_copy_enabled = xu.getenv_as(xenv.ZERO_COPY_ENABLED, bool, defval=False)
152-
if zero_copy_enabled and tensor.device.type == 'cuda' and target_device.type == 'xla':
153-
# If the input cuda tensor requires gradient, we need to call detach. Otherwise, we'd get the error "RuntimeError: Can't export tensors that require gradient, use tensor.detach()"
154-
moved_tensor = torch_xla_dlpack.from_dlpack(tensor.detach())
155-
elif zero_copy_enabled and tensor.device.type == 'xla' and target_device.type == 'cuda':
156-
# `torch_xla.sync()` is need to make sure the pjrt buffer is valid.
157-
torch_xla.sync()
158-
moved_tensor = torch_xla_dlpack.from_xla_cuda_to_cuda(tensor)
159-
else:
160-
# Have to move to CPU before moving it to target device.
161-
cpu_device: torch.device = torch.device("cpu")
162-
moved_tensor = tensor.to(cpu_device)
163-
moved_tensor = moved_tensor.to(target_device)
151+
# Have to move to CPU before moving it to target device.
152+
cpu_device: torch.device = torch.device("cpu")
153+
moved_tensor = tensor.to(cpu_device)
154+
moved_tensor = moved_tensor.to(target_device)
164155

165156
# Explicitly have to copy requires_grad attribute because it's dropped
166157
# with torch.to(..)

torch_xla/core/xla_env_vars.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,3 @@
3030
RANK = 'RANK'
3131
WORLD_SIZE = 'WORLD_SIZE'
3232
LOCAL_WORLD_SIZE = 'LOCAL_WORLD_SIZE'
33-
ZERO_COPY_ENABLED = 'ZERO_COPY_ENABLED'

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1771,11 +1771,6 @@ void InitXlaModuleBindings(py::module m) {
17711771
[]() {
17721772
return runtime::GetComputationClientOrDie()->GetPlatformVersion();
17731773
})
1774-
.def("_get_stream_for_cuda_device",
1775-
[](const int device_id) {
1776-
return runtime::GetComputationClientOrDie()->GetCudaStreamForDevice(
1777-
device_id);
1778-
})
17791774
.def("_xla_num_devices",
17801775
[]() -> int64_t {
17811776
if (UseVirtualDevice()) {

torch_xla/csrc/runtime/computation_client.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,6 @@ class ComputationClient {
375375
virtual absl::StatusOr<xla::PjRtDevice*> LookupAddressableDevice(
376376
int local_device_id) const = 0;
377377

378-
virtual std::intptr_t GetCudaStreamForDevice(int local_device_id) const = 0;
379-
380378
virtual size_t GetNumLocalDevices() const = 0;
381379

382380
virtual size_t GetNumDevices() const = 0;

torch_xla/csrc/runtime/env_vars.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,19 @@ namespace env {
1010
inline constexpr char kEnvLocalWorker[] = "LOCAL_WORKER";
1111
inline constexpr char kEnvTpuConfig[] = "TPU_CONFIG";
1212
inline constexpr char kEnvNumTpu[] = "TPU_NUM_DEVICES";
13-
inline constexpr char kEnvNumGpu[] = "GPU_NUM_DEVICES";
1413
inline constexpr char kEnvNumCpu[] = "CPU_NUM_DEVICES";
1514
inline constexpr char kEnvTpuvmMode[] = "TPUVM_MODE";
1615
inline constexpr char kEnvPjRtDevice[] = "PJRT_DEVICE";
1716
inline constexpr char kEnvPjRtTpuMaxInflightComputations[] =
1817
"PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS";
1918
inline constexpr char kEnvPjrtAsyncCpuClient[] = "PJRT_CPU_ASYNC_CLIENT";
20-
inline constexpr char kEnvPjrtAsyncGpuClient[] = "PJRT_GPU_ASYNC_CLIENT";
2119
inline constexpr char kEnvTpuLibraryPath[] = "TPU_LIBRARY_PATH";
2220
inline constexpr char kEnvInferredTpuLibraryPath[] = "PTXLA_TPU_LIBRARY_PATH";
2321
inline constexpr char kEnvXpuLibraryPath[] = "XPU_LIBRARY_PATH";
2422
inline constexpr char kEnvNeuronLibraryPath[] = "NEURON_LIBRARY_PATH";
2523
inline constexpr char kEnvPjrtDistServiceAddr[] = "PJRT_DIST_SERVICE_ADDR";
2624
inline constexpr char kEnvPjRtLocalProcessCount[] = "PJRT_LOCAL_PROCESS_COUNT";
2725
inline constexpr char kEnvPjRtLocalRank[] = "PJRT_LOCAL_PROCESS_RANK";
28-
inline constexpr char kEnvPjrtAllocatorCudaAsync[] =
29-
"PJRT_ALLOCATOR_CUDA_ASYNC";
3026
inline constexpr char kEnvPjrtAllocatorPreallocate[] =
3127
"PJRT_ALLOCATOR_PREALLOCATE";
3228
inline constexpr char kEnvPjrtAllocatorFraction[] = "PJRT_ALLOCATOR_FRACTION";

torch_xla/csrc/runtime/ifrt_computation_client.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,6 @@ IfrtComputationClient::Create() {
161161
}
162162

163163
IfrtComputationClient::~IfrtComputationClient() {
164-
// In the GPU case, the PjRtClient depends on the DistributedRuntimeClient
165-
// tracked in XlaCoordinator, so the PjRtClient must be destroyed first.
166164
client_ = nullptr;
167165
coordinator_ = nullptr;
168166
}

torch_xla/csrc/runtime/ifrt_computation_client.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,6 @@ class IfrtComputationClient : public ComputationClient {
110110
XLA_ERROR() << __FUNCTION__ << " not implemented";
111111
}
112112

113-
std::intptr_t GetCudaStreamForDevice(int local_device_id) const override {
114-
XLA_ERROR() << __FUNCTION__ << " not implemented";
115-
}
116-
117113
std::vector<std::string> GetLocalDevices() const override;
118114

119115
std::vector<std::string> GetAllDevices() const override;

torch_xla/csrc/runtime/pjrt_computation_client.cpp

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include "xla/hlo/builder/xla_builder.h"
2424
#include "xla/hlo/builder/xla_computation.h"
2525
#include "xla/literal.h"
26-
#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h"
2726
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
2827
#include "xla/pjrt/pjrt_api.h"
2928
#include "xla/pjrt/pjrt_c_api_client.h"
@@ -152,8 +151,6 @@ PjRtComputationClient::Create() {
152151
}
153152

154153
PjRtComputationClient::~PjRtComputationClient() {
155-
// In the GPU case, the PjRtClient depends on the DistributedRuntimeClient
156-
// tracked in XlaCoordinator, so the PjRtClient must be destroyed first.
157154
client_ = nullptr;
158155
coordinator_ = nullptr;
159156
}
@@ -1038,45 +1035,6 @@ ComputationClient::MemoryInfo PjRtComputationClient::GetMemoryInfo(
10381035
};
10391036
}
10401037

1041-
void PjRtComputationClient::RegisterCustomCall(const std::string& fn_name,
1042-
void* function_ptr,
1043-
const std::string& platform) {
1044-
if (platform != "CUDA") {
1045-
XLA_ERROR() << "Custom call targets can only be registered for "
1046-
"PJRT CUDA runtime.";
1047-
return;
1048-
}
1049-
1050-
auto* c_api_client = dynamic_cast<xla::PjRtCApiClient*>(client_.get());
1051-
if (!c_api_client) {
1052-
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(fn_name, function_ptr, platform);
1053-
return;
1054-
}
1055-
const PJRT_Api* pjrt_api = c_api_client->pjrt_c_api();
1056-
1057-
// See openxla reference:
1058-
// https://github.com/openxla/xla/blob/b604c8d87df842002a7a8de79a434026329fbcb2/xla/pjrt/c/pjrt_c_api_gpu_test.cc#L414
1059-
const PJRT_Extension_Base* next =
1060-
reinterpret_cast<const PJRT_Extension_Base*>(pjrt_api->extension_start);
1061-
while (next != nullptr &&
1062-
next->type !=
1063-
PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) {
1064-
next = next->next;
1065-
}
1066-
XLA_CHECK(next) << "Custom call extension not found";
1067-
PJRT_Gpu_Register_Custom_Call_Args args;
1068-
args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE;
1069-
args.function_name = fn_name.c_str();
1070-
args.function_name_size = fn_name.size();
1071-
args.api_version = 0;
1072-
args.handler_execute = function_ptr;
1073-
PJRT_Error* error =
1074-
reinterpret_cast<const PJRT_Gpu_Custom_Call*>(next)->custom_call(&args);
1075-
if (error) {
1076-
XLA_ERROR() << error->status;
1077-
}
1078-
}
1079-
10801038
void PjRtComputationClient::OnReadyCallback(
10811039
ComputationClient::DataPtr data, const std::function<void()>& callback) {
10821040
std::shared_ptr<xla::PjRtBuffer> buffer;

torch_xla/csrc/runtime/pjrt_computation_client.h

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,6 @@ class PjRtComputationClient : public ComputationClient {
118118
xla::PjRtLocalDeviceId(local_device_id));
119119
}
120120

121-
std::intptr_t GetCudaStreamForDevice(int local_device_id) const override {
122-
absl::StatusOr<xla::PjRtDevice*> pjrt_device =
123-
client_->LookupAddressableDevice(
124-
xla::PjRtLocalDeviceId(local_device_id));
125-
XLA_CHECK(pjrt_device.ok()) << "Failed to get a PjRt device.";
126-
absl::StatusOr<std::intptr_t> stream =
127-
pjrt_device.value()->GetStreamForExternalReadyEvents();
128-
XLA_CHECK(stream.ok()) << "Failed to get a stream.";
129-
return stream.value();
130-
}
131-
132121
std::vector<std::string> GetLocalDevices() const override;
133122

134123
std::vector<std::string> GetAllDevices() const override;
@@ -169,7 +158,9 @@ class PjRtComputationClient : public ComputationClient {
169158
absl::Span<xla::PjRtDevice* const> devices) const;
170159

171160
void RegisterCustomCall(const std::string& fn_name, void* function_ptr,
172-
const std::string& platform) override;
161+
const std::string& platform) override {
162+
XLA_ERROR() << __FUNCTION__ << " not implemented";
163+
};
173164

174165
void OnReadyCallback(DataPtr data,
175166
const std::function<void()>& callback) override;

0 commit comments

Comments
 (0)