Skip to content

Commit 36b02dd

Browse files
Fix PR comments for rorated box transforms
1 parent e223c6f commit 36b02dd

File tree

11 files changed

+100
-123
lines changed

11 files changed

+100
-123
lines changed
1.9 KB
Loading

test/common_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -444,13 +444,13 @@ def sample_position(values, max_value):
444444
r_rad = r * torch.pi / 180.0
445445
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
446446
x1, y1 = x, y
447-
x3 = x1 + w * cos
448-
y3 = y1 - w * sin
449-
x2 = x3 + h * sin
450-
y2 = y3 + h * cos
447+
x2 = x1 + w * cos
448+
y2 = y1 - w * sin
449+
x3 = x2 + h * sin
450+
y3 = y2 + h * cos
451451
x4 = x1 + h * sin
452452
y4 = y1 + h * cos
453-
parts = (x1, y1, x3, y3, x2, y2, x4, y4)
453+
parts = (x1, y1, x2, y2, x3, y3, x4, y4)
454454
else:
455455
raise ValueError(f"Format {format} is not supported")
456456

test/test_transforms_v2.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -575,13 +575,13 @@ def affine_rotated_bounding_boxes(bounding_boxes):
575575
new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY,
576576
inplace=True,
577577
)
578-
x1, y1, x3, y3, x2, y2, x4, y4 = input_xyxyxyxy.squeeze(0).tolist()
578+
x1, y1, x2, y2, x3, y3, x4, y4 = input_xyxyxyxy.squeeze(0).tolist()
579579

580580
points = np.array(
581581
[
582582
[x1, y1, 1.0],
583-
[x3, y3, 1.0],
584583
[x2, y2, 1.0],
584+
[x3, y3, 1.0],
585585
[x4, y4, 1.0],
586586
]
587587
)
@@ -604,14 +604,14 @@ def affine_rotated_bounding_boxes(bounding_boxes):
604604
)
605605

606606
if clamp:
607-
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
607+
# It is important to clamp before casting, especially for CXCYWHR format, dtype=int64
608608
output = F.clamp_bounding_boxes(
609609
output,
610610
format=format,
611611
canvas_size=canvas_size,
612612
)
613613
else:
614-
# We leave the bounding box as float64 so the caller gets the full precision to perform any additional
614+
# We leave the bounding box as float32 so the caller gets the full precision to perform any additional
615615
# operation
616616
dtype = output.dtype
617617

@@ -1143,17 +1143,20 @@ def test_image_correctness(self, fn):
11431143

11441144
torch.testing.assert_close(actual, expected)
11451145

1146-
def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes, format):
1146+
def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
11471147
affine_matrix = np.array(
11481148
[
11491149
[-1, 0, bounding_boxes.canvas_size[1]],
11501150
[0, 1, 0],
11511151
],
11521152
)
11531153

1154-
if tv_tensors.is_rotated_bounding_format(format):
1155-
return reference_affine_rotated_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
1156-
return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
1154+
helper = (
1155+
reference_affine_rotated_bounding_boxes_helper
1156+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
1157+
else reference_affine_bounding_boxes_helper
1158+
)
1159+
return helper(bounding_boxes, affine_matrix=affine_matrix)
11571160

