diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index af52d1fca65..d27b2682055 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -6206,6 +6206,11 @@ def test_transform_invalid_quality_error(self, quality): with pytest.raises(ValueError, match="quality must be an integer from 1 to 100"): transforms.JPEG(quality=quality) + @pytest.mark.parametrize("quality", [None, True]) + def test_transform_quality_type_error(self, quality): + with pytest.raises(TypeError, match="quality"): + transforms.JPEG(quality=quality) + class TestUtils: # TODO: Still need to test has_all, has_any, check_type and get_bouding_boxes diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 93d4ba45d65..2aad7bd4dc3 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -352,6 +352,8 @@ class JPEG(Transform): def __init__(self, quality: Union[int, Sequence[int]]): super().__init__() if isinstance(quality, int): + if isinstance(quality, bool): + raise TypeError("quality can't be bool") quality = [quality, quality] else: _check_sequence_input(quality, "quality", req_sizes=(2,))