diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 9f6817bb60d..af52d1fca65 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -3514,6 +3514,14 @@ def test_aug_mix_severity_error(self, severity): with pytest.raises(ValueError, match="severity must be between"): transforms.AugMix(severity=severity) + @pytest.mark.parametrize("num_ops", [-1, 1.1]) + def test_rand_augment_num_ops_error(self, num_ops): + with pytest.raises( + ValueError, + match=re.escape(f"num_ops should be a non-negative integer, but got {num_ops} instead."), + ): + transforms.RandAugment(num_ops=num_ops) + class TestConvertBoundingBoxFormat: old_new_formats = list(itertools.permutations(SUPPORTED_BOX_FORMATS, 2)) diff --git a/torchvision/transforms/v2/_auto_augment.py b/torchvision/transforms/v2/_auto_augment.py index 4dd7ba343aa..240330386fb 100644 --- a/torchvision/transforms/v2/_auto_augment.py +++ b/torchvision/transforms/v2/_auto_augment.py @@ -361,7 +361,8 @@ class RandAugment(_AutoAugmentBase): If img is PIL Image, it is expected to be in mode "L" or "RGB". Args: - num_ops (int, optional): Number of augmentation transformations to apply sequentially. + num_ops (int, optional): Number of augmentation transformations to apply sequentially, + must be non-negative integer. Default: 2. magnitude (int, optional): Magnitude for all the transformations. num_magnitude_bins (int, optional): The number of different magnitude values. interpolation (InterpolationMode, optional): Desired interpolation enum defined by @@ -407,6 +408,8 @@ def __init__( fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None, ) -> None: super().__init__(interpolation=interpolation, fill=fill) + if not isinstance(num_ops, int) or (num_ops < 0): + raise ValueError(f"num_ops should be a non-negative integer, but got {num_ops} instead.") self.num_ops = num_ops self.magnitude = magnitude self.num_magnitude_bins = num_magnitude_bins