Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4608,6 +4608,14 @@ def test_correctness_image(self, bits, fn):

assert_equal(actual, expected)

@pytest.mark.parametrize("bits", [-1, 9, 2.1])
def test_error_functional(self, bits):
with pytest.raises(
TypeError,
match=re.escape(f"bits must be a positive integer in the range [0, 8], get {bits} instead."),
):
F.posterize(make_image(dtype=torch.uint8), bits=bits)


class TestSolarize:
def _make_threshold(self, input, *, factor=0.5):
Expand Down
3 changes: 3 additions & 0 deletions torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,9 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
@_register_kernel_internal(posterize, torch.Tensor)
@_register_kernel_internal(posterize, tv_tensors.Image)
def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
if (not isinstance(bits, int)) or (bits not in range(0, 9)):
raise TypeError(f"bits must be a positive integer in the range [0, 8], get {bits} instead.")

if image.is_floating_point():
levels = 1 << bits
return image.mul(levels).floor_().clamp_(0, levels - 1).mul_(1.0 / levels)
Expand Down