11from __future__ import annotations
22
3- from typing import Any , Mapping , MutableSequence , Optional , Sequence , Tuple , TYPE_CHECKING , Union
3+ from typing import Any , Mapping , MutableSequence , Optional , Sequence , Tuple , Union
44
55import torch
66from torch .utils ._pytree import tree_flatten
@@ -40,13 +40,13 @@ class KeyPoints(TVTensor):
4040 ``data`` is a :class:`torch.Tensor`, the value is taken from it. Otherwise, defaults to ``False``.
4141 """
4242
43- canvas_size : Tuple [int , int ]
43+ canvas_size : tuple [int , int ]
4444
4545 def __new__ (
4646 cls ,
4747 data : Any ,
4848 * ,
49- canvas_size : Tuple [int , int ],
49+ canvas_size : tuple [int , int ],
5050 dtype : Optional [torch .dtype ] = None ,
5151 device : Optional [Union [torch .device , str , int ]] = None ,
5252 requires_grad : Optional [bool ] = None ,
@@ -60,21 +60,6 @@ def __new__(
6060 points .canvas_size = canvas_size
6161 return points
6262
63- if TYPE_CHECKING :
64- # EVIL: Just so that MYPY+PYLANCE+others stop shouting that everything is wrong when initializeing the TVTensor
65- # Not read or defined at Runtime (only at linting time).
66- # TODO: BOUNDING BOXES needs something similar
67- def __init__ (
68- self ,
69- data : Any ,
70- * ,
71- canvas_size : Tuple [int , int ],
72- dtype : Optional [torch .dtype ] = None ,
73- device : Optional [Union [torch .device , str , int ]] = None ,
74- requires_grad : Optional [bool ] = None ,
75- ):
76- pass
77-
7863 @classmethod
7964 def _wrap_output (
8065 cls ,
@@ -87,7 +72,7 @@ def _wrap_output(
8772 # For BoundingBoxes, that included format, but we only support one format here !
8873 flat_params , _ = tree_flatten (args + (tuple (kwargs .values ()) if kwargs else ())) # type: ignore[operator]
8974 first_bbox_from_args = next (x for x in flat_params if isinstance (x , KeyPoints ))
90- canvas_size : Tuple [ int , int ] = first_bbox_from_args .canvas_size
75+ canvas_size = first_bbox_from_args .canvas_size
9176
9277 if isinstance (output , torch .Tensor ) and not isinstance (output , KeyPoints ):
9378 output = KeyPoints (output , canvas_size = canvas_size )
0 commit comments