Skip to content

Commit 05d9cba

Browse files
authored
Remove CUDA logic from C++ files in torch_xla/csrc directory. (#9603)
This PR removes CUDA specific code from C++ files in `torch_xla/csrc` directory. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - (`init_python_bindings.cpp`) Removed `_has_cuda_support` Python API - (`dl_convertor.cpp`) Removed CUDA handling of DLPack capsules - (`tensor_impl.cpp`) Removed special handling of `Autocast` dispatch key for XLA:CUDA device - Also added a check, crashing on `XLA:CUDA` device (shouldn't be supported anymore)
1 parent 8fb90c8 commit 05d9cba

File tree

5 files changed

+8
-37
lines changed

5 files changed

+8
-37
lines changed

torch_xla/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,7 @@ def _init_xla_lazy_backend():
259259
from .experimental import plugins
260260
from ._internal import neuron, xpu # Additional built-in plugins
261261

262-
if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS',
263-
'0' if _XLAC._has_cuda_support() else '1') == '1':
262+
if os.getenv('XLA_REGISTER_INSTALLED_PLUGINS', '1') == '1':
264263
plugins.use_dynamic_plugins()
265264
plugins.register_installed_plugins()
266265

torch_xla/csrc/dl_convertor.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ void DLPackTensorDeleter(DLManagedTensor* t) {
5151
DLDeviceType DLDeviceTypeForDevice(const xla::PjRtDevice& device) {
5252
if (device.client()->platform_id() == xla::CpuId()) {
5353
return DLDeviceType::kDLCPU;
54-
} else if (device.client()->platform_id() == xla::CudaId()) {
55-
return DLDeviceType::kDLCUDA;
5654
}
5755
XLA_ERROR() << "Device " << device.DebugString()
5856
<< " cannot be used as a DLPack device.";
@@ -176,11 +174,6 @@ absl::StatusOr<xla::PjRtDevice*> DeviceForDLDevice(const DLDevice& context) {
176174
xla::CpuId());
177175
return runtime::GetComputationClientOrDie()->LookupAddressableDevice(
178176
context.device_id);
179-
case DLDeviceType::kDLCUDA:
180-
XLA_CHECK_EQ(runtime::GetComputationClientOrDie()->GetPlatformID(),
181-
xla::CudaId());
182-
return runtime::GetComputationClientOrDie()->LookupAddressableDevice(
183-
context.device_id);
184177
default:
185178
return tsl::errors::InvalidArgument(
186179
"Unknown/unsupported DLPack device type %d", context.device_type);

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3055,14 +3055,6 @@ void InitXlaModuleBindings(py::module m) {
30553055
-> std::vector<at::Tensor> {
30563056
return TpuCustomCall(inputs, payload, output_shapes, output_dtypes);
30573057
})
3058-
.def("_has_cuda_support",
3059-
[]() {
3060-
#ifdef GOOGLE_CUDA
3061-
return true;
3062-
#else
3063-
return false;
3064-
#endif
3065-
})
30663058
.def("_xla_register_custom_call_target",
30673059
[](const std::string& fn_name, const py::capsule& function_ptr,
30683060
const std::string& platform) {

torch_xla/csrc/random.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,7 @@
1616
namespace torch_xla {
1717
namespace {
1818

19-
std::string GetDefaultGitGeneratorName() {
20-
XlaDeviceType hw_type =
21-
static_cast<XlaDeviceType>(bridge::GetCurrentDevice().type());
22-
switch (hw_type) {
23-
case XlaDeviceType::CUDA:
24-
return "three_fry";
25-
default:
26-
return "default";
27-
}
28-
}
19+
std::string GetDefaultGitGeneratorName() { return "default"; }
2920

3021
xla::BitGeneratorTy GetBitGenerator() {
3122
static const std::string* bit_generator =

torch_xla/csrc/tensor_impl.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <torch/csrc/lazy/core/tensor_util.h>
1010
#include <torch/csrc/lazy/core/util.h>
1111

12+
#include "absl/log/absl_check.h"
1213
#include "torch_xla/csrc/aten_xla_bridge.h"
1314
#include "torch_xla/csrc/device.h"
1415
#include "torch_xla/csrc/ir_builder.h"
@@ -71,16 +72,11 @@ XLATensorImpl::XLATensorImpl(XLATensor&& tensor)
7172
GetTypeMeta(tensor),
7273
bridge::XlaDeviceToAtenDevice(tensor.GetDevice())),
7374
tensor_(c10::make_intrusive<XLATensor>(std::move(tensor))) {
74-
// Update the Autocast key based off the backend device.
75-
// Upstream TensorImpl cannot differentiate between XLA:TPU and XLA:GPU
76-
// so we must manually update Autocast to AutocastCUDA on XLA:GPU.
77-
torch::lazy::BackendDevice current_device = bridge::GetCurrentDevice();
78-
auto dev_type = static_cast<XlaDeviceType>(current_device.type());
79-
if (dev_type == XlaDeviceType::CUDA) {
80-
auto autocast_cuda_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastCUDA);
81-
auto autocast_xla_ks = c10::DispatchKeySet(c10::DispatchKey::AutocastXLA);
82-
key_set_ = (key_set_ - autocast_xla_ks) | autocast_cuda_ks;
83-
}
75+
auto dev_type = static_cast<XlaDeviceType>(bridge::GetCurrentDevice().type());
76+
ABSL_CHECK(dev_type != XlaDeviceType::CUDA)
77+
<< "XLA:CUDA is not supported anymore. "
78+
"If you are seeing this error, report a bug to the PyTorch/XLA GitHub "
79+
"repository: https://github.com/pytorch/xla";
8480
const_cast<XLATensorImpl*>(this)->SetupSizeProperties();
8581
set_sizes_and_strides(sym_sizes_, c10::fromIntArrayRefSlow(
8682
sizes_and_strides_.strides_arrayref()));

0 commit comments

Comments
 (0)