Skip to content

Commit d4b130d

Browse files
committed
Add more tests
1 parent 0db21e0 commit d4b130d

File tree

3 files changed

+88
-34
lines changed

3 files changed

+88
-34
lines changed

test/test_transforms_v2.py

Lines changed: 86 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -654,14 +654,8 @@ def affine_keypoints(keypoints):
654654
)
655655

656656
if clamp:
657-
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
658-
output = F.clamp_keypoints(
659-
output,
660-
canvas_size=canvas_size,
661-
)
657+
output = F.clamp_keypoints(output, canvas_size=canvas_size)
662658
else:
663-
# We leave the bounding box as float64 so the caller gets the full precision to perform any additional
664-
# operation
665659
dtype = output.dtype
666660

667661
return output.to(dtype=dtype, device=device)
@@ -803,7 +797,15 @@ def test_kernel_video(self):
803797
@pytest.mark.parametrize("size", OUTPUT_SIZES)
804798
@pytest.mark.parametrize(
805799
"make_input",
806-
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
800+
[
801+
make_image_tensor,
802+
make_image_pil,
803+
make_image,
804+
make_bounding_boxes,
805+
make_segmentation_mask,
806+
make_video,
807+
make_keypoints,
808+
],
807809
)
808810
def test_functional(self, size, make_input):
809811
max_size_kwarg = self._make_max_size_kwarg(use_max_size=size is None, size=size)
@@ -844,6 +846,7 @@ def test_functional_signature(self, kernel, input_type):
844846
make_segmentation_mask,
845847
make_detection_masks,
846848
make_video,
849+
make_keypoints,
847850
],
848851
)
849852
def test_transform(self, size, device, make_input):
@@ -901,6 +904,22 @@ def _reference_resize_bounding_boxes(self, bounding_boxes, *, size, max_size=Non
901904
new_canvas_size=(new_height, new_width),
902905
)
903906

907+
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
908+
@pytest.mark.parametrize("size", OUTPUT_SIZES)
909+
@pytest.mark.parametrize("use_max_size", [True, False])
910+
@pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
911+
def test_bounding_boxes_correctness(self, format, size, use_max_size, fn):
912+
if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
913+
return
914+
915+
bounding_boxes = make_bounding_boxes(format=format, canvas_size=self.INPUT_SIZE)
916+
917+
actual = fn(bounding_boxes, size=size, **max_size_kwarg)
918+
expected = self._reference_resize_bounding_boxes(bounding_boxes, size=size, **max_size_kwarg)
919+
920+
self._check_output_size(bounding_boxes, actual, size=size, **max_size_kwarg)
921+
torch.testing.assert_close(actual, expected)
922+
904923
def _reference_resize_keypoints(self, keypoints, *, size, max_size=None):
905924
old_height, old_width = keypoints.canvas_size
906925
new_height, new_width = self._compute_output_size(
@@ -923,22 +942,6 @@ def _reference_resize_keypoints(self, keypoints, *, size, max_size=None):
923942
new_canvas_size=(new_height, new_width),
924943
)
925944

926-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
927-
@pytest.mark.parametrize("size", OUTPUT_SIZES)
928-
@pytest.mark.parametrize("use_max_size", [True, False])
929-
@pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
930-
def test_bounding_boxes_correctness(self, format, size, use_max_size, fn):
931-
if not (max_size_kwarg := self._make_max_size_kwarg(use_max_size=use_max_size, size=size)):
932-
return
933-
934-
bounding_boxes = make_bounding_boxes(format=format, canvas_size=self.INPUT_SIZE)
935-
936-
actual = fn(bounding_boxes, size=size, **max_size_kwarg)
937-
expected = self._reference_resize_bounding_boxes(bounding_boxes, size=size, **max_size_kwarg)
938-
939-
self._check_output_size(bounding_boxes, actual, size=size, **max_size_kwarg)
940-
torch.testing.assert_close(actual, expected)
941-
942945
@pytest.mark.parametrize("size", OUTPUT_SIZES)
943946
@pytest.mark.parametrize("use_max_size", [True, False])
944947
@pytest.mark.parametrize("fn", [F.resize, transform_cls_to_functional(transforms.Resize)])
@@ -989,6 +992,7 @@ def test_functional_pil_antialias_warning(self):
989992
make_segmentation_mask,
990993
make_detection_masks,
991994
make_video,
995+
make_keypoints,
992996
],
993997
)
994998
def test_max_size_error(self, size, make_input):
@@ -1031,6 +1035,7 @@ def test_max_size_error(self, size, make_input):
10311035
make_segmentation_mask,
10321036
make_detection_masks,
10331037
make_video,
1038+
make_keypoints,
10341039
],
10351040
)
10361041
def test_resize_size_none(self, input_size, max_size, expected_size, make_input):
@@ -1076,6 +1081,7 @@ def test_transform_unknown_size_error(self):
10761081
make_segmentation_mask,
10771082
make_detection_masks,
10781083
make_video,
1084+
make_keypoints,
10791085
],
10801086
)
10811087
def test_noop(self, size, make_input):
@@ -1103,6 +1109,7 @@ def test_noop(self, size, make_input):
11031109
make_segmentation_mask,
11041110
make_detection_masks,
11051111
make_video,
1112+
make_keypoints,
11061113
],
11071114
)
11081115
def test_no_regression_5405(self, make_input):
@@ -1215,7 +1222,15 @@ def test_kernel_video(self):
12151222

