Skip to content

Commit 27169d1

Browse files
schustmihtahir1
andauthored
Switch from pkg_resources to importlib (#3722)
* Switch from pkg_resources to importlib * Some test fixes * Improve marker parsing * WIP databricks * Databricks fixes * Revert some changes * Linting * Fix import --------- Co-authored-by: Hamza Tahir <[email protected]>
1 parent 378ac31 commit 27169d1

File tree

10 files changed

+164
-245
lines changed

10 files changed

+164
-245
lines changed

src/zenml/cli/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from zenml.io import fileio
5050
from zenml.logger import get_logger
5151
from zenml.utils.io_utils import copy_dir, get_global_config_directory
52+
from zenml.utils.package_utils import get_package_information
5253
from zenml.utils.server_utils import get_local_server
5354
from zenml.utils.yaml_utils import write_yaml
5455

@@ -640,7 +641,7 @@ def info(
640641
}
641642

642643
if all:
643-
user_info["packages"] = cli_utils.get_package_information()
644+
user_info["packages"] = get_package_information()
644645
if packages:
645646
if user_info.get("packages"):
646647
if isinstance(user_info["packages"], dict):
@@ -650,7 +651,7 @@ def info(
650651
if p in packages
651652
}
652653
else:
653-
user_info["query_packages"] = cli_utils.get_package_information(
654+
user_info["query_packages"] = get_package_information(
654655
list(packages)
655656
)
656657
if file:

src/zenml/cli/utils.py

Lines changed: 5 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
)
4141

4242
import click
43-
import pkg_resources
4443
import yaml
4544
from pydantic import BaseModel, SecretStr
4645
from rich import box, table
@@ -80,6 +79,7 @@
8079
from zenml.stack.flavor import Flavor
8180
from zenml.stack.stack_component import StackComponentConfig
8281
from zenml.utils import secret_utils
82+
from zenml.utils.package_utils import requirement_installed
8383
from zenml.utils.time_utils import expires_in
8484
from zenml.utils.typing_utils import get_origin, is_union
8585

@@ -1052,7 +1052,7 @@ def install_packages(
10521052
# just return without doing anything
10531053
return
10541054

1055-
if use_uv and not is_installed_in_python_environment("uv"):
1055+
if use_uv and not requirement_installed("uv"):
10561056
# If uv is installed globally, don't run as a python module
10571057
command = []
10581058
else:
@@ -1094,7 +1094,7 @@ def uninstall_package(package: str, use_uv: bool = False) -> None:
10941094
package: The package to uninstall.
10951095
use_uv: Whether to use uv for package uninstallation.
10961096
"""
1097-
if use_uv and not is_installed_in_python_environment("uv"):
1097+
if use_uv and not requirement_installed("uv"):
10981098
# If uv is installed globally, don't run as a python module
10991099
command = []
11001100
else:
@@ -1110,22 +1110,6 @@ def uninstall_package(package: str, use_uv: bool = False) -> None:
11101110
subprocess.check_call(command)
11111111

11121112

1113-
def is_installed_in_python_environment(package: str) -> bool:
1114-
"""Check if a package is installed in the current python environment.
1115-
1116-
Args:
1117-
package: The package to check.
1118-
1119-
Returns:
1120-
True if the package is installed, False otherwise.
1121-
"""
1122-
try:
1123-
pkg_resources.get_distribution(package)
1124-
return True
1125-
except pkg_resources.DistributionNotFound:
1126-
return False
1127-
1128-
11291113
def is_uv_installed() -> bool:
11301114
"""Check if uv is installed.
11311115
@@ -1141,7 +1125,7 @@ def is_pip_installed() -> bool:
11411125
Returns:
11421126
True if pip is installed, False otherwise.
11431127
"""
1144-
return is_installed_in_python_environment("pip")
1128+
return requirement_installed("pip")
11451129

11461130

11471131
def pretty_print_secret(
@@ -2499,30 +2483,6 @@ def temporary_active_stack(
24992483
Client().activate_stack(old_stack_id)
25002484

25012485

2502-
def get_package_information(
2503-
package_names: Optional[List[str]] = None,
2504-
) -> Dict[str, str]:
2505-
"""Get a dictionary of installed packages.
2506-
2507-
Args:
2508-
package_names: Specific package names to get the information for.
2509-
2510-
Returns:
2511-
A dictionary of the name:version for the package names passed in or
2512-
all packages and their respective versions.
2513-
"""
2514-
import pkg_resources
2515-
2516-
if package_names:
2517-
return {
2518-
pkg.key: pkg.version
2519-
for pkg in pkg_resources.working_set
2520-
if pkg.key in package_names
2521-
}
2522-
2523-
return {pkg.key: pkg.version for pkg in pkg_resources.working_set}
2524-
2525-
25262486
def print_user_info(info: Dict[str, Any]) -> None:
25272487
"""Print user information to the terminal.
25282488
@@ -2617,11 +2577,7 @@ def is_jupyter_installed() -> bool:
26172577
Returns:
26182578
bool: True if Jupyter notebook is installed, False otherwise.
26192579
"""
2620-
try:
2621-
pkg_resources.get_distribution("notebook")
2622-
return True
2623-
except pkg_resources.DistributionNotFound:
2624-
return False
2580+
return requirement_installed("notebook")
26252581

26262582

26272583
def multi_choice_prompt(

src/zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@
1717
import sys
1818
from typing import Any, List, Set
1919

20-
import pkg_resources
20+
if sys.version_info < (3, 10):
21+
from importlib_metadata import distribution
22+
else:
23+
from importlib.metadata import distribution
2124

2225
from zenml.entrypoints.step_entrypoint_configuration import (
2326
StepEntrypointConfiguration,
@@ -81,8 +84,10 @@ def run(self) -> None:
8184
"""Runs the step."""
8285
# Get the wheel package and add it to the sys path
8386
wheel_package = self.entrypoint_args[WHEEL_PACKAGE_OPTION]
84-
distribution = pkg_resources.get_distribution(wheel_package)
85-
project_root = os.path.join(distribution.location, wheel_package)
87+
88+
dist = distribution(wheel_package)
89+
project_root = os.path.join(dist.locate_file("."), wheel_package)
90+
8691
if project_root not in sys.path:
8792
sys.path.insert(0, project_root)
8893
sys.path.insert(-1, project_root)

src/zenml/integrations/integration.py

Lines changed: 23 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,14 @@
1313
# permissions and limitations under the License.
1414
"""Base and meta classes for ZenML integrations."""
1515

16-
import re
1716
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
1817

19-
import pkg_resources
20-
from pkg_resources import Requirement
18+
from packaging.requirements import Requirement
2119

2220
from zenml.integrations.registry import integration_registry
2321
from zenml.logger import get_logger
2422
from zenml.stack.flavor import Flavor
25-
from zenml.utils.integration_utils import parse_requirement
23+
from zenml.utils.package_utils import get_dependencies, requirement_installed
2624

2725
if TYPE_CHECKING:
2826
from zenml.plugins.base_plugin_flavor import BasePluginFlavor
@@ -69,65 +67,32 @@ def check_installation(cls) -> bool:
6967
Returns:
7068
True if all required packages are installed, False otherwise.
7169
"""
72-
for r in cls.get_requirements():
73-
try:
74-
# First check if the base package is installed
75-
dist = pkg_resources.get_distribution(r)
76-
77-
# Next, check if the dependencies (including extras) are
78-
# installed
79-
deps: List[Requirement] = []
80-
81-
_, extras = parse_requirement(r)
82-
if extras:
83-
extra_list = extras[1:-1].split(",")
84-
for extra in extra_list:
85-
try:
86-
requirements = dist.requires(extras=[extra]) # type: ignore[arg-type]
87-
except pkg_resources.UnknownExtra as e:
88-
logger.debug(f"Unknown extra: {str(e)}")
89-
return False
90-
deps.extend(requirements)
91-
else:
92-
deps = dist.requires()
93-
94-
for ri in deps:
95-
try:
96-
# Remove the "extra == ..." part from the requirement string
97-
cleaned_req = re.sub(
98-
r"; extra == \"\w+\"", "", str(ri)
99-
)
100-
pkg_resources.get_distribution(cleaned_req)
101-
except pkg_resources.DistributionNotFound as e:
102-
logger.debug(
103-
f"Unable to find required dependency "
104-
f"'{e.req}' for requirement '{r}' "
105-
f"necessary for integration '{cls.NAME}'."
106-
)
107-
return False
108-
except pkg_resources.VersionConflict as e:
109-
logger.debug(
110-
f"Package version '{e.dist}' does not match "
111-
f"version '{e.req}' required by '{r}' "
112-
f"necessary for integration '{cls.NAME}'."
113-
)
114-
return False
115-
116-
except pkg_resources.DistributionNotFound as e:
117-
logger.debug(
118-
f"Unable to find required package '{e.req}' for "
119-
f"integration {cls.NAME}."
120-
)
121-
return False
122-
except pkg_resources.VersionConflict as e:
70+
for requirement in cls.get_requirements():
71+
parsed_requirement = Requirement(requirement)
72+
73+
if not requirement_installed(parsed_requirement):
12374
logger.debug(
124-
f"Package version '{e.dist}' does not match version "
125-
f"'{e.req}' necessary for integration {cls.NAME}."
75+
"Requirement '%s' for integration '%s' is not installed "
76+
"or installed with the wrong version.",
77+
requirement,
78+
cls.NAME,
12679
)
12780
return False
12881

82+
dependencies = get_dependencies(parsed_requirement)
83+
84+
for dependency in dependencies:
85+
if not requirement_installed(dependency):
86+
logger.debug(
87+
"Requirement '%s' for integration '%s' is not "
88+
"installed or installed with the wrong version.",
89+
dependency,
90+
cls.NAME,
91+
)
92+
return False
93+
12994
logger.debug(
130-
f"Integration {cls.NAME} is installed correctly with "
95+
f"Integration '{cls.NAME}' is installed correctly with "
13196
f"requirements {cls.get_requirements()}."
13297
)
13398
return True

src/zenml/utils/integration_utils.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

0 commit comments

Comments
 (0)