Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3514,6 +3514,14 @@ def test_aug_mix_severity_error(self, severity):
with pytest.raises(ValueError, match="severity must be between"):
transforms.AugMix(severity=severity)

@pytest.mark.parametrize("num_ops", [-1, 1.1])
def test_rand_augment_num_ops_error(self, num_ops):
with pytest.raises(
ValueError,
match=re.escape(f"num_ops should be a non-negative integer, but got {num_ops} instead."),
):
transforms.RandAugment(num_ops=num_ops)


class TestConvertBoundingBoxFormat:
old_new_formats = list(itertools.permutations(SUPPORTED_BOX_FORMATS, 2))
Expand Down
5 changes: 4 additions & 1 deletion torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ class RandAugment(_AutoAugmentBase):
If img is PIL Image, it is expected to be in mode "L" or "RGB".

Args:
num_ops (int, optional): Number of augmentation transformations to apply sequentially.
num_ops (int, optional): Number of augmentation transformations to apply sequentially,
must be non-negative integer.
magnitude (int, optional): Magnitude for all the transformations.
num_magnitude_bins (int, optional): The number of different magnitude values.
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
Expand Down Expand Up @@ -407,6 +408,8 @@ def __init__(
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
) -> None:
super().__init__(interpolation=interpolation, fill=fill)
if not isinstance(num_ops, int) or (num_ops < 0):
raise ValueError(f"num_ops should be a non-negative integer, but got {num_ops} instead.")
self.num_ops = num_ops
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins
Expand Down