Skip to content

Commit 4dad9d7

Browse files
Merge pull request #102 from roboflow/better-errors-for-wrong-deps-versions
[PoC] Warning for specific versions
2 parents 821bdc0 + 01f706f commit 4dad9d7

File tree

5 files changed

+127
-22
lines changed

5 files changed

+127
-22
lines changed

roboflow/core/version.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@
2525
from roboflow.models.instance_segmentation import InstanceSegmentationModel
2626
from roboflow.models.object_detection import ObjectDetectionModel
2727
from roboflow.models.semantic_segmentation import SemanticSegmentationModel
28+
from roboflow.util.versions import (
29+
print_warn_for_wrong_dependencies_versions,
30+
warn_for_wrong_dependencies_versions,
31+
)
2832

2933
load_dotenv()
3034

@@ -153,6 +157,10 @@ def download(self, model_format=None, location=None, overwrite: bool = True):
153157

154158
self.__wait_if_generating()
155159

160+
if model_format == "yolov8":
161+
# we assume the user will want to use yolov8, for now we only support ultralytics=="8.11.0"
162+
print_warn_for_wrong_dependencies_versions(["ultralytics", "<=", "8.0.20"])
163+
156164
model_format = self.__get_format_identifier(model_format)
157165

158166
if model_format not in self.exports:
@@ -284,6 +292,7 @@ def train(self, speed=None, checkpoint=None) -> bool:
284292

285293
return True
286294

