Skip to content

Commit 15496cd

Browse files
authored
Do not set PJRT_DEVICE=CUDA automatically on import. (#9540)
1 parent 43589c0 commit 15496cd

File tree

1 file changed

+0
-10
lines changed

1 file changed

+0
-10
lines changed

torch_xla/runtime.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,6 @@ def _maybe_select_default_device():
6363
if torch_xla._found_libtpu and tpu.num_available_chips() > 0:
6464
logging.warning('libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.')
6565
os.environ[xenv.PJRT_DEVICE] = 'TPU'
66-
elif xu.getenv_as(xenv.GPU_NUM_DEVICES, int, 0) > 0:
67-
logging.warning('GPU_NUM_DEVICES is set. Setting PJRT_DEVICE=CUDA')
68-
os.environ[xenv.PJRT_DEVICE] = 'CUDA'
69-
elif torch.cuda.is_available() and torch.cuda.device_count() > 0:
70-
num_devices_str = str(torch.cuda.device_count())
71-
logging.warning(
72-
'Found CUDA without GPU_NUM_DEVICES. Defaulting to PJRT_DEVICE=CUDA with GPU_NUM_DEVICES='
73-
+ num_devices_str)
74-
os.environ[xenv.PJRT_DEVICE] = 'CUDA'
75-
os.environ[xenv.GPU_NUM_DEVICES] = num_devices_str
7666
elif torch_xla._found_libneuronxla:
7767
logging.warning('Found libneuronpjrt.so. Setting PJRT_DEVICE=NEURON.')
7868
os.environ[xenv.PJRT_DEVICE] = 'NEURON'

0 commit comments

Comments
 (0)