Skip to content

Commit 3620fda

Browse files
authored
Ban module level imports for cv2, matplotlib and numpy (#292)
1 parent d4a72d9 commit 3620fda

File tree

7 files changed

+55
-17
lines changed

7 files changed

+55
-17
lines changed

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ exclude = [
7575
"tests/manual/debugme.py", # file is intentionally broken
7676
]
7777

78-
7978
# Allow unused variables when underscore-prefixed.
8079
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
8180

@@ -92,6 +91,13 @@ convention = "google"
9291
# Preserve types, even if a file imports `from __future__ import annotations`.
9392
keep-runtime-typing = true
9493

94+
[tool.ruff.lint.flake8-tidy-imports]
95+
banned-module-level-imports = [
96+
"cv2",
97+
"matplotlib",
98+
"numpy",
99+
]
100+
95101
[tool.mypy]
96102
python_version = "3.8"
97103
exclude = ["^build/"]

roboflow/core/version.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import copy
24
import json
35
import os
@@ -6,9 +8,8 @@
68
import time
79
import zipfile
810
from importlib import import_module
9-
from typing import Optional, Union
11+
from typing import TYPE_CHECKING, Optional, Union
1012

11-
import numpy as np
1213
import requests
1314
import yaml
1415
from dotenv import load_dotenv
@@ -28,7 +29,6 @@
2829
)
2930
from roboflow.core.dataset import Dataset
3031
from roboflow.models.classification import ClassificationModel
31-
from roboflow.models.inference import InferenceModel
3232
from roboflow.models.instance_segmentation import InstanceSegmentationModel
3333
from roboflow.models.keypoint_detection import KeypointDetectionModel
3434
from roboflow.models.object_detection import ObjectDetectionModel
@@ -37,6 +37,11 @@
3737
from roboflow.util.general import write_line
3838
from roboflow.util.versions import get_wrong_dependencies_versions, print_warn_for_wrong_dependencies_versions
3939

40+
if TYPE_CHECKING:
41+
import numpy as np
42+
43+
from roboflow.models.inference import InferenceModel
44+
4045
load_dotenv()
4146

4247

@@ -401,6 +406,8 @@ def live_plot(epochs, mAP, loss, title=""):
401406
loss: Union[np.ndarray, list]
402407

403408
if "roboflow-train" in models.keys():
409+
import numpy as np
410+
404411
# training has started
405412
epochs = np.array([int(epoch["epoch"]) for epoch in models["roboflow-train"]["epochs"]])
406413
mAP = np.array([float(epoch["mAP"]) for epoch in models["roboflow-train"]["epochs"]])

roboflow/core/workspace.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import sys
66
from typing import Any, List
77

8-
import numpy as np
98
import requests
10-
from numpy import ndarray
119
from PIL import Image
1210

