Skip to content

Commit 647804c

Browse files
authored
Remove remaining GPU/CUDA mentions in torch_xla directory. (#9608)
This PR removes the remaining CUDA specific code from the PyTorch/XLA package (i.e. `torch_xla` directory) as well a few other related files. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - (`CONTRIBUTING.md`) Removed mention to CUDA specific environment variables - (`configuration.yaml`) Removed description of CUDA specific environment variables - (`docs/source/learn/_pjrt.md`) Removed PjRt documentation on CUDA - (`torch_xla/amp`) Removed CUDA specific branches, as well as `GradScaler` - (`torch_xla/core/xla_env_vars.py`) Removed CUDA specific environment variables - (`torch_xla/utils/checkpoint.py`) Fixed incorrect function name
1 parent 89f929b commit 647804c

File tree

18 files changed

+17
-300
lines changed

18 files changed

+17
-300
lines changed

CONTRIBUTING.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,12 +291,6 @@ To run the tests, follow __one__ of the options below:
291291
export PJRT_DEVICE=TPU
292292
```
293293

294-
* Run on GPU:
295-
296-
```shell
297-
export PJRT_DEVICE=CUDA GPU_NUM_DEVICES=${NUM_GPU}
298-
```
299-
300294
For more detail on configuring the runtime, please refer to [this doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#quickstart)
301295

302296
If you are planning to be building from source and hence using the latest _PyTorch/TPU_ code base,

configuration.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@ variables:
1515
- Whether or not to create an async PJRT client for the CPU device(s).
1616
type: bool
1717
default_value: false
18-
PJRT_GPU_ASYNC_CLIENT:
19-
description:
20-
- Whether or not to create an async PJRT client for the GPU device(s).
21-
type: bool
22-
default_value: false
2318
PJRT_TPU_MAX_INFLIGHT_COMPUTATIONS:
2419
description:
2520
- Max inflight computations that the PJRT client can handle for TPU.
@@ -229,10 +224,6 @@ variables:
229224
description:
230225
- Number of CPU devices being used by this instance of XRT.
231226
type: int
232-
GPU_NUM_DEVICES:
233-
description:
234-
- Number of GPU devices being used by this instance of XRT.
235-
type: int
236227
debug_variables:
237228
XLA_FNTRACKER_FILE:
238229
description:

docs/source/learn/_pjrt.md

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -188,69 +188,6 @@ time. See the [Cloud TPU
188188
documentation](https://cloud.google.com/tpu/docs/run-in-container) for
189189
more information.
190190

191-
### GPU
192-
193-
### Single-node GPU training
194-
195-
To use GPUs with PJRT, simply set `PJRT_DEVICE=CUDA` and configure
196-
`GPU_NUM_DEVICES` to the number of devices on the host. For example:
197-
198-
PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python3 xla/test/test_train_mp_imagenet.py --fake_data --batch_size=128 --num_epochs=1
199-
200-
You can also use `torchrun` to initiate the single-node multi-GPU
201-
training. For example,
202-
203-
PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node ${NUM_GPU_DEVICES} xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
204-
205-
In the above example, `--nnodes` means how many machines (physical
206-
machines or VMs) to be used (it is 1 since we do single-node training).
207-
`--nproc-per-node` means how many GPU devices to be used.
208-
209-
### Multi-node GPU training
210-
211-
**Note that this feature only works for cuda 12+**. Similar to how
212-
PyTorch uses multi-node training, you can run the command as below:
213-
214-
PJRT_DEVICE=CUDA torchrun \
215-
--nnodes=${NUMBER_GPU_VM} \
216-
--node_rank=${CURRENT_NODE_RANK} \
217-
--nproc_per_node=${NUMBER_LOCAL_GPU_DEVICES} \
218-
--rdzv_endpoint=<internal_ip_address:port> multinode_training.py
219-
220-
- `--nnodes`: how many GPU machines to be used.
221-
- `--node_rank`: the index of the current GPU machines. The value can
222-
be 0, 1, ..., \${NUMBER_GPU_VM}-1.
223-
- `--nproc_per_node`: the number of GPU devices to be used on the
224-
current machine.
225-
- `--rdzv_endpoint`: the endpoint of the GPU machine with
226-
node_rank==0, in the form `host:port`. The `host` will be the
227-
internal IP address. The `port` can be any available port on the
228-
machine. For single-node training/inference, this parameter can be
229-
omitted.
230-
231-
For example, if you want to train on 2 GPU machines: machine_0 and
232-
machine_1, on the first GPU machine machine_0, run
233-
234-
# PJRT_DEVICE=CUDA torchrun \
235-
--nnodes=2 \
236-
--node_rank=0 \
237-
--nproc_per_node=4 \
238-
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
239-
240-
On the second GPU machine, run
241-
242-
# PJRT_DEVICE=CUDA torchrun \
243-
--nnodes=2 \
244-
--node_rank=1 \
245-
--nproc_per_node=4 \
246-
--rdzv_endpoint="<MACHINE_0_INTERNAL_IP_ADDRESS>:12355" pytorch/xla/test/test_train_mp_imagenet.py --fake_data --pjrt_distributed --batch_size=128 --num_epochs=1
247-
248-
the difference between the 2 commands above are `--node_rank` and
249-
potentially `--nproc_per_node` if you want to use different number of
250-
GPU devices on each machine. All the rest are identical. For more
251-
information about `torchrun`, please refer to this
252-
[page](https://pytorch.org/docs/stable/elastic/run.html).
253-
254191
## Differences from XRT
255192

256193
Although in most cases we expect PJRT and XRT to work mostly

docs/source/perf/amp.md

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -95,59 +95,6 @@ unlisted ops run if they're downstream from autocasted ops.
9595

9696
`stack`, `cat`, `index_copy`
9797

98-
## AMP for XLA:GPU
99-
100-
AMP on XLA:GPU devices reuse Pytorch's AMP rules. See [Pytorch's AMP
101-
documentation](https://pytorch.org/docs/stable/amp.html) for CUDA
102-
specific behavior. A simple CUDA AMP example is below:
103-
104-
``` python
105-
from torch_xla.amp import syncfree
106-
import torch_xla.core.xla_model as xm
107-
108-
# Creates model and optimizer in default precision
109-
model = Net().to('xla')
110-
# Pytorch/XLA provides sync-free optimizers for improved performance
111-
optimizer = syncfree.SGD(model.parameters(), ...)
112-
scaler = GradScaler()
113-
114-
for input, target in data:
115-
optimizer.zero_grad()
116-
117-
# Enables autocasting for the forward pass
118-
with autocast(torch_xla.device()):
119-
output = model(input)
120-
loss = loss_fn(output, target)
121-
122-
# Exits the context manager before backward pass
123-
scaler.scale(loss).backward()
124-
gradients = xm._fetch_gradients(optimizer)
125-
xm.all_reduce('sum', gradients, scale=1.0 / xr.world_size())
126-
scaler.step(optimizer)
127-
scaler.update()
128-
```
129-
130-
`autocast(torch_xla.device())` aliases `torch.cuda.amp.autocast()` when the
131-
XLA Device is a CUDA device (XLA:GPU). Alternatively, if a script is
132-
only used with CUDA devices, then `torch.cuda.amp.autocast` can be
133-
directly used, but requires `torch` is compiled with `cuda` support for
134-
datatype of `torch.bfloat16`. We recommend using
135-
`autocast(torch_xla.device())` on XLA:GPU as it does not require
136-
`torch.cuda` support for any datatypes, including `torch.bfloat16`.
137-
138-
### AMP for XLA:GPU Best Practices
139-
140-
1. `autocast` should wrap only the forward pass(es) and loss
141-
computation(s) of the network. Backward ops run in the same type
142-
that autocast used for the corresponding forward ops.
143-
2. Do not set `XLA_USE_F16` flag when using AMP on Cuda devices. This
144-
will override the per-operator precision settings provided by AMP
145-
and cause all operators to execute in float16.
146-
3. Use gradient scaling to prevent float16 gradients from underflowing.
147-
4. Pytorch/XLA provides modified version of
148-
[optimizers](https://github.com/pytorch/xla/tree/master/torch_xla/amp/syncfree)
149-
that avoid the additional sync between device and host.
150-
15198
## Examples
15299

153100
Our [mnist training script](https://github.com/pytorch/xla/blob/master/test/test_train_mp_mnist_amp.py)

test/test_autocast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import collections
1313
import unittest
1414
from torch.testing._internal.autocast_test_lists import AutocastTestLists
15-
from torch_xla.amp import autocast, GradScaler
15+
from torch_xla.amp import autocast
1616

1717

1818
class AutocastTPUTestLists:

test/test_train_mp_imagenet_amp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
import torch_xla.utils.utils as xu
6868
import torch_xla.core.xla_model as xm
6969
import torch_xla.test.test_utils as test_utils
70-
from torch_xla.amp import autocast, GradScaler
70+
from torch_xla.amp import autocast
7171
try:
7272
from torch_xla.amp import syncfree
7373
except ImportError:
@@ -220,8 +220,6 @@ def train_imagenet():
220220
if FLAGS.amp:
221221
if device_hw == 'TPU':
222222
scaler = None
223-
elif device_hw == 'CUDA':
224-
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)
225223

226224
def train_loop_fn(loader, epoch):
227225
tracker = xm.RateTracker()

test/test_train_mp_mnist_amp.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import torch_xla.core.xla_model as xm
3939
import torch_xla.distributed.xla_multiprocessing as xmp
4040
import torch_xla.test.test_utils as test_utils
41-
from torch_xla.amp import autocast, GradScaler
41+
from torch_xla.amp import autocast
4242
try:
4343
from torch_xla.amp import syncfree
4444
except ImportError:
@@ -143,11 +143,8 @@ def train_mnist(flags, **kwargs):
143143

144144
if device_hw == 'TPU':
145145
scaler = None
146-
elif device_hw == 'CUDA':
147-
# GradScaler only used for GPU
148-
scaler = GradScaler(use_zero_grad=FLAGS.use_zero_grad)
149146
else:
150-
print("Only TPU or GPU supported for AMP.")
147+
print("Only TPU supported for AMP.")
151148
sys.exit(1)
152149

153150
def train_loop_fn(loader):

torch_xla/_internal/pjrt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def spawn(fn: Callable,
205205
return _run_singleprocess(spawn_fn)
206206
elif nprocs is not None:
207207
raise ValueError(
208-
'Unsupported nprocs (%d). Please use nprocs=1 or None (default). If None, spawn will use all available devices. Use the environment variable X_NUM_DEVICES (where X is CPU, GPU, TPU, NEURONCORE, etc) to limit the number of devices used.'
208+
'Unsupported nprocs (%d). Please use nprocs=1 or None (default). If None, spawn will use all available devices. Use the environment variable X_NUM_DEVICES (where X is CPU, TPU, NEURONCORE, etc) to limit the number of devices used.'
209209
% nprocs)
210210

211211
run_multiprocess(spawn_fn, start_method=start_method)

torch_xla/amp/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
from .autocast_mode import autocast # noqa: F401
2-
from .grad_scaler import GradScaler # noqa: F401

torch_xla/amp/autocast_mode.py

Lines changed: 4 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ class autocast(torch.amp.autocast_mode.autocast):
1010
r"""
1111
`torch.autocast` for XLA backend devices. See :class:`torch.autocast`.
1212
``torch_xla.amp.autocast(device, **kwargs)`` is equivalent to
13-
``torch.autocast("xla", **kwargs)`` for XLA:GPU and XLA:TPU for dtype torch.bfloat16,
14-
``torch.autocast("cuda", **kwargs)`` for XLA:GPU and other dtypes.
13+
``torch.autocast("xla", **kwargs)`` for XLA:TPU for dtype torch.bfloat16.
1514
"""
1615

1716
def __init__(self,
@@ -20,34 +19,11 @@ def __init__(self,
2019
dtype: torch.dtype = None,
2120
cache_enabled: bool = True):
2221
# `torch_xla.amp.autocast` is intended for XLA backend, with AutocastXLA dispatch key.
23-
assert 'xla' in device.__str__(
24-
), "torch_xla.autocast is available for XLA:TPU, XLA:GPU"
22+
assert 'xla' in str(device), "torch_xla.autocast is available for XLA:TPU"
2523

2624
self._enabled = enabled
2725
self._xla_device = xm.xla_device_hw(device)
28-
if self._xla_device == 'CUDA':
29-
backend = 'cuda'
30-
self._xla_bfloat16 = False # True if xla backend with bfloat16 dtype.
31-
if dtype is None:
32-
dtype = torch.float16
33-
elif dtype == torch.bfloat16 and not torch.cuda.is_available():
34-
if xr.is_bf16_supported():
35-
# XLA:GPU with bfloat16 should run on `xla` backend
36-
# unless torch.autocast is compiled with cuda.
37-
backend = 'xla'
38-
self._xla_bfloat16 = True
39-
else:
40-
# This has been the default behavior for unsupported bfloat16 dtype
41-
dtype = torch.float16
42-
error_message = "In XLA:GPU autocast, but bfloat16 is not supported on this HW.\n"
43-
error_message += ("Using the default cuda autocast dtype float16.")
44-
self._dtype = dtype
45-
super().__init__(
46-
backend,
47-
enabled=enabled,
48-
dtype=self._dtype,
49-
cache_enabled=cache_enabled)
50-
elif self._xla_device == 'TPU' or self._xla_device == 'NEURON':
26+
if self._xla_device == 'TPU' or self._xla_device == 'NEURON':
5127
if dtype is None:
5228
dtype = torch.bfloat16
5329
if dtype != torch.bfloat16:
@@ -63,39 +39,4 @@ def __init__(self,
6339
dtype=self._dtype,
6440
cache_enabled=cache_enabled)
6541
else:
66-
print(
67-
'Warning: AMP only supported for XLA:TPU and XLA:GPU. Ignoring autocast.'
68-
)
69-
70-
def __enter__(self):
71-
# This ensures that xla autocast is enabled even for XLA:GPU, which calls
72-
# `torch.amp.autocast_mode.autocast` with `cuda` backend.
73-
if self._xla_device == 'CUDA':
74-
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
75-
self.prev_dtype = torch.get_autocast_xla_dtype(
76-
) # type: ignore[attr-defined]
77-
if self._xla_bfloat16:
78-
# autocast_xla flags will be set by `torch.autocast` and we need to
79-
# set autocast flags as we call into `torch.autocast` apis.
80-
torch.set_autocast_enabled(self._enabled)
81-
torch.set_autocast_gpu_dtype(self._dtype)
82-
else:
83-
torch.set_autocast_xla_enabled(self._enabled)
84-
torch.set_autocast_xla_dtype(self._dtype)
85-
return super().__enter__()
86-
87-
def __exit__(self, exc_type: Any, exc_val: Any,
88-
exc_tb: Any): # type: ignore[override]
89-
if self._xla_device == 'CUDA':
90-
if self._xla_bfloat16:
91-
# autocast_xla flags will be set by `torch.autocast` and we need to
92-
# set autocast flags as we call into `torch.autocast` apis.
93-
torch.set_autocast_enabled(self.prev)
94-
torch.set_autocast_gpu_dtype(self.prev_dtype)
95-
else:
96-
torch.set_autocast_xla_enabled(self.prev)
97-
torch.set_autocast_xla_dtype(self.prev_dtype)
98-
return super().__exit__(exc_type, exc_val, exc_tb)
99-
100-
def __call__(self, func):
101-
return super().__call__(func)
42+
print('Warning: AMP only supported for XLA:TPU. Ignoring autocast.')

0 commit comments

Comments
 (0)