Skip to content

Commit c48478a

Browse files
authored
Remove CUDA tests from distributed tests. (#9612)
This PR removes CUDA specific logic and tests from distributed tests. Including both multiprocessing and SPMD tests. This is in line with the CUDA deprecation that started on release 2.8. **Key Changes:** - Removed `skipIf` test decorations whenever the condition is checking CUDA - Removed `CUDA` from the list of allowed devices for a few of these tests
1 parent 8ff2ee6 commit c48478a

19 files changed

+36
-100
lines changed

test/eager/test_eager_all_reduce_in_place.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def _mp_fn(index):
1212

1313
device = torch_xla.device()
1414

15-
if xm.xla_device_hw(device) not in ('TPU', 'CUDA', 'NEURON'):
15+
if xm.xla_device_hw(device) not in ('TPU', 'NEURON'):
1616
return
1717

1818
ordinal_tensor_1 = torch.tensor([index], dtype=torch.float).to(device)

test/pjrt/test_ddp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ def _ddp_init(index: int = ...):
3333
def test_ddp_init(self):
3434
pjrt.run_multiprocess(self._ddp_init)
3535

36-
@absltest.skipIf(xr.device_type() == 'CUDA',
37-
"GPU device is not supported by pjrt.spawn_threads")
3836
def test_ddp_init_threaded(self):
3937
pjrt.spawn_threads(self._ddp_init)
4038

test/spmd/test_spmd_debugging.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ def setUpClass(cls):
2828
xr.use_spmd()
2929
super().setUpClass()
3030

31-
@unittest.skipIf(
32-
xr.device_type() == 'CPU',
33-
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
31+
@unittest.skipIf(xr.device_type() == 'CPU',
32+
f"Requires PJRT_DEVICE set to `TPU`.")
3433
def test_debugging_spmd_single_host_tiled_tpu(self):
3534
from torch_xla.distributed.spmd.debugging import visualize_sharding
3635
sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}'
@@ -108,9 +107,8 @@ def test_debugging_spmd_single_host_tiled_tpu(self):
108107
fake_output = fake_capture.get()
109108
assert output == fake_output
110109

111-
@unittest.skipIf(
112-
xr.device_type() == 'CPU',
113-
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
110+
@unittest.skipIf(xr.device_type() == 'CPU',
111+
f"Requires PJRT_DEVICE set to `TPU`.")
114112
def test_single_host_partial_replication_tpu(self):
115113
from torch_xla.distributed.spmd.debugging import visualize_sharding
116114
sharding = '{devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}'
@@ -168,9 +166,8 @@ def test_single_host_partial_replication_tpu(self):
168166
fake_output = fake_capture.get()
169167
assert output == fake_output
170168

171-
@unittest.skipIf(
172-
xr.device_type() == 'CPU',
173-
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
169+
@unittest.skipIf(xr.device_type() == 'CPU',
170+
f"Requires PJRT_DEVICE set to `TPU`.")
174171
def test_single_host_replicated_tpu(self):
175172
from torch_xla.distributed.spmd.debugging import visualize_sharding
176173
sharding = '{replicated}'
@@ -340,9 +337,8 @@ def test_single_host_replicated_cpu(self):
340337
# e.g.: sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}
341338
# e.g.: sharding={replicated}
342339

343-
@unittest.skipIf(
344-
xr.device_type() == 'CPU',
345-
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
340+
@unittest.skipIf(xr.device_type() == 'CPU',
341+
f"Requires PJRT_DEVICE set to `TPU`.")
346342
def test_debugging_spmd_multi_host_tiled_tpu(self):
347343
from torch_xla.distributed.spmd.debugging import visualize_sharding
348344
sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}'
@@ -468,9 +464,8 @@ def test_debugging_spmd_multi_host_tiled_tpu(self):
468464
fake_output = fake_capture.get()
469465
assert output == fake_output
470466

471-
@unittest.skipIf(
472-
xr.device_type() == 'CPU',
473-
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
467+
@unittest.skipIf(xr.device_type() == 'CPU',
468+
f"Requires PJRT_DEVICE set to `TPU`.")
474469
def test_multi_host_partial_replication_tpu(self):
475470
from torch_xla.distributed.spmd.debugging import visualize_sharding
476471
sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}'
@@ -560,9 +555,8 @@ def test_multi_host_partial_replication_tpu(self):
560555
fake_output = fake_capture.get()
561556
assert output == fake_output
562557

563-
@unittest.skipIf(
564-
xr.device_type() == 'CPU',
565-
f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'.")
558+
@unittest.skipIf(xr.device_type() == 'CPU',
559+
f"Requires PJRT_DEVICE set to `TPU`.")
566560
@unittest.skipIf(xr.global_runtime_device_count() != 8,
567561
f"Limit test num_devices to 8 for function consistency")
568562
def test_multi_host_replicated_tpu(self):

test/spmd/test_train_spmd_linear_model.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
# the gradient checkpointing A/B test run for it.
2121
SKIP_GRADIENT_CHECKPOINTING: bool = False
2222

23-
skipOnGpu = unittest.skipIf(xr.device_type() == 'CUDA',
24-
'https://github.com/pytorch/xla/issues/9128')
25-
2623

2724
@contextmanager
2825
def extended_argv(args):

test/spmd/test_xla_spmd_python_api_interaction.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,6 @@ def test_global_runtime_device_count(self):
9898
self.assertGreaterEqual(xr.global_runtime_device_count(), 4)
9999
elif device_type == "CPU":
100100
self.assertEqual(xr.global_runtime_device_count(), 1)
101-
elif device_type == 'CUDA':
102-
command = 'nvidia-smi --list-gpus | wc -l'
103-
result = subprocess.run(
104-
command,
105-
capture_output=True,
106-
shell=True,
107-
check=True,
108-
text=True,
109-
)
110-
expected_gpu_cnt = int(result.stdout)
111-
self.assertEqual(xr.global_runtime_device_count(), expected_gpu_cnt)
112101

113102
def test_addressable_runtime_device_count(self):
114103
device_type = os.environ['PJRT_DEVICE']
@@ -145,8 +134,7 @@ class BasicAutocastAPITest(test_xla_sharding_base.XlaShardingTest):
145134
def setUpClass(cls):
146135
super().setUpClass()
147136

148-
@unittest.skipIf(xr.device_type() not in ['TPU', 'CUDA'],
149-
f"TPU/GPU autocast test.")
137+
@unittest.skipIf(xr.device_type() not in ('TPU',), f"TPU autocast test.")
150138
def test_xla_autocast_api(self):
151139
device = torch_xla.device()
152140
t1 = torch.ones([2, 3], device=device, dtype=torch.float32)

test/test_assume_pure_spmd.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,6 @@ def setUp(self):
3737

3838
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
3939
"Multiple devices required")
40-
@unittest.skipIf(
41-
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
42-
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
43-
)
4440
def test_assume_pure_works_with_mark_sharding(self):
4541
x = torch.randn((8, 4, 5, 128), device='xla')
4642
result = assume_pure(mark_sharding)(x, self.spmd_mesh,
@@ -52,10 +48,6 @@ def test_assume_pure_works_with_mark_sharding(self):
5248

5349
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
5450
"Multiple devices required")
55-
@unittest.skipIf(
56-
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
57-
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
58-
)
5951
def test_assume_pure_works_with_mark_sharding_with_gradients(self):
6052
x = torch.randn((8, 4, 5, 128)).to('xla').requires_grad_(True)
6153
result = assume_pure(mark_sharding_with_gradients)(
@@ -71,10 +63,6 @@ def test_assume_pure_works_with_mark_sharding_with_gradients(self):
7163

7264
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
7365
"Multiple devices required")
74-
@unittest.skipIf(
75-
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
76-
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
77-
)
7866
def test_assume_pure_works_with_mark_sharding_nested(self):
7967
mesh = get_2d_mesh("model", "batch")
8068
set_global_mesh(mesh)
@@ -88,10 +76,6 @@ def test_assume_pure_works_with_mark_sharding_nested(self):
8876

8977
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
9078
"Multiple devices required")
91-
@unittest.skipIf(
92-
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
93-
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
94-
)
9579
def test_assume_pure_works_with_mark_sharding_with_gradients_nested(self):
9680
mesh = get_2d_mesh("model", "batch")
9781
set_global_mesh(mesh)
@@ -109,10 +93,6 @@ def test_assume_pure_works_with_mark_sharding_with_gradients_nested(self):
10993

11094
@unittest.skipUnless(xr.global_runtime_device_count() > 1,
11195
"Multiple devices required")
112-
@unittest.skipIf(
113-
torch.cuda.is_available() or os.environ.get('PJRT_DEVICE') == 'CUDA',
114-
"TODO(https://github.com/pytorch/xla/issues/9017): Get these tests working on GPU"
115-
)
11696
def test_convert_to_jax_mesh(self):
11797
jax_mesh = self.spmd_mesh.get_jax_mesh()
11898
self.assertEqual(jax_mesh.devices.shape, self.spmd_mesh.mesh_shape)

test/test_fsdp_auto_wrap.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@ def forward(self, x):
3030
hidden2 = self.fc2(x)
3131
return hidden1, hidden2
3232

33-
@unittest.skipIf(
34-
xr.device_type() == 'CUDA',
35-
"This test fails only on GPU with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)"
36-
)
3733
def test(self):
3834
dev = torch_xla.device()
3935
input = torch.zeros([16, 16], device=dev)
@@ -49,13 +45,12 @@ def test(self):
4945

5046
def _mp_fn(index):
5147
device = torch_xla.device()
52-
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
48+
if xm.xla_device_hw(device) in ('TPU',):
5349
test = unittest.main(exit=False)
5450
sys.exit(0 if test.result.wasSuccessful() else 1)
5551
else:
5652
print(
57-
'Default device {} is not a TPU or CUDA device'.format(device),
58-
file=sys.stderr)
53+
'Default device {} is not a TPU device'.format(device), file=sys.stderr)
5954

6055

6156
if __name__ == '__main__':

test/test_mp_all_gather.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def _mp_fn(index):
1414
device = torch_xla.device()
1515
world_size = xr.world_size()
1616
input_list_size = 5
17-
if xm.xla_device_hw(device) in ('TPU', 'CUDA', 'NEURON'):
17+
if xm.xla_device_hw(device) in ('TPU', 'NEURON'):
1818
# Testing with a single replica group
1919
ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device)
2020
result = xm.all_gather(ordinal_tensor, dim=0)
@@ -161,7 +161,7 @@ def _mp_fn(index):
161161
# TODO: add test for torch.compile when support for list input is ready
162162