1311
from roboflow.adapters import rfapi
@@ -407,6 +405,8 @@ def active_learning(
407405
use_localhost: (bool) = determines if local http format used or remote endpoint
408406
local_server: (str) = local http address for inference server, use_localhost must be True for this to be used
409407
""" # noqa: E501 // docs
408+
import numpy as np
409+
410410
prediction_results = []
411411

412412
# ensure that all fields of conditionals have a key:value pair
@@ -528,7 +528,9 @@ def active_learning(
528528

529529
# return predictions with filenames if globbed images from dir,
530530
# otherwise return latest prediction result
531-
return prediction_results if type(raw_data_location) is not ndarray else prediction_results[-1]["predictions"]
531+
return (
532+
prediction_results if type(raw_data_location) is not np.ndarray else prediction_results[-1]["predictions"]
533+
)
532534

533535
def __str__(self):
534536
projects = self.projects()

roboflow/models/object_detection.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
import random
77
import urllib
88

9-
import cv2
10-
import numpy as np
119
import requests
1210
from PIL import Image
1311

@@ -178,6 +176,9 @@ def predict( # type: ignore[override]
178176
original_dimensions = None
179177
# If image is local image
180178
if not hosted:
179+
import cv2
180+
import numpy as np
181+
181182
if isinstance(image_path, str):
182183
image = Image.open(image_path).convert("RGB")
183184
dimensions = image.size
@@ -294,6 +295,7 @@ def webcam(
294295
stroke (int): Stroke width for bounding box
295296
labels (bool): Whether to show labels on bounding box
296297
""" # noqa: E501 // docs
298+
import cv2
297299

298300
os.environ["OPENCV_VIDEOIO_PRIORITY_MSMF"] = "0"
299301

roboflow/util/image_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import os
44
import urllib
55

6-
import cv2
7-
import numpy as np
86
import requests
97
import yaml
108
from PIL import Image
@@ -40,6 +38,9 @@ def mask_image(image, encoded_mask, transparency=60):
4038
:param transparency: alpha transparency of masks for semantic overlays
4139
:returns: CV2 image / numpy.ndarray matrix
4240
"""
41+
import cv2
42+
import numpy as np
43+
4344
np_data = np.fromstring(base64.b64decode(encoded_mask), np.uint8) # type: ignore[no-overload]
4445
mask = cv2.imdecode(np_data, cv2.IMREAD_UNCHANGED)
4546

@@ -71,6 +72,8 @@ def validate_image_path(image_path):
7172

7273

7374
def file2jpeg(image_path):
75+
import cv2
76+
7477
img = cv2.imread(image_path)
7578
image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
7679
pilImage = Image.fromarray(image)

roboflow/util/prediction.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,7 @@
44
import urllib.request
55
import warnings
66

7-
import cv2
8-
import matplotlib.image as mpimg
9-
import matplotlib.pyplot as plt
10-
import numpy as np
117
import requests
12-
from matplotlib import patches
138
from PIL import Image
149

1510
from roboflow.config import (
@@ -29,6 +24,8 @@ def plot_image(image_path):
2924
:param image_path: path of image to be plotted (can be hosted or local)
3025
:return:
3126
"""
27+
import matplotlib.pyplot as plt
28+
3229
validate_image_path(image_path)
3330
try:
3431
img = Image.open(image_path)
@@ -52,6 +49,8 @@ def plot_annotation(axes, prediction=None, stroke=1, transparency=60, colors=Non
5249
:param transparency: alpha transparency of masks for semantic overlays
5350
:return:
5451
"""
52+
from matplotlib import patches
53+
5554
# Object Detection annotation
5655

5756
colors = {} if colors is None else colors
@@ -88,6 +87,8 @@ def plot_annotation(axes, prediction=None, stroke=1, transparency=60, colors=Non
8887
polygon = patches.Polygon(points, linewidth=stroke, edgecolor=stroke_color, facecolor="none")
8988
axes.add_patch(polygon)
9089
elif prediction["prediction_type"] == SEMANTIC_SEGMENTATION_MODEL:
90+
import matplotlib.image as mpimg
91+
9192
encoded_mask = prediction["segmentation_mask"]
9293
mask_bytes = io.BytesIO(base64.b64decode(encoded_mask))
9394
mask = mpimg.imread(mask_bytes, format="JPG")
@@ -121,6 +122,9 @@ def json(self):
121122
return self.json_prediction
122123

123124
def __load_image(self):
125+
import cv2
126+
import numpy as np
127+
124128
if "http://" in self.image_path:
125129
req = urllib.request.urlopen(self.image_path)
126130
arr = np.asarray(bytearray(req.read()), dtype=np.uint8)
@@ -131,6 +135,8 @@ def __load_image(self):
131135
return cv2.imread(self.image_path)
132136

133137
def plot(self, stroke=1):
138+
import matplotlib.pyplot as plt
139+
134140
# Exception to check if image path exists
135141
validate_image_path(self["image_path"])
136142
_, axes = plot_image(self["image_path"])
@@ -146,6 +152,9 @@ def save(self, output_path="predictions.jpg", stroke=2, transparency=60):
146152
:param stroke: line width to use when drawing rectangles and polygons
147153
:param transparency: alpha transparency of masks for semantic overlays
148154
"""
155+
import cv2
156+
import numpy as np
157+
149158
image = self.__load_image()
150159
stroke_color = (255, 0, 0)
151160

@@ -302,6 +311,8 @@ def add_prediction(self, prediction=None):
302311
self.predictions.append(prediction)
303312

304313
def plot(self, stroke=1):
314+
import matplotlib.pyplot as plt
315+
305316
if len(self) > 0:
306317
validate_image_path(self.base_image_path)
307318
_, axes = plot_image(self.base_image_path)
@@ -311,6 +322,9 @@ def plot(self, stroke=1):
311322
plt.show()
312323

313324
def __load_image(self):
325+
import cv2
326+
import numpy as np
327+
314328
# Check if it is a hosted image and open image as needed
315329
if "http://" in self.base_image_path or "https://" in self.base_image_path:
316330
req = urllib.request.urlopen(self.base_image_path)
@@ -322,6 +336,9 @@ def __load_image(self):
322336
return cv2.imread(self.base_image_path)
323337

324338
def save(self, output_path="predictions.jpg", stroke=2):
339+
import cv2
340+
import numpy as np
341+
325342
# Load image based on image path as an array
326343
image = self.__load_image()
327344
stroke_color = (255, 0, 0)

tests/models/test_object_detection.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import unittest
22

3-
import numpy as np
43
import responses
54
from PIL import UnidentifiedImageError
65
from requests.exceptions import HTTPError
@@ -83,6 +82,8 @@ def test_predict_with_local_image_request(self):
8382

8483
@responses.activate
8584
def test_predict_with_a_numpy_array_request(self):
85+
import numpy as np
86+
8687
np_array = np.ones((100, 100, 1), dtype=np.uint8)
8788
instance = ObjectDetectionModel(self.api_key, self.version_id, version=self.version)
8889

0 commit comments

Comments
 (0)