Skip to content

Commit e9e901d

Browse files
lint
1 parent 4368999 commit e9e901d

File tree

6 files changed

+101
-38
lines changed

6 files changed

+101
-38
lines changed

test/test_transforms_v2.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,8 @@ def affine_keypoints(keypoints):
661661

662662
return tv_tensors.KeyPoints(
663663
torch.cat([affine_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape(keypoints.shape),
664-
canvas_size=canvas_size, clamping_mode=clamping_mode
664+
canvas_size=canvas_size,
665+
clamping_mode=clamping_mode,
665666
)
666667

667668

@@ -5362,11 +5363,9 @@ def perspective_keypoints(keypoints):
53625363
)
53635364

53645365
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
5365-
return F.clamp_keypoints(
5366-
output,
5367-
canvas_size=canvas_size,
5368-
clamping_mode=clamping_mode
5369-
).to(dtype=dtype, device=device)
5366+
return F.clamp_keypoints(output, canvas_size=canvas_size, clamping_mode=clamping_mode).to(
5367+
dtype=dtype, device=device
5368+
)
53705369

53715370
return tv_tensors.KeyPoints(
53725371
torch.cat([perspective_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape(
@@ -5791,9 +5790,14 @@ def test_clamping_mode(self, constructor_clamping_mode, clamping_mode, pass_pure
57915790
return
57925791

57935792
keypoints = tv_tensors.KeyPoints(
5794-
[[0, 100], [0, 100]],canvas_size=(10, 10), clamping_mode=constructor_clamping_mode
5793+
[[0, 100], [0, 100]], canvas_size=(10, 10), clamping_mode=constructor_clamping_mode
5794+
)
5795+
expected_clamped_output = (
5796+
torch.tensor([[0, 10], [0, 10]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]])
5797+
)
5798+
expected_clamped_output = (
5799+
torch.tensor([[0, 9], [0, 9]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]])
57955800
)
5796-
expected_clamped_output = torch.tensor([[0, 9], [0, 9]]) if clamping_mode == "hard" else torch.tensor([[0, 100], [0, 100]])
57975801

57985802
if pass_pure_tensor:
57995803
out = fn(

torchvision/transforms/v2/_meta.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,15 @@ def transform(self, inpt: tv_tensors.BoundingBoxes, params: dict[str, Any]) -> t
4646
class ClampKeyPoints(Transform):
4747
"""Clamp keypoints to their corresponding image dimensions.
4848
49-
Args:
50-
clamping_mode: Default is "auto" which relies on the input keypoint'
51-
``clamping_mode`` attribute.
52-
The clamping is done according to the keypoints' ``canvas_size`` meta-data.
53-
Read more in :ref:`clamping_mode_tuto`
54-
for more details on how to use this transform.
55-
49+
Args:
50+
clamping_mode: Default is "auto" which relies on the input keypoint'
51+
``clamping_mode`` attribute.
52+
The clamping is done according to the keypoints' ``canvas_size`` meta-data.
53+
Read more in :ref:`clamping_mode_tuto`
54+
for more details on how to use this transform.
55+
5656
"""
57+
5758
def __init__(self, clamping_mode: Union[CLAMPING_MODE_TYPE, str] = "auto") -> None:
5859
super().__init__()
5960
self.clamping_mode = clamping_mode

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
6767
return horizontal_flip_image(mask)
6868

6969

70-
def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft"):
70+
def horizontal_flip_keypoints(
71+
keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft"
72+
):
7173
shape = keypoints.shape
7274
keypoints = keypoints.clone().reshape(-1, 2)
7375
keypoints[..., 0] = keypoints[..., 0].sub_(canvas_size[1] - 1).neg_()
@@ -76,7 +78,9 @@ def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, i
7678

7779
@_register_kernel_internal(horizontal_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
7880
def _horizontal_flip_keypoints_dispatch(keypoints: tv_tensors.KeyPoints):
79-
out = horizontal_flip_keypoints(keypoints.as_subclass(torch.Tensor), canvas_size=keypoints.canvas_size, clamping_mode=keypoints.clamping_mode)
81+
out = horizontal_flip_keypoints(
82+
keypoints.as_subclass(torch.Tensor), canvas_size=keypoints.canvas_size, clamping_mode=keypoints.clamping_mode
83+
)
8084
return tv_tensors.wrap(out, like=keypoints)
8185

8286

@@ -155,7 +159,11 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
155159
return vertical_flip_image(mask)
156160

157161

158-
def vertical_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE = "soft",) -> torch.Tensor:
162+
def vertical_flip_keypoints(
163+
keypoints: torch.Tensor,
164+
canvas_size: tuple[int, int],
165+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
166+
) -> torch.Tensor:
159167
shape = keypoints.shape
160168
keypoints = keypoints.clone().reshape(-1, 2)
161169
keypoints[..., 1] = keypoints[..., 1].sub_(canvas_size[0] - 1).neg_()
@@ -199,7 +207,9 @@ def vertical_flip_bounding_boxes(
199207

200208
@_register_kernel_internal(vertical_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
201209
def _vertical_flip_keypoints_dispatch(inpt: tv_tensors.KeyPoints) -> tv_tensors.KeyPoints:
202-
output = vertical_flip_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=inpt.clamping_mode)
210+
output = vertical_flip_keypoints(
211+
inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=inpt.clamping_mode
212+
)
203213
return tv_tensors.wrap(output, like=inpt)
204214

205215

@@ -1027,7 +1037,9 @@ def _affine_keypoints_with_expand(
10271037
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
10281038
canvas_size = (new_height, new_width)
10291039

1030-
out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size, clamping_mode=clamping_mode).reshape(original_shape)
1040+
out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size, clamping_mode=clamping_mode).reshape(
1041+
original_shape
1042+
)
10311043
out_keypoints = out_keypoints.to(original_dtype)
10321044

10331045
return out_keypoints, canvas_size
@@ -1417,7 +1429,12 @@ def _rotate_keypoints_dispatch(
14171429
inpt: tv_tensors.KeyPoints, angle: float, expand: bool = False, center: Optional[list[float]] = None, **kwargs
14181430
) -> tv_tensors.KeyPoints:
14191431
output, canvas_size = rotate_keypoints(
1420-
inpt, canvas_size=inpt.canvas_size, angle=angle, center=center, expand=expand, clamping_mode=inpt.clamping_mode,
1432+
inpt,
1433+
canvas_size=inpt.canvas_size,
1434+
angle=angle,
1435+
center=center,
1436+
expand=expand,
1437+
clamping_mode=inpt.clamping_mode,
14211438
)
14221439
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
14231440

@@ -1689,7 +1706,11 @@ def pad_mask(
16891706

16901707

16911708
def pad_keypoints(
1692-
keypoints: torch.Tensor, canvas_size: tuple[int, int], padding: list[int], padding_mode: str = "constant", clamping_mode: CLAMPING_MODE_TYPE = "soft"
1709+
keypoints: torch.Tensor,
1710+
canvas_size: tuple[int, int],
1711+
padding: list[int],
1712+
padding_mode: str = "constant",
1713+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
16931714
):
16941715
SUPPORTED_MODES = ["constant"]
16951716
if padding_mode not in SUPPORTED_MODES:
@@ -1832,7 +1853,9 @@ def crop_keypoints(
18321853
def _crop_keypoints_dispatch(
18331854
inpt: tv_tensors.KeyPoints, top: int, left: int, height: int, width: int
18341855
) -> tv_tensors.KeyPoints:
1835-
output, canvas_size = crop_keypoints(inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, clamping_mode=inpt.clamping_mode)
1856+
output, canvas_size = crop_keypoints(
1857+
inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, clamping_mode=inpt.clamping_mode
1858+
)
18361859
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
18371860

18381861

@@ -2056,7 +2079,9 @@ def perspective_keypoints(
20562079
numer_points = torch.matmul(points, theta1.T)
20572080
denom_points = torch.matmul(points, theta2.T)
20582081
transformed_points = numer_points.div_(denom_points)
2059-
return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode).reshape(original_shape)
2082+
return clamp_keypoints(
2083+
transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode
2084+
).reshape(original_shape)
20602085

20612086

20622087
@_register_kernel_internal(perspective, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
@@ -2354,7 +2379,10 @@ def _create_identity_grid(size: tuple[int, int], device: torch.device, dtype: to
23542379

23552380

23562381
def elastic_keypoints(
2357-
keypoints: torch.Tensor, canvas_size: tuple[int, int], displacement: torch.Tensor, clamping_mode: CLAMPING_MODE_TYPE = "soft",
2382+
keypoints: torch.Tensor,
2383+
canvas_size: tuple[int, int],
2384+
displacement: torch.Tensor,
2385+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
23582386
) -> torch.Tensor:
23592387
expected_shape = (1, canvas_size[0], canvas_size[1], 2)
23602388
if not isinstance(displacement, torch.Tensor):
@@ -2386,12 +2414,19 @@ def elastic_keypoints(
23862414
t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
23872415
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
23882416

2389-
return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode).reshape(original_shape)
2417+
return clamp_keypoints(
2418+
transformed_points.to(keypoints.dtype), canvas_size=canvas_size, clamping_mode=clamping_mode
2419+
).reshape(original_shape)
23902420

23912421

23922422
@_register_kernel_internal(elastic, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
23932423
def _elastic_keypoints_dispatch(inpt: tv_tensors.KeyPoints, displacement: torch.Tensor, **kwargs):
2394-
output = elastic_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, displacement=displacement, clamping_mode=inpt.clamping_mode)
2424+
output = elastic_keypoints(
2425+
inpt.as_subclass(torch.Tensor),
2426+
canvas_size=inpt.canvas_size,
2427+
displacement=displacement,
2428+
clamping_mode=inpt.clamping_mode,
2429+
)
23952430
return tv_tensors.wrap(output, like=inpt)
23962431

23972432

@@ -2588,16 +2623,26 @@ def _center_crop_image_pil(image: PIL.Image.Image, output_size: list[int]) -> PI
25882623
return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
25892624

25902625

2591-
def center_crop_keypoints(inpt: torch.Tensor, canvas_size: tuple[int, int], output_size: list[int], clamping_mode: CLAMPING_MODE_TYPE = "soft",):
2626+
def center_crop_keypoints(
2627+
inpt: torch.Tensor,
2628+
canvas_size: tuple[int, int],
2629+
output_size: list[int],
2630+
clamping_mode: CLAMPING_MODE_TYPE = "soft",
2631+
):
25922632
crop_height, crop_width = _center_crop_parse_output_size(output_size)
25932633
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size)
2594-
return crop_keypoints(inpt, top=crop_top, left=crop_left, height=crop_height, width=crop_width, clamping_mode=clamping_mode)
2634+
return crop_keypoints(
2635+
inpt, top=crop_top, left=crop_left, height=crop_height, width=crop_width, clamping_mode=clamping_mode
2636+
)
25952637

25962638

25972639
@_register_kernel_internal(center_crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
25982640
def _center_crop_keypoints_dispatch(inpt: tv_tensors.KeyPoints, output_size: list[int]) -> tv_tensors.KeyPoints:
25992641
output, canvas_size = center_crop_keypoints(
2600-
inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, output_size=output_size, clamping_mode=inpt.clamping_mode,
2642+
inpt.as_subclass(torch.Tensor),
2643+
canvas_size=inpt.canvas_size,
2644+
output_size=output_size,
2645+
clamping_mode=inpt.clamping_mode,
26012646
)
26022647
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
26032648

@@ -2766,7 +2811,13 @@ def _resized_crop_keypoints_dispatch(
27662811
inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: list[int], **kwargs
27672812
):
27682813
output, canvas_size = resized_crop_keypoints(
2769-
inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size, clamping_mode=inpt.clamping_mode,
2814+
inpt.as_subclass(torch.Tensor),
2815+
top=top,
2816+
left=left,
2817+
height=height,
2818+
width=width,
2819+
size=size,
2820+
clamping_mode=inpt.clamping_mode,
27702821
)
27712822
return tv_tensors.wrap(output, like=inpt, canvas_size=canvas_size)
27722823

torchvision/transforms/v2/functional/_meta.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
import torch
55
from torchvision import tv_tensors
66
from torchvision.transforms import _functional_pil as _FP
7-
from torchvision.tv_tensors import BoundingBoxFormat
8-
from torchvision.tv_tensors import CLAMPING_MODE_TYPE
7+
from torchvision.tv_tensors import BoundingBoxFormat, CLAMPING_MODE_TYPE
98

109
from torchvision.utils import _log_api_usage_once
1110

@@ -653,7 +652,9 @@ def clamp_bounding_boxes(
653652
)
654653

655654

656-
def _clamp_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE) -> torch.Tensor:
655+
def _clamp_keypoints(
656+
keypoints: torch.Tensor, canvas_size: tuple[int, int], clamping_mode: CLAMPING_MODE_TYPE
657+
) -> torch.Tensor:
657658
if clamping_mode is None or clamping_mode != "hard":
658659
return keypoints.clone()
659660
dtype = keypoints.dtype
@@ -687,7 +688,9 @@ def clamp_keypoints(
687688
raise ValueError("For keypoints tv_tensor inputs, `canvas_size` must not be passed.")
688689
if clamping_mode is None and clamping_mode == "auto":
689690
clamping_mode = inpt.clamping_mode
690-
output = _clamp_keypoints(inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=clamping_mode)
691+
output = _clamp_keypoints(
692+
inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, clamping_mode=clamping_mode
693+
)
691694
return tv_tensors.wrap(output, like=inpt)
692695
else:
693696
raise TypeError(f"Input can either be a plain tensor or a keypoints tv_tensor, but got {type(inpt)} instead.")

torchvision/tv_tensors/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, is_rotated_bounding_format, CLAMPING_MODE_TYPE
3+
from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, CLAMPING_MODE_TYPE, is_rotated_bounding_format
44
from ._image import Image
55
from ._keypoints import KeyPoints
66
from ._mask import Mask

torchvision/tv_tensors/_keypoints.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import torch
66
from torch.utils._pytree import tree_flatten
77

8-
from ._tv_tensor import TVTensor
98
from ._bounding_boxes import CLAMPING_MODE_TYPE
109

10+
from ._tv_tensor import TVTensor
11+
1112

1213
class KeyPoints(TVTensor):
1314
""":class:`torch.Tensor` subclass for tensors with shape ``[..., 2]`` that represent points in an image.
@@ -103,7 +104,10 @@ def _wrap_output(
103104
output = KeyPoints._wrap(output, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False)
104105
elif isinstance(output, (tuple, list)):
105106
# This branch exists for chunk() and unbind()
106-
output = type(output)(KeyPoints._wrap(part, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False) for part in output)
107+
output = type(output)(
108+
KeyPoints._wrap(part, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False)
109+
for part in output
110+
)
107111
return output
108112

109113
def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]

0 commit comments

Comments
 (0)