163163
else:
164-
print(f'{device} is not a TPU or GPU device', file=sys.stderr)
164+
print(f'{device} is not a TPU device', file=sys.stderr)
165165

166166

167167
if __name__ == '__main__':

test/test_mp_distributed_mm.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
def _mp_fn(index):
1010
device = torch_xla.device()
1111

12-
if xm.xla_device_hw(device) in ('TPU', 'CUDA'):
12+
if xm.xla_device_hw(device) in ('TPU',):
1313
world_size = xr.world_size()
1414
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
1515
torch.manual_seed(11)
@@ -34,8 +34,7 @@ def _mp_fn(index):
3434
sys.exit(1)
3535
else:
3636
print(
37-
'Default device {} is not a TPU or GPU device'.format(device),
38-
file=sys.stderr)
37+
'Default device {} is not a TPU device'.format(device), file=sys.stderr)
3938

4039

4140
if __name__ == '__main__':

test/test_mp_early_exit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
def _mp_fn():
1414
dist.init_process_group('xla', init_method='xla://')
1515
device = torch_xla.device()
16-
if xm.xla_device_hw(device) in ['TPU', 'CUDA']:
16+
if xm.xla_device_hw(device) in ('TPU',):
1717
train_loader = xu.SampleGenerator(
1818
data=torch.zeros(1, 12), sample_count=1024)
1919
train_loader = pl.MpDeviceLoader(train_loader, device)
@@ -23,7 +23,7 @@ def _mp_fn():
2323
if step > max_steps:
2424
break
2525
else:
26-
print(f'{device} is not a TPU or GPU device', file=sys.stderr)
26+
print(f'{device} is not a TPU device', file=sys.stderr)
2727

2828

2929
if __name__ == '__main__':

0 commit comments

Comments
 (0)