@@ -28,9 +28,8 @@ def setUpClass(cls):
28
28
xr .use_spmd ()
29
29
super ().setUpClass ()
30
30
31
- @unittest .skipIf (
32
- xr .device_type () == 'CPU' ,
33
- f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'." )
31
+ @unittest .skipIf (xr .device_type () == 'CPU' ,
32
+ f"Requires PJRT_DEVICE set to `TPU`." )
34
33
def test_debugging_spmd_single_host_tiled_tpu (self ):
35
34
from torch_xla .distributed .spmd .debugging import visualize_sharding
36
35
sharding = '{devices=[2,4]0,1,2,3,4,5,6,7}'
@@ -108,9 +107,8 @@ def test_debugging_spmd_single_host_tiled_tpu(self):
108
107
fake_output = fake_capture .get ()
109
108
assert output == fake_output
110
109
111
- @unittest .skipIf (
112
- xr .device_type () == 'CPU' ,
113
- f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'." )
110
+ @unittest .skipIf (xr .device_type () == 'CPU' ,
111
+ f"Requires PJRT_DEVICE set to `TPU`." )
114
112
def test_single_host_partial_replication_tpu (self ):
115
113
from torch_xla .distributed .spmd .debugging import visualize_sharding
116
114
sharding = '{devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate}'
@@ -168,9 +166,8 @@ def test_single_host_partial_replication_tpu(self):
168
166
fake_output = fake_capture .get ()
169
167
assert output == fake_output
170
168
171
- @unittest .skipIf (
172
- xr .device_type () == 'CPU' ,
173
- f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'." )
169
+ @unittest .skipIf (xr .device_type () == 'CPU' ,
170
+ f"Requires PJRT_DEVICE set to `TPU`." )
174
171
def test_single_host_replicated_tpu (self ):
175
172
from torch_xla .distributed .spmd .debugging import visualize_sharding
176
173
sharding = '{replicated}'
@@ -340,9 +337,8 @@ def test_single_host_replicated_cpu(self):
340
337
# e.g.: sharding={devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}
341
338
# e.g.: sharding={replicated}
342
339
343
- @unittest .skipIf (
344
- xr .device_type () == 'CPU' ,
345
- f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'." )
340
+ @unittest .skipIf (xr .device_type () == 'CPU' ,
341
+ f"Requires PJRT_DEVICE set to `TPU`." )
346
342
def test_debugging_spmd_multi_host_tiled_tpu (self ):
347
343
from torch_xla .distributed .spmd .debugging import visualize_sharding
348
344
sharding = '{devices=[2,8]0,4,8,12,2,6,10,14,1,5,9,13,3,7,11,15}'
@@ -468,9 +464,8 @@ def test_debugging_spmd_multi_host_tiled_tpu(self):
468
464
fake_output = fake_capture .get ()
469
465
assert output == fake_output
470
466
471
- @unittest .skipIf (
472
- xr .device_type () == 'CPU' ,
473
- f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'." )
467
+ @unittest .skipIf (xr .device_type () == 'CPU' ,
468
+ f"Requires PJRT_DEVICE set to `TPU`." )
474
469
def test_multi_host_partial_replication_tpu (self ):
475
470
from torch_xla .distributed .spmd .debugging import visualize_sharding
476
471
sharding = '{devices=[8,1,2]0,1,4,5,8,9,12,13,2,3,6,7,10,11,14,15 last_tile_dim_replicate}'
@@ -560,9 +555,8 @@ def test_multi_host_partial_replication_tpu(self):
560
555
fake_output = fake_capture .get ()
561
556
assert output == fake_output
562
557
563
- @unittest .skipIf (
564
- xr .device_type () == 'CPU' ,
565
- f"Requires PJRT_DEVICE set to `TPU`, `GPU`, `CUDA`, or 'ROCM'." )
558
+ @unittest .skipIf (xr .device_type () == 'CPU' ,
559
+ f"Requires PJRT_DEVICE set to `TPU`." )
566
560
@unittest .skipIf (xr .global_runtime_device_count () != 8 ,
567
561
f"Limit test num_devices to 8 for function consistency" )
568
562
def test_multi_host_replicated_tpu (self ):
0 commit comments