Skip to content

Commit 801e24d

Browse files
committed
get_all_keypoints is now get_keypoints and returns the only keypoints object in the sample (as is assumed)
1 parent b68b57b commit 801e24d

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

torchvision/transforms/v2/_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import collections.abc
44
import numbers
5-
from collections.abc import Iterable, Sequence
5+
from collections.abc import Sequence
66
from contextlib import suppress
77

88
from typing import Any, Callable, Literal
@@ -165,18 +165,16 @@ def get_bounding_boxes(flat_inputs: list[Any]) -> tv_tensors.BoundingBoxes:
165165
raise ValueError("No bounding boxes were found in the sample")
166166

167167

168-
def get_all_keypoints(flat_inputs: list[Any]) -> Iterable[tv_tensors.KeyPoints]:
169-
"""Yields all KeyPoints in the input.
168+
def get_keypoints(flat_inputs: list[Any]) -> tv_tensors.KeyPoints:
169+
"""Returns the KeyPoints in the input.
170170
171-
Raises:
172-
ValueError: No KeyPoints can be found
171+
Assumes only one ``KeyPoints`` object is present
173172
"""
174173
generator = (inpt for inpt in flat_inputs if isinstance(inpt, tv_tensors.KeyPoints))
175174
try:
176-
yield next(generator)
175+
return next(generator)
177176
except StopIteration:
178177
raise ValueError("No Keypoints were found in the sample.")
179-
return generator
180178

181179

182180
def query_chw(flat_inputs: list[Any]) -> tuple[int, int, int]:

0 commit comments

Comments
 (0)