Skip to content

Commit 26b689d

Browse files
committed
fix: Add type checking in posterize_image.
1 parent 8ea4772 commit 26b689d

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

test/test_transforms_v2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4608,6 +4608,14 @@ def test_correctness_image(self, bits, fn):
46084608

46094609
assert_equal(actual, expected)
46104610

4611+
@pytest.mark.parametrize("bits", [-1, 9, 2.1])
4612+
def test_error_functional(self, bits):
4613+
with pytest.raises(
4614+
TypeError,
4615+
match=re.escape(f"bits must be a positive integer in the range [0, 8], get {bits} instead."),
4616+
):
4617+
F.posterize(make_image(dtype=torch.uint8), bits=bits)
4618+
46114619

46124620
class TestSolarize:
46134621
def _make_threshold(self, input, *, factor=0.5):

torchvision/transforms/v2/functional/_color.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,9 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
460460
@_register_kernel_internal(posterize, torch.Tensor)
461461
@_register_kernel_internal(posterize, tv_tensors.Image)
462462
def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
463+
if (not isinstance(bits, int)) or (bits not in range(0, 9)):
464+
raise TypeError(f"bits must be a positive integer in the range [0, 8], get {bits} instead.")
465+
463466
if image.is_floating_point():
464467
levels = 1 << bits
465468
return image.mul(levels).floor_().clamp_(0, levels - 1).mul_(1.0 / levels)

0 commit comments

Comments
 (0)