11581161
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
11591162
@pytest.mark.parametrize(
@@ -1163,7 +1166,7 @@ def test_bounding_boxes_correctness(self, format, fn):
11631166
bounding_boxes = make_bounding_boxes(format=format)
11641167

11651168
actual = fn(bounding_boxes)
1166-
expected = self._reference_horizontal_flip_bounding_boxes(bounding_boxes, format)
1169+
expected = self._reference_horizontal_flip_bounding_boxes(bounding_boxes)
11671170

11681171
torch.testing.assert_close(actual, expected)
11691172

@@ -1595,25 +1598,28 @@ def test_image_correctness(self, fn):
15951598

15961599
torch.testing.assert_close(actual, expected)
15971600

1598-
def _reference_vertical_flip_bounding_boxes(self, bounding_boxes, format):
1601+
def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
15991602
affine_matrix = np.array(
16001603
[
16011604
[1, 0, 0],
16021605
[0, -1, bounding_boxes.canvas_size[0]],
16031606
],
16041607
)
16051608

1606-
if tv_tensors.is_rotated_bounding_format(format):
1607-
return reference_affine_rotated_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
1608-
return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
1609+
helper = (
1610+
reference_affine_rotated_bounding_boxes_helper
1611+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
1612+
else reference_affine_bounding_boxes_helper
1613+
)
1614+
return helper(bounding_boxes, affine_matrix=affine_matrix)
16091615

16101616
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
16111617
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
16121618
def test_bounding_boxes_correctness(self, format, fn):
16131619
bounding_boxes = make_bounding_boxes(format=format)
16141620

16151621
actual = fn(bounding_boxes)
1616-
expected = self._reference_vertical_flip_bounding_boxes(bounding_boxes, format)
1622+
expected = self._reference_vertical_flip_bounding_boxes(bounding_boxes)
16171623

16181624
torch.testing.assert_close(actual, expected)
16191625

test/test_tv_tensors.py

Lines changed: 19 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -44,69 +44,31 @@ def test_bbox_instance(data, format):
4444

4545

4646
@pytest.mark.parametrize(
47-
"format",
47+
"format, is_rotated_expected",
4848
[
49-
"XYXY",
50-
"XYWH",
51-
"CXCYWH",
52-
"XYXYXYXY",
53-
"XYWHR",
54-
"CXCYWHR",
55-
tv_tensors.BoundingBoxFormat.XYXY,
56-
tv_tensors.BoundingBoxFormat.XYWH,
57-
tv_tensors.BoundingBoxFormat.CXCYWH,
58-
tv_tensors.BoundingBoxFormat.XYXYXYXY,
59-
tv_tensors.BoundingBoxFormat.XYWHR,
60-
tv_tensors.BoundingBoxFormat.CXCYWHR,
49+
("XYXY", False),
50+
("XYWH", False),
51+
("CXCYWH", False),
52+
("XYXYXYXY", True),
53+
("XYWHR", True),
54+
("CXCYWHR", True),
55+
(tv_tensors.BoundingBoxFormat.XYXY, False),
56+
(tv_tensors.BoundingBoxFormat.XYWH, False),
57+
(tv_tensors.BoundingBoxFormat.CXCYWH, False),
58+
(tv_tensors.BoundingBoxFormat.XYXYXYXY, True),
59+
(tv_tensors.BoundingBoxFormat.XYWHR, True),
60+
(tv_tensors.BoundingBoxFormat.CXCYWHR, True),
6161
],
6262
)
63-
def test_bbox_format(format):
63+
@pytest.mark.parametrize("scripted", (False, True))
64+
def test_bbox_format(format, is_rotated_expected, scripted):
6465
if isinstance(format, str):
6566
format = tv_tensors.BoundingBoxFormat[(format.upper())]
66-
if format == tv_tensors.BoundingBoxFormat.XYXYXYXY:
67-
assert tv_tensors.is_rotated_bounding_format(format) is True
68-
elif format == tv_tensors.BoundingBoxFormat.XYWHR:
69-
assert tv_tensors.is_rotated_bounding_format(format) is True
70-
elif format == tv_tensors.BoundingBoxFormat.CXCYWHR:
71-
assert tv_tensors.is_rotated_bounding_format(format) is True
72-
else:
73-
assert tv_tensors.is_rotated_bounding_format(format) is False
7467

75-
76-
@pytest.mark.parametrize(
77-
"format",
78-
[
79-
"XYXY",
80-
"XYWH",
81-
"CXCYWH",
82-
"XYXYXYXY",
83-
"XYWHR",
84-
"CXCYWHR",
85-
tv_tensors.BoundingBoxFormat.XYXY,
86-
tv_tensors.BoundingBoxFormat.XYWH,
87-
tv_tensors.BoundingBoxFormat.CXCYWH,
88-
tv_tensors.BoundingBoxFormat.XYXYXYXY,
89-
tv_tensors.BoundingBoxFormat.XYWHR,
90-
tv_tensors.BoundingBoxFormat.CXCYWHR,
91-
],
92-
)
93-
def test_bbox_format_scripted(format):
94-
obj = tv_tensors.is_rotated_bounding_format
95-
try:
96-
fn = torch.jit.script(obj)
97-
except Exception as error:
98-
name = getattr(obj, "__name__", obj.__class__.__name__)
99-
raise AssertionError(f"Trying to `torch.jit.script` `{name}` raised the error above.") from error
100-
if isinstance(format, str):
101-
format = tv_tensors.BoundingBoxFormat[(format.upper())]
102-
if format == tv_tensors.BoundingBoxFormat.XYXYXYXY:
103-
assert fn(format) is True
104-
elif format == tv_tensors.BoundingBoxFormat.XYWHR:
105-
assert fn(format) is True
106-
elif format == tv_tensors.BoundingBoxFormat.CXCYWHR:
107-
assert fn(format) is True
108-
else:
109-
assert fn(format) is False
68+
fn = tv_tensors.is_rotated_bounding_format
69+
if scripted:
70+
fn = torch.jit.script(fn)
71+
assert fn(format) == is_rotated_expected
11072

11173

11274
def test_bbox_dim_error():

test/test_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,14 @@
1717
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
1818

1919
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
20-
20+
rotated_boxes = torch.tensor(
21+
[
22+
[100, 150, 150, 150, 150, 250, 100, 250],
23+
[200, 350, 250, 350, 250, 250, 200, 250],
24+
[300, 200, 200, 200, 200, 250, 300, 250],
25+
],
26+
dtype=torch.float,
27+
)
2128
keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float)
2229

