Skip to content

Commit a0883e2

Browse files
AntoineSimoulinNicolasHugelmuz
authored
[release/0.23] Cherry pick (#9138)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: elmuz <[email protected]>
1 parent 5dc9e7d commit a0883e2

File tree

14 files changed

+575
-201
lines changed

14 files changed

+575
-201
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def get_dist(pkgname):
111111
]
112112

113113
# Excluding 8.3.* because of https://github.com/pytorch/vision/issues/4934
114-
pillow_ver = " >= 5.3.0, !=8.3.*"
114+
# TODO remove <11.3 bound and address corresponding deprecation warnings
115+
pillow_ver = " >= 5.3.0, !=8.3.*, <11.3"
115116
pillow_req = "pillow-simd" if get_dist("pillow-simd") is not None else "pillow"
116117
requirements.append(pillow_req + pillow_ver)
117118

3.05 KB
Loading

test/common_utils.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch.testing._comparison import BooleanPair, NonePair, not_close_error_metas, NumberPair, TensorLikePair
2222
from torchvision import io, tv_tensors
2323
from torchvision.transforms._functional_tensor import _max_value as get_max_value
24-
from torchvision.transforms.v2.functional import clamp_bounding_boxes, to_image, to_pil_image
24+
from torchvision.transforms.v2.functional import to_image, to_pil_image
2525

2626

2727
IN_OSS_CI = any(os.getenv(var) == "true" for var in ["CIRCLECI", "GITHUB_ACTIONS"])
@@ -410,6 +410,7 @@ def make_bounding_boxes(
410410
canvas_size=DEFAULT_SIZE,
411411
*,
412412
format=tv_tensors.BoundingBoxFormat.XYXY,
413+
clamping_mode="soft",
413414
num_boxes=1,
414415
dtype=None,
415416
device="cpu",
@@ -423,13 +424,6 @@ def sample_position(values, max_value):
423424
format = tv_tensors.BoundingBoxFormat[format]
424425

425426
dtype = dtype or torch.float32
426-
int_dtype = dtype in (
427-
torch.uint8,
428-
torch.int8,
429-
torch.int16,
430-
torch.int32,
431-
torch.int64,
432-
)
433427

434428
h, w = (torch.randint(1, s, (num_boxes,)) for s in canvas_size)
435429
y = sample_position(h, canvas_size[0])
@@ -456,31 +450,19 @@ def sample_position(values, max_value):
456450
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
457451
r_rad = r * torch.pi / 180.0
458452
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
459-
x1 = torch.round(x) if int_dtype else x
460-
y1 = torch.round(y) if int_dtype else y
461-
x2 = torch.round(x1 + w * cos) if int_dtype else x1 + w * cos
462-
y2 = torch.round(y1 - w * sin) if int_dtype else y1 - w * sin
463-
x3 = torch.round(x2 + h * sin) if int_dtype else x2 + h * sin
464-
y3 = torch.round(y2 + h * cos) if int_dtype else y2 + h * cos
465-
x4 = torch.round(x1 + h * sin) if int_dtype else x1 + h * sin
466-
y4 = torch.round(y1 + h * cos) if int_dtype else y1 + h * cos
453+
x1 = x
454+
y1 = y
455+
x2 = x1 + w * cos
456+
y2 = y1 - w * sin
457+
x3 = x2 + h * sin
458+
y3 = y2 + h * cos
459+
x4 = x1 + h * sin
460+
y4 = y1 + h * cos
467461
parts = (x1, y1, x2, y2, x3, y3, x4, y4)
468462
else:
469463
raise ValueError(f"Format {format} is not supported")
470464
out_boxes = torch.stack(parts, dim=-1).to(dtype=dtype, device=device)
471-
if tv_tensors.is_rotated_bounding_format(format):
472-
# The rotated bounding boxes are not guaranteed to be within the canvas by design,
473-
# so we apply clamping. We also add a 2 buffer to the canvas size to avoid
474-
# numerical issues during the testing
475-
buffer = 4
476-
out_boxes = clamp_bounding_boxes(
477-
out_boxes, format=format, canvas_size=(canvas_size[0] - buffer, canvas_size[1] - buffer)
478-
)
479-
if format is tv_tensors.BoundingBoxFormat.XYWHR or format is tv_tensors.BoundingBoxFormat.CXCYWHR:
480-
out_boxes[:, :2] += buffer // 2
481-
elif format is tv_tensors.BoundingBoxFormat.XYXYXYXY:
482-
out_boxes[:, :] += buffer // 2
483-
return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size)
465+
return tv_tensors.BoundingBoxes(out_boxes, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode)
484466

