Skip to content

Commit 3f6f771

Browse files
add int16/int32 dtypes in round trip tests
1 parent 32a5bf9 commit 3f6f771

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

test/test_transforms_v2.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6845,15 +6845,15 @@ def test_float64_rgb_not_supported(self):
68456845
F.to_cvcuda_tensor(img_data)
68466846

68476847
@pytest.mark.parametrize("num_channels", [1, 3])
6848-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32, torch.float64])
6848+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.int16, torch.int32, torch.float32, torch.float64])
68496849
def test_round_trip(self, num_channels, dtype):
6850-
# Skip float64 for 3-channel (not supported by CV-CUDA)
6851-
if num_channels == 3 and dtype == torch.float64:
6852-
pytest.skip("float64 is not supported for 3-channel RGB images")
6850+
# Skip int16/int32/float64 for 3-channel (only supported for single-channel)
6851+
if num_channels == 3 and dtype in (torch.int16, torch.int32, torch.float64):
6852+
pytest.skip(f"{dtype} is only supported for single-channel images")
68536853

68546854
# Setup: Create a tensor in CHW format (PyTorch standard)
68556855
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6856-
if dtype == torch.uint8:
6856+
if dtype in (torch.uint8, torch.int16, torch.int32):
68576857
original_tensor = torch.randint(0, 256, (num_channels, 4, 4), dtype=dtype)
68586858
else:
68596859
original_tensor = torch.rand(num_channels, 4, 4, dtype=dtype)
@@ -6871,16 +6871,16 @@ def test_round_trip(self, num_channels, dtype):
68716871
torch.testing.assert_close(result_tensor, original_tensor, rtol=0, atol=0)
68726872

68736873
@pytest.mark.parametrize("num_channels", [1, 3])
6874-
@pytest.mark.parametrize("dtype", [torch.uint8, torch.float32, torch.float64])
6874+
@pytest.mark.parametrize("dtype", [torch.uint8, torch.int16, torch.int32, torch.float32, torch.float64])
68756875
@pytest.mark.parametrize("batch_size", [1, 2, 4])
68766876
def test_round_trip_batched(self, num_channels, dtype, batch_size):
6877-
# Skip float64 for 3-channel (not supported by CV-CUDA)
6878-
if num_channels == 3 and dtype == torch.float64:
6879-
pytest.skip("float64 is not supported for 3-channel RGB images")
6877+
# Skip int16/int32/float64 for 3-channel (only supported for single-channel)
6878+
if num_channels == 3 and dtype in (torch.int16, torch.int32, torch.float64):
6879+
pytest.skip(f"{dtype} is only supported for single-channel images")
68806880

68816881
# Setup: Create a batched tensor in NCHW format
68826882
# Create tensor on CPU first, then move to CUDA to avoid CUDA context issues
6883-
if dtype == torch.uint8:
6883+
if dtype in (torch.uint8, torch.int16, torch.int32):
68846884
original_tensor = torch.randint(0, 256, (batch_size, num_channels, 4, 4), dtype=dtype)
68856885
else:
68866886
original_tensor = torch.rand(batch_size, num_channels, 4, 4, dtype=dtype)

0 commit comments

Comments
 (0)