2330

@@ -148,6 +155,18 @@ def test_draw_boxes_with_coloured_label_backgrounds():
148155
assert_equal(result, expected)
149156

150157

158+
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
159+
def test_draw_rotatated_boxes():
160+
img = torch.full((3, 500, 500), 255, dtype=torch.uint8)
161+
colors = ["blue", "yellow", (0, 255, 0)]
162+
163+
result = utils.draw_bounding_boxes(img, rotated_boxes, colors=colors)
164+
expected = torch.as_tensor(np.array(result)).permute(2, 0, 1)
165+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_rotated_boxes.png")
166+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
167+
assert_equal(result, expected)
168+
169+
151170
@pytest.mark.parametrize("fill", [True, False])
152171
def test_draw_boxes_dtypes(fill):
153172
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)

torchvision/ops/_box_convert.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -130,56 +130,56 @@ def _box_xywhr_to_cxcywhr(boxes: Tensor) -> Tensor:
130130

131131
def _box_xywhr_to_xyxyxyxy(boxes: Tensor) -> Tensor:
132132
"""
133-
Converts rotated bounding boxes from (x1, y1, w, h, r) format to (x1, y1, x3, y3, x2, y2, x4, y4) format.
133+
Converts rotated bounding boxes from (x1, y1, w, h, r) format to (x1, y1, x2, y2, x3, y3, x4, y4) format.
134134
(x1, y1) refer to top left of bounding box
135135
(w, h) are width and height of the rotated bounding box
136136
r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
137137
138138
(x1, y1) refer to top left of rotated bounding box
139-
(x3, y3) refer to top right of rotated bounding box
140-
(x2, y2) refer to bottom right of rotated bounding box
139+
(x2, y2) refer to top right of rotated bounding box
140+
(x3, y3) refer to bottom right of rotated bounding box
141141
(x4, y4) refer to bottom left ofrotated bounding box
142142
Args:
143143
boxes (Tensor[N, 5]): rotated boxes in (cx, cy, w, h, r) format which will be converted.
144144
145145
Returns:
146-
boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x3, y3, x2, y2, x4, y4) format.
146+
boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x2, y2, x3, y3, x4, y4) format.
147147
"""
148148
x1, y1, w, h, r = boxes.unbind(-1)
149149
r_rad = r * torch.pi / 180.0
150150
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
151151

