Skip to content

Commit 1cf29d3

Browse files
committed
O(1) check
1 parent 26b689d commit 1cf29d3

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

test/test_transforms_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4612,7 +4612,7 @@ def test_correctness_image(self, bits, fn):
46124612
def test_error_functional(self, bits):
46134613
with pytest.raises(
46144614
TypeError,
4615-
match=re.escape(f"bits must be a positive integer in the range [0, 8], get {bits} instead."),
4615+
match=re.escape(f"bits must be a positive integer in the range [0, 8], got {bits} instead."),
46164616
):
46174617
F.posterize(make_image(dtype=torch.uint8), bits=bits)
46184618

torchvision/transforms/v2/functional/_color.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,8 +460,8 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor:
460460
@_register_kernel_internal(posterize, torch.Tensor)
461461
@_register_kernel_internal(posterize, tv_tensors.Image)
462462
def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor:
463-
if (not isinstance(bits, int)) or (bits not in range(0, 9)):
464-
raise TypeError(f"bits must be a positive integer in the range [0, 8], get {bits} instead.")
463+
if not isinstance(bits, int) or not 0 <= bits <= 8:
464+
raise TypeError(f"bits must be a positive integer in the range [0, 8], got {bits} instead.")
465465

466466
if image.is_floating_point():
467467
levels = 1 << bits

0 commit comments

Comments
 (0)