12161223
@pytest.mark.parametrize(
12171224
"make_input",
1218-
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1225+
[
1226+
make_image_tensor,
1227+
make_image_pil,
1228+
make_image,
1229+
make_bounding_boxes,
1230+
make_segmentation_mask,
1231+
make_video,
1232+
make_keypoints,
1233+
],
12191234
)
12201235
def test_functional(self, make_input):
12211236
check_functional(F.horizontal_flip, make_input())
@@ -1237,7 +1252,15 @@ def test_functional_signature(self, kernel, input_type):
12371252

12381253
@pytest.mark.parametrize(
12391254
"make_input",
1240-
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1255+
[
1256+
make_image_tensor,
1257+
make_image_pil,
1258+
make_image,
1259+
make_bounding_boxes,
1260+
make_segmentation_mask,
1261+
make_video,
1262+
make_keypoints,
1263+
],
12411264
)
12421265
@pytest.mark.parametrize("device", cpu_and_cuda())
12431266
def test_transform(self, make_input, device):
@@ -1304,7 +1327,15 @@ def test_keypoints_correctness(self, fn):
13041327

13051328
@pytest.mark.parametrize(
13061329
"make_input",
1307-
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1330+
[
1331+
make_image_tensor,
1332+
make_image_pil,
1333+
make_image,
1334+
make_bounding_boxes,
1335+
make_segmentation_mask,
1336+
make_video,
1337+
make_keypoints,
1338+
],
13081339
)
13091340
@pytest.mark.parametrize("device", cpu_and_cuda())
13101341
def test_transform_noop(self, make_input, device):
@@ -1778,7 +1809,15 @@ def test_kernel_video(self):
17781809

17791810
@pytest.mark.parametrize(
17801811
"make_input",
1781-
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1812+
[
1813+
make_image_tensor,
1814+
make_image_pil,
1815+
make_image,
1816+
make_bounding_boxes,
1817+
make_segmentation_mask,
1818+
make_video,
1819+
make_keypoints,
1820+
],
17821821
)
17831822
def test_functional(self, make_input):
17841823
check_functional(F.vertical_flip, make_input())
@@ -1800,7 +1839,15 @@ def test_functional_signature(self, kernel, input_type):
18001839

18011840
@pytest.mark.parametrize(
18021841
"make_input",
1803-
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1842+
[
1843+
make_image_tensor,
1844+
make_image_pil,
1845+
make_image,
1846+
make_bounding_boxes,
1847+
make_segmentation_mask,
1848+
make_video,
1849+
make_keypoints,
1850+
],
18041851
)
18051852
@pytest.mark.parametrize("device", cpu_and_cuda())
18061853
def test_transform(self, make_input, device):
@@ -1861,7 +1908,15 @@ def test_keypoints_correctness(self, fn):
18611908

18621909
@pytest.mark.parametrize(
18631910
"make_input",
1864-
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
1911+
[
1912+
make_image_tensor,
1913+
make_image_pil,
1914+
make_image,
1915+
make_bounding_boxes,
1916+
make_segmentation_mask,
1917+
make_video,
1918+
make_keypoints,
1919+
],
18651920
)
18661921
@pytest.mark.parametrize("device", cpu_and_cuda())
18671922
def test_transform_noop(self, make_input, device):
@@ -6975,9 +7030,7 @@ def test_convert_bounding_boxes_to_points(self, boxes: tv_tensors.BoundingBoxes)
69757030
intermediate_format = tv_tensors.BoundingBoxFormat.XYXY
69767031

69777032
reconverted_bbox = F.convert_bounding_box_format(
6978-
tv_tensors.BoundingBoxes(
6979-
reconverted, format=intermediate_format, canvas_size=kp.canvas_size
6980-
),
7033+
tv_tensors.BoundingBoxes(reconverted, format=intermediate_format, canvas_size=kp.canvas_size),
69817034
new_format=boxes.format,
69827035
)
69837036
assert_equal(reconverted_bbox, boxes, atol=1e-5, rtol=0)

torchvision/transforms/v2/functional/_meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,11 +462,11 @@ def _clamp_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]) -> t
462462
return keypoints.to(dtype=dtype)
463463

464464

465+
# TODOKP there is no corresponding transform and this isn't tested
465466
def clamp_keypoints(
466467
inpt: torch.Tensor,
467468
canvas_size: Optional[tuple[int, int]] = None,
468469
) -> torch.Tensor:
469-
"""See :func:`~torchvision.transforms.v2.ClampKeyPoints` for details."""
470470
if not torch.jit.is_scripting():
471471
_log_api_usage_once(clamp_keypoints)
472472

torchvision/transforms/v2/functional/_misc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def sanitize_keypoints(
346346
.. note::
347347
348348
Points that touch the edge of the canvas are removed, unlike for :func:`sanitize_bounding_boxes`.
349+
TODOKP Is this desirable? We probably want keypoints to behave the same as bboxes?
349350
350351
Raises:
351352
ValueError: If the keypoints are not passed as a two dimensional tensor.

0 commit comments

Comments
 (0)