152-
x3 = x1 + w * cos
153-
y3 = y1 - w * sin
154-
x2 = x3 + h * sin
155-
y2 = y3 + h * cos
152+
x2 = x1 + w * cos
153+
y2 = y1 - w * sin
154+
x3 = x2 + h * sin
155+
y3 = y2 + h * cos
156156
x4 = x1 + h * sin
157157
y4 = y1 + h * cos
158158

159-
return torch.stack((x1, y1, x3, y3, x2, y2, x4, y4), dim=-1)
159+
return torch.stack((x1, y1, x2, y2, x3, y3, x4, y4), dim=-1)
160160

161161

162162
def _box_xyxyxyxy_to_xywhr(boxes: Tensor) -> Tensor:
163163
"""
164-
Converts rotated bounding boxes from (x1, y1, x3, y3, x2, y2, x4, y4) format to (x1, y1, w, h, r) format.
164+
Converts rotated bounding boxes from (x1, y1, x2, y2, x3, y3, x4, y4) format to (x1, y1, w, h, r) format.
165165
(x1, y1) refer to top left of the rotated bounding box
166-
(x3, y3) refer to bottom left of the rotated bounding box
167-
(x2, y2) refer to bottom right of the rotated bounding box
166+
(x2, y2) refer to bottom left of the rotated bounding box
167+
(x3, y3) refer to bottom right of the rotated bounding box
168168
(x4, y4) refer to top right of the rotated bounding box
169169
(w, h) refers to width and height of rotated bounding box
170170
r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
171171
172172
Args:
173-
boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x3, y3, x2, y2, x4, y4) format.
173+
boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x2, y2, x3, y3, x4, y4) format.
174174
175175
Returns:
176176
boxes (Tensor[N, 5]): rotated boxes in (x1, y1, w, h, r) format.
177177
"""
178-
x1, y1, x3, y3, x2, y2, x4, y4 = boxes.unbind(-1)
179-
r_rad = torch.atan2(y1 - y3, x3 - x1)
178+
x1, y1, x2, y2, x3, y3, x4, y4 = boxes.unbind(-1)
179+
r_rad = torch.atan2(y1 - y2, x2 - x1)
180180
r = r_rad * 180 / torch.pi
181181

182-
w = ((x3 - x1) ** 2 + (y1 - y3) ** 2).sqrt()
182+
w = ((x2 - x1) ** 2 + (y1 - y2) ** 2).sqrt()
183183
h = ((x3 - x2) ** 2 + (y3 - y2) ** 2).sqrt()
184184

185185
boxes = torch.stack((x1, y1, w, h, r), dim=-1)

