Skip to content

Commit 1cc3b6f

Browse files
committed
Fixed _geometry.py post botched merge request
1 parent fcfd597 commit 1cc3b6f

File tree

1 file changed

+74
-74
lines changed

1 file changed

+74
-74
lines changed

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 74 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,15 @@ def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
6666
return horizontal_flip_image(mask)
6767

6868

69-
def horizontal_flip_keypoints(kp: torch.Tensor, canvas_size: Tuple[int, int]):
70-
kp[..., 0] = kp[..., 0].sub_(canvas_size[1]).neg_()
71-
return kp
69+
def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int]):
70+
keypoints[..., 0] = keypoints[..., 0].sub_(canvas_size[1]).neg_()
71+
return keypoints
7272

7373

7474
@_register_kernel_internal(horizontal_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
75-
def _horizontal_flip_keypoints_dispatch(kp: tv_tensors.KeyPoints):
76-
out = horizontal_flip_keypoints(kp.as_subclass(torch.Tensor), canvas_size=kp.canvas_size)
77-
return tv_tensors.wrap(out, like=kp)
75+
def _horizontal_flip_keypoints_dispatch(keypoints: tv_tensors.KeyPoints):
76+
out = horizontal_flip_keypoints(keypoints.as_subclass(torch.Tensor), canvas_size=keypoints.canvas_size)
77+
return tv_tensors.wrap(out, like=keypoints)
7878

7979

8080
def horizontal_flip_bounding_boxes(
@@ -135,9 +135,9 @@ def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
135135

136136

137137
@_register_kernel_internal(vertical_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
138-
def vertical_flip_keypoints(kp: tv_tensors.KeyPoints):
139-
kp[..., 1] = kp[..., 1].sub_(kp.canvas_size[0]).neg_()
140-
return kp
138+
def vertical_flip_keypoints(keypoints: tv_tensors.KeyPoints):
139+
keypoints[..., 1] = keypoints[..., 1].sub_(keypoints.canvas_size[0]).neg_()
140+
return keypoints
141141

142142

143143
def vertical_flip_bounding_boxes(
@@ -352,9 +352,9 @@ def _resize_mask_dispatch(
352352

353353

354354
def resize_keypoints(
355-
kp: torch.Tensor,
356-
size: Optional[List[int]],
357-
canvas_size: Tuple[int, int],
355+
keypoints: torch.Tensor,
356+
size: Optional[list[int]],
357+
canvas_size: tuple[int, int],
358358
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
359359
max_size: Optional[int] = None,
360360
antialias: Optional[bool] = True,
@@ -364,29 +364,29 @@ def resize_keypoints(
364364

365365
w_ratio = new_width / old_width
366366
h_ratio = new_height / old_height
367-
ratios = torch.tensor([w_ratio, h_ratio], device=kp.device)
368-
kp = kp.mul(ratios).to(kp.dtype)
367+
ratios = torch.tensor([w_ratio, h_ratio], device=keypoints.device)
368+
keypoints = keypoints.mul(ratios).to(keypoints.dtype)
369369

370-
return kp, (new_height, new_width)
370+
return keypoints, (new_height, new_width)
371371

372372

373373
@_register_kernel_internal(resize, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
374374
def _resize_keypoints_dispatch(
375-
kp: tv_tensors.KeyPoints,
376-
size: Optional[List[int]],
375+
keypoints: tv_tensors.KeyPoints,
376+
size: Optional[list[int]],
377377
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
378378
max_size: Optional[int] = None,
379379
antialias: Optional[bool] = True,
380380
) -> tv_tensors.KeyPoints:
381381
out, canvas_size = resize_keypoints(
382-
kp.as_subclass(torch.Tensor),
382+
keypoints.as_subclass(torch.Tensor),
383383
size,
384-
canvas_size=kp.canvas_size,
384+
canvas_size=keypoints.canvas_size,
385385
interpolation=interpolation,
386386
max_size=max_size,
387387
antialias=antialias,
388388
)
389-
return tv_tensors.wrap(out, like=kp, canvas_size=canvas_size)
389+
return tv_tensors.wrap(out, like=keypoints, canvas_size=canvas_size)
390390

391391

392392
def resize_bounding_boxes(
@@ -816,14 +816,14 @@ def _affine_image_pil(
816816

817817
def _affine_keypoints_with_expand(
818818
keypoints: torch.Tensor,
819-
canvas_size: Tuple[int, int],
819+
canvas_size: tuple[int, int],
820820
angle: Union[int, float],
821-
translate: List[float],
821+
translate: list[float],
822822
scale: float,
823-
shear: List[float],
824-
center: Optional[List[float]] = None,
823+
shear: list[float],
824+
center: Optional[list[float]] = None,
825825
expand: bool = False,
826-
) -> Tuple[torch.Tensor, Tuple[int, int]]:
826+
) -> tuple[torch.Tensor, tuple[int, int]]:
827827
if keypoints.numel() == 0:
828828
return keypoints, canvas_size
829829

@@ -860,12 +860,12 @@ def _affine_keypoints_with_expand(
860860

861861
def affine_keypoints(
862862
keypoints: torch.Tensor,
863-
canvas_size: Tuple[int, int],
863+
canvas_size: tuple[int, int],
864864
angle: Union[int, float],
865-
translate: List[float],
865+
translate: list[float],
866866
scale: float,
867-
shear: List[float],
868-
center: Optional[List[float]] = None,
867+
shear: list[float],
868+
center: Optional[list[float]] = None,
869869
):
870870
return _affine_keypoints_with_expand(
871871
keypoints=keypoints,
@@ -883,10 +883,10 @@ def affine_keypoints(
883883
def _affine_keypoints_dispatch(
884884
inpt: tv_tensors.KeyPoints,
885885
angle: Union[int, float],
886-
translate: List[float],
886+
translate: list[float],
887887
scale: float,
888-
shear: List[float],
889-
center: Optional[List[float]] = None,
888+
shear: list[float],
889+
center: Optional[list[float]] = None,
890890
**kwargs,
891891
) -> tv_tensors.KeyPoints:
892892
output, canvas_size = affine_keypoints(
@@ -1203,9 +1203,9 @@ def rotate_keypoints(
12031203
angle: float,
12041204
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
12051205
expand: bool = False,
1206-
center: Optional[List[float]] = None,
1206+
center: Optional[list[float]] = None,
12071207
fill: _FillTypeJIT = None,
1208-
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1208+
) -> tuple[torch.Tensor, tuple[int, int]]:
12091209
return _affine_keypoints_with_expand(
12101210
keypoints=keypoints.as_subclass(torch.Tensor),
12111211
canvas_size=keypoints.canvas_size,
@@ -1220,10 +1220,10 @@ def rotate_keypoints(
12201220

12211221
@_register_kernel_internal(rotate, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
12221222
def _rotate_keypoints_dispatch(
1223-
kp: tv_tensors.KeyPoints, angle: float, expand: bool = False, center: Optional[List[float]] = None, **kwargs
1223+
keypoints: tv_tensors.KeyPoints, angle: float, expand: bool = False, center: Optional[list[float]] = None, **kwargs
12241224
) -> tv_tensors.KeyPoints:
1225-
out, canvas_size = rotate_keypoints(kp, angle, center=center, expand=expand, **kwargs)
1226-
return tv_tensors.wrap(out, like=kp, canvas_size=canvas_size)
1225+
out, canvas_size = rotate_keypoints(keypoints, angle, center=center, expand=expand, **kwargs)
1226+
return tv_tensors.wrap(out, like=keypoints, canvas_size=canvas_size)
12271227

12281228

12291229
def rotate_bounding_boxes(
@@ -1490,7 +1490,7 @@ def pad_mask(
14901490

14911491

14921492
def pad_keypoints(
1493-
keypoints: torch.Tensor, canvas_size: Tuple[int, int], padding: List[int], padding_mode: str = "constant"
1493+
keypoints: torch.Tensor, canvas_size: tuple[int, int], padding: list[int], padding_mode: str = "constant"
14941494
):
14951495
SUPPORTED_MODES = ["constant"]
14961496
if padding_mode not in SUPPORTED_MODES:
@@ -1507,7 +1507,7 @@ def pad_keypoints(
15071507

15081508
@_register_kernel_internal(pad, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
15091509
def _pad_keypoints_dispatch(
1510-
keypoints: tv_tensors.KeyPoints, padding: List[int], padding_mode: str = "constant", **kwargs
1510+
keypoints: tv_tensors.KeyPoints, padding: list[int], padding_mode: str = "constant", **kwargs
15111511
) -> tv_tensors.KeyPoints:
15121512
output, canvas_size = pad_keypoints(
15131513
keypoints.as_subclass(torch.Tensor),
@@ -1605,17 +1605,17 @@ def crop_image(image: torch.Tensor, top: int, left: int, height: int, width: int
16051605

16061606

16071607
def crop_keypoints(
1608-
kp: torch.Tensor,
1608+
keypoints: torch.Tensor,
16091609
top: int,
16101610
left: int,
16111611
height: int,
16121612
width: int,
1613-
) -> Tuple[torch.Tensor, Tuple[int, int]]:
1613+
) -> tuple[torch.Tensor, tuple[int, int]]:
16141614

1615-
kp.sub_(torch.tensor([left, top], dtype=kp.dtype, device=kp.device))
1615+
keypoints.sub_(torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device))
16161616
canvas_size = (height, width)
16171617

1618-
return clamp_keypoints(kp, canvas_size=canvas_size), canvas_size
1618+
return clamp_keypoints(keypoints, canvas_size=canvas_size), canvas_size
16191619

16201620

16211621
@_register_kernel_internal(crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
@@ -1800,16 +1800,16 @@ def _perspective_image_pil(
18001800

18011801

18021802
def perspectice_keypoints(
1803-
kp: torch.Tensor,
1804-
canvas_size: Tuple[int, int],
1805-
startpoints: Optional[List[List[int]]],
1806-
endpoints: Optional[List[List[int]]],
1807-
coefficients: Optional[List[float]] = None,
1803+
keypoints: torch.Tensor,
1804+
canvas_size: tuple[int, int],
1805+
startpoints: Optional[list[list[int]]],
1806+
endpoints: Optional[list[list[int]]],
1807+
coefficients: Optional[list[float]] = None,
18081808
):
1809-
if kp.numel() == 0:
1810-
return kp
1811-
dtype = kp.dtype if torch.is_floating_point(kp) else torch.float32
1812-
device = kp.device
1809+
if keypoints.numel() == 0:
1810+
return keypoints
1811+
dtype = keypoints.dtype if torch.is_floating_point(keypoints) else torch.float32
1812+
device = keypoints.device
18131813

18141814
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
18151815

@@ -1821,20 +1821,20 @@ def perspectice_keypoints(
18211821
)
18221822

18231823
theta1, theta2 = _compute_perspective_thetas(perspective_coeffs, dtype, device, denom)
1824-
kp = torch.cat([kp, torch.ones(kp.shape[0], 1, device=kp.device)], dim=-1)
1824+
keypoints = torch.cat([keypoints, torch.ones(keypoints.shape[0], 1, device=keypoints.device)], dim=-1)
18251825

1826-
numer_points = torch.matmul(kp, theta1.T)
1827-
denom_points = torch.matmul(kp, theta2.T)
1826+
numer_points = torch.matmul(keypoints, theta1.T)
1827+
denom_points = torch.matmul(keypoints, theta2.T)
18281828
transformed_points = numer_points.div_(denom_points)
18291829
return clamp_keypoints(transformed_points, canvas_size)
18301830

18311831

18321832
@_register_kernel_internal(perspective, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
18331833
def _perspective_keypoints_dispatch(
18341834
inpt: tv_tensors.BoundingBoxes,
1835-
startpoints: Optional[List[List[int]]],
1836-
endpoints: Optional[List[List[int]]],
1837-
coefficients: Optional[List[float]] = None,
1835+
startpoints: Optional[list[list[int]]],
1836+
endpoints: Optional[list[list[int]]],
1837+
coefficients: Optional[list[float]] = None,
18381838
**kwargs,
18391839
) -> tv_tensors.BoundingBoxes:
18401840
output = perspectice_keypoints(
@@ -1923,11 +1923,11 @@ def perspective_bounding_boxes(
19231923

19241924

19251925
def _compute_perspective_thetas(
1926-
perspective_coeffs: List[float],
1926+
perspective_coeffs: list[float],
19271927
dtype: torch.dtype,
19281928
device: torch.device,
19291929
denom: float,
1930-
) -> Tuple[torch.Tensor, torch.Tensor]:
1930+
) -> tuple[torch.Tensor, torch.Tensor]:
19311931
inv_coeffs = [
19321932
(perspective_coeffs[4] - perspective_coeffs[5] * perspective_coeffs[7]) / denom,
19331933
(-perspective_coeffs[1] + perspective_coeffs[2] * perspective_coeffs[7]) / denom,
@@ -2112,26 +2112,26 @@ def _create_identity_grid(size: tuple[int, int], device: torch.device, dtype: to
21122112
return base_grid
21132113

21142114

2115-
def elastic_keypoints(kp: torch.Tensor, canvas_size: Tuple[int, int], displacement: torch.Tensor) -> torch.Tensor:
2115+
def elastic_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int], displacement: torch.Tensor) -> torch.Tensor:
21162116
expected_shape = (1, canvas_size[0], canvas_size[1], 2)
21172117
if not isinstance(displacement, torch.Tensor):
21182118
raise TypeError("Argument displacement should be a Tensor")
21192119
elif displacement.shape != expected_shape:
21202120
raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
21212121

2122-
if kp.numel() == 0:
2123-
return kp
2122+
if keypoints.numel() == 0:
2123+
return keypoints
21242124

2125-
device = kp.device
2126-
dtype = kp.dtype if torch.is_floating_point(kp) else torch.float32
2125+
device = keypoints.device
2126+
dtype = keypoints.dtype if torch.is_floating_point(keypoints) else torch.float32
21272127

21282128
if displacement.dtype != dtype or displacement.device != device:
21292129
displacement = displacement.to(dtype=dtype, device=device)
21302130

21312131
id_grid = _create_identity_grid(canvas_size, device=device, dtype=dtype)
21322132
inv_grid = id_grid.sub_(displacement)
21332133

2134-
index_xy = kp.to(dtype=torch.long)
2134+
index_xy = keypoints.to(dtype=torch.long)
21352135
index_x, index_y = index_xy[:, 0], index_xy[:, 1]
21362136
# Unlike bounding boxes, this may not work well.
21372137
index_x.clamp_(0, inv_grid.shape[2] - 1)
@@ -2329,14 +2329,14 @@ def _center_crop_image_pil(image: PIL.Image.Image, output_size: list[int]) -> PI
23292329
return _crop_image_pil(image, crop_top, crop_left, crop_height, crop_width)
23302330

23312331

2332-
def center_crop_keypoints(inpt: torch.Tensor, canvas_size: Tuple[int, int], output_size: List[int]):
2332+
def center_crop_keypoints(inpt: torch.Tensor, canvas_size: tuple[int, int], output_size: list[int]):
23332333
crop_height, crop_width = _center_crop_parse_output_size(output_size)
23342334
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *canvas_size)
23352335
return crop_keypoints(inpt, top=crop_top, left=crop_left, height=crop_height, width=crop_width)
23362336

23372337

23382338
@_register_kernel_internal(center_crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
2339-
def _center_crop_keypoints_dispatch(inpt: tv_tensors.KeyPoints, output_size: List[int]) -> tv_tensors.KeyPoints:
2339+
def _center_crop_keypoints_dispatch(inpt: tv_tensors.KeyPoints, output_size: list[int]) -> tv_tensors.KeyPoints:
23402340
output, canvas_size = center_crop_keypoints(
23412341
inpt.as_subclass(torch.Tensor), canvas_size=inpt.canvas_size, output_size=output_size
23422342
)
@@ -2479,20 +2479,20 @@ def _resized_crop_image_pil_dispatch(
24792479

24802480

24812481
def resized_crop_keypoints(
2482-
kp: torch.Tensor,
2482+
keypoints: torch.Tensor,
24832483
top: int,
24842484
left: int,
24852485
height: int,
24862486
width: int,
2487-
size: List[int],
2488-
) -> Tuple[torch.Tensor, Tuple[int, int]]:
2489-
kp, canvas_size = crop_keypoints(kp, top, left, height, width)
2490-
return resize_keypoints(kp, size=size, canvas_size=canvas_size)
2487+
size: list[int],
2488+
) -> tuple[torch.Tensor, tuple[int, int]]:
2489+
keypoints, canvas_size = crop_keypoints(keypoints, top, left, height, width)
2490+
return resize_keypoints(keypoints, size=size, canvas_size=canvas_size)
24912491

24922492

24932493
@_register_kernel_internal(resized_crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
24942494
def _resized_crop_keypoints_dispatch(
2495-
inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: List[int], **kwargs
2495+
inpt: tv_tensors.BoundingBoxes, top: int, left: int, height: int, width: int, size: list[int], **kwargs
24962496
):
24972497
output, canvas_size = resized_crop_keypoints(
24982498
inpt.as_subclass(torch.Tensor), top=top, left=left, height=height, width=width, size=size

0 commit comments

Comments
 (0)