Skip to content

Commit 40cecba

Browse files
authored
Start cleaning up type annotations (#24)
I started looking at this due to the obvious errors in the type annotations mentioned in #23. When I started this branch, running `mypy main.py` gave 150ish errors. Between this PR and viamrobotics/viam-python-sdk#846, it's down to 90ish errors. So, there's more to do, but this is at least a start. When there are no errors, we should add another CI check that running mypy (or any other type checker, if you want something else) has no errors. * add mypy to requirements.txt * add Optional type annotation to all variables that can be initialized to None * don't pass in floats when we expect ints * fix copypasta: default value of a boolean should be a boolean, not a string * add type annotations for tabulate module to requirements.txt * don't return an exception, raise it * clarify types of optional fields * fix type annotation Image -> Optional[Image.Image] * fix mistake from rebase conflict
1 parent 59841e5 commit 40cecba

File tree

5 files changed

+27
-16
lines changed

5 files changed

+27
-16
lines changed

requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ MarkupSafe==3.0.2
2020
mccabe==0.7.0
2121
mpmath==1.3.0
2222
multidict==6.1.0
23+
mypy==1.15.0
24+
mypy-extensions==1.0.0
2325
networkx==3.4.2
2426
numpy==1.26.4
2527
opencv-python==4.11.0.86
@@ -39,6 +41,7 @@ sympy==1.13.1
3941
tabulate==0.9.0
4042
tomli==2.2.1
4143
tomlkit==0.13.2
44+
types-tabulate==0.9.0.20241207
4245
typing_extensions==4.12.2
4346
viam-sdk==0.38.0
4447

src/config/attribute.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def __init__(
99
field_name: str,
1010
config: ServiceConfig,
1111
required: bool = False,
12-
default_value: Any = None,
12+
default_value: Optional[Any] = None,
1313
):
1414
self.field_name = field_name
1515
self.config = config
@@ -41,9 +41,9 @@ def __init__(
4141
field_name: str,
4242
config: ServiceConfig,
4343
required: bool = False,
44-
min_value: int = None,
45-
max_value: int = None,
46-
default_value: int = None,
44+
min_value: Optional[int] = None,
45+
max_value: Optional[int] = None,
46+
default_value: Optional[int] = None,
4747
):
4848
self.min_value = min_value
4949
self.max_value = max_value
@@ -78,9 +78,9 @@ def __init__(
7878
field_name: str,
7979
config: ServiceConfig,
8080
required: bool = False,
81-
min_value: float = None,
82-
max_value: float = None,
83-
default_value: float = None,
81+
min_value: Optional[float] = None,
82+
max_value: Optional[float] = None,
83+
default_value: Optional[float] = None,
8484
):
8585
self.min_value = min_value
8686
self.max_value = max_value
@@ -113,7 +113,7 @@ def __init__(
113113
config: "ServiceConfig",
114114
required: bool = False,
115115
allowlist: Optional[list] = None,
116-
default_value: str = None,
116+
default_value: Optional[str] = None,
117117
):
118118
self.allowlist = allowlist
119119
super().__init__(field_name, config, required, default_value)
@@ -148,7 +148,7 @@ def __init__(
148148
field_name: str,
149149
config: "ServiceConfig",
150150
required: bool = False,
151-
default_value: str = None,
151+
default_value: Optional[bool] = None,
152152
):
153153
super().__init__(field_name, config, required, default_value)
154154

src/config/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def __init__(self, config: "ServiceConfig"):
3434
field_name="max_age_track",
3535
config=config,
3636
min_value=0,
37-
max_value=1e5,
38-
default_value=1e3,
37+
max_value=100000,
38+
default_value=1000,
3939
)
4040
self.min_distance_threshold = FloatAttribute(
4141
field_name="min_distance_threshold",

src/image/image.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
from typing import Optional
23

34
import numpy as np
45
import torch
@@ -18,7 +19,12 @@ def get_tensor_from_np_array(np_array: np.ndarray) -> torch.Tensor:
1819

1920

2021
class ImageObject:
21-
def __init__(self, viam_image: ViamImage, pil_image: Image = None, device=None):
22+
def __init__(
23+
self,
24+
viam_image: ViamImage,
25+
pil_image: Optional[Image.Image] = None,
26+
device=None
27+
):
2228
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2329
if pil_image is not None:
2430
self.pil_image = pil_image

src/tracker/track.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import numpy as np
24
import torch
35
from viam.services.vision import Detection
@@ -31,11 +33,11 @@ def __init__(
3133

3234
self.label = label
3335

34-
self.label_from_reid = None
36+
self.label_from_reid: Optional[str] = None
3537
self.conf_from_reid = 0
3638

37-
self.label_from_faceid = None
38-
self.conf_from_faceid = None
39+
self.label_from_faceid: Optional[str] = None
40+
self.conf_from_faceid: Optional[str] = None
3941

4042
self.persistence: int = 0
4143
self.is_candidate: bool = is_candidate
@@ -144,7 +146,7 @@ def feature_distance(self, feature_vector):
144146
def get_detection(self, min_persistence=None) -> Detection:
145147
if self.is_candidate:
146148
if min_persistence is None:
147-
return ValueError(
149+
raise ValueError(
148150
"Need to pass persistence in argument to get track candidate"
149151
)
150152
class_name = self._get_label(min_persistence)

0 commit comments

Comments
 (0)