485467

486468
def make_detection_masks(size=DEFAULT_SIZE, *, num_masks=1, dtype=None, device="cpu"):

test/test_transforms_v2.py

Lines changed: 181 additions & 47 deletions
Large diffs are not rendered by default.

test/test_tv_tensors.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,39 @@ def test_bbox_instance(data, format):
6969
)
7070
@pytest.mark.parametrize("scripted", (False, True))
7171
def test_bbox_format(format, is_rotated_expected, scripted):
72-
if isinstance(format, str):
73-
format = tv_tensors.BoundingBoxFormat[(format.upper())]
74-
7572
fn = tv_tensors.is_rotated_bounding_format
7673
if scripted:
7774
fn = torch.jit.script(fn)
7875
assert fn(format) == is_rotated_expected
7976

8077

78+
@pytest.mark.parametrize(
79+
"format, support_integer_dtype",
80+
[
81+
("XYXY", True),
82+
("XYWH", True),
83+
("CXCYWH", True),
84+
("XYXYXYXY", False),
85+
("XYWHR", False),
86+
("CXCYWHR", False),
87+
(tv_tensors.BoundingBoxFormat.XYXY, True),
88+
(tv_tensors.BoundingBoxFormat.XYWH, True),
89+
(tv_tensors.BoundingBoxFormat.CXCYWH, True),
90+
(tv_tensors.BoundingBoxFormat.XYXYXYXY, False),
91+
(tv_tensors.BoundingBoxFormat.XYWHR, False),
92+
(tv_tensors.BoundingBoxFormat.CXCYWHR, False),
93+
],
94+
)
95+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
96+
def test_bbox_format_dtype(format, support_integer_dtype, input_dtype):
97+
tensor = torch.randint(0, 32, size=(5, 2), dtype=input_dtype)
98+
if not input_dtype.is_floating_point and not support_integer_dtype:
99+
with pytest.raises(ValueError, match="Rotated bounding boxes should be floating point tensors"):
100+
tv_tensors.BoundingBoxes(tensor, format=format, canvas_size=(32, 32))
101+
else:
102+
tv_tensors.BoundingBoxes(tensor, format=format, canvas_size=(32, 32))
103+
104+
81105
def test_bbox_dim_error():
82106
data_3d = [[[1, 2, 3, 4]]]
83107
with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):
@@ -406,3 +430,16 @@ def test_return_type_input():
406430
tv_tensors.set_return_type("typo")
407431

408432
tv_tensors.set_return_type("tensor")
433+
434+
435+
def test_box_clamping_mode_default_and_error():
436+
assert (
437+
tv_tensors.BoundingBoxes([0.0, 0.0, 10.0, 10.0], format="XYXY", canvas_size=(100, 100)).clamping_mode == "soft"
438+
)
439+
assert (
440+
tv_tensors.BoundingBoxes([0.0, 0.0, 10.0, 10.0, 0.0], format="XYWHR", canvas_size=(100, 100)).clamping_mode
441+
== "soft"
442+
)
443+
444+
with pytest.raises(ValueError, match="clamping_mode must be"):
445+
tv_tensors.BoundingBoxes([0, 0, 10, 10], format="XYXY", canvas_size=(100, 100), clamping_mode="bad")

