Skip to content

Commit a2fe9b8

Browse files
added file for versions check, forgot to do it before
1 parent ab7b4e3 commit a2fe9b8

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

roboflow/util/versions.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from importlib import import_module
2+
from typing import List, Tuple
3+
4+
5+
def get_wrong_dependencies_versions(
6+
dependencies_versions: List[Tuple[str, str, str]]
7+
) -> List[Tuple[str, str, str, str]]:
8+
"""
9+
Get a list of missmatching dependencies with current version installed.
10+
E.g., assuming we pass `get_wrong_dependencies_versions([("torch", "==", "1.2.0")]), we will check if the current version of `torch` is `==1.2.0`. If not, we will return `[("torch", "==", "1.2.0", "<current_installed_version>")]
11+
12+
We support `<=`, `==`, `>=`
13+
14+
Args:
15+
dependencies_versions (List[Tuple[str, str]]): List of dependencies we want to check, [("<package_name>", "<version_number_to_check")]
16+
17+
Returns:
18+
List[Tuple[str, str, str]]: List of dependencies with wrong version, [("<package_name>", "<version_number_to_check", "<current_version>")]
19+
"""
20+
wrong_dependencies_versions = []
21+
# from e.g. "1.12.0" -> 1120
22+
parse_version_to_int = lambda x: int(x.replace(".", ""))
23+
order_funcs = {
24+
"==": lambda x, y: x == y,
25+
">=": lambda x, y: x >= y,
26+
"<=": lambda x, y: x <= y,
27+
}
28+
for dependency, order, version in dependencies_versions:
29+
module = import_module(dependency)
30+
module_version = module.__version__
31+
if order not in order_funcs:
32+
raise ValueError(
33+
f"order={order} not supported, please use `{', '.join(order_funcs.keys())}`"
34+
)
35+
is_okay = order_funcs[order](
36+
parse_version_to_int(module_version), parse_version_to_int(version)
37+
)
38+
if not is_okay:
39+
wrong_dependencies_versions.append(
40+
(dependency, order, version, module_version)
41+
)
42+
return wrong_dependencies_versions
43+
44+
45+
def print_warn_for_wrong_dependencies_versions(
46+
dependencies_versions: List[Tuple[str, str, str]]
47+
):
48+
wrong_dependencies_versions = get_wrong_dependencies_versions(dependencies_versions)
49+
for (dependency, order, version, module_version) in wrong_dependencies_versions:
50+
print(
51+
f"Dependency {dependency}{order}{version} is required but found version={module_version}, to fix: `pip install {dependency}{order}{version}`"
52+
)
53+
54+
55+
def warn_for_wrong_dependencies_versions(
56+
dependencies_versions: List[Tuple[str, str, str]]
57+
):
58+
"""
59+
Decorator to print a warning based on dependencies versions. E.g.
60+
61+
```python
62+
@warn_for_wrong_dependencies_versions([("torch", "==", "1.2.0")])
63+
def foo(x):
64+
# I only work with torch `1.2.0` but another one is installed
65+
print(f"foo {x}")
66+
```
67+
68+
prints:
69+
70+
```
71+
Dependency torch==1.2.0 is required but found version=1.13.1, to fix: `pip install torch==1.2.0`
72+
```
73+
74+
Args:
75+
dependencies_versions (List[Tuple[str, str]]): List of dependencies we want to check, [("<package_name>", "<version_number_to_check")]
76+
"""
77+
78+
def _inner(func):
79+
def _wrapper(*args, **kwargs):
80+
print_warn_for_wrong_dependencies_versions(dependencies_versions)
81+
func(*args, **kwargs)
82+
83+
return _wrapper
84+
85+
return _inner

0 commit comments

Comments
 (0)