Skip to content

Commit 3750654

Browse files
authored
chore: Enhance type checking with mypy and improve code quality (#285)
* chore: Replace mypy with pyright for type checking and improve code quality * chore: address comments on the PR * chore: remove unused auth function
1 parent 9c361e3 commit 3750654

File tree

15 files changed

+65
-61
lines changed

15 files changed

+65
-61
lines changed

.dockerignore

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,20 @@
1+
# Python cache files
2+
__pycache__
3+
*.pyc
4+
5+
# IDE settings
6+
.vscode/
7+
8+
# Version control
9+
.git/
10+
11+
# Distribution / packaging
12+
build/
13+
dist/
14+
*.egg-info/
15+
16+
# Virtual environments
17+
.venv
18+
19+
# Testing
120
/tests/manual/data

.github/workflows/test.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,8 @@ jobs:
2727
run: |
2828
python -m pip install --upgrade pip
2929
pip install ".[dev]"
30-
- name: 🧹 Lint
30+
- name: 🧹 Check code quality
3131
run: |
3232
make check_code_quality
33-
- name: Check types with mypy
34-
run: mypy .
35-
- name: 🧪 Test
33+
- name: 🧪 Run tests
3634
run: "python -m unittest"

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.PHONY: style check_code_quality
1+
.PHONY: style check_code_quality publish
22

33
export PYTHONPATH = .
44
check_dirs := roboflow
@@ -10,6 +10,7 @@ style:
1010
check_code_quality:
1111
ruff format $(check_dirs) --check
1212
ruff check $(check_dirs)
13+
mypy $(check_dirs)
1314

1415
publish:
1516
python setup.py sdist bdist_wheel

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ target-version = "py38"
1111
line-length = 120
1212

1313
[tool.ruff.lint]
14-
select = [
15-
"ALL",
16-
]
14+
select = ["ALL"]
1715
ignore = [
1816
"A",
1917
"ANN",

roboflow/__init__.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import sys
44
import time
55
from getpass import getpass
6+
from pathlib import Path
67
from urllib.parse import urlparse
78

89
import requests
@@ -59,27 +60,16 @@ def check_key(api_key, model, notebook, num_retries=0):
5960
return "onboarding"
6061

6162

62-
def auth(api_key):
63-
r = check_key(api_key)
64-
w = r["workspace"]
65-
66-
return Roboflow(api_key, w)
67-
68-
6963
def login(workspace=None, force=False):
7064
os_name = os.name
7165

7266
if os_name == "nt":
73-
default_path = os.path.join(os.getenv("USERPROFILE"), "roboflow/config.json")
67+
default_path = str(Path.home() / "roboflow" / "config.json")
7468
else:
75-
default_path = os.path.join(os.getenv("HOME"), ".config/roboflow/config.json")
69+
default_path = str(Path.home() / ".config" / "roboflow" / "config.json")
7670

7771
# default configuration location
78-
conf_location = os.getenv(
79-
"ROBOFLOW_CONFIG_DIR",
80-
default=default_path,
81-
)
82-
72+
conf_location = os.getenv("ROBOFLOW_CONFIG_DIR", default=default_path)
8373
if os.path.isfile(conf_location) and not force:
8474
write_line("You are already logged into Roboflow. To make a different login," "run roboflow.login(force=True).")
8575
return None
@@ -141,10 +131,7 @@ def initialize_roboflow(the_workspace=None):
141131

142132
global active_workspace
143133

144-
conf_location = os.getenv(
145-
"ROBOFLOW_CONFIG_DIR",
146-
default=os.getenv("HOME") + "/.config/roboflow/config.json",
147-
)
134+
conf_location = os.getenv("ROBOFLOW_CONFIG_DIR", default=str(Path.home() / ".config" / "roboflow" / "config.json"))
148135

149136
if not os.path.isfile(conf_location):
150137
raise RuntimeError("To use this method, you must first login - run roboflow.login()")
@@ -176,7 +163,7 @@ def load_model(model_url):
176163
project = path_parts[2]
177164
version = int(path_parts[-1])
178165
else:
179-
raise ("Model URL must be from either app.roboflow.com or universe.roboflow.com")
166+
raise ValueError("Model URL must be from either app.roboflow.com or universe.roboflow.com")
180167

181168
project = operate_workspace.project(project)
182169
version = project.version(version)
@@ -204,7 +191,7 @@ def download_dataset(dataset_url, model_format, location=None):
204191
version = int(path_parts[-1])
205192
the_workspace = path_parts[1]
206193
else:
207-
raise ("Model URL must be from either app.roboflow.com or universe.roboflow.com")
194+
raise ValueError("Model URL must be from either app.roboflow.com or universe.roboflow.com")
208195
operate_workspace = initialize_roboflow(the_workspace=the_workspace)
209196

210197
project = operate_workspace.project(project)
@@ -239,7 +226,7 @@ def auth(self):
239226
self.universe = True
240227
return self
241228
else:
242-
w = r["workspace"]
229+
w = r["workspace"] # type: ignore[arg-type]
243230
self.current_workspace = w
244231
return self
245232

roboflow/core/project.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def generate_version(self, settings):
230230
try:
231231
r_json = r.json()
232232
except Exception:
233-
raise "Error when requesting to generate a new version for project."
233+
raise RuntimeError("Error when requesting to generate a new version for project.")
234234

235235
# if the generation succeeds, return the version that is being generated
236236
if r.status_code == 200:
@@ -256,7 +256,7 @@ def train(
256256
speed=None,
257257
checkpoint=None,
258258
plot_in_notebook=False,
259-
) -> bool:
259+
):
260260
"""
261261
Ask the Roboflow API to train a previously exported version's dataset.
262262
@@ -503,7 +503,7 @@ def single_upload(
503503
sequence_size=sequence_size,
504504
**kwargs,
505505
)
506-
image_id = uploaded_image["id"]
506+
image_id = uploaded_image["id"] # type: ignore[index]
507507
upload_retry_attempts = retry.retries
508508
except BaseException as e:
509509
uploaded_image = {"error": e}
@@ -518,10 +518,10 @@ def single_upload(
518518
uploaded_annotation = rfapi.save_annotation(
519519
self.__api_key,
520520
project_url,
521-
annotation_name,
522-
annotation_str,
521+
annotation_name, # type: ignore[type-var]
522+
annotation_str, # type: ignore[type-var]
523523
image_id,
524-
job_name=batch_name,
524+
job_name=batch_name, # type: ignore[type-var]
525525
is_prediction=is_prediction,
526526
annotation_labelmap=annotation_labelmap,
527527
overwrite=annotation_overwrite,
@@ -543,10 +543,10 @@ def _annotation_params(self, annotation_path):
543543
if isinstance(annotation_path, dict) and annotation_path.get("rawText"):
544544
annotation_name = annotation_path["name"]
545545
annotation_string = annotation_path["rawText"]
546-
elif os.path.exists(annotation_path):
547-
with open(annotation_path):
548-
annotation_string = open(annotation_path).read()
549-
annotation_name = os.path.basename(annotation_path)
546+
elif os.path.exists(annotation_path): # type: ignore[arg-type]
547+
with open(annotation_path): # type: ignore[arg-type]
548+
annotation_string = open(annotation_path).read() # type: ignore[arg-type]
549+
annotation_name = os.path.basename(annotation_path) # type: ignore[arg-type]
550550
elif self.type == "classification":
551551
print(f"-> using {annotation_path} as classname for classification project")
552552
annotation_string = annotation_path

roboflow/core/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,7 @@ def bar_progress(current, total, width=80):
766766

767767
# write the zip file to the desired location
768768
with open(location + "/roboflow.zip", "wb") as f:
769-
total_length = int(response.headers.get("content-length"))
769+
total_length = int(response.headers.get("content-length")) # type: ignore[arg-type]
770770
desc = None if TQDM_DISABLE else f"Downloading Dataset Version Zip in {location} to {format}:"
771771
for chunk in tqdm(
772772
response.iter_content(chunk_size=1024),

roboflow/core/workspace.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import os
55
import sys
6-
from typing import List
6+
from typing import Any, List
77

88
import numpy as np
99
import requests
@@ -179,7 +179,7 @@ def two_stage(
179179
print(self.project(first_stage_model_name))
180180

181181
# perform first inference
182-
predictions = stage_one_model.predict(image)
182+
predictions = stage_one_model.predict(image) # type: ignore[attribute-error]
183183

184184
if stage_one_project.type == "object-detection" and stage_two_project == "classification":
185185
# interact with each detected object from stage one inference results
@@ -199,7 +199,7 @@ def two_stage(
199199
croppedImg.save("./temp.png")
200200

201201
# capture results of second stage inference from cropped image
202-
results.append(stage_two_model.predict("./temp.png")[0])
202+
results.append(stage_two_model.predict("./temp.png")[0]) # type: ignore[attribute-error]
203203

204204
# delete the written image artifact
205205
try:
@@ -244,7 +244,7 @@ def two_stage_ocr(
244244
stage_one_model = stage_one_project.version(first_stage_model_version).model
245245

246246
# perform first inference
247-
predictions = stage_one_model.predict(image)
247+
predictions = stage_one_model.predict(image) # type: ignore[attribute-error]
248248

249249
# interact with each detected object from stage one inference results
250250
if stage_one_project.type == "object-detection":
@@ -391,7 +391,7 @@ def active_learning(
391391
upload_destination: str = "",
392392
conditionals: dict = {},
393393
use_localhost: bool = False,
394-
) -> str:
394+
) -> Any:
395395
"""perform inference on each image in directory and upload based on conditions
396396
@params:
397397
raw_data_location: (str) = folder of frames to be processed
@@ -470,7 +470,7 @@ def active_learning(
470470
print(image2 + " --> similarity too high to --> " + image1)
471471
continue # skip this image if too similar or counter hits limit
472472

473-
predictions = inference_model.predict(image).json()["predictions"]
473+
predictions = inference_model.predict(image).json()["predictions"] # type: ignore[attribute-error]
474474
# collect all predictions to return to user at end
475475
prediction_results.append({"image": image, "predictions": predictions})
476476

roboflow/models/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(
6363
print(f"initalizing local classification model hosted at : {local}")
6464
self.base_url = local
6565

66-
def predict(self, image_path, hosted=False):
66+
def predict(self, image_path, hosted=False): # type: ignore[override]
6767
"""
6868
Run inference on an image.
6969

roboflow/models/inference.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ def predict(self, image_path, prediction_type=None, **kwargs):
121121
params["api_key"] = self.__api_key
122122

123123
params.update(**kwargs)
124-
125-
url = f"{self.api_url}?{urllib.parse.urlencode(params)}"
124+
url = f"{self.api_url}?{urllib.parse.urlencode(params)}" # type: ignore[attr-defined]
126125
response = requests.post(url, **request_kwargs)
127126
response.raise_for_status()
128127

@@ -390,7 +389,7 @@ def download(self, format="pt", location="."):
390389

391390
# write the zip file to the desired location
392391
with open(location + "/weights.pt", "wb") as f:
393-
total_length = int(response.headers.get("content-length"))
392+
total_length = int(response.headers.get("content-length")) # type: ignore[arg-type]
394393
for chunk in tqdm(
395394
response.iter_content(chunk_size=1024),
396395
desc=f"Downloading weights to {location}/weights.pt",

0 commit comments

Comments
 (0)