Skip to content

Commit f3f2c35

Browse files
fix: Raise error when receive non-positive value in RandAugment. (#8994)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 251c57a commit f3f2c35

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

test/test_transforms_v2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3514,6 +3514,14 @@ def test_aug_mix_severity_error(self, severity):
35143514
with pytest.raises(ValueError, match="severity must be between"):
35153515
transforms.AugMix(severity=severity)
35163516

3517+
@pytest.mark.parametrize("num_ops", [-1, 1.1])
3518+
def test_rand_augment_num_ops_error(self, num_ops):
3519+
with pytest.raises(
3520+
ValueError,
3521+
match=re.escape(f"num_ops should be a non-negative integer, but got {num_ops} instead."),
3522+
):
3523+
transforms.RandAugment(num_ops=num_ops)
3524+
35173525

35183526
class TestConvertBoundingBoxFormat:
35193527
old_new_formats = list(itertools.permutations(SUPPORTED_BOX_FORMATS, 2))

torchvision/transforms/v2/_auto_augment.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,8 @@ class RandAugment(_AutoAugmentBase):
361361
If img is PIL Image, it is expected to be in mode "L" or "RGB".
362362
363363
Args:
364-
num_ops (int, optional): Number of augmentation transformations to apply sequentially.
364+
num_ops (int, optional): Number of augmentation transformations to apply sequentially,
365+
must be non-negative integer. Default: 2.
365366
magnitude (int, optional): Magnitude for all the transformations.
366367
num_magnitude_bins (int, optional): The number of different magnitude values.
367368
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
@@ -407,6 +408,8 @@ def __init__(
407408
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
408409
) -> None:
409410
super().__init__(interpolation=interpolation, fill=fill)
411+
if not isinstance(num_ops, int) or (num_ops < 0):
412+
raise ValueError(f"num_ops should be a non-negative integer, but got {num_ops} instead.")
410413
self.num_ops = num_ops
411414
self.magnitude = magnitude
412415
self.num_magnitude_bins = num_magnitude_bins

0 commit comments

Comments
 (0)