Skip to content

Commit 358cf81

Browse files
use needs_cuda decorator
1 parent 8984188 commit 358cf81

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

test/test_transforms_v2.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
)
6161

6262

63-
CUDA_AVAILABLE = torch.cuda.is_available()
6463
CVCUDA_AVAILABLE = _is_cvcuda_available()
6564
if CVCUDA_AVAILABLE:
6665
cvcuda = _import_cvcuda()
@@ -6745,7 +6744,7 @@ def test_functional_error(self):
67456744

67466745

67476746
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6748-
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA")
6747+
@needs_cuda
67496748
class TestToCVCUDATensor:
67506749
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
67516750
@pytest.mark.parametrize("device", cpu_and_cuda())
@@ -6812,7 +6811,7 @@ def test_round_trip(self, dtype, device, color_space, batch_size):
68126811

68136812

68146813
@pytest.mark.skipif(not CVCUDA_AVAILABLE, reason="test requires CVCUDA")
6815-
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="test requires CUDA")
6814+
@needs_cuda
68166815
class TestCVDUDAToTensor:
68176816
@pytest.mark.parametrize("dtype", [torch.uint8, torch.uint16, torch.float32, torch.float64])
68186817
@pytest.mark.parametrize("device", cpu_and_cuda())

0 commit comments

Comments
 (0)