Skip to content

Commit c0eeb57

Browse files
authored
Remove CUDA specific path from internal Python packages. (#9606)
This PR removes CUDA specific code from internal Python packages, such as `_dynamo`, files in `_internal`, and the main `__init__.py` file. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - (`torch_xla/__init__.py`) Removed GPU specific OpenXLA flag - (`torch_xla/_dynamo/dynamo_bridge.py`) Removed CUDA tensor movement - As far as I know, mainly created for the zero overhead CUDA tensor movement
1 parent 05d9cba commit c0eeb57

File tree

4 files changed

+1
-80
lines changed

4 files changed

+1
-80
lines changed

torch_xla/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ def _set_missing_flags(flags, sets):
3131
def _setup_xla_flags():
3232
flags = os.environ.get('XLA_FLAGS', '').split(' ')
3333
flags = _set_missing_flags(flags, (('xla_cpu_enable_fast_math', 'false'),))
34-
flags = _set_missing_flags(flags,
35-
(('xla_gpu_force_compilation_parallelism', '8'),))
3634
os.environ['XLA_FLAGS'] = ' '.join(flags)
3735

3836

torch_xla/_dynamo/dynamo_bridge.py

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -119,48 +119,6 @@ def _get_input_arg_device(input_args: tuple) -> torch.device:
119119
return device
120120

121121

122-
# Returns True if all the input args are on a CUDA device.
123-
def _args_on_cuda(input_args: tuple) -> bool:
124-
input_device: torch.device = _get_input_arg_device(input_args)
125-
if input_device is None:
126-
return False
127-
128-
return input_device.type == "cuda"
129-
130-
131-
# Given an input list, moves the tensors to the given target_device.
132-
# The output order will be the same as the input. Non tensors will also still
133-
# be in the list.
134-
def _maybe_move_tensors_to_device(tensors: tuple,
135-
target_device: torch.device) -> tuple:
136-
assert target_device, "Moving tensors to None device not supported"
137-
138-
moved_tensors = []
139-
for tensor in tensors:
140-
if not isinstance(tensor, torch.Tensor):
141-
moved_tensors.append(tensor)
142-
continue
143-
144-
if tensor.device == target_device:
145-
moved_tensors.append(tensor)
146-
continue
147-
148-
if dynamo_debug:
149-
print("Moving Tensor {} to device {}".format(tensor, target_device))
150-
151-
# Have to move to CPU before moving it to target device.
152-
cpu_device: torch.device = torch.device("cpu")
153-
moved_tensor = tensor.to(cpu_device)
154-
moved_tensor = moved_tensor.to(target_device)
155-
156-
# Explicitly have to copy requires_grad attribute because it's dropped
157-
# with torch.to(..)
158-
moved_tensor.requires_grad = tensor.requires_grad
159-
moved_tensors.append(moved_tensor)
160-
161-
return tuple(moved_tensors)
162-
163-
164122
def _split_xla_args_tensor_sym_constant(args):
165123
tensors = deque(maxlen=len(args))
166124
constants = []
@@ -552,14 +510,6 @@ def optimized_mod(*args: tuple):
552510
special_return_handler, xla_args_need_update) = extract_graph_helper(
553511
xla_model, sym_constants_to_graph_vars)
554512

555-
original_device: torch.device = _get_input_arg_device(args)
556-
is_cuda_args: bool = False
557-
if original_device:
558-
is_cuda_args = original_device.type == "cuda"
559-
560-
if is_cuda_args:
561-
args = _maybe_move_tensors_to_device(args, torch_xla.device())
562-
563513
if not config.skip_input_data_check:
564514
# `torch_xla.sync()` needs to be blocking since we want to access args's
565515
# XLADatas and they can't be placeholder.
@@ -610,11 +560,7 @@ def optimized_mod(*args: tuple):
610560

611561
# First few elements might be xla_args that needs to be in place updated
612562
result = res[len(xla_args_need_update):]
613-
614563
result = none_remover.add_nones(result)
615-
if is_cuda_args:
616-
result = _maybe_move_tensors_to_device(tuple(result), original_device)
617-
618564
if len(result) == 1:
619565
return result[0]
620566
else:
@@ -802,10 +748,6 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
802748

803749

804750
def extract_compiled_graph_helper(xla_model: torch.fx.GraphModule, xla_args):
805-
if _args_on_cuda(xla_args):
806-
xla_args = tuple(
807-
_maybe_move_tensors_to_device(xla_args, torch_xla.device()))
808-
809751
# Synchronize xla_args, so that each FunctionalTensorWrapper argument updates its
810752
# value reference before actually computing it.
811753
for a in xla_args:

torch_xla/_internal/gpu.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

torch_xla/_internal/pjrt.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch_xla.core.xla_env_vars as xenv
1313
import torch_xla.core.xla_model as xm
1414
import torch_xla.distributed.xla_backend
15-
from torch_xla._internal import tpu, gpu, neuron
15+
from torch_xla._internal import tpu, neuron
1616
from torch_xla import runtime
1717
import torch_xla.utils.utils as xu
1818
from torch_xla.experimental import plugins
@@ -149,8 +149,6 @@ def run_multiprocess(fn: Callable[..., R],
149149
num_processes = plugins.default().physical_chip_count()
150150
elif runtime.device_type() == 'TPU':
151151
num_processes = tpu.num_local_processes()
152-
elif runtime.device_type() == 'CUDA':
153-
num_processes = gpu.num_local_processes()
154152
elif runtime.device_type() == 'NEURON':
155153
num_processes = neuron.num_local_processes()
156154
else:
@@ -220,8 +218,6 @@ def _initialize_single_process(local_rank: int, local_world_size: int):
220218

221219
def spawn_threads(fn: Callable, args: Tuple = ()) -> None:
222220
"""Run function in one process with one thread per addressable device."""
223-
assert runtime.device_type() not in (
224-
'CUDA'), "spawn_threads does not support GPU device"
225221
spawn_fn = _SpawnFn(fn, *args)
226222
_run_thread_per_device(
227223
local_rank=0,

0 commit comments

Comments
 (0)