|
| 1 | +from importlib import import_module |
| 2 | +from typing import List, Tuple |
| 3 | + |
| 4 | + |
| 5 | +def get_wrong_dependencies_versions( |
| 6 | + dependencies_versions: List[Tuple[str, str, str]] |
| 7 | +) -> List[Tuple[str, str, str, str]]: |
| 8 | + """ |
| 9 | + Get a list of missmatching dependencies with current version installed. |
| 10 | + E.g., assuming we pass `get_wrong_dependencies_versions([("torch", "==", "1.2.0")]), we will check if the current version of `torch` is `==1.2.0`. If not, we will return `[("torch", "==", "1.2.0", "<current_installed_version>")] |
| 11 | +
|
| 12 | + We support `<=`, `==`, `>=` |
| 13 | +
|
| 14 | + Args: |
| 15 | + dependencies_versions (List[Tuple[str, str]]): List of dependencies we want to check, [("<package_name>", "<version_number_to_check")] |
| 16 | +
|
| 17 | + Returns: |
| 18 | + List[Tuple[str, str, str]]: List of dependencies with wrong version, [("<package_name>", "<version_number_to_check", "<current_version>")] |
| 19 | + """ |
| 20 | + wrong_dependencies_versions = [] |
| 21 | + # from e.g. "1.12.0" -> 1120 |
| 22 | + parse_version_to_int = lambda x: int(x.replace(".", "")) |
| 23 | + order_funcs = { |
| 24 | + "==": lambda x, y: x == y, |
| 25 | + ">=": lambda x, y: x >= y, |
| 26 | + "<=": lambda x, y: x <= y, |
| 27 | + } |
| 28 | + for dependency, order, version in dependencies_versions: |
| 29 | + module = import_module(dependency) |
| 30 | + module_version = module.__version__ |
| 31 | + if order not in order_funcs: |
| 32 | + raise ValueError( |
| 33 | + f"order={order} not supported, please use `{', '.join(order_funcs.keys())}`" |
| 34 | + ) |
| 35 | + is_okay = order_funcs[order]( |
| 36 | + parse_version_to_int(module_version), parse_version_to_int(version) |
| 37 | + ) |
| 38 | + if not is_okay: |
| 39 | + wrong_dependencies_versions.append( |
| 40 | + (dependency, order, version, module_version) |
| 41 | + ) |
| 42 | + return wrong_dependencies_versions |
| 43 | + |
| 44 | + |
| 45 | +def print_warn_for_wrong_dependencies_versions( |
| 46 | + dependencies_versions: List[Tuple[str, str, str]] |
| 47 | +): |
| 48 | + wrong_dependencies_versions = get_wrong_dependencies_versions(dependencies_versions) |
| 49 | + for (dependency, order, version, module_version) in wrong_dependencies_versions: |
| 50 | + print( |
| 51 | + f"Dependency {dependency}{order}{version} is required but found version={module_version}, to fix: `pip install {dependency}{order}{version}`" |
| 52 | + ) |
| 53 | + |
| 54 | + |
| 55 | +def warn_for_wrong_dependencies_versions( |
| 56 | + dependencies_versions: List[Tuple[str, str, str]] |
| 57 | +): |
| 58 | + """ |
| 59 | + Decorator to print a warning based on dependencies versions. E.g. |
| 60 | +
|
| 61 | + ```python |
| 62 | + @warn_for_wrong_dependencies_versions([("torch", "==", "1.2.0")]) |
| 63 | + def foo(x): |
| 64 | + # I only work with torch `1.2.0` but another one is installed |
| 65 | + print(f"foo {x}") |
| 66 | + ``` |
| 67 | +
|
| 68 | + prints: |
| 69 | +
|
| 70 | + ``` |
| 71 | + Dependency torch==1.2.0 is required but found version=1.13.1, to fix: `pip install torch==1.2.0` |
| 72 | + ``` |
| 73 | +
|
| 74 | + Args: |
| 75 | + dependencies_versions (List[Tuple[str, str]]): List of dependencies we want to check, [("<package_name>", "<version_number_to_check")] |
| 76 | + """ |
| 77 | + |
| 78 | + def _inner(func): |
| 79 | + def _wrapper(*args, **kwargs): |
| 80 | + print_warn_for_wrong_dependencies_versions(dependencies_versions) |
| 81 | + func(*args, **kwargs) |
| 82 | + |
| 83 | + return _wrapper |
| 84 | + |
| 85 | + return _inner |
0 commit comments