test/test_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,17 @@ def test_draw_rotated_boxes():
177177
assert_equal(result, expected)
178178

179179

180+
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
181+
def test_draw_rotated_boxes_fill():
182+
img = torch.full((3, 500, 500), 255, dtype=torch.uint8)
183+
colors = ["blue", "yellow", (0, 255, 0), "black"]
184+
185+
result = utils.draw_bounding_boxes(img, rotated_boxes, colors=colors, fill=True)
186+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_rotated_boxes_fill.png")
187+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
188+
assert_equal(result, expected)
189+
190+
180191
@pytest.mark.parametrize("fill", [True, False])
181192
def test_draw_boxes_dtypes(fill):
182193
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)

torchvision/datasets/fakedata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class FakeData(VisionDataset):
1111
1212
Args:
1313
size (int, optional): Size of the dataset. Default: 1000 images
14-
image_size(tuple, optional): Size if the returned images. Default: (3, 224, 224)
14+
image_size(tuple, optional): Size of the returned images. Default: (3, 224, 224)
1515
num_classes(int, optional): Number of classes in the dataset. Default: 10
1616
transform (callable, optional): A function/transform that takes in a PIL image
1717
and returns a transformed version. E.g, ``transforms.RandomCrop``

torchvision/transforms/v2/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
ScaleJitter,
4242
TenCrop,
4343
)
44-
from ._meta import ClampBoundingBoxes, ClampKeyPoints, ConvertBoundingBoxFormat
44+
from ._meta import ClampBoundingBoxes, ClampKeyPoints, ConvertBoundingBoxFormat, SetClampingMode
4545
from ._misc import (
4646
ConvertImageDtype,
4747
GaussianBlur,

torchvision/transforms/v2/_meta.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from torchvision import tv_tensors
44
from torchvision.transforms.v2 import functional as F, Transform
5+
from torchvision.tv_tensors._bounding_boxes import CLAMPING_MODE_TYPE
56

67

78
class ConvertBoundingBoxFormat(Transform):
@@ -28,12 +29,19 @@ class ClampBoundingBoxes(Transform):
2829
2930
The clamping is done according to the bounding boxes' ``canvas_size`` meta-data.
3031
32+
Args:
33+
clamping_mode: TODOBB more docs. Default is None which relies on the input box' clamping_mode attribute.
34+
3135
"""
3236

37+
def __init__(self, clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto") -> None:
38+
super().__init__()
39+
self.clamping_mode = clamping_mode
40+
3341
_transformed_types = (tv_tensors.BoundingBoxes,)
3442

3543
def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes:
36-
return F.clamp_bounding_boxes(inpt) # type: ignore[return-value]
44+
return F.clamp_bounding_boxes(inpt, clamping_mode=self.clamping_mode) # type: ignore[return-value]
3745

3846

3947
class ClampKeyPoints(Transform):
@@ -46,3 +54,21 @@ class ClampKeyPoints(Transform):
4654

4755
def transform(self, inpt: tv_tensors.KeyPoints, params: dict[str, Any]) -> tv_tensors.KeyPoints:
4856
return F.clamp_keypoints(inpt) # type: ignore[return-value]
57+
58+
59+
class SetClampingMode(Transform):
60+
"""TODOBB"""
61+
62+
def __init__(self, clamping_mode: CLAMPING_MODE_TYPE) -> None:
63+
super().__init__()
64+
self.clamping_mode = clamping_mode
65+
66+
if self.clamping_mode not in (None, "soft", "hard"):
67+
raise ValueError(f"clamping_mode must be soft, hard or None, got {clamping_mode}")
68+
69+
_transformed_types = (tv_tensors.BoundingBoxes,)
70+
71+
def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> tv_tensors.BoundingBoxes:
72+
out: tv_tensors.BoundingBoxes = inpt.clone() # type: ignore[assignment]
73+
out.clamping_mode = self.clamping_mode
74+
return out

0 commit comments

Comments
 (0)