295+
@warn_for_wrong_dependencies_versions(["ultralytics", "<=", "8.0.20"])
287296
def deploy(self, model_type: str, model_path: str) -> None:
288297
"""Uploads provided weights file to Roboflow
289298

roboflow/util/versions.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from importlib import import_module
2+
from typing import List, Tuple
3+
4+
5+
def get_wrong_dependencies_versions(
6+
dependencies_versions: List[Tuple[str, str, str]]
7+
) -> List[Tuple[str, str, str, str]]:
8+
"""
9+
Get a list of missmatching dependencies with current version installed.
10+
E.g., assuming we pass `get_wrong_dependencies_versions([("torch", "==", "1.2.0")]), we will check if the current version of `torch` is `==1.2.0`. If not, we will return `[("torch", "==", "1.2.0", "<current_installed_version>")]
11+
12+
We support `<=`, `==`, `>=`
13+
14+
Args:
15+
dependencies_versions (List[Tuple[str, str]]): List of dependencies we want to check, [("<package_name>", "<version_number_to_check")]
16+
17+
Returns:
18+
List[Tuple[str, str, str]]: List of dependencies with wrong version, [("<package_name>", "<version_number_to_check", "<current_version>")]
19+
"""
20+
wrong_dependencies_versions = []
21+
# from e.g. "1.12.0" -> 1120
22+
parse_version_to_int = lambda x: int(x.replace(".", ""))
23+
order_funcs = {
24+
"==": lambda x, y: x == y,
25+
">=": lambda x, y: x >= y,
26+
"<=": lambda x, y: x <= y,
27+
}
28+
for dependency, order, version in dependencies_versions:
29+
module = import_module(dependency)
30+
module_version = module.__version__
31+
if order not in order_funcs:
32+
raise ValueError(
33+
f"order={order} not supported, please use `{', '.join(order_funcs.keys())}`"
34+
)
35+
is_okay = order_funcs[order](
36+
parse_version_to_int(module_version), parse_version_to_int(version)
37+
)
38+
if not is_okay:
39+
wrong_dependencies_versions.append(
40+
(dependency, order, version, module_version)
41+
)
42+
return wrong_dependencies_versions
43+
44+
45+
def print_warn_for_wrong_dependencies_versions(
46+
dependencies_versions: List[Tuple[str, str, str]]
47+
):
48+
wrong_dependencies_versions = get_wrong_dependencies_versions(dependencies_versions)
49+
for (dependency, order, version, module_version) in wrong_dependencies_versions:
50+
print(
51+
f"Dependency {dependency}{order}{version} is required but found version={module_version}, to fix: `pip install {dependency}{order}{version}`"
52+
)
53+
54+
55+
def warn_for_wrong_dependencies_versions(
56+
dependencies_versions: List[Tuple[str, str, str]]
57+
):
58+
"""
59+
Decorator to print a warning based on dependencies versions. E.g.
60+
61+
```python
62+
@warn_for_wrong_dependencies_versions([("torch", "==", "1.2.0")])
63+
def foo(x):
64+
# I only work with torch `1.2.0` but another one is installed
65+
print(f"foo {x}")
66+
```
67+
68+
prints:
69+
70+
```
71+
Dependency torch==1.2.0 is required but found version=1.13.1, to fix: `pip install torch==1.2.0`
72+
```
73+
74+
Args:
75+
dependencies_versions (List[Tuple[str, str]]): List of dependencies we want to check, [("<package_name>", "<version_number_to_check")]
76+
"""
77+
78+
def _inner(func):
79+
def _wrapper(*args, **kwargs):
80+
print_warn_for_wrong_dependencies_versions(dependencies_versions)
81+
func(*args, **kwargs)
82+
83+
return _wrapper
84+
85+
return _inner

setup.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
with open("README.md", "r") as fh:
1212
long_description = fh.read()
1313

14+
# we are using the packages in `requirements.txt` for now,
15+
# not 100% ideal but will do
16+
with open("requirements.txt", "r") as fh:
17+
install_requires = fh.read().split('\n')
18+
1419
setuptools.setup(
1520
name="roboflow",
1621
version=version,
@@ -20,28 +25,7 @@
2025
long_description=long_description,
2126
long_description_content_type="text/markdown",
2227
url="https://github.com/roboflow-ai/roboflow-python",
23-
install_requires=[
24-
"certifi==2022.12.7",
25-
"chardet==4.0.0",
26-
"cycler==0.10.0",
27-
"glob2",
28-
"idna==2.10",
29-
"kiwisolver>=1.3.1",
30-
"matplotlib",
31-
"numpy>=1.18.5",
32-
"opencv-python-headless>=4.5.1.48",
33-
"Pillow>=7.1.2",
34-
"pyparsing==2.4.7",
35-
"python-dateutil",
36-
"python-dotenv",
37-
"requests",
38-
"requests_toolbelt",
39-
"six",
40-
"urllib3==1.26.6",
41-
"tqdm>=4.41.0",
42-
"PyYAML>=5.3.1",
43-
"wget",
44-
],
28+
install_requires=install_requires,
4529
packages=find_packages(exclude=("tests",)),
4630
extras_require={
4731
"dev": ["flake8", "black==22.3.0", "isort", "responses", "twine", "wheel"],
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.3.0"

tests/util/test_versions.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from importlib import import_module
2+
from roboflow.util.versions import get_wrong_dependencies_versions
3+
import unittest
4+
import sys
5+
from pathlib import Path
6+
7+
class TestVersions(unittest.TestCase):
8+
9+
def test_wrong_dependencies_versions(self):
10+
module_path = "tests.util.dummy_module"
11+
module_version = import_module(module_path).__version__
12+
tests = [
13+
("tests.util.dummy_module", "==", module_version),
14+
("tests.util.dummy_module", "<=", "0.2.0"),
15+
("tests.util.dummy_module", "<=", "1.0.0"),
16+
("tests.util.dummy_module", ">=", "0.1.0"),
17+
("tests.util.dummy_module", ">=", "0.6.0")
18+
19+
]
20+
# true if dep is correc
21+
expected_results = [True, False, True, True, False]
22+
23+
for (test, expected_result) in zip(tests, expected_results):
24+
wrong_dependencies_versions = get_wrong_dependencies_versions([test])
25+
is_correct_dep = len(wrong_dependencies_versions) == 0
26+
self.assertEqual(is_correct_dep, expected_result)

0 commit comments

Comments
 (0)