Skip to content

Commit 251c57a

Browse files
fix: Add type checking in posterize_image. (#8993)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 8a06122 commit 251c57a

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

test/test_transforms_v2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4608,6 +4608,14 @@ def test_correctness_image(self, bits, fn):
46084608

46094609
assert_equal(actual, expected)
46104610

4611+
@pytest.mark.parametrize("bits", [-1, 9, 2.1])
4612+
def test_error_functional(self, bits):
4613+
with pytest.raises(
4614+
TypeError,
4615+
match=re.escape(f"bits must be a positive integer in the range [0, 8], got {bits} instead."),
4616+
):
4617+
F.posterize(make_image(dtype=torch.uint8), bits=bits)
4618+
46114619

46124620
class TestSolarize:
46134621
def _make_threshold(self, input, *, factor=0.5):

torchvision/transforms/v2/functional/_color.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,9 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
460460
@_register_kernel_internal(posterize, torch.Tensor)
461461
@_register_kernel_internal(posterize, tv_tensors.Image)
462462
def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
463+
if not isinstance(bits, int) or not 0 <= bits <= 8:
464+
raise TypeError(f"bits must be a positive integer in the range [0, 8], got {bits} instead.")
465+
463466
if image.is_floating_point():
464467
levels = 1 << bits
465468
return image.mul(levels).floor_().clamp_(0, levels - 1).mul_(1.0 / levels)

0 commit comments

Comments
 (0)