torchvision/ops/boxes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
209209
being width and height.
210210
r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
211211
212-
``'xyxyxyxy'``: boxes are represented via corners, x1, y1 being top left, x2, y2 bottom right,
213-
x3, y3 bottom left, and x4, y4 top right.
212+
``'xyxyxyxy'``: boxes are represented via corners, x1, y1 being top left, x2, y2 top right,
213+
x3, y3 bottom right, and x4, y4 bottom left.
214214
215215
Args:
216216
boxes (Tensor[N, K]): boxes which will be converted. K is the number of coordinates (4 for unrotated bounding boxes, 5 or 8 for rotated bounding boxes)

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def horizontal_flip_bounding_boxes(
9494
dtype = bounding_boxes.dtype
9595
if not torch.is_floating_point(bounding_boxes):
9696
# Casting to float to support cos and sin computations.
97-
bounding_boxes = bounding_boxes.to(torch.float64)
97+
bounding_boxes = bounding_boxes.to(torch.float32)
9898
angle_rad = bounding_boxes[:, 4].mul(torch.pi).div(180)
9999
bounding_boxes[:, 0].add_(bounding_boxes[:, 2].mul(angle_rad.cos())).sub_(canvas_size[1]).neg_()
100100
bounding_boxes[:, 1].sub_(bounding_boxes[:, 2].mul(angle_rad.sin()))

torchvision/transforms/v2/functional/_meta.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -227,13 +227,13 @@ def _xywhr_to_xyxyxyxy(xywhr: torch.Tensor, inplace: bool) -> torch.Tensor:
227227
r_rad = xywhr[..., 4].mul(torch.pi).div(180.0)
228228
cos, sin = r_rad.cos(), r_rad.sin()
229229
xywhr = xywhr[..., :2].tile((1, 4))
230-
# x1 + w * cos = x3
230+
# x1 + w * cos = x2
231231
xywhr[..., 2].add_(wh[..., 0].mul(cos))
232-
# y1 - w * sin = y3
232+
# y1 - w * sin = y2
233233
xywhr[..., 3].sub_(wh[..., 0].mul(sin))
234-
# x1 + w * cos + h * sin = x2
234+
# x1 + w * cos + h * sin = x3
235235
xywhr[..., 4].add_(wh[..., 0].mul(cos).add(wh[..., 1].mul(sin)))
236-
# y1 - w * sin + h * cos = y2
236+
# y1 - w * sin + h * cos = y3
237237
xywhr[..., 5].sub_(wh[..., 0].mul(sin).sub(wh[..., 1].mul(cos)))
238238
# x1 + h * sin = x4
239239
xywhr[..., 6].add_(wh[..., 1].mul(sin))
@@ -252,12 +252,12 @@ def _xyxyxyxy_to_xywhr(xyxyxyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
252252
xyxyxyxy = xyxyxyxy.float()
253253

254254
r_rad = torch.atan2(xyxyxyxy[..., 1].sub(xyxyxyxy[..., 3]), xyxyxyxy[..., 2].sub(xyxyxyxy[..., 0]))
255-
# x1, y1, (x3 - x1), (y3 - y1), (x2 - x3), (y2 - y3) x4, y4
255+
# x1, y1, (x2 - x1), (y2 - y1), (x3 - x2), (y3 - y2) x4, y4
256256
xyxyxyxy[..., 4:6].sub_(xyxyxyxy[..., 2:4])
257257
xyxyxyxy[..., 2:4].sub_(xyxyxyxy[..., :2])
258-
# sqrt((x3 - x1) ** 2 + (y1 - y3) ** 2) = w
258+
# sqrt((x2 - x1) ** 2 + (y1 - y2) ** 2) = w
259259
xyxyxyxy[..., 2] = xyxyxyxy[..., 2].pow(2).add(xyxyxyxy[..., 3].pow(2)).sqrt()
260-
# sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) = h
260+
# sqrt((x2 - x3) ** 2 + (y2 - y3) ** 2) = h
261261
xyxyxyxy[..., 3] = xyxyxyxy[..., 4].pow(2).add(xyxyxyxy[..., 5].pow(2)).sqrt()
262262
xyxyxyxy[..., 4] = r_rad.div_(torch.pi).mul_(180.0)
263263
return xyxyxyxy[..., :5].to(dtype)

torchvision/tv_tensors/_bounding_boxes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ class BoundingBoxFormat(Enum):
2626
cy being center of box, w, h being width and height. r is rotation angle
2727
in degrees.
2828
* ``XYXYXYXY``: rotated boxes represented via corners, x1, y1 being top
29-
left, x2, y2 being bottom right, x3, y3 being bottom left, x4, y4 being
30-
top right.
29+
left, x2, y2 being top right, x3, y3 being bottom right, x4, y4 being
30+
bottom left.
3131
"""
3232

3333
XYXY = "XYXY"

0 commit comments

Comments
 (0)