diff --git a/scripts/pyupgrade.sh b/scripts/pyupgrade.sh new file mode 100755 index 00000000000..b00034b5e6b --- /dev/null +++ b/scripts/pyupgrade.sh @@ -0,0 +1,18 @@ +#!/bin/bash + +# install helper packages + +uv pip install pyupgrade +uv pip install autoflake + +# run py-upgrade recursively to all nested .py files + +find src -name '*.py' -print0 | xargs -0 -n1 python -m pyupgrade --py310-plus + +# run autoflake to remove dangling imports + +autoflake --remove-all-unused-imports --in-place --recursive src + +# run ruff check to verify imports are fixed + +ruff check core --select F401 diff --git a/src/zenml/actions/base_action.py b/src/zenml/actions/base_action.py index 0f57511692a..91f8c93b423 100644 --- a/src/zenml/actions/base_action.py +++ b/src/zenml/actions/base_action.py @@ -14,7 +14,7 @@ """Base implementation of actions.""" from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, Optional, Type +from typing import Any, ClassVar from zenml.enums import PluginType from zenml.event_hub.base_event_hub import BaseEventHub @@ -56,10 +56,10 @@ class BaseActionFlavor(BasePluginFlavor, ABC): TYPE: ClassVar[PluginType] = PluginType.ACTION # Action specific - ACTION_CONFIG_CLASS: ClassVar[Type[ActionConfig]] + ACTION_CONFIG_CLASS: ClassVar[type[ActionConfig]] @classmethod - def get_action_config_schema(cls) -> Dict[str, Any]: + def get_action_config_schema(cls) -> dict[str, Any]: """The config schema for a flavor. Returns: @@ -97,9 +97,9 @@ def get_flavor_response_model(cls, hydrate: bool) -> ActionFlavorResponse: class BaseActionHandler(BasePlugin, ABC): """Implementation for an action handler.""" - _event_hub: Optional[BaseEventHub] = None + _event_hub: BaseEventHub | None = None - def __init__(self, event_hub: Optional[BaseEventHub] = None) -> None: + def __init__(self, event_hub: BaseEventHub | None = None) -> None: """Event source handler initialization. Args: @@ -122,7 +122,7 @@ def __init__(self, event_hub: Optional[BaseEventHub] = None) -> None: @property @abstractmethod - def config_class(self) -> Type[ActionConfig]: + def config_class(self) -> type[ActionConfig]: """Returns the `BasePluginConfig` config. Returns: @@ -131,7 +131,7 @@ def config_class(self) -> Type[ActionConfig]: @property @abstractmethod - def flavor_class(self) -> Type[BaseActionFlavor]: + def flavor_class(self) -> type[BaseActionFlavor]: """Returns the flavor class of the plugin. Returns: @@ -173,7 +173,7 @@ def set_event_hub(self, event_hub: BaseEventHub) -> None: def event_hub_callback( self, - config: Dict[str, Any], + config: dict[str, Any], trigger_execution: TriggerExecutionResponse, auth_context: AuthContext, ) -> None: @@ -442,7 +442,7 @@ def get_action( return action def validate_action_configuration( - self, action_config: Dict[str, Any] + self, action_config: dict[str, Any] ) -> ActionConfig: """Validate and return the action configuration. @@ -464,7 +464,7 @@ def extract_resources( self, action_config: ActionConfig, hydrate: bool = False, - ) -> Dict[ResourceType, BaseResponse[Any, Any, Any]]: + ) -> dict[ResourceType, BaseResponse[Any, Any, Any]]: """Extract related resources for this action. Args: @@ -504,7 +504,6 @@ def _validate_action_request( action: Action request. config: Action configuration instantiated from the request. """ - pass def _process_action_request( self, action: ActionResponse, config: ActionConfig @@ -529,7 +528,6 @@ def _process_action_request( action: Newly created action config: Action configuration instantiated from the response. """ - pass def _validate_action_update( self, @@ -566,7 +564,6 @@ def _validate_action_update( config_update: Action configuration instantiated from the updated action. """ - pass def _process_action_update( self, @@ -599,13 +596,12 @@ def _process_action_update( previous_config: Action configuration instantiated from the original action. """ - pass def _process_action_delete( self, action: ActionResponse, config: ActionConfig, - force: Optional[bool] = False, + force: bool | None = False, ) -> None: """Process an action before it is deleted from the database. @@ -625,7 +621,6 @@ def _process_action_delete( config: Action configuration before the deletion. force: Whether to force deprovision the action. """ - pass def _process_action_response( self, action: ActionResponse, config: ActionConfig @@ -650,7 +645,6 @@ def _process_action_response( action: Action response. config: Action configuration instantiated from the response. """ - pass def _populate_action_response_resources( self, diff --git a/src/zenml/actions/pipeline_run/pipeline_run_action.py b/src/zenml/actions/pipeline_run/pipeline_run_action.py index 0dd591de81f..065580925b0 100644 --- a/src/zenml/actions/pipeline_run/pipeline_run_action.py +++ b/src/zenml/actions/pipeline_run/pipeline_run_action.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Example file of what an action Plugin could look like.""" -from typing import Any, ClassVar, Dict, Optional, Type +from typing import Any, ClassVar from uuid import UUID from zenml.actions.base_action import ( @@ -47,7 +47,7 @@ class PipelineRunActionConfiguration(ActionConfig): """Configuration class to configure a pipeline run action.""" snapshot_id: UUID - run_config: Optional[PipelineRunConfiguration] = None + run_config: PipelineRunConfiguration | None = None # -------------------- Pipeline Run Plugin ----------------------------------- @@ -57,7 +57,7 @@ class PipelineRunActionHandler(BaseActionHandler): """Action handler for running pipelines.""" @property - def config_class(self) -> Type[PipelineRunActionConfiguration]: + def config_class(self) -> type[PipelineRunActionConfiguration]: """Returns the `BasePluginConfig` config. Returns: @@ -66,7 +66,7 @@ def config_class(self) -> Type[PipelineRunActionConfiguration]: return PipelineRunActionConfiguration @property - def flavor_class(self) -> Type[BaseActionFlavor]: + def flavor_class(self) -> type[BaseActionFlavor]: """Returns the flavor class of the plugin. Returns: @@ -169,7 +169,7 @@ def extract_resources( self, action_config: ActionConfig, hydrate: bool = False, - ) -> Dict[ResourceType, BaseResponse[Any, Any, Any]]: + ) -> dict[ResourceType, BaseResponse[Any, Any, Any]]: """Extract related resources for this action. Args: @@ -196,7 +196,7 @@ def extract_resources( f"No snapshot found with id {action_config.snapshot_id}." ) - resources: Dict[ResourceType, BaseResponse[Any, Any, Any]] = { + resources: dict[ResourceType, BaseResponse[Any, Any, Any]] = { ResourceType.PIPELINE_SNAPSHOT: snapshot } @@ -217,11 +217,11 @@ class PipelineRunActionFlavor(BaseActionFlavor): FLAVOR: ClassVar[str] = "builtin" SUBTYPE: ClassVar[PluginSubType] = PluginSubType.PIPELINE_RUN - PLUGIN_CLASS: ClassVar[Type[PipelineRunActionHandler]] = ( + PLUGIN_CLASS: ClassVar[type[PipelineRunActionHandler]] = ( PipelineRunActionHandler ) # EventPlugin specific - ACTION_CONFIG_CLASS: ClassVar[Type[PipelineRunActionConfiguration]] = ( + ACTION_CONFIG_CLASS: ClassVar[type[PipelineRunActionConfiguration]] = ( PipelineRunActionConfiguration ) diff --git a/src/zenml/alerter/base_alerter.py b/src/zenml/alerter/base_alerter.py index 528bc0af0ca..b474d77f4b7 100644 --- a/src/zenml/alerter/base_alerter.py +++ b/src/zenml/alerter/base_alerter.py @@ -14,7 +14,7 @@ """Base class for all ZenML alerters.""" from abc import ABC -from typing import Optional, Type, cast +from typing import cast from pydantic import BaseModel @@ -44,7 +44,7 @@ def config(self) -> BaseAlerterConfig: return cast(BaseAlerterConfig, self._config) def post( - self, message: str, params: Optional[BaseAlerterStepParameters] = None + self, message: str, params: BaseAlerterStepParameters | None = None ) -> bool: """Post a message to a chat service. @@ -58,7 +58,7 @@ def post( return True def ask( - self, question: str, params: Optional[BaseAlerterStepParameters] = None + self, question: str, params: BaseAlerterStepParameters | None = None ) -> bool: """Post a message to a chat service and wait for approval. @@ -88,7 +88,7 @@ def type(self) -> StackComponentType: return StackComponentType.ALERTER @property - def config_class(self) -> Type[BaseAlerterConfig]: + def config_class(self) -> type[BaseAlerterConfig]: """Returns BaseAlerterConfig class. Returns: @@ -97,7 +97,7 @@ def config_class(self) -> Type[BaseAlerterConfig]: return BaseAlerterConfig @property - def implementation_class(self) -> Type[BaseAlerter]: + def implementation_class(self) -> type[BaseAlerter]: """Implementation class. Returns: diff --git a/src/zenml/analytics/__init__.py b/src/zenml/analytics/__init__.py index cb5f6237598..e39cf386c25 100644 --- a/src/zenml/analytics/__init__.py +++ b/src/zenml/analytics/__init__.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """The 'analytics' module of ZenML.""" from contextvars import ContextVar -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from uuid import UUID from zenml.enums import SourceContextTypes @@ -27,7 +27,7 @@ def identify( # type: ignore[return] - metadata: Optional[Dict[str, Any]] = None + metadata: dict[str, Any] | None = None ) -> bool: """Attach metadata to user directly. @@ -64,7 +64,7 @@ def alias(user_id: UUID, previous_id: UUID) -> bool: # type: ignore[return] def group( # type: ignore[return] group_id: UUID, - group_metadata: Optional[Dict[str, Any]] = None, + group_metadata: dict[str, Any] | None = None, ) -> bool: """Attach metadata to a segment group. @@ -83,7 +83,7 @@ def group( # type: ignore[return] def track( # type: ignore[return] event: "AnalyticsEvent", - metadata: Optional[Dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ) -> bool: """Track segment event if user opted-in. diff --git a/src/zenml/analytics/client.py b/src/zenml/analytics/client.py index d5606d1aa7d..93cbe2c1d40 100644 --- a/src/zenml/analytics/client.py +++ b/src/zenml/analytics/client.py @@ -15,7 +15,7 @@ import json import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any from uuid import UUID from zenml.analytics.enums import AnalyticsEvent @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) -class Client(object): +class Client: """The client class for ZenML analytics.""" def __init__(self, send: bool = True, timeout: int = 15) -> None: @@ -40,8 +40,8 @@ def __init__(self, send: bool = True, timeout: int = 15) -> None: self.timeout = timeout def identify( - self, user_id: UUID, traits: Optional[Dict[Any, Any]] - ) -> Tuple[bool, str]: + self, user_id: UUID, traits: dict[Any, Any] | None + ) -> tuple[bool, str]: """Method to identify a user with given traits. Args: @@ -59,7 +59,7 @@ def identify( } return self._enqueue(json.dumps(msg, cls=AnalyticsEncoder)) - def alias(self, user_id: UUID, previous_id: UUID) -> Tuple[bool, str]: + def alias(self, user_id: UUID, previous_id: UUID) -> tuple[bool, str]: """Method to alias user IDs. Args: @@ -81,8 +81,8 @@ def track( self, user_id: UUID, event: "AnalyticsEvent", - properties: Optional[Dict[Any, Any]], - ) -> Tuple[bool, str]: + properties: dict[Any, Any] | None, + ) -> tuple[bool, str]: """Method to track events. Args: @@ -103,8 +103,8 @@ def track( return self._enqueue(json.dumps(msg, cls=AnalyticsEncoder)) def group( - self, user_id: UUID, group_id: UUID, traits: Optional[Dict[Any, Any]] - ) -> Tuple[bool, str]: + self, user_id: UUID, group_id: UUID, traits: dict[Any, Any] | None + ) -> tuple[bool, str]: """Method to group users. Args: @@ -124,7 +124,7 @@ def group( } return self._enqueue(json.dumps(msg, cls=AnalyticsEncoder)) - def _enqueue(self, msg: str) -> Tuple[bool, str]: + def _enqueue(self, msg: str) -> tuple[bool, str]: """Method to queue messages to be sent. Args: diff --git a/src/zenml/analytics/context.py b/src/zenml/analytics/context.py index 12fc9872f6a..ff8feecc8ce 100644 --- a/src/zenml/analytics/context.py +++ b/src/zenml/analytics/context.py @@ -19,7 +19,7 @@ import locale from types import TracebackType -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Optional, Union from uuid import UUID from zenml import __version__ @@ -40,7 +40,7 @@ ServerDeploymentType, ) -Json = Union[Dict[str, Any], List[Any], str, int, float, bool, None] +Json = Union[dict[str, Any], list[Any], str, int, float, bool, None] logger = get_logger(__name__) @@ -57,13 +57,13 @@ def __init__(self) -> None: """ self.analytics_opt_in: bool = False - self.user_id: Optional[UUID] = None - self.external_user_id: Optional[UUID] = None - self.executed_by_service_account: Optional[bool] = None - self.client_id: Optional[UUID] = None - self.server_id: Optional[UUID] = None - self.external_server_id: Optional[UUID] = None - self.server_metadata: Optional[Dict[str, str]] = None + self.user_id: UUID | None = None + self.external_user_id: UUID | None = None + self.executed_by_service_account: bool | None = None + self.client_id: UUID | None = None + self.server_id: UUID | None = None + self.external_server_id: UUID | None = None + self.server_metadata: dict[str, str] | None = None self.database_type: Optional["ServerDatabaseType"] = None self.deployment_type: Optional["ServerDeploymentType"] = None @@ -153,9 +153,9 @@ def __enter__(self) -> "AnalyticsContext": def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> bool: """Exit context manager. @@ -181,7 +181,7 @@ def in_server(self) -> bool: """ return handle_bool_env_var(ENV_ZENML_SERVER) - def identify(self, traits: Optional[Dict[str, Any]] = None) -> bool: + def identify(self, traits: dict[str, Any] | None = None) -> bool: """Identify the user through segment. Args: @@ -221,7 +221,7 @@ def alias(self, user_id: UUID, previous_id: UUID) -> bool: def group( self, group_id: UUID, - traits: Optional[Dict[str, Any]] = None, + traits: dict[str, Any] | None = None, ) -> bool: """Group the user. @@ -245,7 +245,7 @@ def group( def track( self, event: "AnalyticsEvent", - properties: Optional[Dict[str, Any]] = None, + properties: dict[str, Any] | None = None, ) -> bool: """Track an event. diff --git a/src/zenml/analytics/models.py b/src/zenml/analytics/models.py index c4c195fc6fe..be9efb326f1 100644 --- a/src/zenml/analytics/models.py +++ b/src/zenml/analytics/models.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Helper models for ZenML analytics.""" -from typing import Any, ClassVar, Dict, List +from typing import Any, ClassVar from pydantic import BaseModel @@ -28,9 +28,9 @@ class AnalyticsTrackedModelMixin(BaseModel): tracking metadata. """ - ANALYTICS_FIELDS: ClassVar[List[str]] = [] + ANALYTICS_FIELDS: ClassVar[list[str]] = [] - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """Get the analytics metadata for the model. Returns: diff --git a/src/zenml/analytics/request.py b/src/zenml/analytics/request.py index bf4a9a198bc..ece46ad850c 100644 --- a/src/zenml/analytics/request.py +++ b/src/zenml/analytics/request.py @@ -18,7 +18,6 @@ """ import logging -from typing import List import requests @@ -28,7 +27,7 @@ logger = logging.getLogger(__name__) -def post(batch: List[str], timeout: int = 15) -> requests.Response: +def post(batch: list[str], timeout: int = 15) -> requests.Response: """Post a batch of messages to the ZenML analytics server. Args: diff --git a/src/zenml/analytics/utils.py b/src/zenml/analytics/utils.py index 01829b01b19..4e58427abe7 100644 --- a/src/zenml/analytics/utils.py +++ b/src/zenml/analytics/utils.py @@ -15,7 +15,8 @@ import json from functools import wraps -from typing import Any, Callable, Dict, Optional, TypeVar, cast +from typing import Any, TypeVar, cast +from collections.abc import Callable from uuid import UUID from zenml.analytics import identify, track @@ -78,7 +79,7 @@ def __str__(self) -> str: return msg.format(self.message, self.status) -def email_opt_int(opted_in: bool, email: Optional[str], source: str) -> None: +def email_opt_int(opted_in: bool, email: str | None, source: str) -> None: """Track the event of the users response to the email prompt, identify them. Args: @@ -102,7 +103,7 @@ class analytics_disabler: def __init__(self) -> None: """Initialization of the context manager.""" - self.original_value: Optional[bool] = None + self.original_value: bool | None = None def __enter__(self) -> None: """Disable the analytics.""" @@ -112,9 +113,9 @@ def __enter__(self) -> None: def __exit__( self, - exc_type: Optional[Any], - exc_value: Optional[Any], - traceback: Optional[Any], + exc_type: Any | None, + exc_value: Any | None, + traceback: Any | None, ) -> None: """Set it back to the original state. @@ -201,13 +202,13 @@ def inner_func(*args: Any, **kwargs: Any) -> Any: return inner_decorator -class track_handler(object): +class track_handler: """Context handler to enable tracking the success status of an event.""" def __init__( self, event: AnalyticsEvent, - metadata: Optional[Dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ): """Initialization of the context manager. @@ -216,7 +217,7 @@ def __init__( metadata: The metadata of the event. """ self.event: AnalyticsEvent = event - self.metadata: Dict[str, Any] = metadata or {} + self.metadata: dict[str, Any] = metadata or {} def __enter__(self) -> "track_handler": """Enter function of the event handler. @@ -228,9 +229,9 @@ def __enter__(self) -> "track_handler": def __exit__( self, - type_: Optional[Any], - value: Optional[Any], - traceback: Optional[Any], + type_: Any | None, + value: Any | None, + traceback: Any | None, ) -> Any: """Exit function of the event handler. diff --git a/src/zenml/annotators/base_annotator.py b/src/zenml/annotators/base_annotator.py index 4321ee4d0d8..83b98afacd2 100644 --- a/src/zenml/annotators/base_annotator.py +++ b/src/zenml/annotators/base_annotator.py @@ -14,7 +14,7 @@ """Base class for ZenML annotator stack components.""" from abc import ABC, abstractmethod -from typing import Any, ClassVar, List, Tuple, Type, cast +from typing import Any, ClassVar, cast from zenml.enums import StackComponentType from zenml.stack import Flavor, StackComponent @@ -63,7 +63,7 @@ def get_url_for_dataset(self, dataset_name: str) -> str: """ @abstractmethod - def get_datasets(self) -> List[Any]: + def get_datasets(self) -> list[Any]: """Gets the datasets currently available for annotation. Returns: @@ -71,7 +71,7 @@ def get_datasets(self) -> List[Any]: """ @abstractmethod - def get_dataset_names(self) -> List[str]: + def get_dataset_names(self) -> list[str]: """Gets the names of the datasets currently available for annotation. Returns: @@ -79,7 +79,7 @@ def get_dataset_names(self) -> List[str]: """ @abstractmethod - def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]: + def get_dataset_stats(self, dataset_name: str) -> tuple[int, int]: """Gets the statistics of a dataset. Args: @@ -165,7 +165,7 @@ def type(self) -> StackComponentType: return StackComponentType.ANNOTATOR @property - def config_class(self) -> Type[BaseAnnotatorConfig]: + def config_class(self) -> type[BaseAnnotatorConfig]: """Config class for this flavor. Returns: @@ -175,7 +175,7 @@ def config_class(self) -> Type[BaseAnnotatorConfig]: @property @abstractmethod - def implementation_class(self) -> Type[BaseAnnotator]: + def implementation_class(self) -> type[BaseAnnotator]: """Implementation class. Returns: diff --git a/src/zenml/artifact_stores/base_artifact_store.py b/src/zenml/artifact_stores/base_artifact_store.py index e2a18323b95..a062c51f4ba 100644 --- a/src/zenml/artifact_stores/base_artifact_store.py +++ b/src/zenml/artifact_stores/base_artifact_store.py @@ -20,18 +20,11 @@ from pathlib import Path from typing import ( Any, - Callable, ClassVar, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Type, Union, cast, ) +from collections.abc import Callable, Iterable from pydantic import Field, model_validator @@ -85,8 +78,8 @@ def __init__(self, func: Callable[..., Any], fixed_root_path: str) -> None: else: self.allow_local_file_access = True - self.path_args: List[int] = [] - self.path_kwargs: List[str] = [] + self.path_args: list[int] = [] + self.path_kwargs: list[str] = [] for i, param in enumerate( inspect.signature(self.func).parameters.values() ): @@ -208,13 +201,13 @@ class BaseArtifactStoreConfig(StackComponentConfig): "Path must be accessible with the configured credentials and permissions" ) - SUPPORTED_SCHEMES: ClassVar[Set[str]] + SUPPORTED_SCHEMES: ClassVar[set[str]] IS_IMMUTABLE_FILESYSTEM: ClassVar[bool] = False @model_validator(mode="before") @classmethod @before_validator_handler - def _ensure_artifact_store(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def _ensure_artifact_store(cls, data: dict[str, Any]) -> dict[str, Any]: """Validator function for the Artifact Stores. Checks whether supported schemes are defined and the given path is @@ -289,7 +282,7 @@ def path(self) -> str: return self.config.path @property - def custom_cache_key(self) -> Optional[bytes]: + def custom_cache_key(self) -> bytes | None: """Custom cache key. Any artifact store can override this property in case they need @@ -337,7 +330,7 @@ def exists(self, path: PathType) -> bool: """ @abstractmethod - def glob(self, pattern: PathType) -> List[PathType]: + def glob(self, pattern: PathType) -> list[PathType]: """Gets the paths that match a glob pattern. Args: @@ -359,7 +352,7 @@ def isdir(self, path: PathType) -> bool: """ @abstractmethod - def listdir(self, path: PathType) -> List[PathType]: + def listdir(self, path: PathType) -> list[PathType]: """Returns a list of files under a given directory in the filesystem. Args: @@ -425,7 +418,7 @@ def stat(self, path: PathType) -> Any: """ @abstractmethod - def size(self, path: PathType) -> Optional[int]: + def size(self, path: PathType) -> int | None: """Get the size of a file in bytes. Args: @@ -448,8 +441,8 @@ def walk( self, top: PathType, topdown: bool = True, - onerror: Optional[Callable[..., None]] = None, - ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]: + onerror: Callable[..., None] | None = None, + ) -> Iterable[tuple[PathType, list[PathType], list[PathType]]]: """Return an iterator that walks the contents of the given directory. Args: @@ -469,7 +462,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: *args: The positional arguments to pass to the Pydantic object. **kwargs: The keyword arguments to pass to the Pydantic object. """ - super(BaseArtifactStore, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self._add_path_sanitization() # If running in a ZenML server environment, we don't register @@ -497,7 +490,7 @@ def _register(self) -> None: if isinstance(self, LocalFilesystem): return - overloads: Dict[str, Any] = { + overloads: dict[str, Any] = { "SUPPORTED_SCHEMES": self.config.SUPPORTED_SCHEMES, } for method_name, method in inspect.getmembers(BaseArtifactStore): @@ -536,7 +529,7 @@ def type(self) -> StackComponentType: return StackComponentType.ARTIFACT_STORE @property - def config_class(self) -> Type[StackComponentConfig]: + def config_class(self) -> type[StackComponentConfig]: """Config class for this flavor. Returns: @@ -546,7 +539,7 @@ def config_class(self) -> Type[StackComponentConfig]: @property @abstractmethod - def implementation_class(self) -> Type["BaseArtifactStore"]: + def implementation_class(self) -> type["BaseArtifactStore"]: """Implementation class. Returns: diff --git a/src/zenml/artifact_stores/local_artifact_store.py b/src/zenml/artifact_stores/local_artifact_store.py index 67f44abad18..c6358761c71 100644 --- a/src/zenml/artifact_stores/local_artifact_store.py +++ b/src/zenml/artifact_stores/local_artifact_store.py @@ -19,7 +19,7 @@ """ import os -from typing import TYPE_CHECKING, ClassVar, Optional, Set, Type, Union +from typing import TYPE_CHECKING, ClassVar, Union from pydantic import field_validator @@ -46,7 +46,7 @@ class LocalArtifactStoreConfig(BaseArtifactStoreConfig): path: The path to the local artifact store. """ - SUPPORTED_SCHEMES: ClassVar[Set[str]] = {""} + SUPPORTED_SCHEMES: ClassVar[set[str]] = {""} path: str = "" @@ -88,7 +88,7 @@ class LocalArtifactStore(LocalFilesystem, BaseArtifactStore): All methods are inherited from the default `LocalFilesystem`. """ - _path: Optional[str] = None + _path: str | None = None @staticmethod def get_default_local_path(id_: "UUID") -> str: @@ -126,7 +126,7 @@ def path(self) -> str: return self._path @property - def local_path(self) -> Optional[str]: + def local_path(self) -> str | None: """Returns the local path of the artifact store. Returns: @@ -135,7 +135,7 @@ def local_path(self) -> Optional[str]: return self.path @property - def custom_cache_key(self) -> Optional[bytes]: + def custom_cache_key(self) -> bytes | None: """Custom cache key. The client ID is returned here to invalidate caching when using the same @@ -160,7 +160,7 @@ def name(self) -> str: return "local" @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -169,7 +169,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -187,7 +187,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/artifact_store/local.svg" @property - def config_class(self) -> Type[LocalArtifactStoreConfig]: + def config_class(self) -> type[LocalArtifactStoreConfig]: """Config class for this flavor. Returns: @@ -196,7 +196,7 @@ def config_class(self) -> Type[LocalArtifactStoreConfig]: return LocalArtifactStoreConfig @property - def implementation_class(self) -> Type[LocalArtifactStore]: + def implementation_class(self) -> type[LocalArtifactStore]: """Implementation class. Returns: diff --git a/src/zenml/artifacts/artifact_config.py b/src/zenml/artifacts/artifact_config.py index 5a240b34bc6..ee853008898 100644 --- a/src/zenml/artifacts/artifact_config.py +++ b/src/zenml/artifacts/artifact_config.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Artifact Config classes to support Model Control Plane feature.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any from pydantic import BaseModel, Field, model_validator @@ -60,19 +60,19 @@ def my_step() -> Annotated[ is used. """ - name: Optional[str] = None - version: Optional[Union[str, int]] = Field( + name: str | None = None + version: str | int | None = Field( default=None, union_mode="smart" ) - tags: Optional[List[str]] = None - run_metadata: Optional[Dict[str, MetadataType]] = None + tags: list[str] | None = None + run_metadata: dict[str, MetadataType] | None = None - artifact_type: Optional[ArtifactType] = None + artifact_type: ArtifactType | None = None @model_validator(mode="before") @classmethod @before_validator_handler - def _remove_old_attributes(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def _remove_old_attributes(cls, data: dict[str, Any]) -> dict[str, Any]: """Remove old attributes that are not used anymore. Args: diff --git a/src/zenml/artifacts/external_artifact.py b/src/zenml/artifacts/external_artifact.py index fdb1615da2f..443604380d1 100644 --- a/src/zenml/artifacts/external_artifact.py +++ b/src/zenml/artifacts/external_artifact.py @@ -14,7 +14,7 @@ """External artifact definition.""" import os -from typing import Any, Optional, Type, Union +from typing import Any, Union from uuid import UUID, uuid4 from pydantic import Field, model_validator @@ -27,7 +27,7 @@ from zenml.logger import get_logger from zenml.materializers.base_materializer import BaseMaterializer -MaterializerClassOrSource = Union[str, Source, Type[BaseMaterializer]] +MaterializerClassOrSource = Union[str, Source, type[BaseMaterializer]] logger = get_logger(__name__) @@ -71,8 +71,8 @@ def my_pipeline(): ``` """ - value: Optional[Any] = None - materializer: Optional[MaterializerClassOrSource] = Field( + value: Any | None = None + materializer: MaterializerClassOrSource | None = Field( default=None, union_mode="left_to_right" ) store_artifact_metadata: bool = True diff --git a/src/zenml/artifacts/external_artifact_config.py b/src/zenml/artifacts/external_artifact_config.py index 5a06df07031..505cd3efcba 100644 --- a/src/zenml/artifacts/external_artifact_config.py +++ b/src/zenml/artifacts/external_artifact_config.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """External artifact definition.""" -from typing import Any, Dict, Optional +from typing import Any from uuid import UUID from pydantic import BaseModel, model_validator @@ -30,12 +30,12 @@ class ExternalArtifactConfiguration(BaseModel): Lightweight class to pass in the steps for runtime inference. """ - id: Optional[UUID] = None + id: UUID | None = None @model_validator(mode="before") @classmethod @before_validator_handler - def _remove_old_attributes(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def _remove_old_attributes(cls, data: dict[str, Any]) -> dict[str, Any]: """Remove old attributes that are not used anymore. Args: diff --git a/src/zenml/artifacts/preexisting_data_materializer.py b/src/zenml/artifacts/preexisting_data_materializer.py index 517cbb7b089..9d6714050ea 100644 --- a/src/zenml/artifacts/preexisting_data_materializer.py +++ b/src/zenml/artifacts/preexisting_data_materializer.py @@ -15,7 +15,7 @@ import os from pathlib import Path -from typing import Any, ClassVar, Tuple, Type +from typing import Any, ClassVar from zenml.enums import ArtifactType from zenml.io import fileio @@ -32,11 +32,11 @@ class PreexistingDataMaterializer(BaseMaterializer): This materializer solely supports the `register_artifact` function. """ - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Path,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (Path,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA SKIP_REGISTRATION: ClassVar[bool] = True - def load(self, data_type: Type[Any]) -> Any: + def load(self, data_type: type[Any]) -> Any: """Copy the artifact file(s) to a local temp directory. Args: diff --git a/src/zenml/artifacts/utils.py b/src/zenml/artifacts/utils.py index 78d7ca3aafd..257b2e5a9e2 100644 --- a/src/zenml/artifacts/utils.py +++ b/src/zenml/artifacts/utils.py @@ -23,10 +23,7 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Type, Union, cast, ) @@ -80,7 +77,7 @@ from zenml.metadata.metadata_types import MetadataType from zenml.zen_stores.base_zen_store import BaseZenStore - MaterializerClassOrSource = Union[str, Source, Type[BaseMaterializer]] + MaterializerClassOrSource = Union[str, Source, type[BaseMaterializer]] logger = get_logger(__name__) @@ -92,7 +89,7 @@ def _save_artifact_visualizations( data: Any, materializer: "BaseMaterializer" -) -> List[ArtifactVisualizationRequest]: +) -> list[ArtifactVisualizationRequest]: """Save artifact visualizations. Args: @@ -122,15 +119,15 @@ def _store_artifact_data_and_prepare_request( data: Any, name: str, uri: str, - materializer_class: Type["BaseMaterializer"], + materializer_class: type["BaseMaterializer"], save_type: ArtifactSaveType, - version: Optional[Union[int, str]] = None, - artifact_type: Optional[ArtifactType] = None, - tags: Optional[List[str]] = None, + version: int | str | None = None, + artifact_type: ArtifactType | None = None, + tags: list[str] | None = None, store_metadata: bool = True, store_visualizations: bool = True, has_custom_name: bool = True, - metadata: Optional[Dict[str, "MetadataType"]] = None, + metadata: dict[str, "MetadataType"] | None = None, ) -> ArtifactVersionRequest: """Store artifact data and prepare a request to the server. @@ -179,7 +176,7 @@ def _store_artifact_data_and_prepare_request( else None ) - combined_metadata: Dict[str, "MetadataType"] = {} + combined_metadata: dict[str, "MetadataType"] = {} if store_metadata: try: combined_metadata = materializer.extract_full_metadata(data) @@ -219,14 +216,14 @@ def _store_artifact_data_and_prepare_request( def save_artifact( data: Any, name: str, - version: Optional[Union[int, str]] = None, - artifact_type: Optional[ArtifactType] = None, - tags: Optional[List[str]] = None, + version: int | str | None = None, + artifact_type: ArtifactType | None = None, + tags: list[str] | None = None, extract_metadata: bool = True, include_visualizations: bool = True, - user_metadata: Optional[Dict[str, "MetadataType"]] = None, + user_metadata: dict[str, "MetadataType"] | None = None, materializer: Optional["MaterializerClassOrSource"] = None, - uri: Optional[str] = None, + uri: str | None = None, # TODO: remove these once external artifact does not use this function anymore save_type: ArtifactSaveType = ArtifactSaveType.MANUAL, has_custom_name: bool = True, @@ -316,11 +313,11 @@ def save_artifact( def register_artifact( folder_or_file_uri: str, name: str, - version: Optional[Union[int, str]] = None, - artifact_type: Optional[ArtifactType] = None, - tags: Optional[List[str]] = None, + version: int | str | None = None, + artifact_type: ArtifactType | None = None, + tags: list[str] | None = None, has_custom_name: bool = True, - artifact_metadata: Dict[str, "MetadataType"] = {}, + artifact_metadata: dict[str, "MetadataType"] = {}, ) -> "ArtifactVersionResponse": """Register existing data stored in the artifact store as a ZenML Artifact. @@ -389,8 +386,8 @@ def register_artifact( def load_artifact( - name_or_id: Union[str, UUID], - version: Optional[str] = None, + name_or_id: str | UUID, + version: str | None = None, ) -> Any: """Load an artifact. @@ -407,9 +404,9 @@ def load_artifact( def log_artifact_metadata( - metadata: Dict[str, "MetadataType"], - artifact_name: Optional[str] = None, - artifact_version: Optional[str] = None, + metadata: dict[str, "MetadataType"], + artifact_name: str | None = None, + artifact_version: str | None = None, ) -> None: """Log artifact metadata. @@ -659,7 +656,7 @@ def get_producer_step_of_artifact( def get_artifacts_versions_of_pipeline_run( pipeline_run: "PipelineRunResponse", only_produced: bool = False -) -> List["ArtifactVersionResponse"]: +) -> list["ArtifactVersionResponse"]: """Get all artifact versions produced during a pipeline run. Args: @@ -670,7 +667,7 @@ def get_artifacts_versions_of_pipeline_run( Returns: A list of all artifact versions produced during the pipeline run. """ - artifact_versions: List["ArtifactVersionResponse"] = [] + artifact_versions: list["ArtifactVersionResponse"] = [] for step in pipeline_run.steps.values(): if not only_produced or step.status == ExecutionStatus.COMPLETED: for output in step.outputs.values(): @@ -925,7 +922,7 @@ def _load_file_from_artifact_store( artifact_store: "BaseArtifactStore", mode: str = "rb", offset: int = 0, - length: Optional[int] = None, + length: int | None = None, ) -> Any: """Load the given uri from the given artifact store. @@ -965,7 +962,7 @@ def _load_file_from_artifact_store( f"File '{uri}' does not exist in artifact store " f"'{artifact_store.name}'." ) - except (IOError, IllegalOperationError) as e: + except (OSError, IllegalOperationError) as e: raise e except Exception as e: logger.exception(e) diff --git a/src/zenml/cli/__init__.py b/src/zenml/cli/__init__.py index 138f64c915f..b49aac4a4bd 100644 --- a/src/zenml/cli/__init__.py +++ b/src/zenml/cli/__init__.py @@ -2510,7 +2510,7 @@ def my_pipeline(...): from zenml.cli.downgrade import * # noqa from zenml.cli.feature import * # noqa from zenml.cli.integration import * # noqa -from zenml.cli.login import * +from zenml.cli.login import * # noqa from zenml.cli.model import * # noqa from zenml.cli.model_registry import * # noqa from zenml.cli.pipeline import * # noqa diff --git a/src/zenml/cli/annotator.py b/src/zenml/cli/annotator.py index af5038218d9..cadc00f6c0b 100644 --- a/src/zenml/cli/annotator.py +++ b/src/zenml/cli/annotator.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Functionality for annotator CLI subcommands.""" -from typing import TYPE_CHECKING, Tuple, cast +from typing import TYPE_CHECKING, cast import click @@ -161,7 +161,7 @@ def dataset_delete( def dataset_annotate( annotator: "BaseAnnotator", dataset_name: str, - kwargs: Tuple[str, ...], + kwargs: tuple[str, ...], ) -> None: """Command to launch the annotation interface for a dataset. diff --git a/src/zenml/cli/artifact.py b/src/zenml/cli/artifact.py index a045b3c8027..e0121f2e69b 100644 --- a/src/zenml/cli/artifact.py +++ b/src/zenml/cli/artifact.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """CLI functionality to interact with artifacts.""" -from typing import Any, Dict, List, Optional +from typing import Any import click @@ -80,9 +80,9 @@ def list_artifacts(**kwargs: Any) -> None: ) def update_artifact( artifact_name_or_id: str, - name: Optional[str] = None, - tag: Optional[List[str]] = None, - remove_tag: Optional[List[str]] = None, + name: str | None = None, + tag: list[str] | None = None, + remove_tag: list[str] | None = None, ) -> None: """Update an artifact by ID or name. @@ -164,9 +164,9 @@ def list_artifact_versions(**kwargs: Any) -> None: ) def update_artifact_version( name_id_or_prefix: str, - version: Optional[str] = None, - tag: Optional[List[str]] = None, - remove_tag: Optional[List[str]] = None, + version: str | None = None, + tag: list[str] | None = None, + remove_tag: list[str] | None = None, ) -> None: """Update an artifact version by ID or artifact name. @@ -298,7 +298,7 @@ def prune_artifacts( def _artifact_version_to_print( artifact_version: ArtifactVersionResponse, -) -> Dict[str, Any]: +) -> dict[str, Any]: return { "id": artifact_version.id, "name": artifact_version.artifact.name, @@ -313,7 +313,7 @@ def _artifact_version_to_print( def _artifact_to_print( artifact_version: ArtifactResponse, -) -> Dict[str, Any]: +) -> dict[str, Any]: return { "id": artifact_version.id, "name": artifact_version.name, diff --git a/src/zenml/cli/base.py b/src/zenml/cli/base.py index d80f428b92d..c95949e3c91 100644 --- a/src/zenml/cli/base.py +++ b/src/zenml/cli/base.py @@ -17,7 +17,6 @@ import subprocess import tempfile from pathlib import Path -from typing import Optional, Tuple import click from packaging import version @@ -135,9 +134,9 @@ def copier_github_url(self) -> str: hidden=True, ) def init( - path: Optional[Path], - template: Optional[str] = None, - template_tag: Optional[str] = None, + path: Path | None, + template: str | None = None, + template_tag: str | None = None, template_with_defaults: bool = False, test: bool = False, ) -> None: @@ -214,7 +213,7 @@ def init( console.print(prompt_message, width=80) # Check if template is a URL or a preset template name - vcs_ref: Optional[str] = None + vcs_ref: str | None = None if template in ZENML_PROJECT_TEMPLATES: declare(f"Using the {template} template...") zenml_project_template = ZENML_PROJECT_TEMPLATES[template] @@ -591,7 +590,7 @@ def _prompt_email(event_source: AnalyticsEventSource) -> bool: type=bool, ) def info( - packages: Tuple[str], + packages: tuple[str], all: bool = False, file: str = "", stack: bool = False, @@ -731,8 +730,8 @@ def migrate_database(skip_default_registrations: bool = False) -> None: type=bool, ) def backup_database( - strategy: Optional[str] = None, - location: Optional[str] = None, + strategy: str | None = None, + location: str | None = None, overwrite: bool = False, ) -> None: """Backup the ZenML database. @@ -795,8 +794,8 @@ def backup_database( type=bool, ) def restore_database( - strategy: Optional[str] = None, - location: Optional[str] = None, + strategy: str | None = None, + location: str | None = None, cleanup: bool = False, ) -> None: """Restore the ZenML database. diff --git a/src/zenml/cli/cli.py b/src/zenml/cli/cli.py index baa29bafbe5..28da0aab441 100644 --- a/src/zenml/cli/cli.py +++ b/src/zenml/cli/cli.py @@ -14,7 +14,8 @@ """Core CLI functionality.""" import os -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any +from collections.abc import Sequence import click import rich @@ -38,12 +39,12 @@ class TagGroup(click.Group): def __init__( self, - name: Optional[str] = None, - tag: Optional[CliCategories] = None, - commands: Optional[ - Union[Dict[str, click.Command], Sequence[click.Command]] - ] = None, - **kwargs: Dict[str, Any], + name: str | None = None, + tag: CliCategories | None = None, + commands: None | ( + dict[str, click.Command] | Sequence[click.Command] + ) = None, + **kwargs: dict[str, Any], ) -> None: """Initialize the Tag group. @@ -53,7 +54,7 @@ def __init__( commands: The commands of the group. kwargs: Additional keyword arguments. """ - super(TagGroup, self).__init__(name, commands, **kwargs) + super().__init__(name, commands, **kwargs) self.tag = tag or CliCategories.OTHER_COMMANDS @@ -99,8 +100,8 @@ def format_commands( ctx: The click context. formatter: The click formatter. """ - commands: List[ - Tuple[CliCategories, str, Union[Command, TagGroup]] + commands: list[ + tuple[CliCategories, str, Command | TagGroup] ] = [] for subcommand in self.list_commands(ctx): cmd = self.get_command(ctx, subcommand) @@ -132,7 +133,7 @@ def format_commands( ), ) ) - rows: List[Tuple[str, str, str]] = [] + rows: list[tuple[str, str, str]] = [] for tag, subcommand, cmd in commands: help_ = cmd.get_short_help_str(limit=formatter.width) rows.append((tag.value, subcommand, help_)) diff --git a/src/zenml/cli/code_repository.py b/src/zenml/cli/code_repository.py index d83ab6e75cf..61d1fd98379 100644 --- a/src/zenml/cli/code_repository.py +++ b/src/zenml/cli/code_repository.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """CLI functionality to interact with code repositories.""" -from typing import Any, List, Optional +from typing import Any import click @@ -81,10 +81,10 @@ def code_repository() -> None: def register_code_repository( name: str, type_: str, - source_path: Optional[str], - description: Optional[str], - logo_url: Optional[str], - args: List[str], + source_path: str | None, + description: str | None, + logo_url: str | None, + args: list[str], ) -> None: """Register a code repository. @@ -243,10 +243,10 @@ def list_code_repositories(**kwargs: Any) -> None: ) def update_code_repository( name_or_id: str, - name: Optional[str], - description: Optional[str], - logo_url: Optional[str], - args: List[str], + name: str | None, + description: str | None, + logo_url: str | None, + args: list[str], ) -> None: """Update a code repository. diff --git a/src/zenml/cli/deployment.py b/src/zenml/cli/deployment.py index 793f115bb5e..3a6640f1f4c 100644 --- a/src/zenml/cli/deployment.py +++ b/src/zenml/cli/deployment.py @@ -14,7 +14,7 @@ """CLI functionality to interact with deployments.""" import json -from typing import Any, List, Optional +from typing import Any from uuid import UUID import click @@ -192,10 +192,10 @@ def describe_deployment( ) def provision_deployment( deployment_name_or_id: str, - snapshot_name_or_id: Optional[str] = None, - pipeline_name_or_id: Optional[str] = None, + snapshot_name_or_id: str | None = None, + pipeline_name_or_id: str | None = None, overtake: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> None: """Provision a deployment. @@ -209,7 +209,7 @@ def provision_deployment( timeout: The maximum time in seconds to wait for the deployment to be provisioned. """ - snapshot_id: Optional[UUID] = None + snapshot_id: UUID | None = None if snapshot_name_or_id: snapshot = fetch_snapshot(snapshot_name_or_id, pipeline_name_or_id) snapshot_id = snapshot.id @@ -303,13 +303,13 @@ def provision_deployment( "deprovisioned.", ) def deprovision_deployment( - deployment_name_or_id: Optional[str] = None, + deployment_name_or_id: str | None = None, all: bool = False, mine: bool = False, yes: bool = False, ignore_errors: bool = False, max_count: int = 10, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> None: """Deprovision a deployment. @@ -465,12 +465,12 @@ def deprovision_deployment( help="Force the deletion of the deployment if it cannot be deprovisioned.", ) def delete_deployment( - deployment_name_or_id: Optional[str] = None, + deployment_name_or_id: str | None = None, all: bool = False, mine: bool = False, ignore_errors: bool = False, yes: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, max_count: int = 20, force: bool = False, ) -> None: @@ -603,8 +603,8 @@ def refresh_deployment( @click.argument("args", nargs=-1, type=click.UNPROCESSED) def invoke_deployment( deployment_name_or_id: str, - args: List[str], - timeout: Optional[int] = None, + args: list[str], + timeout: int | None = None, ) -> None: """Call a deployment with arguments. @@ -677,7 +677,7 @@ def invoke_deployment( def log_deployment( deployment_name_or_id: str, follow: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> None: """Get the logs of a deployment. diff --git a/src/zenml/cli/formatter.py b/src/zenml/cli/formatter.py index 1d35c757d36..6da6e554d9c 100644 --- a/src/zenml/cli/formatter.py +++ b/src/zenml/cli/formatter.py @@ -13,13 +13,13 @@ # permissions and limitations under the License. """Helper functions to format output for CLI.""" -from typing import Dict, Iterable, Iterator, Optional, Sequence, Tuple +from collections.abc import Iterable, Iterator, Sequence from click import formatting from click._compat import term_len -def measure_table(rows: Iterable[Tuple[str, ...]]) -> Tuple[int, ...]: +def measure_table(rows: Iterable[tuple[str, ...]]) -> tuple[int, ...]: """Measure the width of each column in a table. Args: @@ -28,7 +28,7 @@ def measure_table(rows: Iterable[Tuple[str, ...]]) -> Tuple[int, ...]: Returns: A tuple of the width of each column. """ - widths: Dict[int, int] = {} + widths: dict[int, int] = {} for row in rows: for idx, col in enumerate(row): widths[idx] = max(widths.get(idx, 0), term_len(col)) @@ -37,9 +37,9 @@ def measure_table(rows: Iterable[Tuple[str, ...]]) -> Tuple[int, ...]: def iter_rows( - rows: Iterable[Tuple[str, ...]], + rows: Iterable[tuple[str, ...]], col_count: int, -) -> Iterator[Tuple[str, ...]]: +) -> Iterator[tuple[str, ...]]: """Iterate over rows of a table. Args: @@ -59,8 +59,8 @@ class ZenFormatter(formatting.HelpFormatter): def __init__( self, indent_increment: int = 2, - width: Optional[int] = None, - max_width: Optional[int] = None, + width: int | None = None, + max_width: int | None = None, ) -> None: """Initialize the formatter. @@ -70,12 +70,12 @@ def __init__( width: The maximum width of the help output. max_width: The maximum width of the help output. """ - super(ZenFormatter, self).__init__(indent_increment, width, max_width) + super().__init__(indent_increment, width, max_width) self.current_indent = 0 def write_dl( self, - rows: Sequence[Tuple[str, ...]], + rows: Sequence[tuple[str, ...]], col_max: int = 30, col_spacing: int = 2, ) -> None: diff --git a/src/zenml/cli/integration.py b/src/zenml/cli/integration.py index 3ec7d9ece2a..1cfaf848e2c 100644 --- a/src/zenml/cli/integration.py +++ b/src/zenml/cli/integration.py @@ -16,7 +16,6 @@ import os import subprocess import sys -from typing import Optional, Tuple import click from rich.progress import track @@ -76,7 +75,7 @@ def list_integrations() -> None: name="requirements", help="List all requirements for an integration." ) @click.argument("integration_name", required=False, default=None) -def get_requirements(integration_name: Optional[str] = None) -> None: +def get_requirements(integration_name: str | None = None) -> None: """List all requirements for the chosen integration. Args: @@ -150,9 +149,9 @@ def get_requirements(integration_name: Optional[str] = None) -> None: help="Add the exported requirements to your current Poetry project.", ) def export_requirements( - integrations: Tuple[str], - ignore_integration: Tuple[str], - output_file: Optional[str] = None, + integrations: tuple[str], + ignore_integration: tuple[str], + output_file: str | None = None, overwrite: bool = False, installed_only: bool = False, poetry: bool = False, @@ -282,8 +281,8 @@ def export_requirements( default=False, ) def install( - integrations: Tuple[str], - ignore_integration: Tuple[str], + integrations: tuple[str], + ignore_integration: tuple[str], force: bool = False, uv: bool = False, ) -> None: @@ -402,7 +401,7 @@ def install( default=False, ) def uninstall( - integrations: Tuple[str], force: bool = False, uv: bool = False + integrations: tuple[str], force: bool = False, uv: bool = False ) -> None: """Uninstalls the required packages for a given integration. @@ -483,7 +482,7 @@ def uninstall( default=False, ) def upgrade( - integrations: Tuple[str], + integrations: tuple[str], force: bool = False, uv: bool = False, ) -> None: diff --git a/src/zenml/cli/login.py b/src/zenml/cli/login.py index 9daadcbfbf5..1cd11ad2b94 100644 --- a/src/zenml/cli/login.py +++ b/src/zenml/cli/login.py @@ -18,7 +18,7 @@ import re import sys import time -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any from uuid import UUID import click @@ -123,13 +123,13 @@ def _display_login_menu() -> LoginMethod: def start_local_server( docker: bool = False, - ip_address: Union[ - ipaddress.IPv4Address, ipaddress.IPv6Address, None - ] = None, - port: Optional[int] = None, + ip_address: ( + ipaddress.IPv4Address | ipaddress.IPv6Address | None + ) = None, + port: int | None = None, blocking: bool = False, - image: Optional[str] = None, - ngrok_token: Optional[str] = None, + image: str | None = None, + ngrok_token: str | None = None, restart: bool = False, ) -> None: """Start the ZenML dashboard locally and connect the client to it. @@ -178,7 +178,7 @@ def start_local_server( deployer = LocalServerDeployer() - config_attrs: Dict[str, Any] = dict( + config_attrs: dict[str, Any] = dict( provider=provider, ) if not docker: @@ -218,8 +218,8 @@ def start_local_server( def connect_to_server( url: str, - api_key: Optional[str] = None, - verify_ssl: Union[str, bool] = True, + api_key: str | None = None, + verify_ssl: str | bool = True, refresh: bool = False, pro_server: bool = False, ) -> None: @@ -300,11 +300,11 @@ def connect_to_server( def connect_to_pro_server( - pro_server: Optional[str] = None, - api_key: Optional[str] = None, + pro_server: str | None = None, + api_key: str | None = None, refresh: bool = False, - pro_api_url: Optional[str] = None, - verify_ssl: Union[str, bool] = True, + pro_api_url: str | None = None, + verify_ssl: str | bool = True, ) -> None: """Connect the client to a ZenML Pro server. @@ -403,7 +403,7 @@ def connect_to_pro_server( "your session expires." ) - workspace_id: Optional[str] = None + workspace_id: str | None = None if token.device_metadata: # TODO: is this still correct? workspace_id = token.device_metadata.get("tenant_id") @@ -524,7 +524,7 @@ def connect_to_pro_server( def is_pro_server( url: str, -) -> Tuple[Optional[bool], Optional[str]]: +) -> tuple[bool | None, str | None]: """Check if the server at the given URL is a ZenML Pro server. Args: @@ -805,23 +805,23 @@ def _fail_if_authentication_environment_variables_set() -> None: "to a self-hosted ZenML Pro deployment.", ) def login( - server: Optional[str] = None, + server: str | None = None, pro: bool = False, refresh: bool = False, api_key: bool = False, no_verify_ssl: bool = False, - ssl_ca_cert: Optional[str] = None, + ssl_ca_cert: str | None = None, local: bool = False, docker: bool = False, restart: bool = False, - ip_address: Union[ - ipaddress.IPv4Address, ipaddress.IPv6Address, None - ] = None, - port: Optional[int] = None, + ip_address: ( + ipaddress.IPv4Address | ipaddress.IPv6Address | None + ) = None, + port: int | None = None, blocking: bool = False, - image: Optional[str] = None, - ngrok_token: Optional[str] = None, - pro_api_url: Optional[str] = None, + image: str | None = None, + ngrok_token: str | None = None, + pro_api_url: str | None = None, ) -> None: """Connect to a remote ZenML server. @@ -867,7 +867,7 @@ def login( ) return - api_key_value: Optional[str] = None + api_key_value: str | None = None if api_key: # Read the API key from the user api_key_value = click.prompt( @@ -890,14 +890,14 @@ def login( return # Get the server that the client is currently connected to, if any - current_non_local_server: Optional[str] = None + current_non_local_server: str | None = None gc = GlobalConfiguration() store_cfg = gc.store_configuration if store_cfg.type == StoreType.REST: if not connected_to_local_server(): current_non_local_server = store_cfg.url - verify_ssl: Union[str, bool] = ( + verify_ssl: str | bool = ( ssl_ca_cert if ssl_ca_cert is not None else not no_verify_ssl ) @@ -1106,11 +1106,11 @@ def login( "from a self-hosted ZenML Pro deployment.", ) def logout( - server: Optional[str] = None, + server: str | None = None, local: bool = False, clear: bool = False, pro: bool = False, - pro_api_url: Optional[str] = None, + pro_api_url: str | None = None, ) -> None: """Disconnect from a ZenML server. diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index 82af9513fe1..54439e324f9 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """CLI functionality to interact with Model Control Plane.""" -from typing import Any, Dict, List, Optional +from typing import Any import click @@ -36,7 +36,7 @@ logger = get_logger(__name__) -def _model_to_print(model: ModelResponse) -> Dict[str, Any]: +def _model_to_print(model: ModelResponse) -> dict[str, Any]: return { "id": model.id, "name": model.name, @@ -58,7 +58,7 @@ def _model_to_print(model: ModelResponse) -> Dict[str, Any]: def _model_version_to_print( model_version: ModelVersionResponse, -) -> Dict[str, Any]: +) -> dict[str, Any]: return { "id": model_version.id, "model": model_version.model.name, @@ -169,15 +169,15 @@ def list_models(**kwargs: Any) -> None: ) def register_model( name: str, - license: Optional[str], - description: Optional[str], - audience: Optional[str], - use_cases: Optional[str], - tradeoffs: Optional[str], - ethical: Optional[str], - limitations: Optional[str], - tag: Optional[List[str]], - save_models_to_registry: Optional[bool], + license: str | None, + description: str | None, + audience: str | None, + use_cases: str | None, + tradeoffs: str | None, + ethical: str | None, + limitations: str | None, + tag: list[str] | None, + save_models_to_registry: bool | None, ) -> None: """Register a new model in the Model Control Plane. @@ -299,17 +299,17 @@ def register_model( ) def update_model( model_name_or_id: str, - name: Optional[str], - license: Optional[str], - description: Optional[str], - audience: Optional[str], - use_cases: Optional[str], - tradeoffs: Optional[str], - ethical: Optional[str], - limitations: Optional[str], - tag: Optional[List[str]], - remove_tag: Optional[List[str]], - save_models_to_registry: Optional[bool], + name: str | None, + license: str | None, + description: str | None, + audience: str | None, + use_cases: str | None, + tradeoffs: str | None, + ethical: str | None, + limitations: str | None, + tag: list[str] | None, + remove_tag: list[str] | None, + save_models_to_registry: bool | None, ) -> None: """Register a new model in the Model Control Plane. @@ -460,11 +460,11 @@ def list_model_versions(**kwargs: Any) -> None: def update_model_version( model_name_or_id: str, model_version_name_or_number_or_id: str, - stage: Optional[str], - name: Optional[str], - description: Optional[str], - tag: Optional[List[str]], - remove_tag: Optional[List[str]], + stage: str | None, + name: str | None, + description: str | None, + tag: list[str] | None, + remove_tag: list[str] | None, force: bool = False, ) -> None: """Update an existing model version stage in the Model Control Plane. @@ -567,7 +567,7 @@ def delete_model_version( def _print_artifacts_links_generic( model_name_or_id: str, - model_version_name_or_number_or_id: Optional[str] = None, + model_version_name_or_number_or_id: str | None = None, only_data_artifacts: bool = False, only_deployment_artifacts: bool = False, only_model_artifacts: bool = False, @@ -625,7 +625,7 @@ def _print_artifacts_links_generic( @cli_utils.list_options(ModelVersionArtifactFilter) def list_model_version_data_artifacts( model_name: str, - model_version: Optional[str] = None, + model_version: str | None = None, **kwargs: Any, ) -> None: """List data artifacts linked to a model version in the Model Control Plane. @@ -653,7 +653,7 @@ def list_model_version_data_artifacts( @cli_utils.list_options(ModelVersionArtifactFilter) def list_model_version_model_artifacts( model_name: str, - model_version: Optional[str] = None, + model_version: str | None = None, **kwargs: Any, ) -> None: """List model artifacts linked to a model version in the Model Control Plane. @@ -681,7 +681,7 @@ def list_model_version_model_artifacts( @cli_utils.list_options(ModelVersionArtifactFilter) def list_model_version_deployment_artifacts( model_name: str, - model_version: Optional[str] = None, + model_version: str | None = None, **kwargs: Any, ) -> None: """List deployment artifacts linked to a model version in the Model Control Plane. @@ -709,7 +709,7 @@ def list_model_version_deployment_artifacts( @cli_utils.list_options(ModelVersionPipelineRunFilter) def list_model_version_pipeline_runs( model_name: str, - model_version: Optional[str] = None, + model_version: str | None = None, **kwargs: Any, ) -> None: """List pipeline runs of a model version in the Model Control Plane. diff --git a/src/zenml/cli/model_registry.py b/src/zenml/cli/model_registry.py index c326dfcd9a3..88be1b005ff 100644 --- a/src/zenml/cli/model_registry.py +++ b/src/zenml/cli/model_registry.py @@ -14,7 +14,7 @@ """Functionality for model deployer CLI subcommands.""" from datetime import datetime -from typing import TYPE_CHECKING, Dict, List, Optional, cast +from typing import TYPE_CHECKING, cast import click @@ -77,7 +77,7 @@ def models(ctx: click.Context) -> None: @click.pass_obj def list_registered_models( model_registry: "BaseModelRegistry", - metadata: Optional[Dict[str, str]], + metadata: dict[str, str] | None, ) -> None: """List of all registered models within the model registry. @@ -124,8 +124,8 @@ def list_registered_models( def register_model( model_registry: "BaseModelRegistry", name: str, - description: Optional[str], - metadata: Optional[Dict[str, str]], + description: str | None, + metadata: dict[str, str] | None, ) -> None: """Register a model with the active model registry. @@ -219,8 +219,8 @@ def delete_model( def update_model( model_registry: "BaseModelRegistry", name: str, - description: Optional[str], - metadata: Optional[Dict[str, str]], + description: str | None, + metadata: dict[str, str] | None, ) -> None: """Update a model in the active model registry. @@ -418,10 +418,10 @@ def update_model_version( model_registry: "BaseModelRegistry", name: str, version: str, - description: Optional[str], - metadata: Optional[Dict[str, str]], - stage: Optional[str], - remove_metadata: Optional[List[str]], + description: str | None, + metadata: dict[str, str] | None, + stage: str | None, + remove_metadata: list[str] | None, ) -> None: """Update a model version in the active model registry. @@ -508,12 +508,12 @@ def update_model_version( def list_model_versions( model_registry: "BaseModelRegistry", name: str, - model_uri: Optional[str], - count: Optional[int], - metadata: Optional[Dict[str, str]], + model_uri: str | None, + count: int | None, + metadata: dict[str, str] | None, order_by_date: str, - created_after: Optional[datetime], - created_before: Optional[datetime], + created_after: datetime | None, + created_before: datetime | None, ) -> None: """List all model versions in the active model registry. @@ -618,12 +618,12 @@ def register_model_version( name: str, version: str, model_uri: str, - description: Optional[str], - metadata: Optional[Dict[str, str]], - zenml_version: Optional[str], - zenml_run_name: Optional[str], - zenml_pipeline_name: Optional[str], - zenml_step_name: Optional[str], + description: str | None, + metadata: dict[str, str] | None, + zenml_version: str | None, + zenml_run_name: str | None, + zenml_pipeline_name: str | None, + zenml_step_name: str | None, ) -> None: """Register a model version in the active model registry. diff --git a/src/zenml/cli/pipeline.py b/src/zenml/cli/pipeline.py index d63c39395ec..971d9eb75da 100644 --- a/src/zenml/cli/pipeline.py +++ b/src/zenml/cli/pipeline.py @@ -15,7 +15,7 @@ import json import os -from typing import Any, Dict, List, Optional, Union +from typing import Any import click @@ -92,7 +92,7 @@ def pipeline() -> None: help="Path to JSON file containing parameters for the pipeline function.", ) def register_pipeline( - source: str, parameters_path: Optional[str] = None + source: str, parameters_path: str | None = None ) -> None: """Register a pipeline. @@ -118,9 +118,9 @@ def register_pipeline( pipeline_instance = _import_pipeline(source=source) - parameters: Dict[str, Any] = {} + parameters: dict[str, Any] = {} if parameters_path: - with open(parameters_path, "r") as f: + with open(parameters_path) as f: parameters = json.load(f) try: @@ -170,9 +170,9 @@ def register_pipeline( ) def build_pipeline( source: str, - config_path: Optional[str] = None, - stack_name_or_id: Optional[str] = None, - output_path: Optional[str] = None, + config_path: str | None = None, + stack_name_or_id: str | None = None, + output_path: str | None = None, ) -> None: """Build Docker images for a pipeline. @@ -252,9 +252,9 @@ def build_pipeline( ) def run_pipeline( source: str, - config_path: Optional[str] = None, - stack_name_or_id: Optional[str] = None, - build_path_or_id: Optional[str] = None, + config_path: str | None = None, + stack_name_or_id: str | None = None, + build_path_or_id: str | None = None, prevent_build_reuse: bool = False, ) -> None: """Run a pipeline. @@ -281,7 +281,7 @@ def run_pipeline( with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id): pipeline_instance = _import_pipeline(source=source) - build: Union[str, PipelineBuildBase, None] = None + build: str | PipelineBuildBase | None = None if build_path_or_id: if uuid_utils.is_valid_uuid(build_path_or_id): build = build_path_or_id @@ -389,15 +389,15 @@ def run_pipeline( ) def deploy_pipeline( source: str, - deployment_name: Optional[str] = None, - config_path: Optional[str] = None, - stack_name_or_id: Optional[str] = None, - build_path_or_id: Optional[str] = None, + deployment_name: str | None = None, + config_path: str | None = None, + stack_name_or_id: str | None = None, + build_path_or_id: str | None = None, prevent_build_reuse: bool = False, update: bool = False, overtake: bool = False, attach: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> None: """Deploy a pipeline for online inference. @@ -431,7 +431,7 @@ def deploy_pipeline( with cli_utils.temporary_active_stack(stack_name_or_id=stack_name_or_id): pipeline_instance = _import_pipeline(source=source) - build: Union[str, PipelineBuildBase, None] = None + build: str | PipelineBuildBase | None = None if build_path_or_id: if uuid_utils.is_valid_uuid(build_path_or_id): build = build_path_or_id @@ -529,8 +529,8 @@ def deploy_pipeline( def create_run_template( source: str, name: str, - config_path: Optional[str] = None, - stack_name_or_id: Optional[str] = None, + config_path: str | None = None, + stack_name_or_id: str | None = None, ) -> None: """DEPRECATED: Create a run template for a pipeline. @@ -663,7 +663,7 @@ def list_schedules(**kwargs: Any) -> None: help="The cron expression to update the schedule with.", ) def update_schedule( - schedule_name_or_id: str, cron_expression: Optional[str] = None + schedule_name_or_id: str, cron_expression: str | None = None ) -> None: """Update a pipeline schedule. @@ -995,11 +995,11 @@ def snapshot() -> None: def create_pipeline_snapshot( source: str, name: str, - description: Optional[str] = None, - replace: Optional[bool] = None, - tags: Optional[List[str]] = None, - config_path: Optional[str] = None, - stack_name_or_id: Optional[str] = None, + description: str | None = None, + replace: bool | None = None, + tags: list[str] | None = None, + config_path: str | None = None, + stack_name_or_id: str | None = None, ) -> None: """Create a snapshot of a pipeline. @@ -1075,8 +1075,8 @@ def create_pipeline_snapshot( ) def run_snapshot( snapshot_name_or_id: str, - pipeline_name_or_id: Optional[str] = None, - config_path: Optional[str] = None, + pipeline_name_or_id: str | None = None, + config_path: str | None = None, ) -> None: """Run a snapshot. @@ -1147,11 +1147,11 @@ def run_snapshot( ) def deploy_snapshot( snapshot_name_or_id: str, - pipeline_name_or_id: Optional[str] = None, - deployment_name_or_id: Optional[str] = None, + pipeline_name_or_id: str | None = None, + deployment_name_or_id: str | None = None, update: bool = False, overtake: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> None: """Deploy a pipeline for online inference. diff --git a/src/zenml/cli/project.py b/src/zenml/cli/project.py index 2512fe58f71..681384c9a36 100644 --- a/src/zenml/cli/project.py +++ b/src/zenml/cli/project.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Functionality to administer projects of the ZenML CLI and server.""" -from typing import Any, Optional +from typing import Any import click @@ -89,7 +89,7 @@ def list_projects(ctx: click.Context, /, **kwargs: Any) -> None: def register_project( project_name: str, set_project: bool = False, - display_name: Optional[str] = None, + display_name: str | None = None, set_default: bool = False, ) -> None: """Register a new project. @@ -167,7 +167,7 @@ def set_project(project_name_or_id: str, default: bool = False) -> None: @project.command("describe") @click.argument("project_name_or_id", type=str, required=False) -def describe_project(project_name_or_id: Optional[str] = None) -> None: +def describe_project(project_name_or_id: str | None = None) -> None: """Get the project. Args: diff --git a/src/zenml/cli/secret.py b/src/zenml/cli/secret.py index 5c5a38b74e0..6858a2a2de0 100644 --- a/src/zenml/cli/secret.py +++ b/src/zenml/cli/secret.py @@ -14,7 +14,7 @@ """Functionality to generate stack component CLI commands.""" import getpass -from typing import Any, List, Optional +from typing import Any import click @@ -84,7 +84,7 @@ def secret() -> None: ) @click.argument("args", nargs=-1, type=click.UNPROCESSED) def create_secret( - name: str, private: bool, interactive: bool, values: str, args: List[str] + name: str, private: bool, interactive: bool, values: str, args: list[str] ) -> None: """Create a secret. @@ -207,7 +207,7 @@ def list_secrets(**kwargs: Any) -> None: required=False, help="Use this flag to explicitly fetch a private secret or a public secret.", ) -def get_secret(name_id_or_prefix: str, private: Optional[bool] = None) -> None: +def get_secret(name_id_or_prefix: str, private: bool | None = None) -> None: """Get a secret and print it to the console. Args: @@ -228,7 +228,7 @@ def get_secret(name_id_or_prefix: str, private: Optional[bool] = None) -> None: def _get_secret( - name_id_or_prefix: str, private: Optional[bool] = None + name_id_or_prefix: str, private: bool | None = None ) -> SecretResponse: """Get a secret with a given name, prefix or id. @@ -296,9 +296,9 @@ def _get_secret( @click.argument("extra_args", nargs=-1, type=click.UNPROCESSED) def update_secret( name_or_id: str, - extra_args: List[str], - private: Optional[bool] = None, - remove_keys: List[str] = [], + extra_args: list[str], + private: bool | None = None, + remove_keys: list[str] = [], interactive: bool = False, values: str = "", ) -> None: @@ -512,8 +512,8 @@ def delete_secret(name_or_id: str, yes: bool = False) -> None: ) def export_secret( name_id_or_prefix: str, - private: Optional[bool] = None, - filename: Optional[str] = None, + private: bool | None = None, + filename: str | None = None, ) -> None: """Export a secret as a YAML file. diff --git a/src/zenml/cli/served_model.py b/src/zenml/cli/served_model.py index 56962f4a22a..a1484f44d21 100644 --- a/src/zenml/cli/served_model.py +++ b/src/zenml/cli/served_model.py @@ -14,7 +14,7 @@ """Functionality for model-deployer CLI subcommands.""" import uuid -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, cast import click from rich.errors import MarkupError @@ -124,12 +124,12 @@ def models(ctx: click.Context) -> None: @click.pass_obj def list_models( model_deployer: "BaseModelDeployer", - step: Optional[str], - pipeline_name: Optional[str], - pipeline_run_id: Optional[str], - model: Optional[str], - model_version: Optional[str], - flavor: Optional[str], + step: str | None, + pipeline_name: str | None, + pipeline_run_id: str | None, + model: str | None, + model_version: str | None, + flavor: str | None, running: bool, ) -> None: """List of all served models within the model-deployer stack component. @@ -393,7 +393,7 @@ def get_model_service_logs( model_deployer: "BaseModelDeployer", served_model_uuid: str, follow: bool, - tail: Optional[int], + tail: int | None, raw: bool, ) -> None: """Display the logs for a model server. diff --git a/src/zenml/cli/server.py b/src/zenml/cli/server.py index 71978dbe4d4..d9e620df481 100644 --- a/src/zenml/cli/server.py +++ b/src/zenml/cli/server.py @@ -15,7 +15,6 @@ import ipaddress import re -from typing import List, Optional, Union import click from rich.errors import MarkupError @@ -90,13 +89,13 @@ ) def up( docker: bool = False, - ip_address: Union[ - ipaddress.IPv4Address, ipaddress.IPv6Address, None - ] = None, - port: Optional[int] = None, + ip_address: ( + ipaddress.IPv4Address | ipaddress.IPv6Address | None + ) = None, + port: int | None = None, blocking: bool = False, - image: Optional[str] = None, - ngrok_token: Optional[str] = None, + image: str | None = None, + ngrok_token: str | None = None, ) -> None: """Start the ZenML dashboard locally and connect the client to it. @@ -142,7 +141,7 @@ def up( default=None, help="Specify an ngrok auth token to use for exposing the ZenML server.", ) -def legacy_show(ngrok_token: Optional[str] = None) -> None: +def legacy_show(ngrok_token: str | None = None) -> None: """Show the ZenML dashboard. Args: @@ -376,12 +375,12 @@ def status() -> None: type=str, ) def connect( - url: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - api_key: Optional[str] = None, + url: str | None = None, + username: str | None = None, + password: str | None = None, + api_key: str | None = None, no_verify_ssl: bool = False, - ssl_ca_cert: Optional[str] = None, + ssl_ca_cert: str | None = None, ) -> None: """Connect to a remote ZenML server. @@ -465,7 +464,7 @@ def disconnect_server() -> None: def logs( follow: bool = False, raw: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> None: """Display the logs for a ZenML server. @@ -570,7 +569,7 @@ def server() -> None: def server_list( verbose: bool = False, all: bool = False, - pro_api_url: Optional[str] = None, + pro_api_url: str | None = None, ) -> None: """List all ZenML servers that this client is authorized to access. @@ -608,7 +607,7 @@ def server_list( # that the user has never connected to (and are therefore not stored in # the credentials store). - accessible_pro_servers: List[WorkspaceRead] = [] + accessible_pro_servers: list[WorkspaceRead] = [] try: client = ZenMLProClient(pro_api_url) accessible_pro_servers = client.workspace.list(member_only=not all) @@ -727,7 +726,7 @@ def server_list( # Figure out if the client is already connected to one of the # servers in the list - current_server: List[ServerCredentials] = [] + current_server: list[ServerCredentials] = [] if current_store_config.type == StoreType.REST: current_server = [ s for s in all_servers if s.url == current_store_config.url @@ -768,7 +767,7 @@ def server_list( "server. Only used when `--local` is set. Primarily used for accessing the " "local dashboard in Colab.", ) -def show(local: bool = False, ngrok_token: Optional[str] = None) -> None: +def show(local: bool = False, ngrok_token: str | None = None) -> None: """Show the ZenML dashboard. Args: diff --git a/src/zenml/cli/service_accounts.py b/src/zenml/cli/service_accounts.py index 9147b4f09c1..d13a41a5ba7 100644 --- a/src/zenml/cli/service_accounts.py +++ b/src/zenml/cli/service_accounts.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """CLI functionality to interact with API keys.""" -from typing import Any, Optional +from typing import Any import click @@ -33,9 +33,9 @@ def _create_api_key( service_account_name_or_id: str, name: str, - description: Optional[str], + description: str | None, set_key: bool = False, - output_file: Optional[str] = None, + output_file: str | None = None, ) -> None: """Create an API key. @@ -130,7 +130,7 @@ def create_service_account( description: str = "", create_api_key: bool = True, set_api_key: bool = False, - output_file: Optional[str] = None, + output_file: str | None = None, ) -> None: """Create a new service account. @@ -240,9 +240,9 @@ def list_service_accounts(ctx: click.Context, /, **kwargs: Any) -> None: ) def update_service_account( service_account_name_or_id: str, - updated_name: Optional[str] = None, - description: Optional[str] = None, - active: Optional[bool] = None, + updated_name: str | None = None, + description: str | None = None, + active: bool | None = None, ) -> None: """Update an existing service account. @@ -329,9 +329,9 @@ def api_key( def create_api_key( service_account_name_or_id: str, name: str, - description: Optional[str], + description: str | None, set_key: bool = False, - output_file: Optional[str] = None, + output_file: str | None = None, ) -> None: """Create an API key. @@ -432,9 +432,9 @@ def list_api_keys(service_account_name_or_id: str, /, **kwargs: Any) -> None: def update_api_key( service_account_name_or_id: str, name_or_id: str, - name: Optional[str] = None, - description: Optional[str] = None, - active: Optional[bool] = None, + name: str | None = None, + description: str | None = None, + active: bool | None = None, ) -> None: """Update an API key. @@ -487,7 +487,7 @@ def rotate_api_key( name_or_id: str, retain: int = 0, set_key: bool = False, - output_file: Optional[str] = None, + output_file: str | None = None, ) -> None: """Rotate an API key. diff --git a/src/zenml/cli/service_connectors.py b/src/zenml/cli/service_connectors.py index f4819b397a2..190e1bc7c3e 100644 --- a/src/zenml/cli/service_connectors.py +++ b/src/zenml/cli/service_connectors.py @@ -14,7 +14,7 @@ """Service connector CLI commands.""" from datetime import datetime -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, cast from uuid import UUID import click @@ -49,7 +49,7 @@ def service_connector() -> None: def prompt_connector_name( - default_name: Optional[str] = None, connector: Optional[UUID] = None + default_name: str | None = None, connector: UUID | None = None ) -> str: """Prompt the user for a service connector name. @@ -96,7 +96,7 @@ def prompt_connector_name( return name -def prompt_resource_type(available_resource_types: List[str]) -> Optional[str]: +def prompt_resource_type(available_resource_types: list[str]) -> str | None: """Prompt the user for a resource type. Args: @@ -139,8 +139,8 @@ def prompt_resource_type(available_resource_types: List[str]) -> Optional[str]: def prompt_resource_id( - resource_name: str, resource_ids: List[str] -) -> Optional[str]: + resource_name: str, resource_ids: list[str] +) -> str | None: """Prompt the user for a resource ID. Args: @@ -150,7 +150,7 @@ def prompt_resource_id( Returns: The resource ID provided by the user. """ - resource_id: Optional[str] = None + resource_id: str | None = None if resource_ids: resource_ids_list = "\n - " + "\n - ".join(resource_ids) prompt = ( @@ -201,9 +201,9 @@ def prompt_resource_id( def prompt_expiration_time( - min: Optional[int] = None, - max: Optional[int] = None, - default: Optional[int] = None, + min: int | None = None, + max: int | None = None, + default: int | None = None, ) -> int: """Prompt the user for an expiration time. @@ -269,8 +269,8 @@ def prompt_expiration_time( def prompt_expires_at( - default: Optional[datetime] = None, -) -> Optional[datetime]: + default: datetime | None = None, +) -> datetime | None: """Prompt the user for an expiration timestamp. Args: @@ -510,18 +510,18 @@ def prompt_expires_at( ) @click.argument("args", nargs=-1, type=click.UNPROCESSED) def register_service_connector( - name: Optional[str], - args: List[str], - description: Optional[str] = None, - connector_type: Optional[str] = None, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - auth_method: Optional[str] = None, - expires_at: Optional[datetime] = None, - expires_skew_tolerance: Optional[int] = None, - expiration_seconds: Optional[int] = None, + name: str | None, + args: list[str], + description: str | None = None, + connector_type: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, + auth_method: str | None = None, + expires_at: datetime | None = None, + expires_skew_tolerance: int | None = None, + expiration_seconds: int | None = None, no_verify: bool = False, - labels: Optional[List[str]] = None, + labels: list[str] | None = None, interactive: bool = False, no_docs: bool = False, show_secrets: bool = False, @@ -569,7 +569,7 @@ def register_service_connector( ) # Parse the given labels - parsed_labels = cast(Dict[str, str], cli_utils.get_parsed_labels(labels)) + parsed_labels = cast(dict[str, str], cli_utils.get_parsed_labels(labels)) if interactive: # Get the list of available service connector types @@ -668,10 +668,10 @@ def register_service_connector( else: auto_configure = False - connector_model: Optional[ - Union[ServiceConnectorRequest, ServiceConnectorResponse] - ] = None - connector_resources: Optional[ServiceConnectorResourcesModel] = None + connector_model: None | ( + ServiceConnectorRequest | ServiceConnectorResponse + ) = None + connector_resources: ServiceConnectorResourcesModel | None = None if auto_configure: # Try to autoconfigure the service connector try: @@ -975,7 +975,7 @@ def register_service_connector( ) @click.pass_context def list_service_connectors( - ctx: click.Context, /, labels: Optional[List[str]] = None, **kwargs: Any + ctx: click.Context, /, labels: list[str] | None = None, **kwargs: Any ) -> None: """List all service connectors. @@ -1079,8 +1079,8 @@ def describe_service_connector( name_id_or_prefix: str, show_secrets: bool = False, describe_client: bool = False, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + resource_id: str | None = None, ) -> None: """Prints details about a service connector. @@ -1137,7 +1137,7 @@ def describe_service_connector( with console.status(f"Describing connector '{connector.name}'..."): active_stack = client.active_stack_model - active_connector_ids: List[UUID] = [] + active_connector_ids: list[UUID] = [] for components in active_stack.components.values(): active_connector_ids.extend( [ @@ -1315,22 +1315,22 @@ def describe_service_connector( ) @click.argument("args", nargs=-1, type=click.UNPROCESSED) def update_service_connector( - args: List[str], - name_id_or_prefix: Optional[str] = None, - name: Optional[str] = None, - description: Optional[str] = None, - connector_type: Optional[str] = None, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - auth_method: Optional[str] = None, - expires_at: Optional[datetime] = None, - expires_skew_tolerance: Optional[int] = None, - expiration_seconds: Optional[int] = None, + args: list[str], + name_id_or_prefix: str | None = None, + name: str | None = None, + description: str | None = None, + connector_type: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, + auth_method: str | None = None, + expires_at: datetime | None = None, + expires_skew_tolerance: int | None = None, + expiration_seconds: int | None = None, no_verify: bool = False, - labels: Optional[List[str]] = None, + labels: list[str] | None = None, interactive: bool = False, show_secrets: bool = False, - remove_attrs: Optional[List[str]] = None, + remove_attrs: list[str] | None = None, ) -> None: """Updates a service connector. @@ -1755,8 +1755,8 @@ def delete_service_connector(name_id_or_prefix: str) -> None: @click.argument("name_id_or_prefix", type=str, required=True) def verify_service_connector( name_id_or_prefix: str, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + resource_id: str | None = None, verify_only: bool = False, ) -> None: """Verifies if a service connector has access to one or more resources. @@ -1847,8 +1847,8 @@ def verify_service_connector( @click.argument("name_id_or_prefix", type=str, required=True) def login_service_connector( name_id_or_prefix: str, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + resource_id: str | None = None, ) -> None: """Authenticate the local client/SDK with connector credentials. @@ -1955,9 +1955,9 @@ def login_service_connector( is_flag=True, ) def list_service_connector_resources( - connector_type: Optional[str] = None, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + connector_type: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, exclude_errors: bool = False, ) -> None: """List resources that can be accessed by service connectors. @@ -2068,9 +2068,9 @@ def list_service_connector_resources( is_flag=True, ) def list_service_connector_types( - type: Optional[str] = None, - resource_type: Optional[str] = None, - auth_method: Optional[str] = None, + type: str | None = None, + resource_type: str | None = None, + auth_method: str | None = None, detailed: bool = False, ) -> None: """List service connector types. @@ -2131,8 +2131,8 @@ def list_service_connector_types( ) def describe_service_connector_type( type: str, - resource_type: Optional[str] = None, - auth_method: Optional[str] = None, + resource_type: str | None = None, + auth_method: str | None = None, ) -> None: """Describes a service connector type. diff --git a/src/zenml/cli/stack.py b/src/zenml/cli/stack.py index 6e82a0a9d32..4aebe3b9286 100644 --- a/src/zenml/cli/stack.py +++ b/src/zenml/cli/stack.py @@ -20,11 +20,6 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, - Optional, - Set, - Union, ) from uuid import UUID @@ -243,24 +238,24 @@ def stack() -> None: ) def register_stack( stack_name: str, - artifact_store: Optional[str] = None, - orchestrator: Optional[str] = None, - container_registry: Optional[str] = None, - model_registry: Optional[str] = None, - step_operator: Optional[str] = None, - feature_store: Optional[str] = None, - model_deployer: Optional[str] = None, - experiment_tracker: Optional[str] = None, - alerter: Optional[str] = None, - annotator: Optional[str] = None, - data_validator: Optional[str] = None, - image_builder: Optional[str] = None, - deployer: Optional[str] = None, + artifact_store: str | None = None, + orchestrator: str | None = None, + container_registry: str | None = None, + model_registry: str | None = None, + step_operator: str | None = None, + feature_store: str | None = None, + model_deployer: str | None = None, + experiment_tracker: str | None = None, + alerter: str | None = None, + annotator: str | None = None, + data_validator: str | None = None, + image_builder: str | None = None, + deployer: str | None = None, set_stack: bool = False, - provider: Optional[str] = None, - connector: Optional[str] = None, - secrets: List[str] = [], - environment_variables: List[str] = [], + provider: str | None = None, + connector: str | None = None, + secrets: list[str] = [], + environment_variables: list[str] = [], ) -> None: """Register a stack. @@ -324,17 +319,17 @@ def register_stack( except KeyError: pass - environment: Dict[str, str] = {} + environment: dict[str, str] = {} for environment_variable in environment_variables: key, value = environment_variable.split("=", 1) environment[key] = value - labels: Dict[str, str] = {} - components: Dict[StackComponentType, List[Union[UUID, ComponentInfo]]] = {} + labels: dict[str, str] = {} + components: dict[StackComponentType, list[UUID | ComponentInfo]] = {} # Cloud Flow - created_objects: Set[str] = set() - service_connector: Optional[Union[UUID, ServiceConnectorInfo]] = None + created_objects: set[str] = set() + service_connector: UUID | ServiceConnectorInfo | None = None if provider is not None and connector is None: service_connector_response = None use_auto_configure = False @@ -367,7 +362,7 @@ def register_stack( show_default=True, ) - connector_selected: Optional[int] = None + connector_selected: int | None = None if not use_auto_configure: service_connector_response = None existing_connectors = client.list_service_connectors( @@ -438,7 +433,7 @@ def register_stack( (StackComponentType.CONTAINER_REGISTRY, container_registry), ) for component_type, preset_name in needed_components: - component_info: Optional[Union[UUID, ComponentInfo]] = None + component_info: UUID | ComponentInfo | None = None if preset_name is not None: component_response = client.get_stack_component( component_type, preset_name @@ -461,7 +456,7 @@ def register_stack( ] # if some existing components are found - prompt user what to do - component_selected: Optional[int] = None + component_selected: int | None = None component_selected = cli_utils.multi_choice_prompt( object_type=component_type.value.replace("_", " "), choices=[ @@ -733,23 +728,23 @@ def register_stack( multiple=True, ) def update_stack( - stack_name_or_id: Optional[str] = None, - artifact_store: Optional[str] = None, - orchestrator: Optional[str] = None, - container_registry: Optional[str] = None, - step_operator: Optional[str] = None, - feature_store: Optional[str] = None, - model_deployer: Optional[str] = None, - experiment_tracker: Optional[str] = None, - alerter: Optional[str] = None, - annotator: Optional[str] = None, - data_validator: Optional[str] = None, - image_builder: Optional[str] = None, - model_registry: Optional[str] = None, - deployer: Optional[str] = None, - secrets: List[str] = [], - remove_secrets: List[str] = [], - environment_variables: List[str] = [], + stack_name_or_id: str | None = None, + artifact_store: str | None = None, + orchestrator: str | None = None, + container_registry: str | None = None, + step_operator: str | None = None, + feature_store: str | None = None, + model_deployer: str | None = None, + experiment_tracker: str | None = None, + alerter: str | None = None, + annotator: str | None = None, + data_validator: str | None = None, + image_builder: str | None = None, + model_registry: str | None = None, + deployer: str | None = None, + secrets: list[str] = [], + remove_secrets: list[str] = [], + environment_variables: list[str] = [], ) -> None: """Update a stack. @@ -776,7 +771,7 @@ def update_stack( """ client = Client() - environment: Dict[str, Any] = {} + environment: dict[str, Any] = {} for environment_variable in environment_variables: key, value = environment_variable.split("=", 1) # Fallback to None if the value is empty so the existing environment @@ -784,7 +779,7 @@ def update_stack( environment[key] = value or None with console.status("Updating stack...\n"): - updates: Dict[StackComponentType, List[Union[str, UUID]]] = dict() + updates: dict[StackComponentType, list[str | UUID]] = dict() if artifact_store: updates[StackComponentType.ARTIFACT_STORE] = [artifact_store] if alerter: @@ -929,18 +924,18 @@ def update_stack( required=False, ) def remove_stack_component( - stack_name_or_id: Optional[str] = None, - container_registry_flag: Optional[bool] = False, - step_operator_flag: Optional[bool] = False, - feature_store_flag: Optional[bool] = False, - model_deployer_flag: Optional[bool] = False, - experiment_tracker_flag: Optional[bool] = False, - alerter_flag: Optional[bool] = False, - annotator_flag: Optional[bool] = False, - data_validator_flag: Optional[bool] = False, - image_builder_flag: Optional[bool] = False, - model_registry_flag: Optional[str] = None, - deployer_flag: Optional[bool] = False, + stack_name_or_id: str | None = None, + container_registry_flag: bool | None = False, + step_operator_flag: bool | None = False, + feature_store_flag: bool | None = False, + model_deployer_flag: bool | None = False, + experiment_tracker_flag: bool | None = False, + alerter_flag: bool | None = False, + annotator_flag: bool | None = False, + data_validator_flag: bool | None = False, + image_builder_flag: bool | None = False, + model_registry_flag: str | None = None, + deployer_flag: bool | None = False, ) -> None: """Remove stack components from a stack. @@ -963,7 +958,7 @@ def remove_stack_component( client = Client() with console.status("Updating the stack...\n"): - stack_component_update: Dict[StackComponentType, List[Any]] = dict() + stack_component_update: dict[StackComponentType, list[Any]] = dict() if container_registry_flag: stack_component_update[StackComponentType.CONTAINER_REGISTRY] = [] @@ -1074,7 +1069,7 @@ def list_stacks(ctx: click.Context, /, **kwargs: Any) -> None: type=click.STRING, required=False, ) -def describe_stack(stack_name_or_id: Optional[str] = None) -> None: +def describe_stack(stack_name_or_id: str | None = None) -> None: """Show details about a named stack or the active stack. Args: @@ -1201,8 +1196,8 @@ def get_active_stack() -> None: @click.argument("stack_name_or_id", type=str, required=False) @click.argument("filename", type=str, required=False) def export_stack( - stack_name_or_id: Optional[str] = None, - filename: Optional[str] = None, + stack_name_or_id: str | None = None, + filename: str | None = None, ) -> None: """Export a stack to YAML. @@ -1232,7 +1227,7 @@ def export_stack( def _import_stack_component( component_type: StackComponentType, - component_dict: Dict[str, Any], + component_dict: dict[str, Any], ) -> UUID: """Import a single stack component with given type/config. @@ -1289,7 +1284,7 @@ def _import_stack_component( ) def import_stack( stack_name: str, - filename: Optional[str], + filename: str | None, ignore_version_mismatch: bool = False, ) -> None: """Import a stack from YAML. @@ -1385,7 +1380,7 @@ def copy_stack(source_stack_name_or_id: str, target_stack: str) -> None: except KeyError as err: cli_utils.exception(err) - component_mapping: Dict[StackComponentType, Union[str, UUID]] = {} + component_mapping: dict[StackComponentType, str | UUID] = {} for c_type, c_list in stack_to_copy.components.items(): if c_list: @@ -1414,7 +1409,7 @@ def copy_stack(source_stack_name_or_id: str, target_stack: str) -> None: ) def register_secrets( skip_existing: bool, - stack_name_or_id: Optional[str] = None, + stack_name_or_id: str | None = None, ) -> None: """Interactively registers all required secrets for a stack. @@ -1588,8 +1583,8 @@ def validate_name(ctx: click.Context, param: str, value: str) -> str: def deploy( ctx: click.Context, provider: str, - stack_name: Optional[str] = None, - location: Optional[str] = None, + stack_name: str | None = None, + location: str | None = None, set_stack: bool = False, ) -> None: """Deploy and register a fully functional cloud ZenML stack. @@ -1796,8 +1791,8 @@ def deploy( type=click.BOOL, ) def connect_stack( - stack_name_or_id: Optional[str] = None, - connector: Optional[str] = None, + stack_name_or_id: str | None = None, + connector: str | None = None, interactive: bool = False, no_verify: bool = False, ) -> None: @@ -1827,9 +1822,9 @@ def connect_stack( def _get_service_connector_info( cloud_provider: str, - connector_details: Optional[ - Union[ServiceConnectorResponse, ServiceConnectorRequest] - ], + connector_details: None | ( + ServiceConnectorResponse | ServiceConnectorRequest + ), ) -> ServiceConnectorInfo: """Get a service connector info with given cloud provider. @@ -1912,7 +1907,7 @@ def _get_stack_component_info( component_type: str, cloud_provider: str, resources_info: ServiceConnectorResourcesInfo, - service_connector_index: Optional[int] = None, + service_connector_index: int | None = None, ) -> ComponentInfo: """Get a stack component info with given type and service connector. @@ -2072,8 +2067,8 @@ def query_region( "only valid if the output file is provided.", ) def export_requirements( - stack_name_or_id: Optional[str] = None, - output_file: Optional[str] = None, + stack_name_or_id: str | None = None, + output_file: str | None = None, overwrite: bool = False, ) -> None: """Exports stack requirements so they can be installed using pip. diff --git a/src/zenml/cli/stack_components.py b/src/zenml/cli/stack_components.py index 87e435b6d1a..1d9e6cf75d2 100644 --- a/src/zenml/cli/stack_components.py +++ b/src/zenml/cli/stack_components.py @@ -15,7 +15,8 @@ import time from importlib import import_module -from typing import Any, Callable, List, Optional, Tuple, cast +from typing import Any, cast +from collections.abc import Callable from uuid import UUID import click @@ -184,7 +185,7 @@ def list_stack_components_command( def generate_stack_component_register_command( component_type: StackComponentType, -) -> Callable[[str, str, List[str]], None]: +) -> Callable[[str, str, list[str]], None]: """Generates a `register` command for the specific stack component type. Args: @@ -253,12 +254,12 @@ def generate_stack_component_register_command( def register_stack_component_command( name: str, flavor: str, - args: List[str], - labels: Optional[List[str]] = None, - connector: Optional[str] = None, - resource_id: Optional[str] = None, - secrets: List[str] = [], - environment_variables: List[str] = [], + args: list[str], + labels: list[str] | None = None, + connector: str | None = None, + resource_id: str | None = None, + secrets: list[str] = [], + environment_variables: list[str] = [], ) -> None: """Registers a stack component. @@ -328,7 +329,7 @@ def register_stack_component_command( def generate_stack_component_update_command( component_type: StackComponentType, -) -> Callable[[str, List[str]], None]: +) -> Callable[[str, list[str]], None]: """Generates an `update` command for the specific stack component type. Args: @@ -380,12 +381,12 @@ def generate_stack_component_update_command( ) @click.argument("args", nargs=-1, type=click.UNPROCESSED) def update_stack_component_command( - name_id_or_prefix: Optional[str], - args: List[str], - labels: Optional[List[str]] = None, - secrets: List[str] = [], - remove_secrets: List[str] = [], - environment_variables: List[str] = [], + name_id_or_prefix: str | None, + args: list[str], + labels: list[str] | None = None, + secrets: list[str] = [], + remove_secrets: list[str] = [], + environment_variables: list[str] = [], ) -> None: """Updates a stack component. @@ -445,7 +446,7 @@ def update_stack_component_command( def generate_stack_component_remove_attribute_command( component_type: StackComponentType, -) -> Callable[[str, List[str]], None]: +) -> Callable[[str, list[str]], None]: """Generates `remove_attribute` command for a specific stack component type. Args: @@ -471,8 +472,8 @@ def generate_stack_component_remove_attribute_command( @click.argument("args", nargs=-1, type=click.UNPROCESSED) def remove_attribute_stack_component_command( name_id_or_prefix: str, - args: List[str], - labels: Optional[List[str]] = None, + args: list[str], + labels: list[str] | None = None, ) -> None: """Removes one or more attributes from a stack component. @@ -719,7 +720,7 @@ def stack_component_logs_command( if follow: try: - with open(log_file, "r") as f: + with open(log_file) as f: # seek to the end of the file f.seek(0, 2) @@ -733,7 +734,7 @@ def stack_component_logs_command( except KeyboardInterrupt: cli_utils.declare(f"Stopped following {display_name} logs.") else: - with open(log_file, "r") as f: + with open(log_file) as f: click.echo(f.read()) return stack_component_logs_command @@ -964,7 +965,7 @@ def delete_stack_component_flavor_command(name_or_id: str) -> None: def prompt_select_resource_id( - resource_ids: List[str], + resource_ids: list[str], resource_name: str, interactive: bool = True, ) -> str: @@ -1013,8 +1014,8 @@ def prompt_select_resource_id( def prompt_select_resource( - resource_list: List[ServiceConnectorResourcesModel], -) -> Tuple[UUID, str]: + resource_list: list[ServiceConnectorResourcesModel], +) -> tuple[UUID, str]: """Prompts the user to select a resource ID from a list of resources. Args: @@ -1144,9 +1145,9 @@ def generate_stack_component_connect_command( type=click.BOOL, ) def connect_stack_component_command( - name_id_or_prefix: Optional[str], - connector: Optional[str] = None, - resource_id: Optional[str] = None, + name_id_or_prefix: str | None, + connector: str | None = None, + resource_id: str | None = None, interactive: bool = False, no_verify: bool = False, ) -> None: @@ -1393,9 +1394,9 @@ def register_all_stack_component_cli_commands() -> None: def connect_stack_component_with_service_connector( component_type: StackComponentType, - name_id_or_prefix: Optional[str] = None, - connector: Optional[str] = None, - resource_id: Optional[str] = None, + name_id_or_prefix: str | None = None, + connector: str | None = None, + resource_id: str | None = None, interactive: bool = False, no_verify: bool = False, ) -> None: @@ -1545,7 +1546,7 @@ def connect_stack_component_with_service_connector( "to select a resource interactively." ) - connector_resources: Optional[ServiceConnectorResourcesModel] = None + connector_resources: ServiceConnectorResourcesModel | None = None if not no_verify: with console.status( "Validating service connector resource configuration...\n" diff --git a/src/zenml/cli/tag.py b/src/zenml/cli/tag.py index a45b17293be..dd6c9ef180c 100644 --- a/src/zenml/cli/tag.py +++ b/src/zenml/cli/tag.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """CLI functionality to interact with tags.""" -from typing import Any, Optional, Union +from typing import Any from uuid import UUID import click @@ -72,7 +72,7 @@ def list_tags(**kwargs: Any) -> None: type=click.Choice(choices=ColorVariants.values()), required=False, ) -def register_tag(name: str, color: Optional[ColorVariants]) -> None: +def register_tag(name: str, color: ColorVariants | None) -> None: """Register a new model in the Model Control Plane. Args: @@ -108,7 +108,7 @@ def register_tag(name: str, color: Optional[ColorVariants]) -> None: required=False, ) def update_tag( - tag_name_or_id: Union[str, UUID], name: Optional[str], color: Optional[str] + tag_name_or_id: str | UUID, name: str | None, color: str | None ) -> None: """Register a new model in the Model Control Plane. @@ -142,7 +142,7 @@ def update_tag( help="Don't ask for confirmation.", ) def delete_tag( - tag_name_or_id: Union[str, UUID], + tag_name_or_id: str | UUID, yes: bool = False, ) -> None: """Delete an existing tag. diff --git a/src/zenml/cli/text_utils.py b/src/zenml/cli/text_utils.py index 95693712f9f..ef9a4f42443 100644 --- a/src/zenml/cli/text_utils.py +++ b/src/zenml/cli/text_utils.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Utilities for CLI output.""" -from typing import List from rich.console import Console, ConsoleOptions, RenderResult from rich.markdown import Heading, Markdown @@ -61,7 +60,7 @@ ) -def zenml_go_notebook_tutorial_message(ipynb_files: List[str]) -> Markdown: +def zenml_go_notebook_tutorial_message(ipynb_files: list[str]) -> Markdown: """Outputs a message to the user about the `zenml go` tutorial. Args: diff --git a/src/zenml/cli/user_management.py b/src/zenml/cli/user_management.py index 71abbc371ca..e3107073b8e 100644 --- a/src/zenml/cli/user_management.py +++ b/src/zenml/cli/user_management.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Functionality to administer users of the ZenML CLI and server.""" -from typing import Any, Optional +from typing import Any import click @@ -39,7 +39,7 @@ def user() -> None: @user.command("describe") @click.argument("user_name_or_id", type=str, required=False) -def describe_user(user_name_or_id: Optional[str] = None) -> None: +def describe_user(user_name_or_id: str | None = None) -> None: """Get the user. Args: @@ -137,7 +137,7 @@ def list_users(ctx: click.Context, /, **kwargs: Any) -> None: ) def create_user( user_name: str, - password: Optional[str] = None, + password: str | None = None, is_admin: bool = False, ) -> None: """Create a new user. @@ -246,12 +246,12 @@ def create_user( ) def update_user( user_name_or_id: str, - updated_name: Optional[str] = None, - updated_full_name: Optional[str] = None, - updated_email: Optional[str] = None, - make_admin: Optional[bool] = None, - make_user: Optional[bool] = None, - active: Optional[bool] = None, + updated_name: str | None = None, + updated_full_name: str | None = None, + updated_email: str | None = None, + make_admin: bool | None = None, + make_user: bool | None = None, + active: bool | None = None, ) -> None: """Update an existing user. @@ -321,7 +321,7 @@ def update_user( type=str, ) def change_user_password( - password: Optional[str] = None, old_password: Optional[str] = None + password: str | None = None, old_password: str | None = None ) -> None: """Change the password of the current user. diff --git a/src/zenml/cli/utils.py b/src/zenml/cli/utils.py index 0baae16ca6a..50b39363cb9 100644 --- a/src/zenml/cli/utils.py +++ b/src/zenml/cli/utils.py @@ -27,20 +27,13 @@ TYPE_CHECKING, AbstractSet, Any, - Callable, - Dict, - Iterator, List, NoReturn, - Optional, - Sequence, - Set, - Tuple, - Type, TypeVar, Union, cast, ) +from collections.abc import Callable, Iterator, Sequence import click import yaml @@ -154,8 +147,8 @@ def confirmation(text: str, *args: Any, **kwargs: Any) -> bool: def declare( text: Union[str, "Text"], - bold: Optional[bool] = None, - italic: Optional[bool] = None, + bold: bool | None = None, + italic: bool | None = None, **kwargs: Any, ) -> None: """Echo a declaration on the CLI. @@ -185,7 +178,7 @@ def error(text: str) -> NoReturn: # Create a custom ClickException that bypasses Click's default "Error: " prefix class StyledClickException(click.ClickException): - def show(self, file: Optional[IO[Any]] = None) -> None: + def show(self, file: IO[Any] | None = None) -> None: if file is None: file = click.get_text_stream("stderr") # Print our custom styled message directly without Click's prefix @@ -221,8 +214,8 @@ def exception(exception: Exception) -> NoReturn: def warning( text: str, - bold: Optional[bool] = None, - italic: Optional[bool] = None, + bold: bool | None = None, + italic: bool | None = None, **kwargs: Any, ) -> None: """Echo a warning string on the CLI. @@ -240,8 +233,8 @@ def warning( def success( text: str, - bold: Optional[bool] = None, - italic: Optional[bool] = None, + bold: bool | None = None, + italic: bool | None = None, **kwargs: Any, ) -> None: """Echo a success string on the CLI. @@ -279,9 +272,9 @@ def print_markdown_with_pager(text: str) -> None: def print_table( - obj: List[Dict[str, Any]], - title: Optional[str] = None, - caption: Optional[str] = None, + obj: list[dict[str, Any]], + title: str | None = None, + caption: str | None = None, **columns: table.Column, ) -> None: """Prints the list of dicts in a table format. @@ -326,7 +319,7 @@ def print_table( ): # Display the URL as a hyperlink in a way that doesn't break # the URL when it needs to be wrapped over multiple lines - value: Union[str, Text] = Text(v, style=f"link {v}") + value: str | Text = Text(v, style=f"link {v}") else: value = str(v) # Escape text when square brackets are used, but allow @@ -341,12 +334,12 @@ def print_table( def print_pydantic_models( - models: Union[Page[T], List[T]], - columns: Optional[List[str]] = None, - exclude_columns: Optional[List[str]] = None, - active_models: Optional[List[T]] = None, + models: Page[T] | list[T], + columns: list[str] | None = None, + exclude_columns: list[str] | None = None, + active_models: list[T] | None = None, show_active: bool = False, - rename_columns: Dict[str, str] = {}, + rename_columns: dict[str, str] = {}, ) -> None: """Prints the list of Pydantic models in a table. @@ -369,7 +362,7 @@ def print_pydantic_models( show_active_column = False active_models = list() - def __dictify(model: T) -> Dict[str, str]: + def __dictify(model: T) -> dict[str, str]: """Helper function to map over the list to turn Models into dicts. Args: @@ -414,7 +407,7 @@ def __dictify(model: T) -> Dict[str, str]: else: include_columns = columns - items: Dict[str, Any] = {} + items: dict[str, Any] = {} for k in include_columns: value = getattr(model, k) @@ -440,7 +433,7 @@ def __dictify(model: T) -> Dict[str, str]: ) else: items.setdefault(k, []).append(str(v.id)) - elif isinstance(value, Set) or isinstance(value, List): + elif isinstance(value, set) or isinstance(value, list): items[k] = [str(v) for v in value] else: items[k] = str(value) @@ -497,8 +490,8 @@ def __dictify(model: T) -> Dict[str, str]: def print_pydantic_model( title: str, model: BaseModel, - exclude_columns: Optional[AbstractSet[str]] = None, - columns: Optional[AbstractSet[str]] = None, + exclude_columns: AbstractSet[str] | None = None, + columns: AbstractSet[str] | None = None, ) -> None: """Prints a single Pydantic model in a table. @@ -555,7 +548,7 @@ def print_pydantic_model( else: include_columns = list(columns) - items: Dict[str, Any] = {} + items: dict[str, Any] = {} for k in include_columns: value = getattr(model, k) @@ -576,7 +569,7 @@ def print_pydantic_model( items.setdefault(k, []).append(str(v.id)) items[k] = str(items[k]) - elif isinstance(value, Set) or isinstance(value, List): + elif isinstance(value, set) or isinstance(value, list): items[k] = str([str(v) for v in value]) else: items[k] = str(value) @@ -588,8 +581,8 @@ def print_pydantic_model( def format_integration_list( - integrations: List[Tuple[str, Type["Integration"]]], -) -> List[Dict[str, str]]: + integrations: list[tuple[str, type["Integration"]]], +) -> list[dict[str, str]]: """Formats a list of integrations into a List of Dicts. This list of dicts can then be printed in a table style using @@ -693,7 +686,7 @@ def print_flavor_list(flavors: Page["FlavorResponse"]) -> None: def print_stack_component_configuration( component: "ComponentResponse", active_status: bool, - connector_requirements: Optional[ServiceConnectorRequirements] = None, + connector_requirements: ServiceConnectorRequirements | None = None, ) -> None: """Prints the configuration options of a stack component. @@ -837,7 +830,7 @@ def expand_argument_value_from_file(name: str, value: str) -> str: f"{MAX_ARGUMENT_VALUE_SIZE} bytes)." ) - with open(filename, "r") as f: + with open(filename) as f: return f.read() except OSError as e: raise ValueError( @@ -846,7 +839,7 @@ def expand_argument_value_from_file(name: str, value: str) -> str: ) -def convert_structured_str_to_dict(string: str) -> Dict[str, str]: +def convert_structured_str_to_dict(string: str) -> dict[str, str]: """Convert a structured string (JSON or YAML) into a dict. Examples: @@ -864,7 +857,7 @@ def convert_structured_str_to_dict(string: str) -> Dict[str, str]: dict_: dict from structured JSON or YAML str """ try: - dict_: Dict[str, str] = json.loads(string) + dict_: dict[str, str] = json.loads(string) return dict_ except ValueError: pass @@ -882,10 +875,10 @@ def convert_structured_str_to_dict(string: str) -> Dict[str, str]: def parse_name_and_extra_arguments( - args: List[str], + args: list[str], expand_args: bool = False, name_mandatory: bool = True, -) -> Tuple[Optional[str], Dict[str, str]]: +) -> tuple[str | None, dict[str, str]]: """Parse a name and extra arguments from the CLI. This is a utility function used to parse a variable list of optional CLI @@ -914,7 +907,7 @@ def parse_name_and_extra_arguments( Returns: The name and a dict of parsed args. """ - name: Optional[str] = None + name: str | None = None # The name was not supplied as the first argument, we have to # search the other arguments for the name. for i, arg in enumerate(args): @@ -937,7 +930,7 @@ def parse_name_and_extra_arguments( "identifier as the key and the following structure: " '--custom_argument="value"' ) - args_dict: Dict[str, str] = {} + args_dict: dict[str, str] = {} for a in args: if not a: # Skip empty arguments. @@ -968,7 +961,7 @@ def validate_keys(key: str) -> None: error("Please provide args with a proper identifier as the key.") -def parse_unknown_component_attributes(args: List[str]) -> List[str]: +def parse_unknown_component_attributes(args: list[str]) -> list[str]: """Parse unknown options from the CLI. Args: @@ -990,10 +983,10 @@ def parse_unknown_component_attributes(args: List[str]) -> List[str]: def prompt_configuration( - config_schema: Dict[str, Any], + config_schema: dict[str, Any], show_secrets: bool = False, - existing_config: Optional[Dict[str, str]] = None, -) -> Dict[str, str]: + existing_config: dict[str, str] | None = None, +) -> dict[str, str]: """Prompt the user for configuration values using the provided schema. Args: @@ -1018,7 +1011,7 @@ def prompt_configuration( title = f"[{attr_name}] {title}" required = attr_name in config_schema.get("required", []) hidden = attr_schema.get("format", "") == "password" - subtitles: List[str] = [] + subtitles: list[str] = [] subtitles.append(attr_type_name) if hidden: subtitles.append("secret") @@ -1098,7 +1091,7 @@ def prompt_configuration( def install_packages( - packages: List[str], + packages: list[str], upgrade: bool = False, use_uv: bool = False, ) -> None: @@ -1210,7 +1203,7 @@ def is_pip_installed() -> bool: def pretty_print_secret( - secret: Dict[str, str], + secret: dict[str, str], hide_secret: bool = True, ) -> None: """Print all key-value pairs associated with a secret. @@ -1220,7 +1213,7 @@ def pretty_print_secret( hide_secret: boolean that configures if the secret values are shown on the CLI """ - title: Optional[str] = None + title: str | None = None def get_secret_value(value: Any) -> str: if value is None: @@ -1238,7 +1231,7 @@ def get_secret_value(value: Any) -> str: print_table(stack_dicts, title=title) -def print_list_items(list_items: List[str], column_title: str) -> None: +def print_list_items(list_items: list[str], column_title: str) -> None: """Prints the configuration options of a stack. Args: @@ -1283,7 +1276,7 @@ def get_service_state_emoji(state: "ServiceState") -> str: def pretty_print_model_deployer( - model_services: List["BaseService"], model_deployer: "BaseModelDeployer" + model_services: list["BaseService"], model_deployer: "BaseModelDeployer" ) -> None: """Given a list of served_models, print all associated key-value pairs. @@ -1316,7 +1309,7 @@ def pretty_print_model_deployer( def pretty_print_registered_model_table( - registered_models: List["RegisteredModel"], + registered_models: list["RegisteredModel"], ) -> None: """Given a list of registered_models, print all associated key-value pairs. @@ -1337,7 +1330,7 @@ def pretty_print_registered_model_table( def pretty_print_model_version_table( - model_versions: List["RegistryModelVersion"], + model_versions: list["RegistryModelVersion"], ) -> None: """Given a list of model_versions, print all associated key-value pairs. @@ -1462,7 +1455,7 @@ def print_served_model_configuration( console.print(rich_table) -def describe_pydantic_object(schema_json: Dict[str, Any]) -> None: +def describe_pydantic_object(schema_json: dict[str, Any]) -> None: """Describes a Pydantic object based on the dict-representation of its schema. Args: @@ -1658,7 +1651,7 @@ def print_service_connectors_table( if len(connectors) == 0: return - active_connectors: List["ServiceConnectorResponse"] = [] + active_connectors: list["ServiceConnectorResponse"] = [] for components in client.active_stack_model.components.values(): for component in components: if component.connector: @@ -1718,7 +1711,7 @@ def print_service_connectors_table( def print_service_connector_resource_table( - resources: List["ServiceConnectorResourcesModel"], + resources: list["ServiceConnectorResourcesModel"], show_resources_only: bool = False, ) -> None: """Prints a table with details for a list of service connector resources. @@ -1730,7 +1723,7 @@ def print_service_connector_resource_table( resource_table = [] for resource_model in resources: printed_connector = False - resource_row: Dict[str, Any] = {} + resource_row: dict[str, Any] = {} if resource_model.error: # Global error @@ -2063,7 +2056,7 @@ def print_service_connector_auth_method( ) message += f"{auth_method.description}\n" - attributes: List[str] = [] + attributes: list[str] = [] for attr_name, attr_schema in auth_method.config_schema.get( "properties", {} ).items(): @@ -2071,7 +2064,7 @@ def print_service_connector_auth_method( attr_type = attr_schema.get("type", "string") required = attr_name in auth_method.config_schema.get("required", []) hidden = attr_schema.get("format", "") == "password" - subtitles: List[str] = [] + subtitles: list[str] = [] subtitles.append(attr_type) if hidden: subtitles.append("secret") @@ -2201,7 +2194,7 @@ def _get_stack_components( return list(stack.components.values()) -def _scrub_secret(config: StackComponentConfig) -> Dict[str, Any]: +def _scrub_secret(config: StackComponentConfig) -> dict[str, Any]: """Remove secret values from a configuration. Args: @@ -2345,7 +2338,7 @@ def print_pipeline_runs_table( def fetch_snapshot( snapshot_name_or_id: str, - pipeline_name_or_id: Optional[str] = None, + pipeline_name_or_id: str | None = None, ) -> "PipelineSnapshotResponse": """Fetch a snapshot by name or ID. @@ -2393,7 +2386,7 @@ def fetch_snapshot( def get_deployment_status_emoji( - status: Optional[str], + status: str | None, ) -> str: """Returns an emoji representing the given deployment status. @@ -2415,7 +2408,7 @@ def get_deployment_status_emoji( return ":question:" -def format_deployment_status(status: Optional[str]) -> str: +def format_deployment_status(status: str | None) -> str: """Format deployment status with color. Args: @@ -2695,7 +2688,7 @@ def print_page_info(page: Page[T]) -> None: F = TypeVar("F", bound=Callable[..., None]) -def create_filter_help_text(filter_model: Type[BaseFilter], field: str) -> str: +def create_filter_help_text(filter_model: type[BaseFilter], field: str) -> str: """Create the help text used in the click option help text. Args: @@ -2747,7 +2740,7 @@ def create_filter_help_text(filter_model: Type[BaseFilter], field: str) -> str: def create_data_type_help_text( - filter_model: Type[BaseFilter], field: str + filter_model: type[BaseFilter], field: str ) -> str: """Create a general help text for a fields datatype. @@ -2805,7 +2798,7 @@ def _is_list_field(field_info: Any) -> bool: ) -def list_options(filter_model: Type[BaseFilter]) -> Callable[[F], F]: +def list_options(filter_model: type[BaseFilter]) -> Callable[[F], F]: """Create a decorator to generate the correct list of filter parameters. The Outer decorator (`list_options`) is responsible for creating the inner @@ -2904,7 +2897,7 @@ def temporary_active_stack( Client().activate_stack(old_stack_id) -def print_user_info(info: Dict[str, Any]) -> None: +def print_user_info(info: dict[str, Any]) -> None: """Print user information to the terminal. Args: @@ -2918,8 +2911,8 @@ def print_user_info(info: Dict[str, Any]) -> None: def get_parsed_labels( - labels: Optional[List[str]], allow_label_only: bool = False -) -> Dict[str, Optional[str]]: + labels: list[str] | None, allow_label_only: bool = False +) -> dict[str, str | None]: """Parse labels into a dictionary. Args: @@ -2975,7 +2968,7 @@ def is_sorted_or_filtered(ctx: click.Context) -> bool: return False -def print_model_url(url: Optional[str]) -> None: +def print_model_url(url: str | None) -> None: """Pretty prints a given URL on the CLI. Args: @@ -3003,12 +2996,12 @@ def is_jupyter_installed() -> bool: def multi_choice_prompt( object_type: str, - choices: List[List[Any]], - headers: List[str], + choices: list[list[Any]], + headers: list[str], prompt_text: str, allow_zero_be_a_new_object: bool = False, - default_choice: Optional[str] = None, -) -> Optional[int]: + default_choice: str | None = None, +) -> int | None: """Prompts the user to select a choice from a list of choices. Args: diff --git a/src/zenml/client.py b/src/zenml/client.py index 3972004d963..476f2bca59e 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -22,19 +22,11 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - Generator, - List, - Mapping, Optional, - Sequence, - Tuple, - Type, TypeVar, - Union, cast, ) +from collections.abc import Callable, Generator, Mapping, Sequence from uuid import UUID from pydantic import ConfigDict, SecretStr @@ -229,8 +221,8 @@ class ClientConfiguration(FileSyncModel): """Pydantic object used for serializing client configuration options.""" _active_project: Optional["ProjectResponse"] = None - active_project_id: Optional[UUID] = None - active_stack_id: Optional[UUID] = None + active_project_id: UUID | None = None + active_stack_id: UUID | None = None _active_stack: Optional["StackResponse"] = None @property @@ -365,7 +357,7 @@ class Client(metaclass=ClientMetaClass): def __init__( self, - root: Optional[Path] = None, + root: Path | None = None, ) -> None: """Initializes the global client instance. @@ -392,8 +384,8 @@ def __init__( current working directory. Only used to initialize new clients internally. """ - self._root: Optional[Path] = None - self._config: Optional[ClientConfiguration] = None + self._root: Path | None = None + self._config: ClientConfiguration | None = None self._set_active_root(root) @@ -420,7 +412,7 @@ def _reset_instance(cls, client: Optional["Client"] = None) -> None: """ cls._global_client = client - def _set_active_root(self, root: Optional[Path] = None) -> None: + def _set_active_root(self, root: Path | None = None) -> None: """Set the supplied path as the repository root. If a client configuration is found at the given path or the @@ -454,7 +446,7 @@ def _set_active_root(self, root: Optional[Path] = None) -> None: # settings self._sanitize_config() - def _config_path(self) -> Optional[str]: + def _config_path(self) -> str | None: """Path to the client configuration file. Returns: @@ -484,7 +476,7 @@ def _sanitize_config(self) -> None: if active_project: self._config.set_active_project(active_project) - def _load_config(self) -> Optional[ClientConfiguration]: + def _load_config(self) -> ClientConfiguration | None: """Loads the client configuration from disk. This happens if the client has an active root and the configuration @@ -513,7 +505,7 @@ def _load_config(self) -> Optional[ClientConfiguration]: @staticmethod def initialize( - root: Optional[Path] = None, + root: Path | None = None, ) -> None: """Initializes a new ZenML repository at the given path. @@ -563,8 +555,8 @@ def is_repository_directory(path: Path) -> bool: @staticmethod def find_repository( - path: Optional[Path] = None, enable_warnings: bool = False - ) -> Optional[Path]: + path: Path | None = None, enable_warnings: bool = False + ) -> Path | None: """Search for a ZenML repository directory. Args: @@ -610,7 +602,7 @@ def find_repository( f"repository, run `zenml init`." ) - def _find_repository_helper(path_: Path) -> Optional[Path]: + def _find_repository_helper(path_: Path) -> Path | None: """Recursively search parent directories for a ZenML repository. Args: @@ -661,7 +653,7 @@ def zen_store(self) -> "BaseZenStore": return GlobalConfiguration().zen_store @property - def root(self) -> Optional[Path]: + def root(self) -> Path | None: """The root directory of this client. Returns: @@ -671,7 +663,7 @@ def root(self) -> Optional[Path]: return self._root @property - def config_directory(self) -> Optional[Path]: + def config_directory(self) -> Path | None: """The configuration directory of this client. Returns: @@ -680,7 +672,7 @@ def config_directory(self) -> Optional[Path]: """ return self.root / REPOSITORY_DIRECTORY_NAME if self.root else None - def activate_root(self, root: Optional[Path] = None) -> None: + def activate_root(self, root: Path | None = None) -> None: """Set the active repository root directory. Args: @@ -693,7 +685,7 @@ def activate_root(self, root: Optional[Path] = None) -> None: self._set_active_root(root) def set_active_project( - self, project_name_or_id: Union[str, UUID] + self, project_name_or_id: str | UUID ) -> "ProjectResponse": """Set the project for the local client. @@ -733,12 +725,12 @@ def get_settings(self, hydrate: bool = True) -> ServerSettingsResponse: def update_server_settings( self, - updated_name: Optional[str] = None, - updated_logo_url: Optional[str] = None, - updated_enable_analytics: Optional[bool] = None, - updated_enable_announcements: Optional[bool] = None, - updated_enable_updates: Optional[bool] = None, - updated_onboarding_state: Optional[Dict[str, Any]] = None, + updated_name: str | None = None, + updated_logo_url: str | None = None, + updated_enable_analytics: bool | None = None, + updated_enable_announcements: bool | None = None, + updated_enable_updates: bool | None = None, + updated_onboarding_state: dict[str, Any] | None = None, ) -> ServerSettingsResponse: """Update the server settings. @@ -771,7 +763,7 @@ def update_server_settings( def create_user( self, name: str, - password: Optional[str] = None, + password: str | None = None, is_admin: bool = False, ) -> UserResponse: """Create a new user. @@ -797,7 +789,7 @@ def create_user( def get_user( self, - name_id_or_prefix: Union[str, UUID], + name_id_or_prefix: str | UUID, allow_name_prefix_match: bool = True, hydrate: bool = True, ) -> UserResponse: @@ -826,15 +818,15 @@ def list_users( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - external_user_id: Optional[str] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - full_name: Optional[str] = None, - email: Optional[str] = None, - active: Optional[bool] = None, - email_opted_in: Optional[bool] = None, + id: UUID | str | None = None, + external_user_id: str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + full_name: str | None = None, + email: str | None = None, + active: bool | None = None, + email_opted_in: bool | None = None, hydrate: bool = False, ) -> Page[UserResponse]: """List all users. @@ -880,17 +872,17 @@ def list_users( def update_user( self, - name_id_or_prefix: Union[str, UUID], - updated_name: Optional[str] = None, - updated_full_name: Optional[str] = None, - updated_email: Optional[str] = None, - updated_email_opt_in: Optional[bool] = None, - updated_password: Optional[str] = None, - old_password: Optional[str] = None, - updated_is_admin: Optional[bool] = None, - updated_metadata: Optional[Dict[str, Any]] = None, - updated_default_project_id: Optional[UUID] = None, - active: Optional[bool] = None, + name_id_or_prefix: str | UUID, + updated_name: str | None = None, + updated_full_name: str | None = None, + updated_email: str | None = None, + updated_email_opt_in: bool | None = None, + updated_password: str | None = None, + old_password: str | None = None, + updated_is_admin: bool | None = None, + updated_metadata: dict[str, Any] | None = None, + updated_default_project_id: UUID | None = None, + active: bool | None = None, ) -> UserResponse: """Update a user. @@ -992,7 +984,7 @@ def create_project( self, name: str, description: str, - display_name: Optional[str] = None, + display_name: str | None = None, ) -> ProjectResponse: """Create a new project. @@ -1014,7 +1006,7 @@ def create_project( def get_project( self, - name_id_or_prefix: Optional[Union[UUID, str]], + name_id_or_prefix: UUID | str | None, allow_name_prefix_match: bool = True, hydrate: bool = True, ) -> ProjectResponse: @@ -1045,11 +1037,11 @@ def list_projects( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - display_name: Optional[str] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + display_name: str | None = None, hydrate: bool = False, ) -> Page[ProjectResponse]: """List all projects. @@ -1087,10 +1079,10 @@ def list_projects( def update_project( self, - name_id_or_prefix: Optional[Union[UUID, str]], - new_name: Optional[str] = None, - new_display_name: Optional[str] = None, - new_description: Optional[str] = None, + name_id_or_prefix: UUID | str | None, + new_name: str | None = None, + new_display_name: str | None = None, + new_description: str | None = None, ) -> ProjectResponse: """Update a project. @@ -1193,10 +1185,10 @@ def active_project(self) -> ProjectResponse: def create_stack( self, name: str, - components: Mapping[StackComponentType, Union[str, UUID]], - stack_spec_file: Optional[str] = None, - labels: Optional[Dict[str, Any]] = None, - secrets: Optional[Sequence[Union[UUID, str]]] = None, + components: Mapping[StackComponentType, str | UUID], + stack_spec_file: str | None = None, + labels: dict[str, Any] | None = None, + secrets: Sequence[UUID | str] | None = None, ) -> StackResponse: """Registers a stack and its components. @@ -1238,7 +1230,7 @@ def create_stack( def get_stack( self, - name_id_or_prefix: Optional[Union[UUID, str]] = None, + name_id_or_prefix: UUID | str | None = None, allow_name_prefix_match: bool = True, hydrate: bool = True, ) -> StackResponse: @@ -1272,14 +1264,14 @@ def list_stacks( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - description: Optional[str] = None, - component_id: Optional[Union[str, UUID]] = None, - user: Optional[Union[UUID, str]] = None, - component: Optional[Union[UUID, str]] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + description: str | None = None, + component_id: str | UUID | None = None, + user: UUID | str | None = None, + component: UUID | str | None = None, hydrate: bool = False, ) -> Page[StackResponse]: """Lists all stacks. @@ -1321,17 +1313,17 @@ def list_stacks( def update_stack( self, - name_id_or_prefix: Optional[Union[UUID, str]] = None, - name: Optional[str] = None, - stack_spec_file: Optional[str] = None, - labels: Optional[Dict[str, Any]] = None, - description: Optional[str] = None, - component_updates: Optional[ - Dict[StackComponentType, List[Union[UUID, str]]] - ] = None, - add_secrets: Optional[Sequence[Union[UUID, str]]] = None, - remove_secrets: Optional[Sequence[Union[UUID, str]]] = None, - environment: Optional[Dict[str, Any]] = None, + name_id_or_prefix: UUID | str | None = None, + name: str | None = None, + stack_spec_file: str | None = None, + labels: dict[str, Any] | None = None, + description: str | None = None, + component_updates: None | ( + dict[StackComponentType, list[UUID | str]] + ) = None, + add_secrets: Sequence[UUID | str] | None = None, + remove_secrets: Sequence[UUID | str] | None = None, + environment: dict[str, Any] | None = None, ) -> StackResponse: """Updates a stack and its components. @@ -1431,7 +1423,7 @@ def update_stack( return updated_stack def delete_stack( - self, name_id_or_prefix: Union[str, UUID], recursive: bool = False + self, name_id_or_prefix: str | UUID, recursive: bool = False ) -> None: """Deregisters a stack. @@ -1527,7 +1519,7 @@ def active_stack_model(self) -> StackResponse: return self._active_stack - stack_id: Optional[UUID] = None + stack_id: UUID | None = None if self._config: if self._config._active_stack: @@ -1547,7 +1539,7 @@ def active_stack_model(self) -> StackResponse: return self.get_stack(stack_id) def activate_stack( - self, stack_name_id_or_prefix: Union[str, UUID] + self, stack_name_id_or_prefix: str | UUID ) -> None: """Sets the stack as active. @@ -1580,11 +1572,11 @@ def _validate_stack_configuration(self, stack: StackRequest) -> None: Args: stack: The stack to validate. """ - local_components: List[str] = [] - remote_components: List[str] = [] + local_components: list[str] = [] + remote_components: list[str] = [] assert stack.components is not None for component_type, components in stack.components.items(): - component_flavor: Union[FlavorResponse, str] + component_flavor: FlavorResponse | str for component in components: if isinstance(component, UUID): @@ -1649,7 +1641,7 @@ def create_service( self, config: "ServiceConfig", service_type: ServiceType, - model_version_id: Optional[UUID] = None, + model_version_id: UUID | None = None, ) -> ServiceResponse: """Registers a service. @@ -1674,11 +1666,11 @@ def create_service( def get_service( self, - name_id_or_prefix: Union[str, UUID], + name_id_or_prefix: str | UUID, allow_name_prefix_match: bool = True, hydrate: bool = True, - type: Optional[str] = None, - project: Optional[Union[str, UUID]] = None, + type: str | None = None, + project: str | UUID | None = None, ) -> ServiceResponse: """Gets a service. @@ -1731,21 +1723,21 @@ def list_services( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, - type: Optional[str] = None, - flavor: Optional[str] = None, - user: Optional[Union[UUID, str]] = None, - project: Optional[Union[str, UUID]] = None, + id: UUID | str | None = None, + created: datetime | None = None, + updated: datetime | None = None, + type: str | None = None, + flavor: str | None = None, + user: UUID | str | None = None, + project: str | UUID | None = None, hydrate: bool = False, - running: Optional[bool] = None, - service_name: Optional[str] = None, - pipeline_name: Optional[str] = None, - pipeline_run_id: Optional[str] = None, - pipeline_step_name: Optional[str] = None, - model_version_id: Optional[Union[str, UUID]] = None, - config: Optional[Dict[str, Any]] = None, + running: bool | None = None, + service_name: str | None = None, + pipeline_name: str | None = None, + pipeline_run_id: str | None = None, + pipeline_step_name: str | None = None, + model_version_id: str | UUID | None = None, + config: dict[str, Any] | None = None, ) -> Page[ServiceResponse]: """List all services. @@ -1802,15 +1794,15 @@ def list_services( def update_service( self, id: UUID, - name: Optional[str] = None, - service_source: Optional[str] = None, - admin_state: Optional[ServiceState] = None, - status: Optional[Dict[str, Any]] = None, - endpoint: Optional[Dict[str, Any]] = None, - labels: Optional[Dict[str, str]] = None, - prediction_url: Optional[str] = None, - health_check_url: Optional[str] = None, - model_version_id: Optional[UUID] = None, + name: str | None = None, + service_source: str | None = None, + admin_state: ServiceState | None = None, + status: dict[str, Any] | None = None, + endpoint: dict[str, Any] | None = None, + labels: dict[str, str] | None = None, + prediction_url: str | None = None, + health_check_url: str | None = None, + model_version_id: UUID | None = None, ) -> ServiceResponse: """Update a service. @@ -1855,7 +1847,7 @@ def update_service( def delete_service( self, name_id_or_prefix: UUID, - project: Optional[Union[str, UUID]] = None, + project: str | UUID | None = None, ) -> None: """Delete a service. @@ -1875,7 +1867,7 @@ def delete_service( def get_stack_component( self, component_type: StackComponentType, - name_id_or_prefix: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID | None = None, allow_name_prefix_match: bool = True, hydrate: bool = True, ) -> ComponentResponse: @@ -1949,15 +1941,15 @@ def list_stack_components( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, - name: Optional[str] = None, - flavor: Optional[str] = None, - type: Optional[str] = None, - connector_id: Optional[Union[str, UUID]] = None, - stack_id: Optional[Union[str, UUID]] = None, - user: Optional[Union[UUID, str]] = None, + id: UUID | str | None = None, + created: datetime | None = None, + updated: datetime | None = None, + name: str | None = None, + flavor: str | None = None, + type: str | None = None, + connector_id: str | UUID | None = None, + stack_id: str | UUID | None = None, + user: UUID | str | None = None, hydrate: bool = False, ) -> Page[ComponentResponse]: """Lists all registered stack components. @@ -2007,10 +1999,10 @@ def create_stack_component( name: str, flavor: str, component_type: StackComponentType, - configuration: Dict[str, str], - labels: Optional[Dict[str, Any]] = None, - secrets: Optional[Sequence[Union[UUID, str]]] = None, - environment: Optional[Dict[str, Any]] = None, + configuration: dict[str, str], + labels: dict[str, Any] | None = None, + secrets: Sequence[UUID | str] | None = None, + environment: dict[str, Any] | None = None, ) -> "ComponentResponse": """Registers a stack component. @@ -2062,17 +2054,17 @@ def create_stack_component( def update_stack_component( self, - name_id_or_prefix: Optional[Union[UUID, str]], + name_id_or_prefix: UUID | str | None, component_type: StackComponentType, - name: Optional[str] = None, - configuration: Optional[Dict[str, Any]] = None, - labels: Optional[Dict[str, Any]] = None, - disconnect: Optional[bool] = None, - connector_id: Optional[UUID] = None, - connector_resource_id: Optional[str] = None, - add_secrets: Optional[Sequence[Union[UUID, str]]] = None, - remove_secrets: Optional[Sequence[Union[UUID, str]]] = None, - environment: Optional[Dict[str, Any]] = None, + name: str | None = None, + configuration: dict[str, Any] | None = None, + labels: dict[str, Any] | None = None, + disconnect: bool | None = None, + connector_id: UUID | None = None, + connector_resource_id: str | None = None, + add_secrets: Sequence[UUID | str] | None = None, + remove_secrets: Sequence[UUID | str] | None = None, + environment: dict[str, Any] | None = None, ) -> ComponentResponse: """Updates a stack component. @@ -2199,7 +2191,7 @@ def update_stack_component( def delete_stack_component( self, - name_id_or_prefix: Union[str, UUID], + name_id_or_prefix: str | UUID, component_type: StackComponentType, ) -> None: """Deletes a registered stack component. @@ -2289,13 +2281,13 @@ def list_flavors( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, - name: Optional[str] = None, - type: Optional[str] = None, - integration: Optional[str] = None, - user: Optional[Union[UUID, str]] = None, + id: UUID | str | None = None, + created: datetime | None = None, + updated: datetime | None = None, + name: str | None = None, + type: str | None = None, + integration: str | None = None, + user: UUID | str | None = None, hydrate: bool = False, ) -> Page[FlavorResponse]: """Fetches all the flavor models. @@ -2410,16 +2402,16 @@ def list_pipelines( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - latest_run_status: Optional[str] = None, - latest_run_user: Optional[Union[UUID, str]] = None, - project: Optional[Union[str, UUID]] = None, - user: Optional[Union[UUID, str]] = None, - tag: Optional[str] = None, - tags: Optional[List[str]] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + latest_run_status: str | None = None, + latest_run_user: UUID | str | None = None, + project: str | UUID | None = None, + user: UUID | str | None = None, + tag: str | None = None, + tags: list[str] | None = None, hydrate: bool = False, ) -> Page[PipelineResponse]: """List all pipelines. @@ -2470,8 +2462,8 @@ def list_pipelines( def get_pipeline( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, hydrate: bool = True, ) -> PipelineResponse: """Get a pipeline by name, id or prefix. @@ -2495,8 +2487,8 @@ def get_pipeline( def delete_pipeline( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, ) -> None: """Delete a pipeline. @@ -2513,8 +2505,8 @@ def delete_pipeline( def get_build( self, - id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + id_or_prefix: str | UUID, + project: str | UUID | None = None, hydrate: bool = True, ) -> PipelineBuildResponse: """Get a build by id or prefix. @@ -2545,7 +2537,7 @@ def get_build( hydrate=hydrate, ) - list_kwargs: Dict[str, Any] = dict( + list_kwargs: dict[str, Any] = dict( id=f"startswith:{id_or_prefix}", hydrate=hydrate, ) @@ -2582,21 +2574,21 @@ def list_builds( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - project: Optional[Union[str, UUID]] = None, - user: Optional[Union[UUID, str]] = None, - pipeline_id: Optional[Union[str, UUID]] = None, - stack_id: Optional[Union[str, UUID]] = None, - container_registry_id: Optional[Union[UUID, str]] = None, - is_local: Optional[bool] = None, - contains_code: Optional[bool] = None, - zenml_version: Optional[str] = None, - python_version: Optional[str] = None, - checksum: Optional[str] = None, - stack_checksum: Optional[str] = None, - duration: Optional[Union[int, str]] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + project: str | UUID | None = None, + user: UUID | str | None = None, + pipeline_id: str | UUID | None = None, + stack_id: str | UUID | None = None, + container_registry_id: UUID | str | None = None, + is_local: bool | None = None, + contains_code: bool | None = None, + zenml_version: str | None = None, + python_version: str | None = None, + checksum: str | None = None, + stack_checksum: str | None = None, + duration: int | str | None = None, hydrate: bool = False, ) -> Page[PipelineBuildResponse]: """List all builds. @@ -2655,7 +2647,7 @@ def list_builds( ) def delete_build( - self, id_or_prefix: str, project: Optional[Union[str, UUID]] = None + self, id_or_prefix: str, project: str | UUID | None = None ) -> None: """Delete a build. @@ -2672,7 +2664,7 @@ def delete_build( def create_event_source( self, name: str, - configuration: Dict[str, Any], + configuration: dict[str, Any], flavor: str, event_source_subtype: PluginSubType, description: str = "", @@ -2704,9 +2696,9 @@ def create_event_source( @_fail_for_sql_zen_store def get_event_source( self, - name_id_or_prefix: Union[UUID, str], + name_id_or_prefix: UUID | str, allow_name_prefix_match: bool = True, - project: Optional[Union[str, UUID]] = None, + project: str | UUID | None = None, hydrate: bool = True, ) -> EventSourceResponse: """Get an event source by name, ID or prefix. @@ -2736,14 +2728,14 @@ def list_event_sources( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, - name: Optional[str] = None, - flavor: Optional[str] = None, - event_source_type: Optional[str] = None, - project: Optional[Union[str, UUID]] = None, - user: Optional[Union[UUID, str]] = None, + id: UUID | str | None = None, + created: datetime | None = None, + updated: datetime | None = None, + name: str | None = None, + flavor: str | None = None, + event_source_type: str | None = None, + project: str | UUID | None = None, + user: UUID | str | None = None, hydrate: bool = False, ) -> Page[EventSourceResponse]: """Lists all event_sources. @@ -2788,13 +2780,13 @@ def list_event_sources( @_fail_for_sql_zen_store def update_event_source( self, - name_id_or_prefix: Union[UUID, str], - name: Optional[str] = None, - description: Optional[str] = None, - configuration: Optional[Dict[str, Any]] = None, - rotate_secret: Optional[bool] = None, - is_active: Optional[bool] = None, - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: UUID | str, + name: str | None = None, + description: str | None = None, + configuration: dict[str, Any] | None = None, + rotate_secret: bool | None = None, + is_active: bool | None = None, + project: str | UUID | None = None, ) -> EventSourceResponse: """Updates an event_source. @@ -2847,8 +2839,8 @@ def update_event_source( @_fail_for_sql_zen_store def delete_event_source( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, ) -> None: """Deletes an event_source. @@ -2874,9 +2866,9 @@ def create_action( name: str, flavor: str, action_type: PluginSubType, - configuration: Dict[str, Any], + configuration: dict[str, Any], service_account_id: UUID, - auth_window: Optional[int] = None, + auth_window: int | None = None, description: str = "", ) -> ActionResponse: """Create an action. @@ -2912,9 +2904,9 @@ def create_action( @_fail_for_sql_zen_store def get_action( self, - name_id_or_prefix: Union[UUID, str], + name_id_or_prefix: UUID | str, allow_name_prefix_match: bool = True, - project: Optional[Union[str, UUID]] = None, + project: str | UUID | None = None, hydrate: bool = True, ) -> ActionResponse: """Get an action by name, ID or prefix. @@ -2945,14 +2937,14 @@ def list_actions( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, - name: Optional[str] = None, - flavor: Optional[str] = None, - action_type: Optional[str] = None, - project: Optional[Union[str, UUID]] = None, - user: Optional[Union[UUID, str]] = None, + id: UUID | str | None = None, + created: datetime | None = None, + updated: datetime | None = None, + name: str | None = None, + flavor: str | None = None, + action_type: str | None = None, + project: str | UUID | None = None, + user: UUID | str | None = None, hydrate: bool = False, ) -> Page[ActionResponse]: """List actions. @@ -2995,13 +2987,13 @@ def list_actions( @_fail_for_sql_zen_store def update_action( self, - name_id_or_prefix: Union[UUID, str], - name: Optional[str] = None, - description: Optional[str] = None, - configuration: Optional[Dict[str, Any]] = None, - service_account_id: Optional[UUID] = None, - auth_window: Optional[int] = None, - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: UUID | str, + name: str | None = None, + description: str | None = None, + configuration: dict[str, Any] | None = None, + service_account_id: UUID | None = None, + auth_window: int | None = None, + project: str | UUID | None = None, ) -> ActionResponse: """Update an action. @@ -3042,8 +3034,8 @@ def update_action( @_fail_for_sql_zen_store def delete_action( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, ) -> None: """Delete an action. @@ -3068,7 +3060,7 @@ def create_trigger( self, name: str, event_source_id: UUID, - event_filter: Dict[str, Any], + event_filter: dict[str, Any], action_id: UUID, description: str = "", ) -> TriggerResponse: @@ -3098,9 +3090,9 @@ def create_trigger( @_fail_for_sql_zen_store def get_trigger( self, - name_id_or_prefix: Union[UUID, str], + name_id_or_prefix: UUID | str, allow_name_prefix_match: bool = True, - project: Optional[Union[str, UUID]] = None, + project: str | UUID | None = None, hydrate: bool = True, ) -> TriggerResponse: """Get a trigger by name, ID or prefix. @@ -3131,18 +3123,18 @@ def list_triggers( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, - name: Optional[str] = None, - event_source_id: Optional[UUID] = None, - action_id: Optional[UUID] = None, - event_source_flavor: Optional[str] = None, - event_source_subtype: Optional[str] = None, - action_flavor: Optional[str] = None, - action_subtype: Optional[str] = None, - project: Optional[Union[str, UUID]] = None, - user: Optional[Union[UUID, str]] = None, + id: UUID | str | None = None, + created: datetime | None = None, + updated: datetime | None = None, + name: str | None = None, + event_source_id: UUID | None = None, + action_id: UUID | None = None, + event_source_flavor: str | None = None, + event_source_subtype: str | None = None, + action_flavor: str | None = None, + action_subtype: str | None = None, + project: str | UUID | None = None, + user: UUID | str | None = None, hydrate: bool = False, ) -> Page[TriggerResponse]: """Lists all triggers. @@ -3197,12 +3189,12 @@ def list_triggers( @_fail_for_sql_zen_store def update_trigger( self, - name_id_or_prefix: Union[UUID, str], - name: Optional[str] = None, - description: Optional[str] = None, - event_filter: Optional[Dict[str, Any]] = None, - is_active: Optional[bool] = None, - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: UUID | str, + name: str | None = None, + description: str | None = None, + event_filter: dict[str, Any] | None = None, + is_active: bool | None = None, + project: str | UUID | None = None, ) -> TriggerResponse: """Updates a trigger. @@ -3251,8 +3243,8 @@ def update_trigger( @_fail_for_sql_zen_store def delete_trigger( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, ) -> None: """Deletes an trigger. @@ -3274,11 +3266,11 @@ def delete_trigger( def get_snapshot( self, - name_id_or_prefix: Union[str, UUID], + name_id_or_prefix: str | UUID, *, - pipeline_name_or_id: Optional[Union[str, UUID]] = None, - project: Optional[Union[str, UUID]] = None, - include_config_schema: Optional[bool] = None, + pipeline_name_or_id: str | UUID | None = None, + project: str | UUID | None = None, + include_config_schema: bool | None = None, allow_prefix_match: bool = True, hydrate: bool = True, ) -> PipelineSnapshotResponse: @@ -3320,7 +3312,7 @@ def get_snapshot( include_config_schema=include_config_schema, ) - list_kwargs: Dict[str, Any] = { + list_kwargs: dict[str, Any] = { "named_only": None, "project": project, "hydrate": hydrate, @@ -3389,23 +3381,23 @@ def list_snapshots( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - project: Optional[Union[str, UUID]] = None, - user: Optional[Union[UUID, str]] = None, - name: Optional[str] = None, - named_only: Optional[bool] = True, - pipeline: Optional[Union[str, UUID]] = None, - stack: Optional[Union[str, UUID]] = None, - build_id: Optional[Union[str, UUID]] = None, - schedule_id: Optional[Union[str, UUID]] = None, - source_snapshot_id: Optional[Union[str, UUID]] = None, - runnable: Optional[bool] = None, - deployable: Optional[bool] = None, - deployed: Optional[bool] = None, - tag: Optional[str] = None, - tags: Optional[List[str]] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + project: str | UUID | None = None, + user: UUID | str | None = None, + name: str | None = None, + named_only: bool | None = True, + pipeline: str | UUID | None = None, + stack: str | UUID | None = None, + build_id: str | UUID | None = None, + schedule_id: str | UUID | None = None, + source_snapshot_id: str | UUID | None = None, + runnable: bool | None = None, + deployable: bool | None = None, + deployed: bool | None = None, + tag: str | None = None, + tags: list[str] | None = None, hydrate: bool = False, ) -> Page[PipelineSnapshotResponse]: """List all snapshots. @@ -3469,13 +3461,13 @@ def list_snapshots( def update_snapshot( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, - name: Optional[str] = None, - description: Optional[str] = None, - replace: Optional[bool] = None, - add_tags: Optional[List[str]] = None, - remove_tags: Optional[List[str]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, + name: str | None = None, + description: str | None = None, + replace: bool | None = None, + add_tags: list[str] | None = None, + remove_tags: list[str] | None = None, ) -> PipelineSnapshotResponse: """Update a snapshot. @@ -3511,8 +3503,8 @@ def update_snapshot( def delete_snapshot( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, ) -> None: """Delete a snapshot. @@ -3530,16 +3522,16 @@ def delete_snapshot( @_fail_for_sql_zen_store def trigger_pipeline( self, - snapshot_name_or_id: Optional[Union[str, UUID]] = None, - pipeline_name_or_id: Union[str, UUID, None] = None, - run_configuration: Union[ - PipelineRunConfiguration, Dict[str, Any], None - ] = None, - config_path: Optional[str] = None, - stack_name_or_id: Union[str, UUID, None] = None, + snapshot_name_or_id: str | UUID | None = None, + pipeline_name_or_id: str | UUID | None = None, + run_configuration: ( + PipelineRunConfiguration | dict[str, Any] | None + ) = None, + config_path: str | None = None, + stack_name_or_id: str | UUID | None = None, synchronous: bool = False, - project: Optional[Union[str, UUID]] = None, - template_id: Optional[UUID] = None, + project: str | UUID | None = None, + template_id: UUID | None = None, ) -> PipelineRunResponse: """Run a pipeline snapshot. @@ -3608,7 +3600,7 @@ def trigger_pipeline( if config_path: run_configuration = PipelineRunConfiguration.from_yaml(config_path) - if isinstance(run_configuration, Dict): + if isinstance(run_configuration, dict): run_configuration = PipelineRunConfiguration.model_validate( run_configuration ) @@ -3720,8 +3712,8 @@ def trigger_pipeline( def get_deployment( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, hydrate: bool = True, ) -> DeploymentResponse: """Get a deployment. @@ -3750,9 +3742,9 @@ def create_curated_visualization( *, resource_id: UUID, resource_type: VisualizationResourceTypes, - project_id: Optional[UUID] = None, - display_name: Optional[str] = None, - display_order: Optional[int] = None, + project_id: UUID | None = None, + display_name: str | None = None, + display_order: int | None = None, layout_size: CuratedVisualizationSize = CuratedVisualizationSize.FULL_WIDTH, ) -> CuratedVisualizationResponse: """Create a curated visualization associated with a resource. @@ -3799,9 +3791,9 @@ def update_curated_visualization( self, visualization_id: UUID, *, - display_name: Optional[str] = None, - display_order: Optional[int] = None, - layout_size: Optional[CuratedVisualizationSize] = None, + display_name: str | None = None, + display_order: int | None = None, + layout_size: CuratedVisualizationSize | None = None, ) -> CuratedVisualizationResponse: """Update display metadata for a curated visualization. @@ -3840,19 +3832,19 @@ def list_deployments( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - snapshot_id: Optional[Union[str, UUID]] = None, - deployer_id: Optional[Union[str, UUID]] = None, - project: Optional[Union[str, UUID]] = None, - status: Optional[DeploymentStatus] = None, - url: Optional[str] = None, - user: Optional[Union[UUID, str]] = None, - pipeline: Optional[Union[UUID, str]] = None, - tag: Optional[str] = None, - tags: Optional[List[str]] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + snapshot_id: str | UUID | None = None, + deployer_id: str | UUID | None = None, + project: str | UUID | None = None, + status: DeploymentStatus | None = None, + url: str | None = None, + user: UUID | str | None = None, + pipeline: UUID | str | None = None, + tag: str | None = None, + tags: list[str] | None = None, hydrate: bool = False, ) -> Page[DeploymentResponse]: """List deployments. @@ -3906,10 +3898,10 @@ def list_deployments( def provision_deployment( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, - snapshot_id: Optional[Union[str, UUID]] = None, - timeout: Optional[int] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, + snapshot_id: str | UUID | None = None, + timeout: int | None = None, ) -> DeploymentResponse: """Provision a deployment. @@ -3938,7 +3930,7 @@ def provision_deployment( from zenml.stack.stack import Stack from zenml.stack.stack_component import StackComponent - deployment: Optional[DeploymentResponse] = None + deployment: DeploymentResponse | None = None deployment_name_or_id = name_id_or_prefix try: deployment = self.get_deployment( @@ -3952,7 +3944,7 @@ def provision_deployment( raise stack = Client().active_stack - deployer: Optional[BaseDeployer] = None + deployer: BaseDeployer | None = None if snapshot_id: snapshot = self.get_snapshot( @@ -4020,9 +4012,9 @@ def provision_deployment( def deprovision_deployment( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, - timeout: Optional[int] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, + timeout: int | None = None, ) -> None: """Deprovision a deployment. @@ -4077,10 +4069,10 @@ def deprovision_deployment( def delete_deployment( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, force: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> None: """Deprovision and delete a deployment. @@ -4151,8 +4143,8 @@ def delete_deployment( def refresh_deployment( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, ) -> DeploymentResponse: """Refresh the status of a deployment. @@ -4201,10 +4193,10 @@ def refresh_deployment( def get_deployment_logs( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, follow: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> Generator[str, bool, None]: """Get the logs of a deployment. @@ -4262,8 +4254,8 @@ def create_run_template( self, name: str, snapshot_id: UUID, - description: Optional[str] = None, - tags: Optional[List[str]] = None, + description: str | None = None, + tags: list[str] | None = None, ) -> RunTemplateResponse: """Create a run template. @@ -4289,8 +4281,8 @@ def create_run_template( def get_run_template( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, hydrate: bool = True, ) -> RunTemplateResponse: """Get a run template. @@ -4321,20 +4313,20 @@ def list_run_templates( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - id: Optional[Union[UUID, str]] = None, - name: Optional[str] = None, - hidden: Optional[bool] = False, - tag: Optional[str] = None, - project: Optional[Union[str, UUID]] = None, - pipeline_id: Optional[Union[str, UUID]] = None, - build_id: Optional[Union[str, UUID]] = None, - stack_id: Optional[Union[str, UUID]] = None, - code_repository_id: Optional[Union[str, UUID]] = None, - user: Optional[Union[UUID, str]] = None, - pipeline: Optional[Union[UUID, str]] = None, - stack: Optional[Union[UUID, str]] = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + id: UUID | str | None = None, + name: str | None = None, + hidden: bool | None = False, + tag: str | None = None, + project: str | UUID | None = None, + pipeline_id: str | UUID | None = None, + build_id: str | UUID | None = None, + stack_id: str | UUID | None = None, + code_repository_id: str | UUID | None = None, + user: UUID | str | None = None, + pipeline: UUID | str | None = None, + stack: UUID | str | None = None, hydrate: bool = False, ) -> Page[RunTemplateResponse]: """Get a page of run templates. @@ -4391,13 +4383,13 @@ def list_run_templates( def update_run_template( self, - name_id_or_prefix: Union[str, UUID], - name: Optional[str] = None, - description: Optional[str] = None, - hidden: Optional[bool] = None, - add_tags: Optional[List[str]] = None, - remove_tags: Optional[List[str]] = None, - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + name: str | None = None, + description: str | None = None, + hidden: bool | None = None, + add_tags: list[str] | None = None, + remove_tags: list[str] | None = None, + project: str | UUID | None = None, ) -> RunTemplateResponse: """Update a run template. @@ -4439,8 +4431,8 @@ def update_run_template( def delete_run_template( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, ) -> None: """Delete a run template. @@ -4467,9 +4459,9 @@ def delete_run_template( def get_schedule( self, - name_id_or_prefix: Union[str, UUID], + name_id_or_prefix: str | UUID, allow_name_prefix_match: bool = True, - project: Optional[Union[str, UUID]] = None, + project: str | UUID | None = None, hydrate: bool = True, ) -> ScheduleResponse: """Get a schedule by name, id or prefix. @@ -4499,22 +4491,22 @@ def list_schedules( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - project: Optional[Union[str, UUID]] = None, - user: Optional[Union[UUID, str]] = None, - pipeline_id: Optional[Union[str, UUID]] = None, - orchestrator_id: Optional[Union[str, UUID]] = None, - active: Optional[Union[str, bool]] = None, - cron_expression: Optional[str] = None, - start_time: Optional[Union[datetime, str]] = None, - end_time: Optional[Union[datetime, str]] = None, - interval_second: Optional[int] = None, - catchup: Optional[Union[str, bool]] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + project: str | UUID | None = None, + user: UUID | str | None = None, + pipeline_id: str | UUID | None = None, + orchestrator_id: str | UUID | None = None, + active: str | bool | None = None, + cron_expression: str | None = None, + start_time: datetime | str | None = None, + end_time: datetime | str | None = None, + interval_second: int | None = None, + catchup: str | bool | None = None, hydrate: bool = False, - run_once_start_time: Optional[Union[datetime, str]] = None, + run_once_start_time: datetime | str | None = None, ) -> Page[ScheduleResponse]: """List schedules. @@ -4600,8 +4592,8 @@ def _get_orchestrator_for_schedule( def update_schedule( self, - name_id_or_prefix: Union[str, UUID], - cron_expression: Optional[str] = None, + name_id_or_prefix: str | UUID, + cron_expression: str | None = None, ) -> ScheduleResponse: """Update a schedule. @@ -4640,8 +4632,8 @@ def update_schedule( def delete_schedule( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, ) -> None: """Delete a schedule. @@ -4676,9 +4668,9 @@ def delete_schedule( def get_pipeline_run( self, - name_id_or_prefix: Union[str, UUID], + name_id_or_prefix: str | UUID, allow_name_prefix_match: bool = True, - project: Optional[Union[str, UUID]] = None, + project: str | UUID | None = None, hydrate: bool = True, include_full_metadata: bool = False, ) -> PipelineRunResponse: @@ -4712,42 +4704,42 @@ def list_pipeline_runs( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - project: Optional[Union[str, UUID]] = None, - pipeline_id: Optional[Union[str, UUID]] = None, - pipeline_name: Optional[str] = None, - stack_id: Optional[Union[str, UUID]] = None, - schedule_id: Optional[Union[str, UUID]] = None, - build_id: Optional[Union[str, UUID]] = None, - snapshot_id: Optional[Union[str, UUID]] = None, - code_repository_id: Optional[Union[str, UUID]] = None, - template_id: Optional[Union[str, UUID]] = None, - source_snapshot_id: Optional[Union[str, UUID]] = None, - model_version_id: Optional[Union[str, UUID]] = None, - linked_to_model_version_id: Optional[Union[str, UUID]] = None, - orchestrator_run_id: Optional[str] = None, - status: Optional[str] = None, - start_time: Optional[Union[datetime, str]] = None, - end_time: Optional[Union[datetime, str]] = None, - unlisted: Optional[bool] = None, - templatable: Optional[bool] = None, - tag: Optional[str] = None, - tags: Optional[List[str]] = None, - user: Optional[Union[UUID, str]] = None, - run_metadata: Optional[List[str]] = None, - pipeline: Optional[Union[UUID, str]] = None, - code_repository: Optional[Union[UUID, str]] = None, - model: Optional[Union[UUID, str]] = None, - stack: Optional[Union[UUID, str]] = None, - stack_component: Optional[Union[UUID, str]] = None, - in_progress: Optional[bool] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + project: str | UUID | None = None, + pipeline_id: str | UUID | None = None, + pipeline_name: str | None = None, + stack_id: str | UUID | None = None, + schedule_id: str | UUID | None = None, + build_id: str | UUID | None = None, + snapshot_id: str | UUID | None = None, + code_repository_id: str | UUID | None = None, + template_id: str | UUID | None = None, + source_snapshot_id: str | UUID | None = None, + model_version_id: str | UUID | None = None, + linked_to_model_version_id: str | UUID | None = None, + orchestrator_run_id: str | None = None, + status: str | None = None, + start_time: datetime | str | None = None, + end_time: datetime | str | None = None, + unlisted: bool | None = None, + templatable: bool | None = None, + tag: str | None = None, + tags: list[str] | None = None, + user: UUID | str | None = None, + run_metadata: list[str] | None = None, + pipeline: UUID | str | None = None, + code_repository: UUID | str | None = None, + model: UUID | str | None = None, + stack: UUID | str | None = None, + stack_component: UUID | str | None = None, + in_progress: bool | None = None, hydrate: bool = False, include_full_metadata: bool = False, - triggered_by_step_run_id: Optional[Union[UUID, str]] = None, - triggered_by_deployment_id: Optional[Union[UUID, str]] = None, + triggered_by_step_run_id: UUID | str | None = None, + triggered_by_deployment_id: UUID | str | None = None, ) -> Page[PipelineRunResponse]: """List all pipeline runs. @@ -4853,8 +4845,8 @@ def list_pipeline_runs( def delete_pipeline_run( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, ) -> None: """Deletes a pipeline run. @@ -4897,26 +4889,26 @@ def list_run_steps( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - cache_key: Optional[str] = None, - cache_expires_at: Optional[Union[datetime, str]] = None, - cache_expired: Optional[bool] = None, - code_hash: Optional[str] = None, - status: Optional[str] = None, - start_time: Optional[Union[datetime, str]] = None, - end_time: Optional[Union[datetime, str]] = None, - pipeline_run_id: Optional[Union[str, UUID]] = None, - snapshot_id: Optional[Union[str, UUID]] = None, - original_step_run_id: Optional[Union[str, UUID]] = None, - project: Optional[Union[str, UUID]] = None, - user: Optional[Union[UUID, str]] = None, - model_version_id: Optional[Union[str, UUID]] = None, - model: Optional[Union[UUID, str]] = None, - run_metadata: Optional[List[str]] = None, - exclude_retried: Optional[bool] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + cache_key: str | None = None, + cache_expires_at: datetime | str | None = None, + cache_expired: bool | None = None, + code_hash: str | None = None, + status: str | None = None, + start_time: datetime | str | None = None, + end_time: datetime | str | None = None, + pipeline_run_id: str | UUID | None = None, + snapshot_id: str | UUID | None = None, + original_step_run_id: str | UUID | None = None, + project: str | UUID | None = None, + user: UUID | str | None = None, + model_version_id: str | UUID | None = None, + model: UUID | str | None = None, + run_metadata: list[str] | None = None, + exclude_retried: bool | None = None, hydrate: bool = False, ) -> Page[StepRunResponse]: """List all pipelines. @@ -4988,7 +4980,7 @@ def list_run_steps( def update_step_run( self, step_run_id: UUID, - cache_expires_at: Optional[datetime] = None, + cache_expires_at: datetime | None = None, ) -> StepRunResponse: """Update a step run. @@ -5009,8 +5001,8 @@ def update_step_run( def get_artifact( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, hydrate: bool = False, ) -> ArtifactResponse: """Get an artifact by name, id or prefix. @@ -5038,16 +5030,16 @@ def list_artifacts( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - has_custom_name: Optional[bool] = None, - user: Optional[Union[UUID, str]] = None, - project: Optional[Union[str, UUID]] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + has_custom_name: bool | None = None, + user: UUID | str | None = None, + project: str | UUID | None = None, hydrate: bool = False, - tag: Optional[str] = None, - tags: Optional[List[str]] = None, + tag: str | None = None, + tags: list[str] | None = None, ) -> Page[ArtifactResponse]: """Get a list of artifacts. @@ -5093,12 +5085,12 @@ def list_artifacts( def update_artifact( self, - name_id_or_prefix: Union[str, UUID], - new_name: Optional[str] = None, - add_tags: Optional[List[str]] = None, - remove_tags: Optional[List[str]] = None, - has_custom_name: Optional[bool] = None, - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + new_name: str | None = None, + add_tags: list[str] | None = None, + remove_tags: list[str] | None = None, + has_custom_name: bool | None = None, + project: str | UUID | None = None, ) -> ArtifactResponse: """Update an artifact. @@ -5129,8 +5121,8 @@ def update_artifact( def delete_artifact( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, ) -> None: """Delete an artifact. @@ -5149,7 +5141,7 @@ def prune_artifacts( self, only_versions: bool = True, delete_from_artifact_store: bool = False, - project: Optional[Union[str, UUID]] = None, + project: str | UUID | None = None, ) -> None: """Delete all unused artifacts and artifact versions. @@ -5180,9 +5172,9 @@ def prune_artifacts( def get_artifact_version( self, - name_id_or_prefix: Union[str, UUID], - version: Optional[str] = None, - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + version: str | None = None, + project: str | UUID | None = None, hydrate: bool = True, ) -> ArtifactVersionResponse: """Get an artifact version by ID or artifact name. @@ -5238,28 +5230,28 @@ def list_artifact_versions( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - artifact: Optional[Union[str, UUID]] = None, - name: Optional[str] = None, - version: Optional[Union[str, int]] = None, - version_number: Optional[int] = None, - artifact_store_id: Optional[Union[str, UUID]] = None, - type: Optional[Union[ArtifactType, str]] = None, - data_type: Optional[str] = None, - uri: Optional[str] = None, - materializer: Optional[str] = None, - project: Optional[Union[str, UUID]] = None, - model_version_id: Optional[Union[str, UUID]] = None, - only_unused: Optional[bool] = False, - has_custom_name: Optional[bool] = None, - user: Optional[Union[UUID, str]] = None, - model: Optional[Union[UUID, str]] = None, - pipeline_run: Optional[Union[UUID, str]] = None, - run_metadata: Optional[List[str]] = None, - tag: Optional[str] = None, - tags: Optional[List[str]] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + artifact: str | UUID | None = None, + name: str | None = None, + version: str | int | None = None, + version_number: int | None = None, + artifact_store_id: str | UUID | None = None, + type: ArtifactType | str | None = None, + data_type: str | None = None, + uri: str | None = None, + materializer: str | None = None, + project: str | UUID | None = None, + model_version_id: str | UUID | None = None, + only_unused: bool | None = False, + has_custom_name: bool | None = None, + user: UUID | str | None = None, + model: UUID | str | None = None, + pipeline_run: UUID | str | None = None, + run_metadata: list[str] | None = None, + tag: str | None = None, + tags: list[str] | None = None, hydrate: bool = False, ) -> Page[ArtifactVersionResponse]: """Get a list of artifact versions. @@ -5335,11 +5327,11 @@ def list_artifact_versions( def update_artifact_version( self, - name_id_or_prefix: Union[str, UUID], - version: Optional[str] = None, - add_tags: Optional[List[str]] = None, - remove_tags: Optional[List[str]] = None, - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + version: str | None = None, + add_tags: list[str] | None = None, + remove_tags: list[str] | None = None, + project: str | UUID | None = None, ) -> ArtifactVersionResponse: """Update an artifact version. @@ -5370,11 +5362,11 @@ def update_artifact_version( def delete_artifact_version( self, - name_id_or_prefix: Union[str, UUID], - version: Optional[str] = None, + name_id_or_prefix: str | UUID, + version: str | None = None, delete_metadata: bool = True, delete_from_artifact_store: bool = False, - project: Optional[Union[str, UUID]] = None, + project: str | UUID | None = None, ) -> None: """Delete an artifact version. @@ -5476,10 +5468,10 @@ def _delete_artifact_from_artifact_store( def create_run_metadata( self, - metadata: Dict[str, "MetadataType"], - resources: List[RunMetadataResource], - stack_component_id: Optional[UUID] = None, - publisher_step_id: Optional[UUID] = None, + metadata: dict[str, "MetadataType"], + resources: list[RunMetadataResource], + stack_component_id: UUID | None = None, + publisher_step_id: UUID | None = None, ) -> None: """Create run metadata. @@ -5494,8 +5486,8 @@ def create_run_metadata( """ from zenml.metadata.metadata_types import get_metadata_type - values: Dict[str, "MetadataType"] = {} - types: Dict[str, "MetadataTypeEnum"] = {} + values: dict[str, "MetadataType"] = {} + types: dict[str, "MetadataTypeEnum"] = {} for key, value in metadata.items(): # Skip metadata that is too large to be stored in the database. if len(json.dumps(value)) > TEXT_FIELD_MAX_LENGTH: @@ -5531,7 +5523,7 @@ def create_run_metadata( def create_secret( self, name: str, - values: Dict[str, str], + values: dict[str, str], private: bool = False, ) -> SecretResponse: """Creates a new secret. @@ -5564,8 +5556,8 @@ def create_secret( def get_secret( self, - name_id_or_prefix: Union[str, UUID], - private: Optional[bool] = None, + name_id_or_prefix: str | UUID, + private: bool | None = None, allow_partial_name_match: bool = True, allow_partial_id_match: bool = True, hydrate: bool = True, @@ -5644,7 +5636,7 @@ def get_secret( ) for search_private_status in search_private_statuses: - partial_matches: List[SecretResponse] = [] + partial_matches: list[SecretResponse] = [] for secret in secrets.items: if secret.private != search_private_status: continue @@ -5697,12 +5689,12 @@ def list_secrets( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, - name: Optional[str] = None, - private: Optional[bool] = None, - user: Optional[Union[UUID, str]] = None, + id: UUID | str | None = None, + created: datetime | None = None, + updated: datetime | None = None, + name: str | None = None, + private: bool | None = None, + user: UUID | str | None = None, hydrate: bool = False, ) -> Page[SecretResponse]: """Fetches all the secret models. @@ -5756,12 +5748,12 @@ def list_secrets( def update_secret( self, - name_id_or_prefix: Union[str, UUID], - private: Optional[bool] = None, - new_name: Optional[str] = None, - update_private: Optional[bool] = None, - add_or_update_values: Optional[Dict[str, str]] = None, - remove_values: Optional[List[str]] = None, + name_id_or_prefix: str | UUID, + private: bool | None = None, + new_name: str | None = None, + update_private: bool | None = None, + add_or_update_values: dict[str, str] | None = None, + remove_values: list[str] | None = None, ) -> SecretResponse: """Updates a secret. @@ -5796,7 +5788,7 @@ def update_secret( if update_private: secret_update.private = update_private - values: Dict[str, Optional[SecretStr]] = {} + values: dict[str, SecretStr | None] = {} if add_or_update_values: values.update( { @@ -5825,7 +5817,7 @@ def update_secret( ) def delete_secret( - self, name_id_or_prefix: str, private: Optional[bool] = None + self, name_id_or_prefix: str, private: bool | None = None ) -> None: """Deletes a secret. @@ -5846,7 +5838,7 @@ def delete_secret( def get_secret_by_name_and_private_status( self, name: str, - private: Optional[bool] = None, + private: bool | None = None, hydrate: bool = True, ) -> SecretResponse: """Fetches a registered secret with a given name and optional private status. @@ -5965,7 +5957,7 @@ def restore_secrets( @staticmethod def _validate_code_repository_config( - source: Source, config: Dict[str, Any] + source: Source, config: dict[str, Any] ) -> None: """Validate a code repository config. @@ -5978,7 +5970,7 @@ def _validate_code_repository_config( """ from zenml.code_repositories import BaseCodeRepository - code_repo_class: Type[BaseCodeRepository] = ( + code_repo_class: type[BaseCodeRepository] = ( source_utils.load_and_validate_class( source=source, expected_class=BaseCodeRepository ) @@ -5993,10 +5985,10 @@ def _validate_code_repository_config( def create_code_repository( self, name: str, - config: Dict[str, Any], + config: dict[str, Any], source: Source, - description: Optional[str] = None, - logo_url: Optional[str] = None, + description: str | None = None, + logo_url: str | None = None, ) -> CodeRepositoryResponse: """Create a new code repository. @@ -6025,9 +6017,9 @@ def create_code_repository( def get_code_repository( self, - name_id_or_prefix: Union[str, UUID], + name_id_or_prefix: str | UUID, allow_name_prefix_match: bool = True, - project: Optional[Union[str, UUID]] = None, + project: str | UUID | None = None, hydrate: bool = True, ) -> CodeRepositoryResponse: """Get a code repository by name, id or prefix. @@ -6057,12 +6049,12 @@ def list_code_repositories( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - project: Optional[Union[str, UUID]] = None, - user: Optional[Union[UUID, str]] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + project: str | UUID | None = None, + user: UUID | str | None = None, hydrate: bool = False, ) -> Page[CodeRepositoryResponse]: """List all code repositories. @@ -6103,12 +6095,12 @@ def list_code_repositories( def update_code_repository( self, - name_id_or_prefix: Union[UUID, str], - name: Optional[str] = None, - description: Optional[str] = None, - logo_url: Optional[str] = None, - config: Optional[Dict[str, Any]] = None, - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: UUID | str, + name: str | None = None, + description: str | None = None, + logo_url: str | None = None, + config: dict[str, Any] | None = None, + project: str | UUID | None = None, ) -> CodeRepositoryResponse: """Update a code repository. @@ -6153,8 +6145,8 @@ def update_code_repository( def delete_code_repository( self, - name_id_or_prefix: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + project: str | UUID | None = None, ) -> None: """Delete a code repository. @@ -6175,27 +6167,25 @@ def create_service_connector( self, name: str, connector_type: str, - resource_type: Optional[str] = None, - auth_method: Optional[str] = None, - configuration: Optional[Dict[str, str]] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + auth_method: str | None = None, + configuration: dict[str, str] | None = None, + resource_id: str | None = None, description: str = "", - expiration_seconds: Optional[int] = None, - expires_at: Optional[datetime] = None, - expires_skew_tolerance: Optional[int] = None, - labels: Optional[Dict[str, str]] = None, + expiration_seconds: int | None = None, + expires_at: datetime | None = None, + expires_skew_tolerance: int | None = None, + labels: dict[str, str] | None = None, auto_configure: bool = False, verify: bool = True, list_resources: bool = True, register: bool = True, - ) -> Tuple[ - Optional[ - Union[ - ServiceConnectorResponse, - ServiceConnectorRequest, - ] - ], - Optional[ServiceConnectorResourcesModel], + ) -> tuple[ + None | ( + ServiceConnectorResponse | + ServiceConnectorRequest + ), + ServiceConnectorResourcesModel | None, ]: """Create, validate and/or register a service connector. @@ -6237,8 +6227,8 @@ def create_service_connector( service_connector_registry, ) - connector_instance: Optional[ServiceConnector] = None - connector_resources: Optional[ServiceConnectorResourcesModel] = None + connector_instance: ServiceConnector | None = None + connector_resources: ServiceConnectorResourcesModel | None = None # Get the service connector type class try: @@ -6392,7 +6382,7 @@ def create_service_connector( def get_service_connector( self, - name_id_or_prefix: Union[str, UUID], + name_id_or_prefix: str | UUID, allow_name_prefix_match: bool = True, hydrate: bool = True, expand_secrets: bool = False, @@ -6427,16 +6417,16 @@ def list_service_connectors( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[datetime] = None, - updated: Optional[datetime] = None, - name: Optional[str] = None, - connector_type: Optional[str] = None, - auth_method: Optional[str] = None, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - user: Optional[Union[UUID, str]] = None, - labels: Optional[Dict[str, Optional[str]]] = None, + id: UUID | str | None = None, + created: datetime | None = None, + updated: datetime | None = None, + name: str | None = None, + connector_type: str | None = None, + auth_method: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, + user: UUID | str | None = None, + labels: dict[str, str | None] | None = None, hydrate: bool = False, expand_secrets: bool = False, ) -> Page[ServiceConnectorResponse]: @@ -6491,28 +6481,26 @@ def list_service_connectors( def update_service_connector( self, - name_id_or_prefix: Union[UUID, str], - name: Optional[str] = None, - auth_method: Optional[str] = None, - resource_type: Optional[str] = None, - configuration: Optional[Dict[str, str]] = None, - resource_id: Optional[str] = None, - description: Optional[str] = None, - expires_at: Optional[datetime] = None, - expires_skew_tolerance: Optional[int] = None, - expiration_seconds: Optional[int] = None, - labels: Optional[Dict[str, Optional[str]]] = None, + name_id_or_prefix: UUID | str, + name: str | None = None, + auth_method: str | None = None, + resource_type: str | None = None, + configuration: dict[str, str] | None = None, + resource_id: str | None = None, + description: str | None = None, + expires_at: datetime | None = None, + expires_skew_tolerance: int | None = None, + expiration_seconds: int | None = None, + labels: dict[str, str | None] | None = None, verify: bool = True, list_resources: bool = True, update: bool = True, - ) -> Tuple[ - Optional[ - Union[ - ServiceConnectorResponse, - ServiceConnectorUpdate, - ] - ], - Optional[ServiceConnectorResourcesModel], + ) -> tuple[ + None | ( + ServiceConnectorResponse | + ServiceConnectorUpdate + ), + ServiceConnectorResourcesModel | None, ]: """Validate and/or register an updated service connector. @@ -6575,8 +6563,8 @@ def update_service_connector( expand_secrets=configuration is None, ) - connector_instance: Optional[ServiceConnector] = None - connector_resources: Optional[ServiceConnectorResourcesModel] = None + connector_instance: ServiceConnector | None = None + connector_resources: ServiceConnectorResourcesModel | None = None if isinstance(connector_model.connector_type, str): connector = self.get_service_connector_type( @@ -6585,7 +6573,7 @@ def update_service_connector( else: connector = connector_model.connector_type - resource_types: Optional[Union[str, List[str]]] = None + resource_types: str | list[str] | None = None if resource_type == "": resource_types = None elif resource_type is None: @@ -6709,7 +6697,7 @@ def update_service_connector( def delete_service_connector( self, - name_id_or_prefix: Union[str, UUID], + name_id_or_prefix: str | UUID, ) -> None: """Deletes a registered service connector. @@ -6732,9 +6720,9 @@ def delete_service_connector( def verify_service_connector( self, - name_id_or_prefix: Union[UUID, str], - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + name_id_or_prefix: UUID | str, + resource_type: str | None = None, + resource_id: str | None = None, list_resources: bool = True, ) -> "ServiceConnectorResourcesModel": """Verifies if a service connector has access to one or more resources. @@ -6810,9 +6798,9 @@ def verify_service_connector( def login_service_connector( self, - name_id_or_prefix: Union[UUID, str], - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + name_id_or_prefix: UUID | str, + resource_type: str | None = None, + resource_id: str | None = None, **kwargs: Any, ) -> "ServiceConnector": """Use a service connector to authenticate a local client/SDK. @@ -6851,9 +6839,9 @@ def login_service_connector( def get_service_connector_client( self, - name_id_or_prefix: Union[UUID, str], - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + name_id_or_prefix: UUID | str, + resource_type: str | None = None, + resource_id: str | None = None, verify: bool = False, ) -> "ServiceConnector": """Get the client side of a service connector instance to use with a local client. @@ -6936,10 +6924,10 @@ def get_service_connector_client( def list_service_connector_resources( self, - connector_type: Optional[str] = None, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> List[ServiceConnectorResourcesModel]: + connector_type: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, + ) -> list[ServiceConnectorResourcesModel]: """List resources that can be accessed by service connectors. Args: @@ -6961,10 +6949,10 @@ def list_service_connector_resources( def list_service_connector_types( self, - connector_type: Optional[str] = None, - resource_type: Optional[str] = None, - auth_method: Optional[str] = None, - ) -> List[ServiceConnectorTypeModel]: + connector_type: str | None = None, + resource_type: str | None = None, + auth_method: str | None = None, + ) -> list[ServiceConnectorTypeModel]: """Get a list of service connector types. Args: @@ -7004,14 +6992,14 @@ def get_service_connector_type( def create_model( self, name: str, - license: Optional[str] = None, - description: Optional[str] = None, - audience: Optional[str] = None, - use_cases: Optional[str] = None, - limitations: Optional[str] = None, - trade_offs: Optional[str] = None, - ethics: Optional[str] = None, - tags: Optional[List[str]] = None, + license: str | None = None, + description: str | None = None, + audience: str | None = None, + use_cases: str | None = None, + limitations: str | None = None, + trade_offs: str | None = None, + ethics: str | None = None, + tags: list[str] | None = None, save_models_to_registry: bool = True, ) -> ModelResponse: """Creates a new model in Model Control Plane. @@ -7050,8 +7038,8 @@ def create_model( def delete_model( self, - model_name_or_id: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + model_name_or_id: str | UUID, + project: str | UUID | None = None, ) -> None: """Deletes a model from Model Control Plane. @@ -7066,19 +7054,19 @@ def delete_model( def update_model( self, - model_name_or_id: Union[str, UUID], - name: Optional[str] = None, - license: Optional[str] = None, - description: Optional[str] = None, - audience: Optional[str] = None, - use_cases: Optional[str] = None, - limitations: Optional[str] = None, - trade_offs: Optional[str] = None, - ethics: Optional[str] = None, - add_tags: Optional[List[str]] = None, - remove_tags: Optional[List[str]] = None, - save_models_to_registry: Optional[bool] = None, - project: Optional[Union[str, UUID]] = None, + model_name_or_id: str | UUID, + name: str | None = None, + license: str | None = None, + description: str | None = None, + audience: str | None = None, + use_cases: str | None = None, + limitations: str | None = None, + trade_offs: str | None = None, + ethics: str | None = None, + add_tags: list[str] | None = None, + remove_tags: list[str] | None = None, + save_models_to_registry: bool | None = None, + project: str | UUID | None = None, ) -> ModelResponse: """Updates an existing model in Model Control Plane. @@ -7123,8 +7111,8 @@ def update_model( def get_model( self, - model_name_or_id: Union[str, UUID], - project: Optional[Union[str, UUID]] = None, + model_name_or_id: str | UUID, + project: str | UUID | None = None, hydrate: bool = True, bypass_lazy_loader: bool = False, ) -> ModelResponse: @@ -7163,15 +7151,15 @@ def list_models( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - id: Optional[Union[UUID, str]] = None, - user: Optional[Union[UUID, str]] = None, - project: Optional[Union[str, UUID]] = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + id: UUID | str | None = None, + user: UUID | str | None = None, + project: str | UUID | None = None, hydrate: bool = False, - tag: Optional[str] = None, - tags: Optional[List[str]] = None, + tag: str | None = None, + tags: list[str] | None = None, ) -> Page[ModelResponse]: """Get models by filter from Model Control Plane. @@ -7219,11 +7207,11 @@ def list_models( def create_model_version( self, - model_name_or_id: Union[str, UUID], - name: Optional[str] = None, - description: Optional[str] = None, - tags: Optional[List[str]] = None, - project: Optional[Union[str, UUID]] = None, + model_name_or_id: str | UUID, + name: str | None = None, + description: str | None = None, + tags: list[str] | None = None, + project: str | UUID | None = None, ) -> ModelVersionResponse: """Creates a new model version in Model Control Plane. @@ -7266,11 +7254,11 @@ def delete_model_version( def get_model_version( self, - model_name_or_id: Optional[Union[str, UUID]] = None, - model_version_name_or_number_or_id: Optional[ - Union[str, int, ModelStages, UUID] - ] = None, - project: Optional[Union[str, UUID]] = None, + model_name_or_id: str | UUID | None = None, + model_version_name_or_number_or_id: None | ( + str | int | ModelStages | UUID + ) = None, + project: str | UUID | None = None, hydrate: bool = True, ) -> ModelVersionResponse: """Get an existing model version from Model Control Plane. @@ -7391,24 +7379,24 @@ def get_model_version( def list_model_versions( self, - model: Optional[Union[str, UUID]] = None, - model_name_or_id: Optional[Union[str, UUID]] = None, + model: str | UUID | None = None, + model_name_or_id: str | UUID | None = None, sort_by: str = "number", page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - id: Optional[Union[UUID, str]] = None, - number: Optional[int] = None, - stage: Optional[Union[str, ModelStages]] = None, - run_metadata: Optional[List[str]] = None, - user: Optional[Union[UUID, str]] = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + id: UUID | str | None = None, + number: int | None = None, + stage: str | ModelStages | None = None, + run_metadata: list[str] | None = None, + user: UUID | str | None = None, hydrate: bool = False, - tag: Optional[str] = None, - tags: Optional[List[str]] = None, - project: Optional[Union[str, UUID]] = None, + tag: str | None = None, + tags: list[str] | None = None, + project: str | UUID | None = None, ) -> Page[ModelVersionResponse]: """Get model versions by filter from Model Control Plane. @@ -7476,15 +7464,15 @@ def list_model_versions( def update_model_version( self, - model_name_or_id: Union[str, UUID], - version_name_or_id: Union[str, UUID], - stage: Optional[Union[str, ModelStages]] = None, + model_name_or_id: str | UUID, + version_name_or_id: str | UUID, + stage: str | ModelStages | None = None, force: bool = False, - name: Optional[str] = None, - description: Optional[str] = None, - add_tags: Optional[List[str]] = None, - remove_tags: Optional[List[str]] = None, - project: Optional[Union[str, UUID]] = None, + name: str | None = None, + description: str | None = None, + add_tags: list[str] | None = None, + remove_tags: list[str] | None = None, + project: str | UUID | None = None, ) -> ModelVersionResponse: """Get all model versions by filter. @@ -7534,16 +7522,16 @@ def list_model_version_artifact_links( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - model_version_id: Optional[Union[UUID, str]] = None, - artifact_version_id: Optional[Union[UUID, str]] = None, - artifact_name: Optional[str] = None, - only_data_artifacts: Optional[bool] = None, - only_model_artifacts: Optional[bool] = None, - only_deployment_artifacts: Optional[bool] = None, - has_custom_name: Optional[bool] = None, - user: Optional[Union[UUID, str]] = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + model_version_id: UUID | str | None = None, + artifact_version_id: UUID | str | None = None, + artifact_name: str | None = None, + only_data_artifacts: bool | None = None, + only_model_artifacts: bool | None = None, + only_deployment_artifacts: bool | None = None, + has_custom_name: bool | None = None, + user: UUID | str | None = None, hydrate: bool = False, ) -> Page[ModelVersionArtifactResponse]: """Get model version to artifact links by filter in Model Control Plane. @@ -7646,12 +7634,12 @@ def list_model_version_pipeline_run_links( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - model_version_id: Optional[Union[UUID, str]] = None, - pipeline_run_id: Optional[Union[UUID, str]] = None, - pipeline_run_name: Optional[str] = None, - user: Optional[Union[UUID, str]] = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + model_version_id: UUID | str | None = None, + pipeline_run_id: UUID | str | None = None, + pipeline_run_name: str | None = None, + user: UUID | str | None = None, hydrate: bool = False, ) -> Page[ModelVersionPipelineRunResponse]: """Get all model version to pipeline run links by filter. @@ -7697,16 +7685,16 @@ def list_authorized_devices( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - expires: Optional[Union[datetime, str]] = None, - client_id: Union[UUID, str, None] = None, - status: Union[OAuthDeviceStatus, str, None] = None, - trusted_device: Union[bool, str, None] = None, - user: Optional[Union[UUID, str]] = None, - failed_auth_attempts: Union[int, str, None] = None, - last_login: Optional[Union[datetime, str, None]] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + expires: datetime | str | None = None, + client_id: UUID | str | None = None, + status: OAuthDeviceStatus | str | None = None, + trusted_device: bool | str | None = None, + user: UUID | str | None = None, + failed_auth_attempts: int | str | None = None, + last_login: datetime | str | None | None = None, hydrate: bool = False, ) -> Page[OAuthDeviceResponse]: """List all authorized devices. @@ -7755,7 +7743,7 @@ def list_authorized_devices( def get_authorized_device( self, - id_or_prefix: Union[UUID, str], + id_or_prefix: UUID | str, allow_id_prefix_match: bool = True, hydrate: bool = True, ) -> OAuthDeviceResponse: @@ -7797,8 +7785,8 @@ def get_authorized_device( def update_authorized_device( self, - id_or_prefix: Union[UUID, str], - locked: Optional[bool] = None, + id_or_prefix: UUID | str, + locked: bool | None = None, ) -> OAuthDeviceResponse: """Update an authorized device. @@ -7821,7 +7809,7 @@ def update_authorized_device( def delete_authorized_device( self, - id_or_prefix: Union[str, UUID], + id_or_prefix: str | UUID, ) -> None: """Delete an authorized device. @@ -7861,10 +7849,10 @@ def list_trigger_executions( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - trigger_id: Optional[Union[UUID, str]] = None, - step_run_id: Optional[Union[UUID, str]] = None, - user: Optional[Union[UUID, str]] = None, - project: Optional[Union[UUID, str]] = None, + trigger_id: UUID | str | None = None, + step_run_id: UUID | str | None = None, + user: UUID | str | None = None, + project: UUID | str | None = None, hydrate: bool = False, ) -> Page[TriggerExecutionResponse]: """List all trigger executions matching the given filter criteria. @@ -7914,9 +7902,9 @@ def _get_entity_by_id_or_name_or_prefix( self, get_method: Callable[..., AnyResponse], list_method: Callable[..., Page[AnyResponse]], - name_id_or_prefix: Union[str, UUID], + name_id_or_prefix: str | UUID, allow_name_prefix_match: bool = True, - project: Optional[Union[str, UUID]] = None, + project: str | UUID | None = None, hydrate: bool = True, **kwargs: Any, ) -> AnyResponse: @@ -7951,7 +7939,7 @@ def _get_entity_by_id_or_name_or_prefix( # If not a UUID, try to find by name assert not isinstance(name_id_or_prefix, UUID) - list_kwargs: Dict[str, Any] = dict( + list_kwargs: dict[str, Any] = dict( name=f"equals:{name_id_or_prefix}", hydrate=hydrate, **kwargs, @@ -7997,9 +7985,9 @@ def _get_entity_version_by_id_or_name_or_prefix( self, get_method: Callable[..., AnyResponse], list_method: Callable[..., Page[AnyResponse]], - name_id_or_prefix: Union[str, UUID], - version: Optional[str], - project: Optional[Union[str, UUID]] = None, + name_id_or_prefix: str | UUID, + version: str | None, + project: str | UUID | None = None, hydrate: bool = True, ) -> "AnyResponse": from zenml.utils.uuid_utils import is_valid_uuid @@ -8019,7 +8007,7 @@ def _get_entity_version_by_id_or_name_or_prefix( return get_method(name_id_or_prefix, hydrate=hydrate) assert not isinstance(name_id_or_prefix, UUID) - list_kwargs: Dict[str, Any] = dict( + list_kwargs: dict[str, Any] = dict( size=1, sort_by="desc:created", name=name_id_or_prefix, @@ -8070,7 +8058,7 @@ def _get_entity_by_prefix( list_method: Callable[..., Page[AnyResponse]], partial_id_or_name: str, allow_name_prefix_match: bool, - project: Optional[Union[str, UUID]] = None, + project: str | UUID | None = None, hydrate: bool = True, **kwargs: Any, ) -> AnyResponse: @@ -8095,7 +8083,7 @@ def _get_entity_by_prefix( ZenKeyError: If there is more than one entity with that partial ID or name. """ - list_method_args: Dict[str, Any] = { + list_method_args: dict[str, Any] = { "logical_operator": LogicalOperators.OR, "id": f"startswith:{partial_id_or_name}", "hydrate": hydrate, @@ -8132,7 +8120,7 @@ def _get_entity_by_prefix( ) # If more than one entity is found, raise an error. - ambiguous_entities: List[str] = [] + ambiguous_entities: list[str] = [] for model in entity.items: model_name = getattr(model, "name", None) if model_name: @@ -8153,7 +8141,7 @@ def _get_entity_by_prefix( def create_service_account( self, name: str, - full_name: Optional[str] = None, + full_name: str | None = None, description: str = "", ) -> ServiceAccountResponse: """Create a new service account. @@ -8180,7 +8168,7 @@ def create_service_account( def get_service_account( self, - name_id_or_prefix: Union[str, UUID], + name_id_or_prefix: str | UUID, allow_name_prefix_match: bool = True, hydrate: bool = True, ) -> ServiceAccountResponse: @@ -8209,12 +8197,12 @@ def list_service_accounts( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - description: Optional[str] = None, - active: Optional[bool] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + description: str | None = None, + active: bool | None = None, hydrate: bool = False, ) -> Page[ServiceAccountResponse]: """List all service accounts. @@ -8254,10 +8242,10 @@ def list_service_accounts( def update_service_account( self, - name_id_or_prefix: Union[str, UUID], - updated_name: Optional[str] = None, - description: Optional[str] = None, - active: Optional[bool] = None, + name_id_or_prefix: str | UUID, + updated_name: str | None = None, + description: str | None = None, + active: bool | None = None, ) -> ServiceAccountResponse: """Update a service account. @@ -8286,7 +8274,7 @@ def update_service_account( def delete_service_account( self, - name_id_or_prefix: Union[str, UUID], + name_id_or_prefix: str | UUID, ) -> None: """Delete a service account. @@ -8304,7 +8292,7 @@ def delete_service_account( def create_api_key( self, - service_account_name_id_or_prefix: Union[str, UUID], + service_account_name_id_or_prefix: str | UUID, name: str, description: str = "", set_key: bool = False, @@ -8370,19 +8358,19 @@ def set_api_key(self, key: str) -> None: def list_api_keys( self, - service_account_name_id_or_prefix: Union[str, UUID], + service_account_name_id_or_prefix: str | UUID, sort_by: str = "created", page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - description: Optional[str] = None, - active: Optional[bool] = None, - last_login: Optional[Union[datetime, str]] = None, - last_rotated: Optional[Union[datetime, str]] = None, + id: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + description: str | None = None, + active: bool | None = None, + last_login: datetime | str | None = None, + last_rotated: datetime | str | None = None, hydrate: bool = False, ) -> Page[APIKeyResponse]: """List all API keys. @@ -8434,8 +8422,8 @@ def list_api_keys( def get_api_key( self, - service_account_name_id_or_prefix: Union[str, UUID], - name_id_or_prefix: Union[str, UUID], + service_account_name_id_or_prefix: str | UUID, + name_id_or_prefix: str | UUID, allow_name_prefix_match: bool = True, hydrate: bool = True, ) -> APIKeyResponse: @@ -8486,11 +8474,11 @@ def list_api_keys_method( def update_api_key( self, - service_account_name_id_or_prefix: Union[str, UUID], - name_id_or_prefix: Union[UUID, str], - name: Optional[str] = None, - description: Optional[str] = None, - active: Optional[bool] = None, + service_account_name_id_or_prefix: str | UUID, + name_id_or_prefix: UUID | str, + name: str | None = None, + description: str | None = None, + active: bool | None = None, ) -> APIKeyResponse: """Update an API key. @@ -8521,8 +8509,8 @@ def update_api_key( def rotate_api_key( self, - service_account_name_id_or_prefix: Union[str, UUID], - name_id_or_prefix: Union[UUID, str], + service_account_name_id_or_prefix: str | UUID, + name_id_or_prefix: UUID | str, retain_period_minutes: int = 0, set_key: bool = False, ) -> APIKeyResponse: @@ -8560,8 +8548,8 @@ def rotate_api_key( def delete_api_key( self, - service_account_name_id_or_prefix: Union[str, UUID], - name_id_or_prefix: Union[str, UUID], + service_account_name_id_or_prefix: str | UUID, + name_id_or_prefix: str | UUID, ) -> None: """Delete an API key. @@ -8585,7 +8573,7 @@ def create_tag( self, name: str, exclusive: bool = False, - color: Optional[Union[str, ColorVariants]] = None, + color: str | ColorVariants | None = None, ) -> TagResponse: """Creates a new tag. @@ -8610,7 +8598,7 @@ def create_tag( def delete_tag( self, - tag_name_or_id: Union[str, UUID], + tag_name_or_id: str | UUID, ) -> None: """Deletes a tag. @@ -8622,10 +8610,10 @@ def delete_tag( def update_tag( self, - tag_name_or_id: Union[str, UUID], - name: Optional[str] = None, - exclusive: Optional[bool] = None, - color: Optional[Union[str, ColorVariants]] = None, + tag_name_or_id: str | UUID, + name: str | None = None, + exclusive: bool | None = None, + color: str | ColorVariants | None = None, ) -> TagResponse: """Updates an existing tag. @@ -8665,7 +8653,7 @@ def update_tag( def get_tag( self, - tag_name_or_id: Union[str, UUID], + tag_name_or_id: str | UUID, allow_name_prefix_match: bool = True, hydrate: bool = True, ) -> TagResponse: @@ -8694,14 +8682,14 @@ def list_tags( page: int = PAGINATION_STARTING_PAGE, size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, - id: Optional[Union[UUID, str]] = None, - user: Optional[Union[UUID, str]] = None, - created: Optional[Union[datetime, str]] = None, - updated: Optional[Union[datetime, str]] = None, - name: Optional[str] = None, - color: Optional[Union[str, ColorVariants]] = None, - exclusive: Optional[bool] = None, - resource_type: Optional[Union[str, TaggableResourceTypes]] = None, + id: UUID | str | None = None, + user: UUID | str | None = None, + created: datetime | str | None = None, + updated: datetime | str | None = None, + name: str | None = None, + color: str | ColorVariants | None = None, + exclusive: bool | None = None, + resource_type: str | TaggableResourceTypes | None = None, hydrate: bool = False, ) -> Page[TagResponse]: """Get tags by filter. @@ -8745,8 +8733,8 @@ def list_tags( def attach_tag( self, - tag: Union[str, tag_utils.Tag], - resources: List[TagResource], + tag: str | tag_utils.Tag, + resources: list[TagResource], ) -> None: """Attach a tag to resources. @@ -8798,8 +8786,8 @@ def attach_tag( def detach_tag( self, - tag_name_or_id: Union[str, UUID], - resources: List[TagResource], + tag_name_or_id: str | UUID, + resources: list[TagResource], ) -> None: """Detach a tag from resources. diff --git a/src/zenml/client_lazy_loader.py b/src/zenml/client_lazy_loader.py index 97a7f6e370c..70eca05d156 100644 --- a/src/zenml/client_lazy_loader.py +++ b/src/zenml/client_lazy_loader.py @@ -15,7 +15,8 @@ import contextlib import functools -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any +from collections.abc import Callable from pydantic import BaseModel, Field @@ -28,18 +29,18 @@ class _CallStep(BaseModel): - attribute_name: Optional[str] = None + attribute_name: str | None = None is_call: bool = False - call_args: List[Any] = Field(default_factory=list) - call_kwargs: Dict[str, Any] = Field(default_factory=dict) - selector: Optional[Any] = None + call_args: list[Any] = Field(default_factory=list) + call_kwargs: dict[str, Any] = Field(default_factory=dict) + selector: Any | None = None class ClientLazyLoader(BaseModel): """Lazy loader for Client methods.""" method_name: str - call_chain: List[_CallStep] = [] + call_chain: list[_CallStep] = [] exclude_next_call: bool = False def __getattr__(self, name: str) -> "ClientLazyLoader": @@ -100,7 +101,7 @@ def evaluate(self) -> Any: from zenml.client import Client def _iterate_over_lazy_chain( - self: "ClientLazyLoader", self_: Any, call_chain_: List[_CallStep] + self: "ClientLazyLoader", self_: Any, call_chain_: list[_CallStep] ) -> Any: next_step = call_chain_.pop(0) try: @@ -142,7 +143,7 @@ def _iterate_over_lazy_chain( def client_lazy_loader( method_name: str, *args: Any, **kwargs: Any -) -> Optional[ClientLazyLoader]: +) -> ClientLazyLoader | None: """Lazy loader for Client methods helper. Usage: @@ -174,8 +175,8 @@ def get_something(self, arg1: Any)->SomeResponse: def evaluate_all_lazy_load_args_in_client_methods( - cls: Type["Client"], -) -> Type["Client"]: + cls: type["Client"], +) -> type["Client"]: """Class wrapper to evaluate lazy loader arguments of all methods. Args: @@ -214,7 +215,7 @@ def _inner(*args: Any, **kwargs: Any) -> Any: return _inner - def _decorate() -> Type["Client"]: + def _decorate() -> type["Client"]: for name, fn in inspect.getmembers(cls, inspect.isfunction): setattr( cls, diff --git a/src/zenml/code_repositories/base_code_repository.py b/src/zenml/code_repositories/base_code_repository.py index 6a94fa0e1b6..529ded0246c 100644 --- a/src/zenml/code_repositories/base_code_repository.py +++ b/src/zenml/code_repositories/base_code_repository.py @@ -14,7 +14,7 @@ """Base class for code repositories.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Type +from typing import TYPE_CHECKING, Any, Optional from uuid import UUID, uuid4 from zenml.config.secret_reference_mixin import SecretReferenceMixin @@ -45,7 +45,7 @@ def __init__( self, id: UUID, name: str, - config: Dict[str, Any], + config: dict[str, Any], ) -> None: """Initializes a code repository. @@ -78,7 +78,7 @@ def from_model(cls, model: CodeRepositoryResponse) -> "BaseCodeRepository": Returns: The loaded code repository object. """ - class_: Type[BaseCodeRepository] = ( + class_: type[BaseCodeRepository] = ( source_utils.load_and_validate_class( source=model.source, expected_class=BaseCodeRepository ) @@ -86,7 +86,7 @@ def from_model(cls, model: CodeRepositoryResponse) -> "BaseCodeRepository": return class_(id=model.id, name=model.name, config=model.config) @classmethod - def validate_config(cls, config: Dict[str, Any]) -> None: + def validate_config(cls, config: dict[str, Any]) -> None: """Validate the code repository config. This method should check that the config/credentials are valid and @@ -120,7 +120,7 @@ def name(self) -> str: return self._name @property - def requirements(self) -> Set[str]: + def requirements(self) -> set[str]: """Set of PyPI requirements for the repository. Returns: @@ -140,11 +140,10 @@ def login(self) -> None: Raises: RuntimeError: If the login fails. """ - pass @abstractmethod def download_files( - self, commit: str, directory: str, repo_sub_directory: Optional[str] + self, commit: str, directory: str, repo_sub_directory: str | None ) -> None: """Downloads files from the code repository to a local directory. @@ -157,7 +156,6 @@ def download_files( Raises: RuntimeError: If the download fails. """ - pass @abstractmethod def get_local_context( @@ -171,4 +169,3 @@ def get_local_context( Returns: The local repository context object. """ - pass diff --git a/src/zenml/code_repositories/git/local_git_repository_context.py b/src/zenml/code_repositories/git/local_git_repository_context.py index bf0b7ba63f2..c3627311ffe 100644 --- a/src/zenml/code_repositories/git/local_git_repository_context.py +++ b/src/zenml/code_repositories/git/local_git_repository_context.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Implementation of the Local git repository context.""" -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Optional +from collections.abc import Callable from zenml.code_repositories import ( LocalRepositoryContext, diff --git a/src/zenml/code_repositories/local_repository_context.py b/src/zenml/code_repositories/local_repository_context.py index 2eb3fb0e2b5..755d9b9a453 100644 --- a/src/zenml/code_repositories/local_repository_context.py +++ b/src/zenml/code_repositories/local_repository_context.py @@ -58,7 +58,6 @@ def root(self) -> str: Returns: The root path of the local repository. """ - pass @property @abstractmethod @@ -71,7 +70,6 @@ def is_dirty(self) -> bool: Returns: Whether the local repository is dirty. """ - pass @property @abstractmethod @@ -84,7 +82,6 @@ def has_local_changes(self) -> bool: Returns: Whether the local repository has local changes. """ - pass @property @abstractmethod @@ -94,4 +91,3 @@ def current_commit(self) -> str: Returns: The current commit of the local repository. """ - pass diff --git a/src/zenml/config/base_settings.py b/src/zenml/config/base_settings.py index 6dc9fb8fafc..cc823bb7133 100644 --- a/src/zenml/config/base_settings.py +++ b/src/zenml/config/base_settings.py @@ -14,13 +14,13 @@ """Base class for all ZenML settings.""" from enum import IntFlag, auto -from typing import Any, ClassVar, Dict, Union +from typing import Any, ClassVar, Union from pydantic import ConfigDict from zenml.config.secret_reference_mixin import SecretReferenceMixin -SettingsOrDict = Union[Dict[str, Any], "BaseSettings"] +SettingsOrDict = Union[dict[str, Any], "BaseSettings"] class ConfigurationLevel(IntFlag): diff --git a/src/zenml/config/build_configuration.py b/src/zenml/config/build_configuration.py index b36610f3981..a93f1f8e41d 100644 --- a/src/zenml/config/build_configuration.py +++ b/src/zenml/config/build_configuration.py @@ -15,7 +15,7 @@ import hashlib import json -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Optional from pydantic import BaseModel @@ -47,10 +47,10 @@ class BuildConfiguration(BaseModel): key: str settings: DockerSettings - step_name: Optional[str] = None - entrypoint: Optional[str] = None - extra_files: Dict[str, str] = {} - extra_requirements_files: Dict[str, List[str]] = {} + step_name: str | None = None + entrypoint: str | None = None + extra_files: dict[str, str] = {} + extra_requirements_files: dict[str, list[str]] = {} def compute_settings_checksum( self, diff --git a/src/zenml/config/cache_policy.py b/src/zenml/config/cache_policy.py index 368c3e871de..0fda6ce1254 100644 --- a/src/zenml/config/cache_policy.py +++ b/src/zenml/config/cache_policy.py @@ -13,10 +13,10 @@ # permissions and limitations under the License. """Cache policy.""" -from typing import Any, List, Optional, Union +from typing import Any, Union from pydantic import BaseModel, BeforeValidator, Field, field_validator -from typing_extensions import Annotated +from typing import Annotated from zenml.config.source import Source, SourceWithValidator from zenml.logger import get_logger @@ -45,27 +45,27 @@ class CachePolicy(BaseModel): default=True, description="Whether to include the artifact IDs in the cache key.", ) - ignored_inputs: Optional[List[str]] = Field( + ignored_inputs: list[str] | None = Field( default=None, description="List of input names to ignore in the cache key.", ) - file_dependencies: Optional[List[str]] = Field( + file_dependencies: list[str] | None = Field( default=None, description="List of file paths. The contents of theses files will be " "included in the cache key. Only relative paths within the source root " "are allowed.", ) - source_dependencies: Optional[List[SourceWithValidator]] = Field( + source_dependencies: list[SourceWithValidator] | None = Field( default=None, description="List of Python objects (modules, classes, functions). " "The source code of these objects will be included in the cache key.", ) - cache_func: Optional[SourceWithValidator] = Field( + cache_func: SourceWithValidator | None = Field( default=None, description="Function without arguments that returns a string. The " "returned value will be included in the cache key.", ) - expires_after: Optional[int] = Field( + expires_after: int | None = Field( default=None, description="The number of seconds after which the cached result by a " "step with this cache policy will expire. If not set, the result " @@ -74,8 +74,8 @@ class CachePolicy(BaseModel): @field_validator("source_dependencies", mode="before") def _validate_source_dependencies( - cls, v: Optional[List[Any]] - ) -> Optional[List[Any]]: + cls, v: list[Any] | None + ) -> list[Any] | None: from zenml.utils import source_utils if v is None: @@ -90,7 +90,7 @@ def _validate_source_dependencies( return result @field_validator("cache_func", mode="before") - def _validate_cache_func(cls, v: Optional[Any]) -> Optional[Any]: + def _validate_cache_func(cls, v: Any | None) -> Any | None: from zenml.utils import source_utils if v is None or isinstance(v, (str, Source, dict)): diff --git a/src/zenml/config/compiler.py b/src/zenml/config/compiler.py index dc2f140ca68..878557204d8 100644 --- a/src/zenml/config/compiler.py +++ b/src/zenml/config/compiler.py @@ -19,12 +19,9 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, - Mapping, Optional, - Tuple, ) +from collections.abc import Mapping from zenml import __version__ from zenml.config.base_settings import BaseSettings, ConfigurationLevel @@ -56,7 +53,7 @@ logger = get_logger(__file__) -def get_zenml_versions() -> Tuple[str, str]: +def get_zenml_versions() -> tuple[str, str]: """Returns the version of ZenML on the client and server side. Returns: @@ -324,7 +321,7 @@ def _get_default_settings( @staticmethod def _verify_run_name( run_name: str, - substitutions: Dict[str, str], + substitutions: dict[str, str], ) -> None: """Verifies that the run name contains only valid placeholders. @@ -371,10 +368,10 @@ def _verify_upstream_steps( def _filter_and_validate_settings( self, - settings: Dict[str, "BaseSettings"], + settings: dict[str, "BaseSettings"], configuration_level: ConfigurationLevel, stack: "Stack", - ) -> Dict[str, "BaseSettings"]: + ) -> dict[str, "BaseSettings"]: """Filters and validates settings. Args: @@ -523,7 +520,7 @@ def _compile_step_invocation( def _get_sorted_invocations( self, pipeline: "Pipeline", - ) -> List[Tuple[str, "StepInvocation"]]: + ) -> list[tuple[str, "StepInvocation"]]: """Sorts the step invocations of a pipeline using topological sort. The resulting list of invocations will be in an order that can be @@ -539,12 +536,12 @@ def _get_sorted_invocations( from zenml.orchestrators.topsort import topsorted_layers # Sort step names using topological sort - dag: Dict[str, List[str]] = {} + dag: dict[str, list[str]] = {} for name, step in pipeline.invocations.items(): self._verify_upstream_steps(invocation=step, pipeline=pipeline) dag[name] = list(step.upstream_steps) - reversed_dag: Dict[str, List[str]] = reverse_dag(dag) + reversed_dag: dict[str, list[str]] = reverse_dag(dag) layers = topsorted_layers( nodes=list(dag), get_node_id_fn=lambda node: node, @@ -552,7 +549,7 @@ def _get_sorted_invocations( get_child_nodes=lambda node: reversed_dag[node], ) sorted_step_names = [step for layer in layers for step in layer] - sorted_invocations: List[Tuple[str, "StepInvocation"]] = [ + sorted_invocations: list[tuple[str, "StepInvocation"]] = [ (name_in_pipeline, pipeline.invocations[name_in_pipeline]) for name_in_pipeline in sorted_step_names ] @@ -620,7 +617,7 @@ def _ensure_required_stack_components_exist( @staticmethod def _compute_pipeline_spec( - pipeline: "Pipeline", step_specs: List["StepSpec"] + pipeline: "Pipeline", step_specs: list["StepSpec"] ) -> "PipelineSpec": """Computes the pipeline spec. @@ -664,7 +661,7 @@ def _compute_pipeline_spec( def convert_component_shortcut_settings_keys( - settings: Dict[str, "BaseSettings"], stack: "Stack" + settings: dict[str, "BaseSettings"], stack: "Stack" ) -> None: """Convert component shortcut settings keys. @@ -691,8 +688,8 @@ def convert_component_shortcut_settings_keys( def finalize_environment_variables( - environment: Dict[str, Any], -) -> Dict[str, str]: + environment: dict[str, Any], +) -> dict[str, str]: """Finalize the user environment variables. This function adds all __ZENML__ prefixed environment variables from the diff --git a/src/zenml/config/deployment_settings.py b/src/zenml/config/deployment_settings.py index ba40d74d609..5857913fc93 100644 --- a/src/zenml/config/deployment_settings.py +++ b/src/zenml/config/deployment_settings.py @@ -16,13 +16,9 @@ from enum import Enum, IntFlag, auto from typing import ( Any, - Callable, ClassVar, - Dict, - List, - Optional, - Union, ) +from collections.abc import Callable from pydantic import ( BaseModel, @@ -175,8 +171,8 @@ class constructor, if provided. handler: SourceOrObjectField native: bool = False auth_required: bool = True - init_kwargs: Dict[str, Any] = Field(default_factory=dict) - extra_kwargs: Dict[str, Any] = Field(default_factory=dict) + init_kwargs: dict[str, Any] = Field(default_factory=dict) + extra_kwargs: dict[str, Any] = Field(default_factory=dict) def load_sources(self) -> None: """Load all source strings into callables.""" @@ -281,8 +277,8 @@ async def my_middleware( middleware: SourceOrObjectField native: bool = False order: int = 0 - init_kwargs: Dict[str, Any] = Field(default_factory=dict) - extra_kwargs: Dict[str, Any] = Field(default_factory=dict) + init_kwargs: dict[str, Any] = Field(default_factory=dict) + extra_kwargs: dict[str, Any] = Field(default_factory=dict) def load_sources(self) -> None: """Load source string into callable.""" @@ -346,7 +342,7 @@ def install(self, app_runner: BaseDeploymentAppRunner) -> None: """ extension: SourceOrObjectField - extension_kwargs: Dict[str, Any] = Field(default_factory=dict) + extension_kwargs: dict[str, Any] = Field(default_factory=dict) def load_sources(self) -> None: """Load source string into callable.""" @@ -385,9 +381,9 @@ def resolve_extension_handler( class CORSConfig(BaseModel): """Configuration for CORS.""" - allow_origins: List[str] = ["*"] - allow_methods: List[str] = ["GET", "POST", "OPTIONS"] - allow_headers: List[str] = ["*"] + allow_origins: list[str] = ["*"] + allow_methods: list[str] = ["GET", "POST", "OPTIONS"] + allow_headers: list[str] = ["*"] allow_credentials: bool = False @@ -455,35 +451,35 @@ class SecureHeadersConfig(BaseModel): included in responses. """ - server: Union[bool, str] = Field( + server: bool | str = Field( default=True, union_mode="left_to_right", ) - hsts: Union[bool, str] = Field( + hsts: bool | str = Field( default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_HSTS, union_mode="left_to_right", ) - xfo: Union[bool, str] = Field( + xfo: bool | str = Field( default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_XFO, union_mode="left_to_right", ) - content: Union[bool, str] = Field( + content: bool | str = Field( default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_CONTENT, union_mode="left_to_right", ) - csp: Union[bool, str] = Field( + csp: bool | str = Field( default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_CSP, union_mode="left_to_right", ) - referrer: Union[bool, str] = Field( + referrer: bool | str = Field( default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_REFERRER, union_mode="left_to_right", ) - cache: Union[bool, str] = Field( + cache: bool | str = Field( default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_CACHE, union_mode="left_to_right", ) - permissions: Union[bool, str] = Field( + permissions: bool | str = Field( default=DEFAULT_DEPLOYMENT_APP_SECURE_HEADERS_PERMISSIONS, union_mode="left_to_right", ) @@ -654,10 +650,10 @@ class DeploymentSettings(BaseSettings): # These settings are only available at the pipeline level LEVEL: ClassVar[ConfigurationLevel] = ConfigurationLevel.PIPELINE - app_title: Optional[str] = None - app_description: Optional[str] = None - app_version: Optional[str] = None - app_kwargs: Dict[str, Any] = {} + app_title: str | None = None + app_description: str | None = None + app_version: str | None = None + app_kwargs: dict[str, Any] = {} include_default_endpoints: DeploymentDefaultEndpoints = ( DeploymentDefaultEndpoints.ALL @@ -675,23 +671,23 @@ class DeploymentSettings(BaseSettings): info_url_path: str = DEFAULT_DEPLOYMENT_APP_INFO_URL_PATH metrics_url_path: str = DEFAULT_DEPLOYMENT_APP_METRICS_URL_PATH - dashboard_files_path: Optional[str] = None + dashboard_files_path: str | None = None cors: CORSConfig = CORSConfig() secure_headers: SecureHeadersConfig = SecureHeadersConfig() thread_pool_size: int = DEFAULT_DEPLOYMENT_APP_THREAD_POOL_SIZE - startup_hook: Optional[SourceOrObjectField] = None - shutdown_hook: Optional[SourceOrObjectField] = None - startup_hook_kwargs: Dict[str, Any] = {} - shutdown_hook_kwargs: Dict[str, Any] = {} + startup_hook: SourceOrObjectField | None = None + shutdown_hook: SourceOrObjectField | None = None + startup_hook_kwargs: dict[str, Any] = {} + shutdown_hook_kwargs: dict[str, Any] = {} # Framework-agnostic endpoint/middleware configuration - custom_endpoints: Optional[List[EndpointSpec]] = None - custom_middlewares: Optional[List[MiddlewareSpec]] = None + custom_endpoints: list[EndpointSpec] | None = None + custom_middlewares: list[MiddlewareSpec] | None = None # Pluggable app extensions for advanced features - app_extensions: Optional[List[AppExtensionSpec]] = None + app_extensions: list[AppExtensionSpec] | None = None uvicorn_host: str = "0.0.0.0" # nosec uvicorn_port: int = 8000 @@ -699,12 +695,12 @@ class DeploymentSettings(BaseSettings): uvicorn_reload: bool = False log_level: LoggingLevels = LoggingLevels.INFO - uvicorn_kwargs: Dict[str, Any] = {} + uvicorn_kwargs: dict[str, Any] = {} - deployment_app_runner_flavor: Optional[SourceOrObjectField] = None - deployment_app_runner_kwargs: Dict[str, Any] = {} - deployment_service_class: Optional[SourceOrObjectField] = None - deployment_service_kwargs: Dict[str, Any] = {} + deployment_app_runner_flavor: SourceOrObjectField | None = None + deployment_app_runner_kwargs: dict[str, Any] = {} + deployment_service_class: SourceOrObjectField | None = None + deployment_service_kwargs: dict[str, Any] = {} def load_sources(self) -> None: """Load source string into callable.""" diff --git a/src/zenml/config/docker_settings.py b/src/zenml/config/docker_settings.py index 9274f92d49c..839f8def93f 100644 --- a/src/zenml/config/docker_settings.py +++ b/src/zenml/config/docker_settings.py @@ -14,7 +14,7 @@ """Docker settings.""" from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -68,8 +68,8 @@ class DockerBuildConfig(BaseModel): Docker image. """ - build_options: Dict[str, Any] = {} - dockerignore: Optional[str] = None + build_options: dict[str, Any] = {} + dockerignore: str | None = None _docker_settings_warnings_logged = [] @@ -212,47 +212,47 @@ class DockerSettings(BaseSettings): from the artifact store. """ - parent_image: Optional[str] = None - image_tag: Optional[str] = None - dockerfile: Optional[str] = None - build_context_root: Optional[str] = None - parent_image_build_config: Optional[DockerBuildConfig] = None + parent_image: str | None = None + image_tag: str | None = None + dockerfile: str | None = None + build_context_root: str | None = None + parent_image_build_config: DockerBuildConfig | None = None skip_build: bool = False prevent_build_reuse: bool = False - target_repository: Optional[str] = None + target_repository: str | None = None python_package_installer: PythonPackageInstaller = ( PythonPackageInstaller.UV ) - python_package_installer_args: Dict[str, Any] = {} + python_package_installer_args: dict[str, Any] = {} disable_automatic_requirements_detection: bool = True - replicate_local_python_environment: Optional[ - Union[List[str], PythonEnvironmentExportMethod, bool] - ] = Field(default=None, union_mode="left_to_right") - pyproject_path: Optional[str] = None - pyproject_export_command: Optional[List[str]] = None - requirements: Union[None, str, List[str]] = Field( + replicate_local_python_environment: None | ( + list[str] | PythonEnvironmentExportMethod | bool + ) = Field(default=None, union_mode="left_to_right") + pyproject_path: str | None = None + pyproject_export_command: list[str] | None = None + requirements: None | str | list[str] = Field( default=None, union_mode="left_to_right" ) - required_integrations: List[str] = [] + required_integrations: list[str] = [] install_stack_requirements: bool = True install_deployment_requirements: bool = True - local_project_install_command: Optional[str] = None - apt_packages: List[str] = [] - environment: Dict[str, Any] = {} - user: Optional[str] = None - build_config: Optional[DockerBuildConfig] = None + local_project_install_command: str | None = None + apt_packages: list[str] = [] + environment: dict[str, Any] = {} + user: str | None = None + build_config: DockerBuildConfig | None = None allow_including_files_in_images: bool = True allow_download_from_code_repository: bool = True allow_download_from_artifact_store: bool = True # Deprecated attributes - build_options: Dict[str, Any] = {} - dockerignore: Optional[str] = None + build_options: dict[str, Any] = {} + dockerignore: str | None = None copy_files: bool = True copy_global_config: bool = True - source_files: Optional[str] = None - required_hub_plugins: List[str] = [] + source_files: str | None = None + required_hub_plugins: list[str] = [] _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes( "copy_files", @@ -266,7 +266,7 @@ class DockerSettings(BaseSettings): @model_validator(mode="before") @classmethod @before_validator_handler - def _migrate_source_files(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def _migrate_source_files(cls, data: dict[str, Any]) -> dict[str, Any]: """Migrate old source_files values. Args: diff --git a/src/zenml/config/global_config.py b/src/zenml/config/global_config.py index 4cc8d17fb3b..051b2a1e8c8 100644 --- a/src/zenml/config/global_config.py +++ b/src/zenml/config/global_config.py @@ -16,7 +16,7 @@ import os import uuid from pathlib import Path -from typing import TYPE_CHECKING, Any, Dict, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from uuid import UUID from packaging import version @@ -117,13 +117,13 @@ class GlobalConfiguration(BaseModel, metaclass=GlobalConfigMetaClass): """ user_id: uuid.UUID = Field(default_factory=uuid.uuid4) - user_email: Optional[str] = None - user_email_opt_in: Optional[bool] = None + user_email: str | None = None + user_email_opt_in: bool | None = None analytics_opt_in: bool = True - version: Optional[str] = None - store: Optional[SerializeAsAny[StoreConfiguration]] = None - active_stack_id: Optional[uuid.UUID] = None - active_project_id: Optional[uuid.UUID] = None + version: str | None = None + store: SerializeAsAny[StoreConfiguration] | None = None + active_stack_id: uuid.UUID | None = None + active_project_id: uuid.UUID | None = None _zen_store: Optional["BaseZenStore"] = None _active_project: Optional["ProjectResponse"] = None @@ -176,7 +176,7 @@ def _reset_instance( @field_validator("version") @classmethod - def _validate_version(cls, value: Optional[str]) -> Optional[str]: + def _validate_version(cls, value: str | None) -> str | None: """Validate the version attribute. Args: @@ -288,7 +288,7 @@ def _migrate_config(self) -> None: # to ensure the schema migration results are persisted self.version = __version__ - def _read_config(self) -> Dict[str, Any]: + def _read_config(self) -> dict[str, Any]: """Reads configuration options from disk. If the config file doesn't exist yet, this method returns an empty @@ -439,7 +439,7 @@ def local_stores_path(self) -> str: LOCAL_STORES_DIRECTORY_NAME, ) - def get_config_environment_vars(self) -> Dict[str, str]: + def get_config_environment_vars(self) -> dict[str, str]: """Convert the global configuration to environment variables. Returns: @@ -489,7 +489,7 @@ def get_config_environment_vars(self) -> Dict[str, str]: return environment_vars def _get_store_configuration( - self, baseline: Optional[StoreConfiguration] = None + self, baseline: StoreConfiguration | None = None ) -> StoreConfiguration: """Get the store configuration. @@ -510,12 +510,12 @@ def _get_store_configuration( """ from zenml.zen_stores.base_zen_store import BaseZenStore - store: Optional[StoreConfiguration] = baseline or self.store + store: StoreConfiguration | None = baseline or self.store # Step 1: Read environment variable overrides - env_store_config: Dict[str, str] = {} - env_secrets_store_config: Dict[str, str] = {} - env_backup_secrets_store_config: Dict[str, str] = {} + env_store_config: dict[str, str] = {} + env_secrets_store_config: dict[str, str] = {} + env_backup_secrets_store_config: dict[str, str] = {} for k, v in os.environ.items(): if k.startswith(ENV_ZENML_STORE_PREFIX): env_store_config[k[len(ENV_ZENML_STORE_PREFIX) :].lower()] = v diff --git a/src/zenml/config/pipeline_configurations.py b/src/zenml/config/pipeline_configurations.py index 080be3b2b85..a983bdee215 100644 --- a/src/zenml/config/pipeline_configurations.py +++ b/src/zenml/config/pipeline_configurations.py @@ -14,7 +14,7 @@ """Pipeline configuration classes.""" from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Union from uuid import UUID from pydantic import SerializeAsAny, field_validator @@ -48,31 +48,31 @@ class PipelineConfigurationUpdate(FrozenBaseModel): """Class for pipeline configuration updates.""" - enable_cache: Optional[bool] = None - enable_artifact_metadata: Optional[bool] = None - enable_artifact_visualization: Optional[bool] = None - enable_step_logs: Optional[bool] = None - environment: Dict[str, Any] = {} - secrets: List[Union[str, UUID]] = [] - enable_pipeline_logs: Optional[bool] = None - execution_mode: Optional[ExecutionMode] = None - settings: Dict[str, SerializeAsAny[BaseSettings]] = {} - tags: Optional[List[Union[str, "Tag"]]] = None - extra: Dict[str, Any] = {} - failure_hook_source: Optional[SourceWithValidator] = None - success_hook_source: Optional[SourceWithValidator] = None - init_hook_source: Optional[SourceWithValidator] = None - init_hook_kwargs: Optional[Dict[str, Any]] = None - cleanup_hook_source: Optional[SourceWithValidator] = None - model: Optional[Model] = None - parameters: Optional[Dict[str, Any]] = None - retry: Optional[StepRetryConfig] = None - substitutions: Dict[str, str] = {} - cache_policy: Optional[CachePolicyWithValidator] = None + enable_cache: bool | None = None + enable_artifact_metadata: bool | None = None + enable_artifact_visualization: bool | None = None + enable_step_logs: bool | None = None + environment: dict[str, Any] = {} + secrets: list[str | UUID] = [] + enable_pipeline_logs: bool | None = None + execution_mode: ExecutionMode | None = None + settings: dict[str, SerializeAsAny[BaseSettings]] = {} + tags: list[Union[str, "Tag"]] | None = None + extra: dict[str, Any] = {} + failure_hook_source: SourceWithValidator | None = None + success_hook_source: SourceWithValidator | None = None + init_hook_source: SourceWithValidator | None = None + init_hook_kwargs: dict[str, Any] | None = None + cleanup_hook_source: SourceWithValidator | None = None + model: Model | None = None + parameters: dict[str, Any] | None = None + retry: StepRetryConfig | None = None + substitutions: dict[str, str] = {} + cache_policy: CachePolicyWithValidator | None = None def finalize_substitutions( - self, start_time: Optional[datetime] = None, inplace: bool = False - ) -> Dict[str, str]: + self, start_time: datetime | None = None, inplace: bool = False + ) -> dict[str, str]: """Returns the full substitutions dict. Args: diff --git a/src/zenml/config/pipeline_run_configuration.py b/src/zenml/config/pipeline_run_configuration.py index 65167eadf26..083cad89aba 100644 --- a/src/zenml/config/pipeline_run_configuration.py +++ b/src/zenml/config/pipeline_run_configuration.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Pipeline run configuration class.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any from uuid import UUID from pydantic import Field, SerializeAsAny @@ -37,99 +37,99 @@ class PipelineRunConfiguration( ): """Class for pipeline run configurations.""" - run_name: Optional[str] = Field( + run_name: str | None = Field( default=None, description="The name of the pipeline run." ) - enable_cache: Optional[bool] = Field( + enable_cache: bool | None = Field( default=None, description="Whether to enable cache for all steps of the pipeline " "run.", ) - enable_artifact_metadata: Optional[bool] = Field( + enable_artifact_metadata: bool | None = Field( default=None, description="Whether to enable metadata for the output artifacts of " "all steps of the pipeline run.", ) - enable_artifact_visualization: Optional[bool] = Field( + enable_artifact_visualization: bool | None = Field( default=None, description="Whether to enable visualizations for the output " "artifacts of all steps of the pipeline run.", ) - enable_step_logs: Optional[bool] = Field( + enable_step_logs: bool | None = Field( default=None, description="Whether to enable logs for all steps of the pipeline run.", ) - enable_pipeline_logs: Optional[bool] = Field( + enable_pipeline_logs: bool | None = Field( default=None, description="Whether to enable pipeline logs for the pipeline run.", ) - schedule: Optional[Schedule] = Field( + schedule: Schedule | None = Field( default=None, description="The schedule on which to run the pipeline." ) - build: Union[PipelineBuildBase, UUID, None] = Field( + build: PipelineBuildBase | UUID | None = Field( default=None, union_mode="left_to_right", description="The build to use for the pipeline run.", ) - steps: Optional[Dict[str, StepConfigurationUpdate]] = Field( + steps: dict[str, StepConfigurationUpdate] | None = Field( default=None, description="Configurations for the steps of the pipeline run.", ) - settings: Optional[Dict[str, SerializeAsAny[BaseSettings]]] = Field( + settings: dict[str, SerializeAsAny[BaseSettings]] | None = Field( default=None, description="Settings for the pipeline run." ) - environment: Optional[Dict[str, Any]] = Field( + environment: dict[str, Any] | None = Field( default=None, description="The environment for all steps of the pipeline run.", ) - secrets: Optional[List[Union[str, UUID]]] = Field( + secrets: list[str | UUID] | None = Field( default=None, description="The secrets for all steps of the pipeline run.", ) - tags: Optional[List[Union[str, Tag]]] = Field( + tags: list[str | Tag] | None = Field( default=None, description="Tags to apply to the pipeline run." ) - extra: Optional[Dict[str, Any]] = Field( + extra: dict[str, Any] | None = Field( default=None, description="Extra configurations for the pipeline run." ) - model: Optional[Model] = Field( + model: Model | None = Field( default=None, description="The model to use for the pipeline run." ) - parameters: Optional[Dict[str, Any]] = Field( + parameters: dict[str, Any] | None = Field( default=None, description="Parameters for the pipeline function." ) - retry: Optional[StepRetryConfig] = Field( + retry: StepRetryConfig | None = Field( default=None, description="The retry configuration for all steps of the pipeline run.", ) - failure_hook_source: Optional[SourceWithValidator] = Field( + failure_hook_source: SourceWithValidator | None = Field( default=None, description="The failure hook source for all steps of the pipeline run.", ) - init_hook_source: Optional[SourceWithValidator] = Field( + init_hook_source: SourceWithValidator | None = Field( default=None, description="The init hook source for the pipeline run.", ) - init_hook_kwargs: Optional[Dict[str, Any]] = Field( + init_hook_kwargs: dict[str, Any] | None = Field( default=None, description="The init hook args for the pipeline run.", ) - cleanup_hook_source: Optional[SourceWithValidator] = Field( + cleanup_hook_source: SourceWithValidator | None = Field( default=None, description="The cleanup hook source for the pipeline run.", ) - success_hook_source: Optional[SourceWithValidator] = Field( + success_hook_source: SourceWithValidator | None = Field( default=None, description="The success hook source for all steps of the pipeline run.", ) - substitutions: Optional[Dict[str, str]] = Field( + substitutions: dict[str, str] | None = Field( default=None, description="The substitutions for the pipeline run." ) - cache_policy: Optional[CachePolicyWithValidator] = Field( + cache_policy: CachePolicyWithValidator | None = Field( default=None, description="The cache policy for all steps of the pipeline run.", ) - execution_mode: Optional[ExecutionMode] = Field( + execution_mode: ExecutionMode | None = Field( default=None, description="The execution mode for the pipeline run.", ) diff --git a/src/zenml/config/pipeline_spec.py b/src/zenml/config/pipeline_spec.py index 4b6b4e29b72..a2a9574451d 100644 --- a/src/zenml/config/pipeline_spec.py +++ b/src/zenml/config/pipeline_spec.py @@ -14,7 +14,7 @@ """Pipeline configuration classes.""" import json -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import Field @@ -44,18 +44,18 @@ class PipelineSpec(FrozenBaseModel): # inputs in the step specs refer to the pipeline parameter names # - 0.5: Adds input schema, outputs and output schema version: str = "0.5" - source: Optional[SourceWithValidator] = None - parameters: Dict[str, Any] = {} - input_schema: Optional[Dict[str, Any]] = Field( + source: SourceWithValidator | None = None + parameters: dict[str, Any] = {} + input_schema: dict[str, Any] | None = Field( default=None, description="JSON schema of the pipeline inputs. This is only set " "for pipeline specs with version >= 0.5. If the value is None, the " "schema generation failed, which is most likely because some of the " "pipeline inputs are not JSON serializable.", ) - steps: List[StepSpec] - outputs: List[OutputSpec] = [] - output_schema: Optional[Dict[str, Any]] = Field( + steps: list[StepSpec] + outputs: list[OutputSpec] = [] + output_schema: dict[str, Any] | None = Field( default=None, description="JSON schema of the pipeline outputs. This is only set " "for pipeline specs with version >= 0.5. If the value is None, the " diff --git a/src/zenml/config/resource_settings.py b/src/zenml/config/resource_settings.py index 4482c3bf261..41e01d8ec73 100644 --- a/src/zenml/config/resource_settings.py +++ b/src/zenml/config/resource_settings.py @@ -14,7 +14,7 @@ """Resource settings class used to specify resources for a step.""" from enum import Enum -from typing import Literal, Optional, Union +from typing import Literal from pydantic import ( ConfigDict, @@ -128,18 +128,18 @@ class ResourceSettings(BaseSettings): Only relevant to deployed pipelines. """ - cpu_count: Optional[PositiveFloat] = None - gpu_count: Optional[NonNegativeInt] = None - memory: Optional[str] = Field(pattern=MEMORY_REGEX, default=None) + cpu_count: PositiveFloat | None = None + gpu_count: NonNegativeInt | None = None + memory: str | None = Field(pattern=MEMORY_REGEX, default=None) # Settings only applicable for deployers and deployed pipelines - min_replicas: Optional[NonNegativeInt] = None - max_replicas: Optional[NonNegativeInt] = None - autoscaling_metric: Optional[ + min_replicas: NonNegativeInt | None = None + max_replicas: NonNegativeInt | None = None + autoscaling_metric: None | ( Literal["cpu", "memory", "concurrency", "rps"] - ] = None - autoscaling_target: Optional[PositiveFloat] = None - max_concurrency: Optional[PositiveInt] = None + ) = None + autoscaling_target: PositiveFloat | None = None + max_concurrency: PositiveInt | None = None @property def empty(self) -> bool: @@ -154,8 +154,8 @@ def empty(self) -> bool: return len(self.model_dump(exclude_unset=True, exclude_none=True)) == 0 def get_memory( - self, unit: Union[str, ByteUnit] = ByteUnit.GB - ) -> Optional[float]: + self, unit: str | ByteUnit = ByteUnit.GB + ) -> float | None: """Gets the memory configuration in a specific unit. Args: diff --git a/src/zenml/config/schedule.py b/src/zenml/config/schedule.py index 9f4eb585229..799f7b7c17e 100644 --- a/src/zenml/config/schedule.py +++ b/src/zenml/config/schedule.py @@ -14,7 +14,6 @@ """Class for defining a pipeline schedule.""" from datetime import datetime, timedelta -from typing import Optional from pydantic import ( BaseModel, @@ -57,21 +56,21 @@ class Schedule(BaseModel): in the local timezone. """ - name: Optional[str] = None - cron_expression: Optional[str] = None - start_time: Optional[datetime] = None - end_time: Optional[datetime] = None - interval_second: Optional[timedelta] = None + name: str | None = None + cron_expression: str | None = None + start_time: datetime | None = None + end_time: datetime | None = None + interval_second: timedelta | None = None catchup: bool = False - run_once_start_time: Optional[datetime] = None + run_once_start_time: datetime | None = None @field_validator( "start_time", "end_time", "run_once_start_time", mode="after" ) @classmethod def _ensure_timezone( - cls, value: Optional[datetime], info: ValidationInfo - ) -> Optional[datetime]: + cls, value: datetime | None, info: ValidationInfo + ) -> datetime | None: """Ensures that all datetimes are timezone aware. Args: diff --git a/src/zenml/config/secret_reference_mixin.py b/src/zenml/config/secret_reference_mixin.py index 79845c16dd7..ccb6b689b17 100644 --- a/src/zenml/config/secret_reference_mixin.py +++ b/src/zenml/config/secret_reference_mixin.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Secret reference mixin implementation.""" -from typing import TYPE_CHECKING, Any, Set +from typing import TYPE_CHECKING, Any from pydantic import BaseModel @@ -141,7 +141,7 @@ def __custom_getattribute__(self, key: str) -> Any: __getattribute__ = __custom_getattribute__ @property - def required_secrets(self) -> Set[secret_utils.SecretReference]: + def required_secrets(self) -> set[secret_utils.SecretReference]: """All required secrets for this object. Returns: diff --git a/src/zenml/config/secrets_store_config.py b/src/zenml/config/secrets_store_config.py index 7216b7b1a4c..6102a4badd5 100644 --- a/src/zenml/config/secrets_store_config.py +++ b/src/zenml/config/secrets_store_config.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Functionality to support ZenML secrets store configurations.""" -from typing import Optional from pydantic import BaseModel, ConfigDict, model_validator @@ -39,7 +38,7 @@ class SecretsStoreConfiguration(BaseModel): """ type: SecretsStoreType - class_path: Optional[str] = None + class_path: str | None = None @model_validator(mode="after") def validate_custom(self) -> "SecretsStoreConfiguration": diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index 9c8b3e5a82f..0ecf4362bbf 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -16,7 +16,7 @@ import json import os from secrets import token_hex -from typing import Any, Dict, List, Optional, Union +from typing import Any from uuid import UUID from pydantic import ( @@ -253,20 +253,20 @@ class ServerConfiguration(BaseModel): """ deployment_type: ServerDeploymentType = ServerDeploymentType.OTHER - server_url: Optional[str] = None - dashboard_url: Optional[str] = None + server_url: str | None = None + dashboard_url: str | None = None root_url_path: str = "" - metadata: Dict[str, str] = {} + metadata: dict[str, str] = {} auth_scheme: AuthScheme = AuthScheme.OAUTH2_PASSWORD_BEARER jwt_token_algorithm: str = DEFAULT_ZENML_JWT_TOKEN_ALGORITHM - jwt_token_issuer: Optional[str] = None - jwt_token_audience: Optional[str] = None + jwt_token_issuer: str | None = None + jwt_token_audience: str | None = None jwt_token_leeway_seconds: int = DEFAULT_ZENML_JWT_TOKEN_LEEWAY - jwt_token_expire_minutes: Optional[int] = None + jwt_token_expire_minutes: int | None = None jwt_secret_key: str = Field(default_factory=generate_jwt_secret_key) - auth_cookie_name: Optional[str] = None - auth_cookie_domain: Optional[str] = None - cors_allow_origins: Optional[List[str]] = None + auth_cookie_name: str | None = None + auth_cookie_domain: str | None = None + cors_allow_origins: list[str] | None = None max_failed_device_auth_attempts: int = ( DEFAULT_ZENML_SERVER_MAX_DEVICE_AUTH_ATTEMPTS ) @@ -274,8 +274,8 @@ class ServerConfiguration(BaseModel): device_auth_polling_interval: int = ( DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING ) - device_expiration_minutes: Optional[int] = None - trusted_device_expiration_minutes: Optional[int] = None + device_expiration_minutes: int | None = None + trusted_device_expiration_minutes: int | None = None generic_api_token_lifetime: PositiveInt = ( DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME @@ -284,14 +284,14 @@ class ServerConfiguration(BaseModel): DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_MAX_LIFETIME ) - external_login_url: Optional[str] = None - external_user_info_url: Optional[str] = None - external_server_id: Optional[UUID] = None + external_login_url: str | None = None + external_user_info_url: str | None = None + external_server_id: UUID | None = None - rbac_implementation_source: Optional[str] = None - feature_gate_implementation_source: Optional[str] = None - reportable_resources: List[str] = [] - workload_manager_implementation_source: Optional[str] = None + rbac_implementation_source: str | None = None + feature_gate_implementation_source: str | None = None + reportable_resources: list[str] = [] + workload_manager_implementation_source: str | None = None max_concurrent_snapshot_runs: int = ( DEFAULT_ZENML_SERVER_MAX_CONCURRENT_SNAPSHOT_RUNS ) @@ -303,35 +303,35 @@ class ServerConfiguration(BaseModel): login_rate_limit_minute: int = DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE login_rate_limit_day: int = DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY - secure_headers_server: Union[bool, str] = Field( + secure_headers_server: bool | str = Field( default=True, union_mode="left_to_right", ) - secure_headers_hsts: Union[bool, str] = Field( + secure_headers_hsts: bool | str = Field( default=DEFAULT_ZENML_SERVER_SECURE_HEADERS_HSTS, union_mode="left_to_right", ) - secure_headers_xfo: Union[bool, str] = Field( + secure_headers_xfo: bool | str = Field( default=DEFAULT_ZENML_SERVER_SECURE_HEADERS_XFO, union_mode="left_to_right", ) - secure_headers_content: Union[bool, str] = Field( + secure_headers_content: bool | str = Field( default=DEFAULT_ZENML_SERVER_SECURE_HEADERS_CONTENT, union_mode="left_to_right", ) - secure_headers_csp: Union[bool, str] = Field( + secure_headers_csp: bool | str = Field( default=DEFAULT_ZENML_SERVER_SECURE_HEADERS_CSP, union_mode="left_to_right", ) - secure_headers_referrer: Union[bool, str] = Field( + secure_headers_referrer: bool | str = Field( default=DEFAULT_ZENML_SERVER_SECURE_HEADERS_REFERRER, union_mode="left_to_right", ) - secure_headers_cache: Union[bool, str] = Field( + secure_headers_cache: bool | str = Field( default=DEFAULT_ZENML_SERVER_SECURE_HEADERS_CACHE, union_mode="left_to_right", ) - secure_headers_permissions: Union[bool, str] = Field( + secure_headers_permissions: bool | str = Field( default=DEFAULT_ZENML_SERVER_SECURE_HEADERS_PERMISSIONS, union_mode="left_to_right", ) @@ -358,12 +358,12 @@ class ServerConfiguration(BaseModel): DEFAULT_ZENML_SERVER_FILE_DOWNLOAD_SIZE_LIMIT ) - _deployment_id: Optional[UUID] = None + _deployment_id: UUID | None = None @model_validator(mode="before") @classmethod @before_validator_handler - def _validate_config(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def _validate_config(cls, data: dict[str, Any]) -> dict[str, Any]: """Validate the server configuration. Args: @@ -589,7 +589,7 @@ def get_server_config(cls) -> "ServerConfiguration": Returns: The server configuration. """ - env_server_config: Dict[str, Any] = {} + env_server_config: dict[str, Any] = {} for k, v in os.environ.items(): if v == "": continue @@ -693,9 +693,9 @@ class ServerProConfiguration(BaseModel): oauth2_client_secret: str oauth2_audience: str organization_id: UUID - organization_name: Optional[str] = None + organization_name: str | None = None workspace_id: UUID - workspace_name: Optional[str] = None + workspace_name: str | None = None http_timeout: int = DEFAULT_HTTP_TIMEOUT @field_validator("api_url", "dashboard_url") @@ -718,7 +718,7 @@ def get_server_config(cls) -> "ServerProConfiguration": Returns: The server Pro configuration. """ - env_server_config: Dict[str, Any] = {} + env_server_config: dict[str, Any] = {} for k, v in os.environ.items(): if v == "": continue diff --git a/src/zenml/config/settings_resolver.py b/src/zenml/config/settings_resolver.py index f9688732084..0ea174100f0 100644 --- a/src/zenml/config/settings_resolver.py +++ b/src/zenml/config/settings_resolver.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Class for resolving settings.""" -from typing import TYPE_CHECKING, Type, TypeVar +from typing import TYPE_CHECKING, TypeVar from pydantic import ValidationError @@ -77,7 +77,7 @@ def resolve(self, stack: "Stack") -> "BaseSettings": def _resolve_general_settings_class( self, - ) -> Type["BaseSettings"]: + ) -> type["BaseSettings"]: """Resolves general settings. Returns: @@ -87,7 +87,7 @@ def _resolve_general_settings_class( def _resolve_stack_component_setting_class( self, stack: "Stack" - ) -> Type["BaseSettings"]: + ) -> type["BaseSettings"]: """Resolves stack component settings with the given stack. Args: @@ -110,7 +110,7 @@ def _resolve_stack_component_setting_class( return settings_class - def _convert_settings(self, target_class: Type["T"]) -> "T": + def _convert_settings(self, target_class: type["T"]) -> "T": """Converts the settings to their correct class. Args: diff --git a/src/zenml/config/source.py b/src/zenml/config/source.py index e6b4e2b3332..8aec0f804e9 100644 --- a/src/zenml/config/source.py +++ b/src/zenml/config/source.py @@ -15,7 +15,8 @@ from enum import Enum from types import BuiltinFunctionType, FunctionType, ModuleType -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Union +from collections.abc import Callable from uuid import UUID from pydantic import ( @@ -26,7 +27,7 @@ SerializeAsAny, field_validator, ) -from typing_extensions import Annotated +from typing import Annotated from zenml.logger import get_logger @@ -68,7 +69,7 @@ class Source(BaseModel): """ module: str - attribute: Optional[str] = None + attribute: str | None = None type: SourceType @classmethod @@ -142,7 +143,7 @@ def is_module_source(self) -> bool: model_config = ConfigDict(extra="allow") - def model_dump(self, **kwargs: Any) -> Dict[str, Any]: + def model_dump(self, **kwargs: Any) -> dict[str, Any]: """Dump the source as a dictionary. Args: @@ -181,7 +182,7 @@ def convert_source(cls, source: Any) -> Any: ObjectType = Union[ - Type[Any], + type[Any], Callable[..., Any], ModuleType, FunctionType, @@ -210,7 +211,7 @@ class SourceOrObject(Source): _is_loaded: Whether the callable has been loaded. """ - _object: Optional[ObjectType] = None + _object: ObjectType | None = None _is_loaded: bool = False @classmethod @@ -296,9 +297,9 @@ def is_loaded(self) -> bool: def convert_source_or_object( cls, source: Union[ - str, "SourceOrObject", Source, Dict[str, Any], ObjectType + str, "SourceOrObject", Source, dict[str, Any], ObjectType ], - ) -> Union["SourceOrObject", Dict[str, Any]]: + ) -> Union["SourceOrObject", dict[str, Any]]: """Converts a source string or object to a SourceOrObject object. Args: @@ -321,7 +322,7 @@ def convert_source_or_object( return cls.from_object(source) - def model_dump(self, **kwargs: Any) -> Dict[str, Any]: + def model_dump(self, **kwargs: Any) -> dict[str, Any]: """Dump the source as a dictionary. Args: @@ -348,7 +349,7 @@ def model_dump_json(self, **kwargs: Any) -> str: @classmethod def serialize_source_or_object( cls, value: "SourceOrObject" - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Serialize the source or object as a dictionary. Args: @@ -369,7 +370,7 @@ class DistributionPackageSource(Source): """ package_name: str - version: Optional[str] = None + version: str | None = None type: SourceType = SourceType.DISTRIBUTION_PACKAGE @field_validator("type") @@ -437,8 +438,8 @@ class NotebookSource(Source): module code is stored. """ - replacement_module: Optional[str] = None - artifact_store_id: Optional[UUID] = None + replacement_module: str | None = None + artifact_store_id: UUID | None = None type: SourceType = SourceType.NOTEBOOK @field_validator("type") diff --git a/src/zenml/config/step_configurations.py b/src/zenml/config/step_configurations.py index 597074cc82b..fab5eec5303 100644 --- a/src/zenml/config/step_configurations.py +++ b/src/zenml/config/step_configurations.py @@ -16,13 +16,8 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, - Mapping, - Optional, - Tuple, - Union, ) +from collections.abc import Mapping from uuid import UUID from pydantic import ( @@ -60,18 +55,18 @@ class PartialArtifactConfiguration(FrozenBaseModel): """Class representing a partial input/output artifact configuration.""" - materializer_source: Optional[Tuple[SourceWithValidator, ...]] = None + materializer_source: tuple[SourceWithValidator, ...] | None = None # TODO: This could be moved to the `PipelineSnapshot` as it's the same # for all steps/outputs - default_materializer_source: Optional[SourceWithValidator] = None - artifact_config: Optional[ArtifactConfig] = None + default_materializer_source: SourceWithValidator | None = None + artifact_config: ArtifactConfig | None = None @model_validator(mode="before") @classmethod @before_validator_handler def _remove_deprecated_attributes( - cls, data: Dict[str, Any] - ) -> Dict[str, Any]: + cls, data: dict[str, Any] + ) -> dict[str, Any]: """Removes deprecated attributes from the values dict. Args: @@ -92,8 +87,8 @@ def _remove_deprecated_attributes( @classmethod def _convert_source( cls, - value: Union[None, Source, Dict[str, Any], str, Tuple[Source, ...]], - ) -> Optional[Tuple[Source, ...]]: + value: None | Source | dict[str, Any] | str | tuple[Source, ...], + ) -> tuple[Source, ...] | None: """Converts old source strings to tuples of source objects. Args: @@ -115,14 +110,14 @@ def _convert_source( class ArtifactConfiguration(PartialArtifactConfiguration): """Class representing a complete input/output artifact configuration.""" - materializer_source: Tuple[SourceWithValidator, ...] + materializer_source: tuple[SourceWithValidator, ...] @field_validator("materializer_source", mode="before") @classmethod def _convert_source( cls, - value: Union[None, Source, Dict[str, Any], str, Tuple[Source, ...]], - ) -> Optional[Tuple[Source, ...]]: + value: None | Source | dict[str, Any] | str | tuple[Source, ...], + ) -> tuple[Source, ...] | None: """Converts old source strings to tuples of source objects. Args: @@ -144,73 +139,73 @@ def _convert_source( class StepConfigurationUpdate(FrozenBaseModel): """Class for step configuration updates.""" - enable_cache: Optional[bool] = Field( + enable_cache: bool | None = Field( default=None, description="Whether to enable cache for the step.", ) - enable_artifact_metadata: Optional[bool] = Field( + enable_artifact_metadata: bool | None = Field( default=None, description="Whether to store metadata for the output artifacts of " "the step.", ) - enable_artifact_visualization: Optional[bool] = Field( + enable_artifact_visualization: bool | None = Field( default=None, description="Whether to enable visualizations for the output " "artifacts of the step.", ) - enable_step_logs: Optional[bool] = Field( + enable_step_logs: bool | None = Field( default=None, description="Whether to enable logs for the step.", ) - step_operator: Optional[Union[bool, str]] = Field( + step_operator: bool | str | None = Field( default=None, description="The step operator to use for the step.", ) - experiment_tracker: Optional[Union[bool, str]] = Field( + experiment_tracker: bool | str | None = Field( default=None, description="The experiment tracker to use for the step.", ) - parameters: Optional[Dict[str, Any]] = Field( + parameters: dict[str, Any] | None = Field( default=None, description="Parameters for the step function.", ) - settings: Optional[Dict[str, SerializeAsAny[BaseSettings]]] = Field( + settings: dict[str, SerializeAsAny[BaseSettings]] | None = Field( default=None, description="Settings for the step.", ) - environment: Optional[Dict[str, str]] = Field( + environment: dict[str, str] | None = Field( default=None, description="The environment for the step.", ) - secrets: Optional[List[Union[str, UUID]]] = Field( + secrets: list[str | UUID] | None = Field( default=None, description="The secrets for the step.", ) - extra: Optional[Dict[str, Any]] = Field( + extra: dict[str, Any] | None = Field( default=None, description="Extra configurations for the step.", ) - failure_hook_source: Optional[SourceWithValidator] = Field( + failure_hook_source: SourceWithValidator | None = Field( default=None, description="The failure hook source for the step.", ) - success_hook_source: Optional[SourceWithValidator] = Field( + success_hook_source: SourceWithValidator | None = Field( default=None, description="The success hook source for the step.", ) - model: Optional[Model] = Field( + model: Model | None = Field( default=None, description="The model to use for the step.", ) - retry: Optional[StepRetryConfig] = Field( + retry: StepRetryConfig | None = Field( default=None, description="The retry configuration for the step.", ) - substitutions: Optional[Dict[str, str]] = Field( + substitutions: dict[str, str] | None = Field( default=None, description="The substitutions for the step.", ) - cache_policy: Optional[CachePolicyWithValidator] = Field( + cache_policy: CachePolicyWithValidator | None = Field( default=None, description="The cache policy for the step.", ) @@ -254,12 +249,12 @@ class PartialStepConfiguration(StepConfigurationUpdate): """Class representing a partial step configuration.""" name: str - parameters: Dict[str, Any] = {} - settings: Dict[str, SerializeAsAny[BaseSettings]] = {} - environment: Dict[str, str] = {} - secrets: List[Union[str, UUID]] = [] - extra: Dict[str, Any] = {} - substitutions: Dict[str, str] = {} + parameters: dict[str, Any] = {} + settings: dict[str, SerializeAsAny[BaseSettings]] = {} + environment: dict[str, str] = {} + secrets: list[str | UUID] = [] + extra: dict[str, Any] = {} + substitutions: dict[str, str] = {} caching_parameters: Mapping[str, Any] = {} external_input_artifacts: Mapping[str, ExternalArtifactConfiguration] = {} model_artifacts_or_metadata: Mapping[str, ModelVersionDataLazyLoader] = {} @@ -404,14 +399,14 @@ class StepSpec(FrozenBaseModel): """Specification of a pipeline.""" source: SourceWithValidator - upstream_steps: List[str] - inputs: Dict[str, InputSpec] = {} + upstream_steps: list[str] + inputs: dict[str, InputSpec] = {} invocation_id: str @model_validator(mode="before") @classmethod @before_validator_handler - def _migrate_invocation_id(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def _migrate_invocation_id(cls, data: dict[str, Any]) -> dict[str, Any]: if "invocation_id" not in data: data["invocation_id"] = data.pop("pipeline_parameter_name", "") return data @@ -480,7 +475,7 @@ def _add_step_config_overrides_if_missing(cls, data: Any) -> Any: @classmethod def from_dict( cls, - data: Dict[str, Any], + data: dict[str, Any], pipeline_configuration: "PipelineConfiguration", ) -> "Step": """Create a step from a dictionary. diff --git a/src/zenml/config/step_run_info.py b/src/zenml/config/step_run_info.py index 9e3c86723c4..4df04876104 100644 --- a/src/zenml/config/step_run_info.py +++ b/src/zenml/config/step_run_info.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Step run info.""" -from typing import Any, Callable +from typing import Any +from collections.abc import Callable from uuid import UUID from zenml.config.frozen_base_model import FrozenBaseModel diff --git a/src/zenml/config/store_config.py b/src/zenml/config/store_config.py index f9434e21fc2..6814f89e52d 100644 --- a/src/zenml/config/store_config.py +++ b/src/zenml/config/store_config.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Functionality to support ZenML store configurations.""" -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, SerializeAsAny, model_validator @@ -44,10 +44,10 @@ class StoreConfiguration(BaseModel): type: StoreType url: str - secrets_store: Optional[SerializeAsAny[SecretsStoreConfiguration]] = None - backup_secrets_store: Optional[ + secrets_store: SerializeAsAny[SecretsStoreConfiguration] | None = None + backup_secrets_store: None | ( SerializeAsAny[SecretsStoreConfiguration] - ] = None + ) = None @classmethod def supports_url_scheme(cls, url: str) -> bool: @@ -67,7 +67,7 @@ def supports_url_scheme(cls, url: str) -> bool: @model_validator(mode="before") @classmethod @before_validator_handler - def validate_store_config(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def validate_store_config(cls, data: dict[str, Any]) -> dict[str, Any]: """Validate the secrets store configuration. Args: diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 05e8db0b8f5..2ba1b8228fb 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -16,7 +16,7 @@ import json import logging import os -from typing import Any, List, Optional, Type, TypeVar +from typing import Any, TypeVar from zenml.enums import AuthScheme @@ -25,8 +25,8 @@ def handle_json_env_var( var: str, - expected_type: Type[T], - default: Optional[List[str]] = None, + expected_type: type[T], + default: list[str] | None = None, ) -> Any: """Converts a json env var into a Python object. diff --git a/src/zenml/container_registries/azure_container_registry.py b/src/zenml/container_registries/azure_container_registry.py index 42807fc8f62..d8b6b7c21ee 100644 --- a/src/zenml/container_registries/azure_container_registry.py +++ b/src/zenml/container_registries/azure_container_registry.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Implementation of an Azure Container Registry class.""" -from typing import Optional from zenml.constants import DOCKER_REGISTRY_RESOURCE_TYPE from zenml.container_registries.base_container_registry import ( @@ -38,7 +37,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -55,7 +54,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -64,7 +63,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: diff --git a/src/zenml/container_registries/base_container_registry.py b/src/zenml/container_registries/base_container_registry.py index abbaad757eb..1fe91e88a78 100644 --- a/src/zenml/container_registries/base_container_registry.py +++ b/src/zenml/container_registries/base_container_registry.py @@ -14,7 +14,7 @@ """Implementation of a base container registry class.""" import re -from typing import TYPE_CHECKING, Optional, Tuple, Type, cast +from typing import TYPE_CHECKING, Optional, cast from pydantic import Field, field_validator @@ -46,7 +46,7 @@ class BaseContainerRegistryConfig(AuthenticationConfigMixin): "Container Registry, 'ghcr.io' for GitHub Container Registry). This is " "the base URL where container images will be pushed to and pulled from." ) - default_repository: Optional[str] = Field( + default_repository: str | None = Field( default=None, description="Default repository namespace for image storage (e.g., " "'username' for Docker Hub, 'project-id' for GCR, 'organization' for " @@ -103,7 +103,7 @@ def requires_authentication(self) -> bool: return bool(self.config.authentication_secret) @property - def credentials(self) -> Optional[Tuple[str, str]]: + def credentials(self) -> tuple[str, str] | None: """Username and password to authenticate with this container registry. Returns: @@ -225,7 +225,7 @@ def push_image(self, image_name: str) -> str: image_name, docker_client=self.docker_client ) - def get_image_repo_digest(self, image_name: str) -> Optional[str]: + def get_image_repo_digest(self, image_name: str) -> str | None: """Get the repository digest of an image. Args: @@ -260,7 +260,7 @@ def type(self) -> StackComponentType: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -276,7 +276,7 @@ def service_connector_requirements( ) @property - def config_class(self) -> Type[BaseContainerRegistryConfig]: + def config_class(self) -> type[BaseContainerRegistryConfig]: """Config class for this flavor. Returns: @@ -285,7 +285,7 @@ def config_class(self) -> Type[BaseContainerRegistryConfig]: return BaseContainerRegistryConfig @property - def implementation_class(self) -> Type[BaseContainerRegistry]: + def implementation_class(self) -> type[BaseContainerRegistry]: """Implementation class. Returns: diff --git a/src/zenml/container_registries/default_container_registry.py b/src/zenml/container_registries/default_container_registry.py index 39e98df3edc..acc0c141c91 100644 --- a/src/zenml/container_registries/default_container_registry.py +++ b/src/zenml/container_registries/default_container_registry.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Implementation of a default container registry class.""" -from typing import Optional from zenml.container_registries.base_container_registry import ( BaseContainerRegistryFlavor, @@ -34,7 +33,7 @@ def name(self) -> str: return ContainerRegistryFlavor.DEFAULT.value @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A URL to point at docs explaining this flavor. Returns: @@ -43,7 +42,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A URL to point at docs explaining this flavor. Returns: diff --git a/src/zenml/container_registries/dockerhub_container_registry.py b/src/zenml/container_registries/dockerhub_container_registry.py index ddffacbdf26..3e5929b6d1e 100644 --- a/src/zenml/container_registries/dockerhub_container_registry.py +++ b/src/zenml/container_registries/dockerhub_container_registry.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Implementation of a DockerHub Container Registry class.""" -from typing import Optional from zenml.constants import DOCKER_REGISTRY_RESOURCE_TYPE from zenml.container_registries.base_container_registry import ( @@ -38,7 +37,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -55,7 +54,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -64,7 +63,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: diff --git a/src/zenml/container_registries/gcp_container_registry.py b/src/zenml/container_registries/gcp_container_registry.py index 09727a743a2..1fe26a6f0cc 100644 --- a/src/zenml/container_registries/gcp_container_registry.py +++ b/src/zenml/container_registries/gcp_container_registry.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Implementation of a GCP Container Registry class.""" -from typing import Optional from zenml.constants import DOCKER_REGISTRY_RESOURCE_TYPE from zenml.container_registries.base_container_registry import ( @@ -38,7 +37,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -55,7 +54,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -64,7 +63,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: diff --git a/src/zenml/container_registries/github_container_registry.py b/src/zenml/container_registries/github_container_registry.py index 64bcb628fc7..cc638b2e4f5 100644 --- a/src/zenml/container_registries/github_container_registry.py +++ b/src/zenml/container_registries/github_container_registry.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Implementation of the GitHub Container Registry.""" -from typing import Optional from zenml.container_registries.base_container_registry import ( BaseContainerRegistryConfig, @@ -39,7 +38,7 @@ def name(self) -> str: return ContainerRegistryFlavor.GITHUB @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -48,7 +47,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: diff --git a/src/zenml/data_validators/base_data_validator.py b/src/zenml/data_validators/base_data_validator.py index 0a62b2ed142..68be665e8ab 100644 --- a/src/zenml/data_validators/base_data_validator.py +++ b/src/zenml/data_validators/base_data_validator.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Base class for all ZenML data validators.""" -from typing import Any, ClassVar, Optional, Sequence, Type, cast +from typing import Any, ClassVar, cast +from collections.abc import Sequence from zenml.client import Client from zenml.enums import StackComponentType @@ -29,7 +30,7 @@ class BaseDataValidator(StackComponent): """Base class for all ZenML data validators.""" NAME: ClassVar[str] - FLAVOR: ClassVar[Type["BaseDataValidatorFlavor"]] + FLAVOR: ClassVar[type["BaseDataValidatorFlavor"]] @property def config(self) -> BaseDataValidatorConfig: @@ -73,8 +74,8 @@ def get_active_data_validator(cls) -> "BaseDataValidator": def data_profiling( self, dataset: Any, - comparison_dataset: Optional[Any] = None, - profile_list: Optional[Sequence[Any]] = None, + comparison_dataset: Any | None = None, + profile_list: Sequence[Any] | None = None, **kwargs: Any, ) -> Any: """Analyze one or more datasets and generate a data profile. @@ -121,8 +122,8 @@ def data_profiling( def data_validation( self, dataset: Any, - comparison_dataset: Optional[Any] = None, - check_list: Optional[Sequence[Any]] = None, + comparison_dataset: Any | None = None, + check_list: Sequence[Any] | None = None, **kwargs: Any, ) -> Any: """Run data validation checks on a dataset. @@ -167,8 +168,8 @@ def model_validation( self, dataset: Any, model: Any, - comparison_dataset: Optional[Any] = None, - check_list: Optional[Sequence[Any]] = None, + comparison_dataset: Any | None = None, + check_list: Sequence[Any] | None = None, **kwargs: Any, ) -> Any: """Run model validation checks. @@ -227,7 +228,7 @@ def type(self) -> StackComponentType: return StackComponentType.DATA_VALIDATOR @property - def config_class(self) -> Type[BaseDataValidatorConfig]: + def config_class(self) -> type[BaseDataValidatorConfig]: """Config class for data validator. Returns: @@ -236,7 +237,7 @@ def config_class(self) -> Type[BaseDataValidatorConfig]: return BaseDataValidatorConfig @property - def implementation_class(self) -> Type[BaseDataValidator]: + def implementation_class(self) -> type[BaseDataValidator]: """Implementation for data validator. Returns: diff --git a/src/zenml/deployers/base_deployer.py b/src/zenml/deployers/base_deployer.py index 16377c3ca4d..298477c6cfc 100644 --- a/src/zenml/deployers/base_deployer.py +++ b/src/zenml/deployers/base_deployer.py @@ -20,14 +20,10 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - Generator, Optional, - Tuple, - Type, - Union, cast, ) +from collections.abc import Generator from uuid import UUID import requests @@ -77,7 +73,7 @@ class BaseDeployerSettings(BaseSettings): """Base settings for all deployers.""" - auth_key: Optional[str] = None + auth_key: str | None = None generate_auth_key: bool = False lcm_timeout: int = DEFAULT_DEPLOYMENT_LCM_TIMEOUT @@ -223,7 +219,7 @@ def _check_deployment_deployer( ) def _check_deployment_snapshot( - self, snapshot: Optional[PipelineSnapshotResponse] = None + self, snapshot: PipelineSnapshotResponse | None = None ) -> None: """Check if the snapshot was created for this deployer. @@ -253,7 +249,7 @@ def _check_deployment_snapshot( def _check_snapshot_already_deployed( self, snapshot: PipelineSnapshotResponse, - new_deployment_id_or_name: Union[str, UUID], + new_deployment_id_or_name: str | UUID, ) -> None: """Check if the snapshot is already deployed to another deployment. @@ -358,7 +354,7 @@ def _poll_deployment( deployment: DeploymentResponse, desired_status: DeploymentStatus, timeout: int, - ) -> Tuple[DeploymentResponse, DeploymentOperationalState]: + ) -> tuple[DeploymentResponse, DeploymentOperationalState]: """Poll the deployment until it reaches the desired status, an error occurs or times out. Args: @@ -432,7 +428,7 @@ def _get_deployment_analytics_metadata( self, deployment: "DeploymentResponse", stack: Optional["Stack"] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Returns the deployment metadata. Args: @@ -464,9 +460,9 @@ def provision_deployment( self, snapshot: PipelineSnapshotResponse, stack: "Stack", - deployment_name_or_id: Union[str, UUID], + deployment_name_or_id: str | UUID, replace: bool = True, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> DeploymentResponse: """Provision a deployment. @@ -696,8 +692,8 @@ def provision_deployment( def refresh_deployment( self, - deployment_name_or_id: Union[str, UUID], - project: Optional[UUID] = None, + deployment_name_or_id: str | UUID, + project: UUID | None = None, ) -> DeploymentResponse: """Refresh the status of a deployment by name or ID. @@ -754,9 +750,9 @@ def refresh_deployment( def deprovision_deployment( self, - deployment_name_or_id: Union[str, UUID], - project: Optional[UUID] = None, - timeout: Optional[int] = None, + deployment_name_or_id: str | UUID, + project: UUID | None = None, + timeout: int | None = None, ) -> DeploymentResponse: """Deprovision a deployment. @@ -861,10 +857,10 @@ def deprovision_deployment( def delete_deployment( self, - deployment_name_or_id: Union[str, UUID], - project: Optional[UUID] = None, + deployment_name_or_id: str | UUID, + project: UUID | None = None, force: bool = False, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> None: """Deprovision and delete a deployment. @@ -907,10 +903,10 @@ def delete_deployment( def get_deployment_logs( self, - deployment_name_or_id: Union[str, UUID], - project: Optional[UUID] = None, + deployment_name_or_id: str | UUID, + project: UUID | None = None, follow: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> Generator[str, bool, None]: """Get the logs of a deployment. @@ -961,8 +957,8 @@ def do_provision_deployment( self, deployment: DeploymentResponse, stack: "Stack", - environment: Dict[str, str], - secrets: Dict[str, str], + environment: dict[str, str], + secrets: dict[str, str], timeout: int, ) -> DeploymentOperationalState: """Abstract method to deploy a pipeline as an HTTP deployment. @@ -1039,7 +1035,7 @@ def do_get_deployment_state_logs( self, deployment: DeploymentResponse, follow: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> Generator[str, bool, None]: """Abstract method to get the logs of a deployment. @@ -1066,7 +1062,7 @@ def do_deprovision_deployment( self, deployment: DeploymentResponse, timeout: int, - ) -> Optional[DeploymentOperationalState]: + ) -> DeploymentOperationalState | None: """Abstract method to deprovision a deployment. Concrete deployer subclasses must implement the following @@ -1119,7 +1115,7 @@ def type(self) -> StackComponentType: return StackComponentType.DEPLOYER @property - def config_class(self) -> Type[BaseDeployerConfig]: + def config_class(self) -> type[BaseDeployerConfig]: """Returns `BaseDeployerConfig` config class. Returns: @@ -1129,5 +1125,5 @@ def config_class(self) -> Type[BaseDeployerConfig]: @property @abstractmethod - def implementation_class(self) -> Type[BaseDeployer]: + def implementation_class(self) -> type[BaseDeployer]: """The class that implements the deployer.""" diff --git a/src/zenml/deployers/containerized_deployer.py b/src/zenml/deployers/containerized_deployer.py index 6335193678e..dfff2409aa8 100644 --- a/src/zenml/deployers/containerized_deployer.py +++ b/src/zenml/deployers/containerized_deployer.py @@ -14,10 +14,6 @@ """Base class for all containerized deployers.""" from abc import ABC -from typing import ( - List, - Set, -) import zenml from zenml.config.build_configuration import BuildConfiguration @@ -62,7 +58,7 @@ def get_image(snapshot: PipelineSnapshotResponse) -> str: return snapshot.build.images[DEPLOYER_DOCKER_IMAGE_KEY].image @property - def requirements(self) -> Set[str]: + def requirements(self) -> set[str]: """Set of PyPI requirements for the deployer. Returns: @@ -79,7 +75,7 @@ def requirements(self) -> Set[str]: def get_docker_builds( self, snapshot: "PipelineSnapshotBase" - ) -> List["BuildConfiguration"]: + ) -> list["BuildConfiguration"]: """Gets the Docker builds required for the component. Args: diff --git a/src/zenml/deployers/docker/docker_deployer.py b/src/zenml/deployers/docker/docker_deployer.py index 14740ce7d15..a5e37f56b18 100644 --- a/src/zenml/deployers/docker/docker_deployer.py +++ b/src/zenml/deployers/docker/docker_deployer.py @@ -18,14 +18,9 @@ import sys from typing import ( Any, - Dict, - Generator, - List, - Optional, - Tuple, - Type, cast, ) +from collections.abc import Generator import docker.errors as docker_errors from docker.client import DockerClient @@ -74,12 +69,12 @@ class DockerDeploymentMetadata(BaseModel): """Metadata for a Docker deployment.""" - port: Optional[int] = None - container_id: Optional[str] = None - container_name: Optional[str] = None - container_image_id: Optional[str] = None - container_image_uri: Optional[str] = None - container_status: Optional[str] = None + port: int | None = None + container_id: str | None = None + container_name: str | None = None + container_image_id: str | None = None + container_image_uri: str | None = None + container_status: str | None = None @classmethod def from_container( @@ -135,10 +130,10 @@ def from_deployment( class DockerDeployer(ContainerizedDeployer): """Deployer responsible for deploying pipelines locally using Docker.""" - _docker_client: Optional[DockerClient] = None + _docker_client: DockerClient | None = None @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Docker deployer. Returns: @@ -156,7 +151,7 @@ def config(self) -> "DockerDeployerConfig": return cast(DockerDeployerConfig, self._config) @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Ensures there is an image builder in the stack. Returns: @@ -192,7 +187,7 @@ def _get_container_id(self, deployment: DeploymentResponse) -> str: def _get_container( self, deployment: DeploymentResponse - ) -> Optional[Container]: + ) -> Container | None: """Get the docker container associated with a deployment. Args: @@ -253,8 +248,8 @@ def do_provision_deployment( self, deployment: DeploymentResponse, stack: "Stack", - environment: Dict[str, str], - secrets: Dict[str, str], + environment: dict[str, str], + secrets: dict[str, str], timeout: int, ) -> DeploymentOperationalState: """Deploy a pipeline as a Docker container. @@ -350,7 +345,7 @@ def do_provision_deployment( ) self.docker_client.images.pull(image) - preferred_ports: List[int] = [] + preferred_ports: list[int] = [] if settings.port: preferred_ports.append(settings.port) if existing_metadata.port: @@ -364,9 +359,9 @@ def do_provision_deployment( container_port = ( snapshot.pipeline_configuration.deployment_settings.uvicorn_port ) - ports: Dict[str, Optional[int]] = {f"{container_port}/tcp": port} + ports: dict[str, int | None] = {f"{container_port}/tcp": port} - uid_args: Dict[str, Any] = {} + uid_args: dict[str, Any] = {} if sys.platform == "win32": # File permissions are not checked on Windows. This if clause # prevents mypy from complaining about unused 'type: ignore' @@ -462,7 +457,7 @@ def do_get_deployment_state_logs( self, deployment: DeploymentResponse, follow: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> Generator[str, bool, None]: """Get the logs of a Docker deployment. @@ -491,7 +486,7 @@ def do_get_deployment_state_logs( ) try: - log_kwargs: Dict[str, Any] = { + log_kwargs: dict[str, Any] = { "stdout": True, "stderr": True, "stream": follow, @@ -515,8 +510,7 @@ def do_get_deployment_state_logs( else: if isinstance(log_stream, bytes): log_text = log_stream.decode("utf-8", errors="replace") - for line in log_text.splitlines(): - yield line + yield from log_text.splitlines() else: for log_line in log_stream: if isinstance(log_line, bytes): @@ -550,7 +544,7 @@ def do_deprovision_deployment( self, deployment: DeploymentResponse, timeout: int, - ) -> Optional[DeploymentOperationalState]: + ) -> DeploymentOperationalState | None: """Deprovision a docker deployment. Args: @@ -601,10 +595,10 @@ class DockerDeployerSettings(BaseDeployerSettings): of what can be passed.) """ - port: Optional[int] = None + port: int | None = None allocate_port_if_busy: bool = True - port_range: Tuple[int, int] = (8000, 65535) - run_args: Dict[str, Any] = {} + port_range: tuple[int, int] = (8000, 65535) + run_args: dict[str, Any] = {} class DockerDeployerConfig(BaseDeployerConfig, DockerDeployerSettings): @@ -633,7 +627,7 @@ def name(self) -> str: return "docker" @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -642,7 +636,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -660,7 +654,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/deployer/docker.png" @property - def config_class(self) -> Type[BaseDeployerConfig]: + def config_class(self) -> type[BaseDeployerConfig]: """Config class for the base deployer flavor. Returns: @@ -669,7 +663,7 @@ def config_class(self) -> Type[BaseDeployerConfig]: return DockerDeployerConfig @property - def implementation_class(self) -> Type["DockerDeployer"]: + def implementation_class(self) -> type["DockerDeployer"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/deployers/local/local_deployer.py b/src/zenml/deployers/local/local_deployer.py index ffbfa6cced0..64df71a43e0 100644 --- a/src/zenml/deployers/local/local_deployer.py +++ b/src/zenml/deployers/local/local_deployer.py @@ -21,14 +21,9 @@ import time from typing import ( TYPE_CHECKING, - Dict, - Generator, - List, - Optional, - Tuple, - Type, cast, ) +from collections.abc import Generator from uuid import UUID import psutil @@ -81,10 +76,10 @@ class LocalDeploymentMetadata(BaseModel): log_file: Path to log file. """ - pid: Optional[int] = None - port: Optional[int] = None - address: Optional[str] = None - log_file: Optional[str] = None + pid: int | None = None + port: int | None = None + address: str | None = None + log_file: str | None = None @classmethod def from_deployment( @@ -115,9 +110,9 @@ class LocalDeployerSettings(BaseDeployerSettings): code changes. """ - port: Optional[int] = None + port: int | None = None allocate_port_if_busy: bool = True - port_range: Tuple[int, int] = (8000, 65535) + port_range: tuple[int, int] = (8000, 65535) address: str = "127.0.0.1" blocking: bool = False auto_reload: bool = False @@ -140,7 +135,7 @@ class LocalDeployer(BaseDeployer): """Deployer that runs deployments as local daemon processes.""" @property - def settings_class(self) -> Optional[Type[BaseSettings]]: + def settings_class(self) -> type[BaseSettings] | None: """Settings class for the local deployer. Returns: @@ -200,8 +195,8 @@ def do_provision_deployment( self, deployment: DeploymentResponse, stack: "Stack", - environment: Dict[str, str], - secrets: Dict[str, str], + environment: dict[str, str], + secrets: dict[str, str], timeout: int, ) -> DeploymentOperationalState: """Provision a local daemon deployment. @@ -221,7 +216,7 @@ def do_provision_deployment( """ assert deployment.snapshot, "Pipeline snapshot not found" - child_env: Dict[str, str] = dict(os.environ) + child_env: dict[str, str] = dict(os.environ) child_env.update(environment) child_env.update(secrets) @@ -241,7 +236,7 @@ def do_provision_deployment( f"'{deployment.name}' with PID {existing_meta.pid}: {e}" ) - preferred_ports: List[int] = [] + preferred_ports: list[int] = [] if settings.port: preferred_ports.append(settings.port) if existing_meta.port: @@ -254,7 +249,7 @@ def do_provision_deployment( range=settings.port_range, address=settings.address, ) - except IOError as e: + except OSError as e: raise DeploymentProvisionError(str(e)) address = settings.address @@ -402,7 +397,7 @@ def do_get_deployment_state_logs( self, deployment: DeploymentResponse, follow: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> Generator[str, bool, None]: """Read logs from the local daemon log file. @@ -428,7 +423,7 @@ def do_get_deployment_state_logs( try: def _read_tail(path: str, n: int) -> Generator[str, bool, None]: - with open(path, "r", encoding="utf-8", errors="ignore") as f: + with open(path, encoding="utf-8", errors="ignore") as f: lines = f.readlines() for line in lines[-n:]: yield line.rstrip("\n") @@ -438,13 +433,13 @@ def _read_tail(path: str, n: int) -> Generator[str, bool, None]: yield from _read_tail(log_file, tail) else: with open( - log_file, "r", encoding="utf-8", errors="ignore" + log_file, encoding="utf-8", errors="ignore" ) as f: for line in f: yield line.rstrip("\n") return - with open(log_file, "r", encoding="utf-8", errors="ignore") as f: + with open(log_file, encoding="utf-8", errors="ignore") as f: if not tail: tail = DEFAULT_TAIL_FOLLOW_LINES lines = f.readlines() @@ -470,7 +465,7 @@ def _read_tail(path: str, n: int) -> Generator[str, bool, None]: def do_deprovision_deployment( self, deployment: DeploymentResponse, timeout: int - ) -> Optional[DeploymentOperationalState]: + ) -> DeploymentOperationalState | None: """Deprovision a local daemon deployment. Args: @@ -516,7 +511,7 @@ def name(self) -> str: return "local" @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -525,7 +520,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -543,7 +538,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/deployer/local.png" @property - def config_class(self) -> Type[BaseDeployerConfig]: + def config_class(self) -> type[BaseDeployerConfig]: """Config class for the flavor. Returns: @@ -552,7 +547,7 @@ def config_class(self) -> Type[BaseDeployerConfig]: return LocalDeployerConfig @property - def implementation_class(self) -> Type[LocalDeployer]: + def implementation_class(self) -> type[LocalDeployer]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/deployers/server/adapters.py b/src/zenml/deployers/server/adapters.py index 299959ff732..8cac1b69243 100644 --- a/src/zenml/deployers/server/adapters.py +++ b/src/zenml/deployers/server/adapters.py @@ -14,7 +14,8 @@ """Framework adapter interfaces.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Callable, cast +from typing import TYPE_CHECKING, Any, cast +from collections.abc import Callable from asgiref.typing import ( ASGIApplication, diff --git a/src/zenml/deployers/server/app.py b/src/zenml/deployers/server/app.py index bcf4f323f38..e851c0befa9 100644 --- a/src/zenml/deployers/server/app.py +++ b/src/zenml/deployers/server/app.py @@ -19,14 +19,9 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - List, - Optional, - Tuple, - Type, Union, ) +from collections.abc import Callable from uuid import UUID from asgiref.compatibility import guarantee_single_callable @@ -147,11 +142,11 @@ def __init__( # Create framework-specific adapters self.endpoint_adapter = self._create_endpoint_adapter() self.middleware_adapter = self._create_middleware_adapter() - self._asgi_app: Optional[ASGIApplication] = None + self._asgi_app: ASGIApplication | None = None - self.endpoints: List[EndpointSpec] = [] - self.middlewares: List[MiddlewareSpec] = [] - self.extensions: List[AppExtensionSpec] = [] + self.endpoints: list[EndpointSpec] = [] + self.middlewares: list[MiddlewareSpec] = [] + self.extensions: list[AppExtensionSpec] = [] @property def asgi_app(self) -> ASGIApplication: @@ -258,7 +253,7 @@ def load_deployment_service(self) -> BasePipelineDeploymentService: """ settings = self.snapshot.pipeline_configuration.deployment_settings if settings.deployment_service_class is None: - service_cls: Type[BasePipelineDeploymentService] = ( + service_cls: type[BasePipelineDeploymentService] = ( PipelineDeploymentService ) else: @@ -346,7 +341,7 @@ def _invoke_endpoint( return _invoke_endpoint - def dashboard_files_path(self) -> Optional[str]: + def dashboard_files_path(self) -> str | None: """Get the absolute path of the dashboard files directory. Returns: @@ -384,7 +379,7 @@ def dashboard_files_path(self) -> Optional[str]: return dashboard_path @abstractmethod - def _get_dashboard_endpoints(self) -> List[EndpointSpec]: + def _get_dashboard_endpoints(self) -> list[EndpointSpec]: """Get the dashboard endpoints specs. This is called if the dashboard files path is set to construct the @@ -394,7 +389,7 @@ def _get_dashboard_endpoints(self) -> List[EndpointSpec]: The dashboard endpoints specs. """ - def _create_default_endpoint_specs(self) -> List[EndpointSpec]: + def _create_default_endpoint_specs(self) -> list[EndpointSpec]: """Create EndpointSpec objects for default endpoints. Returns: @@ -460,7 +455,7 @@ def _get_secure_headers(self) -> "Secure": # - if set to a string, we use the string as the value for the header # - if set to `False`, we don't set the header - server: Optional[secure.Server] = None + server: secure.Server | None = None if self.settings.secure_headers.server: server = secure.Server() if isinstance(self.settings.secure_headers.server, str): @@ -468,43 +463,43 @@ def _get_secure_headers(self) -> "Secure": else: server.set(str(self.deployment.id)) - hsts: Optional[secure.StrictTransportSecurity] = None + hsts: secure.StrictTransportSecurity | None = None if self.settings.secure_headers.hsts: hsts = secure.StrictTransportSecurity() if isinstance(self.settings.secure_headers.hsts, str): hsts.set(self.settings.secure_headers.hsts) - xfo: Optional[secure.XFrameOptions] = None + xfo: secure.XFrameOptions | None = None if self.settings.secure_headers.xfo: xfo = secure.XFrameOptions() if isinstance(self.settings.secure_headers.xfo, str): xfo.set(self.settings.secure_headers.xfo) - csp: Optional[secure.ContentSecurityPolicy] = None + csp: secure.ContentSecurityPolicy | None = None if self.settings.secure_headers.csp: csp = secure.ContentSecurityPolicy() if isinstance(self.settings.secure_headers.csp, str): csp.set(self.settings.secure_headers.csp) - xcto: Optional[secure.XContentTypeOptions] = None + xcto: secure.XContentTypeOptions | None = None if self.settings.secure_headers.content: xcto = secure.XContentTypeOptions() if isinstance(self.settings.secure_headers.content, str): xcto.set(self.settings.secure_headers.content) - referrer: Optional[secure.ReferrerPolicy] = None + referrer: secure.ReferrerPolicy | None = None if self.settings.secure_headers.referrer: referrer = secure.ReferrerPolicy() if isinstance(self.settings.secure_headers.referrer, str): referrer.set(self.settings.secure_headers.referrer) - cache: Optional[secure.CacheControl] = None + cache: secure.CacheControl | None = None if self.settings.secure_headers.cache: cache = secure.CacheControl() if isinstance(self.settings.secure_headers.cache, str): cache.set(self.settings.secure_headers.cache) - permissions: Optional[secure.PermissionsPolicy] = None + permissions: secure.PermissionsPolicy | None = None if self.settings.secure_headers.permissions: permissions = secure.PermissionsPolicy() if isinstance(self.settings.secure_headers.permissions, str): @@ -556,7 +551,7 @@ async def set_secure_headers( async def send_wrapper(message: ASGISendEvent) -> None: if message["type"] == "http.response.start": - hdrs: List[Tuple[bytes, bytes]] = list( + hdrs: list[tuple[bytes, bytes]] = list( message.get("headers", []) ) existing = {k: i for i, (k, _) in enumerate(hdrs)} @@ -585,7 +580,7 @@ def _build_cors_middleware(self) -> MiddlewareSpec: The CORS middleware spec. """ - def _create_default_middleware_specs(self) -> List[MiddlewareSpec]: + def _create_default_middleware_specs(self) -> list[MiddlewareSpec]: """Create MiddlewareSpec objects for default middleware. Returns: @@ -847,7 +842,7 @@ def run(self) -> None: Log Level: {settings.log_level} """) - uvicorn_kwargs: Dict[str, Any] = dict( + uvicorn_kwargs: dict[str, Any] = dict( host=settings.uvicorn_host, port=settings.uvicorn_port, workers=settings.uvicorn_workers, @@ -889,9 +884,9 @@ def run(self) -> None: @abstractmethod def build( self, - middlewares: List[MiddlewareSpec], - endpoints: List[EndpointSpec], - extensions: List[AppExtensionSpec], + middlewares: list[MiddlewareSpec], + endpoints: list[EndpointSpec], + extensions: list[AppExtensionSpec], ) -> ASGIApplication: """Build the ASGI compatible web application. @@ -925,7 +920,7 @@ def name(self) -> str: @property @abstractmethod - def implementation_class(self) -> Type[BaseDeploymentAppRunner]: + def implementation_class(self) -> type[BaseDeploymentAppRunner]: """The class that implements the deployment app runner. Returns: @@ -933,7 +928,7 @@ def implementation_class(self) -> Type[BaseDeploymentAppRunner]: """ @property - def requirements(self) -> List[str]: + def requirements(self) -> list[str]: """The software requirements for the deployment app runner. Returns: @@ -961,7 +956,7 @@ def load_app_runner_flavor( ) if settings.deployment_app_runner_flavor is None: - app_runner_flavor_class: Type[BaseDeploymentAppRunnerFlavor] = ( + app_runner_flavor_class: type[BaseDeploymentAppRunnerFlavor] = ( FastAPIDeploymentAppRunnerFlavor ) else: @@ -1032,11 +1027,11 @@ def build_asgi_app() -> ASGIApplication: def start_deployment_app( deployment_id: UUID, - pid_file: Optional[str] = None, - log_file: Optional[str] = None, - host: Optional[str] = None, - port: Optional[int] = None, - reload: Optional[bool] = None, + pid_file: str | None = None, + log_file: str | None = None, + host: str | None = None, + port: int | None = None, + reload: bool | None = None, ) -> None: """Start the deployment app. diff --git a/src/zenml/deployers/server/entrypoint_configuration.py b/src/zenml/deployers/server/entrypoint_configuration.py index ef9f1998dea..107af415f5c 100644 --- a/src/zenml/deployers/server/entrypoint_configuration.py +++ b/src/zenml/deployers/server/entrypoint_configuration.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """ZenML Pipeline Deployment Entrypoint Configuration.""" -from typing import Any, List, Set +from typing import Any from uuid import UUID from zenml.client import Client @@ -39,7 +39,7 @@ class DeploymentEntrypointConfiguration(BaseEntrypointConfiguration): """ @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> set[str]: """Gets all options required for the deployment entrypoint. Returns: @@ -50,7 +50,7 @@ def get_entrypoint_options(cls) -> Set[str]: } @classmethod - def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]: + def get_entrypoint_arguments(cls, **kwargs: Any) -> list[str]: """Gets arguments for the deployment entrypoint command. Args: diff --git a/src/zenml/deployers/server/fastapi/__init__.py b/src/zenml/deployers/server/fastapi/__init__.py index b4ef6cccc46..c16b9e9efa5 100644 --- a/src/zenml/deployers/server/fastapi/__init__.py +++ b/src/zenml/deployers/server/fastapi/__init__.py @@ -14,7 +14,6 @@ """FastAPI implementation of the deployment app factory and adapters.""" -from typing import List, Type from zenml.deployers.server.app import BaseDeploymentAppRunner, BaseDeploymentAppRunnerFlavor FASTAPI_APP_RUNNER_FLAVOR_NAME = "fastapi" @@ -32,7 +31,7 @@ def name(self) -> str: return FASTAPI_APP_RUNNER_FLAVOR_NAME @property - def implementation_class(self) -> Type[BaseDeploymentAppRunner]: + def implementation_class(self) -> type[BaseDeploymentAppRunner]: """The class that implements the deployment app runner. Returns: @@ -42,7 +41,7 @@ def implementation_class(self) -> Type[BaseDeploymentAppRunner]: return FastAPIDeploymentAppRunner @property - def requirements(self) -> List[str]: + def requirements(self) -> list[str]: """The software requirements for the deployment app runner. Returns: diff --git a/src/zenml/deployers/server/fastapi/adapters.py b/src/zenml/deployers/server/fastapi/adapters.py index ed7ade426e9..63065323d5e 100644 --- a/src/zenml/deployers/server/fastapi/adapters.py +++ b/src/zenml/deployers/server/fastapi/adapters.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """FastAPI adapter implementations.""" -from typing import Any, Callable, Dict, Optional +from typing import Any +from collections.abc import Callable from fastapi import APIRouter, Depends, FastAPI, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -50,7 +51,7 @@ def _build_auth_dependency(self, api_key: str) -> Callable[..., Any]: ) def verify_token( - credentials: Optional[HTTPAuthorizationCredentials] = Depends( + credentials: HTTPAuthorizationCredentials | None = Depends( security ), ) -> None: @@ -118,7 +119,7 @@ def register_endpoint( return # Register with appropriate HTTP method - route_kwargs: Dict[str, Any] = {"dependencies": dependencies} + route_kwargs: dict[str, Any] = {"dependencies": dependencies} route_kwargs.update(spec.extra_kwargs) if spec.method == EndpointMethod.GET: diff --git a/src/zenml/deployers/server/fastapi/app.py b/src/zenml/deployers/server/fastapi/app.py index 8a6091f1d26..ff18c367774 100644 --- a/src/zenml/deployers/server/fastapi/app.py +++ b/src/zenml/deployers/server/fastapi/app.py @@ -16,7 +16,8 @@ import os from contextlib import asynccontextmanager from genericpath import isdir, isfile -from typing import Any, AsyncGenerator, Dict, List, Optional, cast +from typing import Any, cast +from collections.abc import AsyncGenerator from anyio import to_thread from asgiref.typing import ( @@ -103,7 +104,7 @@ def _build_cors_middleware(self) -> MiddlewareSpec: native=True, ) - def _get_dashboard_endpoints(self) -> List[EndpointSpec]: + def _get_dashboard_endpoints(self) -> list[EndpointSpec]: """Get the dashboard endpoints specs. This is called if the dashboard files path is set to construct the @@ -120,7 +121,7 @@ def _get_dashboard_endpoints(self) -> List[EndpointSpec]: if not dashboard_files_path or not os.path.isdir(dashboard_files_path): return [] - endpoints: List[EndpointSpec] = [] + endpoints: list[EndpointSpec] = [] async def catch_invalid_api(invalid_api_path: str) -> None: """Invalid API endpoint. @@ -247,9 +248,9 @@ def error_handler(self, request: Request, exc: ValueError) -> JSONResponse: def build( self, - middlewares: List[MiddlewareSpec], - endpoints: List[EndpointSpec], - extensions: List[AppExtensionSpec], + middlewares: list[MiddlewareSpec], + endpoints: list[EndpointSpec], + extensions: list[AppExtensionSpec], ) -> ASGIApplication: """Build the FastAPI app for the deployment. @@ -270,14 +271,14 @@ def build( or f"ZenML pipeline deployment server for the " f"{self.deployment.name} deployment" ) - docs_url_path: Optional[str] = None - redoc_url_path: Optional[str] = None + docs_url_path: str | None = None + redoc_url_path: str | None = None if self.settings.endpoint_enabled(DeploymentDefaultEndpoints.DOCS): docs_url_path = self.settings.docs_url_path if self.settings.endpoint_enabled(DeploymentDefaultEndpoints.REDOC): redoc_url_path = self.settings.redoc_url_path - fastapi_kwargs: Dict[str, Any] = dict( + fastapi_kwargs: dict[str, Any] = dict( title=title, description=description, version=self.settings.app_version diff --git a/src/zenml/deployers/server/models.py b/src/zenml/deployers/server/models.py index b2afd395530..ec5d065fcea 100644 --- a/src/zenml/deployers/server/models.py +++ b/src/zenml/deployers/server/models.py @@ -14,7 +14,7 @@ """FastAPI application models.""" from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any from uuid import UUID from pydantic import BaseModel, Field @@ -30,17 +30,17 @@ class DeploymentInvocationResponseMetadata(BaseModel): deployment_id: UUID = Field(title="The ID of the deployment.") deployment_name: str = Field(title="The name of the deployment.") snapshot_id: UUID = Field(title="The ID of the snapshot.") - snapshot_name: Optional[str] = Field( + snapshot_name: str | None = Field( default=None, title="The name of the snapshot." ) pipeline_name: str = Field(title="The name of the pipeline.") - run_id: Optional[UUID] = Field( + run_id: UUID | None = Field( default=None, title="The ID of the pipeline run." ) - run_name: Optional[str] = Field( + run_name: str | None = Field( default=None, title="The name of the pipeline run." ) - parameters_used: Dict[str, Any] = Field( + parameters_used: dict[str, Any] = Field( title="The parameters used for the pipeline execution." ) @@ -51,7 +51,7 @@ class BaseDeploymentInvocationRequest(BaseModel): parameters: BaseModel = Field( title="The parameters for the pipeline execution." ) - run_name: Optional[str] = Field( + run_name: str | None = Field( default=None, title="Custom name for the pipeline run." ) timeout: int = Field( @@ -70,7 +70,7 @@ class BaseDeploymentInvocationResponse(BaseModel): success: bool = Field( title="Whether the pipeline execution was successful." ) - outputs: Optional[Dict[str, Any]] = Field( + outputs: dict[str, Any] | None = Field( default=None, title="The outputs of the pipeline execution, if the pipeline execution " "was successful.", @@ -81,7 +81,7 @@ class BaseDeploymentInvocationResponse(BaseModel): metadata: DeploymentInvocationResponseMetadata = Field( title="The metadata of the pipeline execution." ) - error: Optional[str] = Field( + error: str | None = Field( default=None, title="The error that occurred, if the pipeline invocation failed.", ) @@ -91,13 +91,13 @@ class PipelineInfo(BaseModel): """Pipeline info model.""" name: str = Field(title="The name of the pipeline.") - parameters: Optional[Dict[str, Any]] = Field( + parameters: dict[str, Any] | None = Field( default=None, title="The parameters of the pipeline." ) - input_schema: Optional[Dict[str, Any]] = Field( + input_schema: dict[str, Any] | None = Field( default=None, title="The input schema of the pipeline." ) - output_schema: Optional[Dict[str, Any]] = Field( + output_schema: dict[str, Any] | None = Field( default=None, title="The output schema of the pipeline." ) @@ -116,7 +116,7 @@ class SnapshotInfo(BaseModel): """Snapshot info model.""" id: UUID = Field(title="The ID of the snapshot.") - name: Optional[str] = Field( + name: str | None = Field( default=None, title="The name of the snapshot." ) @@ -149,7 +149,7 @@ class ServiceInfo(BaseModel): total_executions: int = Field( title="The total number of pipeline executions." ) - last_execution_time: Optional[datetime] = Field( + last_execution_time: datetime | None = Field( default=None, title="The time of the last pipeline execution." ) status: str = Field(title="The status of the pipeline service.") @@ -162,6 +162,6 @@ class ExecutionMetrics(BaseModel): total_executions: int = Field( title="The total number of pipeline executions." ) - last_execution_time: Optional[datetime] = Field( + last_execution_time: datetime | None = Field( default=None, title="The time of the last pipeline execution." ) diff --git a/src/zenml/deployers/server/runtime.py b/src/zenml/deployers/server/runtime.py index 2660136298d..4850eeea5b1 100644 --- a/src/zenml/deployers/server/runtime.py +++ b/src/zenml/deployers/server/runtime.py @@ -23,7 +23,7 @@ """ import contextvars -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel, Field @@ -35,13 +35,13 @@ class _DeploymentState(BaseModel): active: bool = False skip_artifact_materialization: bool = False - request_id: Optional[str] = None - snapshot_id: Optional[str] = None - pipeline_parameters: Dict[str, Any] = Field(default_factory=dict) - outputs: Dict[str, Dict[str, Any]] = Field(default_factory=dict) + request_id: str | None = None + snapshot_id: str | None = None + pipeline_parameters: dict[str, Any] = Field(default_factory=dict) + outputs: dict[str, dict[str, Any]] = Field(default_factory=dict) # In-memory data storage for artifacts - in_memory_data: Dict[str, Any] = Field(default_factory=dict) + in_memory_data: dict[str, Any] = Field(default_factory=dict) def reset(self) -> None: """Reset the deployment state.""" @@ -71,7 +71,7 @@ def _get_context() -> _DeploymentState: def start( request_id: str, snapshot: PipelineSnapshotResponse, - parameters: Dict[str, Any], + parameters: dict[str, Any], skip_artifact_materialization: bool = False, ) -> None: """Initialize deployment state for the current request context. @@ -107,7 +107,7 @@ def is_active() -> bool: return _get_context().active -def record_step_outputs(step_name: str, outputs: Dict[str, Any]) -> None: +def record_step_outputs(step_name: str, outputs: dict[str, Any]) -> None: """Record raw outputs for a step by invocation id. Args: @@ -122,7 +122,7 @@ def record_step_outputs(step_name: str, outputs: Dict[str, Any]) -> None: state.outputs.setdefault(step_name, {}).update(outputs) -def get_outputs() -> Dict[str, Dict[str, Any]]: +def get_outputs() -> dict[str, dict[str, Any]]: """Return the outputs for all steps in the current context. Returns: diff --git a/src/zenml/deployers/server/service.py b/src/zenml/deployers/server/service.py index 97a1ef0f47d..74073b77de3 100644 --- a/src/zenml/deployers/server/service.py +++ b/src/zenml/deployers/server/service.py @@ -21,10 +21,6 @@ TYPE_CHECKING, Annotated, Any, - Dict, - Optional, - Tuple, - Type, ) from uuid import uuid4 @@ -89,7 +85,6 @@ def run_init_hook(cls, snapshot: "PipelineSnapshotResponse") -> None: """ # Bypass the init hook execution because it is run globally by # the deployment service - pass @classmethod def run_cleanup_hook(cls, snapshot: "PipelineSnapshotResponse") -> None: @@ -100,7 +95,6 @@ def run_cleanup_hook(cls, snapshot: "PipelineSnapshotResponse") -> None: """ # Bypass the cleanup hook execution because it is run globally by # the deployment service - pass class BasePipelineDeploymentService(ABC): @@ -186,7 +180,7 @@ def health_check(self) -> None: @property def input_model( self, - ) -> Type[BaseModel]: + ) -> type[BaseModel]: """Construct a Pydantic model representing pipeline input parameters. Load the pipeline class from `pipeline_spec.source` and derive the @@ -228,7 +222,7 @@ def input_model( return model @property - def input_schema(self) -> Dict[str, Any]: + def input_schema(self) -> dict[str, Any]: """Return the JSON schema for pipeline input parameters. Returns: @@ -247,7 +241,7 @@ def input_schema(self) -> Dict[str, Any]: raise RuntimeError("The pipeline input schema is not available.") @property - def output_schema(self) -> Dict[str, Any]: + def output_schema(self) -> dict[str, Any]: """Return the JSON schema for the pipeline outputs. Returns: @@ -267,7 +261,7 @@ def output_schema(self) -> Dict[str, Any]: def get_pipeline_invoke_models( self, - ) -> Tuple[Type[BaseModel], Type[BaseModel]]: + ) -> tuple[type[BaseModel], type[BaseModel]]: """Generate the request and response models for the pipeline invoke endpoint. Returns: @@ -288,7 +282,7 @@ class PipelineInvokeRequest(BaseDeploymentInvocationRequest): class PipelineInvokeResponse(BaseDeploymentInvocationResponse): outputs: Annotated[ - Optional[Dict[str, Any]], + dict[str, Any] | None, WithJsonSchema(self.output_schema, mode="serialization"), ] @@ -315,7 +309,7 @@ def initialize(self) -> None: # Execution tracking self.service_start_time = time.time() - self.last_execution_time: Optional[datetime] = None + self.last_execution_time: datetime | None = None self.total_executions = 0 self.orchestrator_class = SharedLocalOrchestrator @@ -353,7 +347,7 @@ def execute_pipeline( start_time = time.time() logger.info("Starting pipeline execution") - placeholder_run: Optional[PipelineRunResponse] = None + placeholder_run: PipelineRunResponse | None = None try: # Create a placeholder run separately from the actual execution, # so that we have a run ID to include in the response even if the @@ -446,12 +440,11 @@ def get_execution_metrics(self) -> ExecutionMetrics: def health_check(self) -> None: """Check service health.""" - pass def _map_outputs( self, - runtime_outputs: Optional[Dict[str, Dict[str, Any]]] = None, - ) -> Dict[str, Any]: + runtime_outputs: dict[str, dict[str, Any]] | None = None, + ) -> dict[str, Any]: """Map pipeline outputs using centralized runtime processing. Args: @@ -495,8 +488,8 @@ def _map_outputs( def _prepare_execute_with_orchestrator( self, - resolved_params: Dict[str, Any], - ) -> Tuple[PipelineRunResponse, PipelineSnapshotResponse]: + resolved_params: dict[str, Any], + ) -> tuple[PipelineRunResponse, PipelineSnapshotResponse]: """Prepare the execution with the orchestrator. Args: @@ -532,9 +525,9 @@ def _execute_with_orchestrator( self, placeholder_run: PipelineRunResponse, deployment_snapshot: PipelineSnapshotResponse, - resolved_params: Dict[str, Any], + resolved_params: dict[str, Any], skip_artifact_materialization: bool, - ) -> Optional[Dict[str, Dict[str, Any]]]: + ) -> dict[str, dict[str, Any]] | None: """Run the snapshot via the orchestrator and return the concrete run. Args: @@ -575,7 +568,7 @@ def _execute_with_orchestrator( skip_artifact_materialization=skip_artifact_materialization, ) - captured_outputs: Optional[Dict[str, Dict[str, Any]]] = None + captured_outputs: dict[str, dict[str, Any]] | None = None try: # Use the new deployment snapshot with pre-configured settings orchestrator.run( @@ -640,11 +633,11 @@ def _log_initialization_success(self) -> None: def _build_response( self, - resolved_params: Dict[str, Any], + resolved_params: dict[str, Any], start_time: float, - mapped_outputs: Optional[Dict[str, Any]] = None, - placeholder_run: Optional[PipelineRunResponse] = None, - error: Optional[Exception] = None, + mapped_outputs: dict[str, Any] | None = None, + placeholder_run: PipelineRunResponse | None = None, + error: Exception | None = None, ) -> BaseDeploymentInvocationResponse: """Build success response with execution tracking. @@ -662,7 +655,7 @@ def _build_response( self.total_executions += 1 self.last_execution_time = datetime.now(timezone.utc) - run: Optional[PipelineRunResponse] = placeholder_run + run: PipelineRunResponse | None = placeholder_run if placeholder_run: try: # Fetch the concrete run via its id diff --git a/src/zenml/deployers/utils.py b/src/zenml/deployers/utils.py index 74ffe88e8c3..0d94b3b3939 100644 --- a/src/zenml/deployers/utils.py +++ b/src/zenml/deployers/utils.py @@ -14,7 +14,7 @@ """ZenML deployers utilities.""" import json -from typing import Any, Dict, List, Optional, Union +from typing import Any from uuid import UUID import jsonref @@ -45,7 +45,7 @@ def get_deployment_input_schema( deployment: DeploymentResponse, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Get the schema for a deployment's input parameters. Args: @@ -71,7 +71,7 @@ def get_deployment_input_schema( def get_deployment_output_schema( deployment: DeploymentResponse, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Get the schema for a deployment's output parameters. Args: @@ -97,7 +97,7 @@ def get_deployment_output_schema( def get_deployment_invocation_example( deployment: DeploymentResponse, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Generate an example invocation command for a deployment. Args: @@ -133,8 +133,8 @@ def get_deployment_invocation_example( def invoke_deployment( - deployment_name_or_id: Union[str, UUID], - project: Optional[UUID] = None, + deployment_name_or_id: str | UUID, + project: UUID | None = None, timeout: int = 300, # 5 minute timeout **kwargs: Any, ) -> Any: @@ -296,7 +296,7 @@ def invoke_deployment( def deployment_snapshot_request_from_source_snapshot( source_snapshot: PipelineSnapshotResponse, - deployment_parameters: Dict[str, Any], + deployment_parameters: dict[str, Any], ) -> PipelineSnapshotRequest: """Generate a snapshot request for deployment execution. @@ -380,10 +380,10 @@ def deployment_snapshot_request_from_source_snapshot( source_snapshot.pipeline_spec and source_snapshot.pipeline_spec.parameters is not None ): - original_params: Dict[str, Any] = dict( + original_params: dict[str, Any] = dict( source_snapshot.pipeline_spec.parameters ) - merged_params: Dict[str, Any] = original_params.copy() + merged_params: dict[str, Any] = original_params.copy() for k, v in deployment_parameters.items(): if k in original_params: merged_params[k] = v @@ -413,7 +413,7 @@ def deployment_snapshot_request_from_source_snapshot( def load_deployment_requirements( deployment_settings: DeploymentSettings, -) -> List[str]: +) -> list[str]: """Load the software requirements for a deployment. Args: diff --git a/src/zenml/entrypoints/base_entrypoint_configuration.py b/src/zenml/entrypoints/base_entrypoint_configuration.py index c8eb1a2d396..be59c9b779a 100644 --- a/src/zenml/entrypoints/base_entrypoint_configuration.py +++ b/src/zenml/entrypoints/base_entrypoint_configuration.py @@ -17,7 +17,7 @@ import os import sys from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Set +from typing import TYPE_CHECKING, Any, NoReturn from uuid import UUID from zenml.client import Client @@ -57,7 +57,7 @@ class BaseEntrypointConfiguration(ABC): entrypoint_args: The parsed arguments passed to the entrypoint. """ - def __init__(self, arguments: List[str]): + def __init__(self, arguments: list[str]): """Initializes the entrypoint configuration. Args: @@ -66,7 +66,7 @@ def __init__(self, arguments: List[str]): self.entrypoint_args = self._parse_arguments(arguments) @classmethod - def get_entrypoint_command(cls) -> List[str]: + def get_entrypoint_command(cls) -> list[str]: """Returns a command that runs the entrypoint module. This entrypoint module is responsible for running the entrypoint @@ -83,7 +83,7 @@ def get_entrypoint_command(cls) -> List[str]: return DEFAULT_ENTRYPOINT_COMMAND @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> set[str]: """Gets all options required for running with this configuration. Returns: @@ -101,7 +101,7 @@ def get_entrypoint_options(cls) -> Set[str]: def get_entrypoint_arguments( cls, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Gets all arguments that the entrypoint command should be called with. The argument list should be something that @@ -144,7 +144,7 @@ def get_entrypoint_arguments( return arguments @classmethod - def _parse_arguments(cls, arguments: List[str]) -> Dict[str, Any]: + def _parse_arguments(cls, arguments: list[str]) -> dict[str, Any]: """Parses command line arguments. This method will create an `argparse.ArgumentParser` and add required @@ -201,7 +201,7 @@ def load_snapshot(self) -> "PipelineSnapshotResponse": def download_code_if_necessary( self, snapshot: "PipelineSnapshotResponse", - step_name: Optional[str] = None, + step_name: str | None = None, ) -> None: """Downloads user code if necessary. @@ -297,7 +297,7 @@ def download_code_from_code_repository( def _should_download_code( self, snapshot: "PipelineSnapshotResponse", - step_name: Optional[str] = None, + step_name: str | None = None, ) -> bool: """Checks whether code should be downloaded. diff --git a/src/zenml/entrypoints/step_entrypoint_configuration.py b/src/zenml/entrypoints/step_entrypoint_configuration.py index 7ba49332df6..dd3225e7f7e 100644 --- a/src/zenml/entrypoints/step_entrypoint_configuration.py +++ b/src/zenml/entrypoints/step_entrypoint_configuration.py @@ -15,7 +15,7 @@ import os import sys -from typing import TYPE_CHECKING, Any, List, Set +from typing import TYPE_CHECKING, Any from uuid import UUID from zenml.client import Client @@ -115,7 +115,7 @@ def post_run( """ @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> set[str]: """Gets all options required for running with this configuration. Returns: @@ -128,7 +128,7 @@ def get_entrypoint_options(cls) -> Set[str]: def get_entrypoint_arguments( cls, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Gets all arguments that the entrypoint command should be called with. The argument list should be something that diff --git a/src/zenml/environment.py b/src/zenml/environment.py index 1ef2cc5c3c1..d2e50785fa5 100644 --- a/src/zenml/environment.py +++ b/src/zenml/environment.py @@ -16,7 +16,7 @@ import os import platform import subprocess -from typing import Any, Dict, List +from typing import Any import distro @@ -28,7 +28,7 @@ logger = get_logger(__name__) -def get_run_environment_dict() -> Dict[str, Any]: +def get_run_environment_dict() -> dict[str, Any]: """Returns a dictionary of the current run environment. Everything that is returned here will be saved in the DB as @@ -39,7 +39,7 @@ def get_run_environment_dict() -> Dict[str, Any]: Returns: A dictionary of the current run environment. """ - env_dict: Dict[str, Any] = { + env_dict: dict[str, Any] = { "environment": str(get_environment()), **Environment.get_system_info(), "python_version": Environment.python_version(), @@ -113,7 +113,7 @@ def __init__(self) -> None: """ @staticmethod - def get_system_info() -> Dict[str, str]: + def get_system_info() -> dict[str, str]: """Information about the operating system. Returns: @@ -178,7 +178,7 @@ def in_docker() -> bool: return True try: - with open("/proc/1/cgroup", "rt") as ifh: + with open("/proc/1/cgroup") as ifh: info = ifh.read() return "docker" in info except (FileNotFoundError, Exception): @@ -196,7 +196,7 @@ def in_kubernetes() -> bool: return True try: - with open("/proc/1/cgroup", "rt") as ifh: + with open("/proc/1/cgroup") as ifh: info = ifh.read() return "kubepod" in info except (FileNotFoundError, Exception): @@ -364,7 +364,7 @@ def in_lightning_ai_studio() -> bool: ) @staticmethod - def get_python_packages() -> List[str]: + def get_python_packages() -> list[str]: """Returns a list of installed Python packages. Raises: diff --git a/src/zenml/event_hub/base_event_hub.py b/src/zenml/event_hub/base_event_hub.py index ddbc095ada0..d7bf616b747 100644 --- a/src/zenml/event_hub/base_event_hub.py +++ b/src/zenml/event_hub/base_event_hub.py @@ -15,7 +15,8 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any +from collections.abc import Callable from zenml import EventSourceResponse from zenml.config.global_config import GlobalConfiguration @@ -38,7 +39,7 @@ logger = get_logger(__name__) ActionHandlerCallback = Callable[ - [Dict[str, Any], TriggerExecutionResponse, AuthContext], None + [dict[str, Any], TriggerExecutionResponse, AuthContext], None ] @@ -54,7 +55,7 @@ class BaseEventHub(ABC): unaware of each other. """ - action_handlers: Dict[Tuple[str, str], ActionHandlerCallback] = {} + action_handlers: dict[tuple[str, str], ActionHandlerCallback] = {} @property def zen_store(self) -> "SqlZenStore": @@ -133,7 +134,7 @@ def trigger_action( token = JWTToken( user_id=trigger.action.service_account.id, ) - expires: Optional[datetime] = None + expires: datetime | None = None if trigger.action.auth_window: expires = utc_now() + timedelta(minutes=trigger.action.auth_window) encoded_token = token.encode(expires=expires) diff --git a/src/zenml/event_hub/event_hub.py b/src/zenml/event_hub/event_hub.py index 8886653abcf..a07d6073faa 100644 --- a/src/zenml/event_hub/event_hub.py +++ b/src/zenml/event_hub/event_hub.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Base class for all the Event Hub.""" -from typing import List from pydantic import ValidationError @@ -112,7 +111,7 @@ def get_matching_active_triggers_for_event( self, event: BaseEvent, event_source: EventSourceResponse, - ) -> List[TriggerResponse]: + ) -> list[TriggerResponse]: """Get all triggers that match an incoming event. Args: @@ -123,7 +122,7 @@ def get_matching_active_triggers_for_event( The list of matching triggers. """ # get all event sources configured for this flavor - triggers: List[TriggerResponse] = depaginate( + triggers: list[TriggerResponse] = depaginate( self.zen_store.list_triggers, trigger_filter_model=TriggerFilter( project=event_source.project_id, @@ -133,7 +132,7 @@ def get_matching_active_triggers_for_event( hydrate=True, ) - trigger_list: List[TriggerResponse] = [] + trigger_list: list[TriggerResponse] = [] for trigger in triggers: # For now, the matching of trigger filters vs event is implemented diff --git a/src/zenml/event_sources/base_event_source.py b/src/zenml/event_sources/base_event_source.py index 30e36ce074e..7628c250450 100644 --- a/src/zenml/event_sources/base_event_source.py +++ b/src/zenml/event_sources/base_event_source.py @@ -17,8 +17,6 @@ from typing import ( Any, ClassVar, - Dict, - Optional, Type, ) @@ -108,9 +106,9 @@ class provides methods that implementations can use to dispatch events to actions. """ - _event_hub: Optional[BaseEventHub] = None + _event_hub: BaseEventHub | None = None - def __init__(self, event_hub: Optional[BaseEventHub] = None) -> None: + def __init__(self, event_hub: BaseEventHub | None = None) -> None: """Event source handler initialization. Args: @@ -133,7 +131,7 @@ def __init__(self, event_hub: Optional[BaseEventHub] = None) -> None: @property @abstractmethod - def config_class(self) -> Type[EventSourceConfig]: + def config_class(self) -> type[EventSourceConfig]: """Returns the event source configuration class. Returns: @@ -142,7 +140,7 @@ def config_class(self) -> Type[EventSourceConfig]: @property @abstractmethod - def filter_class(self) -> Type[EventFilterConfig]: + def filter_class(self) -> type[EventFilterConfig]: """Returns the event filter configuration class. Returns: @@ -396,7 +394,7 @@ def get_event_source( return event_source def validate_event_source_configuration( - self, event_source_config: Dict[str, Any] + self, event_source_config: dict[str, Any] ) -> EventSourceConfig: """Validate and return the event source configuration. @@ -418,7 +416,7 @@ def validate_event_source_configuration( def validate_event_filter_configuration( self, - configuration: Dict[str, Any], + configuration: dict[str, Any], ) -> EventFilterConfig: """Validate and return the configuration of an event filter. @@ -488,7 +486,6 @@ def _validate_event_source_request( event_source: Event source request. config: Event source configuration instantiated from the request. """ - pass def _process_event_source_request( self, event_source: EventSourceResponse, config: EventSourceConfig @@ -513,7 +510,6 @@ def _process_event_source_request( event_source: Newly created event source config: Event source configuration instantiated from the response. """ - pass def _validate_event_source_update( self, @@ -550,7 +546,6 @@ def _validate_event_source_update( config_update: Event source configuration instantiated from the updated event source. """ - pass def _process_event_source_update( self, @@ -583,13 +578,12 @@ def _process_event_source_update( previous_config: Event source configuration instantiated from the original event source. """ - pass def _process_event_source_delete( self, event_source: EventSourceResponse, config: EventSourceConfig, - force: Optional[bool] = False, + force: bool | None = False, ) -> None: """Process an event source before it is deleted from the database. @@ -610,7 +604,6 @@ def _process_event_source_delete( the deletion. force: Whether to force deprovision the event source. """ - pass def _process_event_source_response( self, event_source: EventSourceResponse, config: EventSourceConfig @@ -635,7 +628,6 @@ def _process_event_source_response( event_source: Event source response. config: Event source configuration instantiated from the response. """ - pass # -------------------- Flavors ---------------------------------- @@ -647,11 +639,11 @@ class BaseEventSourceFlavor(BasePluginFlavor, ABC): TYPE: ClassVar[PluginType] = PluginType.EVENT_SOURCE # EventPlugin specific - EVENT_SOURCE_CONFIG_CLASS: ClassVar[Type[EventSourceConfig]] - EVENT_FILTER_CONFIG_CLASS: ClassVar[Type[EventFilterConfig]] + EVENT_SOURCE_CONFIG_CLASS: ClassVar[type[EventSourceConfig]] + EVENT_FILTER_CONFIG_CLASS: ClassVar[type[EventFilterConfig]] @classmethod - def get_event_filter_config_schema(cls) -> Dict[str, Any]: + def get_event_filter_config_schema(cls) -> dict[str, Any]: """The config schema for a flavor. Returns: @@ -660,7 +652,7 @@ def get_event_filter_config_schema(cls) -> Dict[str, Any]: return cls.EVENT_SOURCE_CONFIG_CLASS.model_json_schema() @classmethod - def get_event_source_config_schema(cls) -> Dict[str, Any]: + def get_event_source_config_schema(cls) -> dict[str, Any]: """The config schema for a flavor. Returns: diff --git a/src/zenml/event_sources/webhooks/base_webhook_event_source.py b/src/zenml/event_sources/webhooks/base_webhook_event_source.py index 5db1964df63..07c21c596e6 100644 --- a/src/zenml/event_sources/webhooks/base_webhook_event_source.py +++ b/src/zenml/event_sources/webhooks/base_webhook_event_source.py @@ -17,7 +17,7 @@ import hmac import json from abc import ABC, abstractmethod -from typing import Any, ClassVar, Dict, Optional, Type +from typing import Any, ClassVar, Type from zenml.enums import PluginSubType from zenml.event_sources.base_event import BaseEvent @@ -60,7 +60,7 @@ class BaseWebhookEventSourceHandler(BaseEventSourceHandler, ABC): @property @abstractmethod - def config_class(self) -> Type[WebhookEventSourceConfig]: + def config_class(self) -> type[WebhookEventSourceConfig]: """Returns the webhook event source configuration class. Returns: @@ -69,7 +69,7 @@ def config_class(self) -> Type[WebhookEventSourceConfig]: @property @abstractmethod - def filter_class(self) -> Type[WebhookEventFilterConfig]: + def filter_class(self) -> type[WebhookEventFilterConfig]: """Returns the webhook event filter configuration class. Returns: @@ -110,7 +110,7 @@ def is_valid_signature( return True @abstractmethod - def _interpret_event(self, event: Dict[str, Any]) -> BaseEvent: + def _interpret_event(self, event: dict[str, Any]) -> BaseEvent: """Converts the generic event body into a event-source specific pydantic model. Args: @@ -123,7 +123,7 @@ def _interpret_event(self, event: Dict[str, Any]) -> BaseEvent: @abstractmethod def _get_webhook_secret( self, event_source: EventSourceResponse - ) -> Optional[str]: + ) -> str | None: """Get the webhook secret for the event source. Inheriting classes should implement this method to retrieve the webhook @@ -139,7 +139,7 @@ def _get_webhook_secret( """ def _validate_webhook_event_signature( - self, raw_body: bytes, headers: Dict[str, str], webhook_secret: str + self, raw_body: bytes, headers: dict[str, str], webhook_secret: str ) -> None: """Validate the signature of an incoming webhook event. @@ -169,8 +169,8 @@ def _validate_webhook_event_signature( ) def _load_payload( - self, raw_body: bytes, headers: Dict[str, str] - ) -> Dict[Any, Any]: + self, raw_body: bytes, headers: dict[str, str] + ) -> dict[Any, Any]: """Converts the raw body of the request into a python dictionary. Args: @@ -186,7 +186,7 @@ def _load_payload( # For now assume all webhook events are json encoded and parse # the body as such. try: - body_dict: Dict[Any, Any] = json.loads(raw_body) + body_dict: dict[Any, Any] = json.loads(raw_body) return body_dict except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON body received: {e}") @@ -195,7 +195,7 @@ def process_webhook_event( self, event_source: EventSourceResponse, raw_body: bytes, - headers: Dict[str, str], + headers: dict[str, str], ) -> None: """Process an incoming webhook event. diff --git a/src/zenml/exceptions.py b/src/zenml/exceptions.py index 9988d0cef7a..c6f9687f435 100644 --- a/src/zenml/exceptions.py +++ b/src/zenml/exceptions.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """ZenML specific exception definitions.""" -from typing import Dict, Optional class ZenMLBaseException(Exception): @@ -21,8 +20,8 @@ class ZenMLBaseException(Exception): def __init__( self, - message: Optional[str] = None, - url: Optional[str] = None, + message: str | None = None, + url: str | None = None, ): """The BaseException used to format messages displayed to the user. @@ -177,8 +176,8 @@ def __init__( self, error: str, status_code: int = 400, - error_description: Optional[str] = None, - error_uri: Optional[str] = None, + error_description: str | None = None, + error_uri: str | None = None, ) -> None: """Initializes the OAuthError. @@ -193,7 +192,7 @@ def __init__( self.error_description = error_description self.error_uri = error_uri - def to_dict(self) -> Dict[str, Optional[str]]: + def to_dict(self) -> dict[str, str | None]: """Returns the OAuthError as a dictionary. Returns: diff --git a/src/zenml/experiment_trackers/base_experiment_tracker.py b/src/zenml/experiment_trackers/base_experiment_tracker.py index c54aa0e53bd..3209d67ee01 100644 --- a/src/zenml/experiment_trackers/base_experiment_tracker.py +++ b/src/zenml/experiment_trackers/base_experiment_tracker.py @@ -14,7 +14,7 @@ """Base class for all ZenML experiment trackers.""" from abc import ABC, abstractmethod -from typing import Type, cast +from typing import cast from zenml.enums import StackComponentType from zenml.stack import Flavor, StackComponent @@ -51,7 +51,7 @@ def type(self) -> StackComponentType: return StackComponentType.EXPERIMENT_TRACKER @property - def config_class(self) -> Type[BaseExperimentTrackerConfig]: + def config_class(self) -> type[BaseExperimentTrackerConfig]: """Config class for this flavor. Returns: @@ -61,7 +61,7 @@ def config_class(self) -> Type[BaseExperimentTrackerConfig]: @property @abstractmethod - def implementation_class(self) -> Type[StackComponent]: + def implementation_class(self) -> type[StackComponent]: """Returns the implementation class for this flavor. Returns: diff --git a/src/zenml/feature_stores/base_feature_store.py b/src/zenml/feature_stores/base_feature_store.py index 0ab6304310e..f36e5c7a8d7 100644 --- a/src/zenml/feature_stores/base_feature_store.py +++ b/src/zenml/feature_stores/base_feature_store.py @@ -14,7 +14,7 @@ """The base class for feature stores.""" from abc import ABC, abstractmethod -from typing import Any, Dict, List, Type, cast +from typing import Any, cast from zenml.enums import StackComponentType from zenml.stack import Flavor, StackComponent @@ -41,7 +41,7 @@ def config(self) -> BaseFeatureStoreConfig: def get_historical_features( self, entity_df: Any, - features: List[str], + features: list[str], full_feature_names: bool = False, ) -> Any: """Returns the historical features for training or batch scoring. @@ -58,10 +58,10 @@ def get_historical_features( @abstractmethod def get_online_features( self, - entity_rows: List[Dict[str, Any]], - features: List[str], + entity_rows: list[dict[str, Any]], + features: list[str], full_feature_names: bool = False, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Returns the latest online feature data. Args: @@ -87,7 +87,7 @@ def type(self) -> StackComponentType: return StackComponentType.FEATURE_STORE @property - def config_class(self) -> Type[BaseFeatureStoreConfig]: + def config_class(self) -> type[BaseFeatureStoreConfig]: """Config class for this flavor. Returns: @@ -97,7 +97,7 @@ def config_class(self) -> Type[BaseFeatureStoreConfig]: @property @abstractmethod - def implementation_class(self) -> Type[BaseFeatureStore]: + def implementation_class(self) -> type[BaseFeatureStore]: """Implementation class. Returns: diff --git a/src/zenml/hooks/hook_validators.py b/src/zenml/hooks/hook_validators.py index cdf7d5f55b6..b34cb432588 100644 --- a/src/zenml/hooks/hook_validators.py +++ b/src/zenml/hooks/hook_validators.py @@ -16,12 +16,9 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - Optional, - Tuple, Union, ) +from collections.abc import Callable from pydantic import ConfigDict, ValidationError @@ -40,9 +37,9 @@ def _validate_hook_arguments( _func: Callable[..., Any], - hook_kwargs: Dict[str, Any], - exception_arg: Union[BaseException, bool] = False, -) -> Dict[str, Any]: + hook_kwargs: dict[str, Any], + exception_arg: BaseException | bool = False, +) -> dict[str, Any]: """Validates hook arguments. Args: @@ -58,7 +55,7 @@ def _validate_hook_arguments( """ # Validate hook arguments try: - hook_args: Tuple[Any, ...] = () + hook_args: tuple[Any, ...] = () if isinstance(exception_arg, BaseException): hook_args = (exception_arg,) elif exception_arg is True: @@ -98,9 +95,9 @@ def _validate_hook_arguments( def resolve_and_validate_hook( hook: Union["HookSpecification", "InitHookSpecification"], - hook_kwargs: Optional[Dict[str, Any]] = None, + hook_kwargs: dict[str, Any] | None = None, allow_exception_arg: bool = False, -) -> Tuple[Source, Optional[Dict[str, Any]]]: +) -> tuple[Source, dict[str, Any] | None]: """Resolves and validates a hook callback and its arguments. Args: @@ -135,8 +132,8 @@ def resolve_and_validate_hook( def load_and_run_hook( hook_source: "Source", - hook_parameters: Optional[Dict[str, Any]] = None, - step_exception: Optional[BaseException] = None, + hook_parameters: dict[str, Any] | None = None, + step_exception: BaseException | None = None, raise_on_error: bool = False, ) -> Any: """Loads hook source and runs the hook. diff --git a/src/zenml/image_builders/base_image_builder.py b/src/zenml/image_builders/base_image_builder.py index 4ad38cd8652..43f9d840452 100644 --- a/src/zenml/image_builders/base_image_builder.py +++ b/src/zenml/image_builders/base_image_builder.py @@ -17,7 +17,7 @@ import os import tempfile from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, cast +from typing import TYPE_CHECKING, Any, Optional, cast from zenml.client import Client from zenml.enums import StackComponentType @@ -51,7 +51,7 @@ def config(self) -> BaseImageBuilderConfig: return cast(BaseImageBuilderConfig, self._config) @property - def build_context_class(self) -> Type["BuildContext"]: + def build_context_class(self) -> type["BuildContext"]: """Build context class to use. The default build context class creates a build context that works @@ -79,7 +79,7 @@ def build( self, image_name: str, build_context: "BuildContext", - docker_build_options: Dict[str, Any], + docker_build_options: dict[str, Any], container_registry: Optional["BaseContainerRegistry"] = None, ) -> str: """Builds a Docker image. @@ -155,7 +155,7 @@ def type(self) -> StackComponentType: return StackComponentType.IMAGE_BUILDER @property - def config_class(self) -> Type[BaseImageBuilderConfig]: + def config_class(self) -> type[BaseImageBuilderConfig]: """Config class. Returns: @@ -164,7 +164,7 @@ def config_class(self) -> Type[BaseImageBuilderConfig]: return BaseImageBuilderConfig @property - def implementation_class(self) -> Type[BaseImageBuilder]: + def implementation_class(self) -> type[BaseImageBuilder]: """Implementation class. Returns: diff --git a/src/zenml/image_builders/build_context.py b/src/zenml/image_builders/build_context.py index 610348ef1a1..01f757b3e23 100644 --- a/src/zenml/image_builders/build_context.py +++ b/src/zenml/image_builders/build_context.py @@ -14,7 +14,7 @@ """Image build context.""" import os -from typing import IO, Dict, List, Optional, Set, cast +from typing import IO, cast from zenml.constants import REPOSITORY_DIRECTORY_NAME from zenml.io import fileio @@ -34,8 +34,8 @@ class BuildContext(Archivable): def __init__( self, - root: Optional[str] = None, - dockerignore_file: Optional[str] = None, + root: str | None = None, + dockerignore_file: str | None = None, ) -> None: """Initializes a build context. @@ -50,7 +50,7 @@ def __init__( self._dockerignore_file = dockerignore_file @property - def dockerignore_file(self) -> Optional[str]: + def dockerignore_file(self) -> str | None: """The dockerignore file to use. Returns: @@ -98,7 +98,7 @@ def write_archive( os.path.join(self._root, ".dockerignore"), ) - def get_files(self) -> Dict[str, str]: + def get_files(self) -> dict[str, str]: """Gets all regular files that should be included in the archive. Returns: @@ -111,7 +111,7 @@ def get_files(self) -> Dict[str, str]: exclude_patterns = self._get_exclude_patterns() archive_paths = cast( - Set[str], + set[str], docker_build_utils.exclude_paths( self._root, patterns=exclude_patterns ), @@ -123,7 +123,7 @@ def get_files(self) -> Dict[str, str]: else: return {} - def _get_exclude_patterns(self) -> List[str]: + def _get_exclude_patterns(self) -> list[str]: """Gets all exclude patterns from the dockerignore file. Returns: @@ -143,7 +143,7 @@ def _get_exclude_patterns(self) -> List[str]: return [] @staticmethod - def _parse_dockerignore(dockerignore_path: str) -> List[str]: + def _parse_dockerignore(dockerignore_path: str) -> list[str]: """Parses a dockerignore file and returns a list of patterns to ignore. Args: diff --git a/src/zenml/image_builders/local_image_builder.py b/src/zenml/image_builders/local_image_builder.py index 91aab432c32..7238a569ea1 100644 --- a/src/zenml/image_builders/local_image_builder.py +++ b/src/zenml/image_builders/local_image_builder.py @@ -15,7 +15,7 @@ import shutil import tempfile -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, cast +from typing import TYPE_CHECKING, Any, Optional, cast from zenml.image_builders import ( BaseImageBuilder, @@ -93,7 +93,7 @@ def build( self, image_name: str, build_context: "BuildContext", - docker_build_options: Optional[Dict[str, Any]] = None, + docker_build_options: dict[str, Any] | None = None, container_registry: Optional["BaseContainerRegistry"] = None, ) -> str: """Builds and optionally pushes an image using the local Docker client. @@ -147,7 +147,7 @@ def name(self) -> str: return "local" @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -156,7 +156,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -174,7 +174,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/image_builder/local.svg" @property - def config_class(self) -> Type[LocalImageBuilderConfig]: + def config_class(self) -> type[LocalImageBuilderConfig]: """Config class. Returns: @@ -183,7 +183,7 @@ def config_class(self) -> Type[LocalImageBuilderConfig]: return LocalImageBuilderConfig @property - def implementation_class(self) -> Type[LocalImageBuilder]: + def implementation_class(self) -> type[LocalImageBuilder]: """Implementation class. Returns: diff --git a/src/zenml/integrations/airflow/__init__.py b/src/zenml/integrations/airflow/__init__.py index 43c14e70dca..4fd5195bef1 100644 --- a/src/zenml/integrations/airflow/__init__.py +++ b/src/zenml/integrations/airflow/__init__.py @@ -15,7 +15,6 @@ The Airflow integration powers an alternative orchestrator. """ -from typing import List, Type from zenml.integrations.constants import AIRFLOW from zenml.integrations.integration import Integration @@ -31,7 +30,7 @@ class AirflowIntegration(Integration): REQUIREMENTS = [] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Airflow integration. Returns: diff --git a/src/zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py b/src/zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py index 1c6f0987bcf..a634124b78f 100644 --- a/src/zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py +++ b/src/zenml/integrations/airflow/flavors/airflow_orchestrator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Airflow orchestrator flavor.""" -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any from pydantic import field_validator @@ -76,16 +76,16 @@ class AirflowOrchestratorSettings(BaseSettings): pipeline and ignored if defined on a step. """ - dag_output_dir: Optional[str] = None + dag_output_dir: str | None = None - dag_id: Optional[str] = None - dag_tags: List[str] = [] - dag_args: Dict[str, Any] = {} + dag_id: str | None = None + dag_tags: list[str] = [] + dag_args: dict[str, Any] = {} operator: str = OperatorType.DOCKER.source - operator_args: Dict[str, Any] = {} + operator_args: dict[str, Any] = {} - custom_dag_generator: Optional[str] = None + custom_dag_generator: str | None = None @field_validator("operator", mode="before") @classmethod @@ -151,7 +151,7 @@ def name(self) -> str: return AIRFLOW_ORCHESTRATOR_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -160,7 +160,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -178,7 +178,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/airflow.png" @property - def config_class(self) -> Type[AirflowOrchestratorConfig]: + def config_class(self) -> type[AirflowOrchestratorConfig]: """Returns `AirflowOrchestratorConfig` config class. Returns: @@ -187,7 +187,7 @@ def config_class(self) -> Type[AirflowOrchestratorConfig]: return AirflowOrchestratorConfig @property - def implementation_class(self) -> Type["AirflowOrchestrator"]: + def implementation_class(self) -> type["AirflowOrchestrator"]: """Implementation class. Returns: diff --git a/src/zenml/integrations/airflow/orchestrators/airflow_orchestrator.py b/src/zenml/integrations/airflow/orchestrators/airflow_orchestrator.py index 647ea86cf05..2102dbdbfb5 100644 --- a/src/zenml/integrations/airflow/orchestrators/airflow_orchestrator.py +++ b/src/zenml/integrations/airflow/orchestrators/airflow_orchestrator.py @@ -20,11 +20,8 @@ from typing import ( TYPE_CHECKING, Any, - Dict, NamedTuple, Optional, - Tuple, - Type, cast, ) @@ -67,12 +64,12 @@ class DagGeneratorValues(NamedTuple): file: str config_file_name: str run_id_env_variable_name: str - dag_configuration_class: Type["DagConfiguration"] - task_configuration_class: Type["TaskConfiguration"] + dag_configuration_class: type["DagConfiguration"] + task_configuration_class: type["TaskConfiguration"] def get_dag_generator_values( - custom_dag_generator_source: Optional[str] = None, + custom_dag_generator_source: str | None = None, ) -> DagGeneratorValues: """Gets values from the DAG generator module. @@ -126,7 +123,7 @@ def config(self) -> AirflowOrchestratorConfig: return cast(AirflowOrchestratorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Kubeflow orchestrator. Returns: @@ -151,7 +148,7 @@ def validator(self) -> Optional["StackValidator"]: def _validate_remote_components( stack: "Stack", - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: for component in stack.components.values(): if not component.config.is_local: continue @@ -181,10 +178,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to @@ -283,7 +280,7 @@ def submit_pipeline( def _apply_resource_settings( self, resource_settings: "ResourceSettings", - operator_args: Dict[str, Any], + operator_args: dict[str, Any], ) -> None: """Adds resource settings to the operator args. @@ -399,7 +396,7 @@ def get_orchestrator_run_id(self) -> str: @staticmethod def _translate_schedule( schedule: Optional["ScheduleResponse"] = None, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Convert ZenML schedule into Airflow schedule. The Airflow schedule uses slightly different naming and needs some diff --git a/src/zenml/integrations/airflow/orchestrators/dag_generator.py b/src/zenml/integrations/airflow/orchestrators/dag_generator.py index a1c7feb9bca..ac65fab8eef 100644 --- a/src/zenml/integrations/airflow/orchestrators/dag_generator.py +++ b/src/zenml/integrations/airflow/orchestrators/dag_generator.py @@ -17,7 +17,7 @@ import importlib import os import zipfile -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any from pydantic import BaseModel, Field @@ -31,38 +31,38 @@ class TaskConfiguration(BaseModel): id: str zenml_step_name: str - upstream_steps: List[str] + upstream_steps: list[str] docker_image: str - command: List[str] - arguments: List[str] + command: list[str] + arguments: list[str] - environment: Dict[str, str] = {} + environment: dict[str, str] = {} operator_source: str - operator_args: Dict[str, Any] = {} + operator_args: dict[str, Any] = {} class DagConfiguration(BaseModel): """Airflow DAG configuration.""" id: str - tasks: List[TaskConfiguration] + tasks: list[TaskConfiguration] - local_stores_path: Optional[str] = None + local_stores_path: str | None = None - schedule: Union[datetime.timedelta, str] = Field( + schedule: datetime.timedelta | str = Field( union_mode="left_to_right" ) start_date: datetime.datetime - end_date: Optional[datetime.datetime] = None + end_date: datetime.datetime | None = None catchup: bool = False - tags: List[str] = [] - dag_args: Dict[str, Any] = {} + tags: list[str] = [] + dag_args: dict[str, Any] = {} -def import_class_by_path(class_path: str) -> Type[Any]: +def import_class_by_path(class_path: str) -> type[Any]: """Imports a class based on a given path. Args: @@ -77,10 +77,10 @@ def import_class_by_path(class_path: str) -> Type[Any]: def get_operator_init_kwargs( - operator_class: Type[Any], + operator_class: type[Any], dag_config: DagConfiguration, task_config: TaskConfiguration, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Gets keyword arguments to pass to the operator init method. Args: @@ -141,7 +141,7 @@ def get_operator_init_kwargs( def get_docker_operator_init_kwargs( dag_config: DagConfiguration, task_config: TaskConfiguration -) -> Dict[str, Any]: +) -> dict[str, Any]: """Gets keyword arguments to pass to the DockerOperator. Args: @@ -179,7 +179,7 @@ def get_docker_operator_init_kwargs( def get_kubernetes_pod_operator_init_kwargs( dag_config: DagConfiguration, task_config: TaskConfiguration -) -> Dict[str, Any]: +) -> dict[str, Any]: """Gets keyword arguments to pass to the KubernetesPodOperator. Args: diff --git a/src/zenml/integrations/argilla/__init__.py b/src/zenml/integrations/argilla/__init__.py index a1584f31e22..373cadd4299 100644 --- a/src/zenml/integrations/argilla/__init__.py +++ b/src/zenml/integrations/argilla/__init__.py @@ -12,7 +12,6 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Initialization of the Argilla integration.""" -from typing import List, Type from zenml.integrations.constants import ARGILLA from zenml.integrations.integration import Integration @@ -30,7 +29,7 @@ class ArgillaIntegration(Integration): ] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Argilla integration. Returns: diff --git a/src/zenml/integrations/argilla/annotators/argilla_annotator.py b/src/zenml/integrations/argilla/annotators/argilla_annotator.py index fe04ce40dc5..08c30d6c11c 100644 --- a/src/zenml/integrations/argilla/annotators/argilla_annotator.py +++ b/src/zenml/integrations/argilla/annotators/argilla_annotator.py @@ -15,7 +15,7 @@ import json import webbrowser -from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast +from typing import Any, cast import argilla as rg from argilla._exceptions._api import ArgillaAPIError @@ -47,7 +47,7 @@ def config(self) -> ArgillaAnnotatorConfig: return cast(ArgillaAnnotatorConfig, self._config) @property - def settings_class(self) -> Type[ArgillaAnnotatorSettings]: + def settings_class(self) -> type[ArgillaAnnotatorSettings]: """Settings class for the Argilla annotator. Returns: @@ -124,7 +124,7 @@ def get_url_for_dataset(self, dataset_name: str, **kwargs: Any) -> str: ).id return f"{self.get_url()}/dataset/{dataset_id}/annotation-mode" - def get_datasets(self, **kwargs: Any) -> List[Any]: + def get_datasets(self, **kwargs: Any) -> list[Any]: """Gets the datasets currently available for annotation. Args: @@ -144,7 +144,7 @@ def get_datasets(self, **kwargs: Any) -> List[Any]: return datasets - def get_dataset_names(self, **kwargs: Any) -> List[str]: + def get_dataset_names(self, **kwargs: Any) -> list[str]: """Gets the names of the datasets. Args: @@ -168,7 +168,7 @@ def get_dataset_names(self, **kwargs: Any) -> List[str]: return dataset_names def _get_data_by_status( - self, dataset_name: str, status: str, workspace: Optional[str] + self, dataset_name: str, status: str, workspace: str | None ) -> Any: """Gets the dataset containing the data with the specified status. @@ -196,7 +196,7 @@ def _get_data_by_status( def get_dataset_stats( self, dataset_name: str, **kwargs: Any - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """Gets the statistics of the given dataset. Args: @@ -299,9 +299,9 @@ def add_dataset(self, **kwargs: Any) -> Any: def add_records( self, dataset_name: str, - records: Union[Any, List[Dict[str, Any]]], - workspace: Optional[str] = None, - mapping: Optional[Dict[str, str]] = None, + records: Any | list[dict[str, Any]], + workspace: str | None = None, + mapping: dict[str, str] | None = None, ) -> Any: """Add records to an Argilla dataset for annotation. diff --git a/src/zenml/integrations/argilla/flavors/argilla_annotator_flavor.py b/src/zenml/integrations/argilla/flavors/argilla_annotator_flavor.py index 649c9eb4cd7..ae7e1e66f82 100644 --- a/src/zenml/integrations/argilla/flavors/argilla_annotator_flavor.py +++ b/src/zenml/integrations/argilla/flavors/argilla_annotator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Argilla annotator flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from pydantic import field_validator @@ -50,12 +50,12 @@ class ArgillaAnnotatorSettings(BaseSettings): """ instance_url: str = DEFAULT_LOCAL_INSTANCE_URL - api_key: Optional[str] = SecretField(default=None) - port: Optional[int] = DEFAULT_LOCAL_ARGILLA_PORT - headers: Optional[str] = None - httpx_extra_kwargs: Optional[str] = None + api_key: str | None = SecretField(default=None) + port: int | None = DEFAULT_LOCAL_ARGILLA_PORT + headers: str | None = None + httpx_extra_kwargs: str | None = None - extra_headers: Optional[str] = None + extra_headers: str | None = None _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes( ("extra_headers", "headers"), @@ -101,7 +101,7 @@ def name(self) -> str: return ARGILLA_ANNOTATOR_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -110,7 +110,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -128,7 +128,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/annotator/argilla.png" @property - def config_class(self) -> Type[ArgillaAnnotatorConfig]: + def config_class(self) -> type[ArgillaAnnotatorConfig]: """Returns `ArgillaAnnotatorConfig` config class. Returns: @@ -137,7 +137,7 @@ def config_class(self) -> Type[ArgillaAnnotatorConfig]: return ArgillaAnnotatorConfig @property - def implementation_class(self) -> Type["ArgillaAnnotator"]: + def implementation_class(self) -> type["ArgillaAnnotator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/aws/__init__.py b/src/zenml/integrations/aws/__init__.py index 28cd8a82a85..12befb4f21d 100644 --- a/src/zenml/integrations/aws/__init__.py +++ b/src/zenml/integrations/aws/__init__.py @@ -18,7 +18,6 @@ Sagemaker integration submodule provides a way to run ZenML steps in Sagemaker. """ -from typing import List, Type from zenml.integrations.constants import AWS from zenml.integrations.integration import Integration @@ -54,7 +53,7 @@ def activate(cls) -> None: from zenml.integrations.aws import service_connectors # noqa @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the AWS integration. Returns: diff --git a/src/zenml/integrations/aws/container_registries/aws_container_registry.py b/src/zenml/integrations/aws/container_registries/aws_container_registry.py index 0e1e8bfb694..f89a6866734 100644 --- a/src/zenml/integrations/aws/container_registries/aws_container_registry.py +++ b/src/zenml/integrations/aws/container_registries/aws_container_registry.py @@ -14,7 +14,7 @@ """Implementation of the AWS container registry integration.""" import re -from typing import List, Optional, cast +from typing import cast import boto3 from botocore.client import BaseClient @@ -130,7 +130,7 @@ def prepare_image_push(self, image_name: str) -> None: return try: - repo_uris: List[str] = [ + repo_uris: list[str] = [ repository["repositoryUri"] for repository in response["repositories"] ] @@ -153,7 +153,7 @@ def prepare_image_push(self, image_name: str) -> None: ) @property - def post_registration_message(self) -> Optional[str]: + def post_registration_message(self) -> str | None: """Optional message printed after the stack component is registered. Returns: diff --git a/src/zenml/integrations/aws/deployers/aws_deployer.py b/src/zenml/integrations/aws/deployers/aws_deployer.py index 80c0f6fee60..c8ac11e90ba 100644 --- a/src/zenml/integrations/aws/deployers/aws_deployer.py +++ b/src/zenml/integrations/aws/deployers/aws_deployer.py @@ -19,14 +19,9 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - Generator, - List, - Optional, - Tuple, - Type, cast, ) +from collections.abc import Generator from uuid import UUID import boto3 @@ -79,47 +74,47 @@ class AppRunnerDeploymentMetadata(BaseModel): """Metadata for an App Runner deployment.""" - service_name: Optional[str] = None - service_arn: Optional[str] = None - service_url: Optional[str] = None - region: Optional[str] = None - service_id: Optional[str] = None - status: Optional[str] = None - source_configuration: Optional[Dict[str, Any]] = None - instance_configuration: Optional[Dict[str, Any]] = None - auto_scaling_configuration_summary: Optional[Dict[str, Any]] = None - auto_scaling_configuration_arn: Optional[str] = None - health_check_configuration: Optional[Dict[str, Any]] = None - network_configuration: Optional[Dict[str, Any]] = None - observability_configuration: Optional[Dict[str, Any]] = None - encryption_configuration: Optional[Dict[str, Any]] = None - cpu: Optional[str] = None - memory: Optional[str] = None - port: Optional[int] = None - auto_scaling_max_concurrency: Optional[int] = None - auto_scaling_max_size: Optional[int] = None - auto_scaling_min_size: Optional[int] = None - is_publicly_accessible: Optional[bool] = None - health_check_grace_period_seconds: Optional[int] = None - health_check_interval_seconds: Optional[int] = None - health_check_path: Optional[str] = None - health_check_protocol: Optional[str] = None - health_check_timeout_seconds: Optional[int] = None - health_check_healthy_threshold: Optional[int] = None - health_check_unhealthy_threshold: Optional[int] = None - tags: Optional[Dict[str, str]] = None - traffic_allocation: Optional[Dict[str, int]] = None - created_at: Optional[str] = None - updated_at: Optional[str] = None - deleted_at: Optional[str] = None - secret_arn: Optional[str] = None + service_name: str | None = None + service_arn: str | None = None + service_url: str | None = None + region: str | None = None + service_id: str | None = None + status: str | None = None + source_configuration: dict[str, Any] | None = None + instance_configuration: dict[str, Any] | None = None + auto_scaling_configuration_summary: dict[str, Any] | None = None + auto_scaling_configuration_arn: str | None = None + health_check_configuration: dict[str, Any] | None = None + network_configuration: dict[str, Any] | None = None + observability_configuration: dict[str, Any] | None = None + encryption_configuration: dict[str, Any] | None = None + cpu: str | None = None + memory: str | None = None + port: int | None = None + auto_scaling_max_concurrency: int | None = None + auto_scaling_max_size: int | None = None + auto_scaling_min_size: int | None = None + is_publicly_accessible: bool | None = None + health_check_grace_period_seconds: int | None = None + health_check_interval_seconds: int | None = None + health_check_path: str | None = None + health_check_protocol: str | None = None + health_check_timeout_seconds: int | None = None + health_check_healthy_threshold: int | None = None + health_check_unhealthy_threshold: int | None = None + tags: dict[str, str] | None = None + traffic_allocation: dict[str, int] | None = None + created_at: str | None = None + updated_at: str | None = None + deleted_at: str | None = None + secret_arn: str | None = None @classmethod def from_app_runner_service( cls, - service: Dict[str, Any], + service: dict[str, Any], region: str, - secret_arn: Optional[str] = None, + secret_arn: str | None = None, ) -> "AppRunnerDeploymentMetadata": """Create metadata from an App Runner service. @@ -253,11 +248,11 @@ def from_deployment( class AWSDeployer(ContainerizedDeployer): """Deployer responsible for deploying pipelines on AWS App Runner.""" - _boto_session: Optional[boto3.Session] = None - _region: Optional[str] = None - _app_runner_client: Optional[Any] = None - _secrets_manager_client: Optional[Any] = None - _logs_client: Optional[Any] = None + _boto_session: boto3.Session | None = None + _region: str | None = None + _app_runner_client: Any | None = None + _secrets_manager_client: Any | None = None + _logs_client: Any | None = None @property def config(self) -> AWSDeployerConfig: @@ -269,7 +264,7 @@ def config(self) -> AWSDeployerConfig: return cast(AWSDeployerConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the AWS deployer. Returns: @@ -278,7 +273,7 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: return AWSDeployerSettings @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Ensures there is an image builder in the stack. Returns: @@ -291,7 +286,7 @@ def validator(self) -> Optional[StackValidator]: } ) - def _get_boto_session_and_region(self) -> Tuple[boto3.Session, str]: + def _get_boto_session_and_region(self) -> tuple[boto3.Session, str]: """Get an authenticated boto3 session and determine the region. Returns: @@ -389,7 +384,7 @@ def get_tags( self, deployment: DeploymentResponse, settings: AWSDeployerSettings, - ) -> List[Dict[str, str]]: + ) -> list[dict[str, str]]: """Get the tags for a deployment to be used for AWS resources. Args: @@ -604,7 +599,7 @@ def _create_or_update_secret( f"Failed to create/update secret {secret_name}: {e}" ) - def _get_secret_arn(self, deployment: DeploymentResponse) -> Optional[str]: + def _get_secret_arn(self, deployment: DeploymentResponse) -> str | None: """Get the existing AWS Secrets Manager secret ARN for a deployment. Args: @@ -819,10 +814,10 @@ def _cleanup_deployment_auto_scaling_config( def _prepare_environment_variables( self, deployment: DeploymentResponse, - environment: Dict[str, str], - secrets: Dict[str, str], + environment: dict[str, str], + secrets: dict[str, str], settings: AWSDeployerSettings, - ) -> Tuple[Dict[str, str], Dict[str, str], Optional[str]]: + ) -> tuple[dict[str, str], dict[str, str], str | None]: """Prepare environment variables for App Runner, handling secrets appropriately. Args: @@ -838,7 +833,7 @@ def _prepare_environment_variables( - Optional secret ARN (None if no secrets or fallback to env vars). """ secret_refs = {} - active_secret_arn: Optional[str] = None + active_secret_arn: str | None = None env_vars = {**settings.environment_variables, **environment} @@ -894,7 +889,7 @@ def _prepare_environment_variables( def _get_app_runner_service( self, deployment: DeploymentResponse - ) -> Optional[Dict[str, Any]]: + ) -> dict[str, Any] | None: """Get an existing App Runner service for a deployment. Args: @@ -925,9 +920,9 @@ def _get_app_runner_service( def _get_service_operational_state( self, - service: Dict[str, Any], + service: dict[str, Any], region: str, - secret_arn: Optional[str] = None, + secret_arn: str | None = None, ) -> DeploymentOperationalState: """Get the operational state of an App Runner service. @@ -982,7 +977,7 @@ def _get_service_operational_state( def _requires_service_replacement( self, - existing_service: Dict[str, Any], + existing_service: dict[str, Any], settings: AWSDeployerSettings, ) -> bool: """Check if the service configuration requires replacement. @@ -1029,9 +1024,9 @@ def _requires_service_replacement( def _convert_resource_settings_to_aws_format( self, resource_settings: ResourceSettings, - resource_combinations: List[Tuple[float, float]], + resource_combinations: list[tuple[float, float]], strict_resource_matching: bool = False, - ) -> Tuple[str, str]: + ) -> tuple[str, str]: """Convert ResourceSettings to AWS App Runner resource format. AWS App Runner only supports specific CPU-memory combinations. @@ -1065,11 +1060,11 @@ def _convert_resource_settings_to_aws_format( def _select_aws_cpu_memory_combination( self, - requested_cpu: Optional[float], - requested_memory_gb: Optional[float], - resource_combinations: List[Tuple[float, float]], + requested_cpu: float | None, + requested_memory_gb: float | None, + resource_combinations: list[tuple[float, float]], strict_resource_matching: bool = False, - ) -> Tuple[str, str]: + ) -> tuple[str, str]: """Select the best AWS App Runner CPU-memory combination. AWS App Runner only supports specific CPU and memory combinations, e.g.: @@ -1162,7 +1157,7 @@ def _select_aws_cpu_memory_combination( def _convert_scaling_settings_to_aws_format( self, resource_settings: ResourceSettings, - ) -> Tuple[int, int, int]: + ) -> tuple[int, int, int]: """Convert ResourceSettings scaling to AWS App Runner format. Args: @@ -1202,8 +1197,8 @@ def do_provision_deployment( self, deployment: DeploymentResponse, stack: "Stack", - environment: Dict[str, str], - secrets: Dict[str, str], + environment: dict[str, str], + secrets: dict[str, str], timeout: int, ) -> DeploymentOperationalState: """Serve a pipeline as an App Runner service. @@ -1309,7 +1304,7 @@ def do_provision_deployment( container_port = ( snapshot.pipeline_configuration.deployment_settings.uvicorn_port ) - image_config: Dict[str, Any] = { + image_config: dict[str, Any] = { "Port": str(container_port), "StartCommand": " ".join(entrypoint + arguments), } @@ -1590,7 +1585,7 @@ def do_get_deployment_state_logs( self, deployment: DeploymentResponse, follow: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> Generator[str, bool, None]: """Get the logs of an App Runner deployment. @@ -1699,7 +1694,7 @@ def do_deprovision_deployment( self, deployment: DeploymentResponse, timeout: int, - ) -> Optional[DeploymentOperationalState]: + ) -> DeploymentOperationalState | None: """Deprovision an App Runner deployment. Args: diff --git a/src/zenml/integrations/aws/flavors/aws_container_registry_flavor.py b/src/zenml/integrations/aws/flavors/aws_container_registry_flavor.py index cd5f277e0ce..6e2cc976b7a 100644 --- a/src/zenml/integrations/aws/flavors/aws_container_registry_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_container_registry_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """AWS container registry flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from pydantic import field_validator @@ -75,7 +75,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -92,7 +92,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -101,7 +101,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -119,7 +119,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/container_registry/aws.png" @property - def config_class(self) -> Type[AWSContainerRegistryConfig]: + def config_class(self) -> type[AWSContainerRegistryConfig]: """Config class for this flavor. Returns: @@ -128,7 +128,7 @@ def config_class(self) -> Type[AWSContainerRegistryConfig]: return AWSContainerRegistryConfig @property - def implementation_class(self) -> Type["AWSContainerRegistry"]: + def implementation_class(self) -> type["AWSContainerRegistry"]: """Implementation class. Returns: diff --git a/src/zenml/integrations/aws/flavors/aws_deployer_flavor.py b/src/zenml/integrations/aws/flavors/aws_deployer_flavor.py index 95348f02e1f..862e3864c6e 100644 --- a/src/zenml/integrations/aws/flavors/aws_deployer_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_deployer_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """AWS App Runner deployer flavor.""" -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type +from typing import TYPE_CHECKING from pydantic import Field @@ -36,7 +36,7 @@ class AWSDeployerSettings(BaseDeployerSettings): """Settings for the AWS App Runner deployer.""" - region: Optional[str] = Field( + region: str | None = Field( default=None, description="AWS region where the App Runner service will be deployed. " "If not specified, the region will be determined from the authenticated " @@ -98,20 +98,20 @@ class AWSDeployerSettings(BaseDeployerSettings): description="Whether the App Runner service is publicly accessible.", ) - ingress_vpc_configuration: Optional[str] = Field( + ingress_vpc_configuration: str | None = Field( default=None, description="VPC configuration for private App Runner services. " "JSON string with VpcId, VpcEndpointId, and VpcIngressConnectionName.", ) # Environment and configuration - environment_variables: Dict[str, str] = Field( + environment_variables: dict[str, str] = Field( default_factory=dict, description="Environment variables to set in the App Runner service.", ) # Tags - tags: Dict[str, str] = Field( + tags: dict[str, str] = Field( default_factory=dict, description="Tags to apply to the App Runner service.", ) @@ -131,25 +131,25 @@ class AWSDeployerSettings(BaseDeployerSettings): ) # Observability - observability_configuration_arn: Optional[str] = Field( + observability_configuration_arn: str | None = Field( default=None, description="ARN of the observability configuration to associate with " "the App Runner service.", ) # Encryption - encryption_kms_key: Optional[str] = Field( + encryption_kms_key: str | None = Field( default=None, description="KMS key ARN for encrypting App Runner service data.", ) # IAM Roles - instance_role_arn: Optional[str] = Field( + instance_role_arn: str | None = Field( default=None, description="ARN of the IAM role to assign to the App Runner service instances.", ) - access_role_arn: Optional[str] = Field( + access_role_arn: str | None = Field( default=None, description="ARN of the IAM role that App Runner uses to access the " "image repository (ECR). Required for private ECR repositories. If not " @@ -158,7 +158,7 @@ class AWSDeployerSettings(BaseDeployerSettings): ) # Traffic allocation for A/B testing and gradual rollouts - traffic_allocation: Dict[str, int] = Field( + traffic_allocation: dict[str, int] = Field( default_factory=lambda: {"LATEST": 100}, description="Traffic allocation between revisions for A/B testing and " "gradual rollouts. Keys can be revision names, tags, or 'LATEST' for " @@ -233,7 +233,7 @@ class AWSDeployerConfig( ): """Configuration for the AWS App Runner deployer.""" - resource_combinations: List[Tuple[float, float]] = Field( + resource_combinations: list[tuple[float, float]] = Field( default=DEFAULT_RESOURCE_COMBINATIONS, description="AWS App Runner supported CPU (vCPU), memory (GB) " "combinations.", @@ -268,7 +268,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -284,7 +284,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -293,7 +293,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -311,7 +311,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/deployer/aws-app-runner.png" @property - def config_class(self) -> Type[AWSDeployerConfig]: + def config_class(self) -> type[AWSDeployerConfig]: """Returns the AWSDeployerConfig config class. Returns: @@ -320,7 +320,7 @@ def config_class(self) -> Type[AWSDeployerConfig]: return AWSDeployerConfig @property - def implementation_class(self) -> Type["AWSDeployer"]: + def implementation_class(self) -> type["AWSDeployer"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/aws/flavors/aws_image_builder_flavor.py b/src/zenml/integrations/aws/flavors/aws_image_builder_flavor.py index 4fea01f3a52..5770ed4c864 100644 --- a/src/zenml/integrations/aws/flavors/aws_image_builder_flavor.py +++ b/src/zenml/integrations/aws/flavors/aws_image_builder_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """AWS Code Build image builder flavor.""" -from typing import TYPE_CHECKING, Dict, Optional, Type +from typing import TYPE_CHECKING from zenml.image_builders import BaseImageBuilderConfig, BaseImageBuilderFlavor from zenml.integrations.aws import ( @@ -63,7 +63,7 @@ class AWSImageBuilderConfig(BaseImageBuilderConfig): code_build_project: str build_image: str = DEFAULT_CLOUDBUILD_IMAGE - custom_env_vars: Optional[Dict[str, str]] = None + custom_env_vars: dict[str, str] | None = None compute_type: str = DEFAULT_CLOUDBUILD_COMPUTE_TYPE implicit_container_registry_auth: bool = True @@ -83,7 +83,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -99,7 +99,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -108,7 +108,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -126,7 +126,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/image_builder/aws.png" @property - def config_class(self) -> Type[BaseImageBuilderConfig]: + def config_class(self) -> type[BaseImageBuilderConfig]: """The config class. Returns: @@ -135,7 +135,7 @@ def config_class(self) -> Type[BaseImageBuilderConfig]: return AWSImageBuilderConfig @property - def implementation_class(self) -> Type["AWSImageBuilder"]: + def implementation_class(self) -> type["AWSImageBuilder"]: """Implementation class. Returns: diff --git a/src/zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py b/src/zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py index f7a04930455..31f546108cc 100644 --- a/src/zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +++ b/src/zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Amazon SageMaker orchestrator flavor.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Any from pydantic import Field, model_validator @@ -47,14 +47,14 @@ class SagemakerOrchestratorSettings(BaseSettings): "production pipelines where you don't want to maintain a connection", ) - instance_type: Optional[str] = Field( + instance_type: str | None = Field( None, description="AWS EC2 instance type for step execution. Must be a valid " "SageMaker-supported instance type. Examples: 'ml.t3.medium' (2 vCPU, 4GB RAM), " "'ml.m5.xlarge' (4 vCPU, 16GB RAM), 'ml.p3.2xlarge' (8 vCPU, 61GB RAM, 1 GPU). " "Defaults to ml.m5.xlarge for training steps or ml.t3.medium for processing steps", ) - execution_role: Optional[str] = Field( + execution_role: str | None = Field( None, description="IAM role ARN for SageMaker step execution permissions. Must have " "necessary policies attached (SageMakerFullAccess, S3 access, etc.). " @@ -75,45 +75,45 @@ class SagemakerOrchestratorSettings(BaseSettings): "Examples: 3600 (1 hour), 86400 (24 hours), 259200 (3 days). " "Consider your longest expected step duration", ) - tags: Dict[str, str] = Field( + tags: dict[str, str] = Field( default_factory=dict, description="Tags to apply to the Processor/Estimator assigned to the step. " "Example: {'Environment': 'Production', 'Project': 'MLOps'}", ) - pipeline_tags: Dict[str, str] = Field( + pipeline_tags: dict[str, str] = Field( default_factory=dict, description="Tags to apply to the pipeline via the " "sagemaker.workflow.pipeline.Pipeline.create method. Example: " "{'Environment': 'Production', 'Project': 'MLOps'}", ) - keep_alive_period_in_seconds: Optional[int] = Field( + keep_alive_period_in_seconds: int | None = Field( 300, # 5 minutes description="The time in seconds after which the provisioned instance " "will be terminated if not used. This is only applicable for " "TrainingStep type.", ) - use_training_step: Optional[bool] = Field( + use_training_step: bool | None = Field( None, description="Whether to use the TrainingStep type. It is not possible " "to use TrainingStep type if the `output_data_s3_uri` is set to " "Dict[str, str] or if the `output_data_s3_mode` != 'EndOfJob'.", ) - processor_args: Dict[str, Any] = Field( + processor_args: dict[str, Any] = Field( default_factory=dict, description="Arguments that are directly passed to the SageMaker " "Processor for a specific step, allowing for overriding the default " "settings provided when configuring the component. Example: " "{'instance_count': 2, 'base_job_name': 'my-processing-job'}", ) - estimator_args: Dict[str, Any] = Field( + estimator_args: dict[str, Any] = Field( default_factory=dict, description="Arguments that are directly passed to the SageMaker " "Estimator for a specific step, allowing for overriding the default " "settings provided when configuring the component. Example: " "{'train_instance_count': 2, 'train_max_run': 3600}", ) - environment: Dict[str, str] = Field( + environment: dict[str, str] = Field( default_factory=dict, description="Environment variables to pass to the container. " "Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", @@ -124,7 +124,7 @@ class SagemakerOrchestratorSettings(BaseSettings): description="How data is made available to the container. " "Two possible input modes: File, Pipe.", ) - input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field( + input_data_s3_uri: str | dict[str, str] | None = Field( default=None, union_mode="left_to_right", description="S3 URI where data is located if not locally. Example string: " @@ -137,19 +137,19 @@ class SagemakerOrchestratorSettings(BaseSettings): description="How data is uploaded to the S3 bucket. " "Two possible output modes: EndOfJob, Continuous.", ) - output_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field( + output_data_s3_uri: str | dict[str, str] | None = Field( default=None, union_mode="left_to_right", description="S3 URI where data is uploaded after or during processing run. " "Example string: 's3://my-bucket/my-data/output'. Example dict: " "{'output_one': 's3://bucket/out1', 'output_two': 's3://bucket/out2'}", ) - processor_role: Optional[str] = Field( + processor_role: str | None = Field( None, description="DEPRECATED: use `execution_role` instead. " "The IAM role to use for the step execution.", ) - processor_tags: Optional[Dict[str, str]] = Field( + processor_tags: dict[str, str] | None = Field( None, description="DEPRECATED: use `tags` instead. " "Tags to apply to the Processor assigned to the step.", @@ -159,7 +159,7 @@ class SagemakerOrchestratorSettings(BaseSettings): ) @model_validator(mode="before") - def validate_model(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def validate_model(cls, data: dict[str, Any]) -> dict[str, Any]: """Check if model is configured correctly. Args: @@ -212,39 +212,39 @@ class SagemakerOrchestratorConfig( execution_role: str = Field( ..., description="The IAM role ARN to use for the pipeline." ) - scheduler_role: Optional[str] = Field( + scheduler_role: str | None = Field( None, description="The ARN of the IAM role that will be assumed by " "the EventBridge service to launch Sagemaker pipelines. " "Required for scheduled pipelines.", ) - aws_access_key_id: Optional[str] = SecretField( + aws_access_key_id: str | None = SecretField( default=None, description="The AWS access key ID to use to authenticate to AWS. " "If not provided, the value from the default AWS config will be used.", ) - aws_secret_access_key: Optional[str] = SecretField( + aws_secret_access_key: str | None = SecretField( default=None, description="The AWS secret access key to use to authenticate to AWS. " "If not provided, the value from the default AWS config will be used.", ) - aws_profile: Optional[str] = Field( + aws_profile: str | None = Field( None, description="The AWS profile to use for authentication if not using " "service connectors or explicit credentials. If not provided, the " "default profile will be used.", ) - aws_auth_role_arn: Optional[str] = Field( + aws_auth_role_arn: str | None = Field( None, description="The ARN of an intermediate IAM role to assume when " "authenticating to AWS.", ) - region: Optional[str] = Field( + region: str | None = Field( None, description="The AWS region where the processing job will be run. " "If not provided, the value from the default AWS config will be used.", ) - bucket: Optional[str] = Field( + bucket: str | None = Field( None, description="Name of the S3 bucket to use for storing artifacts " "from the job run. If not provided, a default bucket will be created " @@ -298,7 +298,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -311,7 +311,7 @@ def service_connector_requirements( return ServiceConnectorRequirements(resource_type=AWS_RESOURCE_TYPE) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -320,7 +320,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -338,7 +338,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/sagemaker.png" @property - def config_class(self) -> Type[SagemakerOrchestratorConfig]: + def config_class(self) -> type[SagemakerOrchestratorConfig]: """Returns SagemakerOrchestratorConfig config class. Returns: @@ -347,7 +347,7 @@ def config_class(self) -> Type[SagemakerOrchestratorConfig]: return SagemakerOrchestratorConfig @property - def implementation_class(self) -> Type["SagemakerOrchestrator"]: + def implementation_class(self) -> type["SagemakerOrchestrator"]: """Implementation class. Returns: diff --git a/src/zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py b/src/zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py index 8c43737b67e..649c76a9737 100644 --- a/src/zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py +++ b/src/zenml/integrations/aws/flavors/sagemaker_step_operator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Amazon SageMaker step operator flavor.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Any from pydantic import Field @@ -36,31 +36,31 @@ class SagemakerStepOperatorSettings(BaseSettings): """Settings for the Sagemaker step operator.""" - instance_type: Optional[str] = Field( + instance_type: str | None = Field( None, description="DEPRECATED: The instance type to use for the step execution. " "Use estimator_args instead. Example: 'ml.m5.xlarge'", ) - experiment_name: Optional[str] = Field( + experiment_name: str | None = Field( None, description="The name for the experiment to which the job will be associated. " "If not provided, the job runs would be independent. Example: 'my-training-experiment'", ) - input_data_s3_uri: Optional[Union[str, Dict[str, str]]] = Field( + input_data_s3_uri: str | dict[str, str] | None = Field( default=None, union_mode="left_to_right", description="S3 URI where training data is located if not locally. " "Example string: 's3://my-bucket/my-data/train'. Example dict: " "{'training': 's3://bucket/train', 'validation': 's3://bucket/val'}", ) - estimator_args: Dict[str, Any] = Field( + estimator_args: dict[str, Any] = Field( default_factory=dict, description="Arguments that are directly passed to the SageMaker Estimator. " "See SageMaker documentation for available arguments and instance types. Example: " "{'instance_type': 'ml.m5.xlarge', 'instance_count': 1, " "'train_max_run': 3600, 'input_mode': 'File'}", ) - environment: Dict[str, str] = Field( + environment: dict[str, str] = Field( default_factory=dict, description="Environment variables to pass to the container during execution. " "Example: {'LOG_LEVEL': 'INFO', 'DEBUG_MODE': 'False'}", @@ -82,7 +82,7 @@ class SagemakerStepOperatorConfig( "running in SageMaker. This role must have the necessary permissions " "to access SageMaker and S3 resources.", ) - bucket: Optional[str] = Field( + bucket: str | None = Field( None, description="Name of the S3 bucket to use for storing artifacts from the job run. " "If not provided, a default bucket will be created based on the format: " @@ -118,7 +118,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -131,7 +131,7 @@ def service_connector_requirements( return ServiceConnectorRequirements(resource_type=AWS_RESOURCE_TYPE) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -140,7 +140,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -158,7 +158,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/sagemaker.png" @property - def config_class(self) -> Type[SagemakerStepOperatorConfig]: + def config_class(self) -> type[SagemakerStepOperatorConfig]: """Returns SagemakerStepOperatorConfig config class. Returns: @@ -167,7 +167,7 @@ def config_class(self) -> Type[SagemakerStepOperatorConfig]: return SagemakerStepOperatorConfig @property - def implementation_class(self) -> Type["SagemakerStepOperator"]: + def implementation_class(self) -> type["SagemakerStepOperator"]: """Implementation class. Returns: diff --git a/src/zenml/integrations/aws/image_builders/aws_image_builder.py b/src/zenml/integrations/aws/image_builders/aws_image_builder.py index 73319bb819c..18756d0928d 100644 --- a/src/zenml/integrations/aws/image_builders/aws_image_builder.py +++ b/src/zenml/integrations/aws/image_builders/aws_image_builder.py @@ -14,7 +14,7 @@ """AWS Code Build image builder implementation.""" import time -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Optional, cast from urllib.parse import urlparse from uuid import uuid4 @@ -41,7 +41,7 @@ class AWSImageBuilder(BaseImageBuilder): """AWS Code Build image builder implementation.""" - _code_build_client: Optional[Any] = None + _code_build_client: Any | None = None @property def config(self) -> AWSImageBuilderConfig: @@ -73,7 +73,7 @@ def validator(self) -> Optional["StackValidator"]: Stack validator. """ - def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + def _validate_remote_components(stack: "Stack") -> tuple[bool, str]: if stack.artifact_store.flavor != "s3": return False, ( "The AWS Image Builder requires an S3 Artifact Store to " @@ -126,7 +126,7 @@ def build( self, image_name: str, build_context: "BuildContext", - docker_build_options: Dict[str, Any], + docker_build_options: dict[str, Any], container_registry: Optional["BaseContainerRegistry"] = None, ) -> str: """Builds and pushes a Docker image. @@ -169,7 +169,7 @@ def build( # Pass authentication credentials as environment variables, if # the container registry has credentials and if implicit authentication # is disabled - environment_variables_override: Dict[str, str] = {} + environment_variables_override: dict[str, str] = {} pre_build_commands = [] if not self.config.implicit_container_registry_auth: credentials = container_registry.credentials diff --git a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py index 8ea7cd0d0d5..be5c9445ec4 100644 --- a/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +++ b/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py @@ -18,11 +18,7 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Tuple, - Type, Union, cast, ) @@ -90,7 +86,7 @@ def dissect_schedule_arn( schedule_arn: str, -) -> Tuple[Optional[str], Optional[str]]: +) -> tuple[str | None, str | None]: """Extracts the region and the name from an EventBridge schedule ARN. Args: @@ -120,7 +116,7 @@ def dissect_schedule_arn( def dissect_pipeline_execution_arn( pipeline_execution_arn: str, -) -> Tuple[Optional[str], Optional[str], Optional[str]]: +) -> tuple[str | None, str | None, str | None]: """Extract region name, pipeline name, and execution id from the ARN. Args: @@ -159,7 +155,7 @@ def config(self) -> SagemakerOrchestratorConfig: return cast(SagemakerOrchestratorConfig, self._config) @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates the stack. In the remote case, checks that the stack contains a container registry, @@ -171,7 +167,7 @@ def validator(self) -> Optional[StackValidator]: def _validate_remote_components( stack: "Stack", - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: for component in stack.components.values(): if not component.config.is_local: continue @@ -216,7 +212,7 @@ def get_orchestrator_run_id(self) -> str: ) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Sagemaker orchestrator. Returns: @@ -275,10 +271,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to @@ -437,10 +433,10 @@ def submit_pipeline( ) # Construct S3 inputs to container for step - training_inputs: Optional[ - Union[TrainingInput, Dict[str, TrainingInput]] - ] = None - processing_inputs: Optional[List[ProcessingInput]] = None + training_inputs: None | ( + TrainingInput | dict[str, TrainingInput] + ) = None + processing_inputs: list[ProcessingInput] | None = None if step_settings.input_data_s3_uri is None: pass @@ -514,7 +510,7 @@ def submit_pipeline( ) ) - final_step_environment: Dict[str, Union[str, PipelineVariable]] = { + final_step_environment: dict[str, str | PipelineVariable] = { key: str(value) for key, value in step_environment.items() } final_step_environment[ENV_ZENML_SAGEMAKER_RUN_ID] = ( @@ -535,7 +531,7 @@ def submit_pipeline( sagemaker_step = TrainingStep( name=step_name, depends_on=cast( - Optional[List[Union[str, Step, StepCollection]]], + Optional[list[Union[str, Step, StepCollection]]], step.spec.upstream_steps, ), inputs=training_inputs, @@ -545,7 +541,7 @@ def submit_pipeline( # Create Processor and ProcessingStep processor = Processor( entrypoint=cast( - Optional[List[Union[str, PipelineVariable]]], + Optional[list[Union[str, PipelineVariable]]], entrypoint, ), env=final_step_environment, @@ -556,7 +552,7 @@ def submit_pipeline( name=step_name, processor=processor, depends_on=cast( - Optional[List[Union[str, Step, StepCollection]]], + Optional[list[Union[str, Step, StepCollection]]], step.spec.upstream_steps, ), inputs=processing_inputs, @@ -787,7 +783,7 @@ def _wait_for_completion() -> None: def get_pipeline_run_metadata( self, run_id: UUID - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Get general component-specific metadata for a pipeline run. Args: @@ -810,8 +806,8 @@ def get_pipeline_run_metadata( def fetch_status( self, run: "PipelineRunResponse", include_steps: bool = False - ) -> Tuple[ - Optional[ExecutionStatus], Optional[Dict[str, ExecutionStatus]] + ) -> tuple[ + ExecutionStatus | None, dict[str, ExecutionStatus] | None ]: """Refreshes the status of a specific pipeline run. @@ -881,7 +877,7 @@ def compute_metadata( self, execution_arn: str, settings: SagemakerOrchestratorSettings, - ) -> Dict[str, MetadataType]: + ) -> dict[str, MetadataType]: """Generate run metadata based on the generated Sagemaker Execution. Args: @@ -892,7 +888,7 @@ def compute_metadata( A dictionary of metadata related to the pipeline run. """ # Orchestrator Run ID - metadata: Dict[str, MetadataType] = { + metadata: dict[str, MetadataType] = { "pipeline_execution_arn": execution_arn, METADATA_ORCHESTRATOR_RUN_ID: execution_arn, } @@ -914,7 +910,7 @@ def compute_metadata( def _compute_orchestrator_url( self, execution_arn: Any, - ) -> Optional[str]: + ) -> str | None: """Generate the Orchestrator Dashboard URL upon pipeline execution. Args: @@ -953,7 +949,7 @@ def _compute_orchestrator_url( def _compute_orchestrator_logs_url( execution_arn: Any, settings: SagemakerOrchestratorSettings, - ) -> Optional[str]: + ) -> str | None: """Generate the CloudWatch URL upon pipeline execution. Args: @@ -989,7 +985,7 @@ def _compute_orchestrator_logs_url( @staticmethod def generate_schedule_metadata( schedule_arn: str, - ) -> Dict[str, MetadataType]: + ) -> dict[str, MetadataType]: """Attaches metadata to the ZenML Schedules. Args: diff --git a/src/zenml/integrations/aws/service_connectors/aws_service_connector.py b/src/zenml/integrations/aws/service_connectors/aws_service_connector.py index 793e782f185..4ce04438acc 100644 --- a/src/zenml/integrations/aws/service_connectors/aws_service_connector.py +++ b/src/zenml/integrations/aws/service_connectors/aws_service_connector.py @@ -30,7 +30,7 @@ import json import os import re -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Any, cast import boto3 from aws_profile_manager import Common # type: ignore[import-untyped] @@ -102,7 +102,7 @@ class AWSBaseConfig(AuthenticationConfig): region: str = Field( title="AWS Region", ) - endpoint_url: Optional[str] = Field( + endpoint_url: str | None = Field( default=None, title="AWS Endpoint URL", ) @@ -111,13 +111,13 @@ class AWSBaseConfig(AuthenticationConfig): class AWSSessionPolicy(AuthenticationConfig): """AWS session IAM policy configuration.""" - policy_arns: Optional[List[str]] = Field( + policy_arns: list[str] | None = Field( default=None, title="ARNs of the IAM managed policies that you want to use as a " "managed session policy. The policies must exist in the same account " "as the IAM user that is requesting temporary credentials.", ) - policy: Optional[str] = Field( + policy: str | None = Field( default=None, title="An IAM policy in JSON format that you want to use as an inline " "session policy", @@ -127,11 +127,11 @@ class AWSSessionPolicy(AuthenticationConfig): class AWSImplicitConfig(AWSBaseConfig, AWSSessionPolicy): """AWS implicit configuration.""" - profile_name: Optional[str] = Field( + profile_name: str | None = Field( default=None, title="AWS Profile Name", ) - role_arn: Optional[str] = Field( + role_arn: str | None = Field( default=None, title="Optional AWS IAM Role ARN to assume", ) @@ -645,10 +645,10 @@ class AWSServiceConnector(ServiceConnector): config: AWSBaseConfig - _account_id: Optional[str] = None - _session_cache: Dict[ - Tuple[str, Optional[str], Optional[str]], - Tuple[boto3.Session, Optional[datetime.datetime]], + _account_id: str | None = None + _session_cache: dict[ + tuple[str, str | None, str | None], + tuple[boto3.Session, datetime.datetime | None], ] = {} @classmethod @@ -689,9 +689,9 @@ def account_id(self) -> str: def get_boto3_session( self, auth_method: str, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> Tuple[boto3.Session, Optional[datetime.datetime]]: + resource_type: str | None = None, + resource_id: str | None = None, + ) -> tuple[boto3.Session, datetime.datetime | None]: """Get a boto3 session for the specified resource. Args: @@ -765,9 +765,9 @@ def get_ecr_client(self) -> BaseClient: def _get_iam_policy( self, region_id: str, - resource_type: Optional[str], - resource_id: Optional[str] = None, - ) -> Optional[str]: + resource_type: str | None, + resource_id: str | None = None, + ) -> str | None: """Get the IAM inline policy to use for the specified resource. Args: @@ -874,9 +874,9 @@ def _get_iam_policy( def _authenticate( self, auth_method: str, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> Tuple[boto3.Session, Optional[datetime.datetime]]: + resource_type: str | None = None, + resource_id: str | None = None, + ) -> tuple[boto3.Session, datetime.datetime | None]: """Authenticate to AWS and return a boto3 session. Args: @@ -894,7 +894,7 @@ def _authenticate( NotImplementedError: If the authentication method is not supported. """ cfg = self.config - policy_kwargs: Dict[str, Any] = {} + policy_kwargs: dict[str, Any] = {} if auth_method == AWSAuthenticationMethods.IMPLICIT: self._check_implicit_auth_method_allowed() @@ -1166,7 +1166,7 @@ def _parse_s3_resource_id(self, resource_id: str) -> str: # - the S3 bucket name # # We need to extract the bucket name from the provided resource ID - bucket_name: Optional[str] = None + bucket_name: str | None = None if re.match( r"^arn:aws:s3:::[a-z0-9][a-z0-9\-\.]{1,61}[a-z0-9](/.*)*$", resource_id, @@ -1220,7 +1220,7 @@ def _parse_ecr_resource_id( # We need to extract the region ID and registry ID from # the provided resource ID config_region_id = self.config.region - region_id: Optional[str] = None + region_id: str | None = None if re.match( r"^arn:aws:ecr:[a-z0-9-]+:\d{12}:repository(/.+)*$", resource_id, @@ -1276,8 +1276,8 @@ def _parse_eks_resource_id(self, resource_id: str) -> str: # We need to extract the cluster name and region ID from the # provided resource ID config_region_id = self.config.region - cluster_name: Optional[str] = None - region_id: Optional[str] = None + cluster_name: str | None = None + region_id: str | None = None if re.match( r"^arn:aws:eks:[a-z0-9-]+:\d{12}:cluster/[0-9A-Za-z][A-Za-z0-9\-_]*$", resource_id, @@ -1440,7 +1440,7 @@ def _connect_to_resource( def _configure_local_client( self, - profile_name: Optional[str] = None, + profile_name: str | None = None, **kwargs: Any, ) -> None: """Configure a local client to authenticate and connect to a resource. @@ -1526,12 +1526,12 @@ def _configure_local_client( @classmethod def _auto_configure( cls, - auth_method: Optional[str] = None, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - region_name: Optional[str] = None, - profile_name: Optional[str] = None, - role_arn: Optional[str] = None, + auth_method: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, + region_name: str | None = None, + profile_name: str | None = None, + role_arn: str | None = None, **kwargs: Any, ) -> "AWSServiceConnector": """Auto-configure the connector. @@ -1571,8 +1571,8 @@ def _auto_configure( the environment. """ auth_config: AWSBaseConfig - expiration_seconds: Optional[int] = None - expires_at: Optional[datetime.datetime] = None + expiration_seconds: int | None = None + expires_at: datetime.datetime | None = None if auth_method == AWSAuthenticationMethods.IMPLICIT: if region_name is None: raise ValueError( @@ -1776,9 +1776,9 @@ def _auto_configure( def _verify( self, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> List[str]: + resource_type: str | None = None, + resource_id: str | None = None, + ) -> list[str]: """Verify and list all the resources that the connector can access. Args: @@ -1899,7 +1899,7 @@ def _verify( logger.error(msg) raise AuthorizationException(msg) from e - return cast(List[str], clusters["clusters"]) + return cast(list[str], clusters["clusters"]) else: # Check if the specified EKS cluster exists cluster_name = self._parse_eks_resource_id(resource_id) diff --git a/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py index 8adee8ba810..658288628ce 100644 --- a/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py +++ b/src/zenml/integrations/aws/step_operators/sagemaker_step_operator.py @@ -15,12 +15,6 @@ from typing import ( TYPE_CHECKING, - Dict, - List, - Optional, - Tuple, - Type, - Union, cast, ) @@ -77,7 +71,7 @@ def config(self) -> SagemakerStepOperatorConfig: return cast(SagemakerStepOperatorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the SageMaker step operator. Returns: @@ -88,7 +82,7 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: @property def entrypoint_config_class( self, - ) -> Type[StepOperatorEntrypointConfiguration]: + ) -> type[StepOperatorEntrypointConfiguration]: """Returns the entrypoint configuration class for this step operator. Returns: @@ -97,7 +91,7 @@ def entrypoint_config_class( return SagemakerEntrypointConfiguration @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates the stack. Returns: @@ -105,7 +99,7 @@ def validator(self) -> Optional[StackValidator]: registry and a remote artifact store. """ - def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + def _validate_remote_components(stack: "Stack") -> tuple[bool, str]: if stack.artifact_store.config.is_local: return False, ( "The SageMaker step operator runs code remotely and " @@ -141,7 +135,7 @@ def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: def get_docker_builds( self, snapshot: "PipelineSnapshotBase" - ) -> List["BuildConfiguration"]: + ) -> list["BuildConfiguration"]: """Gets the Docker builds required for the component. Args: @@ -166,8 +160,8 @@ def get_docker_builds( def launch( self, info: "StepRunInfo", - entrypoint_command: List[str], - environment: Dict[str, str], + entrypoint_command: list[str], + environment: dict[str, str], ) -> None: """Launches a step on SageMaker. @@ -257,7 +251,7 @@ def launch( ) # Construct training input object, if necessary - inputs: Optional[Union[TrainingInput, Dict[str, TrainingInput]]] = None + inputs: TrainingInput | dict[str, TrainingInput] | None = None if isinstance(settings.input_data_s3_uri, str): inputs = TrainingInput(s3_data=settings.input_data_s3_uri) diff --git a/src/zenml/integrations/azure/__init__.py b/src/zenml/integrations/azure/__init__.py index 933344ab6a2..fb687c9bd34 100644 --- a/src/zenml/integrations/azure/__init__.py +++ b/src/zenml/integrations/azure/__init__.py @@ -19,7 +19,6 @@ The Azure Step Operator integration submodule provides a way to run ZenML steps in AzureML. """ -from typing import List, Type from zenml.integrations.constants import AZURE from zenml.integrations.integration import Integration @@ -65,7 +64,7 @@ def activate(cls) -> None: from zenml.integrations.azure import service_connectors # noqa @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declares the flavors for the integration. Returns: diff --git a/src/zenml/integrations/azure/artifact_stores/azure_artifact_store.py b/src/zenml/integrations/azure/artifact_stores/azure_artifact_store.py index cbbaa144125..6cdb1469941 100644 --- a/src/zenml/integrations/azure/artifact_stores/azure_artifact_store.py +++ b/src/zenml/integrations/azure/artifact_stores/azure_artifact_store.py @@ -15,15 +15,10 @@ from typing import ( Any, - Callable, - Dict, - Iterable, - List, - Optional, - Tuple, Union, cast, ) +from collections.abc import Callable, Iterable import adlfs @@ -41,7 +36,7 @@ class AzureArtifactStore(BaseArtifactStore, AuthenticationMixin): """Artifact Store for Microsoft Azure based artifacts.""" - _filesystem: Optional[adlfs.AzureBlobFileSystem] = None + _filesystem: adlfs.AzureBlobFileSystem | None = None @property def config(self) -> AzureArtifactStoreConfig: @@ -52,7 +47,7 @@ def config(self) -> AzureArtifactStoreConfig: """ return cast(AzureArtifactStoreConfig, self._config) - def get_credentials(self) -> Optional[AzureSecretSchema]: + def get_credentials(self) -> AzureSecretSchema | None: """Returns the credentials for the Azure Artifact Store if configured. Returns: @@ -122,7 +117,7 @@ def filesystem(self) -> adlfs.AzureBlobFileSystem: ) return self._filesystem - def _split_path(self, path: PathType) -> Tuple[str, str]: + def _split_path(self, path: PathType) -> tuple[str, str]: """Splits a path into the filesystem prefix and remainder. Example: @@ -197,7 +192,7 @@ def exists(self, path: PathType) -> bool: """ return self.filesystem.exists(path=path) # type: ignore[no-any-return] - def glob(self, pattern: PathType) -> List[PathType]: + def glob(self, pattern: PathType) -> list[PathType]: """Return all paths that match the given glob pattern. The glob pattern may include: @@ -229,7 +224,7 @@ def isdir(self, path: PathType) -> bool: """ return self.filesystem.isdir(path=path) # type: ignore[no-any-return] - def listdir(self, path: PathType) -> List[PathType]: + def listdir(self, path: PathType) -> list[PathType]: """Return a list of files in a directory. Args: @@ -240,7 +235,7 @@ def listdir(self, path: PathType) -> List[PathType]: """ _, path = self._split_path(path) - def _extract_basename(file_dict: Dict[str, Any]) -> str: + def _extract_basename(file_dict: dict[str, Any]) -> str: """Extracts the basename from a dictionary returned by the Azure filesystem. Args: @@ -318,7 +313,7 @@ def rmtree(self, path: PathType) -> None: """ self.filesystem.delete(path=path, recursive=True) - def stat(self, path: PathType) -> Dict[str, Any]: + def stat(self, path: PathType) -> dict[str, Any]: """Return stat info for the given path. Args: @@ -344,8 +339,8 @@ def walk( self, top: PathType, topdown: bool = True, - onerror: Optional[Callable[..., None]] = None, - ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]: + onerror: Callable[..., None] | None = None, + ) -> Iterable[tuple[PathType, list[PathType], list[PathType]]]: """Return an iterator that walks the contents of the given directory. Args: diff --git a/src/zenml/integrations/azure/azureml_utils.py b/src/zenml/integrations/azure/azureml_utils.py index 73daf71e28d..aa4556411d1 100644 --- a/src/zenml/integrations/azure/azureml_utils.py +++ b/src/zenml/integrations/azure/azureml_utils.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """AzureML definitions.""" -from typing import Optional from azure.ai.ml import MLClient from azure.ai.ml.entities import Compute @@ -58,7 +57,7 @@ def create_or_get_compute( client: MLClient, settings: AzureMLComputeSettings, default_compute_name: str, -) -> Optional[str]: +) -> str | None: """Creates or fetches the compute target if defined in the settings. Args: diff --git a/src/zenml/integrations/azure/flavors/azure_artifact_store_flavor.py b/src/zenml/integrations/azure/flavors/azure_artifact_store_flavor.py index 4bf209f1441..802c18105de 100644 --- a/src/zenml/integrations/azure/flavors/azure_artifact_store_flavor.py +++ b/src/zenml/integrations/azure/flavors/azure_artifact_store_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Azure artifact store flavor.""" -from typing import TYPE_CHECKING, ClassVar, Optional, Set, Type +from typing import TYPE_CHECKING, ClassVar from zenml.artifact_stores import ( BaseArtifactStoreConfig, @@ -36,7 +36,7 @@ class AzureArtifactStoreConfig( ): """Configuration class for Azure Artifact Store.""" - SUPPORTED_SCHEMES: ClassVar[Set[str]] = {"abfs://", "az://"} + SUPPORTED_SCHEMES: ClassVar[set[str]] = {"abfs://", "az://"} class AzureArtifactStoreFlavor(BaseArtifactStoreFlavor): @@ -54,7 +54,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -71,7 +71,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -80,7 +80,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -98,7 +98,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/artifact_store/azure.png" @property - def config_class(self) -> Type[AzureArtifactStoreConfig]: + def config_class(self) -> type[AzureArtifactStoreConfig]: """Returns AzureArtifactStoreConfig config class. Returns: @@ -107,7 +107,7 @@ def config_class(self) -> Type[AzureArtifactStoreConfig]: return AzureArtifactStoreConfig @property - def implementation_class(self) -> Type["AzureArtifactStore"]: + def implementation_class(self) -> type["AzureArtifactStore"]: """Implementation class. Returns: diff --git a/src/zenml/integrations/azure/flavors/azureml.py b/src/zenml/integrations/azure/flavors/azureml.py index 9db7fcc223b..cc1a209baa4 100644 --- a/src/zenml/integrations/azure/flavors/azureml.py +++ b/src/zenml/integrations/azure/flavors/azureml.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """AzureML definitions.""" -from typing import Optional from pydantic import model_validator @@ -70,18 +69,18 @@ class AzureMLComputeSettings(BaseSettings): mode: AzureMLComputeTypes = AzureMLComputeTypes.SERVERLESS # Common Configuration for Compute Instances and Clusters - compute_name: Optional[str] = None - size: Optional[str] = None + compute_name: str | None = None + size: str | None = None # Additional configuration for a Compute Instance - idle_time_before_shutdown_minutes: Optional[int] = None + idle_time_before_shutdown_minutes: int | None = None # Additional configuration for a Compute Cluster - idle_time_before_scaledown_down: Optional[int] = None - location: Optional[str] = None - min_instances: Optional[int] = None - max_instances: Optional[int] = None - tier: Optional[str] = None + idle_time_before_scaledown_down: int | None = None + location: str | None = None + min_instances: int | None = None + max_instances: int | None = None + tier: str | None = None @model_validator(mode="after") def azureml_settings_validator(self) -> "AzureMLComputeSettings": diff --git a/src/zenml/integrations/azure/flavors/azureml_orchestrator_flavor.py b/src/zenml/integrations/azure/flavors/azureml_orchestrator_flavor.py index 3caadbc397c..7d1e45f937a 100644 --- a/src/zenml/integrations/azure/flavors/azureml_orchestrator_flavor.py +++ b/src/zenml/integrations/azure/flavors/azureml_orchestrator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the AzureML Orchestrator flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from pydantic import Field @@ -106,7 +106,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -119,7 +119,7 @@ def service_connector_requirements( return ServiceConnectorRequirements(resource_type=AZURE_RESOURCE_TYPE) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A URL to point at docs explaining this flavor. Returns: @@ -128,7 +128,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A URL to point at SDK docs explaining this flavor. Returns: @@ -146,7 +146,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/azureml.png" @property - def config_class(self) -> Type[AzureMLOrchestratorConfig]: + def config_class(self) -> type[AzureMLOrchestratorConfig]: """Returns AzureMLOrchestratorConfig config class. Returns: @@ -155,7 +155,7 @@ def config_class(self) -> Type[AzureMLOrchestratorConfig]: return AzureMLOrchestratorConfig @property - def implementation_class(self) -> Type["AzureMLOrchestrator"]: + def implementation_class(self) -> type["AzureMLOrchestrator"]: """Implementation class. Returns: diff --git a/src/zenml/integrations/azure/flavors/azureml_step_operator_flavor.py b/src/zenml/integrations/azure/flavors/azureml_step_operator_flavor.py index 49550212f96..7d8d631d435 100644 --- a/src/zenml/integrations/azure/flavors/azureml_step_operator_flavor.py +++ b/src/zenml/integrations/azure/flavors/azureml_step_operator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """AzureML step operator flavor.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Type +from typing import TYPE_CHECKING, Any from pydantic import Field, model_validator @@ -45,7 +45,7 @@ class AzureMLStepOperatorSettings(AzureMLComputeSettings): Deprecated in favor of `compute_name`. """ - compute_target_name: Optional[str] = Field( + compute_target_name: str | None = Field( default=None, description="Name of the configured ComputeTarget. Deprecated in favor " "of `compute_name`.", @@ -54,7 +54,7 @@ class AzureMLStepOperatorSettings(AzureMLComputeSettings): @model_validator(mode="before") @classmethod @before_validator_handler - def _migrate_compute_name(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def _migrate_compute_name(cls, data: dict[str, Any]) -> dict[str, Any]: """Backward compatibility for compute_target_name. Args: @@ -102,9 +102,9 @@ class AzureMLStepOperatorConfig( # Service principal authentication # https://docs.microsoft.com/en-us/azure/machine-learning/how-to-setup-authentication#configure-a-service-principal - tenant_id: Optional[str] = SecretField(default=None) - service_principal_id: Optional[str] = SecretField(default=None) - service_principal_password: Optional[str] = SecretField(default=None) + tenant_id: str | None = SecretField(default=None) + service_principal_id: str | None = SecretField(default=None) + service_principal_password: str | None = SecretField(default=None) @property def is_remote(self) -> bool: @@ -135,7 +135,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -148,7 +148,7 @@ def service_connector_requirements( return ServiceConnectorRequirements(resource_type=AZURE_RESOURCE_TYPE) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -157,7 +157,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -175,7 +175,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/azureml.png" @property - def config_class(self) -> Type[AzureMLStepOperatorConfig]: + def config_class(self) -> type[AzureMLStepOperatorConfig]: """Returns AzureMLStepOperatorConfig config class. Returns: @@ -184,7 +184,7 @@ def config_class(self) -> Type[AzureMLStepOperatorConfig]: return AzureMLStepOperatorConfig @property - def implementation_class(self) -> Type["AzureMLStepOperator"]: + def implementation_class(self) -> type["AzureMLStepOperator"]: """Implementation class. Returns: diff --git a/src/zenml/integrations/azure/orchestrators/azureml_orchestrator.py b/src/zenml/integrations/azure/orchestrators/azureml_orchestrator.py index 49f15d82489..898f98df733 100644 --- a/src/zenml/integrations/azure/orchestrators/azureml_orchestrator.py +++ b/src/zenml/integrations/azure/orchestrators/azureml_orchestrator.py @@ -18,12 +18,7 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Tuple, - Type, - Union, cast, ) from uuid import UUID @@ -89,7 +84,7 @@ def config(self) -> AzureMLOrchestratorConfig: return cast(AzureMLOrchestratorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the AzureML orchestrator. Returns: @@ -98,7 +93,7 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: return AzureMLOrchestratorSettings @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates the stack. In the remote case, checks that the stack contains a container registry, @@ -110,7 +105,7 @@ def validator(self) -> Optional[StackValidator]: def _validate_remote_components( stack: "Stack", - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: for component in stack.components.values(): if not component.config.is_local: continue @@ -160,8 +155,8 @@ def _create_command_component( step_name: str, env_name: str, image: str, - command: List[str], - arguments: List[str], + command: list[str], + arguments: list[str], ) -> CommandComponent: """Creates a CommandComponent to run on AzureML Pipelines. @@ -201,10 +196,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to @@ -294,7 +289,7 @@ def azureml_pipeline() -> None: """Create an AzureML pipeline.""" # Here we have to track the inputs and outputs so that we can bind # the components to each other to execute them in a specific order. - component_outputs: Dict[str, Any] = {} + component_outputs: dict[str, Any] = {} for component_name, component in components.items(): # Inputs component_inputs = {} @@ -321,9 +316,9 @@ def azureml_pipeline() -> None: # Scheduling if schedule := snapshot.schedule: try: - schedule_trigger: Optional[ - Union[CronTrigger, RecurrenceTrigger] - ] = None + schedule_trigger: None | ( + CronTrigger | RecurrenceTrigger + ) = None start_time = None if schedule.start_time is not None: @@ -423,7 +418,7 @@ def _wait_for_completion() -> None: def get_pipeline_run_metadata( self, run_id: UUID - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Get general component-specific metadata for a pipeline run. Args: @@ -460,8 +455,8 @@ def get_pipeline_run_metadata( def fetch_status( self, run: "PipelineRunResponse", include_steps: bool = False - ) -> Tuple[ - Optional[ExecutionStatus], Optional[Dict[str, ExecutionStatus]] + ) -> tuple[ + ExecutionStatus | None, dict[str, ExecutionStatus] | None ]: """Refreshes the status of a specific pipeline run. @@ -542,7 +537,7 @@ def fetch_status( # AzureML doesn't support step-level status fetching yet return pipeline_status, None - def compute_metadata(self, job: Any) -> Dict[str, MetadataType]: + def compute_metadata(self, job: Any) -> dict[str, MetadataType]: """Generate run metadata based on the generated AzureML PipelineJob. Args: @@ -552,7 +547,7 @@ def compute_metadata(self, job: Any) -> Dict[str, MetadataType]: A dictionary of metadata related to the pipeline run. """ # Metadata - metadata: Dict[str, MetadataType] = {} + metadata: dict[str, MetadataType] = {} # Orchestrator Run ID if run_id := self._compute_orchestrator_run_id(job): @@ -565,7 +560,7 @@ def compute_metadata(self, job: Any) -> Dict[str, MetadataType]: return metadata @staticmethod - def _compute_orchestrator_url(job: Any) -> Optional[str]: + def _compute_orchestrator_url(job: Any) -> str | None: """Generate the Orchestrator Dashboard URL upon pipeline execution. Args: @@ -587,7 +582,7 @@ def _compute_orchestrator_url(job: Any) -> Optional[str]: return None @staticmethod - def _compute_orchestrator_run_id(job: Any) -> Optional[str]: + def _compute_orchestrator_run_id(job: Any) -> str | None: """Generate the Orchestrator Dashboard URL upon pipeline execution. Args: diff --git a/src/zenml/integrations/azure/orchestrators/azureml_orchestrator_entrypoint_config.py b/src/zenml/integrations/azure/orchestrators/azureml_orchestrator_entrypoint_config.py index a2006cef723..a61ecfd153c 100644 --- a/src/zenml/integrations/azure/orchestrators/azureml_orchestrator_entrypoint_config.py +++ b/src/zenml/integrations/azure/orchestrators/azureml_orchestrator_entrypoint_config.py @@ -15,7 +15,7 @@ import json import os -from typing import Any, List, Set +from typing import Any from zenml.entrypoints.step_entrypoint_configuration import ( StepEntrypointConfiguration, @@ -30,7 +30,7 @@ class AzureMLEntrypointConfiguration(StepEntrypointConfiguration): """Entrypoint configuration for ZenML AzureML pipeline steps.""" @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> set[str]: """Gets all options required for running with this configuration. Returns: @@ -40,7 +40,7 @@ def get_entrypoint_options(cls) -> Set[str]: return super().get_entrypoint_options() | {ZENML_ENV_VARIABLES} @classmethod - def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]: + def get_entrypoint_arguments(cls, **kwargs: Any) -> list[str]: """Gets all arguments that the entrypoint command should be called with. Args: diff --git a/src/zenml/integrations/azure/service_connectors/azure_service_connector.py b/src/zenml/integrations/azure/service_connectors/azure_service_connector.py index f4c1865f6fb..e12bfa16082 100644 --- a/src/zenml/integrations/azure/service_connectors/azure_service_connector.py +++ b/src/zenml/integrations/azure/service_connectors/azure_service_connector.py @@ -18,7 +18,7 @@ import logging import re import subprocess -from typing import Any, Dict, List, Optional, Tuple +from typing import Any from uuid import UUID import requests @@ -77,21 +77,21 @@ class AzureBaseConfig(AuthenticationConfig): """Azure base configuration.""" - subscription_id: Optional[UUID] = Field( + subscription_id: UUID | None = Field( default=None, title="Azure Subscription ID", description="The subscription ID of the Azure account. If not " "specified, ZenML will attempt to retrieve the subscription ID from " "Azure using the configured credentials.", ) - tenant_id: Optional[UUID] = Field( + tenant_id: UUID | None = Field( default=None, title="Azure Tenant ID", description="The tenant ID of the Azure account. If not specified, " "ZenML will attempt to retrieve the tenant from Azure using the " "configured credentials.", ) - resource_group: Optional[str] = Field( + resource_group: str | None = Field( default=None, title="Azure Resource Group", description="A resource group may be used to restrict the scope of " @@ -99,7 +99,7 @@ class AzureBaseConfig(AuthenticationConfig): "specified, ZenML will retrieve resources from all resource groups " "accessible with the configured credentials.", ) - storage_account: Optional[str] = Field( + storage_account: str | None = Field( default=None, title="Azure Storage Account", description="The name of an Azure storage account may be used to " @@ -457,12 +457,12 @@ class AzureServiceConnector(ServiceConnector): config: AzureBaseConfig - _subscription_id: Optional[str] = None - _subscription_name: Optional[str] = None - _tenant_id: Optional[str] = None - _session_cache: Dict[ + _subscription_id: str | None = None + _subscription_name: str | None = None + _tenant_id: str | None = None + _session_cache: dict[ str, - Tuple[TokenCredential, Optional[datetime.datetime]], + tuple[TokenCredential, datetime.datetime | None], ] = {} @classmethod @@ -475,7 +475,7 @@ def _get_connector_type(cls) -> ServiceConnectorTypeModel: return AZURE_SERVICE_CONNECTOR_TYPE_SPEC @property - def subscription(self) -> Tuple[str, str]: + def subscription(self) -> tuple[str, str]: """Get the Azure subscription ID and name. Returns: @@ -578,9 +578,9 @@ def tenant_id(self) -> str: def get_azure_credential( self, auth_method: str, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> Tuple[TokenCredential, Optional[datetime.datetime]]: + resource_type: str | None = None, + resource_id: str | None = None, + ) -> tuple[TokenCredential, datetime.datetime | None]: """Get an Azure credential for the specified resource. Args: @@ -622,9 +622,9 @@ def get_azure_credential( def _authenticate( self, auth_method: str, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> Tuple[TokenCredential, Optional[datetime.datetime]]: + resource_type: str | None = None, + resource_id: str | None = None, + ) -> tuple[TokenCredential, datetime.datetime | None]: """Authenticate to Azure and return a token credential. Args: @@ -711,7 +711,7 @@ def _parse_blob_container_resource_id(self, resource_id: str) -> str: ValueError: If the provided resource ID is not a valid Azure blob resource ID. """ - container_name: Optional[str] = None + container_name: str | None = None if re.match( r"^(az|abfs)://[a-z0-9](?!.*--)[a-z0-9-]{1,61}[a-z0-9](/.*)*$", resource_id, @@ -758,7 +758,7 @@ def _parse_acr_resource_id( ValueError: If the provided resource ID is not a valid ACR resource ID. """ - registry_name: Optional[str] = None + registry_name: str | None = None if re.match( r"^(https?://)?[a-zA-Z0-9]+\.azurecr\.io(/.+)*$", resource_id, @@ -783,7 +783,7 @@ def _parse_acr_resource_id( def _parse_aks_resource_id( self, resource_id: str - ) -> Tuple[Optional[str], str]: + ) -> tuple[str | None, str]: """Validate and convert an AKS resource ID to an AKS cluster name. The resource ID could mean different things: @@ -804,7 +804,7 @@ def _parse_aks_resource_id( ValueError: If the provided resource ID is not a valid AKS cluster name. """ - resource_group: Optional[str] = self.config.resource_group + resource_group: str | None = self.config.resource_group if re.match( r"^[a-zA-Z0-9_.()-]+/[a-zA-Z0-9]+[a-zA-Z0-9_-]*[a-zA-Z0-9]+$", resource_id, @@ -1045,11 +1045,11 @@ def _configure_local_client( @classmethod def _auto_configure( cls, - auth_method: Optional[str] = None, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - resource_group: Optional[str] = None, - storage_account: Optional[str] = None, + auth_method: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, + resource_group: str | None = None, + storage_account: str | None = None, **kwargs: Any, ) -> "AzureServiceConnector": """Auto-configure the connector. @@ -1089,8 +1089,8 @@ def _auto_configure( the environment. """ auth_config: AzureBaseConfig - expiration_seconds: Optional[int] = None - expires_at: Optional[datetime.datetime] = None + expiration_seconds: int | None = None + expires_at: datetime.datetime | None = None if auth_method == AzureAuthenticationMethods.IMPLICIT: auth_config = AzureBaseConfig( resource_group=resource_group, @@ -1165,8 +1165,8 @@ def _get_resource_group(cls, resource_id: str) -> str: return resource_id.split("/")[4] def _list_blob_containers( - self, credential: TokenCredential, container_name: Optional[str] = None - ) -> Dict[str, str]: + self, credential: TokenCredential, container_name: str | None = None + ) -> dict[str, str]: """Get the list of blob storage containers that the connector can access. Args: @@ -1192,7 +1192,7 @@ def _list_blob_containers( # is provided, we only need to find the storage account that contains # it. - storage_accounts: List[str] = [] + storage_accounts: list[str] = [] if self.config.storage_account: storage_accounts = [self.config.storage_account] else: @@ -1239,7 +1239,7 @@ def _list_blob_containers( account.name for account in accounts if account.name ] - containers: Dict[str, str] = {} + containers: dict[str, str] = {} for storage_account in storage_accounts: account_url = f"https://{storage_account}.blob.core.windows.net/" @@ -1301,8 +1301,8 @@ def _list_blob_containers( return containers def _list_acr_registries( - self, credential: TokenCredential, registry_name: Optional[str] = None - ) -> Dict[str, str]: + self, credential: TokenCredential, registry_name: str | None = None + ) -> dict[str, str]: """Get the list of ACR registries that the connector can access. Args: @@ -1322,7 +1322,7 @@ def _list_acr_registries( """ subscription_id, _ = self.subscription - container_registries: Dict[str, str] = {} + container_registries: dict[str, str] = {} if registry_name and self.config.resource_group: try: container_client = ContainerRegistryManagementClient( @@ -1419,9 +1419,9 @@ def _list_acr_registries( def _list_aks_clusters( self, credential: TokenCredential, - cluster_name: Optional[str] = None, - resource_group: Optional[str] = None, - ) -> List[Tuple[str, str]]: + cluster_name: str | None = None, + resource_group: str | None = None, + ) -> list[tuple[str, str]]: """Get the list of AKS clusters that the connector can access. Args: @@ -1444,7 +1444,7 @@ def _list_aks_clusters( """ subscription_id, _ = self.subscription - clusters: List[Tuple[str, str]] = [] + clusters: list[tuple[str, str]] = [] if cluster_name and resource_group: try: container_client = ContainerServiceClient( @@ -1547,9 +1547,9 @@ def _list_aks_clusters( def _verify( self, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> List[str]: + resource_type: str | None = None, + resource_id: str | None = None, + ) -> list[str]: """Verify and list all the resources that the connector can access. Args: @@ -1600,7 +1600,7 @@ def _verify( "supported for blob storage resources" ) - container_name: Optional[str] = None + container_name: str | None = None if resource_id: container_name = self._parse_blob_container_resource_id( resource_id @@ -1614,7 +1614,7 @@ def _verify( return [f"az://{container}" for container in containers.keys()] if resource_type == DOCKER_REGISTRY_RESOURCE_TYPE: - registry_name: Optional[str] = None + registry_name: str | None = None if resource_id: registry_name = self._parse_acr_resource_id(resource_id) @@ -1626,7 +1626,7 @@ def _verify( return [f"{registry}.azurecr.io" for registry in registries.keys()] if resource_type == KUBERNETES_CLUSTER_RESOURCE_TYPE: - cluster_name: Optional[str] = None + cluster_name: str | None = None resource_group = self.config.resource_group if resource_id: resource_group, cluster_name = self._parse_aks_resource_id( @@ -1722,7 +1722,7 @@ def _get_connector_client( resource_id=resource_id, ) - resource_group: Optional[str] + resource_group: str | None registry_name: str cluster_name: str diff --git a/src/zenml/integrations/azure/step_operators/azureml_step_operator.py b/src/zenml/integrations/azure/step_operators/azureml_step_operator.py index 0084ccc4c01..c37eade8ea0 100644 --- a/src/zenml/integrations/azure/step_operators/azureml_step_operator.py +++ b/src/zenml/integrations/azure/step_operators/azureml_step_operator.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the ZenML AzureML Step Operator.""" -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, cast +from typing import TYPE_CHECKING, cast from azure.ai.ml import MLClient, command from azure.ai.ml.entities import Environment @@ -59,7 +59,7 @@ def config(self) -> AzureMLStepOperatorConfig: return cast(AzureMLStepOperatorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the AzureML step operator. Returns: @@ -68,7 +68,7 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: return AzureMLStepOperatorSettings @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates the stack. Returns: @@ -78,7 +78,7 @@ def validator(self) -> Optional[StackValidator]: def _validate_remote_components( stack: "Stack", - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: if stack.artifact_store.config.is_local: return False, ( "The AzureML step operator runs code remotely and " @@ -138,7 +138,7 @@ def _get_credentials(self) -> TokenCredential: def get_docker_builds( self, snapshot: "PipelineSnapshotBase" - ) -> List["BuildConfiguration"]: + ) -> list["BuildConfiguration"]: """Gets the Docker builds required for the component. Args: @@ -162,8 +162,8 @@ def get_docker_builds( def launch( self, info: "StepRunInfo", - entrypoint_command: List[str], - environment: Dict[str, str], + entrypoint_command: list[str], + environment: dict[str, str], ) -> None: """Launches a step on AzureML. diff --git a/src/zenml/integrations/bentoml/__init__.py b/src/zenml/integrations/bentoml/__init__.py index 43d47f8b22a..5b5603bbf3d 100644 --- a/src/zenml/integrations/bentoml/__init__.py +++ b/src/zenml/integrations/bentoml/__init__.py @@ -16,7 +16,6 @@ The BentoML integration allows you to use the BentoML model serving to implement continuous model deployment. """ -from typing import List, Type from zenml.integrations.constants import BENTOML from zenml.integrations.integration import Integration @@ -41,7 +40,7 @@ def activate(cls) -> None: from zenml.integrations.bentoml import services # noqa @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for BentoML. Returns: diff --git a/src/zenml/integrations/bentoml/flavors/bentoml_model_deployer_flavor.py b/src/zenml/integrations/bentoml/flavors/bentoml_model_deployer_flavor.py index e99ada72d3f..936b7993628 100644 --- a/src/zenml/integrations/bentoml/flavors/bentoml_model_deployer_flavor.py +++ b/src/zenml/integrations/bentoml/flavors/bentoml_model_deployer_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """BentoML model deployer flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from pydantic import Field @@ -50,7 +50,7 @@ def name(self) -> str: return BENTOML_MODEL_DEPLOYER_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -59,7 +59,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -77,7 +77,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/bentoml.png" @property - def config_class(self) -> Type[BentoMLModelDeployerConfig]: + def config_class(self) -> type[BentoMLModelDeployerConfig]: """Returns `BentoMLModelDeployerConfig` config class. Returns: @@ -86,7 +86,7 @@ def config_class(self) -> Type[BentoMLModelDeployerConfig]: return BentoMLModelDeployerConfig @property - def implementation_class(self) -> Type["BentoMLModelDeployer"]: + def implementation_class(self) -> type["BentoMLModelDeployer"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/bentoml/materializers/bentoml_bento_materializer.py b/src/zenml/integrations/bentoml/materializers/bentoml_bento_materializer.py index 87370f36a02..dc0b89b27d9 100644 --- a/src/zenml/integrations/bentoml/materializers/bentoml_bento_materializer.py +++ b/src/zenml/integrations/bentoml/materializers/bentoml_bento_materializer.py @@ -14,7 +14,7 @@ """Materializer for BentoML Bento objects.""" import os -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar import bentoml from bentoml._internal.bento import Bento, bento @@ -35,10 +35,10 @@ class BentoMaterializer(BaseMaterializer): """Materializer for Bentoml Bento objects.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (bento.Bento,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (bento.Bento,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - def load(self, data_type: Type[bento.Bento]) -> bento.Bento: + def load(self, data_type: type[bento.Bento]) -> bento.Bento: """Read from artifact store and return a Bento object. Args: @@ -76,7 +76,7 @@ def save(self, bento: bento.Bento) -> None: def extract_metadata( self, bento: bento.Bento - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Extract metadata from the given `Bento` object. Args: diff --git a/src/zenml/integrations/bentoml/model_deployers/bentoml_model_deployer.py b/src/zenml/integrations/bentoml/model_deployers/bentoml_model_deployer.py index dbc99935c0e..5c5c4bb275b 100644 --- a/src/zenml/integrations/bentoml/model_deployers/bentoml_model_deployer.py +++ b/src/zenml/integrations/bentoml/model_deployers/bentoml_model_deployer.py @@ -15,7 +15,7 @@ import os import shutil -from typing import ClassVar, Dict, Optional, Type, cast +from typing import ClassVar, cast from uuid import UUID from zenml.config.global_config import GlobalConfiguration @@ -48,11 +48,11 @@ class BentoMLModelDeployer(BaseModelDeployer): """BentoML model deployer stack component implementation.""" NAME: ClassVar[str] = "BentoML" - FLAVOR: ClassVar[Type[BaseModelDeployerFlavor]] = ( + FLAVOR: ClassVar[type[BaseModelDeployerFlavor]] = ( BentoMLModelDeployerFlavor ) - _service_path: Optional[str] = None + _service_path: str | None = None @property def config(self) -> BentoMLModelDeployerConfig: @@ -110,7 +110,7 @@ def local_path(self) -> str: @staticmethod def get_model_server_info( service_instance: BaseService, - ) -> Dict[str, Optional[str]]: + ) -> dict[str, str | None]: """Return implementation specific information on the model server. Args: diff --git a/src/zenml/integrations/bentoml/services/bentoml_container_deployment.py b/src/zenml/integrations/bentoml/services/bentoml_container_deployment.py index e35ad910e16..f710c8974e7 100644 --- a/src/zenml/integrations/bentoml/services/bentoml_container_deployment.py +++ b/src/zenml/integrations/bentoml/services/bentoml_container_deployment.py @@ -15,7 +15,7 @@ import os import sys -from typing import Any, Dict, List, Optional, Union +from typing import Any import bentoml import docker.errors as docker_errors @@ -57,18 +57,18 @@ class BentoMLContainerDeploymentConfig(ContainerServiceConfig): model_name: str model_uri: str bento_tag: str - bento_uri: Optional[str] = None - platform: Optional[str] = None + bento_uri: str | None = None + platform: str | None = None image: str = "" - image_tag: Optional[str] = None - features: Optional[List[str]] = None - file: Optional[str] = None - apis: List[str] = [] - working_dir: Optional[str] = None + image_tag: str | None = None + features: list[str] | None = None + file: str | None = None + apis: list[str] = [] + working_dir: str | None = None workers: int = 1 backlog: int = 2048 - host: Optional[str] = None - port: Optional[int] = None + host: str | None = None + port: int | None = None class BentoMLContainerDeploymentEndpointConfig(ContainerServiceEndpointConfig): @@ -91,7 +91,7 @@ class BentoMLContainerDeploymentEndpoint(ContainerServiceEndpoint): config: BentoMLContainerDeploymentEndpointConfig @property - def prediction_url(self) -> Optional[str]: + def prediction_url(self) -> str | None: """Gets the prediction URL for the endpoint. Returns: @@ -121,7 +121,7 @@ class BentoMLContainerDeploymentService( def __init__( self, - config: Union[BentoMLContainerDeploymentConfig, Dict[str, Any]], + config: BentoMLContainerDeploymentConfig | dict[str, Any], **attrs: Any, ) -> None: """Initialize the BentoML deployment service. @@ -203,7 +203,7 @@ def _start_container(self) -> None: self._setup_runtime_path() - ports: Dict[int, Optional[int]] = {} + ports: dict[int, int | None] = {} if self.endpoint: self.endpoint.prepare_for_start() if self.endpoint.status.port: @@ -341,7 +341,7 @@ def run(self) -> None: raise @property - def prediction_url(self) -> Optional[str]: + def prediction_url(self) -> str | None: """Get the URI where the http server is running. Returns: @@ -353,7 +353,7 @@ def prediction_url(self) -> Optional[str]: return self.endpoint.prediction_url @property - def prediction_apis_urls(self) -> Optional[List[str]]: + def prediction_apis_urls(self) -> list[str] | None: """Get the URI where the prediction api services is answering requests. Returns: diff --git a/src/zenml/integrations/bentoml/services/bentoml_local_deployment.py b/src/zenml/integrations/bentoml/services/bentoml_local_deployment.py index 59bfd99af0f..9d4513257ff 100644 --- a/src/zenml/integrations/bentoml/services/bentoml_local_deployment.py +++ b/src/zenml/integrations/bentoml/services/bentoml_local_deployment.py @@ -14,7 +14,7 @@ """Implementation for the BentoML local deployment service.""" import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any from bentoml import AsyncHTTPClient, SyncHTTPClient from pydantic import BaseModel, Field @@ -69,7 +69,7 @@ class BentoMLDeploymentEndpoint(LocalDaemonServiceEndpoint): monitor: HTTPEndpointHealthMonitor @property - def prediction_url(self) -> Optional[str]: + def prediction_url(self) -> str | None: """Gets the prediction URL for the endpoint. Returns: @@ -94,13 +94,13 @@ class SSLBentoMLParametersConfig(BaseModel): ssl_ciphers: SSL ciphers """ - ssl_certfile: Optional[str] = None - ssl_keyfile: Optional[str] = None - ssl_keyfile_password: Optional[str] = None - ssl_version: Optional[int] = None - ssl_cert_reqs: Optional[int] = None - ssl_ca_certs: Optional[str] = None - ssl_ciphers: Optional[str] = None + ssl_certfile: str | None = None + ssl_keyfile: str | None = None + ssl_keyfile_password: str | None = None + ssl_version: int | None = None + ssl_cert_reqs: int | None = None + ssl_ca_certs: str | None = None + ssl_ciphers: str | None = None class BentoMLLocalDeploymentConfig(LocalDaemonServiceConfig): @@ -123,15 +123,15 @@ class BentoMLLocalDeploymentConfig(LocalDaemonServiceConfig): model_name: str model_uri: str bento_tag: str - bento_uri: Optional[str] = None - apis: List[str] = [] + bento_uri: str | None = None + apis: list[str] = [] workers: int = 1 - port: Optional[int] = None + port: int | None = None backlog: int = 2048 production: bool = False working_dir: str - host: Optional[str] = None - ssl_parameters: Optional[SSLBentoMLParametersConfig] = Field( + host: str | None = None + ssl_parameters: SSLBentoMLParametersConfig | None = Field( default_factory=SSLBentoMLParametersConfig ) @@ -159,7 +159,7 @@ class BentoMLLocalDeploymentService(LocalDaemonService, BaseDeploymentService): def __init__( self, - config: Union[BentoMLLocalDeploymentConfig, Dict[str, Any]], + config: BentoMLLocalDeploymentConfig | dict[str, Any], **attrs: Any, ) -> None: """Initialize the BentoML deployment service. @@ -258,7 +258,7 @@ def run(self) -> None: logger.info("Stopping BentoML prediction service...") @property - def prediction_url(self) -> Optional[str]: + def prediction_url(self) -> str | None: """Get the URI where the http server is running. Returns: @@ -270,7 +270,7 @@ def prediction_url(self) -> Optional[str]: return self.endpoint.prediction_url @property - def prediction_apis_urls(self) -> Optional[List[str]]: + def prediction_apis_urls(self) -> list[str] | None: """Get the URI where the prediction api services is answering requests. Returns: diff --git a/src/zenml/integrations/bentoml/steps/bento_builder.py b/src/zenml/integrations/bentoml/steps/bento_builder.py index 6163ae9ea07..1a6cc3e384a 100644 --- a/src/zenml/integrations/bentoml/steps/bento_builder.py +++ b/src/zenml/integrations/bentoml/steps/bento_builder.py @@ -15,7 +15,7 @@ import importlib import os -from typing import Any, Dict, List, Optional +from typing import Any import bentoml from bentoml import bentos @@ -37,14 +37,14 @@ def bento_builder_step( model_name: str, model_type: str, service: str, - version: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - description: Optional[str] = None, - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - python: Optional[Dict[str, Any]] = None, - docker: Optional[Dict[str, Any]] = None, - working_dir: Optional[str] = None, + version: str | None = None, + labels: dict[str, str] | None = None, + description: str | None = None, + include: list[str] | None = None, + exclude: list[str] | None = None, + python: dict[str, Any] | None = None, + docker: dict[str, Any] | None = None, + working_dir: str | None = None, ) -> bento.Bento: """Build a BentoML Model and Bento bundle. diff --git a/src/zenml/integrations/bentoml/steps/bentoml_deployer.py b/src/zenml/integrations/bentoml/steps/bentoml_deployer.py index 85560a05d84..8ffc9586f1c 100644 --- a/src/zenml/integrations/bentoml/steps/bentoml_deployer.py +++ b/src/zenml/integrations/bentoml/steps/bentoml_deployer.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the BentoML model deployer pipeline step.""" -from typing import List, Optional, Tuple, cast +from typing import cast import bentoml from bentoml._internal.bento import bento @@ -49,21 +49,21 @@ def bentoml_model_deployer_step( port: int, deployment_type: BentoMLDeploymentType = BentoMLDeploymentType.LOCAL, deploy_decision: bool = True, - workers: Optional[int] = 1, - backlog: Optional[int] = 2048, + workers: int | None = 1, + backlog: int | None = 2048, production: bool = False, - working_dir: Optional[str] = None, - host: Optional[str] = None, - image: Optional[str] = None, - image_tag: Optional[str] = None, - platform: Optional[str] = None, - ssl_certfile: Optional[str] = None, - ssl_keyfile: Optional[str] = None, - ssl_keyfile_password: Optional[str] = None, - ssl_version: Optional[str] = None, - ssl_cert_reqs: Optional[str] = None, - ssl_ca_certs: Optional[str] = None, - ssl_ciphers: Optional[str] = None, + working_dir: str | None = None, + host: str | None = None, + image: str | None = None, + image_tag: str | None = None, + platform: str | None = None, + ssl_certfile: str | None = None, + ssl_keyfile: str | None = None, + ssl_keyfile_password: str | None = None, + ssl_version: str | None = None, + ssl_cert_reqs: str | None = None, + ssl_ca_certs: str | None = None, + ssl_ciphers: str | None = None, timeout: int = 30, ) -> BaseService: """Model deployer pipeline step for BentoML. @@ -109,7 +109,7 @@ def bentoml_model_deployer_step( # Return the apis endpoint of the defined service to use in the predict. # This is a workaround to get the endpoints of the service defined as functions # from the user code in the BentoML service. - def service_apis(bento_tag: str) -> List[str]: + def service_apis(bento_tag: str) -> list[str]: # Add working dir in the bentoml load service = bentoml.load( bento_identifier=bento_tag, @@ -121,7 +121,7 @@ def service_apis(bento_tag: str) -> List[str]: def create_deployment_config( deployment_type: BentoMLDeploymentType, - ) -> Tuple[ServiceConfig, ServiceType]: + ) -> tuple[ServiceConfig, ServiceType]: common_config = { "model_name": model_name, "bento_tag": str(bento.tag), @@ -167,7 +167,7 @@ def create_deployment_config( ) # Creating a new service with inactive state and status by default - service: Optional[BaseService] = None + service: BaseService | None = None if existing_services: if deployment_type == BentoMLDeploymentType.CONTAINER: service = cast( diff --git a/src/zenml/integrations/bitbucket/__init__.py b/src/zenml/integrations/bitbucket/__init__.py index df55482b097..654668acd8a 100644 --- a/src/zenml/integrations/bitbucket/__init__.py +++ b/src/zenml/integrations/bitbucket/__init__.py @@ -12,7 +12,6 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Initialization of the bitbucket ZenML integration.""" -from typing import List, Type from zenml.integrations.constants import BITBUCKET from zenml.integrations.integration import Integration @@ -25,10 +24,10 @@ class BitbucketIntegration(Integration): """Definition of bitbucket integration for ZenML.""" NAME = BITBUCKET - REQUIREMENTS: List[str] = [] + REQUIREMENTS: list[str] = [] @classmethod - def plugin_flavors(cls) -> List[Type[BasePluginFlavor]]: + def plugin_flavors(cls) -> list[type[BasePluginFlavor]]: """Declare the event flavors for the bitbucket integration. Returns: diff --git a/src/zenml/integrations/bitbucket/plugins/bitbucket_webhook_event_source_flavor.py b/src/zenml/integrations/bitbucket/plugins/bitbucket_webhook_event_source_flavor.py index a389b6677ee..28eb1f5d765 100644 --- a/src/zenml/integrations/bitbucket/plugins/bitbucket_webhook_event_source_flavor.py +++ b/src/zenml/integrations/bitbucket/plugins/bitbucket_webhook_event_source_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Bitbucket webhook event source flavor.""" -from typing import ClassVar, Type +from typing import ClassVar from zenml.event_sources.webhooks.base_webhook_event_source import ( BaseWebhookEventSourceFlavor, @@ -30,14 +30,14 @@ class BitbucketWebhookEventSourceFlavor(BaseWebhookEventSourceFlavor): """Enables users to configure Bitbucket event sources.""" FLAVOR: ClassVar[str] = BITBUCKET_EVENT_FLAVOR - PLUGIN_CLASS: ClassVar[Type[BitbucketWebhookEventSourceHandler]] = ( + PLUGIN_CLASS: ClassVar[type[BitbucketWebhookEventSourceHandler]] = ( BitbucketWebhookEventSourceHandler ) # EventPlugin specific EVENT_SOURCE_CONFIG_CLASS: ClassVar[ - Type[BitbucketWebhookEventSourceConfiguration] + type[BitbucketWebhookEventSourceConfiguration] ] = BitbucketWebhookEventSourceConfiguration EVENT_FILTER_CONFIG_CLASS: ClassVar[ - Type[BitbucketWebhookEventFilterConfiguration] + type[BitbucketWebhookEventFilterConfiguration] ] = BitbucketWebhookEventFilterConfiguration diff --git a/src/zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py b/src/zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py index ac997f74073..36e7f0c3e72 100644 --- a/src/zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py +++ b/src/zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the Bitbucket webhook event source.""" -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any from uuid import UUID from pydantic import BaseModel, ConfigDict, Field @@ -58,9 +58,9 @@ class BitbucketEventType(StrEnum): class User(BaseModel): """Bitbucket User.""" - name: Optional[str] = None - email: Optional[str] = None - username: Optional[str] = None + name: str | None = None + email: str | None = None + username: str | None = None class Commit(BaseModel): @@ -68,7 +68,7 @@ class Commit(BaseModel): hash: str message: str - links: Dict[str, Any] + links: dict[str, Any] author: User @@ -78,21 +78,21 @@ class Repository(BaseModel): uuid: str name: str full_name: str - links: Dict[str, Any] + links: dict[str, Any] class PushChange(BaseModel): """Bitbucket Push Change.""" - new: Optional[Dict[str, Any]] = None - old: Optional[Dict[str, Any]] = None - commits: List[Commit] + new: dict[str, Any] | None = None + old: dict[str, Any] | None = None + commits: list[Commit] class Push(BaseModel): """Bitbucket Push.""" - changes: List[PushChange] + changes: list[PushChange] class BitbucketEvent(BaseEvent): @@ -104,7 +104,7 @@ class BitbucketEvent(BaseEvent): model_config = ConfigDict(extra="allow") @property - def branch(self) -> Optional[str]: + def branch(self) -> str | None: """The branch the event happened on. Returns: @@ -117,7 +117,7 @@ def branch(self) -> Optional[str]: return None @property - def event_type(self) -> Union[BitbucketEventType, str]: + def event_type(self) -> BitbucketEventType | str: """The type of Bitbucket event. Args: @@ -151,9 +151,9 @@ def event_type(self) -> Union[BitbucketEventType, str]: class BitbucketWebhookEventFilterConfiguration(WebhookEventFilterConfig): """Configuration for Bitbucket event filters.""" - repo: Optional[str] = None - branch: Optional[str] = None - event_type: Optional[BitbucketEventType] = None + repo: str | None = None + branch: str | None = None + event_type: BitbucketEventType | None = None def event_matches_filter(self, event: BaseEvent) -> bool: """Checks the filter against the inbound event. @@ -181,15 +181,15 @@ def event_matches_filter(self, event: BaseEvent) -> bool: class BitbucketWebhookEventSourceConfiguration(WebhookEventSourceConfig): """Configuration for Bitbucket source filters.""" - webhook_secret: Optional[str] = Field( + webhook_secret: str | None = Field( default=None, title="The webhook secret for the event source.", ) - webhook_secret_id: Optional[UUID] = Field( + webhook_secret_id: UUID | None = Field( default=None, description="The ID of the secret containing the webhook secret.", ) - rotate_secret: Optional[bool] = Field( + rotate_secret: bool | None = Field( default=None, description="Set to rotate the webhook secret." ) @@ -201,7 +201,7 @@ class BitbucketWebhookEventSourceHandler(BaseWebhookEventSourceHandler): """Handler for all Bitbucket events.""" @property - def config_class(self) -> Type[BitbucketWebhookEventSourceConfiguration]: + def config_class(self) -> type[BitbucketWebhookEventSourceConfiguration]: """Returns the webhook event source configuration class. Returns: @@ -210,7 +210,7 @@ def config_class(self) -> Type[BitbucketWebhookEventSourceConfiguration]: return BitbucketWebhookEventSourceConfiguration @property - def filter_class(self) -> Type[BitbucketWebhookEventFilterConfiguration]: + def filter_class(self) -> type[BitbucketWebhookEventFilterConfiguration]: """Returns the webhook event filter configuration class. Returns: @@ -219,7 +219,7 @@ def filter_class(self) -> Type[BitbucketWebhookEventFilterConfiguration]: return BitbucketWebhookEventFilterConfiguration @property - def flavor_class(self) -> Type[BaseWebhookEventSourceFlavor]: + def flavor_class(self) -> type[BaseWebhookEventSourceFlavor]: """Returns the flavor class of the plugin. Returns: @@ -231,7 +231,7 @@ def flavor_class(self) -> Type[BaseWebhookEventSourceFlavor]: return BitbucketWebhookEventSourceFlavor - def _interpret_event(self, event: Dict[str, Any]) -> BitbucketEvent: + def _interpret_event(self, event: dict[str, Any]) -> BitbucketEvent: """Converts the generic event body into a event-source specific pydantic model. Args: @@ -252,7 +252,7 @@ def _interpret_event(self, event: Dict[str, Any]) -> BitbucketEvent: def _get_webhook_secret( self, event_source: EventSourceResponse - ) -> Optional[str]: + ) -> str | None: """Get the webhook secret for the event source. Args: @@ -440,7 +440,7 @@ def _process_event_source_delete( self, event_source: EventSourceResponse, config: EventSourceConfig, - force: Optional[bool] = False, + force: bool | None = False, ) -> None: """Process an event source before it is deleted from the database. diff --git a/src/zenml/integrations/comet/__init__.py b/src/zenml/integrations/comet/__init__.py index 9908279a3dc..0b60722cf0a 100644 --- a/src/zenml/integrations/comet/__init__.py +++ b/src/zenml/integrations/comet/__init__.py @@ -16,9 +16,7 @@ The CometML integrations currently enables you to use Comet tracking as a convenient way to visualize your experiment runs within the Comet ui. """ -from typing import List, Type -from zenml.enums import StackComponentType from zenml.integrations.constants import COMET from zenml.integrations.integration import Integration from zenml.stack import Flavor @@ -33,7 +31,7 @@ class CometIntegration(Integration): REQUIREMENTS = ["comet-ml>=3.0.0"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Comet integration. Returns: diff --git a/src/zenml/integrations/comet/experiment_trackers/comet_experiment_tracker.py b/src/zenml/integrations/comet/experiment_trackers/comet_experiment_tracker.py index 7dee29cef56..a3ef2bd195e 100644 --- a/src/zenml/integrations/comet/experiment_trackers/comet_experiment_tracker.py +++ b/src/zenml/integrations/comet/experiment_trackers/comet_experiment_tracker.py @@ -14,7 +14,7 @@ """Implementation for the Comet experiment tracker.""" import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, cast +from typing import TYPE_CHECKING, Any, cast from comet_ml import Experiment # type: ignore @@ -52,7 +52,7 @@ def config(self) -> CometExperimentTrackerConfig: return cast(CometExperimentTrackerConfig, self._config) @property - def settings_class(self) -> Type[CometExperimentTrackerSettings]: + def settings_class(self) -> type[CometExperimentTrackerSettings]: """Settings class for the Comet experiment tracker. Returns: @@ -80,7 +80,7 @@ def prepare_step_run(self, info: "StepRunInfo") -> None: def get_step_run_metadata( self, info: "StepRunInfo" - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Get component- and step-specific metadata after a step ran. Args: @@ -89,8 +89,8 @@ def get_step_run_metadata( Returns: A dictionary of metadata. """ - exp_url: Optional[str] = None - exp_name: Optional[str] = None + exp_url: str | None = None + exp_name: str | None = None if self.experiment: exp_url = self.experiment.url @@ -128,8 +128,8 @@ def cleanup_step_run(self, info: "StepRunInfo", step_failed: bool) -> None: def log_metrics( self, - metrics: Dict[str, Any], - step: Optional[int] = None, + metrics: dict[str, Any], + step: int | None = None, ) -> None: """Logs metrics to the Comet experiment. @@ -140,7 +140,7 @@ def log_metrics( if self.experiment: self.experiment.log_metrics(metrics, step=step) - def log_params(self, params: Dict[str, Any]) -> None: + def log_params(self, params: dict[str, Any]) -> None: """Logs parameters to the Comet experiment. Args: @@ -152,8 +152,8 @@ def log_params(self, params: Dict[str, Any]) -> None: def _initialize_comet( self, run_name: str, - tags: List[str], - settings: Union[Dict[str, Any], None] = None, + tags: list[str], + settings: dict[str, Any] | None = None, ) -> None: """Initializes a Comet experiment. diff --git a/src/zenml/integrations/comet/flavors/comet_experiment_tracker_flavor.py b/src/zenml/integrations/comet/flavors/comet_experiment_tracker_flavor.py index 1f8179ea9dc..9f0c5a48c6e 100644 --- a/src/zenml/integrations/comet/flavors/comet_experiment_tracker_flavor.py +++ b/src/zenml/integrations/comet/flavors/comet_experiment_tracker_flavor.py @@ -16,10 +16,6 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, - Optional, - Type, ) from zenml.config.base_settings import BaseSettings @@ -45,9 +41,9 @@ class CometExperimentTrackerSettings(BaseSettings): settings: Settings for the Comet experiment. """ - run_name: Optional[str] = None - tags: List[str] = [] - settings: Dict[str, Any] = {} + run_name: str | None = None + tags: list[str] = [] + settings: dict[str, Any] = {} class CometExperimentTrackerConfig( @@ -63,8 +59,8 @@ class CometExperimentTrackerConfig( """ api_key: str = SecretField() - workspace: Optional[str] = None - project_name: Optional[str] = None + workspace: str | None = None + project_name: str | None = None class CometExperimentTrackerFlavor(BaseExperimentTrackerFlavor): @@ -80,7 +76,7 @@ def name(self) -> str: return COMET_EXPERIMENT_TRACKER_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -89,7 +85,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -107,7 +103,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/experiment_tracker/comet.png" @property - def config_class(self) -> Type[CometExperimentTrackerConfig]: + def config_class(self) -> type[CometExperimentTrackerConfig]: """Returns `CometExperimentTrackerConfig` config class. Returns: @@ -116,7 +112,7 @@ def config_class(self) -> Type[CometExperimentTrackerConfig]: return CometExperimentTrackerConfig @property - def implementation_class(self) -> Type["CometExperimentTracker"]: + def implementation_class(self) -> type["CometExperimentTracker"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/databricks/__init__.py b/src/zenml/integrations/databricks/__init__.py index 40a0ab370de..5a015729284 100644 --- a/src/zenml/integrations/databricks/__init__.py +++ b/src/zenml/integrations/databricks/__init__.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Initialization of the Databricks integration for ZenML.""" -from typing import List, Type, Optional from zenml.integrations.constants import DATABRICKS @@ -35,8 +34,8 @@ class DatabricksIntegration(Integration): @classmethod def get_requirements( - cls, target_os: Optional[str] = None, python_version: Optional[str] = None - ) -> List[str]: + cls, target_os: str | None = None, python_version: str | None = None + ) -> list[str]: """Method to get the requirements for the integration. Args: @@ -54,7 +53,7 @@ def get_requirements( PandasIntegration.get_requirements(target_os=target_os, python_version=python_version) @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Databricks integration. Returns: diff --git a/src/zenml/integrations/databricks/flavors/databricks_model_deployer_flavor.py b/src/zenml/integrations/databricks/flavors/databricks_model_deployer_flavor.py index 0ce3fe30c18..acdec9a068a 100644 --- a/src/zenml/integrations/databricks/flavors/databricks_model_deployer_flavor.py +++ b/src/zenml/integrations/databricks/flavors/databricks_model_deployer_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Databricks model deployer flavor.""" -from typing import TYPE_CHECKING, Dict, Optional, Type +from typing import TYPE_CHECKING from pydantic import BaseModel @@ -35,9 +35,9 @@ class DatabricksBaseConfig(BaseModel): workload_size: str scale_to_zero_enabled: bool = False - env_vars: Optional[Dict[str, str]] = None - workload_type: Optional[str] = None - endpoint_secret_name: Optional[str] = None + env_vars: dict[str, str] | None = None + workload_type: str | None = None + endpoint_secret_name: str | None = None class DatabricksModelDeployerConfig(BaseModelDeployerConfig): @@ -51,9 +51,9 @@ class DatabricksModelDeployerConfig(BaseModelDeployerConfig): """ host: str - secret_name: Optional[str] = None - client_id: Optional[str] = SecretField(default=None) - client_secret: Optional[str] = SecretField(default=None) + secret_name: str | None = None + client_id: str | None = SecretField(default=None) + client_secret: str | None = SecretField(default=None) class DatabricksModelDeployerFlavor(BaseModelDeployerFlavor): @@ -69,7 +69,7 @@ def name(self) -> str: return DATABRICKS_MODEL_DEPLOYER_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -78,7 +78,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -96,7 +96,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/databricks.png" @property - def config_class(self) -> Type[DatabricksModelDeployerConfig]: + def config_class(self) -> type[DatabricksModelDeployerConfig]: """Returns `DatabricksModelDeployerConfig` config class. Returns: @@ -105,7 +105,7 @@ def config_class(self) -> Type[DatabricksModelDeployerConfig]: return DatabricksModelDeployerConfig @property - def implementation_class(self) -> Type["DatabricksModelDeployer"]: + def implementation_class(self) -> type["DatabricksModelDeployer"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py b/src/zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py index af9540a88e5..0fce3a5ea18 100644 --- a/src/zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py +++ b/src/zenml/integrations/databricks/flavors/databricks_orchestrator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Databricks orchestrator base config and settings.""" -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type +from typing import TYPE_CHECKING from pydantic import Field @@ -49,55 +49,55 @@ class DatabricksOrchestratorSettings(BaseSettings): """ # Cluster Configuration - spark_version: Optional[str] = Field( + spark_version: str | None = Field( default=None, description="Apache Spark version for the Databricks cluster. " "Uses workspace default if not specified. Example: '3.2.x-scala2.12'", ) - num_workers: Optional[int] = Field( + num_workers: int | None = Field( default=None, description="Fixed number of worker nodes. Cannot be used with autoscaling.", ) - node_type_id: Optional[str] = Field( + node_type_id: str | None = Field( default=None, description="Databricks node type identifier. " "Refer to Databricks documentation for available instance types. " "Example: 'i3.xlarge'", ) - policy_id: Optional[str] = Field( + policy_id: str | None = Field( default=None, description="Databricks cluster policy ID for governance and cost control.", ) - autotermination_minutes: Optional[int] = Field( + autotermination_minutes: int | None = Field( default=None, description="Minutes of inactivity before automatic cluster termination. " "Helps control costs by shutting down idle clusters.", ) - autoscale: Tuple[int, int] = Field( + autoscale: tuple[int, int] = Field( default=(0, 1), description="Cluster autoscaling bounds as (min_workers, max_workers). " "Automatically adjusts cluster size based on workload.", ) - single_user_name: Optional[str] = Field( + single_user_name: str | None = Field( default=None, description="Databricks username for single-user cluster access mode.", ) - spark_conf: Optional[Dict[str, str]] = Field( + spark_conf: dict[str, str] | None = Field( default=None, description="Custom Spark configuration properties as key-value pairs. " "Example: {'spark.sql.adaptive.enabled': 'true', 'spark.sql.adaptive.coalescePartitions.enabled': 'true'}", ) - spark_env_vars: Optional[Dict[str, str]] = Field( + spark_env_vars: dict[str, str] | None = Field( default=None, description="Environment variables for the Spark driver and executors. " "Example: {'SPARK_WORKER_MEMORY': '4g', 'SPARK_DRIVER_MEMORY': '2g'}", ) - schedule_timezone: Optional[str] = Field( + schedule_timezone: str | None = Field( default=None, description="Timezone for scheduled pipeline execution. " "Uses IANA timezone format (e.g., 'America/New_York').", ) - availability_type: Optional[DatabricksAvailabilityType] = Field( + availability_type: DatabricksAvailabilityType | None = Field( default=None, description="Instance availability type: ON_DEMAND (guaranteed), SPOT (cost-optimized), " "or SPOT_WITH_FALLBACK (spot with on-demand backup).", @@ -160,7 +160,7 @@ def name(self) -> str: return DATABRICKS_ORCHESTRATOR_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -169,7 +169,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -187,7 +187,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/databricks.png" @property - def config_class(self) -> Type[DatabricksOrchestratorConfig]: + def config_class(self) -> type[DatabricksOrchestratorConfig]: """Returns `KubeflowOrchestratorConfig` config class. Returns: @@ -196,7 +196,7 @@ def config_class(self) -> Type[DatabricksOrchestratorConfig]: return DatabricksOrchestratorConfig @property - def implementation_class(self) -> Type["DatabricksOrchestrator"]: + def implementation_class(self) -> type["DatabricksOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/databricks/model_deployers/databricks_model_deployer.py b/src/zenml/integrations/databricks/model_deployers/databricks_model_deployer.py index 696bf1131b5..ce32881e46b 100644 --- a/src/zenml/integrations/databricks/model_deployers/databricks_model_deployer.py +++ b/src/zenml/integrations/databricks/model_deployers/databricks_model_deployer.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the Databricks Model Deployer.""" -from typing import ClassVar, Dict, Optional, Tuple, Type, cast +from typing import ClassVar, cast from uuid import UUID from zenml.analytics.enums import AnalyticsEvent @@ -45,7 +45,7 @@ class DatabricksModelDeployer(BaseModelDeployer): """Databricks endpoint model deployer.""" NAME: ClassVar[str] = "Databricks" - FLAVOR: ClassVar[Type[BaseModelDeployerFlavor]] = ( + FLAVOR: ClassVar[type[BaseModelDeployerFlavor]] = ( DatabricksModelDeployerFlavor ) @@ -59,7 +59,7 @@ def config(self) -> DatabricksModelDeployerConfig: return cast(DatabricksModelDeployerConfig, self._config) @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates the stack. Returns: @@ -69,7 +69,7 @@ def validator(self) -> Optional[StackValidator]: def _validate_if_secret_or_token_is_present( stack: "Stack", - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: """Check if client id and client secret or secret name is present in the stack. Args: @@ -234,7 +234,7 @@ def perform_delete_model( @staticmethod def get_model_server_info( # type: ignore[override] service_instance: "DatabricksDeploymentService", - ) -> Dict[str, Optional[str]]: + ) -> dict[str, str | None]: """Return implementation specific information that might be relevant to the user. Args: diff --git a/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator.py b/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator.py index 9acf3cac29c..b66de51a98e 100644 --- a/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator.py +++ b/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator.py @@ -15,7 +15,7 @@ import itertools import os -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, cast +from typing import TYPE_CHECKING, Optional, cast from uuid import UUID from databricks.sdk import WorkspaceClient as DatabricksClient @@ -78,7 +78,7 @@ class DatabricksOrchestrator(WheeledOrchestrator): """Databricks orchestrator.""" @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates the stack. In the remote case, checks that the stack contains a container registry, @@ -90,7 +90,7 @@ def validator(self) -> Optional[StackValidator]: def _validate_remote_components( stack: "Stack", - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: for component in stack.components.values(): if not component.config.is_local: continue @@ -134,7 +134,7 @@ def config(self) -> DatabricksOrchestratorConfig: return cast(DatabricksOrchestratorConfig, self._config) @property - def settings_class(self) -> Type[DatabricksOrchestratorSettings]: + def settings_class(self) -> type[DatabricksOrchestratorSettings]: """Settings class for the Databricks orchestrator. Returns: @@ -173,10 +173,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to @@ -230,7 +230,7 @@ def submit_pipeline( # Create a callable for future compilation into a dsl.Pipeline. def _construct_databricks_pipeline( zenml_project_wheel: str, job_cluster_key: str - ) -> List[DatabricksTask]: + ) -> list[DatabricksTask]: """Create a databrcks task for each step. This should contain the name of the step or task and configures the @@ -356,8 +356,8 @@ def _upload_and_run_pipeline( self, pipeline_name: str, settings: DatabricksOrchestratorSettings, - tasks: List[DatabricksTask], - env_vars: Dict[str, str], + tasks: list[DatabricksTask], + env_vars: dict[str, str], job_cluster_key: str, schedule: Optional["ScheduleResponse"] = None, ) -> None: @@ -440,7 +440,7 @@ def _upload_and_run_pipeline( def get_pipeline_run_metadata( self, run_id: UUID - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Get general component-specific metadata for a pipeline run. Args: diff --git a/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py b/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py index de2001e3fee..efbbd21ef01 100644 --- a/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py +++ b/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py @@ -16,7 +16,7 @@ import os import sys from importlib.metadata import distribution -from typing import Any, List, Set +from typing import Any from zenml.entrypoints.step_entrypoint_configuration import ( StepEntrypointConfiguration, @@ -38,7 +38,7 @@ class DatabricksEntrypointConfiguration(StepEntrypointConfiguration): """ @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> set[str]: """Gets all options required for running with this configuration. Returns: @@ -54,7 +54,7 @@ def get_entrypoint_options(cls) -> Set[str]: def get_entrypoint_arguments( cls, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Gets all arguments that the entrypoint command should be called with. The argument list should be something that diff --git a/src/zenml/integrations/databricks/services/databricks_deployment.py b/src/zenml/integrations/databricks/services/databricks_deployment.py index 1bd3a699c54..020cf148c82 100644 --- a/src/zenml/integrations/databricks/services/databricks_deployment.py +++ b/src/zenml/integrations/databricks/services/databricks_deployment.py @@ -14,7 +14,8 @@ """Implementation of the Databricks Deployment service.""" import time -from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Union +from collections.abc import Generator import numpy as np import pandas as pd @@ -56,15 +57,15 @@ class DatabricksDeploymentConfig(DatabricksBaseConfig, ServiceConfig): """Databricks service configurations.""" - model_uri: Optional[str] = Field( + model_uri: str | None = Field( None, description="URI of the model to deploy. This can be a local path or a cloud storage path.", ) - host: Optional[str] = Field( + host: str | None = Field( None, description="Databricks host URL for the deployment." ) - def get_databricks_deployment_labels(self) -> Dict[str, str]: + def get_databricks_deployment_labels(self) -> dict[str, str]: """Generate labels for the Databricks deployment from the service configuration. These labels are attached to the Databricks deployment resource @@ -119,7 +120,7 @@ def __init__(self, config: DatabricksDeploymentConfig, **attrs: Any): """ super().__init__(config=config, **attrs) - def get_client_id_and_secret(self) -> Tuple[str, str, str]: + def get_client_id_and_secret(self) -> tuple[str, str, str]: """Get the Databricks client id and secret. Raises: @@ -159,7 +160,7 @@ def get_client_id_and_secret(self) -> Tuple[str, str, str]: raise ValueError("Host not found.") return host, client_id, client_secret - def _get_databricks_deployment_labels(self) -> Dict[str, str]: + def _get_databricks_deployment_labels(self) -> dict[str, str]: """Generate the labels for the Databricks deployment from the service configuration. Returns: @@ -195,7 +196,7 @@ def databricks_endpoint(self) -> ServingEndpointDetailed: ) @property - def prediction_url(self) -> Optional[str]: + def prediction_url(self) -> str | None: """The prediction URI exposed by the prediction service. Returns: @@ -246,7 +247,7 @@ def provision(self) -> None: "Failed to start Databricks inference endpoint service: No URL available, please check the Databricks console for more details." ) - def check_status(self) -> Tuple[ServiceState, str]: + def check_status(self) -> tuple[ServiceState, str]: """Check the the current operational state of the Databricks deployment. Returns: @@ -355,7 +356,7 @@ def predict( return np.array(response.json()["predictions"]) def get_logs( - self, follow: bool = False, tail: Optional[int] = None + self, follow: bool = False, tail: int | None = None ) -> Generator[str, bool, None]: """Retrieve the service logs. @@ -385,8 +386,7 @@ def log_generator() -> Generator[str, bool, None]: log_lines = log_lines[-tail:] # Yield only new lines - for line in log_lines[last_log_count:]: - yield line + yield from log_lines[last_log_count:] last_log_count = len(log_lines) diff --git a/src/zenml/integrations/databricks/utils/databricks_utils.py b/src/zenml/integrations/databricks/utils/databricks_utils.py index 58b371c909a..29e68f5ac86 100644 --- a/src/zenml/integrations/databricks/utils/databricks_utils.py +++ b/src/zenml/integrations/databricks/utils/databricks_utils.py @@ -14,7 +14,6 @@ """Databricks utilities.""" import re -from typing import Dict, List, Optional from databricks.sdk.service.compute import Library, PythonPyPiLibrary from databricks.sdk.service.jobs import PythonWheelTask, TaskDependency @@ -26,11 +25,11 @@ def convert_step_to_task( task_name: str, command: str, - arguments: List[str], - libraries: Optional[List[str]] = None, - depends_on: Optional[List[str]] = None, - zenml_project_wheel: Optional[str] = None, - job_cluster_key: Optional[str] = None, + arguments: list[str], + libraries: list[str] | None = None, + depends_on: list[str] | None = None, + zenml_project_wheel: str | None = None, + job_cluster_key: str | None = None, ) -> DatabricksTask: """Convert a ZenML step to a Databricks task. @@ -69,7 +68,7 @@ def convert_step_to_task( ) -def sanitize_labels(labels: Dict[str, str]) -> None: +def sanitize_labels(labels: dict[str, str]) -> None: """Update the label values to be valid Kubernetes labels. See: diff --git a/src/zenml/integrations/deepchecks/__init__.py b/src/zenml/integrations/deepchecks/__init__.py index 99e69e90992..f858aebefd1 100644 --- a/src/zenml/integrations/deepchecks/__init__.py +++ b/src/zenml/integrations/deepchecks/__init__.py @@ -21,9 +21,8 @@ `SuiteResults`. """ -from typing import List, Type, Optional -from zenml.integrations.constants import DEEPCHECKS, PANDAS +from zenml.integrations.constants import DEEPCHECKS from zenml.integrations.integration import Integration from zenml.stack import Flavor @@ -59,8 +58,8 @@ def activate(cls) -> None: @classmethod def get_requirements( - cls, target_os: Optional[str] = None, python_version: Optional[str] = None - ) -> List[str]: + cls, target_os: str | None = None, python_version: str | None = None + ) -> list[str]: """Method to get the requirements for the integration. Args: @@ -76,7 +75,7 @@ def get_requirements( PandasIntegration.get_requirements(target_os=target_os, python_version=python_version) @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Deepchecks integration. Returns: diff --git a/src/zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py b/src/zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py index 60dfccc8d51..2b135b1680f 100644 --- a/src/zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py +++ b/src/zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py @@ -16,14 +16,8 @@ from typing import ( Any, ClassVar, - Dict, - List, - Optional, - Sequence, - Tuple, - Type, - Union, ) +from collections.abc import Sequence import pandas as pd from deepchecks.core.checks import BaseCheck @@ -61,14 +55,14 @@ class DeepchecksDataValidator(BaseDataValidator): """Deepchecks data validator stack component.""" NAME: ClassVar[str] = "Deepchecks" - FLAVOR: ClassVar[Type[BaseDataValidatorFlavor]] = ( + FLAVOR: ClassVar[type[BaseDataValidatorFlavor]] = ( DeepchecksDataValidatorFlavor ) @staticmethod def _split_checks( check_list: Sequence[str], - ) -> Tuple[Sequence[str], Sequence[str]]: + ) -> tuple[Sequence[str], Sequence[str]]: """Split a list of check identifiers in two lists, one for tabular and one for computer vision checks. Args: @@ -97,16 +91,16 @@ def _split_checks( @classmethod def _create_and_run_check_suite( cls, - check_enum: Type[DeepchecksValidationCheck], - reference_dataset: Union[pd.DataFrame, DataLoader[Any]], - comparison_dataset: Optional[ - Union[pd.DataFrame, DataLoader[Any]] - ] = None, - models: Optional[List[Union[ClassifierMixin, Module]]] = None, - check_list: Optional[Sequence[str]] = None, - dataset_kwargs: Dict[str, Any] = {}, - check_kwargs: Dict[str, Dict[str, Any]] = {}, - run_kwargs: Dict[str, Any] = {}, + check_enum: type[DeepchecksValidationCheck], + reference_dataset: pd.DataFrame | DataLoader[Any], + comparison_dataset: None | ( + pd.DataFrame | DataLoader[Any] + ) = None, + models: list[ClassifierMixin | Module] | None = None, + check_list: Sequence[str] | None = None, + dataset_kwargs: dict[str, Any] = {}, + check_kwargs: dict[str, dict[str, Any]] = {}, + run_kwargs: dict[str, Any] = {}, ) -> SuiteResult: """Create and run a Deepchecks check suite corresponding to the input parameters. @@ -171,7 +165,7 @@ def _create_and_run_check_suite( is_multi_model = True # if the models are of different types, raise an error # only the same type of models can be used for comparison - if len(set(type(model) for model in models)) > 1: + if len({type(model) for model in models}) > 1: raise TypeError( "Models used for comparison checks must be of the same type." ) @@ -345,12 +339,12 @@ def _create_and_run_check_suite( def data_validation( self, - dataset: Union[pd.DataFrame, DataLoader[Any]], - comparison_dataset: Optional[Any] = None, - check_list: Optional[Sequence[str]] = None, - dataset_kwargs: Dict[str, Any] = {}, - check_kwargs: Dict[str, Dict[str, Any]] = {}, - run_kwargs: Dict[str, Any] = {}, + dataset: pd.DataFrame | DataLoader[Any], + comparison_dataset: Any | None = None, + check_list: Sequence[str] | None = None, + dataset_kwargs: dict[str, Any] = {}, + check_kwargs: dict[str, dict[str, Any]] = {}, + run_kwargs: dict[str, Any] = {}, **kwargs: Any, ) -> SuiteResult: """Run one or more Deepchecks data validation checks on a dataset. @@ -400,7 +394,7 @@ def data_validation( Returns: A Deepchecks SuiteResult with the results of the validation. """ - check_enum: Type[DeepchecksValidationCheck] + check_enum: type[DeepchecksValidationCheck] if comparison_dataset is None: check_enum = DeepchecksDataIntegrityCheck else: @@ -418,13 +412,13 @@ def data_validation( def model_validation( self, - dataset: Union[pd.DataFrame, DataLoader[Any]], - model: Union[ClassifierMixin, Module], - comparison_dataset: Optional[Any] = None, - check_list: Optional[Sequence[str]] = None, - dataset_kwargs: Dict[str, Any] = {}, - check_kwargs: Dict[str, Dict[str, Any]] = {}, - run_kwargs: Dict[str, Any] = {}, + dataset: pd.DataFrame | DataLoader[Any], + model: ClassifierMixin | Module, + comparison_dataset: Any | None = None, + check_list: Sequence[str] | None = None, + dataset_kwargs: dict[str, Any] = {}, + check_kwargs: dict[str, dict[str, Any]] = {}, + run_kwargs: dict[str, Any] = {}, **kwargs: Any, ) -> Any: """Run one or more Deepchecks model validation checks. @@ -475,7 +469,7 @@ def model_validation( Returns: A Deepchecks SuiteResult with the results of the validation. """ - check_enum: Type[DeepchecksValidationCheck] + check_enum: type[DeepchecksValidationCheck] if comparison_dataset is None: check_enum = DeepchecksModelValidationCheck else: diff --git a/src/zenml/integrations/deepchecks/flavors/deepchecks_data_validator_flavor.py b/src/zenml/integrations/deepchecks/flavors/deepchecks_data_validator_flavor.py index beacdd52467..130ccff78f4 100644 --- a/src/zenml/integrations/deepchecks/flavors/deepchecks_data_validator_flavor.py +++ b/src/zenml/integrations/deepchecks/flavors/deepchecks_data_validator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Deepchecks data validator flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.data_validators.base_data_validator import BaseDataValidatorFlavor from zenml.integrations.deepchecks import DEEPCHECKS_DATA_VALIDATOR_FLAVOR @@ -37,7 +37,7 @@ def name(self) -> str: return DEEPCHECKS_DATA_VALIDATOR_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -46,7 +46,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -64,7 +64,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/data_validator/deepchecks.png" @property - def implementation_class(self) -> Type["DeepchecksDataValidator"]: + def implementation_class(self) -> type["DeepchecksDataValidator"]: """Implementation class. Returns: diff --git a/src/zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py b/src/zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py index 010839e6b76..7b5809403ba 100644 --- a/src/zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py +++ b/src/zenml/integrations/deepchecks/materializers/deepchecks_dataset_materializer.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of Deepchecks dataset materializer.""" -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar from deepchecks.tabular import Dataset @@ -29,10 +29,10 @@ class DeepchecksDatasetMaterializer(PandasMaterializer): """Materializer to read data to and from Deepchecks dataset.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Dataset,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (Dataset,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - def load(self, data_type: Type[Any]) -> Dataset: + def load(self, data_type: type[Any]) -> Dataset: """Reads pandas dataframes and creates `deepchecks.Dataset` from it. Args: @@ -54,7 +54,7 @@ def save(self, dataset: Dataset) -> None: def save_visualizations( self, dataset: Dataset - ) -> Dict[str, VisualizationType]: + ) -> dict[str, VisualizationType]: """Saves visualizations for the given Deepchecks dataset. Args: @@ -65,7 +65,7 @@ def save_visualizations( """ return super().save_visualizations(dataset.data) - def extract_metadata(self, dataset: Dataset) -> Dict[str, "MetadataType"]: + def extract_metadata(self, dataset: Dataset) -> dict[str, "MetadataType"]: """Extract metadata from the given `Dataset` object. Args: diff --git a/src/zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py b/src/zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py index 3ec06777636..9843021d294 100644 --- a/src/zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py +++ b/src/zenml/integrations/deepchecks/materializers/deepchecks_results_materializer.py @@ -15,7 +15,7 @@ """Implementation of Deepchecks suite results materializer.""" import os -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, ClassVar from deepchecks.core.check_result import CheckResult from deepchecks.core.suite import SuiteResult @@ -35,7 +35,7 @@ class DeepchecksResultMaterializer(BaseMaterializer): """Materializer to read data to and from CheckResult and SuiteResult objects.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = ( + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = ( CheckResult, SuiteResult, ) @@ -43,7 +43,7 @@ class DeepchecksResultMaterializer(BaseMaterializer): ArtifactType.DATA_ANALYSIS ) - def load(self, data_type: Type[Any]) -> Union[CheckResult, SuiteResult]: + def load(self, data_type: type[Any]) -> CheckResult | SuiteResult: """Reads a Deepchecks check or suite result from a serialized JSON file. Args: @@ -66,7 +66,7 @@ def load(self, data_type: Type[Any]) -> Union[CheckResult, SuiteResult]: raise RuntimeError(f"Unknown data type: {data_type}") return res - def save(self, result: Union[CheckResult, SuiteResult]) -> None: + def save(self, result: CheckResult | SuiteResult) -> None: """Creates a JSON serialization for a CheckResult or SuiteResult. Args: @@ -77,8 +77,8 @@ def save(self, result: Union[CheckResult, SuiteResult]) -> None: io_utils.write_file_contents_as_string(filepath, serialized_json) def save_visualizations( - self, result: Union[CheckResult, SuiteResult] - ) -> Dict[str, VisualizationType]: + self, result: CheckResult | SuiteResult + ) -> dict[str, VisualizationType]: """Saves visualizations for the given Deepchecks result. Args: @@ -94,8 +94,8 @@ def save_visualizations( return {visualization_path: VisualizationType.HTML} def extract_metadata( - self, result: Union[CheckResult, SuiteResult] - ) -> Dict[str, "MetadataType"]: + self, result: CheckResult | SuiteResult + ) -> dict[str, "MetadataType"]: """Extract metadata from the given Deepchecks result. Args: diff --git a/src/zenml/integrations/deepchecks/steps/__init__.py b/src/zenml/integrations/deepchecks/steps/__init__.py index 28dee8c31e0..adb79bed1d1 100644 --- a/src/zenml/integrations/deepchecks/steps/__init__.py +++ b/src/zenml/integrations/deepchecks/steps/__init__.py @@ -13,15 +13,3 @@ # permissions and limitations under the License. """Initialization of the Deepchecks Standard Steps.""" -from zenml.integrations.deepchecks.steps.deepchecks_data_drift import ( - deepchecks_data_drift_check_step, -) -from zenml.integrations.deepchecks.steps.deepchecks_data_integrity import ( - deepchecks_data_integrity_check_step, -) -from zenml.integrations.deepchecks.steps.deepchecks_model_drift import ( - deepchecks_model_drift_check_step, -) -from zenml.integrations.deepchecks.steps.deepchecks_model_validation import ( - deepchecks_model_validation_check_step, -) diff --git a/src/zenml/integrations/deepchecks/steps/deepchecks_data_drift.py b/src/zenml/integrations/deepchecks/steps/deepchecks_data_drift.py index fbe53d82456..c5e7454eddf 100644 --- a/src/zenml/integrations/deepchecks/steps/deepchecks_data_drift.py +++ b/src/zenml/integrations/deepchecks/steps/deepchecks_data_drift.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Implementation of the Deepchecks data drift validation step.""" -from typing import Any, Dict, Optional, Sequence, cast +from typing import Any, cast +from collections.abc import Sequence import pandas as pd from deepchecks.core.suite import SuiteResult @@ -31,10 +32,10 @@ def deepchecks_data_drift_check_step( reference_dataset: pd.DataFrame, target_dataset: pd.DataFrame, - check_list: Optional[Sequence[DeepchecksDataDriftCheck]] = None, - dataset_kwargs: Optional[Dict[str, Any]] = None, - check_kwargs: Optional[Dict[str, Any]] = None, - run_kwargs: Optional[Dict[str, Any]] = None, + check_list: Sequence[DeepchecksDataDriftCheck] | None = None, + dataset_kwargs: dict[str, Any] | None = None, + check_kwargs: dict[str, Any] | None = None, + run_kwargs: dict[str, Any] | None = None, ) -> SuiteResult: """Run data drift checks on two pandas datasets using Deepchecks. diff --git a/src/zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py b/src/zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py index dd830014f71..30f4a1898dd 100644 --- a/src/zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py +++ b/src/zenml/integrations/deepchecks/steps/deepchecks_data_integrity.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Implementation of the Deepchecks data integrity validation step.""" -from typing import Any, Dict, Optional, Sequence, cast +from typing import Any, Optional, cast +from collections.abc import Sequence import pandas as pd from deepchecks.core.suite import SuiteResult @@ -30,10 +31,10 @@ @step def deepchecks_data_integrity_check_step( dataset: pd.DataFrame, - check_list: Optional[Sequence[DeepchecksDataIntegrityCheck]] = None, - dataset_kwargs: Optional[Dict[str, Any]] = None, - check_kwargs: Optional[Dict[str, Any]] = None, - run_kwargs: Optional[Dict[str, Any]] = None, + check_list: Sequence[DeepchecksDataIntegrityCheck] | None = None, + dataset_kwargs: dict[str, Any] | None = None, + check_kwargs: dict[str, Any] | None = None, + run_kwargs: dict[str, Any] | None = None, ) -> SuiteResult: """Run data integrity checks on a pandas dataset using Deepchecks. diff --git a/src/zenml/integrations/deepchecks/steps/deepchecks_model_drift.py b/src/zenml/integrations/deepchecks/steps/deepchecks_model_drift.py index eafe8a3fc9b..fbe0da9f47e 100644 --- a/src/zenml/integrations/deepchecks/steps/deepchecks_model_drift.py +++ b/src/zenml/integrations/deepchecks/steps/deepchecks_model_drift.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Implementation of the Deepchecks model drift validation step.""" -from typing import Any, Dict, Optional, Sequence, cast +from typing import Any, Optional, cast +from collections.abc import Sequence import pandas as pd from deepchecks.core.suite import SuiteResult @@ -33,10 +34,10 @@ def deepchecks_model_drift_check_step( reference_dataset: pd.DataFrame, target_dataset: pd.DataFrame, model: ClassifierMixin, - check_list: Optional[Sequence[DeepchecksModelDriftCheck]] = None, - dataset_kwargs: Optional[Dict[str, Any]] = None, - check_kwargs: Optional[Dict[str, Any]] = None, - run_kwargs: Optional[Dict[str, Any]] = None, + check_list: Sequence[DeepchecksModelDriftCheck] | None = None, + dataset_kwargs: dict[str, Any] | None = None, + check_kwargs: dict[str, Any] | None = None, + run_kwargs: dict[str, Any] | None = None, ) -> SuiteResult: """Run model drift checks on two pandas DataFrames and an sklearn model. diff --git a/src/zenml/integrations/deepchecks/steps/deepchecks_model_validation.py b/src/zenml/integrations/deepchecks/steps/deepchecks_model_validation.py index b9ebd80ad13..bf472f5bdd5 100644 --- a/src/zenml/integrations/deepchecks/steps/deepchecks_model_validation.py +++ b/src/zenml/integrations/deepchecks/steps/deepchecks_model_validation.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Implementation of the Deepchecks model validation validation step.""" -from typing import Any, Dict, Optional, Sequence, cast +from typing import Any, Optional, cast +from collections.abc import Sequence import pandas as pd from deepchecks.core.suite import SuiteResult @@ -32,10 +33,10 @@ def deepchecks_model_validation_check_step( dataset: pd.DataFrame, model: ClassifierMixin, - check_list: Optional[Sequence[DeepchecksModelValidationCheck]] = None, - dataset_kwargs: Optional[Dict[str, Any]] = None, - check_kwargs: Optional[Dict[str, Any]] = None, - run_kwargs: Optional[Dict[str, Any]] = None, + check_list: Sequence[DeepchecksModelValidationCheck] | None = None, + dataset_kwargs: dict[str, Any] | None = None, + check_kwargs: dict[str, Any] | None = None, + run_kwargs: dict[str, Any] | None = None, ) -> SuiteResult: """Run model validation checks on a pandas DataFrame and an sklearn model. diff --git a/src/zenml/integrations/deepchecks/validation_checks.py b/src/zenml/integrations/deepchecks/validation_checks.py index b381ca03dfb..3cf02eee6fc 100644 --- a/src/zenml/integrations/deepchecks/validation_checks.py +++ b/src/zenml/integrations/deepchecks/validation_checks.py @@ -14,7 +14,6 @@ """Definition of the Deepchecks validation check types.""" import re -from typing import Type import deepchecks.tabular.checks as tabular_checks import deepchecks.vision.checks as vision_checks @@ -97,7 +96,7 @@ def is_vision_check(cls, check_name: str) -> bool: return check_name.startswith("deepchecks.vision.") @classmethod - def get_check_class(cls, check_name: str) -> Type[BaseCheck]: + def get_check_class(cls, check_name: str) -> type[BaseCheck]: """Get the Deepchecks check class associated with an enum value or a custom check name. Args: @@ -118,7 +117,7 @@ def get_check_class(cls, check_name: str) -> Type[BaseCheck]: cls.validate_check_name(check_name) try: - check_class: Type[BaseCheck] = ( + check_class: type[BaseCheck] = ( source_utils.load_and_validate_class( check_name, expected_class=BaseCheck ) @@ -139,7 +138,7 @@ def get_check_class(cls, check_name: str) -> Type[BaseCheck]: return check_class @property - def check_class(self) -> Type[BaseCheck]: + def check_class(self) -> type[BaseCheck]: """Convert the enum value to a valid Deepchecks check class. Returns: diff --git a/src/zenml/integrations/discord/__init__.py b/src/zenml/integrations/discord/__init__.py index 7f8e213255d..5f687936bb4 100644 --- a/src/zenml/integrations/discord/__init__.py +++ b/src/zenml/integrations/discord/__init__.py @@ -13,9 +13,7 @@ # permissions and limitations under the License. """Discord integration for alerter components.""" -from typing import List, Type -from zenml.enums import StackComponentType from zenml.integrations.constants import DISCORD from zenml.integrations.integration import Integration from zenml.stack import Flavor @@ -34,7 +32,7 @@ class DiscordIntegration(Integration): REQUIREMENTS_IGNORED_ON_UNINSTALL = ["aiohttp","asyncio"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Discord integration. Returns: diff --git a/src/zenml/integrations/discord/alerters/discord_alerter.py b/src/zenml/integrations/discord/alerters/discord_alerter.py index db97476259a..09c1713d7ca 100644 --- a/src/zenml/integrations/discord/alerters/discord_alerter.py +++ b/src/zenml/integrations/discord/alerters/discord_alerter.py @@ -14,7 +14,7 @@ """Implementation for discord flavor of alerter component.""" import asyncio -from typing import List, Optional, cast +from typing import cast from discord import Client, DiscordException, Embed, Intents, Message from pydantic import BaseModel @@ -35,24 +35,24 @@ class DiscordAlerterPayload(BaseModel): """Discord alerter payload implementation.""" - pipeline_name: Optional[str] = None - step_name: Optional[str] = None - stack_name: Optional[str] = None + pipeline_name: str | None = None + step_name: str | None = None + stack_name: str | None = None class DiscordAlerterParameters(BaseAlerterStepParameters): """Discord alerter parameters.""" # The ID of the Discord channel to use for communication. - discord_channel_id: Optional[str] = None + discord_channel_id: str | None = None # Set of messages that lead to approval in alerter.ask() - approve_msg_options: Optional[List[str]] = None + approve_msg_options: list[str] | None = None # Set of messages that lead to disapproval in alerter.ask() - disapprove_msg_options: Optional[List[str]] = None - payload: Optional[DiscordAlerterPayload] = None - include_format_blocks: Optional[bool] = True + disapprove_msg_options: list[str] | None = None + payload: DiscordAlerterPayload | None = None + include_format_blocks: bool | None = True class DiscordAlerter(BaseAlerter): @@ -68,7 +68,7 @@ def config(self) -> DiscordAlerterConfig: return cast(DiscordAlerterConfig, self._config) def _get_channel_id( - self, params: Optional[BaseAlerterStepParameters] = None + self, params: BaseAlerterStepParameters | None = None ) -> str: """Get the Discord channel ID to be used by post/ask. @@ -103,8 +103,8 @@ def _get_channel_id( ) def _get_approve_msg_options( - self, params: Optional[BaseAlerterStepParameters] - ) -> List[str]: + self, params: BaseAlerterStepParameters | None + ) -> list[str]: """Define which messages will lead to approval during ask(). Args: @@ -122,8 +122,8 @@ def _get_approve_msg_options( return DEFAULT_APPROVE_MSG_OPTIONS def _get_disapprove_msg_options( - self, params: Optional[BaseAlerterStepParameters] - ) -> List[str]: + self, params: BaseAlerterStepParameters | None + ) -> list[str]: """Define which messages will lead to disapproval during ask(). Args: @@ -141,8 +141,8 @@ def _get_disapprove_msg_options( return DEFAULT_DISAPPROVE_MSG_OPTIONS def _create_blocks( - self, message: str, params: Optional[BaseAlerterStepParameters] - ) -> Optional[Embed]: + self, message: str, params: BaseAlerterStepParameters | None + ) -> Embed | None: """Helper function to create discord blocks. Args: @@ -220,7 +220,7 @@ def start_client(self, client: Client) -> None: loop.close() def post( - self, message: str, params: Optional[BaseAlerterStepParameters] = None + self, message: str, params: BaseAlerterStepParameters | None = None ) -> bool: """Post a message to a Discord channel. @@ -265,7 +265,7 @@ async def on_ready() -> None: return message_sent def ask( - self, message: str, params: Optional[BaseAlerterStepParameters] = None + self, message: str, params: BaseAlerterStepParameters | None = None ) -> bool: """Post a message to a Discord channel and wait for approval. diff --git a/src/zenml/integrations/discord/flavors/discord_alerter_flavor.py b/src/zenml/integrations/discord/flavors/discord_alerter_flavor.py index f40e423603e..76563e9c0fc 100644 --- a/src/zenml/integrations/discord/flavors/discord_alerter_flavor.py +++ b/src/zenml/integrations/discord/flavors/discord_alerter_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Discord alerter flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.alerter.base_alerter import BaseAlerterConfig, BaseAlerterFlavor from zenml.integrations.discord import DISCORD_ALERTER_FLAVOR @@ -37,7 +37,7 @@ class DiscordAlerterConfig(BaseAlerterConfig): """ discord_token: str = SecretField() - default_discord_channel_id: Optional[str] = None # TODO: Potential setting + default_discord_channel_id: str | None = None # TODO: Potential setting @property def is_valid(self) -> bool: @@ -96,7 +96,7 @@ def name(self) -> str: return DISCORD_ALERTER_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -105,7 +105,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -123,7 +123,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/alerter/discord.png" @property - def config_class(self) -> Type[DiscordAlerterConfig]: + def config_class(self) -> type[DiscordAlerterConfig]: """Returns `DiscordAlerterConfig` config class. Returns: @@ -132,7 +132,7 @@ def config_class(self) -> Type[DiscordAlerterConfig]: return DiscordAlerterConfig @property - def implementation_class(self) -> Type["DiscordAlerter"]: + def implementation_class(self) -> type["DiscordAlerter"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/discord/steps/discord_alerter_ask_step.py b/src/zenml/integrations/discord/steps/discord_alerter_ask_step.py index 57e46e622e3..9399958a09a 100644 --- a/src/zenml/integrations/discord/steps/discord_alerter_ask_step.py +++ b/src/zenml/integrations/discord/steps/discord_alerter_ask_step.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Step that allows you to send messages to Discord and wait for a response.""" -from typing import Optional from zenml import get_step_context, step from zenml.client import Client @@ -27,7 +26,7 @@ @step def discord_alerter_ask_step( message: str, - params: Optional[DiscordAlerterParameters] = None, + params: DiscordAlerterParameters | None = None, ) -> bool: """Posts a message to the Discord alerter component and waits for approval. diff --git a/src/zenml/integrations/discord/steps/discord_alerter_post_step.py b/src/zenml/integrations/discord/steps/discord_alerter_post_step.py index 9b2258af49a..4d4afcb9c8b 100644 --- a/src/zenml/integrations/discord/steps/discord_alerter_post_step.py +++ b/src/zenml/integrations/discord/steps/discord_alerter_post_step.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Step that allows you to post messages to Discord.""" -from typing import Optional from zenml import get_step_context, step from zenml.client import Client @@ -27,7 +26,7 @@ @step def discord_alerter_post_step( message: str, - params: Optional[DiscordAlerterParameters] = None, + params: DiscordAlerterParameters | None = None, ) -> bool: """Post a message to the Discord alerter component of the active stack. diff --git a/src/zenml/integrations/evidently/__init__.py b/src/zenml/integrations/evidently/__init__.py index 13a6bf01d86..8e480765e8d 100644 --- a/src/zenml/integrations/evidently/__init__.py +++ b/src/zenml/integrations/evidently/__init__.py @@ -25,7 +25,6 @@ import logging import os import warnings -from typing import List, Type, Optional from zenml.integrations.constants import EVIDENTLY from zenml.integrations.integration import Integration @@ -62,8 +61,8 @@ class EvidentlyIntegration(Integration): @classmethod def get_requirements( - cls, target_os: Optional[str] = None, python_version: Optional[str] = None - ) -> List[str]: + cls, target_os: str | None = None, python_version: str | None = None + ) -> list[str]: """Method to get the requirements for the integration. Args: @@ -79,7 +78,7 @@ def get_requirements( PandasIntegration.get_requirements(target_os=target_os, python_version=python_version) @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Great Expectations integration. Returns: diff --git a/src/zenml/integrations/evidently/column_mapping.py b/src/zenml/integrations/evidently/column_mapping.py index 37f9dd505d0..f074121b613 100644 --- a/src/zenml/integrations/evidently/column_mapping.py +++ b/src/zenml/integrations/evidently/column_mapping.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """ZenML representation of an Evidently column mapping.""" -from typing import List, Optional, Sequence, Union +from collections.abc import Sequence from evidently import ColumnMapping # type: ignore[import-untyped] from pydantic import BaseModel, ConfigDict, Field @@ -40,21 +40,21 @@ class EvidentlyColumnMapping(BaseModel): text_features: text features """ - target: Optional[str] = None - prediction: Optional[Union[str, Sequence[str]]] = Field( + target: str | None = None + prediction: str | Sequence[str] | None = Field( default="prediction", union_mode="left_to_right" ) - datetime: Optional[str] = None - id: Optional[str] = None - numerical_features: Optional[List[str]] = None - categorical_features: Optional[List[str]] = None - datetime_features: Optional[List[str]] = None - target_names: Optional[List[str]] = None - task: Optional[str] = None - pos_label: Optional[Union[str, int]] = Field( + datetime: str | None = None + id: str | None = None + numerical_features: list[str] | None = None + categorical_features: list[str] | None = None + datetime_features: list[str] | None = None + target_names: list[str] | None = None + task: str | None = None + pos_label: str | int | None = Field( default=1, union_mode="left_to_right" ) - text_features: Optional[List[str]] = None + text_features: list[str] | None = None model_config = ConfigDict( validate_assignment=True, diff --git a/src/zenml/integrations/evidently/data_validators/evidently_data_validator.py b/src/zenml/integrations/evidently/data_validators/evidently_data_validator.py index da3b1c6869d..f544b3d9057 100644 --- a/src/zenml/integrations/evidently/data_validators/evidently_data_validator.py +++ b/src/zenml/integrations/evidently/data_validators/evidently_data_validator.py @@ -17,12 +17,8 @@ from typing import ( Any, ClassVar, - Dict, - Optional, - Sequence, - Tuple, - Type, ) +from collections.abc import Sequence import pandas as pd from evidently.pipeline.column_mapping import ColumnMapping # type: ignore @@ -45,13 +41,13 @@ class EvidentlyDataValidator(BaseDataValidator): """Evidently data validator stack component.""" NAME: ClassVar[str] = "Evidently" - FLAVOR: ClassVar[Type[BaseDataValidatorFlavor]] = ( + FLAVOR: ClassVar[type[BaseDataValidatorFlavor]] = ( EvidentlyDataValidatorFlavor ) @classmethod def _unpack_options( - cls, option_list: Sequence[Tuple[str, Dict[str, Any]]] + cls, option_list: Sequence[tuple[str, dict[str, Any]]] ) -> Sequence[Any]: """Unpack Evidently options. @@ -165,10 +161,10 @@ def _download_nltk_data() -> None: def data_profiling( self, dataset: pd.DataFrame, - comparison_dataset: Optional[pd.DataFrame] = None, - profile_list: Optional[Sequence[EvidentlyMetricConfig]] = None, - column_mapping: Optional[ColumnMapping] = None, - report_options: Sequence[Tuple[str, Dict[str, Any]]] = [], + comparison_dataset: pd.DataFrame | None = None, + profile_list: Sequence[EvidentlyMetricConfig] | None = None, + column_mapping: ColumnMapping | None = None, + report_options: Sequence[tuple[str, dict[str, Any]]] = [], download_nltk_data: bool = False, **kwargs: Any, ) -> Report: @@ -235,10 +231,10 @@ def data_profiling( def data_validation( self, dataset: Any, - comparison_dataset: Optional[Any] = None, - check_list: Optional[Sequence[EvidentlyTestConfig]] = None, - test_options: Sequence[Tuple[str, Dict[str, Any]]] = [], - column_mapping: Optional[ColumnMapping] = None, + comparison_dataset: Any | None = None, + check_list: Sequence[EvidentlyTestConfig] | None = None, + test_options: Sequence[tuple[str, dict[str, Any]]] = [], + column_mapping: ColumnMapping | None = None, download_nltk_data: bool = False, **kwargs: Any, ) -> TestSuite: diff --git a/src/zenml/integrations/evidently/flavors/evidently_data_validator_flavor.py b/src/zenml/integrations/evidently/flavors/evidently_data_validator_flavor.py index 5c9ced7072f..4294bd0c763 100644 --- a/src/zenml/integrations/evidently/flavors/evidently_data_validator_flavor.py +++ b/src/zenml/integrations/evidently/flavors/evidently_data_validator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Evidently data validator flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.data_validators.base_data_validator import BaseDataValidatorFlavor from zenml.integrations.evidently import EVIDENTLY_DATA_VALIDATOR_FLAVOR @@ -37,7 +37,7 @@ def name(self) -> str: return EVIDENTLY_DATA_VALIDATOR_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -46,7 +46,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -64,7 +64,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/data_validator/evidently.png" @property - def implementation_class(self) -> Type["EvidentlyDataValidator"]: + def implementation_class(self) -> type["EvidentlyDataValidator"]: """Implementation class. Returns: diff --git a/src/zenml/integrations/evidently/metrics.py b/src/zenml/integrations/evidently/metrics.py index 0ecc74bc99a..5f792d9e89d 100644 --- a/src/zenml/integrations/evidently/metrics.py +++ b/src/zenml/integrations/evidently/metrics.py @@ -15,11 +15,6 @@ from typing import ( Any, - Dict, - List, - Optional, - Type, - Union, ) from evidently import metric_preset, metrics # type: ignore[import-untyped] @@ -73,15 +68,15 @@ class EvidentlyMetricConfig(BaseModel): """ class_path: str - parameters: Dict[str, Any] = Field(default_factory=dict) + parameters: dict[str, Any] = Field(default_factory=dict) is_generator: bool = False - columns: Optional[Union[str, List[str]]] = Field( + columns: str | list[str] | None = Field( default=None, union_mode="left_to_right" ) skip_id_column: bool = False @staticmethod - def get_metric_class(metric_name: str) -> Union[Metric, MetricPreset]: + def get_metric_class(metric_name: str) -> Metric | MetricPreset: """Get the Evidently metric or metric preset class from a string. Args: @@ -129,8 +124,8 @@ def get_metric_class(metric_name: str) -> Union[Metric, MetricPreset]: @classmethod def metric_generator( cls, - metric: Union[Type[Metric], str], - columns: Optional[Union[str, List[str]]] = None, + metric: type[Metric] | str, + columns: str | list[str] | None = None, skip_id_column: bool = False, **parameters: Any, ) -> "EvidentlyMetricConfig": @@ -230,7 +225,7 @@ def metric_generator( @classmethod def metric( cls, - metric: Union[Type[Metric], Type[MetricPreset], str], + metric: type[Metric] | type[MetricPreset] | str, **parameters: Any, ) -> "EvidentlyMetricConfig": """Create a declarative configuration for an Evidently Metric. @@ -302,7 +297,7 @@ class path. return config @classmethod - def default_metrics(cls) -> List["EvidentlyMetricConfig"]: + def default_metrics(cls) -> list["EvidentlyMetricConfig"]: """Default Evidently metric configurations. Call this to fetch a default list of Evidently metrics to use in cases @@ -324,7 +319,7 @@ def default_metrics(cls) -> List["EvidentlyMetricConfig"]: def to_evidently_metric( self, - ) -> Union[Metric, MetricPreset, BaseGenerator]: + ) -> Metric | MetricPreset | BaseGenerator: """Create an Evidently Metric, MetricPreset or metric generator object. Call this method to create an Evidently Metric, MetricPreset or metric diff --git a/src/zenml/integrations/evidently/steps/__init__.py b/src/zenml/integrations/evidently/steps/__init__.py index baff052f4e1..1df7744841e 100644 --- a/src/zenml/integrations/evidently/steps/__init__.py +++ b/src/zenml/integrations/evidently/steps/__init__.py @@ -13,10 +13,3 @@ # permissions and limitations under the License. """Initialization of the Evidently Standard Steps.""" -from zenml.integrations.evidently.column_mapping import EvidentlyColumnMapping -from zenml.integrations.evidently.steps.evidently_report import ( - evidently_report_step, -) -from zenml.integrations.evidently.steps.evidently_test import ( - evidently_test_step, -) diff --git a/src/zenml/integrations/evidently/steps/evidently_report.py b/src/zenml/integrations/evidently/steps/evidently_report.py index 6d8fc38170e..a9bec27b7fd 100644 --- a/src/zenml/integrations/evidently/steps/evidently_report.py +++ b/src/zenml/integrations/evidently/steps/evidently_report.py @@ -13,10 +13,11 @@ # permissions and limitations under the License. """Implementation of the Evidently Report Step.""" -from typing import Any, Dict, List, Optional, Sequence, Tuple, cast +from typing import Any, cast +from collections.abc import Sequence import pandas as pd -from typing_extensions import Annotated +from typing import Annotated from zenml import step from zenml.integrations.evidently.column_mapping import EvidentlyColumnMapping @@ -31,13 +32,13 @@ @step def evidently_report_step( reference_dataset: pd.DataFrame, - comparison_dataset: Optional[pd.DataFrame] = None, - column_mapping: Optional[EvidentlyColumnMapping] = None, - ignored_cols: Optional[List[str]] = None, - metrics: Optional[List[EvidentlyMetricConfig]] = None, - report_options: Optional[Sequence[Tuple[str, Dict[str, Any]]]] = None, + comparison_dataset: pd.DataFrame | None = None, + column_mapping: EvidentlyColumnMapping | None = None, + ignored_cols: list[str] | None = None, + metrics: list[EvidentlyMetricConfig] | None = None, + report_options: Sequence[tuple[str, dict[str, Any]]] | None = None, download_nltk_data: bool = False, -) -> Tuple[ +) -> tuple[ Annotated[str, "report_json"], Annotated[HTMLString, "report_html"] ]: """Generate an Evidently report on one or two pandas datasets. diff --git a/src/zenml/integrations/evidently/steps/evidently_test.py b/src/zenml/integrations/evidently/steps/evidently_test.py index 1449fd0c51d..4d38d9d33ca 100644 --- a/src/zenml/integrations/evidently/steps/evidently_test.py +++ b/src/zenml/integrations/evidently/steps/evidently_test.py @@ -13,10 +13,11 @@ # permissions and limitations under the License. """Implementation of the Evidently Test Step.""" -from typing import Any, Dict, List, Optional, Sequence, Tuple, cast +from typing import Any, cast +from collections.abc import Sequence import pandas as pd -from typing_extensions import Annotated +from typing import Annotated from zenml import step from zenml.integrations.evidently.column_mapping import ( @@ -30,13 +31,13 @@ @step def evidently_test_step( reference_dataset: pd.DataFrame, - comparison_dataset: Optional[pd.DataFrame], - column_mapping: Optional[EvidentlyColumnMapping] = None, - ignored_cols: Optional[List[str]] = None, - tests: Optional[List[EvidentlyTestConfig]] = None, - test_options: Optional[Sequence[Tuple[str, Dict[str, Any]]]] = None, + comparison_dataset: pd.DataFrame | None, + column_mapping: EvidentlyColumnMapping | None = None, + ignored_cols: list[str] | None = None, + tests: list[EvidentlyTestConfig] | None = None, + test_options: Sequence[tuple[str, dict[str, Any]]] | None = None, download_nltk_data: bool = False, -) -> Tuple[Annotated[str, "test_json"], Annotated[HTMLString, "test_html"]]: +) -> tuple[Annotated[str, "test_json"], Annotated[HTMLString, "test_html"]]: """Run an Evidently test suite on one or two pandas datasets. Args: diff --git a/src/zenml/integrations/evidently/tests.py b/src/zenml/integrations/evidently/tests.py index 4478ee4ad69..92b55ca043f 100644 --- a/src/zenml/integrations/evidently/tests.py +++ b/src/zenml/integrations/evidently/tests.py @@ -15,11 +15,6 @@ from typing import ( Any, - Dict, - List, - Optional, - Type, - Union, ) from evidently import test_preset, tests # type: ignore[import-untyped] @@ -71,14 +66,14 @@ class EvidentlyTestConfig(BaseModel): """ class_path: str - parameters: Dict[str, Any] = Field(default_factory=dict) + parameters: dict[str, Any] = Field(default_factory=dict) is_generator: bool = False - columns: Optional[Union[str, List[str]]] = Field( + columns: str | list[str] | None = Field( default=None, union_mode="left_to_right" ) @staticmethod - def get_test_class(test_name: str) -> Union[Test, TestPreset]: + def get_test_class(test_name: str) -> Test | TestPreset: """Get the Evidently test or test preset class from a string. Args: @@ -126,8 +121,8 @@ def get_test_class(test_name: str) -> Union[Test, TestPreset]: @classmethod def test_generator( cls, - test: Union[Type[Test], str], - columns: Optional[Union[str, List[str]]] = None, + test: type[Test] | str, + columns: str | list[str] | None = None, **parameters: Any, ) -> "EvidentlyTestConfig": """Create a declarative configuration for an Evidently column Test generator. @@ -223,7 +218,7 @@ def test_generator( @classmethod def test( cls, - test: Union[Type[Test], Type[TestPreset], str], + test: type[Test] | type[TestPreset] | str, **parameters: Any, ) -> "EvidentlyTestConfig": """Create a declarative configuration for an Evidently Test. @@ -294,7 +289,7 @@ class path. return config @classmethod - def default_tests(cls) -> List["EvidentlyTestConfig"]: + def default_tests(cls) -> list["EvidentlyTestConfig"]: """Default Evidently test configurations. Call this to fetch a default list of Evidently tests to use in cases @@ -309,7 +304,7 @@ def default_tests(cls) -> List["EvidentlyTestConfig"]: for test_preset_class_name in test_preset.__all__ ] - def to_evidently_test(self) -> Union[Test, TestPreset, BaseGenerator]: + def to_evidently_test(self) -> Test | TestPreset | BaseGenerator: """Create an Evidently Test, TestPreset or test generator object. Call this method to create an Evidently Test, TestPreset or test diff --git a/src/zenml/integrations/facets/__init__.py b/src/zenml/integrations/facets/__init__.py index cd06ca9e23a..aeb6eebef73 100644 --- a/src/zenml/integrations/facets/__init__.py +++ b/src/zenml/integrations/facets/__init__.py @@ -12,8 +12,7 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Facets integration for ZenML.""" -from typing import Optional, List -from zenml.integrations.constants import FACETS, PANDAS +from zenml.integrations.constants import FACETS from zenml.integrations.integration import Integration @@ -32,8 +31,8 @@ def activate(cls) -> None: @classmethod def get_requirements( - cls, target_os: Optional[str] = None, python_version: Optional[str] = None - ) -> List[str]: + cls, target_os: str | None = None, python_version: str | None = None + ) -> list[str]: """Method to get the requirements for the integration. Args: diff --git a/src/zenml/integrations/facets/materializers/facets_materializer.py b/src/zenml/integrations/facets/materializers/facets_materializer.py index 638040f0ce8..87ed6df5acd 100644 --- a/src/zenml/integrations/facets/materializers/facets_materializer.py +++ b/src/zenml/integrations/facets/materializers/facets_materializer.py @@ -15,7 +15,6 @@ import base64 import os -from typing import Dict from facets_overview.generic_feature_statistics_generator import ( GenericFeatureStatisticsGenerator, @@ -45,7 +44,7 @@ class FacetsMaterializer(BaseMaterializer): def save_visualizations( self, data: FacetsComparison - ) -> Dict[str, VisualizationType]: + ) -> dict[str, VisualizationType]: """Save a Facets visualization of the data. Args: diff --git a/src/zenml/integrations/facets/models.py b/src/zenml/integrations/facets/models.py index 0cac91e37d1..8657061956e 100644 --- a/src/zenml/integrations/facets/models.py +++ b/src/zenml/integrations/facets/models.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Models used by the Facets integration.""" -from typing import Dict, List, Union import pandas as pd from pydantic import BaseModel, ConfigDict @@ -30,5 +29,5 @@ class FacetsComparison(BaseModel): `[{"name": "dataset_name", "table": pd.DataFrame}, ...]`. """ - datasets: List[Dict[str, Union[str, pd.DataFrame]]] + datasets: list[dict[str, str | pd.DataFrame]] model_config = ConfigDict(arbitrary_types_allowed=True) diff --git a/src/zenml/integrations/facets/steps/facets_visualization_steps.py b/src/zenml/integrations/facets/steps/facets_visualization_steps.py index a8324bf22b8..fe5ac6dddca 100644 --- a/src/zenml/integrations/facets/steps/facets_visualization_steps.py +++ b/src/zenml/integrations/facets/steps/facets_visualization_steps.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Facets Standard Steps.""" -from typing import Dict, List, Union import pandas as pd @@ -44,7 +43,7 @@ def facets_visualization_step( @step def facets_list_visualization_step( - dataframes: List[pd.DataFrame], + dataframes: list[pd.DataFrame], ) -> FacetsComparison: """Visualize and compare dataset statistics with Facets. @@ -54,7 +53,7 @@ def facets_list_visualization_step( Returns: `FacetsComparison` object. """ - datasets: List[Dict[str, Union[str, pd.DataFrame]]] = [] + datasets: list[dict[str, str | pd.DataFrame]] = [] for i, df in enumerate(dataframes): datasets.append({"name": f"dataset_{i}", "table": df}) return FacetsComparison(datasets=datasets) @@ -62,7 +61,7 @@ def facets_list_visualization_step( @step def facets_dict_visualization_step( - dataframes: Dict[str, pd.DataFrame], + dataframes: dict[str, pd.DataFrame], ) -> FacetsComparison: """Visualize and compare dataset statistics with Facets. @@ -73,7 +72,7 @@ def facets_dict_visualization_step( Returns: `FacetsComparison` object. """ - datasets: List[Dict[str, Union[str, pd.DataFrame]]] = [] + datasets: list[dict[str, str | pd.DataFrame]] = [] for name, df in dataframes.items(): datasets.append({"name": name, "table": df}) return FacetsComparison(datasets=datasets) diff --git a/src/zenml/integrations/feast/__init__.py b/src/zenml/integrations/feast/__init__.py index 8b81c37ee7d..728ad3834d6 100644 --- a/src/zenml/integrations/feast/__init__.py +++ b/src/zenml/integrations/feast/__init__.py @@ -17,7 +17,6 @@ implements a dedicated stack component that you can access as part of your ZenML steps in the usual ways. """ -from typing import List, Type, Optional from zenml.integrations.constants import FEAST from zenml.integrations.integration import Integration @@ -35,7 +34,7 @@ class FeastIntegration(Integration): REQUIREMENTS_IGNORED_ON_UNINSTALL = ["click", "pandas"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Feast integration. Returns: @@ -46,8 +45,8 @@ def flavors(cls) -> List[Type[Flavor]]: return [FeastFeatureStoreFlavor] @classmethod - def get_requirements(cls, target_os: Optional[str] = None, python_version: Optional[str] = None - ) -> List[str]: + def get_requirements(cls, target_os: str | None = None, python_version: str | None = None + ) -> list[str]: """Method to get the requirements for the integration. Args: diff --git a/src/zenml/integrations/feast/feature_stores/feast_feature_store.py b/src/zenml/integrations/feast/feature_stores/feast_feature_store.py index ffd4e84340f..7420bb68e91 100644 --- a/src/zenml/integrations/feast/feature_stores/feast_feature_store.py +++ b/src/zenml/integrations/feast/feature_stores/feast_feature_store.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the Feast Feature Store for ZenML.""" -from typing import Any, Dict, List, Union, cast +from typing import Any, cast import pandas as pd from feast import FeatureService, FeatureStore # type: ignore @@ -42,8 +42,8 @@ def config(self) -> FeastFeatureStoreConfig: def get_historical_features( self, - entity_df: Union[pd.DataFrame, str], - features: Union[List[str], FeatureService], + entity_df: pd.DataFrame | str, + features: list[str] | FeatureService, full_feature_names: bool = False, ) -> pd.DataFrame: """Returns the historical features for training or batch scoring. @@ -69,10 +69,10 @@ def get_historical_features( def get_online_features( self, - entity_rows: List[Dict[str, Any]], - features: Union[List[str], FeatureService], + entity_rows: list[dict[str, Any]], + features: list[str] | FeatureService, full_feature_names: bool = False, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Returns the latest online feature data. Args: @@ -94,7 +94,7 @@ def get_online_features( full_feature_names=full_feature_names, ).to_dict() - def get_data_sources(self) -> List[str]: + def get_data_sources(self) -> list[str]: """Returns the data sources' names. Raise: @@ -106,7 +106,7 @@ def get_data_sources(self) -> List[str]: fs = FeatureStore(repo_path=self.config.feast_repo) return [ds.name for ds in fs.list_data_sources()] - def get_entities(self) -> List[str]: + def get_entities(self) -> list[str]: """Returns the entity names. Raise: @@ -118,7 +118,7 @@ def get_entities(self) -> List[str]: fs = FeatureStore(repo_path=self.config.feast_repo) return [ds.name for ds in fs.list_entities()] - def get_feature_services(self) -> List[FeatureService]: + def get_feature_services(self) -> list[FeatureService]: """Returns the feature services. Raise: @@ -128,13 +128,13 @@ def get_feature_services(self) -> List[FeatureService]: The feature services. """ fs = FeatureStore(repo_path=self.config.feast_repo) - feature_services: List[FeatureService] = list( + feature_services: list[FeatureService] = list( fs.list_feature_services() ) return feature_services - def get_feature_views(self) -> List[str]: + def get_feature_views(self) -> list[str]: """Returns the feature view names. Raise: diff --git a/src/zenml/integrations/feast/flavors/feast_feature_store_flavor.py b/src/zenml/integrations/feast/flavors/feast_feature_store_flavor.py index 2180d82c8fc..2ae5af6ea20 100644 --- a/src/zenml/integrations/feast/flavors/feast_feature_store_flavor.py +++ b/src/zenml/integrations/feast/flavors/feast_feature_store_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Feast feature store flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from pydantic import Field @@ -70,7 +70,7 @@ def name(self) -> str: return FEAST_FEATURE_STORE_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -79,7 +79,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -97,7 +97,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/feature_store/feast.png" @property - def config_class(self) -> Type[FeastFeatureStoreConfig]: + def config_class(self) -> type[FeastFeatureStoreConfig]: """Returns FeastFeatureStoreConfig config class. Returns: @@ -107,7 +107,7 @@ def config_class(self) -> Type[FeastFeatureStoreConfig]: return FeastFeatureStoreConfig @property - def implementation_class(self) -> Type["FeastFeatureStore"]: + def implementation_class(self) -> type["FeastFeatureStore"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/gcp/__init__.py b/src/zenml/integrations/gcp/__init__.py index bfd37ea82ce..f40e1c21ae3 100644 --- a/src/zenml/integrations/gcp/__init__.py +++ b/src/zenml/integrations/gcp/__init__.py @@ -22,7 +22,6 @@ Vertex AI environment. """ -from typing import List, Type from zenml.integrations.constants import GCP from zenml.integrations.integration import Integration @@ -69,7 +68,7 @@ def activate(cls) -> None: from zenml.integrations.gcp import service_connectors # noqa @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the GCP integration. Returns: diff --git a/src/zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py b/src/zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py index 9d8a47f1ae8..3358b3421f3 100644 --- a/src/zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py +++ b/src/zenml/integrations/gcp/artifact_stores/gcp_artifact_store.py @@ -15,15 +15,10 @@ from typing import ( Any, - Callable, - Dict, - Iterable, - List, - Optional, - Tuple, Union, cast, ) +from collections.abc import Callable, Iterable import gcsfs from google.cloud import storage @@ -46,7 +41,7 @@ class GCPArtifactStore(BaseArtifactStore, AuthenticationMixin): """Artifact Store for Google Cloud Storage based artifacts.""" - _filesystem: Optional[gcsfs.GCSFileSystem] = None + _filesystem: gcsfs.GCSFileSystem | None = None @property def config(self) -> GCPArtifactStoreConfig: @@ -59,7 +54,7 @@ def config(self) -> GCPArtifactStoreConfig: def get_credentials( self, - ) -> Optional[Union[Dict[str, Any], gcp_credentials.Credentials]]: + ) -> dict[str, Any] | gcp_credentials.Credentials | None: """Returns the credentials for the GCP Artifact Store if configured. Returns: @@ -153,7 +148,7 @@ def exists(self, path: PathType) -> bool: """ return self.filesystem.exists(path=path) # type: ignore[no-any-return] - def glob(self, pattern: PathType) -> List[PathType]: + def glob(self, pattern: PathType) -> list[PathType]: """Return all paths that match the given glob pattern. The glob pattern may include: @@ -185,7 +180,7 @@ def isdir(self, path: PathType) -> bool: """ return self.filesystem.isdir(path=path) # type: ignore[no-any-return] - def listdir(self, path: PathType) -> List[PathType]: + def listdir(self, path: PathType) -> list[PathType]: """Return a list of files in a directory. Args: @@ -198,7 +193,7 @@ def listdir(self, path: PathType) -> List[PathType]: if path_without_prefix.startswith(GCP_PATH_PREFIX): path_without_prefix = path_without_prefix[len(GCP_PATH_PREFIX) :] - def _extract_basename(file_dict: Dict[str, Any]) -> str: + def _extract_basename(file_dict: dict[str, Any]) -> str: """Extracts the basename from a file info dict returned by GCP. Args: @@ -279,7 +274,7 @@ def rmtree(self, path: PathType) -> None: """ self.filesystem.delete(path=path, recursive=True) - def stat(self, path: PathType) -> Dict[str, Any]: + def stat(self, path: PathType) -> dict[str, Any]: """Return stat info for the given path. Args: @@ -305,8 +300,8 @@ def walk( self, top: PathType, topdown: bool = True, - onerror: Optional[Callable[..., None]] = None, - ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]: + onerror: Callable[..., None] | None = None, + ) -> Iterable[tuple[PathType, list[PathType], list[PathType]]]: """Return an iterator that walks the contents of the given directory. Args: diff --git a/src/zenml/integrations/gcp/deployers/gcp_deployer.py b/src/zenml/integrations/gcp/deployers/gcp_deployer.py index e8f48e0fae6..56c937856b3 100644 --- a/src/zenml/integrations/gcp/deployers/gcp_deployer.py +++ b/src/zenml/integrations/gcp/deployers/gcp_deployer.py @@ -18,14 +18,9 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - Generator, - List, - Optional, - Tuple, - Type, cast, ) +from collections.abc import Generator from uuid import UUID from google.api_core import exceptions as google_exceptions @@ -82,31 +77,31 @@ class CloudRunDeploymentMetadata(BaseModel): """Metadata for a Cloud Run deployment.""" - service_name: Optional[str] = None - service_url: Optional[str] = None - project_id: Optional[str] = None - location: Optional[str] = None - revision_name: Optional[str] = None - reconciling: Optional[bool] = None - service_status: Optional[Dict[str, Any]] = None - cpu: Optional[str] = None - memory: Optional[str] = None - min_instances: Optional[int] = None - max_instances: Optional[int] = None - concurrency: Optional[int] = None - timeout_seconds: Optional[int] = None - ingress: Optional[str] = None - vpc_connector: Optional[str] = None - service_account: Optional[str] = None - execution_environment: Optional[str] = None - port: Optional[int] = None - allow_unauthenticated: Optional[bool] = None - labels: Optional[Dict[str, str]] = None - annotations: Optional[Dict[str, str]] = None - traffic_allocation: Optional[Dict[str, int]] = None - created_time: Optional[str] = None - updated_time: Optional[str] = None - secrets: List[str] = [] + service_name: str | None = None + service_url: str | None = None + project_id: str | None = None + location: str | None = None + revision_name: str | None = None + reconciling: bool | None = None + service_status: dict[str, Any] | None = None + cpu: str | None = None + memory: str | None = None + min_instances: int | None = None + max_instances: int | None = None + concurrency: int | None = None + timeout_seconds: int | None = None + ingress: str | None = None + vpc_connector: str | None = None + service_account: str | None = None + execution_environment: str | None = None + port: int | None = None + allow_unauthenticated: bool | None = None + labels: dict[str, str] | None = None + annotations: dict[str, str] | None = None + traffic_allocation: dict[str, int] | None = None + created_time: str | None = None + updated_time: str | None = None + secrets: list[str] = [] @classmethod def from_cloud_run_service( @@ -114,7 +109,7 @@ def from_cloud_run_service( service: run_v2.Service, project_id: str, location: str, - secrets: List[secretmanager.Secret], + secrets: list[secretmanager.Secret], ) -> "CloudRunDeploymentMetadata": """Create metadata from a Cloud Run service. @@ -247,13 +242,13 @@ def from_deployment( class GCPDeployer(ContainerizedDeployer, GoogleCredentialsMixin): """Deployer responsible for deploying pipelines on GCP Cloud Run.""" - _credentials: Optional[Any] = None - _project_id: Optional[str] = None - _cloud_run_client: Optional[run_v2.ServicesClient] = None - _logging_client: Optional[LoggingClient] = None - _secret_manager_client: Optional[ + _credentials: Any | None = None + _project_id: str | None = None + _cloud_run_client: run_v2.ServicesClient | None = None + _logging_client: LoggingClient | None = None + _secret_manager_client: None | ( secretmanager.SecretManagerServiceClient - ] = None + ) = None @property def config(self) -> GCPDeployerConfig: @@ -265,7 +260,7 @@ def config(self) -> GCPDeployerConfig: return cast(GCPDeployerConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the GCP deployer. Returns: @@ -274,7 +269,7 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: return GCPDeployerSettings @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Ensures there is an image builder in the stack. Returns: @@ -287,7 +282,7 @@ def validator(self) -> Optional[StackValidator]: } ) - def _get_credentials_and_project_id(self) -> Tuple[Any, str]: + def _get_credentials_and_project_id(self) -> tuple[Any, str]: """Get GCP credentials and project ID. Returns: @@ -364,7 +359,7 @@ def secret_manager_client( def get_labels( self, deployment: DeploymentResponse, settings: GCPDeployerSettings - ) -> Dict[str, str]: + ) -> dict[str, str]: """Get the labels for a deployment. Args: @@ -563,7 +558,7 @@ def _create_or_update_secret( def _get_secrets( self, deployment: DeploymentResponse - ) -> List[secretmanager.Secret]: + ) -> list[secretmanager.Secret]: """Get the existing GCP Secret Manager secrets for a deployment. Args: @@ -574,7 +569,7 @@ def _get_secrets( deployment. """ metadata = CloudRunDeploymentMetadata.from_deployment(deployment) - secrets: List[secretmanager.Secret] = [] + secrets: list[secretmanager.Secret] = [] for secret_name in metadata.secrets: try: secret = self.secret_manager_client.get_secret( @@ -622,11 +617,11 @@ def _cleanup_deployment_secrets( def _prepare_environment_variables( self, deployment: DeploymentResponse, - environment: Dict[str, str], - secrets: Dict[str, str], + environment: dict[str, str], + secrets: dict[str, str], settings: GCPDeployerSettings, project_id: str, - ) -> Tuple[List[run_v2.EnvVar], List[secretmanager.Secret]]: + ) -> tuple[list[run_v2.EnvVar], list[secretmanager.Secret]]: """Prepare environment variables for Cloud Run, handling secrets appropriately. Args: @@ -647,7 +642,7 @@ def _prepare_environment_variables( for key, value in merged_env.items(): env_vars.append(run_v2.EnvVar(name=key, value=value)) - active_secrets: List[secretmanager.Secret] = [] + active_secrets: list[secretmanager.Secret] = [] if secrets: if settings.use_secret_manager: for key, value in secrets.items(): @@ -726,7 +721,7 @@ def _get_service_path( def _get_cloud_run_service( self, deployment: DeploymentResponse - ) -> Optional[run_v2.Service]: + ) -> run_v2.Service | None: """Get an existing Cloud Run service for a deployment. Args: @@ -762,7 +757,7 @@ def _get_service_operational_state( service: run_v2.Service, project_id: str, location: str, - secrets: List[secretmanager.Secret], + secrets: list[secretmanager.Secret], ) -> DeploymentOperationalState: """Get the operational state of a Cloud Run service. @@ -812,7 +807,7 @@ def _get_service_operational_state( def _convert_resource_settings_to_gcp_format( self, resource_settings: ResourceSettings, - ) -> Tuple[str, str]: + ) -> tuple[str, str]: """Convert ResourceSettings to GCP Cloud Run resource format. GCP Cloud Run CPU constraints: @@ -865,8 +860,8 @@ def _convert_resource_settings_to_gcp_format( return str(cpu), memory def _validate_memory_for_cpu( - self, cpu: str, memory_gib: Optional[float] - ) -> Optional[float]: + self, cpu: str, memory_gib: float | None + ) -> float | None: """Validate and adjust memory allocation based on CPU requirements. GCP Cloud Run has minimum memory requirements per CPU configuration: @@ -908,7 +903,7 @@ def _validate_memory_for_cpu( def _convert_scaling_settings_to_gcp_format( self, resource_settings: ResourceSettings, - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """Convert ResourceSettings scaling to GCP Cloud Run format. Args: @@ -954,8 +949,8 @@ def do_provision_deployment( self, deployment: DeploymentResponse, stack: "Stack", - environment: Dict[str, str], - secrets: Dict[str, str], + environment: dict[str, str], + secrets: dict[str, str], timeout: int, ) -> DeploymentOperationalState: """Serve a pipeline as a Cloud Run service. @@ -1226,7 +1221,7 @@ def do_get_deployment_state_logs( self, deployment: DeploymentResponse, follow: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> Generator[str, bool, None]: """Get the logs of a Cloud Run deployment. @@ -1307,7 +1302,7 @@ def do_deprovision_deployment( self, deployment: DeploymentResponse, timeout: int, - ) -> Optional[DeploymentOperationalState]: + ) -> DeploymentOperationalState | None: """Deprovision a Cloud Run deployment. Args: diff --git a/src/zenml/integrations/gcp/experiment_trackers/vertex_experiment_tracker.py b/src/zenml/integrations/gcp/experiment_trackers/vertex_experiment_tracker.py index e6bdb2e914d..2d300f331d2 100644 --- a/src/zenml/integrations/gcp/experiment_trackers/vertex_experiment_tracker.py +++ b/src/zenml/integrations/gcp/experiment_trackers/vertex_experiment_tracker.py @@ -14,7 +14,7 @@ """Implementation of the VertexAI experiment tracker for ZenML.""" import re -from typing import TYPE_CHECKING, Dict, Optional, Type, cast +from typing import TYPE_CHECKING, cast from google.api_core import exceptions from google.cloud import aiplatform @@ -54,7 +54,7 @@ def config(self) -> VertexExperimentTrackerConfig: return cast(VertexExperimentTrackerConfig, self._config) @property - def settings_class(self) -> Type[VertexExperimentTrackerSettings]: + def settings_class(self) -> type[VertexExperimentTrackerSettings]: """Returns the `BaseSettings` settings class. Returns: @@ -74,7 +74,7 @@ def prepare_step_run(self, info: "StepRunInfo") -> None: def get_step_run_metadata( self, info: "StepRunInfo" - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Get component- and step-specific metadata after a step ran. Args: @@ -138,7 +138,7 @@ def _get_dashboard_url(self, experiment: str) -> str: resource = aiplatform.Experiment(experiment_name=experiment) return cast(str, resource.dashboard_url) - def _get_tensorboard_resource_name(self, experiment: str) -> Optional[str]: + def _get_tensorboard_resource_name(self, experiment: str) -> str | None: resource = aiplatform.Experiment( experiment_name=experiment ).get_backing_tensorboard_resource() diff --git a/src/zenml/integrations/gcp/flavors/gcp_artifact_store_flavor.py b/src/zenml/integrations/gcp/flavors/gcp_artifact_store_flavor.py index 41e9d9a09e7..000e8e1987d 100644 --- a/src/zenml/integrations/gcp/flavors/gcp_artifact_store_flavor.py +++ b/src/zenml/integrations/gcp/flavors/gcp_artifact_store_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """GCP artifact store flavor.""" -from typing import TYPE_CHECKING, ClassVar, Optional, Set, Type +from typing import TYPE_CHECKING, ClassVar from zenml.artifact_stores import ( BaseArtifactStoreConfig, @@ -35,7 +35,7 @@ class GCPArtifactStoreConfig( ): """Configuration for GCP Artifact Store.""" - SUPPORTED_SCHEMES: ClassVar[Set[str]] = {GCP_PATH_PREFIX} + SUPPORTED_SCHEMES: ClassVar[set[str]] = {GCP_PATH_PREFIX} IS_IMMUTABLE_FILESYSTEM: ClassVar[bool] = True @@ -54,7 +54,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -70,7 +70,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -79,7 +79,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -97,7 +97,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/artifact_store/gcp.png" @property - def config_class(self) -> Type[GCPArtifactStoreConfig]: + def config_class(self) -> type[GCPArtifactStoreConfig]: """Returns GCPArtifactStoreConfig config class. Returns: @@ -106,7 +106,7 @@ def config_class(self) -> Type[GCPArtifactStoreConfig]: return GCPArtifactStoreConfig @property - def implementation_class(self) -> Type["GCPArtifactStore"]: + def implementation_class(self) -> type["GCPArtifactStore"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/gcp/flavors/gcp_deployer_flavor.py b/src/zenml/integrations/gcp/flavors/gcp_deployer_flavor.py index aa2d87eb7e9..534e17bef3f 100644 --- a/src/zenml/integrations/gcp/flavors/gcp_deployer_flavor.py +++ b/src/zenml/integrations/gcp/flavors/gcp_deployer_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """GCP Cloud Run deployer flavor.""" -from typing import TYPE_CHECKING, Dict, Optional, Type +from typing import TYPE_CHECKING from pydantic import Field @@ -65,32 +65,32 @@ class GCPDeployerSettings(BaseDeployerSettings): "Options: 'all', 'internal', 'internal-and-cloud-load-balancing'.", ) - vpc_connector: Optional[str] = Field( + vpc_connector: str | None = Field( default=None, description="VPC connector for private networking. " "Format: projects/PROJECT_ID/locations/LOCATION/connectors/CONNECTOR_NAME", ) # Service account and IAM - service_account: Optional[str] = Field( + service_account: str | None = Field( default=None, description="Service account email to run the Cloud Run service. " "If not specified, uses the default Compute Engine service account.", ) # Environment and configuration - environment_variables: Dict[str, str] = Field( + environment_variables: dict[str, str] = Field( default_factory=dict, description="Environment variables to set in the Cloud Run service.", ) # Labels and annotations - labels: Dict[str, str] = Field( + labels: dict[str, str] = Field( default_factory=dict, description="Labels to apply to the Cloud Run service.", ) - annotations: Dict[str, str] = Field( + annotations: dict[str, str] = Field( default_factory=dict, description="Annotations to apply to the Cloud Run service.", ) @@ -102,7 +102,7 @@ class GCPDeployerSettings(BaseDeployerSettings): ) # Deployment configuration - traffic_allocation: Dict[str, int] = Field( + traffic_allocation: dict[str, int] = Field( default_factory=lambda: {"LATEST": 100}, description="Traffic allocation between revisions. " "Keys are revision names or 'LATEST', values are percentages.", @@ -164,7 +164,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -179,7 +179,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -188,7 +188,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -206,7 +206,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/deployer/google-cloud-run.svg" @property - def config_class(self) -> Type[GCPDeployerConfig]: + def config_class(self) -> type[GCPDeployerConfig]: """Returns the GCPDeployerConfig config class. Returns: @@ -215,7 +215,7 @@ def config_class(self) -> Type[GCPDeployerConfig]: return GCPDeployerConfig @property - def implementation_class(self) -> Type["GCPDeployer"]: + def implementation_class(self) -> type["GCPDeployer"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/gcp/flavors/gcp_image_builder_flavor.py b/src/zenml/integrations/gcp/flavors/gcp_image_builder_flavor.py index 3614ced37da..8bf278b52d3 100644 --- a/src/zenml/integrations/gcp/flavors/gcp_image_builder_flavor.py +++ b/src/zenml/integrations/gcp/flavors/gcp_image_builder_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Google Cloud image builder flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from pydantic import PositiveInt @@ -75,7 +75,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -91,7 +91,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -100,7 +100,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -118,7 +118,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/image_builder/gcp.png" @property - def config_class(self) -> Type[BaseImageBuilderConfig]: + def config_class(self) -> type[BaseImageBuilderConfig]: """The config class. Returns: @@ -127,7 +127,7 @@ def config_class(self) -> Type[BaseImageBuilderConfig]: return GCPImageBuilderConfig @property - def implementation_class(self) -> Type["GCPImageBuilder"]: + def implementation_class(self) -> type["GCPImageBuilder"]: """Implementation class. Returns: diff --git a/src/zenml/integrations/gcp/flavors/vertex_experiment_tracker_flavor.py b/src/zenml/integrations/gcp/flavors/vertex_experiment_tracker_flavor.py index 40d8e0c8aa1..1e42825a237 100644 --- a/src/zenml/integrations/gcp/flavors/vertex_experiment_tracker_flavor.py +++ b/src/zenml/integrations/gcp/flavors/vertex_experiment_tracker_flavor.py @@ -14,7 +14,7 @@ """Vertex experiment tracker flavor.""" import re -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Any from pydantic import Field, field_validator @@ -42,10 +42,10 @@ class VertexExperimentTrackerSettings(BaseSettings): """Settings for the VertexAI experiment tracker.""" - experiment: Optional[str] = Field( + experiment: str | None = Field( None, description="The VertexAI experiment name." ) - experiment_tensorboard: Optional[Union[str, bool]] = Field( + experiment_tensorboard: str | bool | None = Field( None, description="The VertexAI experiment tensorboard instance to use.", ) @@ -78,14 +78,14 @@ class VertexExperimentTrackerConfig( ): """Config for the VertexAI experiment tracker.""" - location: Optional[str] = None - staging_bucket: Optional[str] = None - network: Optional[str] = None - encryption_spec_key_name: Optional[str] = SecretField(default=None) - api_endpoint: Optional[str] = SecretField(default=None) - api_key: Optional[str] = SecretField(default=None) - api_transport: Optional[str] = None - request_metadata: Optional[Dict[str, Any]] = None + location: str | None = None + staging_bucket: str | None = None + network: str | None = None + encryption_spec_key_name: str | None = SecretField(default=None) + api_endpoint: str | None = SecretField(default=None) + api_key: str | None = SecretField(default=None) + api_transport: str | None = None + request_metadata: dict[str, Any] | None = None class VertexExperimentTrackerFlavor(BaseExperimentTrackerFlavor): @@ -101,7 +101,7 @@ def name(self) -> str: return GCP_VERTEX_EXPERIMENT_TRACKER_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A URL to point at docs explaining this flavor. Returns: @@ -110,7 +110,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A URL to point at SDK docs explaining this flavor. Returns: @@ -128,7 +128,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/experiment_tracker/vertexai.png" @property - def config_class(self) -> Type[VertexExperimentTrackerConfig]: + def config_class(self) -> type[VertexExperimentTrackerConfig]: """Returns `VertexExperimentTrackerConfig` config class. Returns: @@ -137,7 +137,7 @@ def config_class(self) -> Type[VertexExperimentTrackerConfig]: return VertexExperimentTrackerConfig @property - def implementation_class(self) -> Type["VertexExperimentTracker"]: + def implementation_class(self) -> type["VertexExperimentTracker"]: """Implementation class for this flavor. Returns: @@ -152,7 +152,7 @@ def implementation_class(self) -> Type["VertexExperimentTracker"]: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available diff --git a/src/zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py b/src/zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py index 38cff04705f..aea4cbbe04c 100644 --- a/src/zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py +++ b/src/zenml/integrations/gcp/flavors/vertex_orchestrator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Vertex orchestrator flavor.""" -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type +from typing import TYPE_CHECKING from pydantic import Field @@ -40,7 +40,7 @@ class VertexOrchestratorSettings(BaseSettings): """Settings for the Vertex orchestrator.""" - labels: Dict[str, str] = Field( + labels: dict[str, str] = Field( default_factory=dict, description="Labels to assign to the pipeline job. " "Example: {'environment': 'production', 'team': 'ml-ops'}", @@ -52,7 +52,7 @@ class VertexOrchestratorSettings(BaseSettings): "the client returns immediately and the pipeline is executed " "asynchronously.", ) - node_selector_constraint: Optional[Tuple[str, str]] = Field( + node_selector_constraint: tuple[str, str] | None = Field( None, description="Each constraint is a key-value pair label. For the container " "to be eligible to run on a node, the node must have each of the " @@ -61,12 +61,12 @@ class VertexOrchestratorSettings(BaseSettings): "Hint: the selected region (location) must provide the requested accelerator" "(see https://cloud.google.com/compute/docs/gpus/gpu-regions-zones).", ) - pod_settings: Optional[KubernetesPodSettings] = Field( + pod_settings: KubernetesPodSettings | None = Field( None, description="Pod settings to apply to the orchestrator and step pods.", ) - custom_job_parameters: Optional[VertexCustomJobParameters] = Field( + custom_job_parameters: VertexCustomJobParameters | None = Field( None, description="Custom parameters for the Vertex AI custom job." ) @@ -90,32 +90,32 @@ class VertexOrchestratorConfig( "Vertex AI Pipelines is available in specific regions: " "https://cloud.google.com/vertex-ai/docs/general/locations#feature-availability", ) - pipeline_root: Optional[str] = Field( + pipeline_root: str | None = Field( None, description="A Cloud Storage URI that will be used by the Vertex AI Pipelines. " "If not provided but the artifact store in the stack is a GCPArtifactStore, " "then a subdirectory of the artifact store will be used.", ) - encryption_spec_key_name: Optional[str] = Field( + encryption_spec_key_name: str | None = Field( None, description="The Cloud KMS resource identifier of the customer managed " "encryption key used to protect the job. Has the form: " "projects//locations//keyRings//cryptoKeys/. " "The key needs to be in the same region as where the compute resource is created.", ) - workload_service_account: Optional[str] = Field( + workload_service_account: str | None = Field( None, description="The service account for workload run-as account. Users submitting " "jobs must have act-as permission on this run-as account. If not provided, " "the Compute Engine default service account for the GCP project is used.", ) - network: Optional[str] = Field( + network: str | None = Field( None, description="The full name of the Compute Engine Network to which the job " "should be peered. For example, 'projects/12345/global/networks/myVPC'. " "If not provided, the job will not be peered with any network.", ) - private_service_connect: Optional[str] = Field( + private_service_connect: str | None = Field( None, description="The full name of a Private Service Connect endpoint to which " "the job should be peered. For example, " @@ -124,27 +124,27 @@ class VertexOrchestratorConfig( ) # Deprecated - cpu_limit: Optional[str] = Field( + cpu_limit: str | None = Field( None, description="DEPRECATED: The maximum CPU limit for this operator. " "Use custom_job_parameters or pod_settings instead.", ) - memory_limit: Optional[str] = Field( + memory_limit: str | None = Field( None, description="DEPRECATED: The maximum memory limit for this operator. " "Use custom_job_parameters or pod_settings instead.", ) - gpu_limit: Optional[int] = Field( + gpu_limit: int | None = Field( None, description="DEPRECATED: The GPU limit for the operator. " "Use custom_job_parameters or pod_settings instead.", ) - function_service_account: Optional[str] = Field( + function_service_account: str | None = Field( None, description="DEPRECATED: The service account for cloud function run-as account, " "for scheduled pipelines. This functionality is no longer supported.", ) - scheduler_service_account: Optional[str] = Field( + scheduler_service_account: str | None = Field( None, description="DEPRECATED: The service account used by the Google Cloud Scheduler " "to trigger and authenticate to the pipeline Cloud Function on a schedule. " @@ -206,7 +206,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -221,7 +221,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -230,7 +230,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -248,7 +248,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/vertexai.png" @property - def config_class(self) -> Type[VertexOrchestratorConfig]: + def config_class(self) -> type[VertexOrchestratorConfig]: """Returns VertexOrchestratorConfig config class. Returns: @@ -257,7 +257,7 @@ def config_class(self) -> Type[VertexOrchestratorConfig]: return VertexOrchestratorConfig @property - def implementation_class(self) -> Type["VertexOrchestrator"]: + def implementation_class(self) -> type["VertexOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py b/src/zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py index 694c554f899..fa3f5c49919 100644 --- a/src/zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +++ b/src/zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Vertex step operator flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.config.base_settings import BaseSettings from zenml.integrations.gcp import ( @@ -63,13 +63,13 @@ class VertexStepOperatorConfig( # customer managed encryption key resource name # will be applied to all Vertex AI resources if set - encryption_spec_key_name: Optional[str] = None + encryption_spec_key_name: str | None = None - network: Optional[str] = None + network: str | None = None - reserved_ip_ranges: Optional[str] = None + reserved_ip_ranges: str | None = None - service_account: Optional[str] = None + service_account: str | None = None @property def is_remote(self) -> bool: @@ -100,7 +100,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -115,7 +115,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -124,7 +124,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -142,7 +142,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/vertexai.png" @property - def config_class(self) -> Type[VertexStepOperatorConfig]: + def config_class(self) -> type[VertexStepOperatorConfig]: """Returns `VertexStepOperatorConfig` config class. Returns: @@ -151,7 +151,7 @@ def config_class(self) -> Type[VertexStepOperatorConfig]: return VertexStepOperatorConfig @property - def implementation_class(self) -> Type["VertexStepOperator"]: + def implementation_class(self) -> type["VertexStepOperator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/gcp/google_credentials_mixin.py b/src/zenml/integrations/gcp/google_credentials_mixin.py index 681bbf4da66..d4004c792e3 100644 --- a/src/zenml/integrations/gcp/google_credentials_mixin.py +++ b/src/zenml/integrations/gcp/google_credentials_mixin.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the Google credentials mixin.""" -from typing import TYPE_CHECKING, Optional, Tuple, cast +from typing import TYPE_CHECKING, cast from pydantic import Field @@ -34,11 +34,11 @@ class GoogleCredentialsConfigMixin(StackComponentConfig): Field descriptions are defined inline using Field() descriptors. """ - project: Optional[str] = Field( + project: str | None = Field( default=None, description="Google Cloud Project ID. Auto-detected from environment if not specified.", ) - service_account_path: Optional[str] = Field( + service_account_path: str | None = Field( default=None, description="Path to service account JSON key file for authentication. " "Uses Application Default Credentials if not provided.", @@ -57,7 +57,7 @@ def config(self) -> GoogleCredentialsConfigMixin: """ return cast(GoogleCredentialsConfigMixin, self._config) - def _get_authentication(self) -> Tuple["Credentials", str]: + def _get_authentication(self) -> tuple["Credentials", str]: """Get GCP credentials and the project ID associated with the credentials. If `service_account_path` is provided, then the credentials will be diff --git a/src/zenml/integrations/gcp/image_builders/gcp_image_builder.py b/src/zenml/integrations/gcp/image_builders/gcp_image_builder.py index b83bda7e291..9979d316b83 100644 --- a/src/zenml/integrations/gcp/image_builders/gcp_image_builder.py +++ b/src/zenml/integrations/gcp/image_builders/gcp_image_builder.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Google Cloud Builder image builder implementation.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Optional, cast from urllib.parse import urlparse from google.cloud.devtools import cloudbuild_v1 @@ -69,7 +69,7 @@ def validator(self) -> Optional["StackValidator"]: Stack validator. """ - def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + def _validate_remote_components(stack: "Stack") -> tuple[bool, str]: assert stack.container_registry if stack.container_registry.config.is_local: @@ -98,7 +98,7 @@ def build( self, image_name: str, build_context: "BuildContext", - docker_build_options: Dict[str, Any], + docker_build_options: dict[str, Any], container_registry: Optional["BaseContainerRegistry"] = None, ) -> str: """Builds and pushes a Docker image. @@ -141,7 +141,7 @@ def _configure_cloud_build( self, image_name: str, cloud_build_context: str, - build_options: Dict[str, Any], + build_options: dict[str, Any], ) -> cloudbuild_v1.Build: """Configures the build to be run to generate the Docker image. diff --git a/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py b/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py index 27c85080483..49e579186ae 100644 --- a/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +++ b/src/zenml/integrations/gcp/orchestrators/vertex_orchestrator.py @@ -36,11 +36,7 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Tuple, - Type, cast, ) from uuid import UUID @@ -140,7 +136,7 @@ def config(self) -> VertexOrchestratorConfig: return cast(VertexOrchestratorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Vertex orchestrator. Returns: @@ -149,7 +145,7 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: return VertexOrchestratorSettings @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates that the stack contains a container registry. Also validates that the artifact store is not local. @@ -158,7 +154,7 @@ def validator(self) -> Optional[StackValidator]: A StackValidator instance. """ - def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]: + def _validate_stack_requirements(stack: "Stack") -> tuple[bool, str]: """Validates that all the stack components are not local. Args: @@ -249,8 +245,8 @@ def pipeline_directory(self) -> str: def _create_container_component( self, image: str, - command: List[str], - arguments: List[str], + command: list[str], + arguments: list[str], component_name: str, ) -> BaseComponent: """Creates a container component for a Vertex pipeline. @@ -291,7 +287,7 @@ def _convert_to_custom_training_job( self, component: BaseComponent, settings: VertexOrchestratorSettings, - environment: Dict[str, str], + environment: dict[str, str], ) -> BaseComponent: """Convert a component to a custom training job component. @@ -375,10 +371,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to @@ -442,7 +438,7 @@ def _create_dynamic_pipeline() -> Any: Returns: pipeline_func """ - step_name_to_dynamic_component: Dict[str, BaseComponent] = {} + step_name_to_dynamic_component: dict[str, BaseComponent] = {} for step_name, step in snapshot.step_configurations.items(): image = self.get_image( @@ -560,7 +556,7 @@ def dynamic_pipeline() -> None: pod_settings = step_settings.pod_settings - node_selector_constraint: Optional[Tuple[str, str]] = ( + node_selector_constraint: tuple[str, str] | None = ( None ) if pod_settings and ( @@ -628,7 +624,7 @@ def _upload_and_run_pipeline( run_name: str, settings: VertexOrchestratorSettings, schedule: Optional["ScheduleResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Uploads and run the pipeline on the Vertex AI Pipelines service. Args: @@ -802,7 +798,7 @@ def get_orchestrator_run_id(self) -> str: def get_pipeline_run_metadata( self, run_id: UUID - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Get general component-specific metadata for a pipeline run. Args: @@ -826,7 +822,7 @@ def _configure_container_resources( self, dynamic_component: dsl.PipelineTask, resource_settings: "ResourceSettings", - node_selector_constraint: Optional[Tuple[str, str]] = None, + node_selector_constraint: tuple[str, str] | None = None, ) -> dsl.PipelineTask: """Adds resource requirements to the container. @@ -886,8 +882,8 @@ def _configure_container_resources( def fetch_status( self, run: "PipelineRunResponse", include_steps: bool = False - ) -> Tuple[ - Optional[ExecutionStatus], Optional[Dict[str, ExecutionStatus]] + ) -> tuple[ + ExecutionStatus | None, dict[str, ExecutionStatus] | None ]: """Refreshes the status of a specific pipeline run. @@ -967,7 +963,7 @@ def fetch_status( def compute_metadata( self, job: aiplatform.PipelineJob - ) -> Dict[str, MetadataType]: + ) -> dict[str, MetadataType]: """Generate run metadata based on the corresponding Vertex PipelineJob. Args: @@ -976,7 +972,7 @@ def compute_metadata( Returns: A dictionary of metadata related to the pipeline run. """ - metadata: Dict[str, MetadataType] = {} + metadata: dict[str, MetadataType] = {} # Orchestrator Run ID if run_id := self._compute_orchestrator_run_id(job): @@ -995,7 +991,7 @@ def compute_metadata( @staticmethod def _compute_orchestrator_url( job: aiplatform.PipelineJob, - ) -> Optional[str]: + ) -> str | None: """Generate the Orchestrator Dashboard URL upon pipeline execution. Args: @@ -1015,7 +1011,7 @@ def _compute_orchestrator_url( @staticmethod def _compute_orchestrator_logs_url( job: aiplatform.PipelineJob, - ) -> Optional[str]: + ) -> str | None: """Generate the Logs Explorer URL upon pipeline execution. Args: @@ -1042,7 +1038,7 @@ def _compute_orchestrator_logs_url( @staticmethod def _compute_orchestrator_run_id( job: aiplatform.PipelineJob, - ) -> Optional[str]: + ) -> str | None: """Fetch the Orchestrator Run ID upon pipeline execution. Args: diff --git a/src/zenml/integrations/gcp/service_connectors/gcp_service_connector.py b/src/zenml/integrations/gcp/service_connectors/gcp_service_connector.py index bebe4cd68a9..b1b7c701202 100644 --- a/src/zenml/integrations/gcp/service_connectors/gcp_service_connector.py +++ b/src/zenml/integrations/gcp/service_connectors/gcp_service_connector.py @@ -28,7 +28,7 @@ import shutil import subprocess import tempfile -from typing import Any, Dict, List, Optional, Tuple +from typing import Any import google.api_core.exceptions import google.auth @@ -106,8 +106,8 @@ class GCPUserAccountCredentials(AuthenticationConfig): @classmethod @before_validator_handler def validate_user_account_dict( - cls, data: Dict[str, Any] - ) -> Dict[str, Any]: + cls, data: dict[str, Any] + ) -> dict[str, Any]: """Convert the user account credentials to JSON if given in dict format. Args: @@ -202,8 +202,8 @@ class GCPServiceAccountCredentials(AuthenticationConfig): @classmethod @before_validator_handler def validate_service_account_dict( - cls, data: Dict[str, Any] - ) -> Dict[str, Any]: + cls, data: dict[str, Any] + ) -> dict[str, Any]: """Convert the service account credentials to JSON if given in dict format. Args: @@ -309,8 +309,8 @@ class GCPExternalAccountCredentials(AuthenticationConfig): @classmethod @before_validator_handler def validate_service_account_dict( - cls, data: Dict[str, Any] - ) -> Dict[str, Any]: + cls, data: dict[str, Any] + ) -> dict[str, Any]: """Convert the external account credentials to JSON if given in dict format. Args: @@ -438,7 +438,7 @@ class GCPUserAccountConfig(GCPBaseProjectIDConfig, GCPUserAccountCredentials): class GCPServiceAccountConfig(GCPBaseConfig, GCPServiceAccountCredentials): """GCP service account configuration.""" - project_id: Optional[str] = None + project_id: str | None = None @property def gcp_project_id(self) -> str: @@ -469,7 +469,7 @@ class GCPExternalAccountConfig( class GCPOAuth2TokenConfig(GCPBaseProjectIDConfig, GCPOAuth2Token): """GCP OAuth 2.0 configuration.""" - service_account_email: Optional[str] = Field( + service_account_email: str | None = Field( default=None, title="GCP Service Account Email", description="The email address of the service account that signed the " @@ -596,7 +596,6 @@ def get_aws_region(self, context: Any, request: Any) -> str: # Before that, the AWS logic was part of the `google.auth.awsCredentials` # class itself. ZenMLAwsSecurityCredentialsSupplier = None # type: ignore[assignment,misc] - pass class ZenMLGCPAWSExternalAccountCredentials(gcp_aws.Credentials): # type: ignore[misc] @@ -622,7 +621,7 @@ class that can be subclassed instead and supplied as the def _get_security_credentials( self, request: Any, imdsv2_session_token: Any - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Get the security credentials from the local environment. This method is a copy of the original method from the @@ -1083,11 +1082,11 @@ class GCPServiceConnector(ServiceConnector): config: GCPBaseConfig - _session_cache: Dict[ - Tuple[str, Optional[str], Optional[str]], - Tuple[ + _session_cache: dict[ + tuple[str, str | None, str | None], + tuple[ gcp_credentials.Credentials, - Optional[datetime.datetime], + datetime.datetime | None, ], ] = {} @@ -1103,9 +1102,9 @@ def _get_connector_type(cls) -> ServiceConnectorTypeModel: def get_session( self, auth_method: str, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> Tuple[gcp_credentials.Credentials, Optional[datetime.datetime]]: + resource_type: str | None = None, + resource_id: str | None = None, + ) -> tuple[gcp_credentials.Credentials, datetime.datetime | None]: """Get a GCP session object with credentials for the specified resource. Args: @@ -1147,9 +1146,9 @@ def get_session( @classmethod def _get_scopes( cls, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> List[str]: + resource_type: str | None = None, + resource_id: str | None = None, + ) -> list[str]: """Get the OAuth 2.0 scopes to use for the specified resource type. Args: @@ -1166,11 +1165,11 @@ def _get_scopes( def _authenticate( self, auth_method: str, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> Tuple[ + resource_type: str | None = None, + resource_id: str | None = None, + ) -> tuple[ gcp_credentials.Credentials, - Optional[datetime.datetime], + datetime.datetime | None, ]: """Authenticate to GCP and return a session with credentials. @@ -1187,7 +1186,7 @@ def _authenticate( """ cfg = self.config scopes = self._get_scopes(resource_type, resource_id) - expires_at: Optional[datetime.datetime] = None + expires_at: datetime.datetime | None = None if auth_method == GCPAuthenticationMethods.IMPLICIT: self._check_implicit_auth_method_allowed() @@ -1330,7 +1329,7 @@ def _parse_gcs_resource_id(self, resource_id: str) -> str: # - the GCS bucket name # # We need to extract the bucket name from the provided resource ID - bucket_name: Optional[str] = None + bucket_name: str | None = None if re.match( r"^gs://[a-z0-9][a-z0-9_-]{1,61}[a-z0-9](/.*)*$", resource_id, @@ -1356,7 +1355,7 @@ def _parse_gcs_resource_id(self, resource_id: str) -> str: def _parse_gar_resource_id( self, resource_id: str, - ) -> Tuple[str, Optional[str]]: + ) -> tuple[str, str | None]: """Validate and convert a GAR resource ID to a Google Artifact Registry ID and name. Args: @@ -1379,9 +1378,9 @@ def _parse_gar_resource_id( # We need to extract the project ID and registry ID from # the provided resource ID config_project_id = self.config.gcp_project_id - project_id: Optional[str] = None + project_id: str | None = None canonical_url: str - registry_name: Optional[str] = None + registry_name: str | None = None # A Google Artifact Registry URI uses the -docker-pkg.dev # domain format with the project ID as the first part of the URL path @@ -1613,7 +1612,7 @@ def _configure_local_client( resource_type = self.resource_type if resource_type in [GCP_RESOURCE_TYPE, GCS_RESOURCE_TYPE]: - gcloud_config_json: Optional[str] = None + gcloud_config_json: str | None = None # There is no way to configure the local gcloud CLI to use # temporary OAuth 2.0 tokens. However, we can configure it to use @@ -1732,9 +1731,9 @@ def _configure_local_client( @classmethod def _auto_configure( cls, - auth_method: Optional[str] = None, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + auth_method: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, **kwargs: Any, ) -> "GCPServiceConnector": """Auto-configure the connector. @@ -1768,7 +1767,7 @@ def _auto_configure( auth_config: GCPBaseConfig scopes = cls._get_scopes() - expires_at: Optional[datetime.datetime] = None + expires_at: datetime.datetime | None = None try: # Determine the credentials from the environment @@ -1883,7 +1882,7 @@ def _auto_configure( "account JSON file or run 'gcloud auth application-" "default login' to generate a new ADC file." ) - with open(service_account_json_file, "r") as f: + with open(service_account_json_file) as f: service_account_json = f.read() auth_config = GCPServiceAccountConfig( project_id=project_id, @@ -1928,7 +1927,7 @@ def _auto_configure( "account JSON file or run 'gcloud auth application-" "default login' to generate a new ADC file." ) - with open(external_account_json_file, "r") as f: + with open(external_account_json_file) as f: external_account_json = f.read() auth_config = GCPExternalAccountConfig( project_id=project_id, @@ -1951,9 +1950,9 @@ def _auto_configure( def _verify( self, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> List[str]: + resource_type: str | None = None, + resource_id: str | None = None, + ) -> list[str]: """Verify and list all the resources that the connector can access. Args: @@ -2056,7 +2055,7 @@ def _verify( # For backwards compatibility, we initialize the list of resource # IDs with all GCR supported registries for the configured GCP # project - resource_ids: List[str] = [ + resource_ids: list[str] = [ f"{location}gcr.io/{self.config.gcp_project_id}" for location in ["", "us.", "eu.", "asia."] ] @@ -2076,7 +2075,7 @@ def _verify( ] # Then, we need to fetch all the repositories in each location - repository_names: List[str] = [] + repository_names: list[str] = [] for location in location_names: repositories = gar_client.list_repositories( parent=f"projects/{self.config.gcp_project_id}/locations/{location}" diff --git a/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py b/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py index 427b4b8c7b0..6f1b94b3009 100644 --- a/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py +++ b/src/zenml/integrations/gcp/step_operators/vertex_step_operator.py @@ -19,7 +19,7 @@ """ import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast +from typing import TYPE_CHECKING, Any, cast from google.api_core.exceptions import ServerError from google.cloud import aiplatform @@ -55,7 +55,7 @@ VERTEX_DOCKER_IMAGE_KEY = "vertex_step_operator" -def validate_accelerator_type(accelerator_type: Optional[str] = None) -> None: +def validate_accelerator_type(accelerator_type: str | None = None) -> None: """Validates that the accelerator type is valid. Args: @@ -97,7 +97,7 @@ def config(self) -> VertexStepOperatorConfig: return cast(VertexStepOperatorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Vertex step operator. Returns: @@ -106,7 +106,7 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: return VertexStepOperatorSettings @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates the stack. Returns: @@ -114,7 +114,7 @@ def validator(self) -> Optional[StackValidator]: registry and a remote artifact store. """ - def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + def _validate_remote_components(stack: "Stack") -> tuple[bool, str]: if stack.artifact_store.config.is_local: return False, ( "The Vertex step operator runs code remotely and " @@ -150,7 +150,7 @@ def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: def get_docker_builds( self, snapshot: "PipelineSnapshotBase" - ) -> List["BuildConfiguration"]: + ) -> list["BuildConfiguration"]: """Gets the Docker builds required for the component. Args: @@ -174,8 +174,8 @@ def get_docker_builds( def launch( self, info: "StepRunInfo", - entrypoint_command: List[str], - environment: Dict[str, str], + entrypoint_command: list[str], + environment: dict[str, str], ) -> None: """Launches a step on VertexAI. diff --git a/src/zenml/integrations/gcp/vertex_custom_job_parameters.py b/src/zenml/integrations/gcp/vertex_custom_job_parameters.py index 1227e70629d..3230522b65d 100644 --- a/src/zenml/integrations/gcp/vertex_custom_job_parameters.py +++ b/src/zenml/integrations/gcp/vertex_custom_job_parameters.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Vertex custom job parameter model.""" -from typing import Any, Dict, Optional +from typing import Any from pydantic import BaseModel @@ -54,11 +54,11 @@ class VertexCustomJobParameters(BaseModel): See: https://google-cloud-pipeline-components.readthedocs.io/en/google-cloud-pipeline-components-2.19.0/api/v1/custom_job.html """ - accelerator_type: Optional[str] = None + accelerator_type: str | None = None accelerator_count: int = 0 machine_type: str = "n1-standard-4" boot_disk_size_gb: int = 100 boot_disk_type: str = "pd-ssd" - persistent_resource_id: Optional[str] = None - service_account: Optional[str] = None - additional_training_job_args: Dict[str, Any] = {} + persistent_resource_id: str | None = None + service_account: str | None = None + additional_training_job_args: dict[str, Any] = {} diff --git a/src/zenml/integrations/github/__init__.py b/src/zenml/integrations/github/__init__.py index 2358dfb0e3f..4c66eda10da 100644 --- a/src/zenml/integrations/github/__init__.py +++ b/src/zenml/integrations/github/__init__.py @@ -12,7 +12,6 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Initialization of the GitHub ZenML integration.""" -from typing import List, Type from zenml.integrations.constants import GITHUB from zenml.integrations.integration import Integration @@ -25,11 +24,11 @@ class GitHubIntegration(Integration): """Definition of GitHub integration for ZenML.""" NAME = GITHUB - REQUIREMENTS: List[str] = ["pygithub"] + REQUIREMENTS: list[str] = ["pygithub"] @classmethod - def plugin_flavors(cls) -> List[Type[BasePluginFlavor]]: + def plugin_flavors(cls) -> list[type[BasePluginFlavor]]: """Declare the event flavors for the github integration. Returns: diff --git a/src/zenml/integrations/github/code_repositories/github_code_repository.py b/src/zenml/integrations/github/code_repositories/github_code_repository.py index b4e1e57e76c..9a08e8ca0a4 100644 --- a/src/zenml/integrations/github/code_repositories/github_code_repository.py +++ b/src/zenml/integrations/github/code_repositories/github_code_repository.py @@ -15,7 +15,7 @@ import os import re -from typing import Any, Dict, List, Optional +from typing import Any from urllib.parse import urlparse from uuid import uuid4 @@ -50,13 +50,13 @@ class GitHubCodeRepositoryConfig(BaseCodeRepositoryConfig): token: The token to access the repository. """ - api_url: Optional[str] = None + api_url: str | None = None owner: str repository: str - host: Optional[str] = "github.com" - token: Optional[str] = SecretField(default=None) + host: str | None = "github.com" + token: str | None = SecretField(default=None) - url: Optional[str] = None + url: str | None = None _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes( ("url", "api_url") ) @@ -66,7 +66,7 @@ class GitHubCodeRepository(BaseCodeRepository): """GitHub code repository.""" @classmethod - def validate_config(cls, config: Dict[str, Any]) -> None: + def validate_config(cls, config: dict[str, Any]) -> None: """Validate the code repository config. This method should check that the config/credentials are valid and @@ -162,7 +162,7 @@ def login( raise RuntimeError(f"An error occurred while logging in: {str(e)}") def download_files( - self, commit: str, directory: str, repo_sub_directory: Optional[str] + self, commit: str, directory: str, repo_sub_directory: str | None ) -> None: """Downloads files from a commit to a local directory. @@ -177,7 +177,7 @@ def download_files( contents = self.github_repo.get_contents( repo_sub_directory or "", ref=commit ) - if not isinstance(contents, List): + if not isinstance(contents, list): raise RuntimeError("Invalid repository subdirectory.") os.makedirs(directory, exist_ok=True) @@ -206,7 +206,7 @@ def download_files( try: with open(local_path, "wb") as f: f.write(content.decoded_content) - except (GithubException, IOError, AssertionError) as e: + except (GithubException, OSError, AssertionError) as e: logger.error( "Error processing `%s` (%s): %s", content.path, @@ -214,7 +214,7 @@ def download_files( e, ) - def get_local_context(self, path: str) -> Optional[LocalRepositoryContext]: + def get_local_context(self, path: str) -> LocalRepositoryContext | None: """Gets the local repository context. Args: diff --git a/src/zenml/integrations/github/plugins/event_sources/github_webhook_event_source.py b/src/zenml/integrations/github/plugins/event_sources/github_webhook_event_source.py index d6771ffa931..09c3b1e8f4a 100644 --- a/src/zenml/integrations/github/plugins/event_sources/github_webhook_event_source.py +++ b/src/zenml/integrations/github/plugins/event_sources/github_webhook_event_source.py @@ -14,7 +14,7 @@ """Implementation of the github webhook event source.""" import urllib -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any from uuid import UUID from pydantic import BaseModel, ConfigDict, Field @@ -109,14 +109,14 @@ class GithubEvent(BaseEvent): before: str after: str repository: Repository - commits: List[Commit] - head_commit: Optional[Commit] = None - tags: Optional[List[Tag]] = None - pull_requests: Optional[List[PullRequest]] = None + commits: list[Commit] + head_commit: Commit | None = None + tags: list[Tag] | None = None + pull_requests: list[PullRequest] | None = None model_config = ConfigDict(extra="allow") @property - def branch(self) -> Optional[str]: + def branch(self) -> str | None: """The branch the event happened on. Returns: @@ -127,7 +127,7 @@ def branch(self) -> Optional[str]: return None @property - def event_type(self) -> Union[GithubEventType, str]: + def event_type(self) -> GithubEventType | str: """The type of github event. Args: @@ -152,9 +152,9 @@ def event_type(self) -> Union[GithubEventType, str]: class GithubWebhookEventFilterConfiguration(WebhookEventFilterConfig): """Configuration for github event filters.""" - repo: Optional[str] = None - branch: Optional[str] = None - event_type: Optional[GithubEventType] = None + repo: str | None = None + branch: str | None = None + event_type: GithubEventType | None = None def event_matches_filter(self, event: BaseEvent) -> bool: """Checks the filter against the inbound event. @@ -182,15 +182,15 @@ def event_matches_filter(self, event: BaseEvent) -> bool: class GithubWebhookEventSourceConfiguration(WebhookEventSourceConfig): """Configuration for github source filters.""" - webhook_secret: Optional[str] = Field( + webhook_secret: str | None = Field( default=None, title="The webhook secret for the event source.", ) - webhook_secret_id: Optional[UUID] = Field( + webhook_secret_id: UUID | None = Field( default=None, description="The ID of the secret containing the webhook secret.", ) - rotate_secret: Optional[bool] = Field( + rotate_secret: bool | None = Field( default=None, description="Set to rotate the webhook secret." ) @@ -202,7 +202,7 @@ class GithubWebhookEventSourceHandler(BaseWebhookEventSourceHandler): """Handler for all github events.""" @property - def config_class(self) -> Type[GithubWebhookEventSourceConfiguration]: + def config_class(self) -> type[GithubWebhookEventSourceConfiguration]: """Returns the webhook event source configuration class. Returns: @@ -211,7 +211,7 @@ def config_class(self) -> Type[GithubWebhookEventSourceConfiguration]: return GithubWebhookEventSourceConfiguration @property - def filter_class(self) -> Type[GithubWebhookEventFilterConfiguration]: + def filter_class(self) -> type[GithubWebhookEventFilterConfiguration]: """Returns the webhook event filter configuration class. Returns: @@ -220,7 +220,7 @@ def filter_class(self) -> Type[GithubWebhookEventFilterConfiguration]: return GithubWebhookEventFilterConfiguration @property - def flavor_class(self) -> Type[BaseWebhookEventSourceFlavor]: + def flavor_class(self) -> type[BaseWebhookEventSourceFlavor]: """Returns the flavor class of the plugin. Returns: @@ -232,7 +232,7 @@ def flavor_class(self) -> Type[BaseWebhookEventSourceFlavor]: return GithubWebhookEventSourceFlavor - def _interpret_event(self, event: Dict[str, Any]) -> GithubEvent: + def _interpret_event(self, event: dict[str, Any]) -> GithubEvent: """Converts the generic event body into a event-source specific pydantic model. Args: @@ -252,8 +252,8 @@ def _interpret_event(self, event: Dict[str, Any]) -> GithubEvent: return github_event def _load_payload( - self, raw_body: bytes, headers: Dict[str, str] - ) -> Dict[str, Any]: + self, raw_body: bytes, headers: dict[str, str] + ) -> dict[str, Any]: """Converts the raw body of the request into a python dictionary. For github webhooks users can optionally choose to urlencode the @@ -278,7 +278,7 @@ def _load_payload( def _get_webhook_secret( self, event_source: EventSourceResponse - ) -> Optional[str]: + ) -> str | None: """Get the webhook secret for the event source. Args: @@ -464,7 +464,7 @@ def _process_event_source_delete( self, event_source: EventSourceResponse, config: EventSourceConfig, - force: Optional[bool] = False, + force: bool | None = False, ) -> None: """Process an event source before it is deleted from the database. diff --git a/src/zenml/integrations/github/plugins/github_webhook_event_source_flavor.py b/src/zenml/integrations/github/plugins/github_webhook_event_source_flavor.py index 5b321911ad5..5f276a9d82d 100644 --- a/src/zenml/integrations/github/plugins/github_webhook_event_source_flavor.py +++ b/src/zenml/integrations/github/plugins/github_webhook_event_source_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Github webhook event source flavor.""" -from typing import ClassVar, Type +from typing import ClassVar from zenml.event_sources.webhooks.base_webhook_event_source import ( BaseWebhookEventSourceFlavor, @@ -30,14 +30,14 @@ class GithubWebhookEventSourceFlavor(BaseWebhookEventSourceFlavor): """Enables users to configure github event sources.""" FLAVOR: ClassVar[str] = GITHUB_EVENT_FLAVOR - PLUGIN_CLASS: ClassVar[Type[GithubWebhookEventSourceHandler]] = ( + PLUGIN_CLASS: ClassVar[type[GithubWebhookEventSourceHandler]] = ( GithubWebhookEventSourceHandler ) # EventPlugin specific EVENT_SOURCE_CONFIG_CLASS: ClassVar[ - Type[GithubWebhookEventSourceConfiguration] + type[GithubWebhookEventSourceConfiguration] ] = GithubWebhookEventSourceConfiguration EVENT_FILTER_CONFIG_CLASS: ClassVar[ - Type[GithubWebhookEventFilterConfiguration] + type[GithubWebhookEventFilterConfiguration] ] = GithubWebhookEventFilterConfiguration diff --git a/src/zenml/integrations/gitlab/__init__.py b/src/zenml/integrations/gitlab/__init__.py index f342abc1b27..9b491f5e806 100644 --- a/src/zenml/integrations/gitlab/__init__.py +++ b/src/zenml/integrations/gitlab/__init__.py @@ -12,7 +12,6 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Initialization of the GitLab ZenML integration.""" -from typing import List, Type from zenml.integrations.constants import GITLAB from zenml.integrations.integration import Integration @@ -22,6 +21,6 @@ class GitLabIntegration(Integration): """Definition of GitLab integration for ZenML.""" NAME = GITLAB - REQUIREMENTS: List[str] = ["python-gitlab"] + REQUIREMENTS: list[str] = ["python-gitlab"] diff --git a/src/zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py b/src/zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py index f40256f13b9..3ddc08e01a2 100644 --- a/src/zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py +++ b/src/zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py @@ -15,7 +15,7 @@ import os import re -from typing import Any, Dict, Optional +from typing import Any from urllib.parse import urlparse from uuid import uuid4 @@ -50,13 +50,13 @@ class GitLabCodeRepositoryConfig(BaseCodeRepositoryConfig): token: The token to access the repository. """ - instance_url: Optional[str] = None + instance_url: str | None = None group: str project: str - host: Optional[str] = "gitlab.com" - token: Optional[str] = SecretField(default=None) + host: str | None = "gitlab.com" + token: str | None = SecretField(default=None) - url: Optional[str] = None + url: str | None = None _deprecation_validator = deprecation_utils.deprecate_pydantic_attributes( ("url", "instance_url") ) @@ -66,7 +66,7 @@ class GitLabCodeRepository(BaseCodeRepository): """GitLab code repository.""" @classmethod - def validate_config(cls, config: Dict[str, Any]) -> None: + def validate_config(cls, config: dict[str, Any]) -> None: """Validate the code repository config. This method should check that the config/credentials are valid and @@ -118,7 +118,7 @@ def login(self) -> None: raise RuntimeError(f"An error occurred while logging in: {str(e)}") def download_files( - self, commit: str, directory: str, repo_sub_directory: Optional[str] + self, commit: str, directory: str, repo_sub_directory: str | None ) -> None: """Downloads files from a commit to a local directory. @@ -152,7 +152,7 @@ def download_files( except Exception as e: logger.error("Error processing %s: %s", content["path"], e) - def get_local_context(self, path: str) -> Optional[LocalRepositoryContext]: + def get_local_context(self, path: str) -> LocalRepositoryContext | None: """Gets the local repository context. Args: @@ -196,7 +196,7 @@ def check_remote_url(self, url: str) -> bool: f"@{host}:" r"(?P\d+)?" r"(?(scheme_with_delimiter)/|/?)" - f"{group}/{project}(\.git)?$", + fr"{group}/{project}(\.git)?$", ) if ssh_regex.fullmatch(url): return True diff --git a/src/zenml/integrations/great_expectations/__init__.py b/src/zenml/integrations/great_expectations/__init__.py index 0379d5863fc..02b26e3f51e 100644 --- a/src/zenml/integrations/great_expectations/__init__.py +++ b/src/zenml/integrations/great_expectations/__init__.py @@ -17,7 +17,6 @@ way of profiling and validating your data. """ -from typing import List, Type, Optional from zenml.integrations.constants import GREAT_EXPECTATIONS from zenml.integrations.integration import Integration @@ -40,7 +39,7 @@ def activate(cls) -> None: from zenml.integrations.great_expectations import materializers # noqa @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Great Expectations integration. Returns: @@ -53,8 +52,8 @@ def flavors(cls) -> List[Type[Flavor]]: return [GreatExpectationsDataValidatorFlavor] @classmethod - def get_requirements(cls, target_os: Optional[str] = None, python_version: Optional[str] = None - ) -> List[str]: + def get_requirements(cls, target_os: str | None = None, python_version: str | None = None + ) -> list[str]: """Method to get the requirements for the integration. Args: diff --git a/src/zenml/integrations/great_expectations/data_validators/ge_data_validator.py b/src/zenml/integrations/great_expectations/data_validators/ge_data_validator.py index 0eb8f7a590c..11f755eb904 100644 --- a/src/zenml/integrations/great_expectations/data_validators/ge_data_validator.py +++ b/src/zenml/integrations/great_expectations/data_validators/ge_data_validator.py @@ -14,7 +14,8 @@ """Implementation of the Great Expectations data validator.""" import os -from typing import Any, ClassVar, Dict, List, Optional, Sequence, Type, cast +from typing import Any, ClassVar, cast +from collections.abc import Sequence import pandas as pd from great_expectations.checkpoint.types.checkpoint_result import ( # type: ignore[import-untyped] @@ -65,12 +66,12 @@ class GreatExpectationsDataValidator(BaseDataValidator): """Great Expectations data validator stack component.""" NAME: ClassVar[str] = "Great Expectations" - FLAVOR: ClassVar[Type[BaseDataValidatorFlavor]] = ( + FLAVOR: ClassVar[type[BaseDataValidatorFlavor]] = ( GreatExpectationsDataValidatorFlavor ) - _context: Optional[AbstractDataContext] = None - _context_config: Optional[DataContextConfig] = None + _context: AbstractDataContext | None = None + _context_config: DataContextConfig | None = None @property def config(self) -> GreatExpectationsDataValidatorConfig: @@ -98,7 +99,7 @@ def get_data_context(cls) -> AbstractDataContext: return data_validator.data_context @property - def context_config(self) -> Optional[DataContextConfig]: + def context_config(self) -> DataContextConfig | None: """Get the Great Expectations data context configuration. Raises: @@ -126,7 +127,7 @@ def context_config(self) -> Optional[DataContextConfig]: return self._context_config @property - def local_path(self) -> Optional[str]: + def local_path(self) -> str | None: """Return a local path where this component stores information. If an existing local GE data context is used, it is @@ -138,7 +139,7 @@ def local_path(self) -> Optional[str]: """ return self.config.context_root_dir - def get_store_config(self, class_name: str, prefix: str) -> Dict[str, Any]: + def get_store_config(self, class_name: str, prefix: str) -> dict[str, Any]: """Generate a Great Expectations store configuration. Args: @@ -159,7 +160,7 @@ def get_store_config(self, class_name: str, prefix: str) -> Dict[str, Any]: def get_data_docs_config( self, prefix: str, local: bool = False - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Generate Great Expectations data docs configuration. Args: @@ -205,7 +206,7 @@ def data_context(self) -> AbstractDataContext: # Define default configuration options that plug the GX stores # in the active ZenML artifact store - zenml_context_config: Dict[str, Any] = dict( + zenml_context_config: dict[str, Any] = dict( stores={ expectations_store_name: self.get_store_config( "ExpectationsStore", "expectations" @@ -326,11 +327,11 @@ def root_directory(self) -> str: def data_profiling( self, dataset: pd.DataFrame, - comparison_dataset: Optional[Any] = None, - profile_list: Optional[Sequence[str]] = None, - expectation_suite_name: Optional[str] = None, - data_asset_name: Optional[str] = None, - profiler_kwargs: Optional[Dict[str, Any]] = None, + comparison_dataset: Any | None = None, + profile_list: Sequence[str] | None = None, + expectation_suite_name: str | None = None, + data_asset_name: str | None = None, + profiler_kwargs: dict[str, Any] | None = None, overwrite_existing_suite: bool = True, **kwargs: Any, ) -> ExpectationSuite: @@ -438,11 +439,11 @@ def data_profiling( def data_validation( self, dataset: pd.DataFrame, - comparison_dataset: Optional[Any] = None, - check_list: Optional[Sequence[str]] = None, - expectation_suite_name: Optional[str] = None, - data_asset_name: Optional[str] = None, - action_list: Optional[List[Dict[str, Any]]] = None, + comparison_dataset: Any | None = None, + check_list: Sequence[str] | None = None, + expectation_suite_name: str | None = None, + data_asset_name: str | None = None, + action_list: list[dict[str, Any]] | None = None, **kwargs: Any, ) -> CheckpointResult: """Great Expectations data validation. @@ -513,7 +514,7 @@ def data_validation( }, ] - checkpoint_config: Dict[str, Any] = { + checkpoint_config: dict[str, Any] = { "name": checkpoint_name, "run_name_template": run_name, "config_version": 1, diff --git a/src/zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py b/src/zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py index fdbda159fbe..ff17d4bc9ef 100644 --- a/src/zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py +++ b/src/zenml/integrations/great_expectations/flavors/great_expectations_data_validator_flavor.py @@ -14,7 +14,7 @@ """Great Expectations data validator flavor.""" import os -from typing import TYPE_CHECKING, Any, Dict, Optional, Type +from typing import TYPE_CHECKING, Any import yaml from pydantic import field_validator, model_validator @@ -54,16 +54,16 @@ class GreatExpectationsDataValidatorConfig(BaseDataValidatorConfig): Expectations docs are generated and can be visualized locally. """ - context_root_dir: Optional[str] = None - context_config: Optional[Dict[str, Any]] = None + context_root_dir: str | None = None + context_config: dict[str, Any] | None = None configure_zenml_stores: bool = False configure_local_docs: bool = True @field_validator("context_root_dir") @classmethod def _ensure_valid_context_root_dir( - cls, context_root_dir: Optional[str] = None - ) -> Optional[str]: + cls, context_root_dir: str | None = None + ) -> str | None: """Ensures that the root directory is an absolute path and points to an existing path. Args: @@ -87,7 +87,7 @@ def _ensure_valid_context_root_dir( @model_validator(mode="before") @classmethod @before_validator_handler - def validate_context_config(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def validate_context_config(cls, data: dict[str, Any]) -> dict[str, Any]: """Convert the context configuration if given in JSON/YAML format. Args: @@ -137,7 +137,7 @@ def name(self) -> str: return GREAT_EXPECTATIONS_DATA_VALIDATOR_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -146,7 +146,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -164,7 +164,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/data_validator/greatexpectations.jpeg" @property - def config_class(self) -> Type[GreatExpectationsDataValidatorConfig]: + def config_class(self) -> type[GreatExpectationsDataValidatorConfig]: """Returns `GreatExpectationsDataValidatorConfig` config class. Returns: @@ -173,7 +173,7 @@ def config_class(self) -> Type[GreatExpectationsDataValidatorConfig]: return GreatExpectationsDataValidatorConfig @property - def implementation_class(self) -> Type["GreatExpectationsDataValidator"]: + def implementation_class(self) -> type["GreatExpectationsDataValidator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/great_expectations/ge_store_backend.py b/src/zenml/integrations/great_expectations/ge_store_backend.py index 6a5f2d89bfb..b6a6c634e97 100644 --- a/src/zenml/integrations/great_expectations/ge_store_backend.py +++ b/src/zenml/integrations/great_expectations/ge_store_backend.py @@ -15,7 +15,7 @@ import os from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Any, cast from great_expectations.data_context.store.tuple_store_backend import ( TupleStoreBackend, @@ -92,7 +92,7 @@ def __init__( ) def _build_object_path( - self, key: Tuple[str, ...], is_prefix: bool = False + self, key: tuple[str, ...], is_prefix: bool = False ) -> str: """Build a filepath corresponding to an object key. @@ -118,7 +118,7 @@ def _build_object_path( object_key = object_relative_path return os.path.join(self.root_path, object_key) - def _get(self, key: Tuple[str, ...]) -> str: # type: ignore[override] + def _get(self, key: tuple[str, ...]) -> str: # type: ignore[override] """Get the value of an object from the store. Args: @@ -142,7 +142,7 @@ def _get(self, key: Tuple[str, ...]) -> str: # type: ignore[override] ) return contents - def _get_all(self) -> List[Any]: + def _get_all(self) -> list[Any]: """Get all objects in the store. Raises: @@ -153,7 +153,7 @@ def _get_all(self) -> List[Any]: "Method `_get_all` is not implemented for this store backend." ) - def _set(self, key: Tuple[str, ...], value: str, **kwargs: Any) -> str: # type: ignore[override] + def _set(self, key: tuple[str, ...], value: str, **kwargs: Any) -> str: # type: ignore[override] """Set the value of an object in the store. Args: @@ -178,8 +178,8 @@ def _set(self, key: Tuple[str, ...], value: str, **kwargs: Any) -> str: # type: def _move( self, - source_key: Tuple[str, ...], - dest_key: Tuple[str, ...], + source_key: tuple[str, ...], + dest_key: tuple[str, ...], **kwargs: Any, ) -> None: """Associate an object with a different key in the store. @@ -198,7 +198,7 @@ def _move( os.makedirs(parent_dir, exist_ok=True) fileio.rename(source_path, dest_path, overwrite=True) - def list_keys(self, prefix: Tuple[str, ...] = ()) -> List[Tuple[str, ...]]: + def list_keys(self, prefix: tuple[str, ...] = ()) -> list[tuple[str, ...]]: """List the keys of all objects identified by a partial key. Args: @@ -230,7 +230,7 @@ def list_keys(self, prefix: Tuple[str, ...] = ()) -> List[Tuple[str, ...]]: key_list.append(key) return key_list - def remove_key(self, key: Tuple[str, ...]) -> bool: # type: ignore[override] + def remove_key(self, key: tuple[str, ...]) -> bool: # type: ignore[override] """Delete an object from the store. Args: @@ -250,7 +250,7 @@ def remove_key(self, key: Tuple[str, ...]) -> bool: # type: ignore[override] return True return False - def _has_key(self, key: Tuple[str, ...]) -> bool: + def _has_key(self, key: tuple[str, ...]) -> bool: """Check if an object is present in the store. Args: @@ -264,7 +264,7 @@ def _has_key(self, key: Tuple[str, ...]) -> bool: return result def get_url_for_key( # type: ignore[override] - self, key: Tuple[str, ...], protocol: Optional[str] = None + self, key: tuple[str, ...], protocol: str | None = None ) -> str: """Get the URL of an object in the store. @@ -284,7 +284,7 @@ def get_url_for_key( # type: ignore[override] return filepath def get_public_url_for_key( - self, key: str, protocol: Optional[str] = None + self, key: str, protocol: str | None = None ) -> str: """Get the public URL of an object in the store. @@ -322,7 +322,7 @@ def rrmdir(start_path: str, end_path: str) -> None: end_path = os.path.dirname(end_path) @property - def config(self) -> Dict[str, Any]: + def config(self) -> dict[str, Any]: """Get the store configuration. Returns: diff --git a/src/zenml/integrations/great_expectations/materializers/ge_materializer.py b/src/zenml/integrations/great_expectations/materializers/ge_materializer.py index 88fd66364f4..49d50bc09b1 100644 --- a/src/zenml/integrations/great_expectations/materializers/ge_materializer.py +++ b/src/zenml/integrations/great_expectations/materializers/ge_materializer.py @@ -14,7 +14,7 @@ """Implementation of the Great Expectations materializers.""" import os -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast from great_expectations.checkpoint.types.checkpoint_result import ( # type: ignore[import-untyped] CheckpointResult, @@ -52,7 +52,7 @@ class GreatExpectationsMaterializer(BaseMaterializer): """Materializer to read/write Great Expectation objects.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = ( + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = ( ExpectationSuite, CheckpointResult, ) @@ -62,7 +62,7 @@ class GreatExpectationsMaterializer(BaseMaterializer): @staticmethod def preprocess_checkpoint_result_dict( - artifact_dict: Dict[str, Any], + artifact_dict: dict[str, Any], ) -> None: """Pre-processes a GE checkpoint dict before it is used to de-serialize a GE CheckpointResult object. @@ -97,7 +97,7 @@ def preprocess_run_result(key: str, value: Any) -> Any: validation_dict[validation_ident] = validation_results artifact_dict["run_results"] = validation_dict - def load(self, data_type: Type[Any]) -> SerializableDictDot: + def load(self, data_type: type[Any]) -> SerializableDictDot: """Reads and returns a Great Expectations object. Args: @@ -130,8 +130,8 @@ def save(self, obj: SerializableDictDot) -> None: yaml_utils.write_json(filepath, artifact_dict) def save_visualizations( - self, data: Union[ExpectationSuite, CheckpointResult] - ) -> Dict[str, VisualizationType]: + self, data: ExpectationSuite | CheckpointResult + ) -> dict[str, VisualizationType]: """Saves visualizations for the given Great Expectations object. Args: @@ -160,8 +160,8 @@ def save_visualizations( return visualizations def extract_metadata( - self, data: Union[ExpectationSuite, CheckpointResult] - ) -> Dict[str, "MetadataType"]: + self, data: ExpectationSuite | CheckpointResult + ) -> dict[str, "MetadataType"]: """Extract metadata from the given Great Expectations object. Args: diff --git a/src/zenml/integrations/great_expectations/steps/__init__.py b/src/zenml/integrations/great_expectations/steps/__init__.py index eca1af5290c..728cb292d06 100644 --- a/src/zenml/integrations/great_expectations/steps/__init__.py +++ b/src/zenml/integrations/great_expectations/steps/__init__.py @@ -14,9 +14,3 @@ """Great Expectations data profiling and validation standard steps.""" -from zenml.integrations.great_expectations.steps.ge_profiler import ( - great_expectations_profiler_step, -) -from zenml.integrations.great_expectations.steps.ge_validator import ( - great_expectations_validator_step, -) diff --git a/src/zenml/integrations/great_expectations/steps/ge_profiler.py b/src/zenml/integrations/great_expectations/steps/ge_profiler.py index cc1edd2d187..7808faa87ce 100644 --- a/src/zenml/integrations/great_expectations/steps/ge_profiler.py +++ b/src/zenml/integrations/great_expectations/steps/ge_profiler.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Great Expectations data profiling standard step.""" -from typing import Any, Dict, Optional +from typing import Any import pandas as pd from great_expectations.core import ( # type: ignore[import-untyped] @@ -30,8 +30,8 @@ def great_expectations_profiler_step( dataset: pd.DataFrame, expectation_suite_name: str, - data_asset_name: Optional[str] = None, - profiler_kwargs: Optional[Dict[str, Any]] = None, + data_asset_name: str | None = None, + profiler_kwargs: dict[str, Any] | None = None, overwrite_existing_suite: bool = True, ) -> ExpectationSuite: """Infer data validation rules from a pandas dataset. diff --git a/src/zenml/integrations/great_expectations/steps/ge_validator.py b/src/zenml/integrations/great_expectations/steps/ge_validator.py index e23c9cd0e0a..b636ab4b868 100644 --- a/src/zenml/integrations/great_expectations/steps/ge_validator.py +++ b/src/zenml/integrations/great_expectations/steps/ge_validator.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Great Expectations data validation standard step.""" -from typing import Any, Dict, List, Optional +from typing import Any import pandas as pd from great_expectations.checkpoint.types.checkpoint_result import ( # type: ignore[import-untyped] @@ -30,8 +30,8 @@ def great_expectations_validator_step( dataset: pd.DataFrame, expectation_suite_name: str, - data_asset_name: Optional[str] = None, - action_list: Optional[List[Dict[str, Any]]] = None, + data_asset_name: str | None = None, + action_list: list[dict[str, Any]] | None = None, exit_on_error: bool = False, ) -> CheckpointResult: """Shortcut function to create a new instance of the GreatExpectationsValidatorStep step. diff --git a/src/zenml/integrations/great_expectations/utils.py b/src/zenml/integrations/great_expectations/utils.py index aeeeec94ff8..9a4abedd7fe 100644 --- a/src/zenml/integrations/great_expectations/utils.py +++ b/src/zenml/integrations/great_expectations/utils.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Great Expectations data profiling standard step.""" -from typing import Any, Dict, Optional +from typing import Any import pandas as pd from great_expectations.core.batch import ( # type: ignore[import-untyped] @@ -33,7 +33,7 @@ def create_batch_request( context: AbstractDataContext, dataset: pd.DataFrame, - data_asset_name: Optional[str], + data_asset_name: str | None, ) -> RuntimeBatchRequest: """Create a temporary runtime GE batch request from a dataset step artifact. @@ -62,7 +62,7 @@ def create_batch_request( data_asset_name = data_asset_name or f"{pipeline_name}_{step_name}" batch_identifier = "default" - datasource_config: Dict[str, Any] = { + datasource_config: dict[str, Any] = { "name": datasource_name, "class_name": "Datasource", "module_name": "great_expectations.datasource", diff --git a/src/zenml/integrations/huggingface/__init__.py b/src/zenml/integrations/huggingface/__init__.py index 0fb3ce74214..6aeb3ad16e0 100644 --- a/src/zenml/integrations/huggingface/__init__.py +++ b/src/zenml/integrations/huggingface/__init__.py @@ -12,8 +12,6 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Initialization of the Huggingface integration.""" -import sys -from typing import List, Type, Optional from zenml.integrations.constants import HUGGINGFACE from zenml.integrations.integration import Integration @@ -34,11 +32,10 @@ class HuggingfaceIntegration(Integration): def activate(cls) -> None: """Activates the integration.""" from zenml.integrations.huggingface import materializers # noqa - from zenml.integrations.huggingface import services @classmethod - def get_requirements(cls, target_os: Optional[str] = None, python_version: Optional[str] = None - ) -> List[str]: + def get_requirements(cls, target_os: str | None = None, python_version: str | None = None + ) -> list[str]: """Defines platform specific requirements for the integration. Args: @@ -64,7 +61,7 @@ def get_requirements(cls, target_os: Optional[str] = None, python_version: Optio PandasIntegration.get_requirements(target_os=target_os, python_version=python_version) @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Huggingface integration. Returns: diff --git a/src/zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py b/src/zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py index 88ab3073070..c5e25ec7f17 100644 --- a/src/zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py +++ b/src/zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Hugging Face model deployer flavor.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Type +from typing import TYPE_CHECKING, Any from pydantic import BaseModel @@ -33,22 +33,22 @@ class HuggingFaceBaseConfig(BaseModel): """Hugging Face Inference Endpoint configuration.""" - repository: Optional[str] = None - framework: Optional[str] = None - accelerator: Optional[str] = None - instance_size: Optional[str] = None - instance_type: Optional[str] = None - region: Optional[str] = None - vendor: Optional[str] = None - account_id: Optional[str] = None + repository: str | None = None + framework: str | None = None + accelerator: str | None = None + instance_size: str | None = None + instance_type: str | None = None + region: str | None = None + vendor: str | None = None + account_id: str | None = None min_replica: int = 0 max_replica: int = 1 - revision: Optional[str] = None - task: Optional[str] = None - custom_image: Optional[Dict[str, Any]] = None + revision: str | None = None + task: str | None = None + custom_image: dict[str, Any] | None = None endpoint_type: str = "public" - secret_name: Optional[str] = None - namespace: Optional[str] = None + secret_name: str | None = None + namespace: str | None = None class HuggingFaceModelDeployerConfig( @@ -61,7 +61,7 @@ class HuggingFaceModelDeployerConfig( namespace: Hugging Face namespace used to list endpoints """ - token: Optional[str] = SecretField(default=None) + token: str | None = SecretField(default=None) # The namespace to list endpoints for. Set to `"*"` to list all endpoints # from all namespaces (i.e. personal namespace and all orgs the user belongs to). @@ -81,7 +81,7 @@ def name(self) -> str: return HUGGINGFACE_MODEL_DEPLOYER_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -90,7 +90,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -108,7 +108,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_registry/huggingface.png" @property - def config_class(self) -> Type[HuggingFaceModelDeployerConfig]: + def config_class(self) -> type[HuggingFaceModelDeployerConfig]: """Returns `HuggingFaceModelDeployerConfig` config class. Returns: @@ -117,7 +117,7 @@ def config_class(self) -> Type[HuggingFaceModelDeployerConfig]: return HuggingFaceModelDeployerConfig @property - def implementation_class(self) -> Type["HuggingFaceModelDeployer"]: + def implementation_class(self) -> type["HuggingFaceModelDeployer"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/huggingface/materializers/__init__.py b/src/zenml/integrations/huggingface/materializers/__init__.py index 117a009a303..d76ffc1d147 100644 --- a/src/zenml/integrations/huggingface/materializers/__init__.py +++ b/src/zenml/integrations/huggingface/materializers/__init__.py @@ -13,18 +13,3 @@ # permissions and limitations under the License. """Initialization of Huggingface materializers.""" -from zenml.integrations.huggingface.materializers.huggingface_datasets_materializer import ( - HFDatasetMaterializer, -) -from zenml.integrations.huggingface.materializers.huggingface_pt_model_materializer import ( - HFPTModelMaterializer, -) -from zenml.integrations.huggingface.materializers.huggingface_tf_model_materializer import ( - HFTFModelMaterializer, -) -from zenml.integrations.huggingface.materializers.huggingface_tokenizer_materializer import ( - HFTokenizerMaterializer, -) -from zenml.integrations.huggingface.materializers.huggingface_t5_materializer import ( - HFT5Materializer, -) diff --git a/src/zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py b/src/zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py index a08314a40ce..b3a316c36a2 100644 --- a/src/zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py +++ b/src/zenml/integrations/huggingface/materializers/huggingface_datasets_materializer.py @@ -19,11 +19,6 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - Optional, - Tuple, - Type, - Union, ) from datasets import Dataset, load_from_disk @@ -43,7 +38,7 @@ DEFAULT_DATASET_DIR = "hf_datasets" -def extract_repo_name(checksum_str: str) -> Optional[str]: +def extract_repo_name(checksum_str: str) -> str | None: """Extracts the repo name from the checksum string. An example of a checksum_str is: @@ -71,14 +66,14 @@ def extract_repo_name(checksum_str: str) -> Optional[str]: class HFDatasetMaterializer(BaseMaterializer): """Materializer to read data to and from huggingface datasets.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Dataset, DatasetDict) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (Dataset, DatasetDict) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ( ArtifactType.DATA_ANALYSIS ) def load( - self, data_type: Union[Type[Dataset], Type[DatasetDict]] - ) -> Union[Dataset, DatasetDict]: + self, data_type: type[Dataset] | type[DatasetDict] + ) -> Dataset | DatasetDict: """Reads Dataset. Args: @@ -94,7 +89,7 @@ def load( ) return load_from_disk(temp_dir) - def save(self, ds: Union[Dataset, DatasetDict]) -> None: + def save(self, ds: Dataset | DatasetDict) -> None: """Writes a Dataset to the specified dir. Args: @@ -109,8 +104,8 @@ def save(self, ds: Union[Dataset, DatasetDict]) -> None: ) def extract_metadata( - self, ds: Union[Dataset, DatasetDict] - ) -> Dict[str, "MetadataType"]: + self, ds: Dataset | DatasetDict + ) -> dict[str, "MetadataType"]: """Extract metadata from the given `Dataset` object. Args: @@ -126,7 +121,7 @@ def extract_metadata( if isinstance(ds, Dataset): return pandas_materializer.extract_metadata(ds.to_pandas()) elif isinstance(ds, DatasetDict): - metadata: Dict[str, Dict[str, "MetadataType"]] = defaultdict(dict) + metadata: dict[str, dict[str, "MetadataType"]] = defaultdict(dict) for dataset_name, dataset in ds.items(): dataset_metadata = pandas_materializer.extract_metadata( dataset.to_pandas() @@ -137,8 +132,8 @@ def extract_metadata( raise ValueError(f"Unsupported type {type(ds)}") def save_visualizations( - self, ds: Union[Dataset, DatasetDict] - ) -> Dict[str, VisualizationType]: + self, ds: Dataset | DatasetDict + ) -> dict[str, VisualizationType]: """Save visualizations for the dataset. Args: diff --git a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py index dc48626bf48..ed19b01a897 100644 --- a/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py +++ b/src/zenml/integrations/huggingface/materializers/huggingface_pt_model_materializer.py @@ -15,7 +15,7 @@ import importlib import os -from typing import Any, ClassVar, Dict, Tuple, Type +from typing import Any, ClassVar from transformers import ( AutoConfig, @@ -33,10 +33,10 @@ class HFPTModelMaterializer(BaseMaterializer): """Materializer to read torch model to and from huggingface pretrained model.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (PreTrainedModel,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (PreTrainedModel,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL - def load(self, data_type: Type[PreTrainedModel]) -> PreTrainedModel: + def load(self, data_type: type[PreTrainedModel]) -> PreTrainedModel: """Reads HFModel. Args: @@ -72,7 +72,7 @@ def save(self, model: PreTrainedModel) -> None: def extract_metadata( self, model: PreTrainedModel - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Extract metadata from the given `PreTrainedModel` object. Args: diff --git a/src/zenml/integrations/huggingface/materializers/huggingface_t5_materializer.py b/src/zenml/integrations/huggingface/materializers/huggingface_t5_materializer.py index f596d1ee0ae..851cb461c0f 100644 --- a/src/zenml/integrations/huggingface/materializers/huggingface_t5_materializer.py +++ b/src/zenml/integrations/huggingface/materializers/huggingface_t5_materializer.py @@ -14,7 +14,7 @@ """Implementation of the Huggingface t5 materializer.""" import os -from typing import Any, ClassVar, Type, Union +from typing import Any, ClassVar from transformers import ( T5ForConditionalGeneration, @@ -37,8 +37,8 @@ class HFT5Materializer(BaseMaterializer): ) def load( - self, data_type: Type[Any] - ) -> Union[T5ForConditionalGeneration, T5Tokenizer, T5TokenizerFast]: + self, data_type: type[Any] + ) -> T5ForConditionalGeneration | T5Tokenizer | T5TokenizerFast: """Reads a T5ForConditionalGeneration model or T5Tokenizer from a serialized zip file. Args: @@ -77,7 +77,7 @@ def load( def save( self, - obj: Union[T5ForConditionalGeneration, T5Tokenizer, T5TokenizerFast], + obj: T5ForConditionalGeneration | T5Tokenizer | T5TokenizerFast, ) -> None: """Creates a serialization for a T5ForConditionalGeneration model or T5Tokenizer. diff --git a/src/zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py b/src/zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py index e604f3d4ef9..dc0e0e12e2e 100644 --- a/src/zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py +++ b/src/zenml/integrations/huggingface/materializers/huggingface_tf_model_materializer.py @@ -15,7 +15,7 @@ import importlib import os -from typing import Any, ClassVar, Dict, Tuple, Type +from typing import Any, ClassVar from transformers import ( AutoConfig, @@ -33,10 +33,10 @@ class HFTFModelMaterializer(BaseMaterializer): """Materializer to read Tensorflow model to and from huggingface pretrained model.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (TFPreTrainedModel,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (TFPreTrainedModel,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL - def load(self, data_type: Type[TFPreTrainedModel]) -> TFPreTrainedModel: + def load(self, data_type: type[TFPreTrainedModel]) -> TFPreTrainedModel: """Reads HFModel. Args: @@ -72,7 +72,7 @@ def save(self, model: TFPreTrainedModel) -> None: def extract_metadata( self, model: TFPreTrainedModel - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Extract metadata from the given `PreTrainedModel` object. Args: diff --git a/src/zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py b/src/zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py index 02db6f07dd3..b5d41f9080c 100644 --- a/src/zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py +++ b/src/zenml/integrations/huggingface/materializers/huggingface_tokenizer_materializer.py @@ -14,7 +14,7 @@ """Implementation of the Huggingface tokenizer materializer.""" import os -from typing import Any, ClassVar, Tuple, Type +from typing import Any, ClassVar from transformers import AutoTokenizer from transformers.tokenization_utils_base import ( @@ -31,12 +31,12 @@ class HFTokenizerMaterializer(BaseMaterializer): """Materializer to read tokenizer to and from huggingface tokenizer.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = ( + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = ( PreTrainedTokenizerBase, ) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL - def load(self, data_type: Type[Any]) -> PreTrainedTokenizerBase: + def load(self, data_type: type[Any]) -> PreTrainedTokenizerBase: """Reads Tokenizer. Args: @@ -51,7 +51,7 @@ def load(self, data_type: Type[Any]) -> PreTrainedTokenizerBase: ) return AutoTokenizer.from_pretrained(temp_dir) - def save(self, tokenizer: Type[Any]) -> None: + def save(self, tokenizer: type[Any]) -> None: """Writes a Tokenizer to the specified dir. Args: diff --git a/src/zenml/integrations/huggingface/model_deployers/huggingface_model_deployer.py b/src/zenml/integrations/huggingface/model_deployers/huggingface_model_deployer.py index eb551d5051a..1277af7af6f 100644 --- a/src/zenml/integrations/huggingface/model_deployers/huggingface_model_deployer.py +++ b/src/zenml/integrations/huggingface/model_deployers/huggingface_model_deployer.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the Hugging Face Model Deployer.""" -from typing import ClassVar, Dict, Optional, Tuple, Type, cast +from typing import ClassVar, cast from uuid import UUID from zenml.analytics.enums import AnalyticsEvent @@ -45,7 +45,7 @@ class HuggingFaceModelDeployer(BaseModelDeployer): """Hugging Face endpoint model deployer.""" NAME: ClassVar[str] = "HuggingFace" - FLAVOR: ClassVar[Type[BaseModelDeployerFlavor]] = ( + FLAVOR: ClassVar[type[BaseModelDeployerFlavor]] = ( HuggingFaceModelDeployerFlavor ) @@ -59,7 +59,7 @@ def config(self) -> HuggingFaceModelDeployerConfig: return cast(HuggingFaceModelDeployerConfig, self._config) @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates the stack. Returns: @@ -69,7 +69,7 @@ def validator(self) -> Optional[StackValidator]: def _validate_if_secret_or_token_is_present( stack: "Stack", - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: """Check if secret or token is present in the stack. Args: @@ -232,7 +232,7 @@ def perform_delete_model( @staticmethod def get_model_server_info( # type: ignore[override] service_instance: "HuggingFaceDeploymentService", - ) -> Dict[str, Optional[str]]: + ) -> dict[str, str | None]: """Return implementation specific information that might be relevant to the user. Args: diff --git a/src/zenml/integrations/huggingface/services/huggingface_deployment.py b/src/zenml/integrations/huggingface/services/huggingface_deployment.py index d3596ff6634..436c544145e 100644 --- a/src/zenml/integrations/huggingface/services/huggingface_deployment.py +++ b/src/zenml/integrations/huggingface/services/huggingface_deployment.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Implementation of the Hugging Face Deployment service.""" -from typing import Any, Dict, Generator, Optional, Tuple +from typing import Any +from collections.abc import Generator from huggingface_hub import ( InferenceClient, @@ -123,7 +124,7 @@ def hf_endpoint(self) -> InferenceEndpoint: ) @property - def prediction_url(self) -> Optional[str]: + def prediction_url(self) -> str | None: """The prediction URI exposed by the prediction service. Returns: @@ -141,7 +142,7 @@ def inference_client(self) -> InferenceClient: """ return self.hf_endpoint.client - def _validate_endpoint_configuration(self) -> Dict[str, str]: + def _validate_endpoint_configuration(self) -> dict[str, str]: """Validates the configuration to provision a Huggingface service. Raises: @@ -229,7 +230,7 @@ def provision(self) -> None: "Face console for more details." ) - def check_status(self) -> Tuple[ServiceState, str]: + def check_status(self) -> tuple[ServiceState, str]: """Check the current operational state of the Hugging Face deployment. Returns: @@ -309,7 +310,7 @@ def predict(self, data: "Any", max_new_tokens: int) -> "Any": ) def get_logs( - self, follow: bool = False, tail: Optional[int] = None + self, follow: bool = False, tail: int | None = None ) -> Generator[str, bool, None]: """Retrieve the service logs. diff --git a/src/zenml/integrations/huggingface/steps/__init__.py b/src/zenml/integrations/huggingface/steps/__init__.py index 99365cd9569..96507341aee 100644 --- a/src/zenml/integrations/huggingface/steps/__init__.py +++ b/src/zenml/integrations/huggingface/steps/__init__.py @@ -13,9 +13,3 @@ # permissions and limitations under the License. """Initialization for Hugging Face model deployer step.""" -from zenml.integrations.huggingface.steps.huggingface_deployer import ( - huggingface_model_deployer_step, -) -from zenml.integrations.huggingface.steps.accelerate_runner import ( - run_with_accelerate, -) \ No newline at end of file diff --git a/src/zenml/integrations/huggingface/steps/accelerate_runner.py b/src/zenml/integrations/huggingface/steps/accelerate_runner.py index c1cabc4a0f9..5534721c8f5 100644 --- a/src/zenml/integrations/huggingface/steps/accelerate_runner.py +++ b/src/zenml/integrations/huggingface/steps/accelerate_runner.py @@ -17,7 +17,8 @@ """Step function to run any ZenML step using Accelerate.""" import functools -from typing import Any, Callable, Dict, Optional, TypeVar, Union, cast +from typing import Any, TypeVar, cast +from collections.abc import Callable import cloudpickle as pickle from accelerate.commands.launch import ( @@ -35,9 +36,9 @@ def run_with_accelerate( - step_function_top_level: Optional[BaseStep] = None, + step_function_top_level: BaseStep | None = None, **accelerate_launch_kwargs: Any, -) -> Union[Callable[[BaseStep], BaseStep], BaseStep]: +) -> Callable[[BaseStep], BaseStep] | BaseStep: """Run a function with accelerate. Accelerate package: https://huggingface.co/docs/accelerate/en/index @@ -72,7 +73,7 @@ def training_pipeline(some_param: int, ...): def _decorator(step_function: BaseStep) -> BaseStep: def _wrapper( - entrypoint: F, accelerate_launch_kwargs: Dict[str, Any] + entrypoint: F, accelerate_launch_kwargs: dict[str, Any] ) -> F: @functools.wraps(entrypoint) def inner(*args: Any, **kwargs: Any) -> Any: diff --git a/src/zenml/integrations/hyperai/__init__.py b/src/zenml/integrations/hyperai/__init__.py index a46040e6a4e..b14b6870b49 100644 --- a/src/zenml/integrations/hyperai/__init__.py +++ b/src/zenml/integrations/hyperai/__init__.py @@ -12,7 +12,6 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Initialization of the HyperAI integration.""" -from typing import List, Type from zenml.integrations.constants import HYPERAI from zenml.integrations.integration import Integration @@ -37,7 +36,7 @@ def activate(cls) -> None: from zenml.integrations.hyperai import service_connectors # noqa @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the HyperAI integration. Returns: diff --git a/src/zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py b/src/zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py index be7b19d33cc..9082f27b939 100644 --- a/src/zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py +++ b/src/zenml/integrations/hyperai/flavors/hyperai_orchestrator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the ZenML HyperAI orchestrator.""" -from typing import TYPE_CHECKING, Dict, Optional, Type +from typing import TYPE_CHECKING from pydantic import Field @@ -35,7 +35,7 @@ class HyperAIOrchestratorSettings(BaseSettings): """HyperAI orchestrator settings.""" - mounts_from_to: Dict[str, str] = Field( + mounts_from_to: dict[str, str] = Field( default_factory=dict, description="A dictionary mapping from paths on the HyperAI instance " "to paths within the Docker container. This allows users to mount " @@ -104,7 +104,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -119,7 +119,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -128,7 +128,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -146,7 +146,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/connectors/hyperai/hyperai.png" @property - def config_class(self) -> Type[BaseOrchestratorConfig]: + def config_class(self) -> type[BaseOrchestratorConfig]: """Config class for the base orchestrator flavor. Returns: @@ -155,7 +155,7 @@ def config_class(self) -> Type[BaseOrchestratorConfig]: return HyperAIOrchestratorConfig @property - def implementation_class(self) -> Type["HyperAIOrchestrator"]: + def implementation_class(self) -> type["HyperAIOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/hyperai/orchestrators/hyperai_orchestrator.py b/src/zenml/integrations/hyperai/orchestrators/hyperai_orchestrator.py index 2755a6d9e63..3e65a000b4a 100644 --- a/src/zenml/integrations/hyperai/orchestrators/hyperai_orchestrator.py +++ b/src/zenml/integrations/hyperai/orchestrators/hyperai_orchestrator.py @@ -17,7 +17,7 @@ import re import tempfile from shlex import quote -from typing import IO, TYPE_CHECKING, Any, Dict, Optional, Type, cast +from typing import IO, TYPE_CHECKING, Any, Optional, cast import paramiko import yaml @@ -54,7 +54,7 @@ def config(self) -> HyperAIOrchestratorConfig: return cast(HyperAIOrchestratorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the HyperAI orchestrator. Returns: @@ -63,7 +63,7 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: return HyperAIOrchestratorSettings @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Ensures there is an image builder in the stack. Returns: @@ -161,10 +161,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to @@ -202,7 +202,7 @@ def submit_pipeline( HyperAIServiceConnector, ) - compose_definition: Dict[str, Any] = {"version": "3", "services": {}} + compose_definition: dict[str, Any] = {"version": "3", "services": {}} snapshot_id = snapshot.id os.environ[ENV_ZENML_HYPERAI_RUN_ID] = str(snapshot_id) @@ -362,7 +362,7 @@ def submit_pipeline( ) # Send the password to stdin stdin.channel.send( - f"{container_registry_password}\n".encode("utf-8") + f"{container_registry_password}\n".encode() ) stdin.channel.shutdown_write() diff --git a/src/zenml/integrations/hyperai/service_connectors/hyperai_service_connector.py b/src/zenml/integrations/hyperai/service_connectors/hyperai_service_connector.py index 8f68470cacc..4825928a48b 100644 --- a/src/zenml/integrations/hyperai/service_connectors/hyperai_service_connector.py +++ b/src/zenml/integrations/hyperai/service_connectors/hyperai_service_connector.py @@ -19,7 +19,7 @@ import base64 import io -from typing import Any, List, Optional, Type +from typing import Any import paramiko from pydantic import Field @@ -51,7 +51,7 @@ class HyperAICredentials(AuthenticationConfig): base64_ssh_key: PlainSerializedSecretStr = Field( title="SSH key (base64)", ) - ssh_passphrase: Optional[PlainSerializedSecretStr] = Field( + ssh_passphrase: PlainSerializedSecretStr | None = Field( default=None, title="SSH key passphrase", ) @@ -60,7 +60,7 @@ class HyperAICredentials(AuthenticationConfig): class HyperAIConfiguration(HyperAICredentials): """HyperAI client configuration.""" - hostnames: List[str] = Field( + hostnames: list[str] = Field( title="Hostnames of the supported HyperAI instances.", ) @@ -169,7 +169,7 @@ def _get_connector_type(cls) -> ServiceConnectorTypeModel: """ return HYPERAI_SERVICE_CONNECTOR_TYPE_SPEC - def _paramiko_key_type_given_auth_method(self) -> Type[paramiko.PKey]: + def _paramiko_key_type_given_auth_method(self) -> type[paramiko.PKey]: """Get the Paramiko key type given the authentication method. Returns: @@ -309,9 +309,9 @@ def _configure_local_client( @classmethod def _auto_configure( cls, - auth_method: Optional[str] = None, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + auth_method: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, **kwargs: Any, ) -> "HyperAIServiceConnector": """Auto-configure the connector. @@ -339,9 +339,9 @@ def _auto_configure( def _verify( self, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> List[str]: + resource_type: str | None = None, + resource_id: str | None = None, + ) -> list[str]: """Verify that a connection can be established to the HyperAI instance. Args: diff --git a/src/zenml/integrations/integration.py b/src/zenml/integrations/integration.py index c37d17c5d34..dbf4ead889d 100644 --- a/src/zenml/integrations/integration.py +++ b/src/zenml/integrations/integration.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Base and meta classes for ZenML integrations.""" -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast +from typing import TYPE_CHECKING, Any, cast from packaging.requirements import Requirement @@ -33,7 +33,7 @@ class IntegrationMeta(type): """Metaclass responsible for registering different Integration subclasses.""" def __new__( - mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any] + mcs, name: str, bases: tuple[type[Any], ...], dct: dict[str, Any] ) -> "IntegrationMeta": """Hook into creation of an Integration class. @@ -45,7 +45,7 @@ def __new__( Returns: The newly created class. """ - cls = cast(Type["Integration"], super().__new__(mcs, name, bases, dct)) + cls = cast(type["Integration"], super().__new__(mcs, name, bases, dct)) if name != "Integration": integration_registry.register_integration(cls.NAME, cls) return cls @@ -56,9 +56,9 @@ class Integration(metaclass=IntegrationMeta): NAME = "base_integration" - REQUIREMENTS: List[str] = [] - APT_PACKAGES: List[str] = [] - REQUIREMENTS_IGNORED_ON_UNINSTALL: List[str] = [] + REQUIREMENTS: list[str] = [] + APT_PACKAGES: list[str] = [] + REQUIREMENTS_IGNORED_ON_UNINSTALL: list[str] = [] @classmethod def check_installation(cls) -> bool: @@ -100,9 +100,9 @@ def check_installation(cls) -> bool: @classmethod def get_requirements( cls, - target_os: Optional[str] = None, - python_version: Optional[str] = None, - ) -> List[str]: + target_os: str | None = None, + python_version: str | None = None, + ) -> list[str]: """Method to get the requirements for the integration. Args: @@ -116,8 +116,8 @@ def get_requirements( @classmethod def get_uninstall_requirements( - cls, target_os: Optional[str] = None - ) -> List[str]: + cls, target_os: str | None = None + ) -> list[str]: """Method to get the uninstall requirements for the integration. Args: @@ -142,7 +142,7 @@ def activate(cls) -> None: """Abstract method to activate the integration.""" @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Abstract method to declare new stack component flavors. Returns: @@ -151,7 +151,7 @@ def flavors(cls) -> List[Type[Flavor]]: return [] @classmethod - def plugin_flavors(cls) -> List[Type["BasePluginFlavor"]]: + def plugin_flavors(cls) -> list[type["BasePluginFlavor"]]: """Abstract method to declare new plugin flavors. Returns: diff --git a/src/zenml/integrations/jax/materializer.py b/src/zenml/integrations/jax/materializer.py index df311b5835a..121178530f1 100644 --- a/src/zenml/integrations/jax/materializer.py +++ b/src/zenml/integrations/jax/materializer.py @@ -17,8 +17,6 @@ from typing import ( Any, ClassVar, - Tuple, - Type, ) import jax @@ -34,10 +32,10 @@ class JAXArrayMaterializer(BaseMaterializer): """A materializer for JAX arrays.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (jax.Array,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (jax.Array,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - def load(self, data_type: Type[Any]) -> jax.Array: + def load(self, data_type: type[Any]) -> jax.Array: """Reads data from a `.npy` file, and returns a JAX array. Args: diff --git a/src/zenml/integrations/kaniko/__init__.py b/src/zenml/integrations/kaniko/__init__.py index b1f9ea09d22..626cf17040a 100644 --- a/src/zenml/integrations/kaniko/__init__.py +++ b/src/zenml/integrations/kaniko/__init__.py @@ -12,7 +12,6 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Kaniko integration for image building.""" -from typing import List, Type from zenml.integrations.constants import KANIKO from zenml.integrations.integration import Integration @@ -28,7 +27,7 @@ class KanikoIntegration(Integration): REQUIREMENTS = [] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Kaniko integration. Returns: diff --git a/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py b/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py index b7733bd29a4..5245820c38c 100644 --- a/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py +++ b/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Kaniko image builder flavor.""" -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any from pydantic import Field, PositiveInt @@ -59,27 +59,27 @@ class KanikoImageBuilderConfig(BaseImageBuilderConfig): description="The timeout to wait until the pod is running in seconds.", ) - env: List[Dict[str, Any]] = Field( + env: list[dict[str, Any]] = Field( default_factory=list, description="Environment variables section of the Kubernetes container spec. " "Used to configure secrets and environment variables for registry access.", ) - env_from: List[Dict[str, Any]] = Field( + env_from: list[dict[str, Any]] = Field( default_factory=list, description="EnvFrom section of the Kubernetes container spec. " "Used to load environment variables from ConfigMaps or Secrets.", ) - volume_mounts: List[Dict[str, Any]] = Field( + volume_mounts: list[dict[str, Any]] = Field( default_factory=list, description="VolumeMounts section of the Kubernetes container spec. " "Used to mount volumes containing credentials or other data.", ) - volumes: List[Dict[str, Any]] = Field( + volumes: list[dict[str, Any]] = Field( default_factory=list, description="Volumes section of the Kubernetes pod spec. " "Used to define volumes for credentials or other data.", ) - service_account_name: Optional[str] = Field( + service_account_name: str | None = Field( None, description="Name of the Kubernetes service account to use for the Kaniko pod. " "This service account should have the necessary permissions for building and pushing images.", @@ -91,7 +91,7 @@ class KanikoImageBuilderConfig(BaseImageBuilderConfig): "If `False`, the build context will be streamed over stdin of the kubectl process.", ) - executor_args: List[str] = Field( + executor_args: list[str] = Field( default_factory=list, description="Additional arguments to forward to the Kaniko executor. " "See Kaniko documentation for available flags, e.g. ['--compressed-caching=false'].", @@ -111,7 +111,7 @@ def name(self) -> str: return KANIKO_IMAGE_BUILDER_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -120,7 +120,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -138,7 +138,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/image_builder/kaniko.png" @property - def config_class(self) -> Type[KanikoImageBuilderConfig]: + def config_class(self) -> type[KanikoImageBuilderConfig]: """Config class. Returns: @@ -147,7 +147,7 @@ def config_class(self) -> Type[KanikoImageBuilderConfig]: return KanikoImageBuilderConfig @property - def implementation_class(self) -> Type["KanikoImageBuilder"]: + def implementation_class(self) -> type["KanikoImageBuilder"]: """Implementation class. Returns: diff --git a/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py b/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py index 22074472521..68813c052ad 100644 --- a/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py +++ b/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py @@ -18,7 +18,7 @@ import shutil import subprocess import tempfile -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Optional, cast from zenml.enums import StackComponentType from zenml.image_builders import BaseImageBuilder @@ -62,7 +62,7 @@ def is_building_locally(self) -> bool: return False @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates that the stack contains a container registry. Returns: @@ -71,7 +71,7 @@ def validator(self) -> Optional[StackValidator]: def _validate_remote_components( stack: "Stack", - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: assert stack.container_registry if stack.container_registry.config.is_local: @@ -106,7 +106,7 @@ def build( self, image_name: str, build_context: "BuildContext", - docker_build_options: Dict[str, Any], + docker_build_options: dict[str, Any], container_registry: Optional["BaseContainerRegistry"] = None, ) -> str: """Builds and pushes a Docker image. @@ -175,7 +175,7 @@ def build( def _generate_spec_overrides( self, pod_name: str, image_name: str, context: str - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Generates Kubernetes spec overrides for the Kaniko build Pod. These values are used to override the default specification of the @@ -200,7 +200,7 @@ def _generate_spec_overrides( "--image-name-with-digest-file=/dev/termination-log", ] + self.config.executor_args - optional_spec_args: Dict[str, Any] = {} + optional_spec_args: dict[str, Any] = {} if self.config.service_account_name: optional_spec_args["serviceAccountName"] = ( self.config.service_account_name @@ -229,7 +229,7 @@ def _generate_spec_overrides( def _run_kaniko_build( self, pod_name: str, - spec_overrides: Dict[str, Any], + spec_overrides: dict[str, Any], build_context: "BuildContext", ) -> None: """Runs the Kaniko build in Kubernetes. diff --git a/src/zenml/integrations/kubeflow/__init__.py b/src/zenml/integrations/kubeflow/__init__.py index 600c7e587a6..fb9e6368c65 100644 --- a/src/zenml/integrations/kubeflow/__init__.py +++ b/src/zenml/integrations/kubeflow/__init__.py @@ -17,7 +17,6 @@ orchestrator. You can enable it by registering the Kubeflow orchestrator with the CLI tool. """ -from typing import List, Type from zenml.integrations.constants import KUBEFLOW from zenml.integrations.integration import Integration @@ -36,7 +35,7 @@ class KubeflowIntegration(Integration): ] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Kubeflow integration. Returns: diff --git a/src/zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py b/src/zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py index eb218a283e5..f559d08fd90 100644 --- a/src/zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py +++ b/src/zenml/integrations/kubeflow/flavors/kubeflow_orchestrator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Kubeflow orchestrator flavor.""" -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, cast +from typing import TYPE_CHECKING, Any, cast from pydantic import Field, model_validator @@ -49,27 +49,27 @@ class KubeflowOrchestratorSettings(BaseSettings): 1200, description="How many seconds to wait for synchronous runs." ) - client_args: Dict[str, Any] = Field( + client_args: dict[str, Any] = Field( default_factory=dict, description="Arguments to pass when initializing the KFP client. " "Example: {'host': 'https://kubeflow.example.com', 'client_id': 'kubeflow-oidc-authservice', 'existing_token': 'your-auth-token'}", ) - client_username: Optional[str] = SecretField( + client_username: str | None = SecretField( default=None, description="Username to generate a session cookie for the kubeflow client. " "Both `client_username` and `client_password` need to be set together.", ) - client_password: Optional[str] = SecretField( + client_password: str | None = SecretField( default=None, description="Password to generate a session cookie for the kubeflow client. " "Both `client_username` and `client_password` need to be set together.", ) - user_namespace: Optional[str] = Field( + user_namespace: str | None = Field( None, description="The user namespace to use when creating experiments and runs. " "Example: 'my-experiments' or 'team-alpha'", ) - pod_settings: Optional[KubernetesPodSettings] = Field( + pod_settings: KubernetesPodSettings | None = Field( None, description="Pod settings to apply to the orchestrator and step pods.", ) @@ -78,8 +78,8 @@ class KubeflowOrchestratorSettings(BaseSettings): @classmethod @before_validator_handler def _validate_and_migrate_pod_settings( - cls, data: Dict[str, Any] - ) -> Dict[str, Any]: + cls, data: dict[str, Any] + ) -> dict[str, Any]: """Validates settings and migrates pod settings from older version. Args: @@ -91,9 +91,9 @@ def _validate_and_migrate_pod_settings( Raises: ValueError: If username and password are not specified together. """ - node_selectors = cast(Dict[str, str], data.get("node_selectors") or {}) + node_selectors = cast(dict[str, str], data.get("node_selectors") or {}) node_affinity = cast( - Dict[str, List[str]], data.get("node_affinity") or {} + dict[str, list[str]], data.get("node_affinity") or {} ) affinity = {} @@ -142,7 +142,7 @@ class KubeflowOrchestratorConfig( ): """Configuration for the Kubeflow orchestrator.""" - kubeflow_hostname: Optional[str] = Field( + kubeflow_hostname: str | None = Field( None, description="The hostname to use to talk to the Kubeflow Pipelines API. " "If not set, the hostname will be derived from the Kubernetes API proxy. " @@ -153,7 +153,7 @@ class KubeflowOrchestratorConfig( "kubeflow", description="The Kubernetes namespace in which Kubeflow Pipelines is deployed.", ) - kubernetes_context: Optional[str] = Field( + kubernetes_context: str | None = Field( None, description="Name of a kubernetes context to run pipelines in. " "Not applicable when connecting to a multi-tenant Kubeflow Pipelines " @@ -166,8 +166,8 @@ class KubeflowOrchestratorConfig( @classmethod @before_validator_handler def _validate_deprecated_attrs( - cls, data: Dict[str, Any] - ) -> Dict[str, Any]: + cls, data: dict[str, Any] + ) -> dict[str, Any]: """Pydantic root_validator for deprecated attributes. This root validator is used for backwards compatibility purposes. E.g. @@ -249,7 +249,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -264,7 +264,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -273,7 +273,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -291,7 +291,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/kubeflow.png" @property - def config_class(self) -> Type[KubeflowOrchestratorConfig]: + def config_class(self) -> type[KubeflowOrchestratorConfig]: """Returns `KubeflowOrchestratorConfig` config class. Returns: @@ -300,7 +300,7 @@ def config_class(self) -> Type[KubeflowOrchestratorConfig]: return KubeflowOrchestratorConfig @property - def implementation_class(self) -> Type["KubeflowOrchestrator"]: + def implementation_class(self) -> type["KubeflowOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py b/src/zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py index 875518d5dd4..c02438f591e 100644 --- a/src/zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py +++ b/src/zenml/integrations/kubeflow/orchestrators/kubeflow_orchestrator.py @@ -35,11 +35,7 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Tuple, - Type, cast, ) from uuid import UUID @@ -153,7 +149,7 @@ def _load_config(self, *args: Any, **kwargs: Any) -> Any: class KubeflowOrchestrator(ContainerizedOrchestrator): """Orchestrator responsible for running pipelines using Kubeflow.""" - _k8s_client: Optional[k8s_client.ApiClient] = None + _k8s_client: k8s_client.ApiClient | None = None def _get_kfp_client( self, @@ -226,7 +222,7 @@ def config(self) -> KubeflowOrchestratorConfig: """ return cast(KubeflowOrchestratorConfig, self._config) - def get_kubernetes_contexts(self) -> Tuple[List[str], Optional[str]]: + def get_kubernetes_contexts(self) -> tuple[list[str], str | None]: """Get the list of configured Kubernetes contexts and the active context. Returns: @@ -243,7 +239,7 @@ def get_kubernetes_contexts(self) -> Tuple[List[str], Optional[str]]: return context_names, active_context_name @property - def settings_class(self) -> Type[KubeflowOrchestratorSettings]: + def settings_class(self) -> type[KubeflowOrchestratorSettings]: """Settings class for the Kubeflow orchestrator. Returns: @@ -252,7 +248,7 @@ def settings_class(self) -> Type[KubeflowOrchestratorSettings]: return KubeflowOrchestratorSettings @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates that the stack contains a container registry. Also check that requirements are met for local components. @@ -264,7 +260,7 @@ def validator(self) -> Optional[StackValidator]: def _validate_kube_context( kubernetes_context: str, - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: contexts, active_context = self.get_kubernetes_contexts() if kubernetes_context and kubernetes_context not in contexts: @@ -303,7 +299,7 @@ def _validate_kube_context( return True, "" - def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]: + def _validate_local_requirements(stack: "Stack") -> tuple[bool, str]: container_registry = stack.container_registry # should not happen, because the stack validation takes care of @@ -426,8 +422,8 @@ def pipeline_directory(self) -> str: def _create_dynamic_component( self, image: str, - command: List[str], - arguments: List[str], + command: list[str], + arguments: list[str], component_name: str, ) -> dsl.PipelineTask: """Creates a dynamic container component for a Kubeflow pipeline. @@ -470,10 +466,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to @@ -522,8 +518,8 @@ def _create_dynamic_pipeline() -> Any: Returns: pipeline_func """ - step_name_to_dynamic_component: Dict[str, Any] = {} - node_selector_constraint: Optional[Tuple[str, str]] = None + step_name_to_dynamic_component: dict[str, Any] = {} + node_selector_constraint: tuple[str, str] | None = None for step_name, step in snapshot.step_configurations.items(): image = self.get_image( @@ -641,7 +637,7 @@ def _upload_and_run_pipeline( snapshot: "PipelineSnapshotResponse", pipeline_file_path: str, run_name: str, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Tries to upload and run a KFP pipeline. Args: @@ -852,7 +848,7 @@ def _get_session_cookie(self, username: str, password: str) -> str: f"Error while trying to fetch kubeflow cookie: {errh}" ) - cookie_dict: Dict[str, str] = session.cookies.get_dict() # type: ignore[no-untyped-call, unused-ignore] + cookie_dict: dict[str, str] = session.cookies.get_dict() # type: ignore[no-untyped-call, unused-ignore] if "authservice_session" not in cookie_dict: raise RuntimeError("Invalid username and/or password!") @@ -863,7 +859,7 @@ def _get_session_cookie(self, username: str, password: str) -> str: def get_pipeline_run_metadata( self, run_id: UUID - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Get general component-specific metadata for a pipeline run. Args: @@ -906,7 +902,7 @@ def _configure_container_resources( self, dynamic_component: dsl.PipelineTask, resource_settings: "ResourceSettings", - node_selector_constraint: Optional[Tuple[str, str]] = None, + node_selector_constraint: tuple[str, str] | None = None, ) -> dsl.PipelineTask: """Adds resource requirements to the container. diff --git a/src/zenml/integrations/kubernetes/__init__.py b/src/zenml/integrations/kubernetes/__init__.py index 6a9d9bd07e3..343242b7c4c 100644 --- a/src/zenml/integrations/kubernetes/__init__.py +++ b/src/zenml/integrations/kubernetes/__init__.py @@ -17,7 +17,6 @@ orchestrator. You can enable it by registering the Kubernetes orchestrator with the CLI tool. """ -from typing import List, Type from zenml.integrations.constants import KUBERNETES from zenml.integrations.integration import Integration @@ -36,7 +35,7 @@ class KubernetesIntegration(Integration): "kfp", # it is used by many others ] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Kubernetes integration. Returns: diff --git a/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py b/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py index 35faabe7fb0..672b77e4ee9 100644 --- a/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py +++ b/src/zenml/integrations/kubernetes/flavors/kubernetes_orchestrator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Kubernetes orchestrator flavor.""" -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any from pydantic import ( Field, @@ -49,12 +49,12 @@ class KubernetesOrchestratorSettings(BaseSettings): description="Whether to wait for all pipeline steps to complete. " "When `False`, the client returns immediately and execution continues asynchronously.", ) - service_account_name: Optional[str] = Field( + service_account_name: str | None = Field( default=None, description="Kubernetes service account for the orchestrator pod. " "If not specified, creates a new account with 'edit' permissions.", ) - step_pod_service_account_name: Optional[str] = Field( + step_pod_service_account_name: str | None = Field( default=None, description="Kubernetes service account for step execution pods. " "Uses the default service account if not specified.", @@ -63,35 +63,35 @@ class KubernetesOrchestratorSettings(BaseSettings): default=False, description="Whether to run containers in privileged mode with extended permissions.", ) - pod_settings: Optional[KubernetesPodSettings] = Field( + pod_settings: KubernetesPodSettings | None = Field( default=None, description="Pod configuration for step execution containers.", ) - orchestrator_pod_settings: Optional[KubernetesPodSettings] = Field( + orchestrator_pod_settings: KubernetesPodSettings | None = Field( default=None, description="Pod configuration for the orchestrator container that launches step pods.", ) - job_name_prefix: Optional[str] = Field( + job_name_prefix: str | None = Field( default=None, description="Prefix for the job name.", ) - max_parallelism: Optional[PositiveInt] = Field( + max_parallelism: PositiveInt | None = Field( default=None, description="Maximum number of step pods to run concurrently. No limit if not specified.", ) - successful_jobs_history_limit: Optional[NonNegativeInt] = Field( + successful_jobs_history_limit: NonNegativeInt | None = Field( default=None, description="Number of successful scheduled jobs to retain in history.", ) - failed_jobs_history_limit: Optional[NonNegativeInt] = Field( + failed_jobs_history_limit: NonNegativeInt | None = Field( default=None, description="Number of failed scheduled jobs to retain in history.", ) - ttl_seconds_after_finished: Optional[NonNegativeInt] = Field( + ttl_seconds_after_finished: NonNegativeInt | None = Field( default=None, description="Seconds to keep finished jobs before automatic cleanup.", ) - active_deadline_seconds: Optional[NonNegativeInt] = Field( + active_deadline_seconds: NonNegativeInt | None = Field( default=None, description="Job deadline in seconds. If the job doesn't finish " "within this time, it will be terminated.", @@ -117,7 +117,7 @@ class KubernetesOrchestratorSettings(BaseSettings): default=3, description="The backoff limit for the orchestrator job.", ) - fail_on_container_waiting_reasons: Optional[List[str]] = Field( + fail_on_container_waiting_reasons: list[str] | None = Field( default=[ "InvalidImageName", "ErrImagePull", @@ -146,7 +146,7 @@ class KubernetesOrchestratorSettings(BaseSettings): description="The interval in seconds to check for run interruptions.", ge=0.5, ) - pod_failure_policy: Optional[Dict[str, Any]] = Field( + pod_failure_policy: dict[str, Any] | None = Field( default=None, description="The pod failure policy to use for the job that is " "executing the step.", @@ -166,37 +166,37 @@ class KubernetesOrchestratorSettings(BaseSettings): ) # Deprecated fields - timeout: Optional[int] = Field( + timeout: int | None = Field( default=None, deprecated=True, description="DEPRECATED/UNUSED.", ) - stream_step_logs: Optional[bool] = Field( + stream_step_logs: bool | None = Field( default=None, deprecated=True, description="DEPRECATED/UNUSED.", ) - pod_startup_timeout: Optional[int] = Field( + pod_startup_timeout: int | None = Field( default=None, description="DEPRECATED/UNUSED.", deprecated=True, ) - pod_failure_max_retries: Optional[int] = Field( + pod_failure_max_retries: int | None = Field( default=None, description="DEPRECATED/UNUSED.", deprecated=True, ) - pod_failure_retry_delay: Optional[int] = Field( + pod_failure_retry_delay: int | None = Field( default=None, description="DEPRECATED/UNUSED.", deprecated=True, ) - pod_failure_backoff: Optional[float] = Field( + pod_failure_backoff: float | None = Field( default=None, description="DEPRECATED/UNUSED.", deprecated=True, ) - pod_name_prefix: Optional[str] = Field( + pod_name_prefix: str | None = Field( default=None, deprecated=True, description="DEPRECATED/UNUSED.", @@ -246,7 +246,7 @@ class KubernetesOrchestratorConfig( "config option is ignored. If the stack component is linked to a " "Kubernetes service connector, this field is ignored.", ) - kubernetes_context: Optional[str] = Field( + kubernetes_context: str | None = Field( None, description="Name of a Kubernetes context to run pipelines in. " "If the stack component is linked to a Kubernetes service connector, " @@ -268,7 +268,7 @@ class KubernetesOrchestratorConfig( skip_local_validations: bool = Field( False, description="If `True`, the local validations will be skipped." ) - parallel_step_startup_waiting_period: Optional[float] = Field( + parallel_step_startup_waiting_period: float | None = Field( None, description="How long to wait in between starting parallel steps. " "This can be used to distribute server load when running pipelines " @@ -357,7 +357,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -372,7 +372,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -381,7 +381,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -399,7 +399,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/kubernetes.png" @property - def config_class(self) -> Type[KubernetesOrchestratorConfig]: + def config_class(self) -> type[KubernetesOrchestratorConfig]: """Returns `KubernetesOrchestratorConfig` config class. Returns: @@ -408,7 +408,7 @@ def config_class(self) -> Type[KubernetesOrchestratorConfig]: return KubernetesOrchestratorConfig @property - def implementation_class(self) -> Type["KubernetesOrchestrator"]: + def implementation_class(self) -> type["KubernetesOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py b/src/zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py index 6dcbb6ae901..59b08fe23a0 100644 --- a/src/zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py +++ b/src/zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Kubernetes step operator flavor.""" -from typing import TYPE_CHECKING, List, Optional, Type +from typing import TYPE_CHECKING from pydantic import Field, NonNegativeInt @@ -34,11 +34,11 @@ class KubernetesStepOperatorSettings(BaseSettings): """Settings for the Kubernetes step operator.""" - pod_settings: Optional[KubernetesPodSettings] = Field( + pod_settings: KubernetesPodSettings | None = Field( default=None, description="Pod configuration for step execution containers.", ) - service_account_name: Optional[str] = Field( + service_account_name: str | None = Field( default=None, description="Kubernetes service account for step pods. Uses default account if not specified.", ) @@ -46,20 +46,20 @@ class KubernetesStepOperatorSettings(BaseSettings): default=False, description="Whether to run step containers in privileged mode with extended permissions.", ) - job_name_prefix: Optional[str] = Field( + job_name_prefix: str | None = Field( default=None, description="Prefix for the job name.", ) - ttl_seconds_after_finished: Optional[NonNegativeInt] = Field( + ttl_seconds_after_finished: NonNegativeInt | None = Field( default=None, description="Seconds to keep finished jobs before automatic cleanup.", ) - active_deadline_seconds: Optional[NonNegativeInt] = Field( + active_deadline_seconds: NonNegativeInt | None = Field( default=None, description="Job deadline in seconds. If the job doesn't finish " "within this time, it will be terminated.", ) - fail_on_container_waiting_reasons: Optional[List[str]] = Field( + fail_on_container_waiting_reasons: list[str] | None = Field( default=[ "InvalidImageName", "ErrImagePull", @@ -74,22 +74,22 @@ class KubernetesStepOperatorSettings(BaseSettings): ) # Deprecated fields - pod_startup_timeout: Optional[int] = Field( + pod_startup_timeout: int | None = Field( default=None, deprecated=True, description="DEPRECATED/UNUSED.", ) - pod_failure_max_retries: Optional[int] = Field( + pod_failure_max_retries: int | None = Field( default=None, deprecated=True, description="DEPRECATED/UNUSED.", ) - pod_failure_retry_delay: Optional[int] = Field( + pod_failure_retry_delay: int | None = Field( default=None, deprecated=True, description="DEPRECATED/UNUSED.", ) - pod_failure_backoff: Optional[float] = Field( + pod_failure_backoff: float | None = Field( default=None, deprecated=True, description="DEPRECATED/UNUSED.", @@ -116,7 +116,7 @@ class KubernetesStepOperatorConfig( description="Whether to execute within the same cluster as the orchestrator. " "Requires appropriate pod creation permissions.", ) - kubernetes_context: Optional[str] = Field( + kubernetes_context: str | None = Field( default=None, description="Kubernetes context name for cluster connection. " "Ignored when using service connectors or in-cluster execution.", @@ -160,7 +160,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -175,7 +175,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -184,7 +184,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -202,7 +202,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/kubernetes.png" @property - def config_class(self) -> Type[KubernetesStepOperatorConfig]: + def config_class(self) -> type[KubernetesStepOperatorConfig]: """Returns `KubernetesStepOperatorConfig` config class. Returns: @@ -211,7 +211,7 @@ def config_class(self) -> Type[KubernetesStepOperatorConfig]: return KubernetesStepOperatorConfig @property - def implementation_class(self) -> Type["KubernetesStepOperator"]: + def implementation_class(self) -> type["KubernetesStepOperator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/kubernetes/orchestrators/dag_runner.py b/src/zenml/integrations/kubernetes/orchestrators/dag_runner.py index c048538017d..bba278afa6a 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/dag_runner.py +++ b/src/zenml/integrations/kubernetes/orchestrators/dag_runner.py @@ -17,7 +17,8 @@ import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import Any, Callable, Dict, List, Optional +from typing import Any +from collections.abc import Callable from pydantic import BaseModel @@ -52,8 +53,8 @@ class Node(BaseModel): id: str status: NodeStatus = NodeStatus.NOT_READY - upstream_nodes: List[str] = [] - metadata: Dict[str, Any] = {} + upstream_nodes: list[str] = [] + metadata: dict[str, Any] = {} @property def is_finished(self) -> bool: @@ -88,17 +89,17 @@ class DagRunner: def __init__( self, - nodes: List[Node], + nodes: list[Node], node_startup_function: Callable[[Node], NodeStatus], node_monitoring_function: Callable[[Node], NodeStatus], - node_stop_function: Optional[Callable[[Node], None]] = None, - interrupt_function: Optional[ - Callable[[], Optional[InterruptMode]] - ] = None, + node_stop_function: Callable[[Node], None] | None = None, + interrupt_function: None | ( + Callable[[], InterruptMode | None] + ) = None, monitoring_interval: float = 1.0, monitoring_delay: float = 0.0, interrupt_check_interval: float = 1.0, - max_parallelism: Optional[int] = None, + max_parallelism: int | None = None, ) -> None: """Initialize the DAG runner. @@ -137,7 +138,7 @@ def __init__( ) @property - def running_nodes(self) -> List[Node]: + def running_nodes(self) -> list[Node]: """Running nodes. Returns: @@ -150,7 +151,7 @@ def running_nodes(self) -> List[Node]: ] @property - def active_nodes(self) -> List[Node]: + def active_nodes(self) -> list[Node]: """Active nodes. Active nodes are nodes that are either running or starting. @@ -322,7 +323,7 @@ def _monitoring_loop(self) -> None: time_to_sleep = max(0, self.monitoring_interval - duration) self.shutdown_event.wait(timeout=time_to_sleep) - def run(self) -> Dict[str, NodeStatus]: + def run(self) -> dict[str, NodeStatus]: """Run the DAG. Returns: diff --git a/src/zenml/integrations/kubernetes/orchestrators/kube_utils.py b/src/zenml/integrations/kubernetes/orchestrators/kube_utils.py index 0e3674c4a64..fcc4d2f2515 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kube_utils.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kube_utils.py @@ -37,7 +37,8 @@ import re import time from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast +from typing import Any, TypeVar, cast +from collections.abc import Callable from kubernetes import client as k8s_client from kubernetes import config as k8s_config @@ -126,7 +127,7 @@ def is_inside_kubernetes() -> bool: def load_kube_config( - incluster: bool = False, context: Optional[str] = None + incluster: bool = False, context: str | None = None ) -> None: """Load the Kubernetes client config. @@ -200,7 +201,7 @@ def pod_is_done(pod: k8s_client.V1Pod) -> bool: def get_pod( core_api: k8s_client.CoreV1Api, pod_name: str, namespace: str -) -> Optional[k8s_client.V1Pod]: +) -> k8s_client.V1Pod | None: """Get a pod from Kubernetes metadata API. Args: @@ -396,7 +397,7 @@ def create_secret( core_api: k8s_client.CoreV1Api, namespace: str, secret_name: str, - data: Dict[str, Optional[str]], + data: dict[str, str | None], ) -> None: """Create a Kubernetes secret. @@ -416,7 +417,7 @@ def update_secret( core_api: k8s_client.CoreV1Api, namespace: str, secret_name: str, - data: Dict[str, Optional[str]], + data: dict[str, str | None], ) -> None: """Update a Kubernetes secret. @@ -438,7 +439,7 @@ def create_or_update_secret( core_api: k8s_client.CoreV1Api, namespace: str, secret_name: str, - data: Dict[str, Optional[str]], + data: dict[str, str | None], ) -> None: """Create a Kubernetes secret if it doesn't exist, or update it if it does. @@ -583,7 +584,7 @@ def create_and_wait_for_pod_to_start( def get_pod_owner_references( core_api: k8s_client.CoreV1Api, pod_name: str, namespace: str -) -> List[k8s_client.V1OwnerReference]: +) -> list[k8s_client.V1OwnerReference]: """Get owner references for a pod. Args: @@ -600,7 +601,7 @@ def get_pod_owner_references( return [] return cast( - List[k8s_client.V1OwnerReference], pod.metadata.owner_references + list[k8s_client.V1OwnerReference], pod.metadata.owner_references ) @@ -609,7 +610,7 @@ def retry_on_api_exception( max_retries: int = 3, delay: float = 1, backoff: float = 1, - fail_on_status_codes: Tuple[int, ...] = (404,), + fail_on_status_codes: tuple[int, ...] = (404,), ) -> Callable[..., R]: """Retry a function on API exceptions. @@ -691,7 +692,7 @@ def get_job( def list_jobs( batch_api: k8s_client.BatchV1Api, namespace: str, - label_selector: Optional[str] = None, + label_selector: str | None = None, ) -> k8s_client.V1JobList: """List jobs in a namespace. @@ -713,7 +714,7 @@ def update_job( batch_api: k8s_client.BatchV1Api, namespace: str, job_name: str, - annotations: Dict[str, str], + annotations: dict[str, str], ) -> k8s_client.V1Job: """Update a job. @@ -750,7 +751,7 @@ def is_step_job(job: k8s_client.V1Job) -> bool: def get_container_status( pod: k8s_client.V1Pod, container_name: str -) -> Optional[k8s_client.V1ContainerState]: +) -> k8s_client.V1ContainerState | None: """Get the status of a container. Args: @@ -772,7 +773,7 @@ def get_container_status( def get_container_termination_reason( pod: k8s_client.V1Pod, container_name: str -) -> Optional[Tuple[int, str]]: +) -> tuple[int, str] | None: """Get the termination reason for a container. Args: @@ -800,9 +801,9 @@ def wait_for_job_to_finish( backoff_interval: float = 1, maximum_backoff: float = 32, exponential_backoff: bool = False, - fail_on_container_waiting_reasons: Optional[List[str]] = None, + fail_on_container_waiting_reasons: list[str] | None = None, stream_logs: bool = True, - container_name: Optional[str] = None, + container_name: str | None = None, ) -> None: """Wait for a job to finish. @@ -823,7 +824,7 @@ def wait_for_job_to_finish( Raises: RuntimeError: If the job failed or timed out. """ - logged_lines_per_pod: Dict[str, int] = defaultdict(int) + logged_lines_per_pod: dict[str, int] = defaultdict(int) finished_pods = set() while True: @@ -939,9 +940,9 @@ def check_job_status( core_api: k8s_client.CoreV1Api, namespace: str, job_name: str, - fail_on_container_waiting_reasons: Optional[List[str]] = None, - container_name: Optional[str] = None, -) -> Tuple[JobStatus, Optional[str]]: + fail_on_container_waiting_reasons: list[str] | None = None, + container_name: str | None = None, +) -> tuple[JobStatus, str | None]: """Check the status of a job. Args: @@ -1031,7 +1032,7 @@ def create_config_map( core_api: k8s_client.CoreV1Api, namespace: str, name: str, - data: Dict[str, str], + data: dict[str, str], ) -> None: """Create a Kubernetes config map. @@ -1051,7 +1052,7 @@ def update_config_map( core_api: k8s_client.CoreV1Api, namespace: str, name: str, - data: Dict[str, str], + data: dict[str, str], ) -> None: """Update a Kubernetes config map. @@ -1111,7 +1112,7 @@ def get_parent_job_name( core_api: k8s_client.CoreV1Api, pod_name: str, namespace: str, -) -> Optional[str]: +) -> str | None: """Get the name of the job that created a pod. Args: @@ -1137,8 +1138,8 @@ def get_parent_job_name( def apply_default_resource_requests( memory: str, - cpu: Optional[str] = None, - pod_settings: Optional[KubernetesPodSettings] = None, + cpu: str | None = None, + pod_settings: KubernetesPodSettings | None = None, ) -> KubernetesPodSettings: """Applies default resource requests to a pod settings object. diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py index 384a80bece6..b8cbf1e5a86 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py @@ -34,11 +34,7 @@ import random from typing import ( TYPE_CHECKING, - Dict, - List, Optional, - Tuple, - Type, cast, ) from uuid import UUID @@ -94,7 +90,7 @@ class KubernetesOrchestrator(ContainerizedOrchestrator): """Orchestrator for running ZenML pipelines using native Kubernetes.""" - _k8s_client: Optional[k8s_client.ApiClient] = None + _k8s_client: k8s_client.ApiClient | None = None def should_build_pipeline_image( self, snapshot: "PipelineSnapshotBase" @@ -113,7 +109,7 @@ def should_build_pipeline_image( return settings.always_build_pipeline_image def get_kube_client( - self, incluster: Optional[bool] = None + self, incluster: bool | None = None ) -> k8s_client.ApiClient: """Getter for the Kubernetes API client. @@ -197,7 +193,7 @@ def config(self) -> KubernetesOrchestratorConfig: return cast(KubernetesOrchestratorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Kubernetes orchestrator. Returns: @@ -205,7 +201,7 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: """ return KubernetesOrchestratorSettings - def get_kubernetes_contexts(self) -> Tuple[List[str], str]: + def get_kubernetes_contexts(self) -> tuple[list[str], str]: """Get list of configured Kubernetes contexts and the active context. Raises: @@ -227,14 +223,14 @@ def get_kubernetes_contexts(self) -> Tuple[List[str], str]: return context_names, active_context_name @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Defines the validator that checks whether the stack is valid. Returns: Stack validator. """ - def _validate_local_requirements(stack: "Stack") -> Tuple[bool, str]: + def _validate_local_requirements(stack: "Stack") -> tuple[bool, str]: """Validates that the stack contains no local components. Args: @@ -381,7 +377,7 @@ def get_token_secret_name(self, snapshot_id: UUID) -> str: return f"zenml-token-{snapshot_id}" @property - def supported_execution_modes(self) -> List[ExecutionMode]: + def supported_execution_modes(self) -> list[ExecutionMode]: """Returns the supported execution modes for this flavor. Returns: @@ -397,10 +393,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to @@ -800,8 +796,8 @@ def _stop_run( def fetch_status( self, run: "PipelineRunResponse", include_steps: bool = False - ) -> Tuple[ - Optional[ExecutionStatus], Optional[Dict[str, ExecutionStatus]] + ) -> tuple[ + ExecutionStatus | None, dict[str, ExecutionStatus] | None ]: """Refreshes the status of a specific pipeline run. @@ -869,7 +865,7 @@ def fetch_status( def _map_job_status_to_execution_status( self, job: k8s_client.V1Job - ) -> Optional[ExecutionStatus]: + ) -> ExecutionStatus | None: """Map Kubernetes job status to ZenML execution status. Args: @@ -890,7 +886,7 @@ def _map_job_status_to_execution_status( def get_pipeline_run_metadata( self, run_id: UUID - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Get general component-specific metadata for a pipeline run. Args: diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py index bfba199e407..c9d897493f1 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py @@ -18,7 +18,7 @@ import socket import threading import time -from typing import List, Optional, Tuple, cast +from typing import cast from uuid import UUID from kubernetes import client as k8s_client @@ -91,7 +91,7 @@ def parse_args() -> argparse.Namespace: def _get_orchestrator_job_state( batch_api: k8s_client.BatchV1Api, namespace: str, job_name: str -) -> Tuple[Optional[UUID], Optional[str]]: +) -> tuple[UUID | None, str | None]: """Get the existing status of the orchestrator job. Args: @@ -127,7 +127,7 @@ def _reconstruct_nodes( pipeline_run: PipelineRunResponse, namespace: str, batch_api: k8s_client.BatchV1Api, -) -> List[Node]: +) -> list[Node]: """Reconstruct the nodes from the pipeline run. Args: @@ -642,7 +642,7 @@ def check_job_status(node: Node) -> NodeStatus: else: return NodeStatus.RUNNING - def should_interrupt_execution() -> Optional[InterruptMode]: + def should_interrupt_execution() -> InterruptMode | None: """Check if the DAG execution should be interrupted. Returns: diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py index 4dec51e0c76..5b9761c6551 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Entrypoint configuration for the Kubernetes master/orchestrator pod.""" -from typing import TYPE_CHECKING, List, Optional, Set +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from uuid import UUID @@ -26,7 +26,7 @@ class KubernetesOrchestratorEntrypointConfiguration: """Entrypoint configuration for the k8s master/orchestrator pod.""" @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> set[str]: """Gets all the options required for running this entrypoint. Returns: @@ -38,7 +38,7 @@ def get_entrypoint_options(cls) -> Set[str]: return options @classmethod - def get_entrypoint_command(cls) -> List[str]: + def get_entrypoint_command(cls) -> list[str]: """Returns a command that runs the entrypoint module. Returns: @@ -56,7 +56,7 @@ def get_entrypoint_arguments( cls, snapshot_id: "UUID", run_id: Optional["UUID"] = None, - ) -> List[str]: + ) -> list[str]: """Gets all arguments that the entrypoint command should be called with. Args: diff --git a/src/zenml/integrations/kubernetes/orchestrators/manifest_utils.py b/src/zenml/integrations/kubernetes/orchestrators/manifest_utils.py index c52ffb5e9fc..9fa99b0c246 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/manifest_utils.py +++ b/src/zenml/integrations/kubernetes/orchestrators/manifest_utils.py @@ -16,7 +16,8 @@ import base64 import os import sys -from typing import Any, Dict, List, Mapping, Optional +from typing import Any +from collections.abc import Mapping from kubernetes import client as k8s_client @@ -95,18 +96,18 @@ def add_local_stores_mount( def build_pod_manifest( - pod_name: Optional[str], + pod_name: str | None, image_name: str, - command: List[str], - args: List[str], + command: list[str], + args: list[str], privileged: bool, - pod_settings: Optional[KubernetesPodSettings] = None, - service_account_name: Optional[str] = None, - env: Optional[Dict[str, str]] = None, - labels: Optional[Dict[str, str]] = None, + pod_settings: KubernetesPodSettings | None = None, + service_account_name: str | None = None, + env: dict[str, str] | None = None, + labels: dict[str, str] | None = None, mount_local_stores: bool = False, - owner_references: Optional[List[k8s_client.V1OwnerReference]] = None, - termination_grace_period_seconds: Optional[int] = 30, + owner_references: list[k8s_client.V1OwnerReference] | None = None, + termination_grace_period_seconds: int | None = 30, ) -> k8s_client.V1Pod: """Build a Kubernetes pod manifest for a ZenML run or step. @@ -267,7 +268,7 @@ def build_role_binding_manifest_for_service_account( role_name: str, service_account_name: str, namespace: str = "default", -) -> Dict[str, Any]: +) -> dict[str, Any]: """Build a manifest for a role binding of a service account. Args: @@ -300,7 +301,7 @@ def build_role_binding_manifest_for_service_account( def build_service_account_manifest( name: str, namespace: str = "default" -) -> Dict[str, Any]: +) -> dict[str, Any]: """Build the manifest for a service account. Args: @@ -319,7 +320,7 @@ def build_service_account_manifest( } -def build_namespace_manifest(namespace: str) -> Dict[str, Any]: +def build_namespace_manifest(namespace: str) -> dict[str, Any]: """Build the manifest for a new namespace. Args: @@ -339,9 +340,9 @@ def build_namespace_manifest(namespace: str) -> Dict[str, Any]: def build_secret_manifest( name: str, - data: Mapping[str, Optional[str]], + data: Mapping[str, str | None], secret_type: str = "Opaque", -) -> Dict[str, Any]: +) -> dict[str, Any]: """Builds a Kubernetes secret manifest. Args: @@ -388,13 +389,13 @@ def pod_template_manifest_from_pod( def build_job_manifest( job_name: str, pod_template: k8s_client.V1PodTemplateSpec, - backoff_limit: Optional[int] = None, - ttl_seconds_after_finished: Optional[int] = None, - labels: Optional[Dict[str, str]] = None, - annotations: Optional[Dict[str, str]] = None, - active_deadline_seconds: Optional[int] = None, - pod_failure_policy: Optional[Dict[str, Any]] = None, - owner_references: Optional[List[k8s_client.V1OwnerReference]] = None, + backoff_limit: int | None = None, + ttl_seconds_after_finished: int | None = None, + labels: dict[str, str] | None = None, + annotations: dict[str, str] | None = None, + active_deadline_seconds: int | None = None, + pod_failure_policy: dict[str, Any] | None = None, + owner_references: list[k8s_client.V1OwnerReference] | None = None, ) -> k8s_client.V1Job: """Build a Kubernetes job manifest. @@ -450,8 +451,8 @@ def job_template_manifest_from_job( def build_cron_job_manifest( job_template: k8s_client.V1JobTemplateSpec, cron_expression: str, - successful_jobs_history_limit: Optional[int] = None, - failed_jobs_history_limit: Optional[int] = None, + successful_jobs_history_limit: int | None = None, + failed_jobs_history_limit: int | None = None, ) -> k8s_client.V1CronJob: """Build a Kubernetes cron job manifest. diff --git a/src/zenml/integrations/kubernetes/pod_settings.py b/src/zenml/integrations/kubernetes/pod_settings.py index 99fb7cc4719..b1c9da748e6 100644 --- a/src/zenml/integrations/kubernetes/pod_settings.py +++ b/src/zenml/integrations/kubernetes/pod_settings.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Kubernetes pod settings.""" -from typing import Any, Dict, List, Optional +from typing import Any from pydantic import field_validator @@ -72,20 +72,20 @@ class KubernetesPodSettings(BaseSettings): will be applied to the pod spec. """ - node_selectors: Dict[str, str] = {} - affinity: Dict[str, Any] = {} - tolerations: List[Dict[str, Any]] = [] - resources: Dict[str, Dict[str, str]] = {} - annotations: Dict[str, str] = {} - volumes: List[Dict[str, Any]] = [] - volume_mounts: List[Dict[str, Any]] = [] + node_selectors: dict[str, str] = {} + affinity: dict[str, Any] = {} + tolerations: list[dict[str, Any]] = [] + resources: dict[str, dict[str, str]] = {} + annotations: dict[str, str] = {} + volumes: list[dict[str, Any]] = [] + volume_mounts: list[dict[str, Any]] = [] host_ipc: bool = False - scheduler_name: Optional[str] = None - image_pull_secrets: List[str] = [] - labels: Dict[str, str] = {} - env: List[Dict[str, Any]] = [] - env_from: List[Dict[str, Any]] = [] - additional_pod_spec_args: Dict[str, Any] = {} + scheduler_name: str | None = None + image_pull_secrets: list[str] = [] + labels: dict[str, str] = {} + env: list[dict[str, Any]] = [] + env_from: list[dict[str, Any]] = [] + additional_pod_spec_args: dict[str, Any] = {} @field_validator("volumes", mode="before") @classmethod diff --git a/src/zenml/integrations/kubernetes/serialization_utils.py b/src/zenml/integrations/kubernetes/serialization_utils.py index 4a5a8c49681..72738ed39a1 100644 --- a/src/zenml/integrations/kubernetes/serialization_utils.py +++ b/src/zenml/integrations/kubernetes/serialization_utils.py @@ -15,10 +15,10 @@ import re from datetime import date, datetime -from typing import Any, Dict, List, Type, cast +from typing import Any, cast -def serialize_kubernetes_model(model: Any) -> Dict[str, Any]: +def serialize_kubernetes_model(model: Any) -> dict[str, Any]: """Serializes a Kubernetes model. Args: @@ -34,7 +34,7 @@ def serialize_kubernetes_model(model: Any) -> Dict[str, Any]: raise TypeError(f"Unable to serialize non-kubernetes model {model}.") assert hasattr(model, "attribute_map") - attribute_mapping = cast(Dict[str, str], model.attribute_map) + attribute_mapping = cast(dict[str, str], model.attribute_map) model_attributes = { serialized_attribute_name: getattr(model, attribute_name) @@ -43,7 +43,7 @@ def serialize_kubernetes_model(model: Any) -> Dict[str, Any]: return _serialize_dict(model_attributes) -def _serialize_dict(dict_: Dict[str, Any]) -> Dict[str, Any]: +def _serialize_dict(dict_: dict[str, Any]) -> dict[str, Any]: """Serializes a dictionary. Args: @@ -75,11 +75,11 @@ def _serialize(value: Any) -> Any: return value elif isinstance(value, (datetime, date)): return value.isoformat() - elif isinstance(value, List): + elif isinstance(value, list): return [_serialize(item) for item in value] elif isinstance(value, tuple): return tuple(_serialize(item) for item in value) - elif isinstance(value, Dict): + elif isinstance(value, dict): return _serialize_dict(value) elif is_model_class(value.__class__.__name__): return serialize_kubernetes_model(value) @@ -87,7 +87,7 @@ def _serialize(value: Any) -> Any: raise TypeError(f"Failed to serialize unknown object {value}") -def deserialize_kubernetes_model(data: Dict[str, Any], class_name: str) -> Any: +def deserialize_kubernetes_model(data: dict[str, Any], class_name: str) -> Any: """Deserializes a Kubernetes model. Args: @@ -104,14 +104,14 @@ def deserialize_kubernetes_model(data: Dict[str, Any], class_name: str) -> Any: assert hasattr(model_class, "openapi_types") assert hasattr(model_class, "attribute_map") # Mapping of the attribute name of the model class to the attribute type - type_mapping = cast(Dict[str, str], model_class.openapi_types) - reverse_attribute_mapping = cast(Dict[str, str], model_class.attribute_map) + type_mapping = cast(dict[str, str], model_class.openapi_types) + reverse_attribute_mapping = cast(dict[str, str], model_class.attribute_map) # Mapping of the serialized key to the attribute name of the model class attribute_mapping = { value: key for key, value in reverse_attribute_mapping.items() } - deserialized_attributes: Dict[str, Any] = {} + deserialized_attributes: dict[str, Any] = {} for key, value in data.items(): if key not in attribute_mapping: @@ -164,7 +164,7 @@ def is_model_class(class_name: str) -> bool: return hasattr(kubernetes.client.models, class_name) -def get_model_class(class_name: str) -> Type[Any]: +def get_model_class(class_name: str) -> type[Any]: """Gets a Kubernetes model class. Args: @@ -189,7 +189,7 @@ def get_model_class(class_name: str) -> Type[Any]: return class_ -def _deserialize_list(data: Any, class_name: str) -> List[Any]: +def _deserialize_list(data: Any, class_name: str) -> list[Any]: """Deserializes a list of potential Kubernetes models. Args: @@ -199,7 +199,7 @@ def _deserialize_list(data: Any, class_name: str) -> List[Any]: Returns: The deserialized list. """ - assert isinstance(data, List) + assert isinstance(data, list) if is_model_class(class_name): return [ deserialize_kubernetes_model(element, class_name) @@ -209,7 +209,7 @@ def _deserialize_list(data: Any, class_name: str) -> List[Any]: return data -def _deserialize_dict(data: Any, class_name: str) -> Dict[str, Any]: +def _deserialize_dict(data: Any, class_name: str) -> dict[str, Any]: """Deserializes a dict of potential Kubernetes models. Args: @@ -219,7 +219,7 @@ def _deserialize_dict(data: Any, class_name: str) -> Dict[str, Any]: Returns: The deserialized dict. """ - assert isinstance(data, Dict) + assert isinstance(data, dict) if is_model_class(class_name): return { key: deserialize_kubernetes_model(value, class_name) diff --git a/src/zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py b/src/zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py index cc9cd9c2c08..806b9394ddc 100644 --- a/src/zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py +++ b/src/zenml/integrations/kubernetes/service_connectors/kubernetes_service_connector.py @@ -21,7 +21,7 @@ import os import subprocess import tempfile -from typing import Any, List, Optional +from typing import Any from kubernetes import client as k8s_client from kubernetes import config as k8s_config @@ -49,7 +49,7 @@ class KubernetesServerCredentials(AuthenticationConfig): """Kubernetes server authentication config.""" - certificate_authority: Optional[PlainSerializedSecretStr] = Field( + certificate_authority: PlainSerializedSecretStr | None = Field( default=None, title="Kubernetes CA Certificate (base64 encoded)", ) @@ -96,15 +96,15 @@ class KubernetesUserPasswordConfig( class KubernetesTokenCredentials(AuthenticationConfig): """Kubernetes token authentication config.""" - client_certificate: Optional[PlainSerializedSecretStr] = Field( + client_certificate: PlainSerializedSecretStr | None = Field( default=None, title="Kubernetes Client Certificate (base64 encoded)", ) - client_key: Optional[PlainSerializedSecretStr] = Field( + client_key: PlainSerializedSecretStr | None = Field( default=None, title="Kubernetes Client Key (base64 encoded)", ) - token: Optional[PlainSerializedSecretStr] = Field( + token: PlainSerializedSecretStr | None = Field( default=None, title="Kubernetes Token", ) @@ -323,7 +323,7 @@ def _configure_local_client( """ cfg = self.config cluster_name = cfg.cluster_name - delete_files: List[str] = [] + delete_files: list[str] = [] if self.auth_method == KubernetesAuthenticationMethods.PASSWORD: assert isinstance(cfg, KubernetesUserPasswordConfig) @@ -454,10 +454,10 @@ def _configure_local_client( @classmethod def _auto_configure( cls, - auth_method: Optional[str] = None, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - kubernetes_context: Optional[str] = None, + auth_method: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, + kubernetes_context: str | None = None, **kwargs: Any, ) -> "KubernetesServiceConnector": """Auto-configure the connector. @@ -510,7 +510,7 @@ def _auto_configure( insecure=kube_config.verify_ssl is False, ) else: - token: Optional[str] = None + token: str | None = None if kube_config.api_key: token = kube_config.api_key["authorization"].removeprefix( "Bearer " @@ -545,9 +545,9 @@ def _auto_configure( def _verify( self, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> List[str]: + resource_type: str | None = None, + resource_id: str | None = None, + ) -> list[str]: """Verify and list all the resources that the connector can access. Args: diff --git a/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py b/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py index 5a22df4c558..1563b12eb00 100644 --- a/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +++ b/src/zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py @@ -14,7 +14,7 @@ """Kubernetes step operator implementation.""" import random -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, cast +from typing import TYPE_CHECKING, cast from kubernetes import client as k8s_client @@ -53,7 +53,7 @@ class KubernetesStepOperator(BaseStepOperator): """Step operator to run on Kubernetes.""" - _k8s_client: Optional[k8s_client.ApiClient] = None + _k8s_client: k8s_client.ApiClient | None = None @property def config(self) -> KubernetesStepOperatorConfig: @@ -65,7 +65,7 @@ def config(self) -> KubernetesStepOperatorConfig: return cast(KubernetesStepOperatorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Kubernetes step operator. Returns: @@ -74,7 +74,7 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: return KubernetesStepOperatorSettings @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates the stack. Returns: @@ -82,7 +82,7 @@ def validator(self) -> Optional[StackValidator]: registry and a remote artifact store. """ - def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + def _validate_remote_components(stack: "Stack") -> tuple[bool, str]: if stack.artifact_store.config.is_local: return False, ( "The Kubernetes step operator runs code remotely and " @@ -118,7 +118,7 @@ def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: def get_docker_builds( self, snapshot: "PipelineSnapshotBase" - ) -> List["BuildConfiguration"]: + ) -> list["BuildConfiguration"]: """Gets the Docker builds required for the component. Args: @@ -195,8 +195,8 @@ def _k8s_batch_api(self) -> k8s_client.BatchV1Api: def launch( self, info: "StepRunInfo", - entrypoint_command: List[str], - environment: Dict[str, str], + entrypoint_command: list[str], + environment: dict[str, str], ) -> None: """Launches a step on Kubernetes. diff --git a/src/zenml/integrations/label_studio/__init__.py b/src/zenml/integrations/label_studio/__init__.py index e4a0fcf3081..a4042723f4a 100644 --- a/src/zenml/integrations/label_studio/__init__.py +++ b/src/zenml/integrations/label_studio/__init__.py @@ -12,7 +12,6 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Initialization of the Label Studio integration.""" -from typing import List, Type from zenml.integrations.constants import LABEL_STUDIO from zenml.integrations.integration import Integration @@ -30,7 +29,7 @@ class LabelStudioIntegration(Integration): ] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Label Studio integration. Returns: diff --git a/src/zenml/integrations/label_studio/annotators/label_studio_annotator.py b/src/zenml/integrations/label_studio/annotators/label_studio_annotator.py index 77d0b38fc56..3e201e2bd62 100644 --- a/src/zenml/integrations/label_studio/annotators/label_studio_annotator.py +++ b/src/zenml/integrations/label_studio/annotators/label_studio_annotator.py @@ -17,7 +17,7 @@ import os import webbrowser from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Type, cast +from typing import Any, cast from label_studio_sdk import Client, Project @@ -54,7 +54,7 @@ def config(self) -> LabelStudioAnnotatorConfig: return cast(LabelStudioAnnotatorConfig, self._config) @property - def settings_class(self) -> Type[LabelStudioAnnotatorSettings]: + def settings_class(self) -> type[LabelStudioAnnotatorSettings]: """Settings class for the Label Studio annotator. Returns: @@ -86,7 +86,7 @@ def get_url_for_dataset(self, dataset_name: str) -> str: project_id = self.get_id_from_name(dataset_name) return f"{self.get_url()}/projects/{project_id}/" - def get_id_from_name(self, dataset_name: str) -> Optional[int]: + def get_id_from_name(self, dataset_name: str) -> int | None: """Gets the ID of the given dataset. Args: @@ -101,16 +101,16 @@ def get_id_from_name(self, dataset_name: str) -> Optional[int]: return cast(int, project.get_params()["id"]) return None - def get_datasets(self) -> List[Any]: + def get_datasets(self) -> list[Any]: """Gets the datasets currently available for annotation. Returns: A list of datasets. """ datasets = self._get_client().get_projects() - return cast(List[Any], datasets) + return cast(list[Any], datasets) - def get_dataset_names(self) -> List[str]: + def get_dataset_names(self) -> list[str]: """Gets the names of the datasets. Returns: @@ -120,7 +120,7 @@ def get_dataset_names(self) -> List[str]: dataset.get_params()["title"] for dataset in self.get_datasets() ] - def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]: + def get_dataset_stats(self, dataset_name: str) -> tuple[int, int]: """Gets the statistics of the given dataset. Args: @@ -280,7 +280,7 @@ def get_dataset(self, **kwargs: Any) -> Any: def get_converted_dataset( self, dataset_name: str, output_format: str - ) -> Dict[Any, Any]: + ) -> dict[Any, Any]: """Extract annotated tasks in a specific converted format. Args: @@ -367,7 +367,7 @@ def register_dataset_for_annotation( def _get_azure_import_storage_sources( self, dataset_id: int - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Gets a list of all Azure import storage sources. Args: @@ -383,7 +383,7 @@ def _get_azure_import_storage_sources( query_url = f"/api/storages/azure?project={dataset_id}" response = self._get_client().make_request(method="GET", url=query_url) if response.status_code == 200: - return cast(List[Dict[str, Any]], response.json()) + return cast(list[dict[str, Any]], response.json()) else: raise ConnectionError( f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}." @@ -391,7 +391,7 @@ def _get_azure_import_storage_sources( def _get_gcs_import_storage_sources( self, dataset_id: int - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Gets a list of all Google Cloud Storage import storage sources. Args: @@ -407,7 +407,7 @@ def _get_gcs_import_storage_sources( query_url = f"/api/storages/gcs?project={dataset_id}" response = self._get_client().make_request(method="GET", url=query_url) if response.status_code == 200: - return cast(List[Dict[str, Any]], response.json()) + return cast(list[dict[str, Any]], response.json()) else: raise ConnectionError( f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}." @@ -415,7 +415,7 @@ def _get_gcs_import_storage_sources( def _get_s3_import_storage_sources( self, dataset_id: int - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Gets a list of all AWS S3 import storage sources. Args: @@ -431,7 +431,7 @@ def _get_s3_import_storage_sources( query_url = f"/api/storages/s3?project={dataset_id}" response = self._get_client().make_request(method="GET", url=query_url) if response.status_code == 200: - return cast(List[Dict[str, Any]], response.json()) + return cast(list[dict[str, Any]], response.json()) else: raise ConnectionError( f"Unable to get list of import storage sources. Client raised HTTP error {response.status_code}." @@ -484,7 +484,7 @@ def _storage_source_already_exists( for source in storage_sources ) - def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]: + def get_parsed_label_config(self, dataset_id: int) -> dict[str, Any]: """Returns the parsed Label Studio label config for a dataset. Args: @@ -499,7 +499,7 @@ def get_parsed_label_config(self, dataset_id: int) -> Dict[str, Any]: # TODO: check if client actually is connected etc dataset = self._get_client().get_project(dataset_id) if dataset: - return cast(Dict[str, Any], dataset.parsed_label_config) + return cast(dict[str, Any], dataset.parsed_label_config) raise ValueError("No dataset found for the given id.") def populate_artifact_store_parameters( @@ -682,7 +682,7 @@ def connect_and_sync_external_storage( uri: str, params: LabelStudioDatasetSyncParameters, dataset: Project, - ) -> Optional[Dict[str, Any]]: + ) -> dict[str, Any] | None: """Syncs the external storage for the given project. Args: @@ -813,4 +813,4 @@ def connect_and_sync_external_storage( synced_storage = self._get_client().sync_storage( storage_id=storage["id"], storage_type=storage["type"] ) - return cast(Dict[str, Any], synced_storage) + return cast(dict[str, Any], synced_storage) diff --git a/src/zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py b/src/zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py index e93594bd79a..e37d16ca673 100644 --- a/src/zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py +++ b/src/zenml/integrations/label_studio/flavors/label_studio_annotator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Label Studio annotator flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.annotators.base_annotator import ( BaseAnnotatorConfig, @@ -42,8 +42,8 @@ class LabelStudioAnnotatorSettings(BaseSettings): """ instance_url: str = DEFAULT_LOCAL_INSTANCE_URL - port: Optional[int] = DEFAULT_LOCAL_LABEL_STUDIO_PORT - api_key: Optional[str] = SecretField(default=None) + port: int | None = DEFAULT_LOCAL_LABEL_STUDIO_PORT + api_key: str | None = SecretField(default=None) class LabelStudioAnnotatorConfig( @@ -72,7 +72,7 @@ def name(self) -> str: return LABEL_STUDIO_ANNOTATOR_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -81,7 +81,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -99,7 +99,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/annotator/label_studio.png" @property - def config_class(self) -> Type[LabelStudioAnnotatorConfig]: + def config_class(self) -> type[LabelStudioAnnotatorConfig]: """Returns `LabelStudioAnnotatorConfig` config class. Returns: @@ -108,7 +108,7 @@ def config_class(self) -> Type[LabelStudioAnnotatorConfig]: return LabelStudioAnnotatorConfig @property - def implementation_class(self) -> Type["LabelStudioAnnotator"]: + def implementation_class(self) -> type["LabelStudioAnnotator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/label_studio/label_config_generators/__init__.py b/src/zenml/integrations/label_studio/label_config_generators/__init__.py index ba64ba54ca9..1f81fda5da6 100644 --- a/src/zenml/integrations/label_studio/label_config_generators/__init__.py +++ b/src/zenml/integrations/label_studio/label_config_generators/__init__.py @@ -14,8 +14,6 @@ """Initialization of the Label Studio config generators submodule.""" from zenml.integrations.label_studio.label_config_generators.label_config_generators import ( - generate_basic_ocr_label_config, - generate_basic_object_detection_bounding_boxes_label_config, generate_image_classification_label_config, generate_text_classification_label_config, TASK_TO_FILENAME_REFERENCE_MAPPING, diff --git a/src/zenml/integrations/label_studio/label_config_generators/label_config_generators.py b/src/zenml/integrations/label_studio/label_config_generators/label_config_generators.py index 1330b14c047..ef2385bf359 100644 --- a/src/zenml/integrations/label_studio/label_config_generators/label_config_generators.py +++ b/src/zenml/integrations/label_studio/label_config_generators/label_config_generators.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Implementation of label config generators for Label Studio.""" -from typing import List, Tuple from zenml.enums import AnnotationTasks from zenml.logger import get_logger @@ -36,8 +35,8 @@ def _generate_label_config() -> str: def generate_text_classification_label_config( - labels: List[str], -) -> Tuple[str, str]: + labels: list[str], +) -> tuple[str, str]: """Generates a Label Studio label config for text classification. This is based on the basic config example shown at @@ -75,8 +74,8 @@ def generate_text_classification_label_config( def generate_image_classification_label_config( - labels: List[str], -) -> Tuple[str, str]: + labels: list[str], +) -> tuple[str, str]: """Generates a Label Studio label config for image classification. This is based on the basic config example shown at @@ -113,8 +112,8 @@ def generate_image_classification_label_config( def generate_basic_object_detection_bounding_boxes_label_config( - labels: List[str], -) -> Tuple[str, str]: + labels: list[str], +) -> tuple[str, str]: """Generates a Label Studio config for object detection with bounding boxes. This is based on the basic config example shown at @@ -151,8 +150,8 @@ def generate_basic_object_detection_bounding_boxes_label_config( def generate_basic_ocr_label_config( - labels: List[str], -) -> Tuple[str, str]: + labels: list[str], +) -> tuple[str, str]: """Generates a Label Studio config for optical character recognition (OCR) labeling task. This is based on the basic config example shown at diff --git a/src/zenml/integrations/label_studio/label_studio_utils.py b/src/zenml/integrations/label_studio/label_studio_utils.py index 26a80541284..9b0682a8d55 100644 --- a/src/zenml/integrations/label_studio/label_studio_utils.py +++ b/src/zenml/integrations/label_studio/label_studio_utils.py @@ -14,7 +14,7 @@ """Utility functions for the Label Studio annotator integration.""" import os -from typing import Any, Dict, List +from typing import Any from urllib.parse import quote, urlparse @@ -39,9 +39,9 @@ def clean_url(url: str) -> str: def convert_pred_filenames_to_task_ids( - preds: List[Dict[str, Any]], - tasks: List[Dict[str, Any]], -) -> List[Dict[str, Any]]: + preds: list[dict[str, Any]], + tasks: list[dict[str, Any]], +) -> list[dict[str, Any]]: """Converts a list of predictions from local file references to task id. Args: diff --git a/src/zenml/integrations/label_studio/steps/__init__.py b/src/zenml/integrations/label_studio/steps/__init__.py index 316f9923a13..b00d7170a06 100644 --- a/src/zenml/integrations/label_studio/steps/__init__.py +++ b/src/zenml/integrations/label_studio/steps/__init__.py @@ -13,9 +13,3 @@ # permissions and limitations under the License. """Standard steps to be used with the Label Studio annotator integration.""" -from zenml.integrations.label_studio.steps.label_studio_standard_steps import ( - LabelStudioDatasetSyncParameters, - get_labeled_data, - get_or_create_dataset, - sync_new_data_to_label_studio, -) diff --git a/src/zenml/integrations/label_studio/steps/label_studio_standard_steps.py b/src/zenml/integrations/label_studio/steps/label_studio_standard_steps.py index a71639c606b..6c13f60fcbd 100644 --- a/src/zenml/integrations/label_studio/steps/label_studio_standard_steps.py +++ b/src/zenml/integrations/label_studio/steps/label_studio_standard_steps.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of standard steps for the Label Studio annotator integration.""" -from typing import Any, Dict, List, Optional, cast +from typing import Any, cast from urllib.parse import urlparse from pydantic import BaseModel @@ -65,22 +65,22 @@ class LabelStudioDatasetSyncParameters(BaseModel): storage_type: str = "local" label_config_type: str - prefix: Optional[str] = None - regex_filter: Optional[str] = ".*" - use_blob_urls: Optional[bool] = True - presign: Optional[bool] = True - presign_ttl: Optional[int] = 1 - description: Optional[str] = "" + prefix: str | None = None + regex_filter: str | None = ".*" + use_blob_urls: bool | None = True + presign: bool | None = True + presign_ttl: int | None = 1 + description: str | None = "" # credentials specific to the main cloud providers - azure_account_name: Optional[str] = None - azure_account_key: Optional[str] = None - google_application_credentials: Optional[str] = None - aws_access_key_id: Optional[str] = None - aws_secret_access_key: Optional[str] = None - aws_session_token: Optional[str] = None - s3_region_name: Optional[str] = None - s3_endpoint: Optional[str] = None + azure_account_name: str | None = None + azure_account_key: str | None = None + google_application_credentials: str | None = None + aws_access_key_id: str | None = None + aws_secret_access_key: str | None = None + aws_session_token: str | None = None + s3_region_name: str | None = None + s3_endpoint: str | None = None @step(enable_cache=False) @@ -127,7 +127,7 @@ def get_or_create_dataset( @step(enable_cache=False) -def get_labeled_data(dataset_name: str) -> List: # type: ignore[type-arg] +def get_labeled_data(dataset_name: str) -> list: # type: ignore[type-arg] """Gets labeled data from the dataset. Args: @@ -166,7 +166,7 @@ def get_labeled_data(dataset_name: str) -> List: # type: ignore[type-arg] def sync_new_data_to_label_studio( uri: str, dataset_name: str, - predictions: List[Dict[str, Any]], + predictions: list[dict[str, Any]], params: LabelStudioDatasetSyncParameters, ) -> None: """Syncs new data to Label Studio. diff --git a/src/zenml/integrations/langchain/materializers/__init__.py b/src/zenml/integrations/langchain/materializers/__init__.py index 1f1da6a9e3a..f7db8fe8d51 100644 --- a/src/zenml/integrations/langchain/materializers/__init__.py +++ b/src/zenml/integrations/langchain/materializers/__init__.py @@ -12,14 +12,5 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Initialization of the langchain materializer.""" -from zenml.integrations.langchain.materializers.vector_store_materializer import ( - LangchainVectorStoreMaterializer, -) -from zenml.integrations.langchain.materializers.document_materializer import ( - LangchainDocumentMaterializer, -) -from zenml.integrations.langchain.materializers.openai_embedding_materializer import ( - LangchainOpenaiEmbeddingMaterializer, -) diff --git a/src/zenml/integrations/langchain/materializers/document_materializer.py b/src/zenml/integrations/langchain/materializers/document_materializer.py index c95ebd7252a..0d631931a16 100644 --- a/src/zenml/integrations/langchain/materializers/document_materializer.py +++ b/src/zenml/integrations/langchain/materializers/document_materializer.py @@ -14,7 +14,7 @@ """Implementation of ZenML's Langchain Document materializer.""" import os -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar from langchain.docstore.document import Document @@ -32,9 +32,9 @@ class LangchainDocumentMaterializer(BaseMaterializer): """Handle Langchain Document objects.""" ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Document,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (Document,) - def load(self, data_type: Type["Document"]) -> Any: + def load(self, data_type: type["Document"]) -> Any: """Reads BaseModel from JSON. Args: @@ -56,7 +56,7 @@ def save(self, data: "Document") -> None: data_path = os.path.join(self.uri, DEFAULT_FILENAME) yaml_utils.write_json(data_path, data.json()) - def extract_metadata(self, data: Document) -> Dict[str, "MetadataType"]: + def extract_metadata(self, data: Document) -> dict[str, "MetadataType"]: """Extract metadata from the given BaseModel object. Args: diff --git a/src/zenml/integrations/langchain/materializers/openai_embedding_materializer.py b/src/zenml/integrations/langchain/materializers/openai_embedding_materializer.py index d9e7d6b10a4..e04a944e71e 100644 --- a/src/zenml/integrations/langchain/materializers/openai_embedding_materializer.py +++ b/src/zenml/integrations/langchain/materializers/openai_embedding_materializer.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the Langchain OpenAI embedding materializer.""" -from typing import Any, ClassVar, Tuple, Type +from typing import Any, ClassVar from langchain_community.embeddings import ( OpenAIEmbeddings, @@ -29,7 +29,7 @@ class LangchainOpenaiEmbeddingMaterializer(CloudpickleMaterializer): """Materializer for Langchain OpenAI Embeddings.""" ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (OpenAIEmbeddings,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (OpenAIEmbeddings,) def save(self, embeddings: Any) -> None: """Saves the embeddings model after clearing non-picklable clients. @@ -44,7 +44,7 @@ def save(self, embeddings: Any) -> None: # Use the parent class's save implementation which uses cloudpickle super().save(embeddings) - def load(self, data_type: Type[Any]) -> Any: + def load(self, data_type: type[Any]) -> Any: """Loads the embeddings model and lets it recreate clients when needed. Args: diff --git a/src/zenml/integrations/langchain/materializers/vector_store_materializer.py b/src/zenml/integrations/langchain/materializers/vector_store_materializer.py index 7d49423c2c6..433a25b1c02 100644 --- a/src/zenml/integrations/langchain/materializers/vector_store_materializer.py +++ b/src/zenml/integrations/langchain/materializers/vector_store_materializer.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the langchain vector store materializer.""" -from typing import Any, ClassVar, Tuple, Type +from typing import Any, ClassVar from langchain.vectorstores.base import VectorStore @@ -27,4 +27,4 @@ class LangchainVectorStoreMaterializer(CloudpickleMaterializer): """Handle langchain vector store objects.""" ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (VectorStore,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (VectorStore,) diff --git a/src/zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py b/src/zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py index c420e27a666..26e3a741f1c 100644 --- a/src/zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py +++ b/src/zenml/integrations/lightgbm/materializers/lightgbm_booster_materializer.py @@ -14,7 +14,7 @@ """Implementation of the LightGBM booster materializer.""" import os -from typing import Any, ClassVar, Tuple, Type +from typing import Any, ClassVar import lightgbm as lgb @@ -28,10 +28,10 @@ class LightGBMBoosterMaterializer(BaseMaterializer): """Materializer to read data to and from lightgbm.Booster.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (lgb.Booster,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (lgb.Booster,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL - def load(self, data_type: Type[Any]) -> lgb.Booster: + def load(self, data_type: type[Any]) -> lgb.Booster: """Reads a lightgbm Booster model from a serialized JSON file. Args: diff --git a/src/zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py b/src/zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py index 9269d4c15ca..43f9374785d 100644 --- a/src/zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py +++ b/src/zenml/integrations/lightgbm/materializers/lightgbm_dataset_materializer.py @@ -14,7 +14,7 @@ """Implementation of the LightGBM materializer.""" import os -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar import lightgbm as lgb @@ -31,10 +31,10 @@ class LightGBMDatasetMaterializer(BaseMaterializer): """Materializer to read data to and from lightgbm.Dataset.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (lgb.Dataset,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (lgb.Dataset,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - def load(self, data_type: Type[Any]) -> lgb.Dataset: + def load(self, data_type: type[Any]) -> lgb.Dataset: """Reads a lightgbm.Dataset binary file and loads it. Args: @@ -71,7 +71,7 @@ def save(self, matrix: lgb.Dataset) -> None: def extract_metadata( self, matrix: lgb.Dataset - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Extract metadata from the given `Dataset` object. Args: diff --git a/src/zenml/integrations/lightning/__init__.py b/src/zenml/integrations/lightning/__init__.py index f5241485f9b..7cd3da5358e 100644 --- a/src/zenml/integrations/lightning/__init__.py +++ b/src/zenml/integrations/lightning/__init__.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Initialization of the Lightning integration for ZenML.""" -from typing import List, Type from zenml.integrations.constants import ( LIGHTNING, @@ -31,7 +30,7 @@ class LightningIntegration(Integration): REQUIREMENTS = ["lightning-sdk>=0.1.17"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Lightning integration. Returns: diff --git a/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py b/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py index 425b5616f45..76c45ca905d 100644 --- a/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py +++ b/src/zenml/integrations/lightning/flavors/lightning_orchestrator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Lightning orchestrator base config and settings.""" -from typing import TYPE_CHECKING, List, Optional, Type +from typing import TYPE_CHECKING from pydantic import Field @@ -40,34 +40,34 @@ class LightningOrchestratorSettings(BaseSettings): """ # Lightning AI Platform Configuration - main_studio_name: Optional[str] = Field( + main_studio_name: str | None = Field( default=None, description="Lightning AI studio instance name where the pipeline will execute.", ) - machine_type: Optional[str] = Field( + machine_type: str | None = Field( default=None, description="Compute instance type for pipeline execution. " "Refer to Lightning AI documentation for available options.", ) - user_id: Optional[str] = SecretField( + user_id: str | None = SecretField( default=None, description="Lightning AI user ID for authentication." ) - api_key: Optional[str] = SecretField( + api_key: str | None = SecretField( default=None, description="Lightning AI API key for platform authentication.", ) - username: Optional[str] = Field( + username: str | None = Field( default=None, description="Lightning AI platform username." ) - teamspace: Optional[str] = Field( + teamspace: str | None = Field( default=None, description="Lightning AI teamspace for collaborative pipeline execution.", ) - organization: Optional[str] = Field( + organization: str | None = Field( default=None, description="Lightning AI organization name for enterprise accounts.", ) - custom_commands: Optional[List[str]] = Field( + custom_commands: list[str] | None = Field( default=None, description="Additional shell commands to execute in the Lightning AI environment.", ) @@ -135,7 +135,7 @@ def name(self) -> str: return LIGHTNING_ORCHESTRATOR_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -144,7 +144,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -162,7 +162,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/lightning.png" @property - def config_class(self) -> Type[LightningOrchestratorConfig]: + def config_class(self) -> type[LightningOrchestratorConfig]: """Returns `KubeflowOrchestratorConfig` config class. Returns: @@ -171,7 +171,7 @@ def config_class(self) -> Type[LightningOrchestratorConfig]: return LightningOrchestratorConfig @property - def implementation_class(self) -> Type["LightningOrchestrator"]: + def implementation_class(self) -> type["LightningOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator.py b/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator.py index cc72c785c9a..480ab740df3 100644 --- a/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator.py +++ b/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator.py @@ -16,7 +16,7 @@ import os import tempfile import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, cast +from typing import TYPE_CHECKING, Any, Optional, cast from uuid import uuid4 from lightning_sdk import Machine, Studio @@ -99,7 +99,7 @@ def config(self) -> LightningOrchestratorConfig: return cast(LightningOrchestratorConfig, self._config) @property - def settings_class(self) -> Type[LightningOrchestratorSettings]: + def settings_class(self) -> type[LightningOrchestratorSettings]: """Settings class for the Lightning orchestrator. Returns: @@ -160,10 +160,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to @@ -241,7 +241,7 @@ def submit_pipeline( def _construct_lightning_steps( snapshot: "PipelineSnapshotResponse", - ) -> Dict[str, Dict[str, Any]]: + ) -> dict[str, dict[str, Any]]: """Construct the steps for the pipeline. Args: @@ -367,7 +367,7 @@ def _upload_and_run_pipeline( orchestrator_run_id: str, requirements: str, settings: LightningOrchestratorSettings, - steps_commands: Dict[str, Dict[str, Any]], + steps_commands: dict[str, dict[str, Any]], code_path: str, filename: str, env_file_path: str, @@ -474,11 +474,11 @@ def _run_step_in_new_studio( self, orchestrator_run_id: str, step_name: str, - details: Dict[str, Any], + details: dict[str, Any], code_path: str, filename: str, env_file_path: str, - custom_commands: Optional[List[str]] = None, + custom_commands: list[str] | None = None, ) -> None: """Run a step in a new studio. @@ -531,7 +531,7 @@ def _run_step_in_new_studio( studio.delete() def _run_step_in_main_studio( - self, studio: Studio, details: Dict[str, Any], filename: str + self, studio: Studio, details: dict[str, Any], filename: str ) -> None: """Run a step in the main studio. diff --git a/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py b/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py index c38cab7f346..39d56ab2788 100644 --- a/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py +++ b/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint.py @@ -15,7 +15,7 @@ import argparse import os -from typing import Dict, cast +from typing import cast from lightning_sdk import Machine, Studio @@ -115,7 +115,7 @@ def main() -> None: f'"{req}"' for req in pipeline_requirements ) - unique_resource_configs: Dict[str, str] = {} + unique_resource_configs: dict[str, str] = {} main_studio_name = sanitize_studio_name( f"zenml_{orchestrator_run_id}_pipeline" ) diff --git a/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint_configuration.py b/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint_configuration.py index 688de9dc3c4..e4fc6454972 100644 --- a/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint_configuration.py +++ b/src/zenml/integrations/lightning/orchestrators/lightning_orchestrator_entrypoint_configuration.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Entrypoint configuration for the Lightning master/orchestrator VM.""" -from typing import TYPE_CHECKING, List, Set +from typing import TYPE_CHECKING if TYPE_CHECKING: from uuid import UUID @@ -26,7 +26,7 @@ class LightningOrchestratorEntrypointConfiguration: """Entrypoint configuration for the Lightning master/orchestrator VM.""" @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> set[str]: """Gets all the options required for running this entrypoint. Returns: @@ -39,7 +39,7 @@ def get_entrypoint_options(cls) -> Set[str]: return options @classmethod - def get_entrypoint_command(cls) -> List[str]: + def get_entrypoint_command(cls) -> list[str]: """Returns a command that runs the entrypoint module. Returns: @@ -57,7 +57,7 @@ def get_entrypoint_arguments( cls, run_name: str, snapshot_id: "UUID", - ) -> List[str]: + ) -> list[str]: """Gets all arguments that the entrypoint command should be called with. Args: diff --git a/src/zenml/integrations/lightning/orchestrators/utils.py b/src/zenml/integrations/lightning/orchestrators/utils.py index 27c1091d679..c91884ca4aa 100644 --- a/src/zenml/integrations/lightning/orchestrators/utils.py +++ b/src/zenml/integrations/lightning/orchestrators/utils.py @@ -15,7 +15,6 @@ import itertools import re -from typing import List from zenml.client import Client from zenml.config import DockerSettings @@ -38,7 +37,7 @@ def sanitize_studio_name(studio_name: str) -> str: return re.sub(r"[-]+", "-", studio_name) -def gather_requirements(docker_settings: "DockerSettings") -> List[str]: +def gather_requirements(docker_settings: "DockerSettings") -> list[str]: """Gather the requirements files. Args: diff --git a/src/zenml/integrations/mlflow/__init__.py b/src/zenml/integrations/mlflow/__init__.py index 9671011bdff..1a836a152ba 100644 --- a/src/zenml/integrations/mlflow/__init__.py +++ b/src/zenml/integrations/mlflow/__init__.py @@ -16,13 +16,10 @@ The MLflow integrations currently enables you to use MLflow tracking as a convenient way to visualize your experiment runs within the MLflow UI. """ -from packaging import version -from typing import List, Type, Optional from zenml.integrations.constants import MLFLOW from zenml.integrations.integration import Integration from zenml.stack import Flavor -import sys from zenml.logger import get_logger @@ -47,8 +44,8 @@ class MlflowIntegration(Integration): @classmethod def get_requirements( - cls, target_os: Optional[str] = None, python_version: Optional[str] = None - ) -> List[str]: + cls, target_os: str | None = None, python_version: str | None = None + ) -> list[str]: """Method to get the requirements for the integration. Args: @@ -77,7 +74,7 @@ def activate(cls) -> None: from zenml.integrations.mlflow import services # noqa @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the MLflow integration. Returns: diff --git a/src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py b/src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py index e33e39e5479..bcdb5fd981e 100644 --- a/src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py +++ b/src/zenml/integrations/mlflow/experiment_trackers/mlflow_experiment_tracker.py @@ -15,7 +15,7 @@ import importlib import os -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, cast +from typing import TYPE_CHECKING, Any, Optional, cast import mlflow from mlflow.entities import Experiment, Run @@ -103,7 +103,7 @@ def config(self) -> MLFlowExperimentTrackerConfig: return cast(MLFlowExperimentTrackerConfig, self._config) @property - def local_path(self) -> Optional[str]: + def local_path(self) -> str | None: """Path to the local directory where the MLflow artifacts are stored. Returns: @@ -140,7 +140,7 @@ def validator(self) -> Optional["StackValidator"]: ) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Mlflow experiment tracker. Returns: @@ -212,7 +212,7 @@ def prepare_step_run(self, info: "StepRunInfo") -> None: def get_step_run_metadata( self, info: "StepRunInfo" - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Get component- and step-specific metadata after a step ran. Args: @@ -221,7 +221,7 @@ def get_step_run_metadata( Returns: A dictionary of metadata. """ - metadata: Dict[str, Any] = { + metadata: dict[str, Any] = { METADATA_EXPERIMENT_TRACKER_URL: Uri( self.get_tracking_uri(as_plain_text=False) ), @@ -317,7 +317,7 @@ def configure_mlflow(self) -> None: "true" if self.config.tracking_insecure_tls else "false" ) - def get_run_id(self, experiment_name: str, run_name: str) -> Optional[str]: + def get_run_id(self, experiment_name: str, run_name: str) -> str | None: """Gets the if of a run with the given name and experiment. Args: @@ -393,7 +393,7 @@ def _adjust_experiment_name(self, experiment_name: str) -> str: return experiment_name @staticmethod - def _get_internal_tags() -> Dict[str, Any]: + def _get_internal_tags() -> dict[str, Any]: """Gets ZenML internal tags for MLflow runs. Returns: diff --git a/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py b/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py index 13160f7fe7e..24c57c0729d 100644 --- a/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py +++ b/src/zenml/integrations/mlflow/flavors/mlflow_experiment_tracker_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """MLflow experiment tracker flavor.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Type +from typing import TYPE_CHECKING, Any from pydantic import Field, model_validator @@ -61,7 +61,7 @@ def is_databricks_tracking_uri(tracking_uri: str) -> bool: class MLFlowExperimentTrackerSettings(BaseSettings): """Settings for the MLflow experiment tracker.""" - experiment_name: Optional[str] = Field( + experiment_name: str | None = Field( None, description="The MLflow experiment name to use for tracking runs.", ) @@ -69,7 +69,7 @@ class MLFlowExperimentTrackerSettings(BaseSettings): False, description="If `True`, will create a nested sub-run for the step.", ) - tags: Dict[str, Any] = Field( + tags: dict[str, Any] = Field( default_factory=dict, description="Tags to attach to the MLflow run for categorization and filtering.", ) @@ -80,23 +80,23 @@ class MLFlowExperimentTrackerConfig( ): """Config for the MLflow experiment tracker.""" - tracking_uri: Optional[str] = Field( + tracking_uri: str | None = Field( None, description="The URI of the MLflow tracking server. If no URI is set, " "your stack must contain a LocalArtifactStore and ZenML will point " "MLflow to a subdirectory of your artifact store instead.", ) - tracking_username: Optional[str] = SecretField( + tracking_username: str | None = SecretField( default=None, description="Username for authenticating with the MLflow tracking server. " "Required when using a remote tracking URI along with tracking_password.", ) - tracking_password: Optional[str] = SecretField( + tracking_password: str | None = SecretField( default=None, description="Password for authenticating with the MLflow tracking server. " "Required when using a remote tracking URI along with tracking_username.", ) - tracking_token: Optional[str] = SecretField( + tracking_token: str | None = SecretField( default=None, description="Token for authenticating with the MLflow tracking server. " "Alternative to username/password authentication for remote tracking URIs.", @@ -106,7 +106,7 @@ class MLFlowExperimentTrackerConfig( description="Skips verification of TLS connection to the MLflow tracking " "server if set to `True`. Use with caution in production environments.", ) - databricks_host: Optional[str] = Field( + databricks_host: str | None = Field( None, description="The host of the Databricks workspace with the MLflow managed " "server to connect to. Required when tracking_uri is set to 'databricks'.", @@ -194,7 +194,7 @@ def name(self) -> str: return MLFLOW_MODEL_EXPERIMENT_TRACKER_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -203,7 +203,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -221,7 +221,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/experiment_tracker/mlflow.png" @property - def config_class(self) -> Type[MLFlowExperimentTrackerConfig]: + def config_class(self) -> type[MLFlowExperimentTrackerConfig]: """Returns `MLFlowExperimentTrackerConfig` config class. Returns: @@ -230,7 +230,7 @@ def config_class(self) -> Type[MLFlowExperimentTrackerConfig]: return MLFlowExperimentTrackerConfig @property - def implementation_class(self) -> Type["MLFlowExperimentTracker"]: + def implementation_class(self) -> type["MLFlowExperimentTracker"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py b/src/zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py index 483304e6548..f36260e0118 100644 --- a/src/zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py +++ b/src/zenml/integrations/mlflow/flavors/mlflow_model_deployer_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """MLflow model deployer flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from pydantic import Field @@ -63,7 +63,7 @@ def name(self) -> str: return MLFLOW_MODEL_DEPLOYER_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -72,7 +72,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -90,7 +90,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png" @property - def config_class(self) -> Type[MLFlowModelDeployerConfig]: + def config_class(self) -> type[MLFlowModelDeployerConfig]: """Returns `MLFlowModelDeployerConfig` config class. Returns: @@ -99,7 +99,7 @@ def config_class(self) -> Type[MLFlowModelDeployerConfig]: return MLFlowModelDeployerConfig @property - def implementation_class(self) -> Type["MLFlowModelDeployer"]: + def implementation_class(self) -> type["MLFlowModelDeployer"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/mlflow/flavors/mlflow_model_registry_flavor.py b/src/zenml/integrations/mlflow/flavors/mlflow_model_registry_flavor.py index dcbcfc6f146..3d91866b32b 100644 --- a/src/zenml/integrations/mlflow/flavors/mlflow_model_registry_flavor.py +++ b/src/zenml/integrations/mlflow/flavors/mlflow_model_registry_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """MLflow model registry flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.integrations.mlflow import MLFLOW_MODEL_REGISTRY_FLAVOR from zenml.model_registries.base_model_registry import ( @@ -44,7 +44,7 @@ def name(self) -> str: return MLFLOW_MODEL_REGISTRY_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -53,7 +53,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -71,7 +71,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png" @property - def config_class(self) -> Type[MLFlowModelRegistryConfig]: + def config_class(self) -> type[MLFlowModelRegistryConfig]: """Returns `MLFlowModelRegistryConfig` config class. Returns: @@ -80,7 +80,7 @@ def config_class(self) -> Type[MLFlowModelRegistryConfig]: return MLFlowModelRegistryConfig @property - def implementation_class(self) -> Type["MLFlowModelRegistry"]: + def implementation_class(self) -> type["MLFlowModelRegistry"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py b/src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py index 1c0de9b43d2..17644adc54d 100644 --- a/src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py +++ b/src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py @@ -15,7 +15,7 @@ import os import shutil -from typing import ClassVar, Dict, Optional, Type, cast +from typing import ClassVar, cast from uuid import UUID from zenml.config.global_config import GlobalConfiguration @@ -40,9 +40,9 @@ class MLFlowModelDeployer(BaseModelDeployer): """MLflow implementation of the BaseModelDeployer.""" NAME: ClassVar[str] = "MLflow" - FLAVOR: ClassVar[Type[BaseModelDeployerFlavor]] = MLFlowModelDeployerFlavor + FLAVOR: ClassVar[type[BaseModelDeployerFlavor]] = MLFlowModelDeployerFlavor - _service_path: Optional[str] = None + _service_path: str | None = None @property def config(self) -> MLFlowModelDeployerConfig: @@ -100,7 +100,7 @@ def local_path(self) -> str: @staticmethod def get_model_server_info( # type: ignore[override] service_instance: "MLFlowDeploymentService", - ) -> Dict[str, Optional[str]]: + ) -> dict[str, str | None]: """Return implementation specific information relevant to the user. Args: diff --git a/src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py b/src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py index 455236ac6d0..ca1e4a1d991 100644 --- a/src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py +++ b/src/zenml/integrations/mlflow/model_registries/mlflow_model_registry.py @@ -15,7 +15,7 @@ import os from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Any, cast from urllib.parse import unquote, urlparse import mlflow @@ -97,7 +97,7 @@ def _remove_file_scheme(uri: str) -> str: class MLFlowModelRegistry(BaseModelRegistry): """Register models using MLflow.""" - _client: Optional[MlflowClient] = None + _client: MlflowClient | None = None @property def config(self) -> MLFlowModelRegistryConfig: @@ -127,14 +127,14 @@ def mlflow_client(self) -> MlflowClient: return self._client @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates that the stack contains an mlflow experiment tracker. Returns: A StackValidator instance. """ - def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]: + def _validate_stack_requirements(stack: "Stack") -> tuple[bool, str]: """Validates that all the requirements are met for the stack. Args: @@ -181,8 +181,8 @@ def _validate_stack_requirements(stack: "Stack") -> Tuple[bool, str]: def register_model( self, name: str, - description: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, + description: str | None = None, + metadata: dict[str, str] | None = None, ) -> RegisteredModel: """Register a model to the MLflow model registry. @@ -254,9 +254,9 @@ def delete_model( def update_model( self, name: str, - description: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, - remove_metadata: Optional[List[str]] = None, + description: str | None = None, + metadata: dict[str, str] | None = None, + remove_metadata: list[str] | None = None, ) -> RegisteredModel: """Update a model in the MLflow model registry. @@ -347,9 +347,9 @@ def get_model(self, name: str) -> RegisteredModel: def list_models( self, - name: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, - ) -> List[RegisteredModel]: + name: str | None = None, + metadata: dict[str, str] | None = None, + ) -> list[RegisteredModel]: """List models in the MLflow model registry. Args: @@ -391,10 +391,10 @@ def list_models( def register_model_version( self, name: str, - version: Optional[str] = None, - model_source_uri: Optional[str] = None, - description: Optional[str] = None, - metadata: Optional[ModelRegistryModelMetadata] = None, + version: str | None = None, + model_source_uri: str | None = None, + description: str | None = None, + metadata: ModelRegistryModelMetadata | None = None, **kwargs: Any, ) -> RegistryModelVersion: """Register a model version to the MLflow model registry. @@ -492,10 +492,10 @@ def update_model_version( self, name: str, version: str, - description: Optional[str] = None, - metadata: Optional[ModelRegistryModelMetadata] = None, - remove_metadata: Optional[List[str]] = None, - stage: Optional[ModelVersionStage] = None, + description: str | None = None, + metadata: ModelRegistryModelMetadata | None = None, + remove_metadata: list[str] | None = None, + stage: ModelVersionStage | None = None, ) -> RegistryModelVersion: """Update a model version in the MLflow model registry. @@ -606,16 +606,16 @@ def get_model_version( def list_model_versions( self, - name: Optional[str] = None, - model_source_uri: Optional[str] = None, - metadata: Optional[ModelRegistryModelMetadata] = None, - stage: Optional[ModelVersionStage] = None, - count: Optional[int] = None, - created_after: Optional[datetime] = None, - created_before: Optional[datetime] = None, - order_by_date: Optional[str] = None, + name: str | None = None, + model_source_uri: str | None = None, + metadata: ModelRegistryModelMetadata | None = None, + stage: ModelVersionStage | None = None, + count: int | None = None, + created_after: datetime | None = None, + created_before: datetime | None = None, + order_by_date: str | None = None, **kwargs: Any, - ) -> List[RegistryModelVersion]: + ) -> list[RegistryModelVersion]: """List model versions from the MLflow model registry. Args: diff --git a/src/zenml/integrations/mlflow/services/mlflow_deployment.py b/src/zenml/integrations/mlflow/services/mlflow_deployment.py index 2c147155e1e..1c6711e06c7 100644 --- a/src/zenml/integrations/mlflow/services/mlflow_deployment.py +++ b/src/zenml/integrations/mlflow/services/mlflow_deployment.py @@ -16,7 +16,7 @@ import importlib import os import sys -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Union import numpy as np import pandas as pd @@ -77,7 +77,7 @@ class MLFlowDeploymentEndpoint(LocalDaemonServiceEndpoint): monitor: HTTPEndpointHealthMonitor @property - def prediction_url(self) -> Optional[str]: + def prediction_url(self) -> str | None: """Gets the prediction URL for the endpoint. Returns: @@ -106,9 +106,9 @@ class MLFlowDeploymentConfig(LocalDaemonServiceConfig): model_uri: str model_name: str - registry_model_name: Optional[str] = None - registry_model_version: Optional[str] = None - registry_model_stage: Optional[str] = None + registry_model_name: str | None = None + registry_model_version: str | None = None + registry_model_stage: str | None = None workers: int = 1 mlserver: bool = False timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT @@ -178,7 +178,7 @@ class MLFlowDeploymentService(LocalDaemonService, BaseDeploymentService): def __init__( self, - config: Union[MLFlowDeploymentConfig, Dict[str, Any]], + config: MLFlowDeploymentConfig | dict[str, Any], **attrs: Any, ) -> None: """Initialize the MLflow deployment service. @@ -232,8 +232,8 @@ def run(self) -> None: self.endpoint.prepare_for_start() try: - backend_kwargs: Dict[str, Any] = {} - serve_kwargs: Dict[str, Any] = {} + backend_kwargs: dict[str, Any] = {} + serve_kwargs: dict[str, Any] = {} mlflow_version = MLFLOW_VERSION.split(".") # MLflow version 1.26 introduces an additional mandatory # `timeout` argument to the `PyFuncBackend.serve` function @@ -271,7 +271,7 @@ def run(self) -> None: ) @property - def prediction_url(self) -> Optional[str]: + def prediction_url(self) -> str | None: """Get the URI where the prediction service is answering requests. Returns: diff --git a/src/zenml/integrations/mlflow/steps/__init__.py b/src/zenml/integrations/mlflow/steps/__init__.py index 5219ad52c61..3a64c38cf2e 100644 --- a/src/zenml/integrations/mlflow/steps/__init__.py +++ b/src/zenml/integrations/mlflow/steps/__init__.py @@ -13,6 +13,3 @@ # permissions and limitations under the License. """Initialization of the MLflow standard interface steps.""" -from zenml.integrations.mlflow.steps.mlflow_deployer import ( - mlflow_model_deployer_step, -) diff --git a/src/zenml/integrations/mlflow/steps/mlflow_deployer.py b/src/zenml/integrations/mlflow/steps/mlflow_deployer.py index 7ca77d7e7a9..b9d0723e383 100644 --- a/src/zenml/integrations/mlflow/steps/mlflow_deployer.py +++ b/src/zenml/integrations/mlflow/steps/mlflow_deployer.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the MLflow model deployer pipeline step.""" -from typing import Optional, cast +from typing import cast from mlflow.tracking import MlflowClient, artifact_utils @@ -50,13 +50,13 @@ def mlflow_model_deployer_step( model: UnmaterializedArtifact, deploy_decision: bool = True, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, + experiment_name: str | None = None, + run_name: str | None = None, model_name: str = "model", workers: int = 1, mlserver: bool = False, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, -) -> Optional[MLFlowDeploymentService]: +) -> MLFlowDeploymentService | None: """Model deployer pipeline step for MLflow. This step deploys a model logged in the MLflow artifact store to a @@ -191,8 +191,8 @@ def mlflow_model_deployer_step( @step(enable_cache=False) def mlflow_model_registry_deployer_step( registry_model_name: str, - registry_model_version: Optional[str] = None, - registry_model_stage: Optional[ModelVersionStage] = None, + registry_model_version: str | None = None, + registry_model_stage: ModelVersionStage | None = None, replace_existing: bool = True, model_name: str = "model", workers: int = 1, diff --git a/src/zenml/integrations/mlflow/steps/mlflow_registry.py b/src/zenml/integrations/mlflow/steps/mlflow_registry.py index c89d6fcb43d..9dae280e5bb 100644 --- a/src/zenml/integrations/mlflow/steps/mlflow_registry.py +++ b/src/zenml/integrations/mlflow/steps/mlflow_registry.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Implementation of the MLflow model registration pipeline step.""" -from typing import Optional from mlflow.tracking import artifact_utils @@ -38,14 +37,14 @@ def mlflow_register_model_step( model: UnmaterializedArtifact, name: str, - version: Optional[str] = None, - trained_model_name: Optional[str] = "model", - model_source_uri: Optional[str] = None, - experiment_name: Optional[str] = None, - run_name: Optional[str] = None, - run_id: Optional[str] = None, - description: Optional[str] = None, - metadata: Optional[ModelRegistryModelMetadata] = None, + version: str | None = None, + trained_model_name: str | None = "model", + model_source_uri: str | None = None, + experiment_name: str | None = None, + run_name: str | None = None, + run_id: str | None = None, + description: str | None = None, + metadata: ModelRegistryModelMetadata | None = None, ) -> None: """MLflow model registry step. diff --git a/src/zenml/integrations/mlx/__init__.py b/src/zenml/integrations/mlx/__init__.py index 93d4de00fa6..fe532d3723e 100644 --- a/src/zenml/integrations/mlx/__init__.py +++ b/src/zenml/integrations/mlx/__init__.py @@ -15,7 +15,6 @@ import platform import sys -from typing import List, Optional from zenml.integrations.constants import MLX from zenml.integrations.integration import Integration @@ -51,9 +50,9 @@ def check_installation(cls) -> bool: @classmethod def get_requirements( cls, - target_os: Optional[str] = None, - python_version: Optional[str] = None, - ) -> List[str]: + target_os: str | None = None, + python_version: str | None = None, + ) -> list[str]: # sys.platform is "darwin", while platform.system() is "Darwin", # similarly on Linux. target_os = (target_os or sys.platform).lower() diff --git a/src/zenml/integrations/mlx/materializer.py b/src/zenml/integrations/mlx/materializer.py index db9a5d667eb..6182323ad73 100644 --- a/src/zenml/integrations/mlx/materializer.py +++ b/src/zenml/integrations/mlx/materializer.py @@ -17,8 +17,6 @@ from typing import ( Any, ClassVar, - Tuple, - Type, ) import mlx.core as mx @@ -32,10 +30,10 @@ class MLXArrayMaterializer(BaseMaterializer): """A materializer for MLX arrays.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (mx.array,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (mx.array,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - def load(self, data_type: Type[Any]) -> mx.array: + def load(self, data_type: type[Any]) -> mx.array: """Reads data from a `.npy` file, and returns an MLX array. Args: diff --git a/src/zenml/integrations/modal/__init__.py b/src/zenml/integrations/modal/__init__.py index 081628cb035..4d33bb6ff60 100644 --- a/src/zenml/integrations/modal/__init__.py +++ b/src/zenml/integrations/modal/__init__.py @@ -16,7 +16,6 @@ The Modal integration sub-module provides a step operator flavor that allows executing steps on Modal's cloud infrastructure. """ -from typing import List, Type from zenml.integrations.constants import MODAL from zenml.integrations.integration import Integration @@ -32,7 +31,7 @@ class ModalIntegration(Integration): REQUIREMENTS = ["modal>=0.64.49,<1"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Modal integration. Returns: diff --git a/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py b/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py index a66580faa54..cefd5784c25 100644 --- a/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py +++ b/src/zenml/integrations/modal/flavors/modal_step_operator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Modal step operator flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.config.base_settings import BaseSettings from zenml.integrations.modal import MODAL_STEP_OPERATOR_FLAVOR @@ -41,9 +41,9 @@ class ModalStepOperatorSettings(BaseSettings): cloud: The cloud provider to use for the step execution. """ - gpu: Optional[str] = None - region: Optional[str] = None - cloud: Optional[str] = None + gpu: str | None = None + region: str | None = None + cloud: str | None = None class ModalStepOperatorConfig( @@ -78,7 +78,7 @@ def name(self) -> str: return MODAL_STEP_OPERATOR_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -87,7 +87,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -105,7 +105,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/modal.png" @property - def config_class(self) -> Type[ModalStepOperatorConfig]: + def config_class(self) -> type[ModalStepOperatorConfig]: """Returns `ModalStepOperatorConfig` config class. Returns: @@ -114,7 +114,7 @@ def config_class(self) -> Type[ModalStepOperatorConfig]: return ModalStepOperatorConfig @property - def implementation_class(self) -> Type["ModalStepOperator"]: + def implementation_class(self) -> type["ModalStepOperator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/modal/step_operators/modal_step_operator.py b/src/zenml/integrations/modal/step_operators/modal_step_operator.py index 9c654ca6541..10158bad8c2 100644 --- a/src/zenml/integrations/modal/step_operators/modal_step_operator.py +++ b/src/zenml/integrations/modal/step_operators/modal_step_operator.py @@ -14,7 +14,7 @@ """Modal step operator implementation.""" import asyncio -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, cast +from typing import TYPE_CHECKING, cast import modal from modal_proto import api_pb2 @@ -43,7 +43,7 @@ def get_gpu_values( settings: ModalStepOperatorSettings, resource_settings: ResourceSettings -) -> Optional[str]: +) -> str | None: """Get the GPU values for the Modal step operator. Args: @@ -76,7 +76,7 @@ def config(self) -> ModalStepOperatorConfig: return cast(ModalStepOperatorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Get the settings class for the Modal step operator. Returns: @@ -85,14 +85,14 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: return ModalStepOperatorSettings @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Get the stack validator for the Modal step operator. Returns: The stack validator. """ - def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + def _validate_remote_components(stack: "Stack") -> tuple[bool, str]: if stack.artifact_store.config.is_local: return False, ( "The Modal step operator runs code remotely and " @@ -128,7 +128,7 @@ def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: def get_docker_builds( self, snapshot: "PipelineSnapshotBase" - ) -> List["BuildConfiguration"]: + ) -> list["BuildConfiguration"]: """Get the Docker build configurations for the Modal step operator. Args: @@ -152,8 +152,8 @@ def get_docker_builds( def launch( self, info: "StepRunInfo", - entrypoint_command: List[str], - environment: Dict[str, str], + entrypoint_command: list[str], + environment: dict[str, str], ) -> None: """Launch a step run on Modal. diff --git a/src/zenml/integrations/neptune/__init__.py b/src/zenml/integrations/neptune/__init__.py index 579f52b7254..1dde063c3f0 100644 --- a/src/zenml/integrations/neptune/__init__.py +++ b/src/zenml/integrations/neptune/__init__.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Module containing Neptune integration.""" -from typing import List, Type from zenml.integrations.constants import NEPTUNE from zenml.integrations.integration import Integration @@ -32,7 +31,7 @@ class NeptuneIntegration(Integration): REQUIREMENTS = ["neptune"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Neptune integration. Returns: diff --git a/src/zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py b/src/zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py index e71d84d2e77..cfdb7661529 100644 --- a/src/zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py +++ b/src/zenml/integrations/neptune/experiment_trackers/neptune_experiment_tracker.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of Neptune Experiment Tracker.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, cast +from typing import TYPE_CHECKING, Any, cast from zenml.constants import METADATA_EXPERIMENT_TRACKER_URL from zenml.experiment_trackers.base_experiment_tracker import ( @@ -57,7 +57,7 @@ def config(self) -> NeptuneExperimentTrackerConfig: return cast(NeptuneExperimentTrackerConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Neptune experiment tracker. Returns: @@ -86,7 +86,7 @@ def prepare_step_run(self, info: "StepRunInfo") -> None: def get_step_run_metadata( self, info: "StepRunInfo" - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Get component- and step-specific metadata after a step ran. Args: diff --git a/src/zenml/integrations/neptune/experiment_trackers/run_state.py b/src/zenml/integrations/neptune/experiment_trackers/run_state.py index d71d9ed8100..a554c598429 100644 --- a/src/zenml/integrations/neptune/experiment_trackers/run_state.py +++ b/src/zenml/integrations/neptune/experiment_trackers/run_state.py @@ -14,7 +14,7 @@ """Contains objects that create a Neptune run and store its state throughout the pipeline.""" from hashlib import md5 -from typing import TYPE_CHECKING, Any, List, Optional +from typing import TYPE_CHECKING, Any, Optional import neptune @@ -34,18 +34,18 @@ class RunProvider(metaclass=SingletonMetaClass): def __init__(self) -> None: """Initialize RunProvider. Called with no arguments.""" self._active_run: Optional["Run"] = None - self._project: Optional[str] = None - self._run_name: Optional[str] = None - self._token: Optional[str] = None - self._tags: Optional[List[str]] = None + self._project: str | None = None + self._run_name: str | None = None + self._token: str | None = None + self._tags: list[str] | None = None self._initialized = False def initialize( self, - project: Optional[str] = None, - token: Optional[str] = None, - run_name: Optional[str] = None, - tags: Optional[List[str]] = None, + project: str | None = None, + token: str | None = None, + run_name: str | None = None, + tags: list[str] | None = None, ) -> None: """Initialize the run state. @@ -62,7 +62,7 @@ def initialize( self._initialized = True @property - def project(self) -> Optional[Any]: + def project(self) -> Any | None: """Getter for project name. Returns: @@ -71,7 +71,7 @@ def project(self) -> Optional[Any]: return self._project @property - def token(self) -> Optional[Any]: + def token(self) -> Any | None: """Getter for API token. Returns: @@ -80,7 +80,7 @@ def token(self) -> Optional[Any]: return self._token @property - def run_name(self) -> Optional[Any]: + def run_name(self) -> Any | None: """Getter for run name. Returns: @@ -89,7 +89,7 @@ def run_name(self) -> Optional[Any]: return self._run_name @property - def tags(self) -> Optional[Any]: + def tags(self) -> Any | None: """Getter for run tags. Returns: diff --git a/src/zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py b/src/zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py index dd10c6a2dbe..a5e9dd54cd4 100644 --- a/src/zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py +++ b/src/zenml/integrations/neptune/flavors/neptune_experiment_tracker_flavor.py @@ -19,7 +19,7 @@ "NeptuneExperimentTrackerSettings", ] -from typing import TYPE_CHECKING, Optional, Set, Type +from typing import TYPE_CHECKING from pydantic import Field @@ -44,11 +44,11 @@ class NeptuneExperimentTrackerConfig(BaseExperimentTrackerConfig): will try to find the relevant values in the environment """ - project: Optional[str] = Field( + project: str | None = Field( None, description="Name of the Neptune project you want to log the metadata to.", ) - api_token: Optional[str] = SecretField( + api_token: str | None = SecretField( default=None, description="Your Neptune API token for authentication." ) @@ -56,7 +56,7 @@ class NeptuneExperimentTrackerConfig(BaseExperimentTrackerConfig): class NeptuneExperimentTrackerSettings(BaseSettings): """Settings for the Neptune experiment tracker.""" - tags: Set[str] = Field( + tags: set[str] = Field( default_factory=set, description="Tags for the Neptune run." ) @@ -74,7 +74,7 @@ def name(self) -> str: return NEPTUNE_MODEL_EXPERIMENT_TRACKER_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -83,7 +83,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -101,7 +101,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/experiment_tracker/neptune.png" @property - def config_class(self) -> Type[NeptuneExperimentTrackerConfig]: + def config_class(self) -> type[NeptuneExperimentTrackerConfig]: """Returns `NeptuneExperimentTrackerConfig` config class. Returns: @@ -110,7 +110,7 @@ def config_class(self) -> Type[NeptuneExperimentTrackerConfig]: return NeptuneExperimentTrackerConfig @property - def implementation_class(self) -> Type["NeptuneExperimentTracker"]: + def implementation_class(self) -> type["NeptuneExperimentTracker"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py b/src/zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py index 2142d5ca549..ed6c03db377 100644 --- a/src/zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py +++ b/src/zenml/integrations/neural_prophet/materializers/neural_prophet_materializer.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the Neural Prophet materializer.""" -from typing import Any, ClassVar, Tuple, Type +from typing import Any, ClassVar from neuralprophet import NeuralProphet @@ -33,6 +33,6 @@ class NeuralProphetMaterializer(BasePyTorchMaterializer): """Materializer to read/write NeuralProphet models.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (NeuralProphet,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (NeuralProphet,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL FILENAME: ClassVar[str] = DEFAULT_FILENAME diff --git a/src/zenml/integrations/numpy/materializers/numpy_materializer.py b/src/zenml/integrations/numpy/materializers/numpy_materializer.py index e9e2ad671dd..1705a9a94d0 100644 --- a/src/zenml/integrations/numpy/materializers/numpy_materializer.py +++ b/src/zenml/integrations/numpy/materializers/numpy_materializer.py @@ -19,10 +19,6 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - Optional, - Tuple, - Type, Union, ) @@ -70,7 +66,7 @@ def _ensure_dtype_compatibility(arr: "NDArray[Any]") -> "NDArray[Any]": def _create_array( - data: Any, dtype: Optional[Union["np.dtype[Any]", Type[Any]]] = None + data: Any, dtype: Union["np.dtype[Any]", type[Any]] | None = None ) -> "NDArray[Any]": """Create arrays with consistent behavior across NumPy versions. @@ -91,10 +87,10 @@ def _create_array( class NumpyMaterializer(BaseMaterializer): """Materializer to read data to and from pandas.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (np.ndarray,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (np.ndarray,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - def load(self, data_type: Type[Any]) -> "Any": + def load(self, data_type: type[Any]) -> "Any": """Reads a numpy array from a `.npy` file. Args: @@ -166,7 +162,7 @@ def save(self, arr: "NDArray[Any]") -> None: def save_visualizations( self, arr: "NDArray[Any]" - ) -> Dict[str, VisualizationType]: + ) -> dict[str, VisualizationType]: """Saves visualizations for a numpy array. If the array is 1D, a histogram is saved. If the array is 2D or 3D with @@ -233,7 +229,7 @@ def _save_image(self, output_path: str, arr: "NDArray[Any]") -> None: def extract_metadata( self, arr: "NDArray[Any]" - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Extract metadata from the given numpy array. Args: @@ -253,7 +249,7 @@ def extract_metadata( def _extract_numeric_metadata( self, arr: "NDArray[Any]" - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Extracts numeric metadata from a numpy array. Args: @@ -268,7 +264,7 @@ def _extract_numeric_metadata( min_val = np.min(arr).item() max_val = np.max(arr).item() - numpy_metadata: Dict[str, "MetadataType"] = { + numpy_metadata: dict[str, "MetadataType"] = { "shape": tuple(arr.shape), "dtype": DType(arr.dtype.type), "mean": np.mean(arr).item(), @@ -280,7 +276,7 @@ def _extract_numeric_metadata( def _extract_text_metadata( self, arr: "NDArray[Any]" - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Extracts text metadata from a numpy array. Args: @@ -304,7 +300,7 @@ def _extract_text_metadata( total_words = len(words) most_common_word, most_common_count = word_counts.most_common(1)[0] - text_metadata: Dict[str, "MetadataType"] = { + text_metadata: dict[str, "MetadataType"] = { "shape": tuple(arr.shape), "dtype": DType(arr.dtype.type), "unique_words": unique_words, diff --git a/src/zenml/integrations/openai/hooks/__init__.py b/src/zenml/integrations/openai/hooks/__init__.py index 7ad2bbda9bf..4dd1ddce744 100644 --- a/src/zenml/integrations/openai/hooks/__init__.py +++ b/src/zenml/integrations/openai/hooks/__init__.py @@ -13,7 +13,3 @@ # permissions and limitations under the License. """Initialization of the OpenAI hooks module.""" -from zenml.integrations.openai.hooks.open_ai_failure_hook import ( - openai_chatgpt_alerter_failure_hook, - openai_gpt4_alerter_failure_hook, -) diff --git a/src/zenml/integrations/openai/hooks/open_ai_failure_hook.py b/src/zenml/integrations/openai/hooks/open_ai_failure_hook.py index d9b497fcf63..af7b1bc706b 100644 --- a/src/zenml/integrations/openai/hooks/open_ai_failure_hook.py +++ b/src/zenml/integrations/openai/hooks/open_ai_failure_hook.py @@ -15,7 +15,6 @@ import io import sys -from typing import Optional from openai import OpenAI from rich.console import Console @@ -50,7 +49,7 @@ def openai_alerter_failure_hook_helper( openai_secret = client.get_secret( "openai", allow_partial_name_match=False ) - openai_api_key: Optional[str] = openai_secret.secret_values.get( + openai_api_key: str | None = openai_secret.secret_values.get( "api_key" ) except (KeyError, NotImplementedError): diff --git a/src/zenml/integrations/pandas/materializers/pandas_materializer.py b/src/zenml/integrations/pandas/materializers/pandas_materializer.py index aa1b6f2cc4f..2dac50af6aa 100644 --- a/src/zenml/integrations/pandas/materializers/pandas_materializer.py +++ b/src/zenml/integrations/pandas/materializers/pandas_materializer.py @@ -26,7 +26,7 @@ """ import os -from typing import Any, ClassVar, Dict, Optional, Tuple, Type, Union +from typing import Any, ClassVar import pandas as pd @@ -76,14 +76,14 @@ def is_standard_dtype(dtype_str: str) -> bool: class PandasMaterializer(BaseMaterializer): """Materializer to read data to and from pandas.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = ( + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = ( pd.DataFrame, pd.Series, ) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA def __init__( - self, uri: str, artifact_store: Optional[BaseArtifactStore] = None + self, uri: str, artifact_store: BaseArtifactStore | None = None ): """Define `self.data_path`. @@ -109,7 +109,7 @@ def __init__( self.parquet_path = os.path.join(self.uri, PARQUET_FILENAME) self.csv_path = os.path.join(self.uri, CSV_FILENAME) - def load(self, data_type: Type[Any]) -> Union[pd.DataFrame, pd.Series]: + def load(self, data_type: type[Any]) -> pd.DataFrame | pd.Series: """Reads `pd.DataFrame` or `pd.Series` from a `.parquet` or `.csv` file. Args: @@ -174,8 +174,8 @@ def load(self, data_type: Type[Any]) -> Union[pd.DataFrame, pd.Series]: # validate the type of the data. def is_dataframe_or_series( - df: Union[pd.DataFrame, pd.Series], - ) -> Union[pd.DataFrame, pd.Series]: + df: pd.DataFrame | pd.Series, + ) -> pd.DataFrame | pd.Series: """Checks if the data is a `pd.DataFrame` or `pd.Series`. Args: @@ -195,7 +195,7 @@ def is_dataframe_or_series( return is_dataframe_or_series(df) - def save(self, df: Union[pd.DataFrame, pd.Series]) -> None: + def save(self, df: pd.DataFrame | pd.Series) -> None: """Writes a pandas dataframe or series to the specified filename. Args: @@ -212,8 +212,8 @@ def save(self, df: Union[pd.DataFrame, pd.Series]) -> None: df.to_csv(f, index=True) def save_visualizations( - self, df: Union[pd.DataFrame, pd.Series] - ) -> Dict[str, VisualizationType]: + self, df: pd.DataFrame | pd.Series + ) -> dict[str, VisualizationType]: """Save visualizations of the given pandas dataframe or series. Creates two visualizations: @@ -258,8 +258,8 @@ def save_visualizations( return visualizations def extract_metadata( - self, df: Union[pd.DataFrame, pd.Series] - ) -> Dict[str, "MetadataType"]: + self, df: pd.DataFrame | pd.Series + ) -> dict[str, "MetadataType"]: """Extract metadata from the given pandas dataframe or series. Args: @@ -279,7 +279,7 @@ def extract_metadata( series_obj = df # Keep original Series for some calculations df = df.to_frame(name="series") - pandas_metadata: Dict[str, "MetadataType"] = {"shape": original_shape} + pandas_metadata: dict[str, "MetadataType"] = {"shape": original_shape} # Add information about custom data types to metadata custom_types = {} diff --git a/src/zenml/integrations/pigeon/__init__.py b/src/zenml/integrations/pigeon/__init__.py index a21bdd8bf84..0195ba7d530 100644 --- a/src/zenml/integrations/pigeon/__init__.py +++ b/src/zenml/integrations/pigeon/__init__.py @@ -12,7 +12,6 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Initialization of the Pigeon integration.""" -from typing import List, Type from zenml.integrations.constants import PIGEON from zenml.integrations.integration import Integration @@ -28,7 +27,7 @@ class PigeonIntegration(Integration): REQUIREMENTS = ["ipywidgets>=8.0.0"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Pigeon integration. Returns: diff --git a/src/zenml/integrations/pigeon/annotators/pigeon_annotator.py b/src/zenml/integrations/pigeon/annotators/pigeon_annotator.py index f24343300d4..f153861a047 100644 --- a/src/zenml/integrations/pigeon/annotators/pigeon_annotator.py +++ b/src/zenml/integrations/pigeon/annotators/pigeon_annotator.py @@ -25,7 +25,7 @@ import json import os from datetime import datetime -from typing import Any, List, Optional, Tuple, cast +from typing import Any, cast import ipywidgets as widgets # type: ignore from IPython.core.display_functions import clear_output, display @@ -70,7 +70,7 @@ def get_url_for_dataset(self, dataset_name: str) -> str: """ raise NotImplementedError("Pigeon annotator does not have a URL.") - def get_datasets(self) -> List[str]: + def get_datasets(self) -> list[str]: """Get a list of datasets (annotation files) in the output directory. Returns: @@ -82,7 +82,7 @@ def get_datasets(self) -> List[str]: except FileNotFoundError: return [] - def get_dataset_names(self) -> List[str]: + def get_dataset_names(self) -> list[str]: """List dataset names (annotation file names) in the output directory. Returns: @@ -90,7 +90,7 @@ def get_dataset_names(self) -> List[str]: """ return self.get_datasets() - def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]: + def get_dataset_stats(self, dataset_name: str) -> tuple[int, int]: """List labeled and unlabeled examples in a dataset (annotation file). Args: @@ -105,7 +105,7 @@ def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]: num_unlabeled_examples = 0 try: - with open(dataset_path, "r") as file: + with open(dataset_path) as file: num_labeled_examples = sum(1 for _ in file) except FileNotFoundError: logger.error(f"File not found: {dataset_path}") @@ -114,10 +114,10 @@ def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]: def _annotate( self, - data: List[Any], - options: List[str], - display_fn: Optional[Any] = None, - ) -> List[Tuple[Any, Any]]: + data: list[Any], + options: list[str], + display_fn: Any | None = None, + ) -> list[tuple[Any, Any]]: """Internal method to build an interactive widget for annotating. Args: @@ -207,10 +207,10 @@ def launch(self, **kwargs: Any) -> None: def annotate( self, - data: List[Any], - options: List[str], - display_fn: Optional[Any] = None, - ) -> List[Tuple[Any, Any]]: + data: list[Any], + options: list[str], + display_fn: Any | None = None, + ) -> list[tuple[Any, Any]]: """Annotate with the Pigeon annotator in the Jupyter notebook. Args: @@ -224,7 +224,7 @@ def annotate( annotations = self._annotate(data, options, display_fn) return annotations - def _save_annotations(self, annotations: List[Tuple[Any, Any]]) -> None: + def _save_annotations(self, annotations: list[tuple[Any, Any]]) -> None: """Save annotations to a file with a unique date-time suffix. Args: @@ -269,7 +269,7 @@ def delete_dataset(self, **kwargs: Any) -> None: dataset_path = os.path.join(self.config.output_dir, dataset_name) os.remove(dataset_path) - def get_dataset(self, **kwargs: Any) -> List[Tuple[Any, Any]]: + def get_dataset(self, **kwargs: Any) -> list[tuple[Any, Any]]: """Get the annotated examples from a dataset (annotation file). Takes the `dataset_name` argument from the kwargs. @@ -290,11 +290,11 @@ def get_dataset(self, **kwargs: Any) -> List[Tuple[Any, Any]]: "Dataset name (`dataset_name`) is required to retrieve a dataset." ) dataset_path = os.path.join(self.config.output_dir, dataset_name) - with open(dataset_path, "r") as f: + with open(dataset_path) as f: annotations = json.load(f) - return cast(List[Tuple[Any, Any]], annotations) + return cast(list[tuple[Any, Any]], annotations) - def get_labeled_data(self, **kwargs: Any) -> List[Tuple[Any, Any]]: + def get_labeled_data(self, **kwargs: Any) -> list[tuple[Any, Any]]: """Get the labeled examples from a dataset (annotation file). Takes the `dataset_name` argument from the kwargs. diff --git a/src/zenml/integrations/pigeon/flavors/pigeon_annotator_flavor.py b/src/zenml/integrations/pigeon/flavors/pigeon_annotator_flavor.py index 69e1bea3e2e..b195df20870 100644 --- a/src/zenml/integrations/pigeon/flavors/pigeon_annotator_flavor.py +++ b/src/zenml/integrations/pigeon/flavors/pigeon_annotator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Pigeon annotator flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.annotators.base_annotator import ( BaseAnnotatorConfig, @@ -57,7 +57,7 @@ def name(self) -> str: return PIGEON_ANNOTATOR_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -66,7 +66,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -84,7 +84,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/annotator/pigeon.png" @property - def config_class(self) -> Type[PigeonAnnotatorConfig]: + def config_class(self) -> type[PigeonAnnotatorConfig]: """Returns `PigeonAnnotatorConfig` config class. Returns: @@ -93,7 +93,7 @@ def config_class(self) -> Type[PigeonAnnotatorConfig]: return PigeonAnnotatorConfig @property - def implementation_class(self) -> Type["PigeonAnnotator"]: + def implementation_class(self) -> type["PigeonAnnotator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/pillow/materializers/pillow_image_materializer.py b/src/zenml/integrations/pillow/materializers/pillow_image_materializer.py index fc4029ebaca..8c032415720 100644 --- a/src/zenml/integrations/pillow/materializers/pillow_image_materializer.py +++ b/src/zenml/integrations/pillow/materializers/pillow_image_materializer.py @@ -14,7 +14,7 @@ """Materializer for Pillow Image objects.""" import os -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar from PIL import Image @@ -41,10 +41,10 @@ class PillowImageMaterializer(BaseMaterializer): https://pillow.readthedocs.io/en/stable/handbook/image-file-formats.html. """ - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Image.Image,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (Image.Image,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - def load(self, data_type: Type[Image.Image]) -> Image.Image: + def load(self, data_type: type[Image.Image]) -> Image.Image: """Read from artifact store. Args: @@ -86,7 +86,7 @@ def save(self, image: Image.Image) -> None: def save_visualizations( self, image: Image.Image - ) -> Dict[str, VisualizationType]: + ) -> dict[str, VisualizationType]: """Finds and saves the given image as a visualization. Args: @@ -103,7 +103,7 @@ def save_visualizations( def extract_metadata( self, image: Image.Image - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Extract metadata from the given `Image` object. Args: diff --git a/src/zenml/integrations/polars/materializers/dataframe_materializer.py b/src/zenml/integrations/polars/materializers/dataframe_materializer.py index 5a888f9e671..c30242b03d9 100644 --- a/src/zenml/integrations/polars/materializers/dataframe_materializer.py +++ b/src/zenml/integrations/polars/materializers/dataframe_materializer.py @@ -14,7 +14,7 @@ """Polars materializer.""" import os -from typing import Any, ClassVar, Tuple, Type, Union +from typing import Any, ClassVar import polars as pl import pyarrow as pa # type: ignore @@ -28,13 +28,13 @@ class PolarsMaterializer(BaseMaterializer): """Materializer to read/write Polars dataframes.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = ( + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = ( pl.DataFrame, pl.Series, ) ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA - def load(self, data_type: Type[Any]) -> Any: + def load(self, data_type: type[Any]) -> Any: """Reads and returns Polars data after copying it to temporary path. Args: @@ -67,7 +67,7 @@ def load(self, data_type: Type[Any]) -> Any: return data - def save(self, data: Union[pl.DataFrame, pl.Series]) -> None: + def save(self, data: pl.DataFrame | pl.Series) -> None: """Writes Polars data to the artifact store. Args: diff --git a/src/zenml/integrations/prodigy/__init__.py b/src/zenml/integrations/prodigy/__init__.py index 868e7ffdf2b..e52a8a10458 100644 --- a/src/zenml/integrations/prodigy/__init__.py +++ b/src/zenml/integrations/prodigy/__init__.py @@ -12,7 +12,6 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Initialization of the Prodigy integration.""" -from typing import List, Type from zenml.integrations.constants import PRODIGY from zenml.integrations.integration import Integration @@ -32,7 +31,7 @@ class ProdigyIntegration(Integration): REQUIREMENTS_IGNORED_ON_UNINSTALL = ["urllib3"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Prodigy integration. Returns: diff --git a/src/zenml/integrations/prodigy/annotators/prodigy_annotator.py b/src/zenml/integrations/prodigy/annotators/prodigy_annotator.py index 035d5775c58..245c9576256 100644 --- a/src/zenml/integrations/prodigy/annotators/prodigy_annotator.py +++ b/src/zenml/integrations/prodigy/annotators/prodigy_annotator.py @@ -14,7 +14,7 @@ """Implementation of the Prodigy annotation integration.""" import json -from typing import Any, List, Optional, Tuple, cast +from typing import Any, cast import prodigy from peewee import Database as PeeweeDatabase @@ -58,7 +58,7 @@ def get_url(self) -> str: instance_url = DEFAULT_LOCAL_INSTANCE_HOST port = DEFAULT_LOCAL_PRODIGY_PORT if self.config.custom_config_path: - with open(self.config.custom_config_path, "r") as f: + with open(self.config.custom_config_path) as f: config = json.load(f) instance_url = config.get("instance_url", instance_url) port = config.get("port", port) @@ -78,16 +78,16 @@ def get_url_for_dataset(self, dataset_name: str) -> str: """ return self.get_url() - def get_datasets(self) -> List[Any]: + def get_datasets(self) -> list[Any]: """Gets the datasets currently available for annotation. Returns: A list of datasets (str). """ datasets = self._get_db().datasets - return cast(List[Any], datasets) + return cast(list[Any], datasets) - def get_dataset_names(self) -> List[str]: + def get_dataset_names(self) -> list[str]: """Gets the names of the datasets. Returns: @@ -95,7 +95,7 @@ def get_dataset_names(self) -> List[str]: """ return self.get_datasets() - def get_dataset_stats(self, dataset_name: str) -> Tuple[int, int]: + def get_dataset_stats(self, dataset_name: str) -> tuple[int, int]: """Gets the statistics of the given dataset. Args: @@ -148,8 +148,8 @@ def launch(self, **kwargs: Any) -> None: def _get_db( self, custom_database: PeeweeDatabase = None, - display_id: Optional[str] = None, - display_name: Optional[str] = None, + display_id: str | None = None, + display_name: str | None = None, ) -> ProdigyDatabase: """Gets Prodigy database / client. diff --git a/src/zenml/integrations/prodigy/flavors/prodigy_annotator_flavor.py b/src/zenml/integrations/prodigy/flavors/prodigy_annotator_flavor.py index c09648fb4d6..9b9030a3dce 100644 --- a/src/zenml/integrations/prodigy/flavors/prodigy_annotator_flavor.py +++ b/src/zenml/integrations/prodigy/flavors/prodigy_annotator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Prodigy annotator flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.annotators.base_annotator import ( BaseAnnotatorConfig, @@ -36,7 +36,7 @@ class ProdigyAnnotatorConfig(BaseAnnotatorConfig, AuthenticationConfigMixin): custom_config_path: The path to a custom config file for Prodigy. """ - custom_config_path: Optional[str] = None + custom_config_path: str | None = None class ProdigyAnnotatorFlavor(BaseAnnotatorFlavor): @@ -52,7 +52,7 @@ def name(self) -> str: return PRODIGY_ANNOTATOR_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -61,7 +61,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -79,7 +79,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/annotator/prodigy.png" @property - def config_class(self) -> Type[ProdigyAnnotatorConfig]: + def config_class(self) -> type[ProdigyAnnotatorConfig]: """Returns `ProdigyAnnotatorConfig` config class. Returns: @@ -88,7 +88,7 @@ def config_class(self) -> Type[ProdigyAnnotatorConfig]: return ProdigyAnnotatorConfig @property - def implementation_class(self) -> Type["ProdigyAnnotator"]: + def implementation_class(self) -> type["ProdigyAnnotator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/pycaret/materializers/model_materializer.py b/src/zenml/integrations/pycaret/materializers/model_materializer.py index 648de8027f5..a4323d8426d 100644 --- a/src/zenml/integrations/pycaret/materializers/model_materializer.py +++ b/src/zenml/integrations/pycaret/materializers/model_materializer.py @@ -15,7 +15,6 @@ from typing import ( Any, - Type, ) from catboost import CatBoostClassifier, CatBoostRegressor # type: ignore @@ -122,7 +121,7 @@ class PyCaretMaterializer(BaseMaterializer): ) ASSOCIATED_ARTIFACT_TYPE = ArtifactType.MODEL - def load(self, data_type: Type[Any]) -> Any: + def load(self, data_type: type[Any]) -> Any: """Reads and returns a PyCaret model after copying it to temporary path. Args: diff --git a/src/zenml/integrations/pytorch/__init__.py b/src/zenml/integrations/pytorch/__init__.py index e97043f9bd8..2a86abf83b9 100644 --- a/src/zenml/integrations/pytorch/__init__.py +++ b/src/zenml/integrations/pytorch/__init__.py @@ -13,8 +13,6 @@ # permissions and limitations under the License. """Initialization of the PyTorch integration.""" -import platform -from typing import List, Optional from zenml.integrations.constants import PYTORCH from zenml.integrations.integration import Integration diff --git a/src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py b/src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py index 0f37a1ec38c..3232c3f5b74 100644 --- a/src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/base_pytorch_materializer.py @@ -14,7 +14,7 @@ """Implementation of the PyTorch DataLoader materializer.""" import os -from typing import Any, ClassVar, Type +from typing import Any, ClassVar import cloudpickle import torch @@ -31,7 +31,7 @@ class BasePyTorchMaterializer(BaseMaterializer): FILENAME: ClassVar[str] = DEFAULT_FILENAME SKIP_REGISTRATION: ClassVar[bool] = True - def load(self, data_type: Type[Any]) -> Any: + def load(self, data_type: type[Any]) -> Any: """Uses `torch.load` to load a PyTorch object. Args: diff --git a/src/zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py b/src/zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py index 3368fd7a8ca..6ac677eec68 100644 --- a/src/zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/pytorch_dataloader_materializer.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the PyTorch DataLoader materializer.""" -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar from torch.utils.data import Dataset from torch.utils.data.dataloader import DataLoader @@ -32,11 +32,11 @@ class PyTorchDataLoaderMaterializer(BasePyTorchMaterializer): """Materializer to read/write PyTorch dataloaders and datasets.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (DataLoader, Dataset) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (DataLoader, Dataset) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA FILENAME: ClassVar[str] = DEFAULT_FILENAME - def extract_metadata(self, dataloader: Any) -> Dict[str, "MetadataType"]: + def extract_metadata(self, dataloader: Any) -> dict[str, "MetadataType"]: """Extract metadata from the given dataloader or dataset. Args: @@ -45,7 +45,7 @@ def extract_metadata(self, dataloader: Any) -> Dict[str, "MetadataType"]: Returns: The extracted metadata as a dictionary. """ - metadata: Dict[str, "MetadataType"] = {} + metadata: dict[str, "MetadataType"] = {} if isinstance(dataloader, DataLoader): if hasattr(dataloader.dataset, "__len__"): metadata["num_samples"] = len(dataloader.dataset) diff --git a/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py b/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py index 86e36572c8d..51ecea37827 100644 --- a/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py +++ b/src/zenml/integrations/pytorch/materializers/pytorch_module_materializer.py @@ -14,7 +14,7 @@ """Implementation of the PyTorch Module materializer.""" import os -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar import cloudpickle import torch @@ -41,7 +41,7 @@ class PyTorchModuleMaterializer(BasePyTorchMaterializer): https://pytorch.org/tutorials/beginner/saving_loading_models.html """ - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Module,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (Module,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL FILENAME: ClassVar[str] = DEFAULT_FILENAME @@ -66,7 +66,7 @@ def save(self, model: Module) -> None: # is intended for use with trusted data sources. torch.save(model.state_dict(), f, pickle_module=cloudpickle) # nosec - def extract_metadata(self, model: Module) -> Dict[str, "MetadataType"]: + def extract_metadata(self, model: Module) -> dict[str, "MetadataType"]: """Extract metadata from the given `Model` object. Args: diff --git a/src/zenml/integrations/pytorch/utils.py b/src/zenml/integrations/pytorch/utils.py index f723a2cb541..edaf89a22c5 100644 --- a/src/zenml/integrations/pytorch/utils.py +++ b/src/zenml/integrations/pytorch/utils.py @@ -13,12 +13,11 @@ # permissions and limitations under the License. """PyTorch utils.""" -from typing import Dict import torch -def count_module_params(module: torch.nn.Module) -> Dict[str, int]: +def count_module_params(module: torch.nn.Module) -> dict[str, int]: """Get the total and trainable parameters of a module. Args: diff --git a/src/zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py b/src/zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py index cd58e3adcc1..c8b50496559 100644 --- a/src/zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py +++ b/src/zenml/integrations/pytorch_lightning/materializers/pytorch_lightning_materializer.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the PyTorch Lightning Materializer.""" -from typing import Any, ClassVar, Tuple, Type +from typing import Any, ClassVar from torch.nn import Module @@ -28,6 +28,6 @@ class PyTorchLightningMaterializer(BasePyTorchMaterializer): """Materializer to read/write PyTorch models.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Module,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (Module,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL FILENAME: ClassVar[str] = CHECKPOINT_NAME diff --git a/src/zenml/integrations/registry.py b/src/zenml/integrations/registry.py index 46f3b7b0d83..67cc9703e94 100644 --- a/src/zenml/integrations/registry.py +++ b/src/zenml/integrations/registry.py @@ -15,7 +15,7 @@ import importlib import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any from zenml.exceptions import IntegrationError from zenml.logger import get_logger @@ -26,16 +26,16 @@ logger = get_logger(__name__) -class IntegrationRegistry(object): +class IntegrationRegistry: """Registry to keep track of ZenML Integrations.""" def __init__(self) -> None: """Initializing the integration registry.""" - self._integrations: Dict[str, Type["Integration"]] = {} + self._integrations: dict[str, type["Integration"]] = {} self._initialized = False @property - def integrations(self) -> Dict[str, Type["Integration"]]: + def integrations(self) -> dict[str, type["Integration"]]: """Method to get integrations dictionary. Returns: @@ -62,7 +62,7 @@ def integrations(self, i: Any) -> None: ) def register_integration( - self, key: str, type_: Type["Integration"] + self, key: str, type_: type["Integration"] ) -> None: """Method to register an integration with a given name. @@ -108,7 +108,7 @@ def activate_integrations(self) -> None: logger.debug(f"Integration `{name}` could not be activated.") @property - def list_integration_names(self) -> List[str]: + def list_integration_names(self) -> list[str]: """Get a list of all possible integrations. Returns: @@ -119,9 +119,9 @@ def list_integration_names(self) -> List[str]: def select_integration_requirements( self, - integration_name: Optional[str] = None, - target_os: Optional[str] = None, - ) -> List[str]: + integration_name: str | None = None, + target_os: str | None = None, + ) -> list[str]: """Select the requirements for a given integration or all integrations. Args: @@ -157,9 +157,9 @@ def select_integration_requirements( def select_uninstall_requirements( self, - integration_name: Optional[str] = None, - target_os: Optional[str] = None, - ) -> List[str]: + integration_name: str | None = None, + target_os: str | None = None, + ) -> list[str]: """Select the uninstall requirements for a given integration or all integrations. Args: @@ -193,7 +193,7 @@ def select_uninstall_requirements( ].get_uninstall_requirements(target_os=target_os) ] - def is_installed(self, integration_name: Optional[str] = None) -> bool: + def is_installed(self, integration_name: str | None = None) -> bool: """Checks if all requirements for an integration are installed. Args: @@ -221,7 +221,7 @@ def is_installed(self, integration_name: Optional[str] = None) -> bool: f"{self.list_integration_names}" ) - def get_installed_integrations(self) -> List[str]: + def get_installed_integrations(self) -> list[str]: """Returns list of installed integrations. Returns: diff --git a/src/zenml/integrations/s3/__init__.py b/src/zenml/integrations/s3/__init__.py index 38f5f890b02..5d6fe662230 100644 --- a/src/zenml/integrations/s3/__init__.py +++ b/src/zenml/integrations/s3/__init__.py @@ -16,7 +16,6 @@ The S3 integration allows the use of cloud artifact stores and file operations on S3 buckets. """ -from typing import List, Type from zenml.integrations.constants import S3 from zenml.integrations.integration import Integration @@ -41,7 +40,7 @@ class S3Integration(Integration): ] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the s3 integration. Returns: diff --git a/src/zenml/integrations/s3/artifact_stores/s3_artifact_store.py b/src/zenml/integrations/s3/artifact_stores/s3_artifact_store.py index 18513ad5835..d7ff85022ed 100644 --- a/src/zenml/integrations/s3/artifact_stores/s3_artifact_store.py +++ b/src/zenml/integrations/s3/artifact_stores/s3_artifact_store.py @@ -16,16 +16,10 @@ from contextlib import contextmanager from typing import ( Any, - Callable, - Dict, - Generator, - Iterable, - List, - Optional, - Tuple, Union, cast, ) +from collections.abc import Callable, Generator, Iterable import boto3 import s3fs @@ -117,7 +111,7 @@ def close_session(loop: Any, s3: Any) -> None: class S3ArtifactStore(BaseArtifactStore, AuthenticationMixin): """Artifact Store for S3 based artifacts.""" - _filesystem: Optional[ZenMLS3Filesystem] = None + _filesystem: ZenMLS3Filesystem | None = None is_versioned: bool = False @@ -158,7 +152,7 @@ def config(self) -> S3ArtifactStoreConfig: def get_credentials( self, - ) -> Tuple[Optional[str], Optional[str], Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None, str | None, str | None]: """Gets authentication credentials. If an authentication secret is configured, the secret values are @@ -290,7 +284,7 @@ def exists(self, path: PathType) -> bool: """ return self.filesystem.exists(path=path) # type: ignore[no-any-return] - def glob(self, pattern: PathType) -> List[PathType]: + def glob(self, pattern: PathType) -> list[PathType]: """Return all paths that match the given glob pattern. The glob pattern may include: @@ -319,7 +313,7 @@ def isdir(self, path: PathType) -> bool: """ return self.filesystem.isdir(path=path) # type: ignore[no-any-return] - def listdir(self, path: PathType) -> List[PathType]: + def listdir(self, path: PathType) -> list[PathType]: """Return a list of files in a directory. Args: @@ -334,7 +328,7 @@ def listdir(self, path: PathType) -> List[PathType]: if path.startswith("s3://"): path = path[5:] - def _extract_basename(file_dict: Dict[str, Any]) -> str: + def _extract_basename(file_dict: dict[str, Any]) -> str: """Extracts the basename from a file info dict returned by the S3 filesystem. Args: @@ -415,7 +409,7 @@ def rmtree(self, path: PathType) -> None: """ self.filesystem.delete(path=path, recursive=True) - def stat(self, path: PathType) -> Dict[str, Any]: + def stat(self, path: PathType) -> dict[str, Any]: """Return stat info for the given path. Args: @@ -441,8 +435,8 @@ def walk( self, top: PathType, topdown: bool = True, - onerror: Optional[Callable[..., None]] = None, - ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]: + onerror: Callable[..., None] | None = None, + ) -> Iterable[tuple[PathType, list[PathType], list[PathType]]]: """Return an iterator that walks the contents of the given directory. Args: diff --git a/src/zenml/integrations/s3/flavors/s3_artifact_store_flavor.py b/src/zenml/integrations/s3/flavors/s3_artifact_store_flavor.py index a956c47fd7a..ecfbd4fcf9b 100644 --- a/src/zenml/integrations/s3/flavors/s3_artifact_store_flavor.py +++ b/src/zenml/integrations/s3/flavors/s3_artifact_store_flavor.py @@ -18,10 +18,6 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - Optional, - Set, - Type, ) from pydantic import Field, field_validator @@ -62,47 +58,47 @@ class S3ArtifactStoreConfig( ``` """ - SUPPORTED_SCHEMES: ClassVar[Set[str]] = {"s3://"} + SUPPORTED_SCHEMES: ClassVar[set[str]] = {"s3://"} - key: Optional[str] = SecretField( + key: str | None = SecretField( default=None, description="AWS access key ID for authentication. " "If not provided, credentials will be inferred from the environment.", ) - secret: Optional[str] = SecretField( + secret: str | None = SecretField( default=None, description="AWS secret access key for authentication. " "If not provided, credentials will be inferred from the environment.", ) - token: Optional[str] = SecretField( + token: str | None = SecretField( default=None, description="AWS session token for temporary credentials. " "If not provided, credentials will be inferred from the environment.", ) - client_kwargs: Optional[Dict[str, Any]] = Field( + client_kwargs: dict[str, Any] | None = Field( None, description="Additional keyword arguments to pass to the S3 client. " "For example, to connect to a custom S3-compatible endpoint: " "{'endpoint_url': 'http://minio:9000'}", ) - config_kwargs: Optional[Dict[str, Any]] = Field( + config_kwargs: dict[str, Any] | None = Field( None, description="Additional keyword arguments to pass to the S3 client configuration. " "For example: {'region_name': 'us-west-2', 'signature_version': 's3v4'}", ) - s3_additional_kwargs: Optional[Dict[str, Any]] = Field( + s3_additional_kwargs: dict[str, Any] | None = Field( None, description="Additional keyword arguments for S3 operations. " "For example: {'ACL': 'bucket-owner-full-control'}", ) - _bucket: Optional[str] = None + _bucket: str | None = None @field_validator("client_kwargs") @classmethod def _validate_client_kwargs( - cls, value: Optional[Dict[str, Any]] - ) -> Optional[Dict[str, Any]]: + cls, value: dict[str, Any] | None + ) -> dict[str, Any] | None: """Validates the `client_kwargs` attribute. Args: @@ -161,7 +157,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -177,7 +173,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A URL to point at docs explaining this flavor. Returns: @@ -186,7 +182,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A URL to point at SDK docs explaining this flavor. Returns: @@ -204,7 +200,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/artifact_store/aws.png" @property - def config_class(self) -> Type[S3ArtifactStoreConfig]: + def config_class(self) -> type[S3ArtifactStoreConfig]: """The config class of the flavor. Returns: @@ -213,7 +209,7 @@ def config_class(self) -> Type[S3ArtifactStoreConfig]: return S3ArtifactStoreConfig @property - def implementation_class(self) -> Type["S3ArtifactStore"]: + def implementation_class(self) -> type["S3ArtifactStore"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/s3/utils.py b/src/zenml/integrations/s3/utils.py index e481f936283..e4f23a73976 100644 --- a/src/zenml/integrations/s3/utils.py +++ b/src/zenml/integrations/s3/utils.py @@ -13,10 +13,9 @@ # permissions and limitations under the License. """Utility methods for S3.""" -from typing import Tuple -def split_s3_path(s3_path: str) -> Tuple[str, str]: +def split_s3_path(s3_path: str) -> tuple[str, str]: """Split S3 URI into bucket and key. Args: diff --git a/src/zenml/integrations/scipy/materializers/sparse_materializer.py b/src/zenml/integrations/scipy/materializers/sparse_materializer.py index ce30793a9c7..629905b234d 100644 --- a/src/zenml/integrations/scipy/materializers/sparse_materializer.py +++ b/src/zenml/integrations/scipy/materializers/sparse_materializer.py @@ -14,7 +14,7 @@ """Implementation of the Scipy Sparse Materializer.""" import os -from typing import Any, ClassVar, Dict, Tuple, Type +from typing import Any, ClassVar from scipy.sparse import load_npz, save_npz, spmatrix @@ -29,10 +29,10 @@ class SparseMaterializer(BaseMaterializer): """Materializer to read and write scipy sparse matrices.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (spmatrix,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (spmatrix,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - def load(self, data_type: Type[Any]) -> spmatrix: + def load(self, data_type: type[Any]) -> spmatrix: """Reads spmatrix from npz file. Args: @@ -54,7 +54,7 @@ def save(self, mat: spmatrix) -> None: with fileio.open(os.path.join(self.uri, DATA_FILENAME), "wb") as f: save_npz(f, mat) - def extract_metadata(self, mat: spmatrix) -> Dict[str, "MetadataType"]: + def extract_metadata(self, mat: spmatrix) -> dict[str, "MetadataType"]: """Extract metadata from the given `spmatrix` object. Args: diff --git a/src/zenml/integrations/seldon/__init__.py b/src/zenml/integrations/seldon/__init__.py index 5497b38f1ec..2c6a161d4e3 100644 --- a/src/zenml/integrations/seldon/__init__.py +++ b/src/zenml/integrations/seldon/__init__.py @@ -16,7 +16,6 @@ The Seldon Core integration allows you to use the Seldon Core model serving platform to implement continuous model deployment. """ -from typing import List, Type, Optional from zenml.integrations.constants import SELDON from zenml.integrations.integration import Integration @@ -42,7 +41,7 @@ def activate(cls) -> None: from zenml.integrations.seldon import services # noqa @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Seldon Core. Returns: @@ -53,8 +52,8 @@ def flavors(cls) -> List[Type[Flavor]]: return [SeldonModelDeployerFlavor] @classmethod - def get_requirements(cls, target_os: Optional[str] = None, python_version: Optional[str] = None - ) -> List[str]: + def get_requirements(cls, target_os: str | None = None, python_version: str | None = None + ) -> list[str]: """Method to get the requirements for the integration. Args: diff --git a/src/zenml/integrations/seldon/custom_deployer/zenml_custom_model.py b/src/zenml/integrations/seldon/custom_deployer/zenml_custom_model.py index 0d844474acd..e3d5af9e24a 100644 --- a/src/zenml/integrations/seldon/custom_deployer/zenml_custom_model.py +++ b/src/zenml/integrations/seldon/custom_deployer/zenml_custom_model.py @@ -14,7 +14,7 @@ """Implements a custom model for the Seldon integration.""" import subprocess -from typing import Any, Dict, List, Optional, Union +from typing import Any, Union import click import numpy as np @@ -27,7 +27,7 @@ DEFAULT_MODEL_NAME = "model" DEFAULT_LOCAL_MODEL_DIR = "/mnt/models" -Array_Like = Union[np.ndarray[Any, Any], List[Any], str, bytes, Dict[str, Any]] +Array_Like = Union[np.ndarray[Any, Any], list[Any], str, bytes, dict[str, Any]] class ZenMLCustomModel: @@ -78,7 +78,7 @@ def load(self) -> bool: self.model = load_model_from_metadata(self.model_uri) except Exception as e: - logger.error("Failed to load model: {}".format(e)) + logger.error(f"Failed to load model: {e}") return False self.ready = True return self.ready @@ -86,7 +86,7 @@ def load(self) -> bool: def predict( self, X: Array_Like, - features_names: Optional[List[str]], + features_names: list[str] | None, **kwargs: Any, ) -> Array_Like: """Predict the given request. @@ -112,7 +112,7 @@ def predict( try: prediction = {"predictions": self.predict_func(self.model, X)} except Exception as e: - raise Exception("Failed to predict: {}".format(e)) + raise Exception(f"Failed to predict: {e}") if isinstance(prediction, dict): return prediction else: diff --git a/src/zenml/integrations/seldon/flavors/seldon_model_deployer_flavor.py b/src/zenml/integrations/seldon/flavors/seldon_model_deployer_flavor.py index cd458ca1efb..ebaed9ac69f 100644 --- a/src/zenml/integrations/seldon/flavors/seldon_model_deployer_flavor.py +++ b/src/zenml/integrations/seldon/flavors/seldon_model_deployer_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Seldon model deployer flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.constants import KUBERNETES_CLUSTER_RESOURCE_TYPE from zenml.integrations.seldon import SELDON_MODEL_DEPLOYER_FLAVOR @@ -60,13 +60,13 @@ class SeldonModelDeployerConfig(BaseModelDeployerConfig): ZenML and is already present in the Kubernetes cluster. """ - kubernetes_context: Optional[str] = None - kubernetes_namespace: Optional[str] = None + kubernetes_context: str | None = None + kubernetes_namespace: str | None = None base_url: str # TODO: unused? - secret: Optional[str] - kubernetes_secret_name: Optional[ + secret: str | None + kubernetes_secret_name: None | ( str - ] # TODO: Add full documentation section on this + ) # TODO: Add full documentation section on this class SeldonModelDeployerFlavor(BaseModelDeployerFlavor): @@ -84,7 +84,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -99,7 +99,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -108,7 +108,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -126,7 +126,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/seldon.png" @property - def config_class(self) -> Type[SeldonModelDeployerConfig]: + def config_class(self) -> type[SeldonModelDeployerConfig]: """Returns `SeldonModelDeployerConfig` config class. Returns: @@ -135,7 +135,7 @@ def config_class(self) -> Type[SeldonModelDeployerConfig]: return SeldonModelDeployerConfig @property - def implementation_class(self) -> Type["SeldonModelDeployer"]: + def implementation_class(self) -> type["SeldonModelDeployer"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/seldon/model_deployers/seldon_model_deployer.py b/src/zenml/integrations/seldon/model_deployers/seldon_model_deployer.py index 07878fc17e0..fde00815394 100644 --- a/src/zenml/integrations/seldon/model_deployers/seldon_model_deployer.py +++ b/src/zenml/integrations/seldon/model_deployers/seldon_model_deployer.py @@ -15,7 +15,7 @@ import json import re -from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, Type, cast +from typing import TYPE_CHECKING, ClassVar, cast from uuid import UUID from zenml.analytics.enums import AnalyticsEvent @@ -59,9 +59,9 @@ class SeldonModelDeployer(BaseModelDeployer): """Seldon Core model deployer stack component implementation.""" NAME: ClassVar[str] = "Seldon Core" - FLAVOR: ClassVar[Type[BaseModelDeployerFlavor]] = SeldonModelDeployerFlavor + FLAVOR: ClassVar[type[BaseModelDeployerFlavor]] = SeldonModelDeployerFlavor - _client: Optional[SeldonClient] = None + _client: SeldonClient | None = None @property def config(self) -> SeldonModelDeployerConfig: @@ -73,7 +73,7 @@ def config(self) -> SeldonModelDeployerConfig: return cast(SeldonModelDeployerConfig, self._config) @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Ensures there is a container registry and image builder in the stack. Returns: @@ -88,7 +88,7 @@ def validator(self) -> Optional[StackValidator]: @staticmethod def get_model_server_info( # type: ignore[override] service_instance: "SeldonDeploymentService", - ) -> Dict[str, Optional[str]]: + ) -> dict[str, str | None]: """Return implementation specific information that might be relevant to the user. Args: @@ -122,7 +122,7 @@ def seldon_client(self) -> SeldonClient: return self._client connector = self.get_connector() - kube_client: Optional[k8s_client.ApiClient] = None + kube_client: k8s_client.ApiClient | None = None if connector: if not self.config.kubernetes_namespace: raise RuntimeError( @@ -180,7 +180,7 @@ def kubernetes_secret_name(self) -> str: def get_docker_builds( self, snapshot: "PipelineSnapshotBase" - ) -> List["BuildConfiguration"]: + ) -> list["BuildConfiguration"]: """Gets the Docker builds required for the component. Args: @@ -201,7 +201,7 @@ def get_docker_builds( return builds - def _create_or_update_kubernetes_secret(self) -> Optional[str]: + def _create_or_update_kubernetes_secret(self) -> str | None: """Create or update the Kubernetes secret used to access the artifact store. Uses the information stored in the ZenML secret configured for the model deployer. diff --git a/src/zenml/integrations/seldon/secret_schemas/secret_schemas.py b/src/zenml/integrations/seldon/secret_schemas/secret_schemas.py index 90d593cffc3..b8e6f2d931e 100644 --- a/src/zenml/integrations/seldon/secret_schemas/secret_schemas.py +++ b/src/zenml/integrations/seldon/secret_schemas/secret_schemas.py @@ -13,9 +13,8 @@ # permissions and limitations under the License. """Implementation for Seldon secret schemas.""" -from typing import Optional -from typing_extensions import Literal +from typing import Literal from zenml.secret.base_secret import BaseSecretSchema @@ -43,11 +42,11 @@ class SeldonS3SecretSchema(BaseSecretSchema): rclone_config_s3_type: Literal["s3"] = "s3" rclone_config_s3_provider: str = "aws" rclone_config_s3_env_auth: bool = False - rclone_config_s3_access_key_id: Optional[str] = None - rclone_config_s3_secret_access_key: Optional[str] = None - rclone_config_s3_session_token: Optional[str] = None - rclone_config_s3_region: Optional[str] = None - rclone_config_s3_endpoint: Optional[str] = None + rclone_config_s3_access_key_id: str | None = None + rclone_config_s3_secret_access_key: str | None = None + rclone_config_s3_session_token: str | None = None + rclone_config_s3_region: str | None = None + rclone_config_s3_endpoint: str | None = None class SeldonGSSecretSchema(BaseSecretSchema): @@ -73,14 +72,14 @@ class SeldonGSSecretSchema(BaseSecretSchema): rclone_config_gs_type: Literal["google cloud storage"] = ( "google cloud storage" ) - rclone_config_gs_client_id: Optional[str] = None - rclone_config_gs_client_secret: Optional[str] = None - rclone_config_gs_project_number: Optional[str] = None - rclone_config_gs_service_account_credentials: Optional[str] = None + rclone_config_gs_client_id: str | None = None + rclone_config_gs_client_secret: str | None = None + rclone_config_gs_project_number: str | None = None + rclone_config_gs_service_account_credentials: str | None = None rclone_config_gs_anonymous: bool = False - rclone_config_gs_token: Optional[str] = None - rclone_config_gs_auth_url: Optional[str] = None - rclone_config_gs_token_url: Optional[str] = None + rclone_config_gs_token: str | None = None + rclone_config_gs_auth_url: str | None = None + rclone_config_gs_token_url: str | None = None class SeldonAzureSecretSchema(BaseSecretSchema): @@ -111,10 +110,10 @@ class SeldonAzureSecretSchema(BaseSecretSchema): rclone_config_az_type: Literal["azureblob"] = "azureblob" rclone_config_az_env_auth: bool = False - rclone_config_az_account: Optional[str] = None - rclone_config_az_key: Optional[str] = None - rclone_config_az_sas_url: Optional[str] = None + rclone_config_az_account: str | None = None + rclone_config_az_key: str | None = None + rclone_config_az_sas_url: str | None = None rclone_config_az_use_msi: bool = False - rclone_config_az_client_secret: Optional[str] = None - rclone_config_az_client_id: Optional[str] = None - rclone_config_az_tenant: Optional[str] = None + rclone_config_az_client_secret: str | None = None + rclone_config_az_client_id: str | None = None + rclone_config_az_tenant: str | None = None diff --git a/src/zenml/integrations/seldon/seldon_client.py b/src/zenml/integrations/seldon/seldon_client.py index 8e040096104..4b7cf934344 100644 --- a/src/zenml/integrations/seldon/seldon_client.py +++ b/src/zenml/integrations/seldon/seldon_client.py @@ -17,7 +17,8 @@ import json import re import time -from typing import Any, Dict, Generator, List, Literal, Optional +from typing import Any, Literal +from collections.abc import Generator from kubernetes import client as k8s_client from kubernetes import config as k8s_config @@ -59,8 +60,8 @@ class SeldonResourceRequirements(BaseModel): requests: resources requested by the model """ - limits: Dict[str, str] = Field(default_factory=dict) - requests: Dict[str, str] = Field(default_factory=dict) + limits: dict[str, str] = Field(default_factory=dict) + requests: dict[str, str] = Field(default_factory=dict) class SeldonDeploymentMetadata(BaseModel): @@ -74,9 +75,9 @@ class SeldonDeploymentMetadata(BaseModel): """ name: str - labels: Dict[str, str] = Field(default_factory=dict) - annotations: Dict[str, str] = Field(default_factory=dict) - creationTimestamp: Optional[str] = None + labels: dict[str, str] = Field(default_factory=dict) + annotations: dict[str, str] = Field(default_factory=dict) + creationTimestamp: str | None = None model_config = ConfigDict( # validate attribute assignments validate_assignment=True, @@ -114,15 +115,15 @@ class SeldonDeploymentPredictiveUnit(BaseModel): """ name: str - type: Optional[SeldonDeploymentPredictiveUnitType] = ( + type: SeldonDeploymentPredictiveUnitType | None = ( SeldonDeploymentPredictiveUnitType.MODEL ) - implementation: Optional[str] = None - modelUri: Optional[str] = None - parameters: Optional[List[SeldonDeploymentPredictorParameter]] = None - serviceAccountName: Optional[str] = None - envSecretRefName: Optional[str] = None - children: Optional[List["SeldonDeploymentPredictiveUnit"]] = None + implementation: str | None = None + modelUri: str | None = None + parameters: list[SeldonDeploymentPredictorParameter] | None = None + serviceAccountName: str | None = None + envSecretRefName: str | None = None + children: list["SeldonDeploymentPredictiveUnit"] | None = None model_config = ConfigDict( # validate attribute assignments validate_assignment=True, @@ -138,7 +139,7 @@ class SeldonDeploymentComponentSpecs(BaseModel): spec: the component spec. """ - spec: Optional[Dict[str, Any]] = None + spec: dict[str, Any] | None = None model_config = ConfigDict( # validate attribute assignments validate_assignment=True, @@ -159,10 +160,10 @@ class SeldonDeploymentPredictor(BaseModel): name: str replicas: int = 1 graph: SeldonDeploymentPredictiveUnit - engineResources: Optional[SeldonResourceRequirements] = Field( + engineResources: SeldonResourceRequirements | None = Field( default_factory=SeldonResourceRequirements ) - componentSpecs: Optional[List[SeldonDeploymentComponentSpecs]] = None + componentSpecs: list[SeldonDeploymentComponentSpecs] | None = None model_config = ConfigDict( # validate attribute assignments validate_assignment=True, @@ -182,8 +183,8 @@ class SeldonDeploymentSpec(BaseModel): """ name: str - protocol: Optional[str] = None - predictors: List[SeldonDeploymentPredictor] + protocol: str | None = None + predictors: list[SeldonDeploymentPredictor] replicas: int = 1 model_config = ConfigDict( # validate attribute assignments @@ -226,8 +227,8 @@ class SeldonDeploymentStatusCondition(BaseModel): type: str status: bool - reason: Optional[str] = None - message: Optional[str] = None + reason: str | None = None + message: str | None = None class SeldonDeploymentStatus(BaseModel): @@ -242,10 +243,10 @@ class SeldonDeploymentStatus(BaseModel): """ state: SeldonDeploymentStatusState = SeldonDeploymentStatusState.UNKNOWN - description: Optional[str] = None - replicas: Optional[int] = None - address: Optional[SeldonDeploymentStatusAddress] = None - conditions: List[SeldonDeploymentStatusCondition] + description: str | None = None + replicas: int | None = None + address: SeldonDeploymentStatusAddress | None = None + conditions: list[SeldonDeploymentStatusCondition] model_config = ConfigDict( # validate attribute assignments validate_assignment=True, @@ -280,7 +281,7 @@ class SeldonDeployment(BaseModel): ) metadata: SeldonDeploymentMetadata spec: SeldonDeploymentSpec - status: Optional[SeldonDeploymentStatus] = None + status: SeldonDeploymentStatus | None = None def __str__(self) -> str: """Returns a string representation of the Seldon Deployment. @@ -293,18 +294,18 @@ def __str__(self) -> str: @classmethod def build( cls, - name: Optional[str] = None, - model_uri: Optional[str] = None, - model_name: Optional[str] = None, - implementation: Optional[str] = None, - parameters: Optional[List[SeldonDeploymentPredictorParameter]] = None, - engineResources: Optional[SeldonResourceRequirements] = None, - secret_name: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - annotations: Optional[Dict[str, str]] = None, - is_custom_deployment: Optional[bool] = False, - spec: Optional[Dict[Any, Any]] = None, - serviceAccountName: Optional[str] = None, + name: str | None = None, + model_uri: str | None = None, + model_name: str | None = None, + implementation: str | None = None, + parameters: list[SeldonDeploymentPredictorParameter] | None = None, + engineResources: SeldonResourceRequirements | None = None, + secret_name: str | None = None, + labels: dict[str, str] | None = None, + annotations: dict[str, str] | None = None, + is_custom_deployment: bool | None = False, + spec: dict[Any, Any] | None = None, + serviceAccountName: str | None = None, ) -> "SeldonDeployment": """Build a basic Seldon Deployment object. @@ -452,7 +453,7 @@ def is_failed(self) -> bool: """ return self.state == SeldonDeploymentStatusState.FAILED - def get_error(self) -> Optional[str]: + def get_error(self) -> str | None: """Get a message describing the error, if in an error state. Returns: @@ -463,7 +464,7 @@ def get_error(self) -> Optional[str]: return self.status.description return None - def get_pending_message(self) -> Optional[str]: + def get_pending_message(self) -> str | None: """Get a message describing the pending conditions of the Seldon Deployment. Returns: @@ -510,9 +511,9 @@ class SeldonClient: def __init__( self, - context: Optional[str], - namespace: Optional[str], - kube_client: Optional[k8s_client.ApiClient] = None, + context: str | None, + namespace: str | None, + kube_client: k8s_client.ApiClient | None = None, ): """Initialize a Seldon Core client. @@ -526,8 +527,8 @@ def __init__( def _initialize_k8s_clients( self, - context: Optional[str], - kube_client: Optional[k8s_client.ApiClient] = None, + context: str | None, + kube_client: k8s_client.ApiClient | None = None, ) -> None: """Initialize the Kubernetes clients. @@ -571,7 +572,7 @@ def _initialize_k8s_clients( self._custom_objects_api = k8s_client.CustomObjectsApi() @staticmethod - def sanitize_labels(labels: Dict[str, str]) -> None: + def sanitize_labels(labels: dict[str, str]) -> None: """Update the label values to be valid Kubernetes labels. See: @@ -856,10 +857,10 @@ def get_deployment(self, name: str) -> SeldonDeployment: def find_deployments( self, - name: Optional[str] = None, - labels: Optional[Dict[str, str]] = None, - fields: Optional[Dict[str, str]] = None, - ) -> List[SeldonDeployment]: + name: str | None = None, + labels: dict[str, str] | None = None, + fields: dict[str, str] | None = None, + ) -> list[SeldonDeployment]: """Find all ZenML-managed Seldon Core deployment resources matching the given criteria. Args: @@ -932,7 +933,7 @@ def get_deployment_logs( self, name: str, follow: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> Generator[str, bool, None]: """Get the logs of a Seldon Core deployment resource. @@ -1020,7 +1021,7 @@ def get_deployment_logs( def create_or_update_secret( self, name: str, - secret_values: Dict[str, Any], + secret_values: dict[str, Any], ) -> None: """Create or update a Kubernetes Secret resource. @@ -1119,11 +1120,11 @@ def delete_secret( def create_seldon_core_custom_spec( - model_uri: Optional[str], - custom_docker_image: Optional[str], - secret_name: Optional[str], - command: Optional[List[str]], - container_registry_secret_name: Optional[str] = None, + model_uri: str | None, + custom_docker_image: str | None, + secret_name: str | None, + command: list[str] | None, + container_registry_secret_name: str | None = None, ) -> k8s_client.V1PodSpec: """Create a custom pod spec for the seldon core container. diff --git a/src/zenml/integrations/seldon/services/seldon_deployment.py b/src/zenml/integrations/seldon/services/seldon_deployment.py index 4594764892c..f530c1c2271 100644 --- a/src/zenml/integrations/seldon/services/seldon_deployment.py +++ b/src/zenml/integrations/seldon/services/seldon_deployment.py @@ -15,7 +15,8 @@ import json import os -from typing import Any, Dict, Generator, List, Optional, Tuple, cast +from typing import Any, cast +from collections.abc import Generator from uuid import UUID import requests @@ -65,17 +66,17 @@ class SeldonDeploymentConfig(ServiceConfig): model_name: str = "default" # TODO [ENG-775]: have an enum of all supported Seldon Core implementations implementation: str - parameters: Optional[List[SeldonDeploymentPredictorParameter]] - resources: Optional[SeldonResourceRequirements] + parameters: list[SeldonDeploymentPredictorParameter] | None + resources: SeldonResourceRequirements | None replicas: int = 1 - secret_name: Optional[str] - model_metadata: Dict[str, Any] = Field(default_factory=dict) - extra_args: Dict[str, Any] = Field(default_factory=dict) - is_custom_deployment: Optional[bool] = False - spec: Optional[Dict[Any, Any]] = Field(default_factory=dict) - serviceAccountName: Optional[str] = None - - def get_seldon_deployment_labels(self) -> Dict[str, str]: + secret_name: str | None + model_metadata: dict[str, Any] = Field(default_factory=dict) + extra_args: dict[str, Any] = Field(default_factory=dict) + is_custom_deployment: bool | None = False + spec: dict[Any, Any] | None = Field(default_factory=dict) + serviceAccountName: str | None = None + + def get_seldon_deployment_labels(self) -> dict[str, str]: """Generate labels for the Seldon Core deployment from the service configuration. These labels are attached to the Seldon Core deployment resource @@ -101,7 +102,7 @@ def get_seldon_deployment_labels(self) -> Dict[str, str]: SeldonClient.sanitize_labels(labels) return labels - def get_seldon_deployment_annotations(self) -> Dict[str, str]: + def get_seldon_deployment_annotations(self) -> dict[str, str]: """Generate annotations for the Seldon Core deployment from the service configuration. The annotations are used to store additional information about the @@ -197,7 +198,7 @@ def _get_client(self) -> SeldonClient: ) return model_deployer.seldon_client - def check_status(self) -> Tuple[ServiceState, str]: + def check_status(self) -> tuple[ServiceState, str]: """Check the the current operational state of the Seldon Core deployment. Returns: @@ -242,7 +243,7 @@ def seldon_deployment_name(self) -> str: """ return f"zenml-{str(self.uuid)}" - def _get_seldon_deployment_labels(self) -> Dict[str, str]: + def _get_seldon_deployment_labels(self) -> dict[str, str]: """Generate the labels for the Seldon Core deployment from the service configuration. Returns: @@ -335,7 +336,7 @@ def deprovision(self, force: bool = False) -> None: def get_logs( self, follow: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> Generator[str, bool, None]: """Get the logs of a Seldon Core model deployment. @@ -353,7 +354,7 @@ def get_logs( ) @property - def prediction_url(self) -> Optional[str]: + def prediction_url(self) -> str | None: """The prediction URI exposed by the prediction service. Returns: diff --git a/src/zenml/integrations/seldon/steps/__init__.py b/src/zenml/integrations/seldon/steps/__init__.py index 1bbc434f882..cabb5bde538 100644 --- a/src/zenml/integrations/seldon/steps/__init__.py +++ b/src/zenml/integrations/seldon/steps/__init__.py @@ -13,7 +13,3 @@ # permissions and limitations under the License. """Initialization for Seldon steps.""" -from zenml.integrations.seldon.steps.seldon_deployer import ( - seldon_custom_model_deployer_step, - seldon_model_deployer_step, -) diff --git a/src/zenml/integrations/seldon/steps/seldon_deployer.py b/src/zenml/integrations/seldon/steps/seldon_deployer.py index 9a77da9a6f5..e333c86bc11 100644 --- a/src/zenml/integrations/seldon/steps/seldon_deployer.py +++ b/src/zenml/integrations/seldon/steps/seldon_deployer.py @@ -14,7 +14,7 @@ """Implementation of the Seldon Deployer step.""" import os -from typing import Optional, cast +from typing import cast from zenml import get_step_context, step from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact @@ -338,9 +338,9 @@ def seldon_custom_model_deployer_step( @step(enable_cache=False) def seldon_mlflow_registry_deployer_step( service_config: SeldonDeploymentConfig, - registry_model_name: Optional[str] = None, - registry_model_version: Optional[str] = None, - registry_model_stage: Optional[ModelVersionStage] = None, + registry_model_name: str | None = None, + registry_model_version: str | None = None, + registry_model_stage: ModelVersionStage | None = None, replace_existing: bool = True, timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT, ) -> SeldonDeploymentService: diff --git a/src/zenml/integrations/sklearn/materializers/sklearn_materializer.py b/src/zenml/integrations/sklearn/materializers/sklearn_materializer.py index b11f7fe7080..3beaa583031 100644 --- a/src/zenml/integrations/sklearn/materializers/sklearn_materializer.py +++ b/src/zenml/integrations/sklearn/materializers/sklearn_materializer.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the sklearn materializer.""" -from typing import Any, ClassVar, Tuple, Type +from typing import Any, ClassVar from sklearn.base import ( BaseEstimator, @@ -37,7 +37,7 @@ class SklearnMaterializer(CloudpickleMaterializer): """Materializer to read data to and from sklearn.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = ( + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = ( BaseEstimator, ClassifierMixin, ClusterMixin, diff --git a/src/zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py b/src/zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py index 8b5df841ed6..f01fc8b913f 100644 --- a/src/zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py +++ b/src/zenml/integrations/skypilot/flavors/skypilot_orchestrator_base_vm_config.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Skypilot orchestrator base config and settings.""" -from typing import Any, Dict, List, Literal, Optional, Union +from typing import Any, Literal from pydantic import Field @@ -112,55 +112,55 @@ class SkypilotBaseOrchestratorSettings(BaseSettings): """ # Resources - instance_type: Optional[str] = None - cpus: Union[None, int, float, str] = Field( + instance_type: str | None = None + cpus: None | int | float | str = Field( default=None, union_mode="left_to_right" ) - memory: Union[None, int, float, str] = Field( + memory: None | int | float | str = Field( default=None, union_mode="left_to_right" ) - accelerators: Union[None, str, Dict[str, int]] = Field( + accelerators: None | str | dict[str, int] = Field( default=None, union_mode="left_to_right" ) - accelerator_args: Optional[Dict[str, str]] = None - use_spot: Optional[bool] = None - job_recovery: Union[None, str, Dict[str, Any]] = Field( + accelerator_args: dict[str, str] | None = None + use_spot: bool | None = None + job_recovery: None | str | dict[str, Any] = Field( default=None, union_mode="left_to_right" ) - region: Optional[str] = None - zone: Optional[str] = None - image_id: Union[Dict[str, str], str, None] = Field( + region: str | None = None + zone: str | None = None + image_id: dict[str, str] | str | None = Field( default=None, union_mode="left_to_right" ) - disk_size: Optional[int] = None - disk_tier: Optional[Literal["high", "medium", "low", "ultra", "best"]] = ( + disk_size: int | None = None + disk_tier: Literal["high", "medium", "low", "ultra", "best"] | None = ( None ) # Run settings - cluster_name: Optional[str] = None + cluster_name: str | None = None retry_until_up: bool = False - idle_minutes_to_autostop: Optional[int] = 30 + idle_minutes_to_autostop: int | None = 30 down: bool = True stream_logs: bool = True - docker_run_args: List[str] = [] + docker_run_args: list[str] = [] # Additional SkyPilot features - ports: Union[None, int, str, List[Union[int, str]]] = Field( + ports: None | int | str | list[int | str] = Field( default=None, union_mode="left_to_right" ) - labels: Optional[Dict[str, str]] = None - any_of: Optional[List[Dict[str, Any]]] = None - ordered: Optional[List[Dict[str, Any]]] = None - workdir: Optional[str] = None - task_name: Optional[str] = None - file_mounts: Optional[Dict[str, Any]] = None - envs: Optional[Dict[str, str]] = None + labels: dict[str, str] | None = None + any_of: list[dict[str, Any]] | None = None + ordered: list[dict[str, Any]] | None = None + workdir: str | None = None + task_name: str | None = None + file_mounts: dict[str, Any] | None = None + envs: dict[str, str] | None = None # Future-proofing settings dictionaries - task_settings: Dict[str, Any] = {} - resources_settings: Dict[str, Any] = {} - launch_settings: Dict[str, Any] = {} + task_settings: dict[str, Any] = {} + resources_settings: dict[str, Any] = {} + launch_settings: dict[str, Any] = {} class SkypilotBaseOrchestratorConfig( diff --git a/src/zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py b/src/zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py index 747bc047708..d6f1f109511 100644 --- a/src/zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py +++ b/src/zenml/integrations/skypilot/orchestrators/skypilot_base_vm_orchestrator.py @@ -15,7 +15,7 @@ import os from abc import abstractmethod -from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast +from typing import TYPE_CHECKING, Optional, cast from uuid import uuid4 import sky @@ -65,10 +65,10 @@ class SkypilotBaseOrchestrator(ContainerizedOrchestrator): """ # The default instance type to use if none is specified in settings - DEFAULT_INSTANCE_TYPE: Optional[str] = None + DEFAULT_INSTANCE_TYPE: str | None = None @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates the stack. In the remote case, checks that the stack contains a container registry, @@ -80,7 +80,7 @@ def validator(self) -> Optional[StackValidator]: def _validate_remote_components( stack: "Stack", - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: for component in stack.components.values(): if not component.config.is_local: continue @@ -158,10 +158,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to diff --git a/src/zenml/integrations/skypilot/orchestrators/skypilot_orchestrator_entrypoint.py b/src/zenml/integrations/skypilot/orchestrators/skypilot_orchestrator_entrypoint.py index 54f7e2c7d20..1dcbffcc499 100644 --- a/src/zenml/integrations/skypilot/orchestrators/skypilot_orchestrator_entrypoint.py +++ b/src/zenml/integrations/skypilot/orchestrators/skypilot_orchestrator_entrypoint.py @@ -16,7 +16,7 @@ import argparse import socket import time -from typing import Dict, cast +from typing import cast import sky @@ -125,7 +125,7 @@ def main() -> None: use_sudo=False, # Entrypoint doesn't use sudo ) - unique_resource_configs: Dict[str, str] = {} + unique_resource_configs: dict[str, str] = {} for step_name, step in snapshot.step_configurations.items(): settings = cast( SkypilotBaseOrchestratorSettings, diff --git a/src/zenml/integrations/skypilot/orchestrators/skypilot_orchestrator_entrypoint_configuration.py b/src/zenml/integrations/skypilot/orchestrators/skypilot_orchestrator_entrypoint_configuration.py index 3caf2cf66ea..87414e62246 100644 --- a/src/zenml/integrations/skypilot/orchestrators/skypilot_orchestrator_entrypoint_configuration.py +++ b/src/zenml/integrations/skypilot/orchestrators/skypilot_orchestrator_entrypoint_configuration.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Entrypoint configuration for the Skypilot master/orchestrator VM.""" -from typing import TYPE_CHECKING, List, Set +from typing import TYPE_CHECKING if TYPE_CHECKING: from uuid import UUID @@ -26,7 +26,7 @@ class SkypilotOrchestratorEntrypointConfiguration: """Entrypoint configuration for the Skypilot master/orchestrator VM.""" @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> set[str]: """Gets all the options required for running this entrypoint. Returns: @@ -39,7 +39,7 @@ def get_entrypoint_options(cls) -> Set[str]: return options @classmethod - def get_entrypoint_command(cls) -> List[str]: + def get_entrypoint_command(cls) -> list[str]: """Returns a command that runs the entrypoint module. Returns: @@ -57,7 +57,7 @@ def get_entrypoint_arguments( cls, run_name: str, snapshot_id: "UUID", - ) -> List[str]: + ) -> list[str]: """Gets all arguments that the entrypoint command should be called with. Args: diff --git a/src/zenml/integrations/skypilot/utils.py b/src/zenml/integrations/skypilot/utils.py index 5b88a7e6227..d2738e7ca1f 100644 --- a/src/zenml/integrations/skypilot/utils.py +++ b/src/zenml/integrations/skypilot/utils.py @@ -1,7 +1,7 @@ """Utility functions for Skypilot orchestrators.""" import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any import sky @@ -36,9 +36,9 @@ def sanitize_cluster_name(name: str) -> str: def prepare_docker_setup( container_registry_uri: str, - credentials: Optional[Tuple[str, str]] = None, + credentials: tuple[str, str] | None = None, use_sudo: bool = True, -) -> Tuple[Optional[str], Dict[str, str]]: +) -> tuple[str | None, dict[str, str]]: """Prepare Docker login setup command and environment variables. Args: @@ -71,8 +71,8 @@ def create_docker_run_command( image: str, entrypoint_str: str, arguments_str: str, - environment: Dict[str, str], - docker_run_args: List[str], + environment: dict[str, str], + docker_run_args: list[str], use_sudo: bool = True, ) -> str: """Create a Docker run command string. @@ -102,10 +102,10 @@ def create_docker_run_command( def prepare_task_kwargs( settings: SkypilotBaseOrchestratorSettings, run_command: str, - setup: Optional[str], - task_envs: Dict[str, str], + setup: str | None, + task_envs: dict[str, str], task_name: str, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Prepare task keyword arguments for sky.Task. Args: @@ -146,9 +146,9 @@ def prepare_task_kwargs( def prepare_resources_kwargs( cloud: "Cloud", settings: SkypilotBaseOrchestratorSettings, - default_instance_type: Optional[str] = None, - kubernetes_image: Optional[str] = None, -) -> Dict[str, Any]: + default_instance_type: str | None = None, + kubernetes_image: str | None = None, +) -> dict[str, Any]: """Prepare resources keyword arguments for sky.Resources. Args: @@ -189,9 +189,9 @@ def prepare_resources_kwargs( def prepare_launch_kwargs( settings: SkypilotBaseOrchestratorSettings, - down: Optional[bool] = None, - idle_minutes_to_autostop: Optional[int] = None, -) -> Dict[str, Any]: + down: bool | None = None, + idle_minutes_to_autostop: int | None = None, +) -> dict[str, Any]: """Prepare launch keyword arguments for sky.launch. Args: @@ -238,7 +238,7 @@ def prepare_launch_kwargs( def sky_job_get( request_id: str, stream_logs: bool, cluster_name: str -) -> Optional[SubmissionResult]: +) -> SubmissionResult | None: """Handle SkyPilot request results based on stream_logs setting. SkyPilot API exec and launch methods are asynchronous and return a request ID. diff --git a/src/zenml/integrations/skypilot_aws/__init__.py b/src/zenml/integrations/skypilot_aws/__init__.py index 10483dafcf9..3140b945123 100644 --- a/src/zenml/integrations/skypilot_aws/__init__.py +++ b/src/zenml/integrations/skypilot_aws/__init__.py @@ -16,7 +16,6 @@ The Skypilot integration sub-module powers an alternative to the local orchestrator for a remote orchestration of ZenML pipelines on VMs. """ -from typing import List, Type from zenml.integrations.constants import ( SKYPILOT_AWS, @@ -36,7 +35,7 @@ class SkypilotAWSIntegration(Integration): APT_PACKAGES = ["openssh-client", "rsync"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Skypilot AWS integration. Returns: diff --git a/src/zenml/integrations/skypilot_aws/flavors/skypilot_orchestrator_aws_vm_flavor.py b/src/zenml/integrations/skypilot_aws/flavors/skypilot_orchestrator_aws_vm_flavor.py index c92678cca99..c1ec7e713a4 100644 --- a/src/zenml/integrations/skypilot_aws/flavors/skypilot_orchestrator_aws_vm_flavor.py +++ b/src/zenml/integrations/skypilot_aws/flavors/skypilot_orchestrator_aws_vm_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Skypilot orchestrator AWS flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.integrations.skypilot.flavors.skypilot_orchestrator_base_vm_config import ( SkypilotBaseOrchestratorConfig, @@ -58,7 +58,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -73,7 +73,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -82,7 +82,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -100,7 +100,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/aws-skypilot.png" @property - def config_class(self) -> Type[BaseOrchestratorConfig]: + def config_class(self) -> type[BaseOrchestratorConfig]: """Config class for the base orchestrator flavor. Returns: @@ -109,7 +109,7 @@ def config_class(self) -> Type[BaseOrchestratorConfig]: return SkypilotAWSOrchestratorConfig @property - def implementation_class(self) -> Type["SkypilotAWSOrchestrator"]: + def implementation_class(self) -> type["SkypilotAWSOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/skypilot_aws/orchestrators/skypilot_aws_vm_orchestrator.py b/src/zenml/integrations/skypilot_aws/orchestrators/skypilot_aws_vm_orchestrator.py index 28a23f9972c..645821594b7 100644 --- a/src/zenml/integrations/skypilot_aws/orchestrators/skypilot_aws_vm_orchestrator.py +++ b/src/zenml/integrations/skypilot_aws/orchestrators/skypilot_aws_vm_orchestrator.py @@ -14,7 +14,7 @@ """Implementation of the a Skypilot based AWS VM orchestrator.""" import os -from typing import TYPE_CHECKING, Optional, Type, cast +from typing import TYPE_CHECKING, cast import sky @@ -62,7 +62,7 @@ def config(self) -> SkypilotAWSOrchestratorConfig: return cast(SkypilotAWSOrchestratorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Skypilot orchestrator. Returns: diff --git a/src/zenml/integrations/skypilot_azure/__init__.py b/src/zenml/integrations/skypilot_azure/__init__.py index 53a0a63c9e0..71be8fc5900 100644 --- a/src/zenml/integrations/skypilot_azure/__init__.py +++ b/src/zenml/integrations/skypilot_azure/__init__.py @@ -16,7 +16,6 @@ The Skypilot integration sub-module powers an alternative to the local orchestrator for a remote orchestration of ZenML pipelines on VMs. """ -from typing import List, Type from zenml.integrations.constants import ( SKYPILOT_AZURE, @@ -35,7 +34,7 @@ class SkypilotAzureIntegration(Integration): APT_PACKAGES = ["openssh-client", "rsync"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Skypilot Azure integration. Returns: diff --git a/src/zenml/integrations/skypilot_azure/flavors/skypilot_orchestrator_azure_vm_flavor.py b/src/zenml/integrations/skypilot_azure/flavors/skypilot_orchestrator_azure_vm_flavor.py index 5f15aebd1dd..e30151186ef 100644 --- a/src/zenml/integrations/skypilot_azure/flavors/skypilot_orchestrator_azure_vm_flavor.py +++ b/src/zenml/integrations/skypilot_azure/flavors/skypilot_orchestrator_azure_vm_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Skypilot orchestrator Azure flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.integrations.skypilot.flavors.skypilot_orchestrator_base_vm_config import ( SkypilotBaseOrchestratorConfig, @@ -60,7 +60,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -75,7 +75,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -84,7 +84,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -102,7 +102,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/azure-skypilot.png" @property - def config_class(self) -> Type[BaseOrchestratorConfig]: + def config_class(self) -> type[BaseOrchestratorConfig]: """Config class for the base orchestrator flavor. Returns: @@ -111,7 +111,7 @@ def config_class(self) -> Type[BaseOrchestratorConfig]: return SkypilotAzureOrchestratorConfig @property - def implementation_class(self) -> Type["SkypilotAzureOrchestrator"]: + def implementation_class(self) -> type["SkypilotAzureOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/skypilot_azure/orchestrators/skypilot_azure_vm_orchestrator.py b/src/zenml/integrations/skypilot_azure/orchestrators/skypilot_azure_vm_orchestrator.py index 30669d44733..a22bca21544 100644 --- a/src/zenml/integrations/skypilot_azure/orchestrators/skypilot_azure_vm_orchestrator.py +++ b/src/zenml/integrations/skypilot_azure/orchestrators/skypilot_azure_vm_orchestrator.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the a Skypilot based Azure VM orchestrator.""" -from typing import TYPE_CHECKING, Optional, Type, cast +from typing import TYPE_CHECKING, cast import sky @@ -59,7 +59,7 @@ def config(self) -> SkypilotAzureOrchestratorConfig: return cast(SkypilotAzureOrchestratorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Skypilot orchestrator. Returns: @@ -73,4 +73,3 @@ def prepare_environment_variable(self, set: bool = True) -> None: Args: set: Whether to set the environment variables or not. """ - pass diff --git a/src/zenml/integrations/skypilot_gcp/__init__.py b/src/zenml/integrations/skypilot_gcp/__init__.py index 39238634bee..47e4c53b30c 100644 --- a/src/zenml/integrations/skypilot_gcp/__init__.py +++ b/src/zenml/integrations/skypilot_gcp/__init__.py @@ -16,7 +16,6 @@ The Skypilot integration sub-module powers an alternative to the local orchestrator for a remote orchestration of ZenML pipelines on VMs. """ -from typing import List, Type from zenml.integrations.constants import ( SKYPILOT_GCP, @@ -43,7 +42,7 @@ class SkypilotGCPIntegration(Integration): APT_PACKAGES = ["openssh-client", "rsync"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Skypilot GCP integration. Returns: diff --git a/src/zenml/integrations/skypilot_gcp/flavors/skypilot_orchestrator_gcp_vm_flavor.py b/src/zenml/integrations/skypilot_gcp/flavors/skypilot_orchestrator_gcp_vm_flavor.py index b9a6dc25c5f..6d3a196c5cd 100644 --- a/src/zenml/integrations/skypilot_gcp/flavors/skypilot_orchestrator_gcp_vm_flavor.py +++ b/src/zenml/integrations/skypilot_gcp/flavors/skypilot_orchestrator_gcp_vm_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Skypilot orchestrator GCP flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.integrations.gcp.google_credentials_mixin import ( GoogleCredentialsConfigMixin, @@ -63,7 +63,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -78,7 +78,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -87,7 +87,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -105,7 +105,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/gcp-skypilot.png" @property - def config_class(self) -> Type[BaseOrchestratorConfig]: + def config_class(self) -> type[BaseOrchestratorConfig]: """Config class for the base orchestrator flavor. Returns: @@ -114,7 +114,7 @@ def config_class(self) -> Type[BaseOrchestratorConfig]: return SkypilotGCPOrchestratorConfig @property - def implementation_class(self) -> Type["SkypilotGCPOrchestrator"]: + def implementation_class(self) -> type["SkypilotGCPOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/skypilot_gcp/orchestrators/skypilot_gcp_vm_orchestrator.py b/src/zenml/integrations/skypilot_gcp/orchestrators/skypilot_gcp_vm_orchestrator.py index 3086099981d..fcd7731bf0f 100644 --- a/src/zenml/integrations/skypilot_gcp/orchestrators/skypilot_gcp_vm_orchestrator.py +++ b/src/zenml/integrations/skypilot_gcp/orchestrators/skypilot_gcp_vm_orchestrator.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the a Skypilot-based GCP VM orchestrator.""" -from typing import TYPE_CHECKING, Optional, Type, cast +from typing import TYPE_CHECKING, cast import sky @@ -64,7 +64,7 @@ def config(self) -> SkypilotGCPOrchestratorConfig: return cast(SkypilotGCPOrchestratorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Skypilot orchestrator. Returns: @@ -78,4 +78,3 @@ def prepare_environment_variable(self, set: bool = True) -> None: Args: set: Whether to set the environment variables or not. """ - pass diff --git a/src/zenml/integrations/skypilot_kubernetes/__init__.py b/src/zenml/integrations/skypilot_kubernetes/__init__.py index a07a2a216f5..e0a02d74664 100644 --- a/src/zenml/integrations/skypilot_kubernetes/__init__.py +++ b/src/zenml/integrations/skypilot_kubernetes/__init__.py @@ -16,7 +16,6 @@ The Skypilot integration sub-module powers an alternative to the local orchestrator for a remote orchestration of ZenML pipelines on VMs. """ -from typing import List, Type from zenml.integrations.constants import ( SKYPILOT_KUBERNETES, @@ -36,7 +35,7 @@ class SkypilotKubernetesIntegration(Integration): APT_PACKAGES = ["openssh-client", "rsync"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Skypilot Kubernetes integration. Returns: diff --git a/src/zenml/integrations/skypilot_kubernetes/flavors/skypilot_orchestrator_kubernetes_vm_flavor.py b/src/zenml/integrations/skypilot_kubernetes/flavors/skypilot_orchestrator_kubernetes_vm_flavor.py index 7334750f6fd..9329a332126 100644 --- a/src/zenml/integrations/skypilot_kubernetes/flavors/skypilot_orchestrator_kubernetes_vm_flavor.py +++ b/src/zenml/integrations/skypilot_kubernetes/flavors/skypilot_orchestrator_kubernetes_vm_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Skypilot orchestrator Kubernetes flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.constants import KUBERNETES_CLUSTER_RESOURCE_TYPE from zenml.integrations.skypilot.flavors.skypilot_orchestrator_base_vm_config import ( @@ -61,7 +61,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -76,7 +76,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -85,7 +85,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -103,7 +103,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/kubernetes-skypilot.png" @property - def config_class(self) -> Type[BaseOrchestratorConfig]: + def config_class(self) -> type[BaseOrchestratorConfig]: """Config class for the base orchestrator flavor. Returns: @@ -112,7 +112,7 @@ def config_class(self) -> Type[BaseOrchestratorConfig]: return SkypilotKubernetesOrchestratorConfig @property - def implementation_class(self) -> Type["SkypilotKubernetesOrchestrator"]: + def implementation_class(self) -> type["SkypilotKubernetesOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/skypilot_kubernetes/orchestrators/skypilot_kubernetes_vm_orchestrator.py b/src/zenml/integrations/skypilot_kubernetes/orchestrators/skypilot_kubernetes_vm_orchestrator.py index e1ce34cd6de..2d0eccefe01 100644 --- a/src/zenml/integrations/skypilot_kubernetes/orchestrators/skypilot_kubernetes_vm_orchestrator.py +++ b/src/zenml/integrations/skypilot_kubernetes/orchestrators/skypilot_kubernetes_vm_orchestrator.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the a Skypilot based Kubernetes VM orchestrator.""" -from typing import TYPE_CHECKING, Optional, Type, cast +from typing import TYPE_CHECKING, cast import sky @@ -57,7 +57,7 @@ def config(self) -> SkypilotKubernetesOrchestratorConfig: return cast(SkypilotKubernetesOrchestratorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Skypilot orchestrator. Returns: @@ -71,4 +71,3 @@ def prepare_environment_variable(self, set: bool = True) -> None: Args: set: Whether to set the environment variables or not. """ - pass diff --git a/src/zenml/integrations/skypilot_lambda/__init__.py b/src/zenml/integrations/skypilot_lambda/__init__.py index 7ebb7b56b40..af6ece91d21 100644 --- a/src/zenml/integrations/skypilot_lambda/__init__.py +++ b/src/zenml/integrations/skypilot_lambda/__init__.py @@ -16,7 +16,6 @@ The Skypilot integration sub-module powers an alternative to the local orchestrator for a remote orchestration of ZenML pipelines on VMs. """ -from typing import List, Type from zenml.integrations.constants import ( SKYPILOT_LAMBDA, @@ -34,7 +33,7 @@ class SkypilotLambdaIntegration(Integration): REQUIREMENTS = ["skypilot[lambda]==0.9.3"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Skypilot Lambda integration. Returns: diff --git a/src/zenml/integrations/skypilot_lambda/flavors/skypilot_orchestrator_lambda_vm_flavor.py b/src/zenml/integrations/skypilot_lambda/flavors/skypilot_orchestrator_lambda_vm_flavor.py index 044c7986d50..f70569c7742 100644 --- a/src/zenml/integrations/skypilot_lambda/flavors/skypilot_orchestrator_lambda_vm_flavor.py +++ b/src/zenml/integrations/skypilot_lambda/flavors/skypilot_orchestrator_lambda_vm_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Skypilot orchestrator Lambda flavor.""" -from typing import TYPE_CHECKING, Any, Optional, Type +from typing import TYPE_CHECKING, Any from zenml.integrations.skypilot.flavors.skypilot_orchestrator_base_vm_config import ( SkypilotBaseOrchestratorConfig, @@ -65,7 +65,7 @@ class SkypilotLambdaOrchestratorConfig( ): """Skypilot orchestrator config.""" - api_key: Optional[str] = SecretField(default=None) + api_key: str | None = SecretField(default=None) class SkypilotLambdaOrchestratorFlavor(BaseOrchestratorFlavor): @@ -81,7 +81,7 @@ def name(self) -> str: return SKYPILOT_LAMBDA_ORCHESTRATOR_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -90,7 +90,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -108,7 +108,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/lambda.png" @property - def config_class(self) -> Type[BaseOrchestratorConfig]: + def config_class(self) -> type[BaseOrchestratorConfig]: """Config class for the base orchestrator flavor. Returns: @@ -117,7 +117,7 @@ def config_class(self) -> Type[BaseOrchestratorConfig]: return SkypilotLambdaOrchestratorConfig @property - def implementation_class(self) -> Type["SkypilotLambdaOrchestrator"]: + def implementation_class(self) -> type["SkypilotLambdaOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/skypilot_lambda/orchestrators/skypilot_lambda_vm_orchestrator.py b/src/zenml/integrations/skypilot_lambda/orchestrators/skypilot_lambda_vm_orchestrator.py index 5f948b372e4..c694330a2be 100644 --- a/src/zenml/integrations/skypilot_lambda/orchestrators/skypilot_lambda_vm_orchestrator.py +++ b/src/zenml/integrations/skypilot_lambda/orchestrators/skypilot_lambda_vm_orchestrator.py @@ -14,7 +14,7 @@ """Implementation of the a Skypilot based Lambda VM orchestrator.""" import os -from typing import TYPE_CHECKING, Optional, Type, cast +from typing import TYPE_CHECKING, cast import sky @@ -61,7 +61,7 @@ def config(self) -> SkypilotLambdaOrchestratorConfig: return cast(SkypilotLambdaOrchestratorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Skypilot orchestrator. Returns: @@ -75,7 +75,6 @@ def prepare_environment_variable(self, set: bool = True) -> None: Args: set: Whether to set the environment variables or not. """ - pass def setup_credentials(self) -> None: """Set up credentials for the orchestrator.""" diff --git a/src/zenml/integrations/slack/__init__.py b/src/zenml/integrations/slack/__init__.py index 4187d4c6803..16ac8557f3d 100644 --- a/src/zenml/integrations/slack/__init__.py +++ b/src/zenml/integrations/slack/__init__.py @@ -13,9 +13,7 @@ # permissions and limitations under the License. """Slack integration for alerter components.""" -from typing import List, Type -from zenml.enums import StackComponentType from zenml.integrations.constants import SLACK from zenml.integrations.integration import Integration from zenml.stack import Flavor @@ -33,7 +31,7 @@ class SlackIntegration(Integration): REQUIREMENTS = ["slack-sdk==3.30.0"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Slack integration. Returns: diff --git a/src/zenml/integrations/slack/alerters/slack_alerter.py b/src/zenml/integrations/slack/alerters/slack_alerter.py index 192f9e4c2af..5fcab60e916 100644 --- a/src/zenml/integrations/slack/alerters/slack_alerter.py +++ b/src/zenml/integrations/slack/alerters/slack_alerter.py @@ -15,7 +15,7 @@ # permissions and limitations under the License. import time -from typing import Dict, List, Optional, Type, cast +from typing import cast from pydantic import BaseModel from slack_sdk import WebClient @@ -39,27 +39,27 @@ class SlackAlerterPayload(BaseModel): """Slack alerter payload implementation.""" - pipeline_name: Optional[str] = None - step_name: Optional[str] = None - stack_name: Optional[str] = None + pipeline_name: str | None = None + step_name: str | None = None + stack_name: str | None = None class SlackAlerterParameters(BaseAlerterStepParameters): """Slack alerter parameters.""" # The ID of the Slack channel to use for communication. - slack_channel_id: Optional[str] = None + slack_channel_id: str | None = None # Set of messages that lead to approval in alerter.ask() - approve_msg_options: Optional[List[str]] = None + approve_msg_options: list[str] | None = None # Set of messages that lead to disapproval in alerter.ask() - disapprove_msg_options: Optional[List[str]] = None - payload: Optional[SlackAlerterPayload] = None - include_format_blocks: Optional[bool] = True + disapprove_msg_options: list[str] | None = None + payload: SlackAlerterPayload | None = None + include_format_blocks: bool | None = True # Allowing user to use their own custom blocks in the Slack post message - blocks: Optional[List[Dict]] = None # type: ignore + blocks: list[dict] | None = None # type: ignore class SlackAlerter(BaseAlerter): @@ -75,7 +75,7 @@ def config(self) -> SlackAlerterConfig: return cast(SlackAlerterConfig, self._config) @property - def settings_class(self) -> Type[SlackAlerterSettings]: + def settings_class(self) -> type[SlackAlerterSettings]: """Settings class for the Slack alerter. Returns: @@ -84,7 +84,7 @@ def settings_class(self) -> Type[SlackAlerterSettings]: return SlackAlerterSettings def _get_channel_id( - self, params: Optional[BaseAlerterStepParameters] = None + self, params: BaseAlerterStepParameters | None = None ) -> str: """Get the Slack channel ID to be used by post/ask. @@ -152,8 +152,8 @@ def _get_timeout_duration(self) -> int: @staticmethod def _get_approve_msg_options( - params: Optional[BaseAlerterStepParameters], - ) -> List[str]: + params: BaseAlerterStepParameters | None, + ) -> list[str]: """Define which messages will lead to approval during ask(). Args: @@ -172,8 +172,8 @@ def _get_approve_msg_options( @staticmethod def _get_disapprove_msg_options( - params: Optional[BaseAlerterStepParameters], - ) -> List[str]: + params: BaseAlerterStepParameters | None, + ) -> list[str]: """Define which messages will lead to disapproval during ask(). Args: @@ -192,9 +192,9 @@ def _get_disapprove_msg_options( @staticmethod def _create_blocks( - message: Optional[str], - params: Optional[BaseAlerterStepParameters], - ) -> List[Dict]: # type: ignore + message: str | None, + params: BaseAlerterStepParameters | None, + ) -> list[dict]: # type: ignore """Helper function to create slack blocks. Args: @@ -261,8 +261,8 @@ def _create_blocks( def post( self, - message: Optional[str] = None, - params: Optional[BaseAlerterStepParameters] = None, + message: str | None = None, + params: BaseAlerterStepParameters | None = None, ) -> bool: """Post a message to a Slack channel. @@ -300,7 +300,7 @@ def post( return False def ask( - self, question: str, params: Optional[BaseAlerterStepParameters] = None + self, question: str, params: BaseAlerterStepParameters | None = None ) -> bool: """Post a message to a Slack channel and wait for approval. diff --git a/src/zenml/integrations/slack/flavors/slack_alerter_flavor.py b/src/zenml/integrations/slack/flavors/slack_alerter_flavor.py index 364533210df..d62b780c159 100644 --- a/src/zenml/integrations/slack/flavors/slack_alerter_flavor.py +++ b/src/zenml/integrations/slack/flavors/slack_alerter_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Slack alerter flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.alerter.base_alerter import BaseAlerterConfig, BaseAlerterFlavor from zenml.config.base_settings import BaseSettings @@ -36,7 +36,7 @@ class SlackAlerterSettings(BaseSettings): timeout: The amount of seconds to wait for the ask method. """ - slack_channel_id: Optional[str] = None + slack_channel_id: str | None = None timeout: int = 300 @@ -51,7 +51,7 @@ class SlackAlerterConfig(BaseAlerterConfig, SlackAlerterSettings): slack_token: str = SecretField() - default_slack_channel_id: Optional[str] = None + default_slack_channel_id: str | None = None _deprecation_validator = deprecate_pydantic_attributes( ("default_slack_channel_id", "slack_channel_id") ) @@ -106,7 +106,7 @@ def name(self) -> str: return SLACK_ALERTER_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A URL to point at docs explaining this flavor. Returns: @@ -115,7 +115,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A URL to point at SDK docs explaining this flavor. Returns: @@ -133,7 +133,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/alerter/slack.png" @property - def config_class(self) -> Type[SlackAlerterConfig]: + def config_class(self) -> type[SlackAlerterConfig]: """Returns `SlackAlerterConfig` config class. Returns: @@ -142,7 +142,7 @@ def config_class(self) -> Type[SlackAlerterConfig]: return SlackAlerterConfig @property - def implementation_class(self) -> Type["SlackAlerter"]: + def implementation_class(self) -> type["SlackAlerter"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/slack/steps/slack_alerter_ask_step.py b/src/zenml/integrations/slack/steps/slack_alerter_ask_step.py index 90dc9277bc8..25e531acd5f 100644 --- a/src/zenml/integrations/slack/steps/slack_alerter_ask_step.py +++ b/src/zenml/integrations/slack/steps/slack_alerter_ask_step.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Step that allows you to send messages to Slack and wait for a response.""" -from typing import Optional from zenml import get_step_context, step from zenml.client import Client @@ -27,7 +26,7 @@ @step def slack_alerter_ask_step( message: str, - params: Optional[SlackAlerterParameters] = None, + params: SlackAlerterParameters | None = None, ) -> bool: """Posts a message to the Slack alerter component and waits for approval. diff --git a/src/zenml/integrations/slack/steps/slack_alerter_post_step.py b/src/zenml/integrations/slack/steps/slack_alerter_post_step.py index 261774a8279..038cedbc1dc 100644 --- a/src/zenml/integrations/slack/steps/slack_alerter_post_step.py +++ b/src/zenml/integrations/slack/steps/slack_alerter_post_step.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Step that allows you to post messages to Slack.""" -from typing import Optional from zenml import get_step_context, step from zenml.client import Client @@ -26,8 +25,8 @@ @step def slack_alerter_post_step( - message: Optional[str] = None, - params: Optional[SlackAlerterParameters] = None, + message: str | None = None, + params: SlackAlerterParameters | None = None, ) -> bool: """Post a message to the Slack alerter component of the active stack. diff --git a/src/zenml/integrations/spark/__init__.py b/src/zenml/integrations/spark/__init__.py index f44308340a0..c67297c8fe2 100644 --- a/src/zenml/integrations/spark/__init__.py +++ b/src/zenml/integrations/spark/__init__.py @@ -14,9 +14,7 @@ """The Spark integration module to enable distributed processing for steps.""" -from typing import List, Type -from zenml.enums import StackComponentType from zenml.integrations.constants import SPARK from zenml.integrations.integration import Integration from zenml.stack import Flavor @@ -36,7 +34,7 @@ def activate(cls) -> None: from zenml.integrations.spark import materializers # noqa @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Spark integration. Returns: diff --git a/src/zenml/integrations/spark/flavors/spark_on_kubernetes_step_operator_flavor.py b/src/zenml/integrations/spark/flavors/spark_on_kubernetes_step_operator_flavor.py index 68610ed043b..23c84052759 100644 --- a/src/zenml/integrations/spark/flavors/spark_on_kubernetes_step_operator_flavor.py +++ b/src/zenml/integrations/spark/flavors/spark_on_kubernetes_step_operator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Spark on Kubernetes step operator flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from zenml.integrations.spark import SPARK_KUBERNETES_STEP_OPERATOR from zenml.integrations.spark.flavors.spark_step_operator_flavor import ( @@ -37,8 +37,8 @@ class KubernetesSparkStepOperatorConfig(SparkStepOperatorConfig): components (to create and watch the pods). """ - namespace: Optional[str] = None - service_account: Optional[str] = None + namespace: str | None = None + service_account: str | None = None class KubernetesSparkStepOperatorFlavor(SparkStepOperatorFlavor): @@ -54,7 +54,7 @@ def name(self) -> str: return SPARK_KUBERNETES_STEP_OPERATOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -63,7 +63,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -81,7 +81,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/step_operator/spark.png" @property - def config_class(self) -> Type[KubernetesSparkStepOperatorConfig]: + def config_class(self) -> type[KubernetesSparkStepOperatorConfig]: """Returns `KubernetesSparkStepOperatorConfig` config class. Returns: @@ -90,7 +90,7 @@ def config_class(self) -> Type[KubernetesSparkStepOperatorConfig]: return KubernetesSparkStepOperatorConfig @property - def implementation_class(self) -> Type["KubernetesSparkStepOperator"]: + def implementation_class(self) -> type["KubernetesSparkStepOperator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/spark/flavors/spark_step_operator_flavor.py b/src/zenml/integrations/spark/flavors/spark_step_operator_flavor.py index 1801eec24f0..d8576aca825 100644 --- a/src/zenml/integrations/spark/flavors/spark_step_operator_flavor.py +++ b/src/zenml/integrations/spark/flavors/spark_step_operator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Spark step operator flavor.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Type +from typing import TYPE_CHECKING, Any from zenml.config.base_settings import BaseSettings from zenml.step_operators.base_step_operator import ( @@ -40,7 +40,7 @@ class SparkStepOperatorSettings(BaseSettings): """ deploy_mode: str = "cluster" - submit_kwargs: Optional[Dict[str, Any]] = None + submit_kwargs: dict[str, Any] | None = None class SparkStepOperatorConfig( @@ -71,7 +71,7 @@ def name(self) -> str: return "spark" @property - def config_class(self) -> Type[SparkStepOperatorConfig]: + def config_class(self) -> type[SparkStepOperatorConfig]: """Returns `SparkStepOperatorConfig` config class. Returns: @@ -80,7 +80,7 @@ def config_class(self) -> Type[SparkStepOperatorConfig]: return SparkStepOperatorConfig @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -89,7 +89,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -98,7 +98,7 @@ def sdk_docs_url(self) -> Optional[str]: return self.generate_default_sdk_docs_url() @property - def implementation_class(self) -> Type["SparkStepOperator"]: + def implementation_class(self) -> type["SparkStepOperator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/spark/materializers/__init__.py b/src/zenml/integrations/spark/materializers/__init__.py index 3275916d736..2fa6291c36c 100644 --- a/src/zenml/integrations/spark/materializers/__init__.py +++ b/src/zenml/integrations/spark/materializers/__init__.py @@ -13,9 +13,3 @@ # permissions and limitations under the License. """Spark Materializers.""" -from zenml.integrations.spark.materializers.spark_dataframe_materializer import ( - SparkDataFrameMaterializer, -) -from zenml.integrations.spark.materializers.spark_model_materializer import ( - SparkModelMaterializer, -) diff --git a/src/zenml/integrations/spark/materializers/spark_dataframe_materializer.py b/src/zenml/integrations/spark/materializers/spark_dataframe_materializer.py index 8f4beb20bcd..e7f25fe37dd 100644 --- a/src/zenml/integrations/spark/materializers/spark_dataframe_materializer.py +++ b/src/zenml/integrations/spark/materializers/spark_dataframe_materializer.py @@ -14,7 +14,7 @@ """Implementation of the Spark Dataframe Materializer.""" import os.path -from typing import Any, ClassVar, Dict, Tuple, Type +from typing import Any, ClassVar from pyspark.sql import DataFrame, SparkSession @@ -28,10 +28,10 @@ class SparkDataFrameMaterializer(BaseMaterializer): """Materializer to read/write Spark dataframes.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (DataFrame,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (DataFrame,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - def load(self, data_type: Type[Any]) -> DataFrame: + def load(self, data_type: type[Any]) -> DataFrame: """Reads and returns a spark dataframe. Args: @@ -57,7 +57,7 @@ def save(self, df: DataFrame) -> None: path = os.path.join(self.uri, DEFAULT_FILEPATH) df.write.parquet(path) - def extract_metadata(self, df: DataFrame) -> Dict[str, "MetadataType"]: + def extract_metadata(self, df: DataFrame) -> dict[str, "MetadataType"]: """Extract metadata from the given `DataFrame` object. Args: diff --git a/src/zenml/integrations/spark/materializers/spark_model_materializer.py b/src/zenml/integrations/spark/materializers/spark_model_materializer.py index 0911d9ade3e..3c6496cd763 100644 --- a/src/zenml/integrations/spark/materializers/spark_model_materializer.py +++ b/src/zenml/integrations/spark/materializers/spark_model_materializer.py @@ -14,7 +14,7 @@ """Implementation of the Spark Model Materializer.""" import os -from typing import Any, ClassVar, Tuple, Type, Union +from typing import Any, ClassVar from pyspark.ml import Estimator, Model, Transformer @@ -27,7 +27,7 @@ class SparkModelMaterializer(BaseMaterializer): """Materializer to read/write Spark models.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = ( + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = ( Transformer, Estimator, Model, @@ -35,8 +35,8 @@ class SparkModelMaterializer(BaseMaterializer): ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL def load( - self, model_type: Type[Any] - ) -> Union[Transformer, Estimator, Model]: # type: ignore[type-arg] + self, model_type: type[Any] + ) -> Transformer | Estimator | Model: # type: ignore[type-arg] """Reads and returns a Spark ML model. Args: @@ -50,7 +50,7 @@ def load( def save( self, - model: Union[Transformer, Estimator, Model], # type: ignore[type-arg] + model: Transformer | Estimator | Model, # type: ignore[type-arg] ) -> None: """Writes a spark model. diff --git a/src/zenml/integrations/spark/step_operators/kubernetes_step_operator.py b/src/zenml/integrations/spark/step_operators/kubernetes_step_operator.py index 4b5d8717d0a..60b72330414 100644 --- a/src/zenml/integrations/spark/step_operators/kubernetes_step_operator.py +++ b/src/zenml/integrations/spark/step_operators/kubernetes_step_operator.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the Kubernetes Spark Step Operator.""" -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, cast from pyspark.conf import SparkConf @@ -57,7 +57,7 @@ def config(self) -> KubernetesSparkStepOperatorConfig: return cast(KubernetesSparkStepOperatorConfig, self._config) @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Validates the stack. Returns: @@ -65,7 +65,7 @@ def validator(self) -> Optional[StackValidator]: registry and a remote artifact store. """ - def _validate_remote_components(stack: "Stack") -> Tuple[bool, str]: + def _validate_remote_components(stack: "Stack") -> tuple[bool, str]: if stack.artifact_store.config.is_local: return False, ( "The Spark step operator runs code remotely and " @@ -110,7 +110,7 @@ def application_path(self) -> Any: def get_docker_builds( self, snapshot: "PipelineSnapshotBase" - ) -> List["BuildConfiguration"]: + ) -> list["BuildConfiguration"]: """Gets the Docker builds required for the component. Args: @@ -139,7 +139,7 @@ def _backend_configuration( self, spark_config: SparkConf, info: "StepRunInfo", - environment: Dict[str, str], + environment: dict[str, str], ) -> None: """Configures Spark to run on Kubernetes. diff --git a/src/zenml/integrations/spark/step_operators/spark_step_operator.py b/src/zenml/integrations/spark/step_operators/spark_step_operator.py index f325b5bb7ef..30c122a8bb3 100644 --- a/src/zenml/integrations/spark/step_operators/spark_step_operator.py +++ b/src/zenml/integrations/spark/step_operators/spark_step_operator.py @@ -14,7 +14,7 @@ """Implementation of the Spark Step Operator.""" import subprocess -from typing import TYPE_CHECKING, Dict, List, Optional, Type, cast +from typing import TYPE_CHECKING, cast from pyspark.conf import SparkConf @@ -49,7 +49,7 @@ def config(self) -> SparkStepOperatorConfig: return cast(SparkStepOperatorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Spark step operator. Returns: @@ -58,7 +58,7 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: return SparkStepOperatorSettings @property - def application_path(self) -> Optional[str]: + def application_path(self) -> str | None: """Optional method for providing the application path. This is especially critical when using 'spark-submit' as it defines the @@ -111,7 +111,7 @@ def _backend_configuration( self, spark_config: SparkConf, info: "StepRunInfo", - environment: Dict[str, str], + environment: dict[str, str], ) -> None: """Configures Spark to handle backends like YARN, Mesos or Kubernetes. @@ -215,7 +215,7 @@ def _launch_spark_job( self, spark_config: SparkConf, deploy_mode: str, - entrypoint_command: List[str], + entrypoint_command: list[str], ) -> None: """Generates and executes a spark-submit command. @@ -269,8 +269,8 @@ def _launch_spark_job( def launch( self, info: "StepRunInfo", - entrypoint_command: List[str], - environment: Dict[str, str], + entrypoint_command: list[str], + environment: dict[str, str], ) -> None: """Launches a step on Spark. diff --git a/src/zenml/integrations/tekton/__init__.py b/src/zenml/integrations/tekton/__init__.py index 99aecf7ec8f..2d0287cc0e1 100644 --- a/src/zenml/integrations/tekton/__init__.py +++ b/src/zenml/integrations/tekton/__init__.py @@ -17,9 +17,7 @@ orchestrator. You can enable it by registering the Tekton orchestrator with the CLI tool. """ -from typing import List, Type -from zenml.enums import StackComponentType from zenml.integrations.constants import TEKTON from zenml.integrations.integration import Integration from zenml.stack import Flavor @@ -35,7 +33,7 @@ class TektonIntegration(Integration): REQUIREMENTS_IGNORED_ON_UNINSTALL = ["kfp"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Tekton integration. Returns: diff --git a/src/zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py b/src/zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py index 84f119d80a3..8bcd8288ac9 100644 --- a/src/zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py +++ b/src/zenml/integrations/tekton/flavors/tekton_orchestrator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Tekton orchestrator flavor.""" -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any from pydantic import model_validator @@ -57,20 +57,20 @@ class TektonOrchestratorSettings(BaseSettings): synchronous: bool = True timeout: int = 1200 - client_args: Dict[str, Any] = {} - client_username: Optional[str] = SecretField(default=None) - client_password: Optional[str] = SecretField(default=None) - user_namespace: Optional[str] = None - node_selectors: Dict[str, str] = {} - node_affinity: Dict[str, List[str]] = {} - pod_settings: Optional[KubernetesPodSettings] = None + client_args: dict[str, Any] = {} + client_username: str | None = SecretField(default=None) + client_password: str | None = SecretField(default=None) + user_namespace: str | None = None + node_selectors: dict[str, str] = {} + node_affinity: dict[str, list[str]] = {} + pod_settings: KubernetesPodSettings | None = None @model_validator(mode="before") @classmethod @before_validator_handler def _validate_and_migrate_pod_settings( - cls, data: Dict[str, Any] - ) -> Dict[str, Any]: + cls, data: dict[str, Any] + ) -> dict[str, Any]: """Validates settings and migrates pod settings from older version. Args: @@ -106,16 +106,16 @@ class TektonOrchestratorConfig( pods that run the pipeline steps should be running. """ - tekton_hostname: Optional[str] = None - kubernetes_context: Optional[str] = None + tekton_hostname: str | None = None + kubernetes_context: str | None = None kubernetes_namespace: str = "kubeflow" @model_validator(mode="before") @classmethod @before_validator_handler def _validate_deprecated_attrs( - cls, data: Dict[str, Any] - ) -> Dict[str, Any]: + cls, data: dict[str, Any] + ) -> dict[str, Any]: """Pydantic root_validator for deprecated attributes. This root validator is used for backwards compatibility purposes. E.g. @@ -179,7 +179,7 @@ def name(self) -> str: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -194,7 +194,7 @@ def service_connector_requirements( ) @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -203,7 +203,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -221,7 +221,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/tekton.png" @property - def config_class(self) -> Type[TektonOrchestratorConfig]: + def config_class(self) -> type[TektonOrchestratorConfig]: """Returns `TektonOrchestratorConfig` config class. Returns: @@ -230,7 +230,7 @@ def config_class(self) -> Type[TektonOrchestratorConfig]: return TektonOrchestratorConfig @property - def implementation_class(self) -> Type["TektonOrchestrator"]: + def implementation_class(self) -> type["TektonOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/tekton/orchestrators/tekton_orchestrator.py b/src/zenml/integrations/tekton/orchestrators/tekton_orchestrator.py index bb864b564f9..eb6c59d1ebf 100644 --- a/src/zenml/integrations/tekton/orchestrators/tekton_orchestrator.py +++ b/src/zenml/integrations/tekton/orchestrators/tekton_orchestrator.py @@ -18,11 +18,7 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Tuple, - Type, cast, ) @@ -126,7 +122,7 @@ def _load_config(self, *args: Any, **kwargs: Any) -> Any: class TektonOrchestrator(ContainerizedOrchestrator): """Orchestrator responsible for running pipelines using Tekton.""" - _k8s_client: Optional[k8s_client.ApiClient] = None + _k8s_client: k8s_client.ApiClient | None = None def _get_kfp_client( self, @@ -243,7 +239,7 @@ def _get_session_cookie(self, username: str, password: str) -> str: raise RuntimeError( f"Error while trying to fetch tekoton cookie: {errh}" ) - cookie_dict: Dict[str, str] = session.cookies.get_dict() # type: ignore[no-untyped-call, unused-ignore] + cookie_dict: dict[str, str] = session.cookies.get_dict() # type: ignore[no-untyped-call, unused-ignore] if "authservice_session" not in cookie_dict: raise RuntimeError("Invalid username and/or password!") @@ -262,7 +258,7 @@ def config(self) -> TektonOrchestratorConfig: return cast(TektonOrchestratorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Tekton orchestrator. Returns: @@ -270,7 +266,7 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: """ return TektonOrchestratorSettings - def get_kubernetes_contexts(self) -> Tuple[List[str], Optional[str]]: + def get_kubernetes_contexts(self) -> tuple[list[str], str | None]: """Get the list of configured Kubernetes contexts and the active context. Returns: @@ -287,14 +283,14 @@ def get_kubernetes_contexts(self) -> Tuple[List[str], Optional[str]]: return context_names, active_context_name @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Ensures a stack with only remote components and a container registry. Returns: A `StackValidator` instance. """ - def _validate(stack: "Stack") -> Tuple[bool, str]: + def _validate(stack: "Stack") -> tuple[bool, str]: container_registry = stack.container_registry # should not happen, because the stack validation takes care of @@ -414,8 +410,8 @@ def _validate(stack: "Stack") -> Tuple[bool, str]: def _create_dynamic_component( self, image: str, - command: List[str], - arguments: List[str], + command: list[str], + arguments: list[str], component_name: str, ) -> dsl.PipelineTask: """Creates a dynamic container component for a Tekton pipeline. @@ -459,10 +455,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to @@ -510,7 +506,7 @@ def _create_dynamic_pipeline() -> Any: Returns: pipeline_func """ - step_name_to_dynamic_component: Dict[str, Any] = {} + step_name_to_dynamic_component: dict[str, Any] = {} for step_name, step in snapshot.step_configurations.items(): image = self.get_image( @@ -530,7 +526,7 @@ def _create_dynamic_pipeline() -> Any: step_settings = cast( TektonOrchestratorSettings, self.get_settings(step) ) - node_selector_constraint: Optional[Tuple[str, str]] = None + node_selector_constraint: tuple[str, str] | None = None pod_settings = step_settings.pod_settings if pod_settings: ignored_fields = pod_settings.model_fields_set - { @@ -627,7 +623,7 @@ def _upload_and_run_pipeline( snapshot: "PipelineSnapshotResponse", pipeline_file_path: str, run_name: str, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Tries to upload and run a KFP pipeline. Args: @@ -828,7 +824,7 @@ def _configure_container_resources( self, dynamic_component: dsl.PipelineTask, resource_settings: "ResourceSettings", - node_selector_constraint: Optional[Tuple[str, str]] = None, + node_selector_constraint: tuple[str, str] | None = None, ) -> dsl.PipelineTask: """Adds resource requirements to the container. diff --git a/src/zenml/integrations/tensorboard/__init__.py b/src/zenml/integrations/tensorboard/__init__.py index 6c86fa10f85..d374b6de4e5 100644 --- a/src/zenml/integrations/tensorboard/__init__.py +++ b/src/zenml/integrations/tensorboard/__init__.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Initialization for TensorBoard integration.""" -from typing import List, Optional from zenml.integrations.constants import TENSORBOARD from zenml.integrations.integration import Integration @@ -25,8 +24,8 @@ class TensorBoardIntegration(Integration): REQUIREMENTS = [] @classmethod - def get_requirements(cls, target_os: Optional[str] = None, python_version: Optional[str] = None - ) -> List[str]: + def get_requirements(cls, target_os: str | None = None, python_version: str | None = None + ) -> list[str]: """Defines platform specific requirements for the integration. Args: diff --git a/src/zenml/integrations/tensorboard/services/tensorboard_service.py b/src/zenml/integrations/tensorboard/services/tensorboard_service.py index 9d78b44e322..5415ce0f8e5 100644 --- a/src/zenml/integrations/tensorboard/services/tensorboard_service.py +++ b/src/zenml/integrations/tensorboard/services/tensorboard_service.py @@ -14,7 +14,7 @@ """Implementation of the TensorBoard service.""" import uuid -from typing import Any, Dict, Union +from typing import Any from zenml.logger import get_logger from zenml.models.v2.misc.service import ServiceType @@ -71,7 +71,7 @@ class TensorboardService(LocalDaemonService): def __init__( self, - config: Union[TensorboardServiceConfig, Dict[str, Any]], + config: TensorboardServiceConfig | dict[str, Any], **attrs: Any, ) -> None: """Initialization for TensorBoard service. diff --git a/src/zenml/integrations/tensorflow/__init__.py b/src/zenml/integrations/tensorflow/__init__.py index d84c0a5de80..c093ffd82f4 100644 --- a/src/zenml/integrations/tensorflow/__init__.py +++ b/src/zenml/integrations/tensorflow/__init__.py @@ -14,8 +14,6 @@ """Initialization for TensorFlow integration.""" import platform -import sys -from typing import List, Optional from zenml.integrations.constants import TENSORFLOW from zenml.integrations.integration import Integration from zenml.logger import get_logger @@ -41,8 +39,8 @@ def activate(cls) -> None: from zenml.integrations.tensorflow import materializers # noqa @classmethod - def get_requirements(cls, target_os: Optional[str] = None, python_version: Optional[str] = None - ) -> List[str]: + def get_requirements(cls, target_os: str | None = None, python_version: str | None = None + ) -> list[str]: """Defines platform specific requirements for the integration. Args: diff --git a/src/zenml/integrations/tensorflow/materializers/keras_materializer.py b/src/zenml/integrations/tensorflow/materializers/keras_materializer.py index 4577b867f41..9e7a6a2d697 100644 --- a/src/zenml/integrations/tensorflow/materializers/keras_materializer.py +++ b/src/zenml/integrations/tensorflow/materializers/keras_materializer.py @@ -14,7 +14,7 @@ """Implementation of the TensorFlow Keras materializer.""" import os -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar import tensorflow as tf from tensorflow.python import keras as tf_keras @@ -31,14 +31,14 @@ class KerasMaterializer(BaseMaterializer): """Materializer to read/write Keras models.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = ( + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = ( tf.keras.Model, tf_keras.Model, ) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL MODEL_FILE_NAME = "model.keras" - def load(self, data_type: Type[Any]) -> tf_keras.Model: + def load(self, data_type: type[Any]) -> tf_keras.Model: """Reads and returns a Keras model after copying it to temporary path. Args: @@ -70,7 +70,7 @@ def save(self, model: tf_keras.Model) -> None: def extract_metadata( self, model: tf_keras.Model - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Extract metadata from the given `Model` object. Args: diff --git a/src/zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py b/src/zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py index 73ba9bf1f14..92aaa26dc86 100644 --- a/src/zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py +++ b/src/zenml/integrations/tensorflow/materializers/tf_dataset_materializer.py @@ -14,7 +14,7 @@ """Implementation of the TensorFlow dataset materializer.""" import os -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar import tensorflow as tf @@ -31,10 +31,10 @@ class TensorflowDatasetMaterializer(BaseMaterializer): """Materializer to read data to and from tf.data.Dataset.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (tf.data.Dataset,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (tf.data.Dataset,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - def load(self, data_type: Type[Any]) -> Any: + def load(self, data_type: type[Any]) -> Any: """Reads data into tf.data.Dataset. Args: @@ -64,7 +64,7 @@ def save(self, dataset: tf.data.Dataset) -> None: def extract_metadata( self, dataset: tf.data.Dataset - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Extract metadata from the given `Dataset` object. Args: diff --git a/src/zenml/integrations/utils.py b/src/zenml/integrations/utils.py index 4e06f4d0e92..8913a02e64b 100644 --- a/src/zenml/integrations/utils.py +++ b/src/zenml/integrations/utils.py @@ -16,14 +16,13 @@ import importlib import inspect import sys -from typing import List, Optional, Type from zenml.integrations.integration import Integration, IntegrationMeta def get_integration_for_module( module_name: str, -) -> Optional[Type[Integration]]: +) -> type[Integration] | None: """Gets the integration class for a module inside an integration. If the module given by `module_name` is not part of a ZenML integration, @@ -58,7 +57,7 @@ def get_integration_for_module( return None -def get_requirements_for_module(module_name: str) -> List[str]: +def get_requirements_for_module(module_name: str) -> list[str]: """Gets requirements for a module inside an integration. If the module given by `module_name` is not part of a ZenML integration, diff --git a/src/zenml/integrations/vllm/__init__.py b/src/zenml/integrations/vllm/__init__.py index 2c6a3a25bc1..6132f21c132 100644 --- a/src/zenml/integrations/vllm/__init__.py +++ b/src/zenml/integrations/vllm/__init__.py @@ -12,7 +12,6 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """Initialization for the ZenML vLLM integration.""" -from typing import List, Type from zenml.integrations.integration import Integration from zenml.stack import Flavor from zenml.logger import get_logger @@ -33,10 +32,9 @@ class VLLMIntegration(Integration): @classmethod def activate(cls) -> None: """Activates the integration.""" - from zenml.integrations.vllm import services @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the vLLM integration. Returns: diff --git a/src/zenml/integrations/vllm/flavors/vllm_model_deployer_flavor.py b/src/zenml/integrations/vllm/flavors/vllm_model_deployer_flavor.py index c21c76fcff8..0469ea65543 100644 --- a/src/zenml/integrations/vllm/flavors/vllm_model_deployer_flavor.py +++ b/src/zenml/integrations/vllm/flavors/vllm_model_deployer_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """vLLM model deployer flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from pydantic import Field @@ -50,7 +50,7 @@ def name(self) -> str: return VLLM_MODEL_DEPLOYER @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -59,7 +59,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -77,7 +77,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/vllm.png" @property - def config_class(self) -> Type[VLLMModelDeployerConfig]: + def config_class(self) -> type[VLLMModelDeployerConfig]: """Returns `VLLMModelDeployerConfig` config class. Returns: @@ -86,7 +86,7 @@ def config_class(self) -> Type[VLLMModelDeployerConfig]: return VLLMModelDeployerConfig @property - def implementation_class(self) -> Type["VLLMModelDeployer"]: + def implementation_class(self) -> type["VLLMModelDeployer"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/vllm/model_deployers/vllm_model_deployer.py b/src/zenml/integrations/vllm/model_deployers/vllm_model_deployer.py index 708949d7878..72ef7935d3d 100644 --- a/src/zenml/integrations/vllm/model_deployers/vllm_model_deployer.py +++ b/src/zenml/integrations/vllm/model_deployers/vllm_model_deployer.py @@ -15,7 +15,7 @@ import os import shutil -from typing import ClassVar, Dict, Optional, Type, cast +from typing import ClassVar, cast from uuid import UUID from zenml.config.global_config import GlobalConfiguration @@ -40,9 +40,9 @@ class VLLMModelDeployer(BaseModelDeployer): """vLLM Inference Server.""" NAME: ClassVar[str] = "VLLM" - FLAVOR: ClassVar[Type[BaseModelDeployerFlavor]] = VLLMModelDeployerFlavor + FLAVOR: ClassVar[type[BaseModelDeployerFlavor]] = VLLMModelDeployerFlavor - _service_path: Optional[str] = None + _service_path: str | None = None @property def config(self) -> VLLMModelDeployerConfig: @@ -100,7 +100,7 @@ def local_path(self) -> str: @staticmethod def get_model_server_info( # type: ignore[override] service_instance: "VLLMDeploymentService", - ) -> Dict[str, Optional[str]]: + ) -> dict[str, str | None]: """Return implementation specific information on the model server. Args: diff --git a/src/zenml/integrations/vllm/services/vllm_deployment.py b/src/zenml/integrations/vllm/services/vllm_deployment.py index 0ff4617f20e..63b41d87fa0 100644 --- a/src/zenml/integrations/vllm/services/vllm_deployment.py +++ b/src/zenml/integrations/vllm/services/vllm_deployment.py @@ -15,7 +15,7 @@ import argparse import os -from typing import Any, List, Optional, Union +from typing import Any from zenml.constants import DEFAULT_LOCAL_SERVICE_IP_ADDRESS from zenml.logger import get_logger @@ -59,7 +59,7 @@ class VLLMDeploymentEndpoint(LocalDaemonServiceEndpoint): monitor: HTTPEndpointHealthMonitor @property - def prediction_url(self) -> Optional[str]: + def prediction_url(self) -> str | None: """Gets the prediction URL for the endpoint. Returns: @@ -76,20 +76,20 @@ class VLLMServiceConfig(LocalDaemonServiceConfig): model: str port: int - host: Optional[str] = None + host: str | None = None blocking: bool = True # If unspecified, model name or path will be used. - tokenizer: Optional[str] = None - served_model_name: Optional[Union[str, List[str]]] = None + tokenizer: str | None = None + served_model_name: str | list[str] | None = None # Trust remote code from huggingface. - trust_remote_code: Optional[bool] = False + trust_remote_code: bool | None = False # ['auto', 'slow', 'mistral'] - tokenizer_mode: Optional[str] = "auto" + tokenizer_mode: str | None = "auto" # ['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'] - dtype: Optional[str] = "auto" + dtype: str | None = "auto" # The specific model version to use. It can be a branch name, a tag name, or a commit id. # If unspecified, will use the default version. - revision: Optional[str] = None + revision: str | None = None class VLLMDeploymentService(LocalDaemonService, BaseDeploymentService): @@ -167,7 +167,7 @@ def run(self) -> None: logger.info("Stopping vLLM prediction service...") @property - def prediction_url(self) -> Optional[str]: + def prediction_url(self) -> str | None: """Gets the prediction URL for the endpoint. Returns: diff --git a/src/zenml/integrations/wandb/__init__.py b/src/zenml/integrations/wandb/__init__.py index dfed8c6a8fc..610c9946108 100644 --- a/src/zenml/integrations/wandb/__init__.py +++ b/src/zenml/integrations/wandb/__init__.py @@ -16,9 +16,7 @@ The wandb integrations currently enables you to use wandb tracking as a convenient way to visualize your experiment runs within the wandb ui. """ -from typing import List, Type -from zenml.enums import StackComponentType from zenml.integrations.constants import WANDB from zenml.integrations.integration import Integration from zenml.stack import Flavor @@ -34,7 +32,7 @@ class WandbIntegration(Integration): REQUIREMENTS_IGNORED_ON_UNINSTALL = ["Pillow"] @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Weights and Biases integration. Returns: diff --git a/src/zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py b/src/zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py index 16899263206..a055bc65295 100644 --- a/src/zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py +++ b/src/zenml/integrations/wandb/experiment_trackers/wandb_experiment_tracker.py @@ -14,7 +14,7 @@ """Implementation for the wandb experiment tracker.""" import os -from typing import TYPE_CHECKING, Dict, List, Optional, Type, cast +from typing import TYPE_CHECKING, cast import wandb @@ -52,7 +52,7 @@ def config(self) -> WandbExperimentTrackerConfig: return cast(WandbExperimentTrackerConfig, self._config) @property - def settings_class(self) -> Type[WandbExperimentTrackerSettings]: + def settings_class(self) -> type[WandbExperimentTrackerSettings]: """Settings class for the Wandb experiment tracker. Returns: @@ -78,7 +78,7 @@ def prepare_step_run(self, info: "StepRunInfo") -> None: def get_step_run_metadata( self, info: "StepRunInfo" - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Get component- and step-specific metadata after a step ran. Args: @@ -87,8 +87,8 @@ def get_step_run_metadata( Returns: A dictionary of metadata. """ - run_url: Optional[str] = None - run_name: Optional[str] = None + run_url: str | None = None + run_name: str | None = None # Try to get the run name and URL from WandB directly current_wandb_run = wandb.run @@ -129,7 +129,7 @@ def _initialize_wandb( self, info: "StepRunInfo", run_name: str, - tags: List[str], + tags: list[str], ) -> None: """Initializes a wandb run. diff --git a/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py b/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py index 3c14af29572..949d9965a1a 100644 --- a/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py +++ b/src/zenml/integrations/wandb/flavors/wandb_experiment_tracker_flavor.py @@ -16,10 +16,6 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, - Optional, - Type, cast, ) @@ -42,14 +38,14 @@ class WandbExperimentTrackerSettings(BaseSettings): """Settings for the Wandb experiment tracker.""" - run_name: Optional[str] = Field( + run_name: str | None = Field( None, description="The Wandb run name to use for tracking experiments." ) - tags: List[str] = Field( + tags: list[str] = Field( default_factory=list, description="Tags to attach to the Wandb run for categorization and filtering.", ) - settings: Dict[str, Any] = Field( + settings: dict[str, Any] = Field( default_factory=dict, description="Additional settings for the Wandb run configuration.", ) @@ -84,7 +80,7 @@ def _convert_settings(cls, value: Any) -> Any: if isinstance(value, BaseModel): return value.model_dump() # type: ignore[no-untyped-call] elif hasattr(value, "make_static"): - return cast(Dict[str, Any], value.make_static()) + return cast(dict[str, Any], value.make_static()) elif hasattr(value, "to_dict"): return value.to_dict() else: @@ -102,12 +98,12 @@ class WandbExperimentTrackerConfig( description="API key that should be authorized to log to the configured " "Wandb entity and project. Required for authentication." ) - entity: Optional[str] = Field( + entity: str | None = Field( None, description="Name of an existing Wandb entity (team or user account) " "to log experiments to.", ) - project_name: Optional[str] = Field( + project_name: str | None = Field( None, description="Name of an existing Wandb project to log experiments to. " "If not specified, a default project will be used.", @@ -127,7 +123,7 @@ def name(self) -> str: return WANDB_EXPERIMENT_TRACKER_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A URL to point at docs explaining this flavor. Returns: @@ -136,7 +132,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A URL to point at SDK docs explaining this flavor. Returns: @@ -154,7 +150,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/experiment_tracker/wandb.png" @property - def config_class(self) -> Type[WandbExperimentTrackerConfig]: + def config_class(self) -> type[WandbExperimentTrackerConfig]: """Returns `WandbExperimentTrackerConfig` config class. Returns: @@ -163,7 +159,7 @@ def config_class(self) -> Type[WandbExperimentTrackerConfig]: return WandbExperimentTrackerConfig @property - def implementation_class(self) -> Type["WandbExperimentTracker"]: + def implementation_class(self) -> type["WandbExperimentTracker"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/whylogs/__init__.py b/src/zenml/integrations/whylogs/__init__.py index dd8b8e590b3..35e21148d4d 100644 --- a/src/zenml/integrations/whylogs/__init__.py +++ b/src/zenml/integrations/whylogs/__init__.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Initialization of the whylogs integration.""" -from typing import List, Type, Optional from zenml.integrations.constants import WHYLOGS from zenml.integrations.integration import Integration @@ -37,7 +36,7 @@ def activate(cls) -> None: from zenml.integrations.whylogs import secret_schemas # noqa @classmethod - def flavors(cls) -> List[Type[Flavor]]: + def flavors(cls) -> list[type[Flavor]]: """Declare the stack component flavors for the Great Expectations integration. Returns: @@ -50,8 +49,8 @@ def flavors(cls) -> List[Type[Flavor]]: return [WhylogsDataValidatorFlavor] @classmethod - def get_requirements(cls, target_os: Optional[str] = None, python_version: Optional[str] = None - ) -> List[str]: + def get_requirements(cls, target_os: str | None = None, python_version: str | None = None + ) -> list[str]: """Method to get the requirements for the integration. Args: diff --git a/src/zenml/integrations/whylogs/data_validators/whylogs_data_validator.py b/src/zenml/integrations/whylogs/data_validators/whylogs_data_validator.py index d35828ab1be..0872b096e98 100644 --- a/src/zenml/integrations/whylogs/data_validators/whylogs_data_validator.py +++ b/src/zenml/integrations/whylogs/data_validators/whylogs_data_validator.py @@ -14,7 +14,8 @@ """Implementation of the whylogs data validator.""" import datetime -from typing import Any, ClassVar, Optional, Sequence, Type, cast +from typing import Any, ClassVar, cast +from collections.abc import Sequence import pandas as pd import whylogs as why # type: ignore @@ -50,7 +51,7 @@ class WhylogsDataValidator(BaseDataValidator, AuthenticationMixin): """ NAME: ClassVar[str] = "whylogs" - FLAVOR: ClassVar[Type[BaseDataValidatorFlavor]] = ( + FLAVOR: ClassVar[type[BaseDataValidatorFlavor]] = ( WhylogsDataValidatorFlavor ) @@ -64,7 +65,7 @@ def config(self) -> WhylogsDataValidatorConfig: return cast(WhylogsDataValidatorConfig, self._config) @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Whylogs data validator. Returns: @@ -75,9 +76,9 @@ def settings_class(self) -> Optional[Type["BaseSettings"]]: def data_profiling( self, dataset: pd.DataFrame, - comparison_dataset: Optional[pd.DataFrame] = None, - profile_list: Optional[Sequence[str]] = None, - dataset_timestamp: Optional[datetime.datetime] = None, + comparison_dataset: pd.DataFrame | None = None, + profile_list: Sequence[str] | None = None, + dataset_timestamp: datetime.datetime | None = None, **kwargs: Any, ) -> DatasetProfileView: """Analyze a dataset and generate a data profile with whylogs. @@ -105,7 +106,7 @@ def data_profiling( def upload_profile_view( self, profile_view: DatasetProfileView, - dataset_id: Optional[str] = None, + dataset_id: str | None = None, ) -> None: """Upload a whylogs data profile view to Whylabs, if configured to do so. diff --git a/src/zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py b/src/zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py index 1696f7779c6..f0ecd86ceb4 100644 --- a/src/zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py +++ b/src/zenml/integrations/whylogs/flavors/whylogs_data_validator_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """WhyLabs whylogs data validator flavor.""" -from typing import TYPE_CHECKING, Optional, Type +from typing import TYPE_CHECKING from pydantic import Field @@ -38,7 +38,7 @@ class WhylogsDataValidatorSettings(BaseSettings): "profile views returned by the step will automatically be uploaded " "to the Whylabs platform if Whylabs credentials are configured.", ) - dataset_id: Optional[str] = Field( + dataset_id: str | None = Field( None, description="Dataset ID to use when uploading profiles to Whylabs.", ) @@ -65,7 +65,7 @@ def name(self) -> str: return WHYLOGS_DATA_VALIDATOR_FLAVOR @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -74,7 +74,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -92,7 +92,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/data_validator/whylogs.png" @property - def config_class(self) -> Type[WhylogsDataValidatorConfig]: + def config_class(self) -> type[WhylogsDataValidatorConfig]: """Returns `WhylogsDataValidatorConfig` config class. Returns: @@ -101,7 +101,7 @@ def config_class(self) -> Type[WhylogsDataValidatorConfig]: return WhylogsDataValidatorConfig @property - def implementation_class(self) -> Type["WhylogsDataValidator"]: + def implementation_class(self) -> type["WhylogsDataValidator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/integrations/whylogs/materializers/whylogs_materializer.py b/src/zenml/integrations/whylogs/materializers/whylogs_materializer.py index 2da1ed8c649..42761b8a253 100644 --- a/src/zenml/integrations/whylogs/materializers/whylogs_materializer.py +++ b/src/zenml/integrations/whylogs/materializers/whylogs_materializer.py @@ -14,7 +14,7 @@ """Implementation of the whylogs materializer.""" import os -from typing import Any, ClassVar, Dict, Tuple, Type, cast +from typing import Any, ClassVar, cast from whylogs.core import DatasetProfileView # type: ignore from whylogs.viz import NotebookProfileVisualizer # type: ignore @@ -34,12 +34,12 @@ class WhylogsMaterializer(BaseMaterializer): """Materializer to read/write whylogs dataset profile views.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (DatasetProfileView,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (DatasetProfileView,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ( ArtifactType.DATA_ANALYSIS ) - def load(self, data_type: Type[Any]) -> DatasetProfileView: + def load(self, data_type: type[Any]) -> DatasetProfileView: """Reads and returns a whylogs dataset profile view. Args: @@ -85,7 +85,7 @@ def save(self, profile_view: DatasetProfileView) -> None: def save_visualizations( self, profile_view: DatasetProfileView, - ) -> Dict[str, VisualizationType]: + ) -> dict[str, VisualizationType]: """Saves visualizations for the given whylogs dataset profile view. Args: diff --git a/src/zenml/integrations/whylogs/secret_schemas/whylabs_secret_schema.py b/src/zenml/integrations/whylogs/secret_schemas/whylabs_secret_schema.py index 12c1c8550a1..d4a1713c657 100644 --- a/src/zenml/integrations/whylogs/secret_schemas/whylabs_secret_schema.py +++ b/src/zenml/integrations/whylogs/secret_schemas/whylabs_secret_schema.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Implementation for Seldon secret schemas.""" -from typing import Optional from zenml.secret.base_secret import BaseSecretSchema @@ -30,4 +29,4 @@ class WhylabsSecretSchema(BaseSecretSchema): whylabs_default_org_id: str whylabs_api_key: str - whylabs_default_dataset_id: Optional[str] = None + whylabs_default_dataset_id: str | None = None diff --git a/src/zenml/integrations/whylogs/steps/__init__.py b/src/zenml/integrations/whylogs/steps/__init__.py index 18c7e13f8cf..2b54b72f223 100644 --- a/src/zenml/integrations/whylogs/steps/__init__.py +++ b/src/zenml/integrations/whylogs/steps/__init__.py @@ -13,7 +13,3 @@ # permissions and limitations under the License. """Initialization of the whylogs steps.""" -from zenml.integrations.whylogs.steps.whylogs_profiler import ( - whylogs_profiler_step, - get_whylogs_profiler_step, -) diff --git a/src/zenml/integrations/whylogs/steps/whylogs_profiler.py b/src/zenml/integrations/whylogs/steps/whylogs_profiler.py index cad37bd91e1..4810a416251 100644 --- a/src/zenml/integrations/whylogs/steps/whylogs_profiler.py +++ b/src/zenml/integrations/whylogs/steps/whylogs_profiler.py @@ -14,7 +14,7 @@ """Implementation of the whylogs profiler step.""" import datetime -from typing import Optional, cast +from typing import cast import pandas as pd from whylogs.core import DatasetProfileView # type: ignore @@ -34,7 +34,7 @@ @step def whylogs_profiler_step( dataset: pd.DataFrame, - dataset_timestamp: Optional[datetime.datetime] = None, + dataset_timestamp: datetime.datetime | None = None, ) -> DatasetProfileView: """Generate a whylogs `DatasetProfileView` from a given `pd.DataFrame`. @@ -55,8 +55,8 @@ def whylogs_profiler_step( def get_whylogs_profiler_step( - dataset_timestamp: Optional[datetime.datetime] = None, - dataset_id: Optional[str] = None, + dataset_timestamp: datetime.datetime | None = None, + dataset_id: str | None = None, enable_whylabs: bool = True, ) -> BaseStep: """Shortcut function to create a new instance of the WhylogsProfilerStep step. diff --git a/src/zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py b/src/zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py index 3c36bea0cce..d85481ec201 100644 --- a/src/zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py +++ b/src/zenml/integrations/xgboost/materializers/xgboost_booster_materializer.py @@ -14,7 +14,7 @@ """Implementation of an XGBoost booster materializer.""" import os -from typing import Any, ClassVar, Tuple, Type +from typing import Any, ClassVar import xgboost as xgb @@ -28,10 +28,10 @@ class XgboostBoosterMaterializer(BaseMaterializer): """Materializer to read data to and from xgboost.Booster.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (xgb.Booster,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (xgb.Booster,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.MODEL - def load(self, data_type: Type[Any]) -> xgb.Booster: + def load(self, data_type: type[Any]) -> xgb.Booster: """Reads a xgboost Booster model from a serialized JSON file. Args: diff --git a/src/zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py b/src/zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py index 6ee8d687211..d2d2eb423c2 100644 --- a/src/zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py +++ b/src/zenml/integrations/xgboost/materializers/xgboost_dmatrix_materializer.py @@ -14,7 +14,7 @@ """Implementation of the XGBoost dmatrix materializer.""" import os -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar import xgboost as xgb @@ -31,10 +31,10 @@ class XgboostDMatrixMaterializer(BaseMaterializer): """Materializer to read data to and from xgboost.DMatrix.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (xgb.DMatrix,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (xgb.DMatrix,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - def load(self, data_type: Type[Any]) -> xgb.DMatrix: + def load(self, data_type: type[Any]) -> xgb.DMatrix: """Reads a xgboost.DMatrix binary file and loads it. Args: @@ -69,7 +69,7 @@ def save(self, matrix: xgb.DMatrix) -> None: def extract_metadata( self, dataset: xgb.DMatrix - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Extract metadata from the given `Dataset` object. Args: diff --git a/src/zenml/io/fileio.py b/src/zenml/io/fileio.py index b639c9df8ce..b60192e03c8 100644 --- a/src/zenml/io/fileio.py +++ b/src/zenml/io/fileio.py @@ -14,7 +14,8 @@ """Functionality for reading, writing and managing files.""" import os -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type +from typing import Any +from collections.abc import Callable, Iterable # this import required for CI to get local filesystem from zenml.io import local_filesystem # noqa @@ -25,7 +26,7 @@ logger = get_logger(__name__) -def _get_filesystem(path: "PathType") -> Type["BaseFilesystem"]: +def _get_filesystem(path: "PathType") -> type["BaseFilesystem"]: """Returns a filesystem class for a given path from the registry. Args: @@ -106,7 +107,7 @@ def exists(path: "PathType") -> bool: return _get_filesystem(path).exists(path) -def glob(pattern: "PathType") -> List["PathType"]: +def glob(pattern: "PathType") -> list["PathType"]: """Find all files matching the given pattern. Args: @@ -130,7 +131,7 @@ def isdir(path: "PathType") -> bool: return _get_filesystem(path).isdir(path) -def listdir(path: str, only_file_names: bool = True) -> List[str]: +def listdir(path: str, only_file_names: bool = True) -> list[str]: """Lists all files in a directory. Args: @@ -147,7 +148,7 @@ def listdir(path: str, only_file_names: bool = True) -> List[str]: else convert_to_str(f) for f in _get_filesystem(path).listdir(path) ] - except IOError: + except OSError: logger.debug(f"Dir {path} not found.") return [] @@ -236,7 +237,7 @@ def stat(path: "PathType") -> Any: return _get_filesystem(path).stat(path) -def size(path: "PathType") -> Optional[int]: +def size(path: "PathType") -> int | None: """Get the size of a file or directory in bytes. Args: @@ -277,8 +278,8 @@ def size(path: "PathType") -> Optional[int]: def walk( top: "PathType", topdown: bool = True, - onerror: Optional[Callable[..., None]] = None, -) -> Iterable[Tuple["PathType", List["PathType"], List["PathType"]]]: + onerror: Callable[..., None] | None = None, +) -> Iterable[tuple["PathType", list["PathType"], list["PathType"]]]: """Return an iterator that walks the contents of the given directory. Args: diff --git a/src/zenml/io/filesystem.py b/src/zenml/io/filesystem.py index 744bcc48387..3e6648ef94e 100644 --- a/src/zenml/io/filesystem.py +++ b/src/zenml/io/filesystem.py @@ -30,15 +30,10 @@ from abc import ABC, abstractmethod from typing import ( Any, - Callable, ClassVar, - Iterable, - List, - Optional, - Set, - Tuple, Union, ) +from collections.abc import Callable, Iterable PathType = Union[bytes, str] @@ -50,7 +45,7 @@ class BaseFilesystem(ABC): https://github.com/tensorflow/tfx/blob/master/tfx/dsl/io/filesystem.py """ - SUPPORTED_SCHEMES: ClassVar[Set[str]] + SUPPORTED_SCHEMES: ClassVar[set[str]] @staticmethod @abstractmethod @@ -96,7 +91,7 @@ def exists(path: PathType) -> bool: @staticmethod @abstractmethod - def glob(pattern: PathType) -> List[PathType]: + def glob(pattern: PathType) -> list[PathType]: """Find all files matching the given pattern. Args: @@ -120,7 +115,7 @@ def isdir(path: PathType) -> bool: @staticmethod @abstractmethod - def listdir(path: PathType) -> List[PathType]: + def listdir(path: PathType) -> list[PathType]: """Lists all files in a directory. Args: @@ -218,8 +213,8 @@ def size(path: PathType) -> int: def walk( top: PathType, topdown: bool = True, - onerror: Optional[Callable[..., None]] = None, - ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]: + onerror: Callable[..., None] | None = None, + ) -> Iterable[tuple[PathType, list[PathType], list[PathType]]]: """Return an iterator that walks the contents of the given directory. Args: diff --git a/src/zenml/io/filesystem_registry.py b/src/zenml/io/filesystem_registry.py index 33194ac2e24..1af49e1cda6 100644 --- a/src/zenml/io/filesystem_registry.py +++ b/src/zenml/io/filesystem_registry.py @@ -29,7 +29,7 @@ import re from threading import Lock -from typing import TYPE_CHECKING, Dict, Type +from typing import TYPE_CHECKING from zenml.logger import get_logger @@ -45,10 +45,10 @@ class FileIORegistry: def __init__(self) -> None: """Initialize the registry.""" - self._filesystems: Dict["PathType", Type["BaseFilesystem"]] = {} + self._filesystems: dict["PathType", type["BaseFilesystem"]] = {} self._registration_lock = Lock() - def register(self, filesystem_cls: Type["BaseFilesystem"]) -> None: + def register(self, filesystem_cls: type["BaseFilesystem"]) -> None: """Register a filesystem implementation. Args: @@ -70,7 +70,7 @@ def register(self, filesystem_cls: Type["BaseFilesystem"]) -> None: def get_filesystem_for_scheme( self, scheme: "PathType" - ) -> Type["BaseFilesystem"]: + ) -> type["BaseFilesystem"]: """Get filesystem plugin for given scheme string. Args: @@ -96,7 +96,7 @@ def get_filesystem_for_scheme( def get_filesystem_for_path( self, path: "PathType" - ) -> Type["BaseFilesystem"]: + ) -> type["BaseFilesystem"]: """Get filesystem plugin for given path. Args: diff --git a/src/zenml/io/local_filesystem.py b/src/zenml/io/local_filesystem.py index efc428e429d..696788da18e 100644 --- a/src/zenml/io/local_filesystem.py +++ b/src/zenml/io/local_filesystem.py @@ -32,14 +32,9 @@ import shutil from typing import ( Any, - Callable, ClassVar, - Iterable, - List, - Optional, - Set, - Tuple, ) +from collections.abc import Callable, Iterable from zenml.io.filesystem import BaseFilesystem, PathType from zenml.io.filesystem_registry import default_filesystem_registry @@ -52,7 +47,7 @@ class LocalFilesystem(BaseFilesystem): https://github.com/tensorflow/tfx/blob/master/tfx/dsl/io/plugins/local.py """ - SUPPORTED_SCHEMES: ClassVar[Set[str]] = {""} + SUPPORTED_SCHEMES: ClassVar[set[str]] = {""} @staticmethod def open(path: PathType, mode: str = "r") -> Any: @@ -103,7 +98,7 @@ def exists(path: PathType) -> bool: return os.path.exists(path) @staticmethod - def glob(pattern: PathType) -> List[PathType]: + def glob(pattern: PathType) -> list[PathType]: """Return the paths that match a glob pattern. Args: @@ -127,7 +122,7 @@ def isdir(path: PathType) -> bool: return os.path.isdir(path) @staticmethod - def listdir(path: PathType) -> List[PathType]: + def listdir(path: PathType) -> list[PathType]: """Returns a list of files under a given directory in the filesystem. Args: @@ -223,8 +218,8 @@ def size(path: PathType) -> int: def walk( top: PathType, topdown: bool = True, - onerror: Optional[Callable[..., None]] = None, - ) -> Iterable[Tuple[PathType, List[PathType], List[PathType]]]: + onerror: Callable[..., None] | None = None, + ) -> Iterable[tuple[PathType, list[PathType], list[PathType]]]: """Return an iterator that walks the contents of the given directory. Args: diff --git a/src/zenml/logger.py b/src/zenml/logger.py index 24bdd1ef7c8..9f3d53eb6f5 100644 --- a/src/zenml/logger.py +++ b/src/zenml/logger.py @@ -19,7 +19,7 @@ import re import sys from contextvars import ContextVar -from typing import TYPE_CHECKING, Any, Dict +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from zenml.logging.step_logging import ArtifactStoreHandler @@ -112,7 +112,7 @@ def _get_format_template(self, record: logging.LogRecord) -> str: else: return "%(message)s" - COLORS: Dict[LoggingLevels, str] = { + COLORS: dict[LoggingLevels, str] = { LoggingLevels.DEBUG: grey, LoggingLevels.INFO: white, LoggingLevels.WARN: yellow, diff --git a/src/zenml/logging/step_logging.py b/src/zenml/logging/step_logging.py index 52c27aeafec..77db26e8b80 100644 --- a/src/zenml/logging/step_logging.py +++ b/src/zenml/logging/step_logging.py @@ -26,12 +26,8 @@ from types import TracebackType from typing import ( Any, - Iterator, - List, - Optional, - Type, - Union, ) +from collections.abc import Iterator from uuid import UUID, uuid4 from pydantic import BaseModel, Field @@ -86,26 +82,26 @@ class LogEntry(BaseModel): """A structured log entry with parsed information.""" message: str = Field(description="The log message content") - name: Optional[str] = Field( + name: str | None = Field( default=None, description="The name of the logger", ) - level: Optional[LoggingLevels] = Field( + level: LoggingLevels | None = Field( default=None, description="The log level", ) - timestamp: Optional[datetime] = Field( + timestamp: datetime | None = Field( default=None, description="When the log was created", ) - module: Optional[str] = Field( + module: str | None = Field( default=None, description="The module that generated this log entry" ) - filename: Optional[str] = Field( + filename: str | None = Field( default=None, description="The name of the file that generated this log entry", ) - lineno: Optional[int] = Field( + lineno: int | None = Field( default=None, description="The fileno that generated this log entry" ) chunk_index: int = Field( @@ -189,7 +185,7 @@ def emit(self, record: logging.LogRecord) -> None: except Exception: pass - def _split_to_chunks(self, message: str) -> List[str]: + def _split_to_chunks(self, message: str) -> list[str]: """Split a large message into chunks. Args: @@ -243,7 +239,7 @@ def remove_ansi_escape_codes(text: str) -> str: return ansi_escape.sub("", text) -def parse_log_entry(log_line: str) -> Optional[LogEntry]: +def parse_log_entry(log_line: str) -> LogEntry | None: """Parse a single log entry into a LogEntry object. Handles two formats: @@ -288,8 +284,8 @@ def parse_log_entry(log_line: str) -> Optional[LogEntry]: def prepare_logs_uri( artifact_store: "BaseArtifactStore", - step_name: Optional[str] = None, - log_key: Optional[str] = None, + step_name: str | None = None, + log_key: str | None = None, ) -> str: """Generates and prepares a URI for the log file or folder for a step. @@ -334,9 +330,9 @@ def prepare_logs_uri( def fetch_log_records( zen_store: "BaseZenStore", - artifact_store_id: Union[str, UUID], + artifact_store_id: str | UUID, logs_uri: str, -) -> List[LogEntry]: +) -> list[LogEntry]: """Fetches log entries. Args: @@ -363,7 +359,7 @@ def fetch_log_records( def _stream_logs_line_by_line( zen_store: "BaseZenStore", - artifact_store_id: Union[str, UUID], + artifact_store_id: str | UUID, logs_uri: str, ) -> Iterator[str]: """Stream logs line by line without loading the entire file into memory. @@ -452,7 +448,7 @@ def __init__( # Queue and log storage thread for async processing self.log_queue: queue.Queue[str] = queue.Queue(maxsize=max_queue_size) - self.log_storage_thread: Optional[threading.Thread] = None + self.log_storage_thread: threading.Thread | None = None self.shutdown_event = threading.Event() self.merge_event = threading.Event() @@ -632,7 +628,7 @@ def _get_timestamped_filename(self, suffix: str = "") -> str: """ return f"{time.time()}{suffix}{LOGS_EXTENSION}" - def write_buffer(self, buffer_to_write: List[str]) -> None: + def write_buffer(self, buffer_to_write: list[str]) -> None: """Write the given buffer to file. This runs in the log storage thread. Args: @@ -757,8 +753,8 @@ def __init__( # Additional configuration self.prepend_step_name = prepend_step_name - self.original_step_names_in_console: Optional[bool] = None - self._original_root_level: Optional[int] = None + self.original_step_names_in_console: bool | None = None + self._original_root_level: int | None = None def __enter__(self) -> "PipelineLogsStorageContext": """Enter condition of the context manager. @@ -805,9 +801,9 @@ def __enter__(self) -> "PipelineLogsStorageContext": def __exit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: """Exit condition of the context manager. @@ -860,7 +856,7 @@ def __exit__( def setup_orchestrator_logging( run_id: UUID, snapshot: "PipelineSnapshotResponse", - logs_response: Optional[LogsResponse] = None, + logs_response: LogsResponse | None = None, ) -> Any: """Set up logging for an orchestrator environment. diff --git a/src/zenml/login/credentials.py b/src/zenml/login/credentials.py index da203a14639..980a0aefd35 100644 --- a/src/zenml/login/credentials.py +++ b/src/zenml/login/credentials.py @@ -14,7 +14,7 @@ """ZenML login credentials models.""" from datetime import datetime, timedelta -from typing import Any, Dict, Optional, Union +from typing import Any from urllib.parse import urlparse from uuid import UUID @@ -43,14 +43,14 @@ class APIToken(BaseModel): """Cached API Token.""" access_token: str - expires_in: Optional[int] = None - expires_at: Optional[datetime] = None - leeway: Optional[int] = None - device_id: Optional[UUID] = None - device_metadata: Optional[Dict[str, Any]] = None + expires_in: int | None = None + expires_at: datetime | None = None + leeway: int | None = None + device_id: UUID | None = None + device_metadata: dict[str, Any] | None = None @property - def expires_at_with_leeway(self) -> Optional[datetime]: + def expires_at_with_leeway(self) -> datetime | None: """Get the token expiration time with leeway. Returns: @@ -84,25 +84,25 @@ class ServerCredentials(BaseModel): """Cached Server Credentials.""" url: str - api_key: Optional[str] = None - api_token: Optional[APIToken] = None - username: Optional[str] = None - password: Optional[str] = None + api_key: str | None = None + api_token: APIToken | None = None + username: str | None = None + password: str | None = None # Extra server attributes - deployment_type: Optional[ServerDeploymentType] = None - server_id: Optional[UUID] = None - server_name: Optional[str] = None - status: Optional[str] = None - version: Optional[str] = None + deployment_type: ServerDeploymentType | None = None + server_id: UUID | None = None + server_name: str | None = None + status: str | None = None + version: str | None = None # Pro server attributes - organization_name: Optional[str] = None - organization_id: Optional[UUID] = None - workspace_name: Optional[str] = None - workspace_id: Optional[UUID] = None - pro_api_url: Optional[str] = None - pro_dashboard_url: Optional[str] = None + organization_name: str | None = None + organization_id: UUID | None = None + workspace_name: str | None = None + workspace_id: UUID | None = None + pro_api_url: str | None = None + pro_dashboard_url: str | None = None @property def id(self) -> str: @@ -166,7 +166,7 @@ def has_valid_token(self) -> bool: return self.api_token is not None and not self.api_token.expired def update_server_info( - self, server_info: Union[ServerModel, WorkspaceRead] + self, server_info: ServerModel | WorkspaceRead ) -> None: """Update with server information received from the server itself or from a ZenML Pro workspace descriptor. diff --git a/src/zenml/login/credentials_store.py b/src/zenml/login/credentials_store.py index 218ff2b230a..d1145e13240 100644 --- a/src/zenml/login/credentials_store.py +++ b/src/zenml/login/credentials_store.py @@ -15,7 +15,7 @@ import os from datetime import timedelta -from typing import Dict, List, Optional, Tuple, Union, cast +from typing import Optional, cast from zenml.config.global_config import GlobalConfiguration from zenml.constants import ( @@ -89,8 +89,8 @@ class CredentialsStore(metaclass=SingletonMetaClass): """ - credentials: Dict[str, ServerCredentials] - last_modified_time: Optional[float] = None + credentials: dict[str, ServerCredentials] + last_modified_time: float | None = None def __init__(self) -> None: """Initializes the login credentials store with values loaded from the credentials YAML file. @@ -232,7 +232,7 @@ def check_and_reload_from_file(self) -> None: def get_password( self, server_url: str - ) -> Tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """Retrieve the username and password from the credentials store for a specific server URL. Args: @@ -248,7 +248,7 @@ def get_password( return credential.username, credential.password return None, None - def get_api_key(self, server_url: str) -> Optional[str]: + def get_api_key(self, server_url: str) -> str | None: """Retrieve an API key from the credentials store for a specific server URL. Args: @@ -265,7 +265,7 @@ def get_api_key(self, server_url: str) -> Optional[str]: def get_token( self, server_url: str, allow_expired: bool = False - ) -> Optional[APIToken]: + ) -> APIToken | None: """Retrieve a valid token from the credentials store for a specific server URL. Args: @@ -285,7 +285,7 @@ def get_token( return token return None - def get_credentials(self, server_url: str) -> Optional[ServerCredentials]: + def get_credentials(self, server_url: str) -> ServerCredentials | None: """Retrieve the credentials for a specific server URL. Args: @@ -314,7 +314,7 @@ def has_valid_credentials(self, server_url: str) -> bool: return True return False - def get_pro_api_key(self, pro_api_url: str) -> Optional[str]: + def get_pro_api_key(self, pro_api_url: str) -> str | None: """Retrieve an API key from the credentials store for a ZenML Pro API server. Args: @@ -330,7 +330,7 @@ def get_pro_api_key(self, pro_api_url: str) -> Optional[str]: def get_pro_token( self, pro_api_url: str, allow_expired: bool = False - ) -> Optional[APIToken]: + ) -> APIToken | None: """Retrieve a valid token from the credentials store for a ZenML Pro API server. Args: @@ -351,7 +351,7 @@ def get_pro_token( def get_pro_credentials( self, pro_api_url: str - ) -> Optional[ServerCredentials]: + ) -> ServerCredentials | None: """Retrieve valid credentials from the credentials store for a ZenML Pro API server. Args: @@ -587,7 +587,7 @@ def set_bare_token( def update_server_info( self, server_url: str, - server_info: Union[ServerModel, WorkspaceRead], + server_info: ServerModel | WorkspaceRead, ) -> None: """Update the server information stored for a specific server URL. @@ -641,8 +641,8 @@ def clear_credentials(self, server_url: str) -> None: self._save_credentials() def list_credentials( - self, type: Optional[ServerType] = None - ) -> List[ServerCredentials]: + self, type: ServerType | None = None + ) -> list[ServerCredentials]: """Get all credentials stored in the credentials store. Args: diff --git a/src/zenml/login/pro/client.py b/src/zenml/login/pro/client.py index d5562cc89c6..bec0bf29f54 100644 --- a/src/zenml/login/pro/client.py +++ b/src/zenml/login/pro/client.py @@ -16,10 +16,7 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Type, TypeVar, Union, ) @@ -45,7 +42,7 @@ from zenml.login.pro.workspace.client import WorkspaceClient # type alias for possible json payloads (the Anys are recursive Json instances) -Json = Union[Dict[str, Any], List[Any], str, int, float, bool, None] +Json = Union[dict[str, Any], list[Any], str, int, float, bool, None] AnyResponse = TypeVar("AnyResponse", bound=BaseRestAPIModel) @@ -55,8 +52,8 @@ class ZenMLProClient(metaclass=SingletonMetaClass): """ZenML Pro client.""" _url: str - _api_token: Optional[APIToken] = None - _session: Optional[requests.Session] = None + _api_token: APIToken | None = None + _session: requests.Session | None = None _workspace: Optional["WorkspaceClient"] = None _organization: Optional["OrganizationClient"] = None @@ -233,7 +230,7 @@ def _request( self, method: str, url: str, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, **kwargs: Any, ) -> Json: """Make a request to the REST API. @@ -265,7 +262,7 @@ def _request( def get( self, path: str, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, **kwargs: Any, ) -> Json: """Make a GET request to the given endpoint path. @@ -289,7 +286,7 @@ def get( def delete( self, path: str, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, **kwargs: Any, ) -> Json: """Make a DELETE request to the given endpoint path. @@ -314,7 +311,7 @@ def post( self, path: str, body: BaseRestAPIModel, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, **kwargs: Any, ) -> Json: """Make a POST request to the given endpoint path. @@ -340,8 +337,8 @@ def post( def put( self, path: str, - body: Optional[BaseRestAPIModel] = None, - params: Optional[Dict[str, Any]] = None, + body: BaseRestAPIModel | None = None, + params: dict[str, Any] | None = None, **kwargs: Any, ) -> Json: """Make a PUT request to the given endpoint path. @@ -370,8 +367,8 @@ def put( def patch( self, path: str, - body: Optional[BaseRestAPIModel] = None, - params: Optional[Dict[str, Any]] = None, + body: BaseRestAPIModel | None = None, + params: dict[str, Any] | None = None, **kwargs: Any, ) -> Json: """Make a PATCH request to the given endpoint path. @@ -400,9 +397,9 @@ def patch( def _create_resource( self, resource: BaseRestAPIModel, - response_model: Type[AnyResponse], + response_model: type[AnyResponse], route: str, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> AnyResponse: """Create a new resource. @@ -422,9 +419,9 @@ def _create_resource( def _get_resource( self, - resource_id: Union[str, int, UUID], + resource_id: str | int | UUID, route: str, - response_model: Type[AnyResponse], + response_model: type[AnyResponse], **params: Any, ) -> AnyResponse: """Retrieve a single resource. @@ -446,9 +443,9 @@ def _get_resource( def _list_resources( self, route: str, - response_model: Type[AnyResponse], + response_model: type[AnyResponse], **params: Any, - ) -> List[AnyResponse]: + ) -> list[AnyResponse]: """Retrieve a list of resources filtered by some criteria. Args: @@ -473,9 +470,9 @@ def _list_resources( def _update_resource( self, - resource_id: Union[str, int, UUID], + resource_id: str | int | UUID, resource_update: BaseRestAPIModel, - response_model: Type[AnyResponse], + response_model: type[AnyResponse], route: str, **params: Any, ) -> AnyResponse: @@ -501,7 +498,7 @@ def _update_resource( return response_model.model_validate(response_body) def _delete_resource( - self, resource_id: Union[str, UUID], route: str + self, resource_id: str | UUID, route: str ) -> None: """Delete a resource. diff --git a/src/zenml/login/pro/organization/client.py b/src/zenml/login/pro/organization/client.py index 6595114232e..e4e5daa81ec 100644 --- a/src/zenml/login/pro/organization/client.py +++ b/src/zenml/login/pro/organization/client.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """ZenML Pro organization client.""" -from typing import List, Union from uuid import UUID from zenml.logger import get_logger @@ -41,7 +40,7 @@ def __init__( def get( self, - id_or_name: Union[UUID, str], + id_or_name: UUID | str, ) -> OrganizationRead: """Get an organization by id or name. @@ -61,7 +60,7 @@ async def list( self, offset: int = 0, limit: int = 20, - ) -> List[OrganizationRead]: + ) -> list[OrganizationRead]: """List organizations. Args: diff --git a/src/zenml/login/pro/organization/models.py b/src/zenml/login/pro/organization/models.py index 9a03e7002a5..7e4b3e8f034 100644 --- a/src/zenml/login/pro/organization/models.py +++ b/src/zenml/login/pro/organization/models.py @@ -14,7 +14,6 @@ """ZenML Pro organization models.""" from datetime import datetime -from typing import Optional from uuid import UUID from zenml.login.pro.models import BaseRestAPIModel @@ -26,7 +25,7 @@ class OrganizationRead(BaseRestAPIModel): id: UUID name: str - description: Optional[str] = None + description: str | None = None created: datetime updated: datetime diff --git a/src/zenml/login/pro/workspace/client.py b/src/zenml/login/pro/workspace/client.py index cd1227b7c87..e30d5a3348e 100644 --- a/src/zenml/login/pro/workspace/client.py +++ b/src/zenml/login/pro/workspace/client.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """ZenML Pro workspace client.""" -from typing import List, Optional from uuid import UUID from zenml.logger import get_logger @@ -58,12 +57,12 @@ def list( self, offset: int = 0, limit: int = 20, - workspace_name: Optional[str] = None, - url: Optional[str] = None, - organization_id: Optional[UUID] = None, - status: Optional[WorkspaceStatus] = None, + workspace_name: str | None = None, + url: str | None = None, + organization_id: UUID | None = None, + status: WorkspaceStatus | None = None, member_only: bool = False, - ) -> List[WorkspaceRead]: + ) -> list[WorkspaceRead]: """List workspaces. Args: diff --git a/src/zenml/login/pro/workspace/models.py b/src/zenml/login/pro/workspace/models.py index 6b7e67a6196..69e1e881f24 100644 --- a/src/zenml/login/pro/workspace/models.py +++ b/src/zenml/login/pro/workspace/models.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """ZenML Pro workspace models.""" -from typing import Optional from uuid import UUID from pydantic import Field @@ -67,7 +66,7 @@ class ZenMLServiceStatus(BaseRestAPIModel): server_url: str = Field( description="The ZenML server URL.", ) - version: Optional[str] = Field( + version: str | None = Field( default=None, description="The ZenML server version.", ) @@ -76,11 +75,11 @@ class ZenMLServiceStatus(BaseRestAPIModel): class ZenMLServiceRead(BaseRestAPIModel): """Pydantic Model for viewing a ZenML service.""" - configuration: Optional[ZenMLServiceConfiguration] = Field( + configuration: ZenMLServiceConfiguration | None = Field( description="The service configuration." ) - status: Optional[ZenMLServiceStatus] = Field( + status: ZenMLServiceStatus | None = Field( default=None, description="Information about the service status. Only set if the " "service is deployed and active.", @@ -93,7 +92,7 @@ class WorkspaceRead(BaseRestAPIModel): id: UUID name: str - description: Optional[str] = Field( + description: str | None = Field( default=None, description="The description of the workspace." ) @@ -129,7 +128,7 @@ def organization_name(self) -> str: return self.organization.name @property - def version(self) -> Optional[str]: + def version(self) -> str | None: """Get the ZenML service version. Returns: @@ -144,7 +143,7 @@ def version(self) -> Optional[str]: return version @property - def url(self) -> Optional[str]: + def url(self) -> str | None: """Get the ZenML server URL. Returns: diff --git a/src/zenml/login/server_info.py b/src/zenml/login/server_info.py index ba4a240e0cd..4a823a196ab 100644 --- a/src/zenml/login/server_info.py +++ b/src/zenml/login/server_info.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """ZenML server information retrieval.""" -from typing import Optional from zenml.logger import get_logger from zenml.models import ServerModel @@ -25,7 +24,7 @@ logger = get_logger(__name__) -def get_server_info(url: str) -> Optional[ServerModel]: +def get_server_info(url: str) -> ServerModel | None: """Retrieve server information from a remote ZenML server. Args: diff --git a/src/zenml/login/web_login.py b/src/zenml/login/web_login.py index 932530908ed..3acffa398f0 100644 --- a/src/zenml/login/web_login.py +++ b/src/zenml/login/web_login.py @@ -16,7 +16,6 @@ import platform import time import webbrowser -from typing import Optional, Union import requests @@ -40,9 +39,9 @@ def web_login( - url: Optional[str] = None, - verify_ssl: Optional[Union[str, bool]] = None, - pro_api_url: Optional[str] = None, + url: str | None = None, + verify_ssl: str | bool | None = None, + pro_api_url: str | None = None, ) -> APIToken: """Implements the OAuth2 Device Authorization Grant flow. @@ -86,7 +85,7 @@ def web_login( # Make a request to the OAuth2 server to get the device code and user code. # The client ID used for the request is the unique ID of the ZenML client. - response: Optional[requests.Response] = None + response: requests.Response | None = None # Add the following information in the user agent header to be used by users # to identify the ZenML client: diff --git a/src/zenml/materializers/base_materializer.py b/src/zenml/materializers/base_materializer.py index d9774fd5866..390db197615 100644 --- a/src/zenml/materializers/base_materializer.py +++ b/src/zenml/materializers/base_materializer.py @@ -17,7 +17,8 @@ import inspect import shutil import tempfile -from typing import Any, ClassVar, Dict, Iterator, Optional, Tuple, Type, cast +from typing import Any, ClassVar, cast +from collections.abc import Iterator from zenml.artifact_stores.base_artifact_store import BaseArtifactStore from zenml.enums import ArtifactType, VisualizationType @@ -37,7 +38,7 @@ class BaseMaterializerMeta(type): """ def __new__( - mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any] + mcs, name: str, bases: tuple[type[Any], ...], dct: dict[str, Any] ) -> "BaseMaterializerMeta": """Creates a Materializer class and registers it at the `MaterializerRegistry`. @@ -53,7 +54,7 @@ def __new__( MaterializerInterfaceError: If the class was improperly defined. """ cls = cast( - Type["BaseMaterializer"], super().__new__(mcs, name, bases, dct) + type["BaseMaterializer"], super().__new__(mcs, name, bases, dct) ) if not cls._DOCS_BUILDING_MODE: # Skip the following validation and registration for base classes. @@ -112,7 +113,7 @@ class BaseMaterializer(metaclass=BaseMaterializerMeta): """Base Materializer to realize artifact data.""" ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.BASE - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = () + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = () # `SKIP_REGISTRATION` can be set to True to not register the class in the # materializer registry. This is primarily useful for defining base classes. @@ -123,7 +124,7 @@ class BaseMaterializer(metaclass=BaseMaterializerMeta): _DOCS_BUILDING_MODE: ClassVar[bool] = False def __init__( - self, uri: str, artifact_store: Optional[BaseArtifactStore] = None + self, uri: str, artifact_store: BaseArtifactStore | None = None ): """Initializes a materializer with the given URI. @@ -155,7 +156,7 @@ def artifact_store(self) -> BaseArtifactStore: # Public Interface # ================ - def load(self, data_type: Type[Any]) -> Any: + def load(self, data_type: type[Any]) -> Any: """Write logic here to load the data of an artifact. Args: @@ -175,7 +176,7 @@ def save(self, data: Any) -> None: """ # write `data` into self.uri - def save_visualizations(self, data: Any) -> Dict[str, VisualizationType]: + def save_visualizations(self, data: Any) -> dict[str, VisualizationType]: """Save visualizations of the given data. If this method is not overridden, no visualizations will be saved. @@ -207,7 +208,7 @@ def save_visualizations(self, data: Any) -> Dict[str, VisualizationType]: # Optionally, save some visualizations of `data` inside `self.uri`. return {} - def extract_metadata(self, data: Any) -> Dict[str, "MetadataType"]: + def extract_metadata(self, data: Any) -> dict[str, "MetadataType"]: """Extract metadata from the given data. This metadata will be tracked and displayed alongside the artifact. @@ -229,7 +230,7 @@ def extract_metadata(self, data: Any) -> Dict[str, "MetadataType"]: # Optionally, extract some metadata from `data` for ZenML to store. return {} - def compute_content_hash(self, data: Any) -> Optional[str]: + def compute_content_hash(self, data: Any) -> str | None: """Compute the content hash of the given data. Args: @@ -243,7 +244,7 @@ def compute_content_hash(self, data: Any) -> Optional[str]: # ================ # Internal Methods # ================ - def validate_save_type_compatibility(self, data_type: Type[Any]) -> None: + def validate_save_type_compatibility(self, data_type: type[Any]) -> None: """Checks whether the materializer can save the given type. Args: @@ -259,7 +260,7 @@ def validate_save_type_compatibility(self, data_type: Type[Any]) -> None: f"{self.ASSOCIATED_TYPES}." ) - def validate_load_type_compatibility(self, data_type: Type[Any]) -> None: + def validate_load_type_compatibility(self, data_type: type[Any]) -> None: """Checks whether the materializer can load the given type. Args: @@ -276,7 +277,7 @@ def validate_load_type_compatibility(self, data_type: Type[Any]) -> None: ) @classmethod - def can_save_type(cls, data_type: Type[Any]) -> bool: + def can_save_type(cls, data_type: type[Any]) -> bool: """Whether the materializer can save a certain type. Args: @@ -291,7 +292,7 @@ def can_save_type(cls, data_type: Type[Any]) -> bool: ) @classmethod - def can_load_type(cls, data_type: Type[Any]) -> bool: + def can_load_type(cls, data_type: type[Any]) -> bool: """Whether the materializer can load an artifact as the given type. Args: @@ -309,7 +310,7 @@ def can_load_type(cls, data_type: Type[Any]) -> bool: for associated_type in cls.ASSOCIATED_TYPES ) - def extract_full_metadata(self, data: Any) -> Dict[str, "MetadataType"]: + def extract_full_metadata(self, data: Any) -> dict[str, "MetadataType"]: """Extract both base and custom metadata from the given data. Args: @@ -322,7 +323,7 @@ def extract_full_metadata(self, data: Any) -> Dict[str, "MetadataType"]: custom_metadata = self.extract_metadata(data) return {**base_metadata, **custom_metadata} - def _extract_base_metadata(self, data: Any) -> Dict[str, "MetadataType"]: + def _extract_base_metadata(self, data: Any) -> dict[str, "MetadataType"]: """Extract metadata from the given data. This metadata will be extracted for all artifacts in addition to the diff --git a/src/zenml/materializers/built_in_materializer.py b/src/zenml/materializers/built_in_materializer.py index 39a34142b4c..39ec2830136 100644 --- a/src/zenml/materializers/built_in_materializer.py +++ b/src/zenml/materializers/built_in_materializer.py @@ -20,14 +20,8 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - Iterable, - List, - Optional, - Tuple, - Type, - Union, ) +from collections.abc import Iterable from zenml.artifact_stores.base_artifact_store import BaseArtifactStore from zenml.constants import ( @@ -63,10 +57,10 @@ class BuiltInMaterializer(BaseMaterializer): """Handle JSON-serializable basic types (`bool`, `float`, `int`, `str`).""" ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = BASIC_TYPES + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = BASIC_TYPES def __init__( - self, uri: str, artifact_store: Optional[BaseArtifactStore] = None + self, uri: str, artifact_store: BaseArtifactStore | None = None ): """Define `self.data_path`. @@ -78,7 +72,7 @@ def __init__( self.data_path = os.path.join(self.uri, DEFAULT_FILENAME) def load( - self, data_type: Union[Type[bool], Type[float], Type[int], Type[str]] + self, data_type: type[bool] | type[float] | type[int] | type[str] ) -> Any: """Reads basic primitive types from JSON. @@ -97,7 +91,7 @@ def load( ) return contents - def save(self, data: Union[bool, float, int, str]) -> None: + def save(self, data: bool | float | int | str) -> None: """Serialize a basic type to JSON. Args: @@ -110,8 +104,8 @@ def save(self, data: Union[bool, float, int, str]) -> None: ) def save_visualizations( - self, data: Union[bool, float, int, str] - ) -> Dict[str, VisualizationType]: + self, data: bool | float | int | str + ) -> dict[str, VisualizationType]: """Save visualizations for the given basic type. Args: @@ -123,8 +117,8 @@ def save_visualizations( return {self.data_path.replace("\\", "/"): VisualizationType.JSON} def extract_metadata( - self, data: Union[bool, float, int, str] - ) -> Dict[str, "MetadataType"]: + self, data: bool | float | int | str + ) -> dict[str, "MetadataType"]: """Extract metadata from the given built-in container object. Args: @@ -140,7 +134,7 @@ def extract_metadata( return {} - def compute_content_hash(self, data: Any) -> Optional[str]: + def compute_content_hash(self, data: Any) -> str | None: """Compute the content hash of the given data. Args: @@ -159,10 +153,10 @@ class BytesMaterializer(BaseMaterializer): """Handle `bytes` data type, which is not JSON serializable.""" ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (bytes,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (bytes,) def __init__( - self, uri: str, artifact_store: Optional[BaseArtifactStore] = None + self, uri: str, artifact_store: BaseArtifactStore | None = None ): """Define `self.data_path`. @@ -173,7 +167,7 @@ def __init__( super().__init__(uri, artifact_store) self.data_path = os.path.join(self.uri, DEFAULT_BYTES_FILENAME) - def load(self, data_type: Type[Any]) -> Any: + def load(self, data_type: type[Any]) -> Any: """Reads a bytes object from file. Args: @@ -194,7 +188,7 @@ def save(self, data: Any) -> None: with self.artifact_store.open(self.data_path, "wb") as file_: file_.write(data) - def save_visualizations(self, data: bytes) -> Dict[str, VisualizationType]: + def save_visualizations(self, data: bytes) -> dict[str, VisualizationType]: """Save visualizations for the given bytes data. Args: @@ -205,7 +199,7 @@ def save_visualizations(self, data: bytes) -> Dict[str, VisualizationType]: """ return {self.data_path.replace("\\", "/"): VisualizationType.MARKDOWN} - def compute_content_hash(self, data: Any) -> Optional[str]: + def compute_content_hash(self, data: Any) -> str | None: """Compute the content hash of the given data. Args: @@ -269,7 +263,7 @@ def _custom_json_converter(obj: Any) -> Any: return obj -def find_type_by_str(type_str: str) -> Type[Any]: +def find_type_by_str(type_str: str) -> type[Any]: """Get a Python type, given its string representation. E.g., "" should resolve to `int`. @@ -294,7 +288,7 @@ def find_type_by_str(type_str: str) -> Type[Any]: raise RuntimeError(f"Cannot resolve type '{type_str}'.") -def find_materializer_registry_type(type_: Type[Any]) -> Type[Any]: +def find_materializer_registry_type(type_: type[Any]) -> type[Any]: """For a given type, find the type registered in the registry. This can be either the type itself, or a superclass of the type. @@ -332,7 +326,7 @@ def find_materializer_registry_type(type_: Type[Any]) -> Type[Any]: class BuiltInContainerMaterializer(BaseMaterializer): """Handle built-in container types (dict, list, set, tuple).""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = ( + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = ( dict, list, set, @@ -340,7 +334,7 @@ class BuiltInContainerMaterializer(BaseMaterializer): ) def __init__( - self, uri: str, artifact_store: Optional[BaseArtifactStore] = None + self, uri: str, artifact_store: BaseArtifactStore | None = None ): """Define `self.data_path` and `self.metadata_path`. @@ -352,7 +346,7 @@ def __init__( self.data_path = os.path.join(self.uri, DEFAULT_FILENAME) self.metadata_path = os.path.join(self.uri, DEFAULT_METADATA_FILENAME) - def load(self, data_type: Type[Any]) -> Any: + def load(self, data_type: type[Any]) -> Any: """Reads a materialized built-in container object. If the data was serialized to JSON, deserialize it. @@ -468,8 +462,8 @@ def save(self, data: Any) -> None: # non-serializable list: Materialize each element into a subfolder. # Get path, type, and corresponding materializer for each element. - metadata: List[Dict[str, str]] = [] - materializers: List[BaseMaterializer] = [] + metadata: list[dict[str, str]] = [] + materializers: list[BaseMaterializer] = [] try: for i, element in enumerate(data): element_path = os.path.join(self.uri, str(i)) @@ -506,7 +500,7 @@ def save(self, data: Any) -> None: raise e # save dict type objects to JSON file with JSON visualization type - def save_visualizations(self, data: Any) -> Dict[str, "VisualizationType"]: + def save_visualizations(self, data: Any) -> dict[str, "VisualizationType"]: """Save visualizations for the given data. Args: @@ -522,7 +516,7 @@ def save_visualizations(self, data: Any) -> Dict[str, "VisualizationType"]: return {self.data_path.replace("\\", "/"): VisualizationType.JSON} return {} - def extract_metadata(self, data: Any) -> Dict[str, "MetadataType"]: + def extract_metadata(self, data: Any) -> dict[str, "MetadataType"]: """Extract metadata from the given built-in container object. Args: @@ -535,7 +529,7 @@ def extract_metadata(self, data: Any) -> Dict[str, "MetadataType"]: return {"length": len(data)} return {} - def compute_content_hash(self, data: Any) -> Optional[str]: + def compute_content_hash(self, data: Any) -> str | None: """Compute the content hash of the given data. Args: diff --git a/src/zenml/materializers/cloudpickle_materializer.py b/src/zenml/materializers/cloudpickle_materializer.py index 11b3123c9b5..2045c068782 100644 --- a/src/zenml/materializers/cloudpickle_materializer.py +++ b/src/zenml/materializers/cloudpickle_materializer.py @@ -14,7 +14,7 @@ """Implementation of ZenML's cloudpickle materializer.""" import os -from typing import Any, ClassVar, Tuple, Type +from typing import Any, ClassVar import cloudpickle @@ -44,11 +44,11 @@ class CloudpickleMaterializer(BaseMaterializer): only used as a fallback materializer inside the materializer registry. """ - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (object,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (object,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA SKIP_REGISTRATION: ClassVar[bool] = True - def load(self, data_type: Type[Any]) -> Any: + def load(self, data_type: type[Any]) -> Any: """Reads an artifact from a cloudpickle file. Args: diff --git a/src/zenml/materializers/in_memory_materializer.py b/src/zenml/materializers/in_memory_materializer.py index b21160fb2b1..942afa0e767 100644 --- a/src/zenml/materializers/in_memory_materializer.py +++ b/src/zenml/materializers/in_memory_materializer.py @@ -16,10 +16,6 @@ from typing import ( Any, ClassVar, - Dict, - Optional, - Tuple, - Type, ) from zenml.enums import ArtifactType @@ -30,7 +26,7 @@ class InMemoryMaterializer(BaseMaterializer): """Materializer that stores artifacts in memory.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (object,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (object,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA SKIP_REGISTRATION: ClassVar[bool] = True @@ -44,7 +40,7 @@ def save(self, data: Any) -> None: runtime.put_in_memory_data(self.uri, data) - def load(self, data_type: Type[Any]) -> Any: + def load(self, data_type: type[Any]) -> Any: """Load data from memory. Args: @@ -65,7 +61,7 @@ def load(self, data_type: Type[Any]) -> Any: f"No data available for artifactURI `{self.uri}`" ) - def extract_full_metadata(self, data: Any) -> Dict[str, MetadataType]: + def extract_full_metadata(self, data: Any) -> dict[str, MetadataType]: """No metadata extraction. Args: @@ -76,7 +72,7 @@ def extract_full_metadata(self, data: Any) -> Dict[str, MetadataType]: """ return {} - def save_visualizations(self, data: Any) -> Dict[str, Any]: + def save_visualizations(self, data: Any) -> dict[str, Any]: """No visualizations. Args: @@ -87,7 +83,7 @@ def save_visualizations(self, data: Any) -> Dict[str, Any]: """ return {} - def compute_content_hash(self, data: Any) -> Optional[str]: + def compute_content_hash(self, data: Any) -> str | None: """No content hash computation in serving mode. Args: diff --git a/src/zenml/materializers/materializer_registry.py b/src/zenml/materializers/materializer_registry.py index 509dbb806bf..3340128d7c0 100644 --- a/src/zenml/materializers/materializer_registry.py +++ b/src/zenml/materializers/materializer_registry.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of a default materializer registry.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Type +from typing import TYPE_CHECKING, Any from zenml.logger import get_logger @@ -28,11 +28,11 @@ class MaterializerRegistry: def __init__(self) -> None: """Initialize the materializer registry.""" - self.default_materializer: Optional[Type["BaseMaterializer"]] = None - self.materializer_types: Dict[Type[Any], Type["BaseMaterializer"]] = {} + self.default_materializer: type["BaseMaterializer"] | None = None + self.materializer_types: dict[type[Any], type["BaseMaterializer"]] = {} def register_materializer_type( - self, key: Type[Any], type_: Type["BaseMaterializer"] + self, key: type[Any], type_: type["BaseMaterializer"] ) -> None: """Registers a new materializer. @@ -51,7 +51,7 @@ def register_materializer_type( ) def register_and_overwrite_type( - self, key: Type[Any], type_: Type["BaseMaterializer"] + self, key: type[Any], type_: type["BaseMaterializer"] ) -> None: """Registers a new materializer and also overwrites a default if set. @@ -62,7 +62,7 @@ def register_and_overwrite_type( self.materializer_types[key] = type_ logger.debug(f"Registered materializer {type_} for {key}") - def __getitem__(self, key: Type[Any]) -> Type["BaseMaterializer"]: + def __getitem__(self, key: type[Any]) -> type["BaseMaterializer"]: """Get a single materializers based on the key. Args: @@ -77,7 +77,7 @@ def __getitem__(self, key: Type[Any]) -> Type["BaseMaterializer"]: return materializer return self.get_default_materializer() - def get_default_materializer(self) -> Type["BaseMaterializer"]: + def get_default_materializer(self) -> type["BaseMaterializer"]: """Get the default materializer that is used if no other is found. Returns: @@ -94,7 +94,7 @@ def get_default_materializer(self) -> Type["BaseMaterializer"]: def get_materializer_types( self, - ) -> Dict[Type[Any], Type["BaseMaterializer"]]: + ) -> dict[type[Any], type["BaseMaterializer"]]: """Get all registered materializer types. Returns: @@ -102,7 +102,7 @@ def get_materializer_types( """ return self.materializer_types - def is_registered(self, key: Type[Any]) -> bool: + def is_registered(self, key: type[Any]) -> bool: """Returns if a materializer class is registered for the given type. Args: diff --git a/src/zenml/materializers/path_materializer.py b/src/zenml/materializers/path_materializer.py index e78c33cd4da..b89b709a1a4 100644 --- a/src/zenml/materializers/path_materializer.py +++ b/src/zenml/materializers/path_materializer.py @@ -17,7 +17,7 @@ import shutil import tarfile from pathlib import Path -from typing import Any, ClassVar, Tuple, Type +from typing import Any, ClassVar from zenml.constants import ( ENV_ZENML_DISABLE_PATH_MATERIALIZER, @@ -61,7 +61,7 @@ class PathMaterializer(BaseMaterializer): or directly copying the file if it's a single file. """ - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Path,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (Path,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA ARCHIVE_NAME: ClassVar[str] = "data.tar.gz" FILE_NAME: ClassVar[str] = "file_data" @@ -71,7 +71,7 @@ class PathMaterializer(BaseMaterializer): ENV_ZENML_DISABLE_PATH_MATERIALIZER, default=False ) - def load(self, data_type: Type[Any]) -> Any: + def load(self, data_type: type[Any]) -> Any: """Copy the artifact files to a local temp directory or file. Args: diff --git a/src/zenml/materializers/pydantic_materializer.py b/src/zenml/materializers/pydantic_materializer.py index 4ac8d88ff4d..3ea9cb2e083 100644 --- a/src/zenml/materializers/pydantic_materializer.py +++ b/src/zenml/materializers/pydantic_materializer.py @@ -16,7 +16,7 @@ import hashlib import json import os -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar from pydantic import BaseModel @@ -34,9 +34,9 @@ class PydanticMaterializer(BaseMaterializer): """Handle Pydantic BaseModel objects.""" ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (BaseModel,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (BaseModel,) - def load(self, data_type: Type[BaseModel]) -> Any: + def load(self, data_type: type[BaseModel]) -> Any: """Reads BaseModel from JSON. Args: @@ -58,7 +58,7 @@ def save(self, data: BaseModel) -> None: data_path = os.path.join(self.uri, DEFAULT_FILENAME) yaml_utils.write_json(data_path, data.model_dump_json()) - def extract_metadata(self, data: BaseModel) -> Dict[str, "MetadataType"]: + def extract_metadata(self, data: BaseModel) -> dict[str, "MetadataType"]: """Extract metadata from the given BaseModel object. Args: @@ -69,7 +69,7 @@ def extract_metadata(self, data: BaseModel) -> Dict[str, "MetadataType"]: """ return {"schema": data.schema()} - def compute_content_hash(self, data: BaseModel) -> Optional[str]: + def compute_content_hash(self, data: BaseModel) -> str | None: """Compute the content hash of the given data. Args: @@ -85,7 +85,7 @@ def compute_content_hash(self, data: BaseModel) -> Optional[str]: hash_.update(json.dumps(json_data, sort_keys=True).encode()) return hash_.hexdigest() - def save_visualizations(self, data: Any) -> Dict[str, "VisualizationType"]: + def save_visualizations(self, data: Any) -> dict[str, "VisualizationType"]: """Save visualizations for the given data. Args: diff --git a/src/zenml/materializers/service_materializer.py b/src/zenml/materializers/service_materializer.py index 2462641637f..ce4ee66a23d 100644 --- a/src/zenml/materializers/service_materializer.py +++ b/src/zenml/materializers/service_materializer.py @@ -15,7 +15,7 @@ import os import uuid -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type +from typing import TYPE_CHECKING, Any, ClassVar from zenml.client import Client from zenml.enums import ArtifactType @@ -31,10 +31,10 @@ class ServiceMaterializer(BaseMaterializer): """Materializer to read/write service instances.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (BaseService,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (BaseService,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.SERVICE - def load(self, data_type: Type[Any]) -> BaseService: + def load(self, data_type: type[Any]) -> BaseService: """Creates and returns a service. This service is instantiated from the serialized service configuration @@ -68,7 +68,7 @@ def save(self, service: BaseService) -> None: def extract_metadata( self, service: BaseService - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Extract metadata from the given service. Args: diff --git a/src/zenml/materializers/structured_string_materializer.py b/src/zenml/materializers/structured_string_materializer.py index 179f19c9862..c5aadf8f067 100644 --- a/src/zenml/materializers/structured_string_materializer.py +++ b/src/zenml/materializers/structured_string_materializer.py @@ -15,7 +15,7 @@ import hashlib import os -from typing import Dict, Optional, Type, Union +from typing import Union from zenml.enums import ArtifactType, VisualizationType from zenml.logger import get_logger @@ -39,7 +39,7 @@ class StructuredStringMaterializer(BaseMaterializer): ASSOCIATED_TYPES = (CSVString, HTMLString, MarkdownString, JSONString) ASSOCIATED_ARTIFACT_TYPE = ArtifactType.DATA_ANALYSIS - def load(self, data_type: Type[STRUCTURED_STRINGS]) -> STRUCTURED_STRINGS: + def load(self, data_type: type[STRUCTURED_STRINGS]) -> STRUCTURED_STRINGS: """Loads the data from the HTML or Markdown file. Args: @@ -64,7 +64,7 @@ def save(self, data: STRUCTURED_STRINGS) -> None: def save_visualizations( self, data: STRUCTURED_STRINGS - ) -> Dict[str, VisualizationType]: + ) -> dict[str, VisualizationType]: """Save visualizations for the given data. Args: @@ -78,7 +78,7 @@ def save_visualizations( visualization_type = self._get_visualization_type(type(data)) return {filepath: visualization_type} - def _get_filepath(self, data_type: Type[STRUCTURED_STRINGS]) -> str: + def _get_filepath(self, data_type: type[STRUCTURED_STRINGS]) -> str: """Get the file path for the given data type. Args: @@ -105,7 +105,7 @@ def _get_filepath(self, data_type: Type[STRUCTURED_STRINGS]) -> str: return os.path.join(self.uri, filename) def _get_visualization_type( - self, data_type: Type[STRUCTURED_STRINGS] + self, data_type: type[STRUCTURED_STRINGS] ) -> VisualizationType: """Get the visualization type for the given data type. @@ -131,7 +131,7 @@ def _get_visualization_type( f"Data type {data_type} is not supported by this materializer." ) - def compute_content_hash(self, data: STRUCTURED_STRINGS) -> Optional[str]: + def compute_content_hash(self, data: STRUCTURED_STRINGS) -> str | None: """Compute the content hash of the given data. Args: diff --git a/src/zenml/materializers/uuid_materializer.py b/src/zenml/materializers/uuid_materializer.py index fad84dde7c1..1e5c265782e 100644 --- a/src/zenml/materializers/uuid_materializer.py +++ b/src/zenml/materializers/uuid_materializer.py @@ -16,7 +16,7 @@ import hashlib import os import uuid -from typing import Any, ClassVar, Dict, Optional, Tuple, Type +from typing import Any, ClassVar from zenml.artifact_stores.base_artifact_store import BaseArtifactStore from zenml.enums import ArtifactType @@ -29,11 +29,11 @@ class UUIDMaterializer(BaseMaterializer): """Materializer to handle UUID objects.""" - ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (uuid.UUID,) + ASSOCIATED_TYPES: ClassVar[tuple[type[Any], ...]] = (uuid.UUID,) ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA def __init__( - self, uri: str, artifact_store: Optional[BaseArtifactStore] = None + self, uri: str, artifact_store: BaseArtifactStore | None = None ): """Define `self.data_path`. @@ -44,7 +44,7 @@ def __init__( super().__init__(uri, artifact_store) self.data_path = os.path.join(self.uri, DEFAULT_FILENAME) - def load(self, _: Type[uuid.UUID]) -> uuid.UUID: + def load(self, _: type[uuid.UUID]) -> uuid.UUID: """Read UUID from artifact store. Args: @@ -66,7 +66,7 @@ def save(self, data: uuid.UUID) -> None: with self.artifact_store.open(self.data_path, "w") as f: f.write(str(data)) - def extract_metadata(self, data: uuid.UUID) -> Dict[str, MetadataType]: + def extract_metadata(self, data: uuid.UUID) -> dict[str, MetadataType]: """Extract metadata from the UUID. Args: @@ -79,7 +79,7 @@ def extract_metadata(self, data: uuid.UUID) -> Dict[str, MetadataType]: "string_representation": str(data), } - def compute_content_hash(self, data: uuid.UUID) -> Optional[str]: + def compute_content_hash(self, data: uuid.UUID) -> str | None: """Compute the content hash of the given data. Args: diff --git a/src/zenml/metadata/lazy_load.py b/src/zenml/metadata/lazy_load.py index 7ce2cc0d30a..dc08d3d4404 100644 --- a/src/zenml/metadata/lazy_load.py +++ b/src/zenml/metadata/lazy_load.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Run Metadata Lazy Loader definition.""" -from typing import Optional from pydantic import BaseModel @@ -27,11 +26,11 @@ class LazyRunMetadataResponse(BaseModel): a pipeline context available only during pipeline compilation. """ - lazy_load_artifact_name: Optional[str] = None - lazy_load_artifact_version: Optional[str] = None - lazy_load_metadata_name: Optional[str] = None + lazy_load_artifact_name: str | None = None + lazy_load_artifact_version: str | None = None + lazy_load_metadata_name: str | None = None lazy_load_model_name: str - lazy_load_model_version: Optional[str] = None + lazy_load_model_version: str | None = None class RunMetadataLazyGetter: @@ -45,9 +44,9 @@ class RunMetadataLazyGetter: def __init__( self, _lazy_load_model_name: str, - _lazy_load_model_version: Optional[str], - _lazy_load_artifact_name: Optional[str] = None, - _lazy_load_artifact_version: Optional[str] = None, + _lazy_load_model_version: str | None, + _lazy_load_artifact_name: str | None = None, + _lazy_load_artifact_version: str | None = None, ): """Initialize a RunMetadataLazyGetter. diff --git a/src/zenml/metadata/metadata_types.py b/src/zenml/metadata/metadata_types.py index c00ed5e886c..e5698ebeee0 100644 --- a/src/zenml/metadata/metadata_types.py +++ b/src/zenml/metadata/metadata_types.py @@ -14,7 +14,7 @@ """Custom types that can be used as metadata of ZenML artifacts.""" import json -from typing import Any, Dict, List, Set, Tuple, Union +from typing import Any, Union from pydantic import GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema @@ -108,10 +108,10 @@ def __get_pydantic_core_schema__( int, float, bool, - Dict[Any, Any], - List[Any], - Set[Any], - Tuple[Any, ...], + dict[Any, Any], + list[Any], + set[Any], + tuple[Any, ...], Uri, Path, DType, @@ -211,8 +211,8 @@ def cast_to_metadata_type( def validate_metadata( - metadata: Dict[str, MetadataType], -) -> Dict[str, MetadataType]: + metadata: dict[str, MetadataType], +) -> dict[str, MetadataType]: """Validate metadata. This function excludes and warns about metadata values that are too long diff --git a/src/zenml/model/lazy_load.py b/src/zenml/model/lazy_load.py index cfc821bb6bf..f2b15568335 100644 --- a/src/zenml/model/lazy_load.py +++ b/src/zenml/model/lazy_load.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Model Version Data Lazy Loader definition.""" -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict, model_validator @@ -33,10 +33,10 @@ class ModelVersionDataLazyLoader(BaseModel): """ model_name: str - model_version: Optional[str] = None - artifact_name: Optional[str] = None - artifact_version: Optional[str] = None - metadata_name: Optional[str] = None + model_version: str | None = None + artifact_name: str | None = None + artifact_version: str | None = None + metadata_name: str | None = None # TODO: In Pydantic v2, the `model_` is a protected namespaces for all # fields defined under base models. If not handled, this raises a warning. @@ -49,7 +49,7 @@ class ModelVersionDataLazyLoader(BaseModel): @model_validator(mode="before") @classmethod @before_validator_handler - def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def _root_validator(cls, data: dict[str, Any]) -> dict[str, Any]: """Validate all in one. Args: diff --git a/src/zenml/model/model.py b/src/zenml/model/model.py index 8b7f8cae2fe..48f7325b6c2 100644 --- a/src/zenml/model/model.py +++ b/src/zenml/model/model.py @@ -16,10 +16,7 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Union, ) from uuid import UUID @@ -65,24 +62,24 @@ class Model(BaseModel): """ name: str - license: Optional[str] = None - description: Optional[str] = None - audience: Optional[str] = None - use_cases: Optional[str] = None - limitations: Optional[str] = None - trade_offs: Optional[str] = None - ethics: Optional[str] = None - tags: Optional[List[str]] = None - version: Optional[Union[ModelStages, int, str]] = Field( + license: str | None = None + description: str | None = None + audience: str | None = None + use_cases: str | None = None + limitations: str | None = None + trade_offs: str | None = None + ethics: str | None = None + tags: list[str] | None = None + version: ModelStages | int | str | None = Field( default=None, union_mode="smart" ) save_models_to_registry: bool = True # technical attributes - model_version_id: Optional[UUID] = None + model_version_id: UUID | None = None suppress_class_validation_warnings: bool = False - _model_id: Optional[UUID] = PrivateAttr(None) - _number: Optional[int] = PrivateAttr(None) + _model_id: UUID | None = PrivateAttr(None) + _number: int | None = PrivateAttr(None) _created_model_version: bool = PrivateAttr(False) # TODO: In Pydantic v2, the `model_` is a protected namespaces for all @@ -160,7 +157,7 @@ def number(self) -> int: return self._number @property - def stage(self) -> Optional[ModelStages]: + def stage(self) -> ModelStages | None: """Get version stage from the Model Control Plane. Returns: @@ -180,7 +177,7 @@ def stage(self) -> Optional[ModelStages]: ) return None - def load_artifact(self, name: str, version: Optional[str] = None) -> Any: + def load_artifact(self, name: str, version: str | None = None) -> Any: """Load artifact from the Model Control Plane. Args: @@ -210,7 +207,7 @@ def load_artifact(self, name: str, version: Optional[str] = None) -> Any: def get_artifact( self, name: str, - version: Optional[str] = None, + version: str | None = None, ) -> Optional["ArtifactVersionResponse"]: """Get the artifact linked to this model version. @@ -234,7 +231,7 @@ def get_artifact( def get_model_artifact( self, name: str, - version: Optional[str] = None, + version: str | None = None, ) -> Optional["ArtifactVersionResponse"]: """Get the model artifact linked to this model version. @@ -258,7 +255,7 @@ def get_model_artifact( def get_data_artifact( self, name: str, - version: Optional[str] = None, + version: str | None = None, ) -> Optional["ArtifactVersionResponse"]: """Get the data artifact linked to this model version. @@ -282,7 +279,7 @@ def get_data_artifact( def get_deployment_artifact( self, name: str, - version: Optional[str] = None, + version: str | None = None, ) -> Optional["ArtifactVersionResponse"]: """Get the deployment artifact linked to this model version. @@ -304,7 +301,7 @@ def get_deployment_artifact( ) def set_stage( - self, stage: Union[str, ModelStages], force: bool = False + self, stage: str | ModelStages, force: bool = False ) -> None: """Sets this Model to a desired stage. @@ -317,7 +314,7 @@ def set_stage( def log_metadata( self, - metadata: Dict[str, "MetadataType"], + metadata: dict[str, "MetadataType"], ) -> None: """Log model version metadata. @@ -340,7 +337,7 @@ def log_metadata( ) @property - def run_metadata(self) -> Dict[str, "MetadataType"]: + def run_metadata(self) -> dict[str, "MetadataType"]: """Get model version run metadata. Returns: @@ -371,7 +368,7 @@ def run_metadata(self) -> Dict[str, "MetadataType"]: def delete_artifact( self, name: str, - version: Optional[str] = None, + version: str | None = None, only_link: bool = True, delete_metadata: bool = True, delete_from_artifact_store: bool = False, @@ -438,7 +435,7 @@ def delete_all_artifacts( def _lazy_artifact_get( self, name: str, - version: Optional[str] = None, + version: str | None = None, ) -> Optional["ArtifactVersionResponse"]: from zenml.models.v2.core.artifact_version import ( LazyArtifactVersionResponse, @@ -479,7 +476,7 @@ def __eq__(self, other: object) -> bool: @model_validator(mode="before") @classmethod @before_validator_handler - def _root_validator(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def _root_validator(cls, data: dict[str, Any]) -> dict[str, Any]: """Validate all in one. Args: @@ -599,7 +596,7 @@ def _get_model_version( ) self.model_version_id = mv.id - difference: Dict[str, Any] = {} + difference: dict[str, Any] = {} if mv.metadata: if self.description and mv.description != self.description: difference["description"] = { @@ -726,18 +723,16 @@ def __hash__(self) -> int: """ return hash( "::".join( - ( str(v) for v in ( self.name, self.version, ) - ) ) ) @property - def _lazy_version(self) -> Optional[str]: + def _lazy_version(self) -> str | None: """Get version name for lazy loader. This getter ensures that new model version diff --git a/src/zenml/model/utils.py b/src/zenml/model/utils.py index 2cff2b3f058..635640bcbeb 100644 --- a/src/zenml/model/utils.py +++ b/src/zenml/model/utils.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Utility functions for linking step outputs to model versions.""" -from typing import Dict, Optional, Union +from typing import Optional from uuid import UUID from zenml.client import Client @@ -34,9 +34,9 @@ def log_model_metadata( - metadata: Dict[str, "MetadataType"], - model_name: Optional[str] = None, - model_version: Optional[Union[ModelStages, int, str]] = None, + metadata: dict[str, "MetadataType"], + model_name: str | None = None, + model_version: ModelStages | int | str | None = None, ) -> None: """Log model version metadata. @@ -139,7 +139,7 @@ def link_artifact_to_model( def link_service_to_model( service_id: UUID, model: Optional["Model"] = None, - model_version_id: Optional[UUID] = None, + model_version_id: UUID | None = None, ) -> None: """Links a service to a model. diff --git a/src/zenml/model_deployers/base_model_deployer.py b/src/zenml/model_deployers/base_model_deployer.py index c881e8b3d12..a604584dba7 100644 --- a/src/zenml/model_deployers/base_model_deployer.py +++ b/src/zenml/model_deployers/base_model_deployer.py @@ -18,13 +18,9 @@ from typing import ( Any, ClassVar, - Dict, - Generator, - List, - Optional, - Type, cast, ) +from collections.abc import Generator from uuid import UUID from zenml.client import Client @@ -82,7 +78,7 @@ class BaseModelDeployer(StackComponent, ABC): """ NAME: ClassVar[str] - FLAVOR: ClassVar[Type["BaseModelDeployerFlavor"]] + FLAVOR: ClassVar[type["BaseModelDeployerFlavor"]] @property def config(self) -> BaseModelDeployerConfig: @@ -268,7 +264,7 @@ def perform_deploy_model( @abstractmethod def get_model_server_info( service: BaseService, - ) -> Dict[str, Optional[str]]: + ) -> dict[str, str | None]: """Give implementation specific way to extract relevant model server properties for the user. Args: @@ -280,19 +276,19 @@ def get_model_server_info( def find_model_server( self, - config: Optional[Dict[str, Any]] = None, - running: Optional[bool] = None, - service_uuid: Optional[UUID] = None, - pipeline_name: Optional[str] = None, - pipeline_step_name: Optional[str] = None, - service_name: Optional[str] = None, - model_name: Optional[str] = None, - model_version: Optional[str] = None, - service_type: Optional[ServiceType] = None, - type: Optional[str] = None, - flavor: Optional[str] = None, - pipeline_run_id: Optional[str] = None, - ) -> List[BaseService]: + config: dict[str, Any] | None = None, + running: bool | None = None, + service_uuid: UUID | None = None, + pipeline_name: str | None = None, + pipeline_step_name: str | None = None, + service_name: str | None = None, + model_name: str | None = None, + model_version: str | None = None, + service_type: ServiceType | None = None, + type: str | None = None, + flavor: str | None = None, + pipeline_run_id: str | None = None, + ) -> list[BaseService]: """Abstract method to find one or more a model servers that match the given criteria. Args: @@ -528,7 +524,7 @@ def get_model_server_logs( self, uuid: UUID, follow: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> Generator[str, bool, None]: """Get the logs of a model server. @@ -578,7 +574,7 @@ def type(self) -> StackComponentType: return StackComponentType.MODEL_DEPLOYER @property - def config_class(self) -> Type[BaseModelDeployerConfig]: + def config_class(self) -> type[BaseModelDeployerConfig]: """Returns `BaseModelDeployerConfig` config class. Returns: @@ -588,14 +584,14 @@ def config_class(self) -> Type[BaseModelDeployerConfig]: @property @abstractmethod - def implementation_class(self) -> Type[BaseModelDeployer]: + def implementation_class(self) -> type[BaseModelDeployer]: """The class that implements the model deployer.""" def get_model_version_id_if_exists( - model_name: Optional[str], - model_version: Optional[str], -) -> Optional[UUID]: + model_name: str | None, + model_version: str | None, +) -> UUID | None: """Get the model version id if it exists. Args: diff --git a/src/zenml/model_registries/base_model_registry.py b/src/zenml/model_registries/base_model_registry.py index 0cc8c0bbd1a..4e046e10ed8 100644 --- a/src/zenml/model_registries/base_model_registry.py +++ b/src/zenml/model_registries/base_model_registry.py @@ -16,7 +16,7 @@ from abc import ABC, abstractmethod from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional, Type, cast +from typing import Any, cast from pydantic import BaseModel, ConfigDict @@ -47,8 +47,8 @@ class RegisteredModel(BaseModel): """ name: str - description: Optional[str] = None - metadata: Optional[Dict[str, str]] = None + description: str | None = None + metadata: dict[str, str] | None = None class ModelRegistryModelMetadata(BaseModel): @@ -62,16 +62,16 @@ class ModelRegistryModelMetadata(BaseModel): model and its development process. """ - zenml_version: Optional[str] = None - zenml_run_name: Optional[str] = None - zenml_pipeline_name: Optional[str] = None - zenml_pipeline_uuid: Optional[str] = None - zenml_pipeline_run_uuid: Optional[str] = None - zenml_step_name: Optional[str] = None - zenml_project: Optional[str] = None + zenml_version: str | None = None + zenml_run_name: str | None = None + zenml_pipeline_name: str | None = None + zenml_pipeline_uuid: str | None = None + zenml_pipeline_run_uuid: str | None = None + zenml_step_name: str | None = None + zenml_project: str | None = None @property - def custom_attributes(self) -> Dict[str, str]: + def custom_attributes(self) -> dict[str, str]: """Returns a dictionary of custom attributes. Returns: @@ -89,7 +89,7 @@ def model_dump( exclude_unset: bool = False, exclude_none: bool = True, **kwargs: Any, - ) -> Dict[str, str]: + ) -> dict[str, str]: """Returns a dictionary representation of the metadata. This method overrides the default Pydantic `model_dump` method to allow @@ -151,13 +151,13 @@ class RegistryModelVersion(BaseModel): version: str model_source_uri: str model_format: str - model_library: Optional[str] = None + model_library: str | None = None registered_model: RegisteredModel - description: Optional[str] = None - created_at: Optional[datetime] = None - last_updated_at: Optional[datetime] = None + description: str | None = None + created_at: datetime | None = None + last_updated_at: datetime | None = None stage: ModelVersionStage = ModelVersionStage.NONE - metadata: Optional[ModelRegistryModelMetadata] = None + metadata: ModelRegistryModelMetadata | None = None # TODO: In Pydantic v2, the `model_` is a protected namespaces for all # fields defined under base models. If not handled, this raises a warning. @@ -192,8 +192,8 @@ def config(self) -> BaseModelRegistryConfig: def register_model( self, name: str, - description: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, + description: str | None = None, + metadata: dict[str, str] | None = None, ) -> RegisteredModel: """Registers a model in the model registry. @@ -229,9 +229,9 @@ def delete_model( def update_model( self, name: str, - description: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, - remove_metadata: Optional[List[str]] = None, + description: str | None = None, + metadata: dict[str, str] | None = None, + remove_metadata: list[str] | None = None, ) -> RegisteredModel: """Updates a registered model in the model registry. @@ -264,9 +264,9 @@ def get_model(self, name: str) -> RegisteredModel: @abstractmethod def list_models( self, - name: Optional[str] = None, - metadata: Optional[Dict[str, str]] = None, - ) -> List[RegisteredModel]: + name: str | None = None, + metadata: dict[str, str] | None = None, + ) -> list[RegisteredModel]: """Lists all registered models in the model registry. Args: @@ -285,10 +285,10 @@ def list_models( def register_model_version( self, name: str, - version: Optional[str] = None, - model_source_uri: Optional[str] = None, - description: Optional[str] = None, - metadata: Optional[ModelRegistryModelMetadata] = None, + version: str | None = None, + model_source_uri: str | None = None, + description: str | None = None, + metadata: ModelRegistryModelMetadata | None = None, **kwargs: Any, ) -> RegistryModelVersion: """Registers a model version in the model registry. @@ -331,10 +331,10 @@ def update_model_version( self, name: str, version: str, - description: Optional[str] = None, - metadata: Optional[ModelRegistryModelMetadata] = None, - remove_metadata: Optional[List[str]] = None, - stage: Optional[ModelVersionStage] = None, + description: str | None = None, + metadata: ModelRegistryModelMetadata | None = None, + remove_metadata: list[str] | None = None, + stage: ModelVersionStage | None = None, ) -> RegistryModelVersion: """Updates a model version in the model registry. @@ -357,16 +357,16 @@ def update_model_version( @abstractmethod def list_model_versions( self, - name: Optional[str] = None, - model_source_uri: Optional[str] = None, - metadata: Optional[ModelRegistryModelMetadata] = None, - stage: Optional[ModelVersionStage] = None, - count: Optional[int] = None, - created_after: Optional[datetime] = None, - created_before: Optional[datetime] = None, - order_by_date: Optional[str] = None, + name: str | None = None, + model_source_uri: str | None = None, + metadata: ModelRegistryModelMetadata | None = None, + stage: ModelVersionStage | None = None, + count: int | None = None, + created_after: datetime | None = None, + created_before: datetime | None = None, + order_by_date: str | None = None, **kwargs: Any, - ) -> Optional[List[RegistryModelVersion]]: + ) -> list[RegistryModelVersion] | None: """Lists all model versions for a registered model. Args: @@ -388,8 +388,8 @@ def list_model_versions( def get_latest_model_version( self, name: str, - stage: Optional[ModelVersionStage] = None, - ) -> Optional[RegistryModelVersion]: + stage: ModelVersionStage | None = None, + ) -> RegistryModelVersion | None: """Gets the latest model version for a registered model. This method is used to get the latest model version for a registered @@ -493,7 +493,7 @@ def type(self) -> StackComponentType: return StackComponentType.MODEL_REGISTRY @property - def config_class(self) -> Type[BaseModelRegistryConfig]: + def config_class(self) -> type[BaseModelRegistryConfig]: """Config class for this flavor. Returns: @@ -503,7 +503,7 @@ def config_class(self) -> Type[BaseModelRegistryConfig]: @property @abstractmethod - def implementation_class(self) -> Type[StackComponent]: + def implementation_class(self) -> type[StackComponent]: """Returns the implementation class for this flavor. Returns: diff --git a/src/zenml/models/v2/base/base.py b/src/zenml/models/v2/base/base.py index b8194e38a05..74b5219ec74 100644 --- a/src/zenml/models/v2/base/base.py +++ b/src/zenml/models/v2/base/base.py @@ -14,7 +14,7 @@ """Base model definitions.""" from datetime import datetime -from typing import Any, Dict, Generic, Optional, TypeVar +from typing import Any, Generic, Optional, TypeVar from uuid import UUID from pydantic import ConfigDict, Field @@ -453,7 +453,7 @@ def get_metadata(self) -> "AnyMetadata": return super().get_metadata() # Analytics - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """Fetches the analytics metadata for base response models. Returns: diff --git a/src/zenml/models/v2/base/filter.py b/src/zenml/models/v2/base/filter.py index 74843317622..bff9094a5c5 100644 --- a/src/zenml/models/v2/base/filter.py +++ b/src/zenml/models/v2/base/filter.py @@ -20,12 +20,7 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - List, Optional, - Set, - Tuple, - Type, TypeVar, Union, ) @@ -81,11 +76,11 @@ class Filter(BaseModel, ABC): This operation set is defined in the ALLOWED_OPS class variable. """ - ALLOWED_OPS: ClassVar[List[str]] = [] + ALLOWED_OPS: ClassVar[list[str]] = [] operation: GenericFilterOps column: str - value: Optional[Any] = None + value: Any | None = None @field_validator("operation", mode="before") @classmethod @@ -111,7 +106,7 @@ def validate_operation(cls, value: Any) -> Any: def generate_query_conditions( self, - table: Type["SQLModel"], + table: type["SQLModel"], ) -> "ColumnElement[bool]": """Generate the query conditions for the database. @@ -147,7 +142,7 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: class BoolFilter(Filter): """Filter for all Boolean fields.""" - ALLOWED_OPS: ClassVar[List[str]] = [ + ALLOWED_OPS: ClassVar[list[str]] = [ GenericFilterOps.EQUALS, GenericFilterOps.NOT_EQUALS, ] @@ -170,7 +165,7 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: class StrFilter(Filter): """Filter for all string fields.""" - ALLOWED_OPS: ClassVar[List[str]] = [ + ALLOWED_OPS: ClassVar[list[str]] = [ GenericFilterOps.EQUALS, GenericFilterOps.NOT_EQUALS, GenericFilterOps.STARTSWITH, @@ -455,9 +450,9 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: class NumericFilter(Filter): """Filter for all numeric fields.""" - value: Union[float, datetime] = Field(union_mode="left_to_right") + value: float | datetime = Field(union_mode="left_to_right") - ALLOWED_OPS: ClassVar[List[str]] = [ + ALLOWED_OPS: ClassVar[list[str]] = [ GenericFilterOps.EQUALS, GenericFilterOps.NOT_EQUALS, GenericFilterOps.GT, @@ -491,11 +486,11 @@ def generate_query_conditions_from_column(self, column: Any) -> Any: class DatetimeFilter(Filter): """Filter for all datetime fields.""" - value: Union[datetime, Tuple[datetime, datetime]] = Field( + value: datetime | tuple[datetime, datetime] = Field( union_mode="left_to_right" ) - ALLOWED_OPS: ClassVar[List[str]] = [ + ALLOWED_OPS: ClassVar[list[str]] = [ GenericFilterOps.EQUALS, GenericFilterOps.NOT_EQUALS, GenericFilterOps.GT, @@ -553,19 +548,19 @@ class BaseFilter(BaseModel): """ # List of fields that cannot be used as filters. - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ "sort_by", "page", "size", "logical_operator", ] - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [] + CUSTOM_SORTING_OPTIONS: ClassVar[list[str]] = [] # List of fields that are not even mentioned as options in the CLI. - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [] + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [] # List of fields that are wrapped with `fastapi.Query(default)` in API. - API_MULTI_INPUT_PARAMS: ClassVar[List[str]] = [] + API_MULTI_INPUT_PARAMS: ClassVar[list[str]] = [] sort_by: str = Field( default="created", description="Which column to sort by." @@ -584,21 +579,21 @@ class BaseFilter(BaseModel): le=PAGE_SIZE_MAXIMUM, description="Page size", ) - id: Optional[Union[UUID, str]] = Field( + id: UUID | str | None = Field( default=None, description="Id for this resource", union_mode="left_to_right", ) - created: Optional[Union[datetime, str]] = Field( + created: datetime | str | None = Field( default=None, description="Created", union_mode="left_to_right" ) - updated: Optional[Union[datetime, str]] = Field( + updated: datetime | str | None = Field( default=None, description="Updated", union_mode="left_to_right" ) - _rbac_configuration: Optional[ - Tuple[UUID, Dict[str, Optional[Set[UUID]]]] - ] = None + _rbac_configuration: None | ( + tuple[UUID, dict[str, set[UUID] | None]] + ) = None @field_validator("sort_by", mode="before") @classmethod @@ -654,7 +649,7 @@ def validate_sort_by(cls, value: Any) -> Any: @model_validator(mode="before") @classmethod @before_validator_handler - def filter_ops(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def filter_ops(cls, data: dict[str, Any]) -> dict[str, Any]: """Parse incoming filters to ensure all filters are legal. Args: @@ -667,7 +662,7 @@ def filter_ops(cls, data: Dict[str, Any]) -> Dict[str, Any]: return data @property - def list_of_filters(self) -> List[Filter]: + def list_of_filters(self) -> list[Filter]: """Converts the class variables into a list of usable Filter Models. Returns: @@ -678,7 +673,7 @@ def list_of_filters(self) -> List[Filter]: ) @property - def sorting_params(self) -> Tuple[str, SorterOps]: + def sorting_params(self) -> tuple[str, SorterOps]: """Converts the class variables into a list of usable Filter Models. Returns: @@ -699,7 +694,7 @@ def sorting_params(self) -> Tuple[str, SorterOps]: def configure_rbac( self, authenticated_user_id: UUID, - **column_allowed_ids: Optional[Set[UUID]], + **column_allowed_ids: set[UUID] | None, ) -> None: """Configure RBAC allowed column values. @@ -715,7 +710,7 @@ def configure_rbac( def generate_rbac_filter( self, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> Optional["ColumnElement[bool]"]: """Generates an optional RBAC filter. @@ -756,7 +751,7 @@ def generate_rbac_filter( return None @classmethod - def _generate_filter_list(cls, values: Dict[str, Any]) -> List[Filter]: + def _generate_filter_list(cls, values: dict[str, Any]) -> list[Filter]: """Create a list of filters from a (column, value) dictionary. Args: @@ -765,7 +760,7 @@ def _generate_filter_list(cls, values: Dict[str, Any]) -> List[Filter]: Returns: A list of filters. """ - list_of_filters: List[Filter] = [] + list_of_filters: list[Filter] = [] for key, value in values.items(): # Ignore excluded filters @@ -788,7 +783,7 @@ def _generate_filter_list(cls, values: Dict[str, Any]) -> List[Filter]: return list_of_filters @staticmethod - def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]: + def _resolve_operator(value: Any) -> tuple[Any, GenericFilterOps]: """Determine the operator and filter value from a user-provided value. If the user-provided value is a string of the form "operator:value", @@ -828,9 +823,9 @@ def _resolve_operator(value: Any) -> Tuple[Any, GenericFilterOps]: def generate_name_or_id_query_conditions( self, - value: Union[UUID, str], - table: Type["NamedSchema"], - additional_columns: Optional[List[str]] = None, + value: UUID | str, + table: type["NamedSchema"], + additional_columns: list[str] | None = None, ) -> "ColumnElement[bool]": """Generate filter conditions for name or id of a table. @@ -871,7 +866,7 @@ def generate_name_or_id_query_conditions( @staticmethod def generate_custom_query_conditions_for_column( value: Any, - table: Type["SQLModel"], + table: type["SQLModel"], column: str, ) -> "ColumnElement[bool]": """Generate custom filter conditions for a column of a table. @@ -900,7 +895,7 @@ def offset(self) -> int: return self.size * (self.page - 1) def generate_filter( - self, table: Type["AnySchema"] + self, table: type["AnySchema"] ) -> Union["ColumnElement[bool]"]: """Generate the filter for the query. @@ -930,8 +925,8 @@ def generate_filter( raise RuntimeError("No valid logical operator was supplied.") def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List["ColumnElement[bool]"]: + self, table: type["AnySchema"] + ) -> list["ColumnElement[bool]"]: """Get custom filters. This can be overridden by subclasses to define custom filters that are @@ -948,7 +943,7 @@ def get_custom_filters( def apply_filter( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Applies the filter to a query. @@ -974,7 +969,7 @@ def apply_filter( def apply_sorting( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Apply sorting to the query. @@ -1005,7 +1000,7 @@ def apply_sorting( class FilterGenerator: """Helper class to define filters for a class.""" - def __init__(self, model_class: Type[BaseModel]) -> None: + def __init__(self, model_class: type[BaseModel]) -> None: """Initialize the object. Args: @@ -1192,7 +1187,7 @@ def _define_datetime_filter( ValueError: If the value is not a valid datetime. """ try: - filter_value: Union[datetime, Tuple[datetime, datetime]] + filter_value: datetime | tuple[datetime, datetime] if isinstance(value, datetime): filter_value = value elif "," in value: diff --git a/src/zenml/models/v2/base/page.py b/src/zenml/models/v2/base/page.py index 35f136d6fec..f63d9029619 100644 --- a/src/zenml/models/v2/base/page.py +++ b/src/zenml/models/v2/base/page.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Page model definitions.""" -from typing import Generator, Generic, List, TypeVar +from typing import Generic, TypeVar +from collections.abc import Generator from pydantic import BaseModel from pydantic.types import NonNegativeInt, PositiveInt @@ -30,7 +31,7 @@ class Page(BaseModel, Generic[B]): max_size: PositiveInt total_pages: NonNegativeInt total: NonNegativeInt - items: List[B] + items: list[B] __params_type__ = BaseFilter @@ -74,8 +75,7 @@ def __iter__(self) -> Generator[B, None, None]: # type: ignore[override] Yields: An iterator over the items in the page. """ - for item in self.items.__iter__(): - yield item + yield from self.items.__iter__() def __contains__(self, item: B) -> bool: """Returns whether the page contains a specific item. diff --git a/src/zenml/models/v2/base/scoped.py b/src/zenml/models/v2/base/scoped.py index 8f15f7049a8..85f9cedd5ac 100644 --- a/src/zenml/models/v2/base/scoped.py +++ b/src/zenml/models/v2/base/scoped.py @@ -17,13 +17,9 @@ TYPE_CHECKING, Any, ClassVar, - Dict, Generic, - List, Optional, - Type, TypeVar, - Union, ) from uuid import UUID @@ -60,7 +56,7 @@ class UserScopedRequest(BaseRequest): Used as a base class for all domain models that are "owned" by a user. """ - user: Optional[UUID] = Field( + user: UUID | None = Field( default=None, title="The id of the user that created this resource. Set " "automatically by the server.", @@ -69,7 +65,7 @@ class UserScopedRequest(BaseRequest): exclude=True, ) - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """Fetches the analytics metadata for user scoped models. Returns: @@ -88,7 +84,7 @@ class ProjectScopedRequest(UserScopedRequest): project: UUID = Field(title="The project to which this resource belongs.") - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """Fetches the analytics metadata for project scoped models. Returns: @@ -106,7 +102,7 @@ def get_analytics_metadata(self) -> Dict[str, Any]: class UserScopedResponseBody(BaseDatedResponseBody): """Base user-owned body.""" - user_id: Optional[UUID] = Field(title="The user id.", default=None) + user_id: UUID | None = Field(title="The user id.", default=None) class UserScopedResponseMetadata(BaseResponseMetadata): @@ -136,7 +132,7 @@ class UserScopedResponse( """ # Analytics - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """Fetches the analytics metadata for user scoped models. Returns: @@ -149,7 +145,7 @@ def get_analytics_metadata(self) -> Dict[str, Any]: # Body and metadata properties @property - def user_id(self) -> Optional[UUID]: + def user_id(self) -> UUID | None: """The user ID property. Returns: @@ -170,25 +166,25 @@ def user(self) -> Optional["UserResponse"]: class UserScopedFilter(BaseFilter): """Model to enable advanced user-based scoping.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *BaseFilter.FILTER_EXCLUDE_FIELDS, "user", "scope_user", ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *BaseFilter.CLI_EXCLUDE_FIELDS, "scope_user", ] - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + CUSTOM_SORTING_OPTIONS: ClassVar[list[str]] = [ *BaseFilter.CUSTOM_SORTING_OPTIONS, "user", ] - scope_user: Optional[UUID] = Field( + scope_user: UUID | None = Field( default=None, description="The user to scope this query to.", ) - user: Optional[Union[UUID, str]] = Field( + user: UUID | str | None = Field( default=None, description="Name/ID of the user that created the entity.", union_mode="left_to_right", @@ -203,8 +199,8 @@ def set_scope_user(self, user_id: UUID) -> None: self.scope_user = user_id def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List["ColumnElement[bool]"]: + self, table: type["AnySchema"] + ) -> list["ColumnElement[bool]"]: """Get custom filters. Args: @@ -235,7 +231,7 @@ def get_custom_filters( def apply_sorting( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Apply sorting to the query. @@ -275,7 +271,7 @@ def apply_sorting( def apply_filter( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Applies the filter to a query. @@ -330,7 +326,7 @@ class ProjectScopedResponse( """ # Analytics - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """Fetches the analytics metadata for project scoped models. Returns: @@ -368,11 +364,11 @@ def project(self) -> "ProjectResponse": class ProjectScopedFilter(UserScopedFilter): """Model to enable advanced scoping with project.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *UserScopedFilter.FILTER_EXCLUDE_FIELDS, "project", ] - project: Optional[Union[UUID, str]] = Field( + project: UUID | str | None = Field( default=None, description="Name/ID of the project which the search is scoped to. " "This field must always be set and is always applied in addition to " @@ -384,7 +380,7 @@ class ProjectScopedFilter(UserScopedFilter): def apply_filter( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Applies the filter to a query. @@ -428,26 +424,26 @@ def apply_filter( class TaggableFilter(BaseFilter): """Model to enable filtering and sorting by tags.""" - tag: Optional[str] = Field( + tag: str | None = Field( description="Tag to apply to the filter query.", default=None ) - tags: Optional[List[str]] = Field( + tags: list[str] | None = Field( description="Tags to apply to the filter query.", default=None ) CLI_EXCLUDE_FIELDS = [ *BaseFilter.CLI_EXCLUDE_FIELDS, ] - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *BaseFilter.FILTER_EXCLUDE_FIELDS, "tag", "tags", ] - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + CUSTOM_SORTING_OPTIONS: ClassVar[list[str]] = [ *BaseFilter.CUSTOM_SORTING_OPTIONS, "tags", ] - API_MULTI_INPUT_PARAMS: ClassVar[List[str]] = [ + API_MULTI_INPUT_PARAMS: ClassVar[list[str]] = [ *BaseFilter.API_MULTI_INPUT_PARAMS, "tags", ] @@ -476,7 +472,7 @@ def add_tag_to_tags(self) -> "TaggableFilter": def apply_filter( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Applies the filter to a query. @@ -500,8 +496,8 @@ def apply_filter( return query def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List["ColumnElement[bool]"]: + self, table: type["AnySchema"] + ) -> list["ColumnElement[bool]"]: """Get custom tag filters. Args: @@ -535,7 +531,7 @@ def get_custom_filters( def apply_sorting( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Apply sorting to the query. @@ -617,15 +613,15 @@ def apply_sorting( class RunMetadataFilterMixin(BaseFilter): """Model to enable filtering and sorting by run metadata.""" - run_metadata: Optional[List[str]] = Field( + run_metadata: list[str] | None = Field( default=None, description="The run_metadata to filter the pipeline runs by.", ) - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *BaseFilter.FILTER_EXCLUDE_FIELDS, "run_metadata", ] - API_MULTI_INPUT_PARAMS: ClassVar[List[str]] = [ + API_MULTI_INPUT_PARAMS: ClassVar[list[str]] = [ *BaseFilter.API_MULTI_INPUT_PARAMS, "run_metadata", ] @@ -674,8 +670,8 @@ def validate_run_metadata_format(self) -> "RunMetadataFilterMixin": return self def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List["ColumnElement[bool]"]: + self, table: type["AnySchema"] + ) -> list["ColumnElement[bool]"]: """Get custom run metadata filters. Args: diff --git a/src/zenml/models/v2/core/action.py b/src/zenml/models/v2/core/action.py index 037b800b2d8..e6dc49e775c 100644 --- a/src/zenml/models/v2/core/action.py +++ b/src/zenml/models/v2/core/action.py @@ -17,8 +17,6 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - Optional, TypeVar, ) from uuid import UUID @@ -66,13 +64,13 @@ class ActionRequest(ProjectScopedRequest): title="The subtype of the action.", max_length=STR_FIELD_MAX_LENGTH, ) - configuration: Dict[str, Any] = Field( + configuration: dict[str, Any] = Field( title="The configuration for the action.", ) service_account_id: UUID = Field( title="The service account that is used to execute the action.", ) - auth_window: Optional[int] = Field( + auth_window: int | None = Field( default=None, title="The time window in minutes for which the service account is " "authorized to execute the action. Set this to 0 to authorize the " @@ -87,25 +85,25 @@ class ActionRequest(ProjectScopedRequest): class ActionUpdate(BaseUpdate): """Update model for actions.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, title="The new name for the action.", max_length=STR_FIELD_MAX_LENGTH, ) - description: Optional[str] = Field( + description: str | None = Field( default=None, title="The new description for the action.", max_length=STR_FIELD_MAX_LENGTH, ) - configuration: Optional[Dict[str, Any]] = Field( + configuration: dict[str, Any] | None = Field( default=None, title="The configuration for the action.", ) - service_account_id: Optional[UUID] = Field( + service_account_id: UUID | None = Field( default=None, title="The service account that is used to execute the action.", ) - auth_window: Optional[int] = Field( + auth_window: int | None = Field( default=None, title="The time window in minutes for which the service account is " "authorized to execute the action. Set this to 0 to authorize the " @@ -152,7 +150,7 @@ class ActionResponseMetadata(ProjectScopedResponseMetadata): title="The description of the action.", max_length=STR_FIELD_MAX_LENGTH, ) - configuration: Dict[str, Any] = Field( + configuration: dict[str, Any] = Field( title="The configuration for the action.", ) auth_window: int = Field( @@ -229,7 +227,7 @@ def auth_window(self) -> int: return self.get_metadata().auth_window @property - def configuration(self) -> Dict[str, Any]: + def configuration(self) -> dict[str, Any]: """The `configuration` property. Returns: @@ -237,7 +235,7 @@ def configuration(self) -> Dict[str, Any]: """ return self.get_metadata().configuration - def set_configuration(self, configuration: Dict[str, Any]) -> None: + def set_configuration(self, configuration: dict[str, Any]) -> None: """Set the `configuration` property. Args: @@ -262,15 +260,15 @@ def service_account(self) -> "UserResponse": class ActionFilter(ProjectScopedFilter): """Model to enable advanced filtering of all actions.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the action.", ) - flavor: Optional[str] = Field( + flavor: str | None = Field( default=None, title="The flavor of the action.", ) - plugin_subtype: Optional[str] = Field( + plugin_subtype: str | None = Field( default=None, title="The subtype of the action.", ) diff --git a/src/zenml/models/v2/core/action_flavor.py b/src/zenml/models/v2/core/action_flavor.py index b331e1efed3..6f3818163b6 100644 --- a/src/zenml/models/v2/core/action_flavor.py +++ b/src/zenml/models/v2/core/action_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Action flavor model definitions.""" -from typing import Any, Dict +from typing import Any from zenml.models.v2.base.base_plugin_flavor import ( BasePluginFlavorResponse, @@ -30,7 +30,7 @@ class ActionFlavorResponseBody(BasePluginResponseBody): class ActionFlavorResponseMetadata(BasePluginResponseMetadata): """Response metadata for action flavors.""" - config_schema: Dict[str, Any] + config_schema: dict[str, Any] class ActionFlavorResponseResources(BasePluginResponseResources): @@ -48,7 +48,7 @@ class ActionFlavorResponse( # Body and metadata properties @property - def config_schema(self) -> Dict[str, Any]: + def config_schema(self) -> dict[str, Any]: """The `source_config_schema` property. Returns: diff --git a/src/zenml/models/v2/core/api_key.py b/src/zenml/models/v2/core/api_key.py index 43a928228c0..f94a871f7dd 100644 --- a/src/zenml/models/v2/core/api_key.py +++ b/src/zenml/models/v2/core/api_key.py @@ -14,7 +14,7 @@ """Models representing API keys.""" from datetime import datetime, timedelta -from typing import TYPE_CHECKING, ClassVar, List, Optional, Type, Union +from typing import TYPE_CHECKING, ClassVar from uuid import UUID from pydantic import BaseModel, Field @@ -89,7 +89,7 @@ class APIKeyRequest(BaseRequest): max_length=STR_FIELD_MAX_LENGTH, ) - description: Optional[str] = Field( + description: str | None = Field( default=None, title="The description of the API Key.", max_length=TEXT_FIELD_MAX_LENGTH, @@ -112,17 +112,17 @@ class APIKeyRotateRequest(BaseRequest): class APIKeyUpdate(BaseUpdate): """Update model for API keys.""" - name: Optional[str] = Field( + name: str | None = Field( title="The name of the API Key.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - description: Optional[str] = Field( + description: str | None = Field( title="The description of the API Key.", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - active: Optional[bool] = Field( + active: bool | None = Field( title="Whether the API key is active.", default=None, ) @@ -143,7 +143,7 @@ class APIKeyInternalUpdate(APIKeyUpdate): class APIKeyResponseBody(BaseDatedResponseBody): """Response body for API keys.""" - key: Optional[str] = Field( + key: str | None = Field( default=None, title="The API key. Only set immediately after creation or rotation.", ) @@ -168,10 +168,10 @@ class APIKeyResponseMetadata(BaseResponseMetadata): title="Number of minutes for which the previous key is still valid " "after it has been rotated.", ) - last_login: Optional[datetime] = Field( + last_login: datetime | None = Field( default=None, title="Time when the API key was last used to log in." ) - last_rotated: Optional[datetime] = Field( + last_rotated: datetime | None = Field( default=None, title="Time when the API key was last rotated." ) @@ -218,7 +218,7 @@ def set_key(self, key: str) -> None: # Body and metadata properties @property - def key(self) -> Optional[str]: + def key(self) -> str | None: """The `key` property. Returns: @@ -263,7 +263,7 @@ def retain_period_minutes(self) -> int: return self.get_metadata().retain_period_minutes @property - def last_login(self) -> Optional[datetime]: + def last_login(self) -> datetime | None: """The `last_login` property. Returns: @@ -272,7 +272,7 @@ def last_login(self) -> Optional[datetime]: return self.get_metadata().last_login @property - def last_rotated(self) -> Optional[datetime]: + def last_rotated(self) -> datetime | None: """The `last_rotated` property. Returns: @@ -284,7 +284,7 @@ def last_rotated(self) -> Optional[datetime]: class APIKeyInternalResponse(APIKeyResponse): """Response model for API keys used internally.""" - previous_key: Optional[str] = Field( + previous_key: str | None = Field( default=None, title="The previous API key. Only set if the key was rotated.", ) @@ -306,7 +306,7 @@ def verify_key( # even when the hashed key is not set, we still want to execute # the hash verification to protect against response discrepancy # attacks (https://cwe.mitre.org/data/definitions/204.html) - key_hash: Optional[str] = None + key_hash: str | None = None context = CryptContext(schemes=["bcrypt"], deprecated="auto") if self.key is not None and self.active: key_hash = self.key @@ -338,38 +338,38 @@ def verify_key( class APIKeyFilter(BaseFilter): """Filter model for API keys.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *BaseFilter.FILTER_EXCLUDE_FIELDS, "service_account", ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *BaseFilter.CLI_EXCLUDE_FIELDS, "service_account", ] - service_account: Optional[UUID] = Field( + service_account: UUID | None = Field( default=None, description="The service account to scope this query to.", ) - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the API key", ) - description: Optional[str] = Field( + description: str | None = Field( default=None, title="Filter by the API key description.", ) - active: Optional[Union[bool, str]] = Field( + active: bool | str | None = Field( default=None, title="Whether the API key is active.", union_mode="left_to_right", ) - last_login: Optional[Union[datetime, str]] = Field( + last_login: datetime | str | None = Field( default=None, title="Time when the API key was last used to log in.", union_mode="left_to_right", ) - last_rotated: Optional[Union[datetime, str]] = Field( + last_rotated: datetime | str | None = Field( default=None, title="Time when the API key was last rotated.", union_mode="left_to_right", @@ -386,7 +386,7 @@ def set_service_account(self, service_account_id: UUID) -> None: def apply_filter( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Override to apply the service account scope as an additional filter. diff --git a/src/zenml/models/v2/core/api_transaction.py b/src/zenml/models/v2/core/api_transaction.py index 99c2c0f886f..4a03bbb01e7 100644 --- a/src/zenml/models/v2/core/api_transaction.py +++ b/src/zenml/models/v2/core/api_transaction.py @@ -15,7 +15,6 @@ from typing import ( TYPE_CHECKING, - Optional, TypeVar, ) from uuid import UUID @@ -60,7 +59,7 @@ class ApiTransactionRequest(UserScopedRequest): class ApiTransactionUpdate(BaseUpdate): """Update model for stack components.""" - result: Optional[PlainSerializedSecretStr] = Field( + result: PlainSerializedSecretStr | None = Field( default=None, title="The response payload.", ) @@ -69,7 +68,7 @@ class ApiTransactionUpdate(BaseUpdate): "completion." ) - def get_result(self) -> Optional[str]: + def get_result(self) -> str | None: """Get the result of the API transaction. Returns: @@ -104,7 +103,7 @@ class ApiTransactionResponseBody(UserScopedResponseBody): completed: bool = Field( title="Whether the transaction is completed.", ) - result: Optional[PlainSerializedSecretStr] = Field( + result: PlainSerializedSecretStr | None = Field( default=None, title="The response payload.", ) @@ -165,7 +164,7 @@ def completed(self) -> bool: return self.get_body().completed @property - def result(self) -> Optional[PlainSerializedSecretStr]: + def result(self) -> PlainSerializedSecretStr | None: """The `result` property. Returns: @@ -173,7 +172,7 @@ def result(self) -> Optional[PlainSerializedSecretStr]: """ return self.get_body().result - def get_result(self) -> Optional[str]: + def get_result(self) -> str | None: """Get the result of the API transaction. Returns: diff --git a/src/zenml/models/v2/core/artifact.py b/src/zenml/models/v2/core/artifact.py index 8cc178f9033..37f62dcccaa 100644 --- a/src/zenml/models/v2/core/artifact.py +++ b/src/zenml/models/v2/core/artifact.py @@ -17,10 +17,6 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - List, - Optional, - Type, TypeVar, ) from uuid import UUID @@ -62,7 +58,7 @@ class ArtifactRequest(ProjectScopedRequest): title="Whether the name is custom (True) or auto-generated (False).", default=False, ) - tags: Optional[List[str]] = Field( + tags: list[str] | None = Field( title="Artifact tags.", description="Should be a list of plain strings, e.g., ['tag1', 'tag2']", default=None, @@ -75,10 +71,10 @@ class ArtifactRequest(ProjectScopedRequest): class ArtifactUpdate(BaseUpdate): """Artifact update model.""" - name: Optional[str] = None - add_tags: Optional[List[str]] = None - remove_tags: Optional[List[str]] = None - has_custom_name: Optional[bool] = None + name: str | None = None + add_tags: list[str] | None = None + remove_tags: list[str] | None = None + has_custom_name: bool | None = None # ------------------ Response Model ------------------ @@ -100,12 +96,12 @@ class ArtifactResponseMetadata(ProjectScopedResponseMetadata): class ArtifactResponseResources(ProjectScopedResponseResources): """Class for all resource models associated with the Artifact Entity.""" - tags: List[TagResponse] = Field( + tags: list[TagResponse] = Field( title="Tags associated with the artifact.", ) # TODO: maybe move these back to body or figure out a better solution - latest_version_name: Optional[str] = None - latest_version_id: Optional[UUID] = None + latest_version_name: str | None = None + latest_version_id: UUID | None = None class ArtifactResponse( @@ -134,7 +130,7 @@ def get_hydrated_version(self) -> "ArtifactResponse": # Body and metadata properties @property - def tags(self) -> List[TagResponse]: + def tags(self) -> list[TagResponse]: """The `tags` property. Returns: @@ -143,7 +139,7 @@ def tags(self) -> List[TagResponse]: return self.get_resources().tags @property - def latest_version_name(self) -> Optional[str]: + def latest_version_name(self) -> str | None: """The `latest_version_name` property. Returns: @@ -152,7 +148,7 @@ def latest_version_name(self) -> Optional[str]: return self.get_resources().latest_version_name @property - def latest_version_id(self) -> Optional[UUID]: + def latest_version_id(self) -> UUID | None: """The `latest_version_id` property. Returns: @@ -171,7 +167,7 @@ def has_custom_name(self) -> bool: # Helper methods @property - def versions(self) -> Dict[str, "ArtifactVersionResponse"]: + def versions(self) -> dict[str, "ArtifactVersionResponse"]: """Get a list of all versions of this artifact. Returns: @@ -189,29 +185,29 @@ def versions(self) -> Dict[str, "ArtifactVersionResponse"]: class ArtifactFilter(ProjectScopedFilter, TaggableFilter): """Model to enable advanced filtering of artifacts.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS, *TaggableFilter.FILTER_EXCLUDE_FIELDS, ] - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + CUSTOM_SORTING_OPTIONS: ClassVar[list[str]] = [ *ProjectScopedFilter.CUSTOM_SORTING_OPTIONS, *TaggableFilter.CUSTOM_SORTING_OPTIONS, SORT_BY_LATEST_VERSION_KEY, ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.CLI_EXCLUDE_FIELDS, *TaggableFilter.CLI_EXCLUDE_FIELDS, ] - name: Optional[str] = None - has_custom_name: Optional[bool] = None + name: str | None = None + has_custom_name: bool | None = None def apply_sorting( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Apply sorting to the query for Artifacts. diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index fcc5a0f753a..5668ec8e605 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -17,10 +17,6 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - List, - Optional, - Type, TypeVar, Union, ) @@ -75,15 +71,15 @@ class ArtifactVersionRequest(ProjectScopedRequest): """Request model for artifact versions.""" - artifact_id: Optional[UUID] = Field( + artifact_id: UUID | None = Field( default=None, title="ID of the artifact to which this version belongs.", ) - artifact_name: Optional[str] = Field( + artifact_name: str | None = Field( default=None, title="Name of the artifact to which this version belongs.", ) - version: Optional[Union[int, str]] = Field( + version: int | str | None = Field( default=None, title="Version of the artifact." ) has_custom_name: bool = Field( @@ -91,7 +87,7 @@ class ArtifactVersionRequest(ProjectScopedRequest): default=False, ) type: ArtifactType = Field(title="Type of the artifact.") - artifact_store_id: Optional[UUID] = Field( + artifact_store_id: UUID | None = Field( title="ID of the artifact store in which this artifact is stored.", default=None, ) @@ -104,23 +100,23 @@ class ArtifactVersionRequest(ProjectScopedRequest): data_type: SourceWithValidator = Field( title="Data type of the artifact.", ) - content_hash: Optional[str] = Field( + content_hash: str | None = Field( title="The content hash of the artifact version.", default=None, max_length=STR_FIELD_MAX_LENGTH, ) - tags: Optional[List[str]] = Field( + tags: list[str] | None = Field( title="Tags of the artifact.", description="Should be a list of plain strings, e.g., ['tag1', 'tag2']", default=None, ) - visualizations: Optional[List["ArtifactVisualizationRequest"]] = Field( + visualizations: list["ArtifactVisualizationRequest"] | None = Field( default=None, title="Visualizations of the artifact." ) save_type: ArtifactSaveType = Field( title="The save type of the artifact version.", ) - metadata: Optional[Dict[str, MetadataType]] = Field( + metadata: dict[str, MetadataType] | None = Field( default=None, title="Metadata of the artifact version." ) @@ -174,9 +170,9 @@ def _validate_request(self) -> "ArtifactVersionRequest": class ArtifactVersionUpdate(BaseUpdate): """Artifact version update model.""" - name: Optional[str] = None - add_tags: Optional[List[str]] = None - remove_tags: Optional[List[str]] = None + name: str | None = None + add_tags: list[str] | None = None + remove_tags: list[str] | None = None # ------------------ Response Model ------------------ @@ -202,11 +198,11 @@ class ArtifactVersionResponseBody(ProjectScopedResponseBody): save_type: ArtifactSaveType = Field( title="The save type of the artifact version.", ) - artifact_store_id: Optional[UUID] = Field( + artifact_store_id: UUID | None = Field( title="ID of the artifact store in which this artifact is stored.", default=None, ) - content_hash: Optional[str] = Field( + content_hash: str | None = Field( title="The content hash of the artifact version.", default=None, ) @@ -236,10 +232,10 @@ def str_field_max_length_check(cls, value: Any) -> Any: class ArtifactVersionResponseMetadata(ProjectScopedResponseMetadata): """Response metadata for artifact versions.""" - visualizations: Optional[List["ArtifactVisualizationResponse"]] = Field( + visualizations: list["ArtifactVisualizationResponse"] | None = Field( default=None, title="Visualizations of the artifact." ) - run_metadata: Dict[str, MetadataType] = Field( + run_metadata: dict[str, MetadataType] = Field( default={}, title="Metadata of the artifact." ) @@ -247,14 +243,14 @@ class ArtifactVersionResponseMetadata(ProjectScopedResponseMetadata): class ArtifactVersionResponseResources(ProjectScopedResponseResources): """Class for all resource models associated with the artifact version entity.""" - tags: List[TagResponse] = Field( + tags: list[TagResponse] = Field( title="Tags associated with the artifact version.", ) - producer_step_run_id: Optional[UUID] = Field( + producer_step_run_id: UUID | None = Field( title="ID of the step run that produced this artifact.", default=None, ) - producer_pipeline_run_id: Optional[UUID] = Field( + producer_pipeline_run_id: UUID | None = Field( title="The ID of the pipeline run that generated this artifact version.", default=None, ) @@ -317,7 +313,7 @@ def type(self) -> ArtifactType: return self.get_body().type @property - def content_hash(self) -> Optional[str]: + def content_hash(self) -> str | None: """The `content_hash` property. Returns: @@ -326,7 +322,7 @@ def content_hash(self) -> Optional[str]: return self.get_body().content_hash @property - def tags(self) -> List[TagResponse]: + def tags(self) -> list[TagResponse]: """The `tags` property. Returns: @@ -335,7 +331,7 @@ def tags(self) -> List[TagResponse]: return self.get_resources().tags @property - def producer_pipeline_run_id(self) -> Optional[UUID]: + def producer_pipeline_run_id(self) -> UUID | None: """The `producer_pipeline_run_id` property. Returns: @@ -353,7 +349,7 @@ def save_type(self) -> ArtifactSaveType: return self.get_body().save_type @property - def artifact_store_id(self) -> Optional[UUID]: + def artifact_store_id(self) -> UUID | None: """The `artifact_store_id` property. Returns: @@ -362,7 +358,7 @@ def artifact_store_id(self) -> Optional[UUID]: return self.get_body().artifact_store_id @property - def producer_step_run_id(self) -> Optional[UUID]: + def producer_step_run_id(self) -> UUID | None: """The `producer_step_run_id` property. Returns: @@ -373,7 +369,7 @@ def producer_step_run_id(self) -> Optional[UUID]: @property def visualizations( self, - ) -> Optional[List["ArtifactVisualizationResponse"]]: + ) -> list["ArtifactVisualizationResponse"] | None: """The `visualizations` property. Returns: @@ -382,7 +378,7 @@ def visualizations( return self.get_metadata().visualizations @property - def run_metadata(self) -> Dict[str, MetadataType]: + def run_metadata(self) -> dict[str, MetadataType]: """The `metadata` property. Returns: @@ -476,7 +472,7 @@ def download_files(self, path: str, overwrite: bool = False) -> None: overwrite=overwrite, ) - def visualize(self, title: Optional[str] = None) -> None: + def visualize(self, title: str | None = None) -> None: """Visualize the artifact in notebook environments. Args: @@ -495,7 +491,7 @@ class ArtifactVersionFilter( ): """Model to enable advanced filtering of artifact versions.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS, *TaggableFilter.FILTER_EXCLUDE_FIELDS, *RunMetadataFilterMixin.FILTER_EXCLUDE_FIELDS, @@ -507,24 +503,24 @@ class ArtifactVersionFilter( "pipeline_run", "model_version_id", ] - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + CUSTOM_SORTING_OPTIONS: ClassVar[list[str]] = [ *ProjectScopedFilter.CUSTOM_SORTING_OPTIONS, *TaggableFilter.CUSTOM_SORTING_OPTIONS, *RunMetadataFilterMixin.CUSTOM_SORTING_OPTIONS, ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.CLI_EXCLUDE_FIELDS, *TaggableFilter.CLI_EXCLUDE_FIELDS, *RunMetadataFilterMixin.CLI_EXCLUDE_FIELDS, "artifact_id", ] - API_MULTI_INPUT_PARAMS: ClassVar[List[str]] = [ + API_MULTI_INPUT_PARAMS: ClassVar[list[str]] = [ *ProjectScopedFilter.API_MULTI_INPUT_PARAMS, *TaggableFilter.API_MULTI_INPUT_PARAMS, *RunMetadataFilterMixin.API_MULTI_INPUT_PARAMS, ] - artifact: Optional[Union[UUID, str]] = Field( + artifact: UUID | str | None = Field( default=None, description="The name or ID of the artifact which the search is scoped " "to. This field must always be set and is always applied in addition " @@ -532,60 +528,60 @@ class ArtifactVersionFilter( "logical_operator field.", union_mode="left_to_right", ) - artifact_id: Optional[Union[UUID, str]] = Field( + artifact_id: UUID | str | None = Field( default=None, description="[Deprecated] Use 'artifact' instead. ID of the artifact to which this version belongs.", union_mode="left_to_right", ) - version: Optional[str] = Field( + version: str | None = Field( default=None, description="Version of the artifact", ) - version_number: Optional[Union[int, str]] = Field( + version_number: int | str | None = Field( default=None, description="Version of the artifact if it is an integer", union_mode="left_to_right", ) - uri: Optional[str] = Field( + uri: str | None = Field( default=None, description="Uri of the artifact", ) - materializer: Optional[str] = Field( + materializer: str | None = Field( default=None, description="Materializer used to produce the artifact", ) - type: Optional[str] = Field( + type: str | None = Field( default=None, description="Type of the artifact", ) - data_type: Optional[str] = Field( + data_type: str | None = Field( default=None, description="Datatype of the artifact", ) - artifact_store_id: Optional[Union[UUID, str]] = Field( + artifact_store_id: UUID | str | None = Field( default=None, description="Artifact store for this artifact", union_mode="left_to_right", ) - model_version_id: Optional[Union[UUID, str]] = Field( + model_version_id: UUID | str | None = Field( default=None, description="ID of the model version that is associated with this " "artifact version.", union_mode="left_to_right", ) - only_unused: Optional[bool] = Field( + only_unused: bool | None = Field( default=False, description="Filter only for unused artifacts" ) - has_custom_name: Optional[bool] = Field( + has_custom_name: bool | None = Field( default=None, description="Filter only artifacts with/without custom names.", ) - model: Optional[Union[UUID, str]] = Field( + model: UUID | str | None = Field( default=None, description="Name/ID of the model that is associated with this " "artifact version.", ) - pipeline_run: Optional[Union[UUID, str]] = Field( + pipeline_run: UUID | str | None = Field( default=None, description="Name/ID of a pipeline run that is associated with this " "artifact version.", @@ -594,8 +590,8 @@ class ArtifactVersionFilter( model_config = ConfigDict(protected_namespaces=()) def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List[Union["ColumnElement[bool]"]]: + self, table: type["AnySchema"] + ) -> list[Union["ColumnElement[bool]"]]: """Get custom filters. Args: @@ -728,11 +724,11 @@ class LazyArtifactVersionResponse(ArtifactVersionResponse): a pipeline context available only during pipeline compilation. """ - id: Optional[UUID] = None # type: ignore[assignment] - lazy_load_name: Optional[str] = None - lazy_load_version: Optional[str] = None + id: UUID | None = None # type: ignore[assignment] + lazy_load_name: str | None = None + lazy_load_version: str | None = None lazy_load_model_name: str - lazy_load_model_version: Optional[str] = None + lazy_load_model_version: str | None = None def get_body(self) -> None: # type: ignore[override] """Protects from misuse of the lazy loader. @@ -753,7 +749,7 @@ def get_metadata(self) -> None: # type: ignore[override] ) @property - def run_metadata(self) -> Dict[str, MetadataType]: + def run_metadata(self) -> dict[str, MetadataType]: """The `metadata` property in lazy loading mode. Returns: diff --git a/src/zenml/models/v2/core/code_repository.py b/src/zenml/models/v2/core/code_repository.py index a235ce8b6a4..44666fd94bd 100644 --- a/src/zenml/models/v2/core/code_repository.py +++ b/src/zenml/models/v2/core/code_repository.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing code repositories.""" -from typing import Any, Dict, Optional +from typing import Any from pydantic import Field @@ -39,16 +39,16 @@ class CodeRepositoryRequest(ProjectScopedRequest): title="The name of the code repository.", max_length=STR_FIELD_MAX_LENGTH, ) - config: Dict[str, Any] = Field( + config: dict[str, Any] = Field( description="Configuration for the code repository." ) source: Source = Field(description="The code repository source.") - logo_url: Optional[str] = Field( + logo_url: str | None = Field( description="Optional URL of a logo (png, jpg or svg) for the " "code repository.", default=None, ) - description: Optional[str] = Field( + description: str | None = Field( description="Code repository description.", max_length=TEXT_FIELD_MAX_LENGTH, default=None, @@ -61,24 +61,24 @@ class CodeRepositoryRequest(ProjectScopedRequest): class CodeRepositoryUpdate(BaseUpdate): """Update model for code repositories.""" - name: Optional[str] = Field( + name: str | None = Field( title="The name of the code repository.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - config: Optional[Dict[str, Any]] = Field( + config: dict[str, Any] | None = Field( description="Configuration for the code repository.", default=None, ) - source: Optional[SourceWithValidator] = Field( + source: SourceWithValidator | None = Field( description="The code repository source.", default=None ) - logo_url: Optional[str] = Field( + logo_url: str | None = Field( description="Optional URL of a logo (png, jpg or svg) for the " "code repository.", default=None, ) - description: Optional[str] = Field( + description: str | None = Field( description="Code repository description.", max_length=TEXT_FIELD_MAX_LENGTH, default=None, @@ -92,7 +92,7 @@ class CodeRepositoryResponseBody(ProjectScopedResponseBody): """Response body for code repositories.""" source: Source = Field(description="The code repository source.") - logo_url: Optional[str] = Field( + logo_url: str | None = Field( default=None, description="Optional URL of a logo (png, jpg or svg) for the " "code repository.", @@ -102,10 +102,10 @@ class CodeRepositoryResponseBody(ProjectScopedResponseBody): class CodeRepositoryResponseMetadata(ProjectScopedResponseMetadata): """Response metadata for code repositories.""" - config: Dict[str, Any] = Field( + config: dict[str, Any] = Field( description="Configuration for the code repository." ) - description: Optional[str] = Field( + description: str | None = Field( default=None, description="Code repository description.", max_length=TEXT_FIELD_MAX_LENGTH, @@ -151,7 +151,7 @@ def source(self) -> Source: return self.get_body().source @property - def logo_url(self) -> Optional[str]: + def logo_url(self) -> str | None: """The `logo_url` property. Returns: @@ -160,7 +160,7 @@ def logo_url(self) -> Optional[str]: return self.get_body().logo_url @property - def config(self) -> Dict[str, Any]: + def config(self) -> dict[str, Any]: """The `config` property. Returns: @@ -169,7 +169,7 @@ def config(self) -> Dict[str, Any]: return self.get_metadata().config @property - def description(self) -> Optional[str]: + def description(self) -> str | None: """The `description` property. Returns: @@ -184,7 +184,7 @@ def description(self) -> Optional[str]: class CodeRepositoryFilter(ProjectScopedFilter): """Model to enable advanced filtering of all code repositories.""" - name: Optional[str] = Field( + name: str | None = Field( description="Name of the code repository.", default=None, ) diff --git a/src/zenml/models/v2/core/component.py b/src/zenml/models/v2/core/component.py index d3b2f4ff25e..d719ba7cdb0 100644 --- a/src/zenml/models/v2/core/component.py +++ b/src/zenml/models/v2/core/component.py @@ -17,10 +17,7 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - List, Optional, - Type, TypeVar, Union, ) @@ -68,26 +65,26 @@ class ComponentBase(BaseModel): title="The flavor of the stack component.", max_length=STR_FIELD_MAX_LENGTH, ) - environment: Optional[Dict[str, str]] = Field( + environment: dict[str, str] | None = Field( default=None, title="Environment variables to set when running on this component.", ) - secrets: Optional[List[Union[UUID, str]]] = Field( + secrets: list[UUID | str] | None = Field( default=None, title="Secrets to set as environment variables when running on this component.", ) - configuration: Dict[str, Any] = Field( + configuration: dict[str, Any] = Field( title="The stack component configuration.", ) - connector_resource_id: Optional[str] = Field( + connector_resource_id: str | None = Field( default=None, description="The ID of a specific resource instance to " "gain access to through the connector", ) - labels: Optional[Dict[str, Any]] = Field( + labels: dict[str, Any] | None = Field( default=None, title="The stack component labels.", ) @@ -99,9 +96,9 @@ class ComponentBase(BaseModel): class ComponentRequest(ComponentBase, UserScopedRequest): """Request model for stack components.""" - ANALYTICS_FIELDS: ClassVar[List[str]] = ["type", "flavor"] + ANALYTICS_FIELDS: ClassVar[list[str]] = ["type", "flavor"] - connector: Optional[UUID] = Field( + connector: UUID | None = Field( default=None, title="The service connector linked to this stack component.", ) @@ -138,37 +135,37 @@ class DefaultComponentRequest(ComponentRequest): class ComponentUpdate(BaseUpdate): """Update model for stack components.""" - name: Optional[str] = Field( + name: str | None = Field( title="The name of the stack component.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - configuration: Optional[Dict[str, Any]] = Field( + configuration: dict[str, Any] | None = Field( title="The stack component configuration.", default=None, ) - environment: Optional[Dict[str, str]] = Field( + environment: dict[str, str] | None = Field( default=None, title="Environment variables to set when running on this component.", ) - connector_resource_id: Optional[str] = Field( + connector_resource_id: str | None = Field( description="The ID of a specific resource instance to " "gain access to through the connector", default=None, ) - labels: Optional[Dict[str, Any]] = Field( + labels: dict[str, Any] | None = Field( title="The stack component labels.", default=None, ) - connector: Optional[UUID] = Field( + connector: UUID | None = Field( title="The service connector linked to this stack component.", default=None, ) - add_secrets: Optional[List[Union[UUID, str]]] = Field( + add_secrets: list[UUID | str] | None = Field( default=None, title="New secrets to add to the stack component.", ) - remove_secrets: Optional[List[Union[UUID, str]]] = Field( + remove_secrets: list[UUID | str] | None = Field( default=None, title="Secrets to remove from the stack component.", ) @@ -187,13 +184,13 @@ class ComponentResponseBody(UserScopedResponseBody): title="The flavor of the stack component.", max_length=STR_FIELD_MAX_LENGTH, ) - integration: Optional[str] = Field( + integration: str | None = Field( default=None, title="The name of the integration that the component's flavor " "belongs to.", max_length=STR_FIELD_MAX_LENGTH, ) - logo_url: Optional[str] = Field( + logo_url: str | None = Field( default=None, title="Optionally, a url pointing to a png," "svg or jpg can be attached.", @@ -203,23 +200,23 @@ class ComponentResponseBody(UserScopedResponseBody): class ComponentResponseMetadata(UserScopedResponseMetadata): """Response metadata for stack components.""" - configuration: Dict[str, Any] = Field( + configuration: dict[str, Any] = Field( title="The stack component configuration.", ) - environment: Dict[str, str] = Field( + environment: dict[str, str] = Field( default={}, title="Environment variables to set when running on this component.", ) - secrets: List[UUID] = Field( + secrets: list[UUID] = Field( default=[], title="Secrets to set as environment variables when running on this " "component.", ) - labels: Optional[Dict[str, Any]] = Field( + labels: dict[str, Any] | None = Field( default=None, title="The stack component labels.", ) - connector_resource_id: Optional[str] = Field( + connector_resource_id: str | None = Field( default=None, description="The ID of a specific resource instance to " "gain access to through the connector", @@ -247,14 +244,14 @@ class ComponentResponse( ): """Response model for stack components.""" - ANALYTICS_FIELDS: ClassVar[List[str]] = ["type"] + ANALYTICS_FIELDS: ClassVar[list[str]] = ["type"] name: str = Field( title="The name of the stack component.", max_length=STR_FIELD_MAX_LENGTH, ) - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """Add the component labels to analytics metadata. Returns: @@ -304,7 +301,7 @@ def flavor_name(self) -> str: return self.get_body().flavor_name @property - def integration(self) -> Optional[str]: + def integration(self) -> str | None: """The `integration` property. Returns: @@ -313,7 +310,7 @@ def integration(self) -> Optional[str]: return self.get_body().integration @property - def logo_url(self) -> Optional[str]: + def logo_url(self) -> str | None: """The `logo_url` property. Returns: @@ -322,7 +319,7 @@ def logo_url(self) -> Optional[str]: return self.get_body().logo_url @property - def configuration(self) -> Dict[str, Any]: + def configuration(self) -> dict[str, Any]: """The `configuration` property. Returns: @@ -331,7 +328,7 @@ def configuration(self) -> Dict[str, Any]: return self.get_metadata().configuration @property - def environment(self) -> Dict[str, str]: + def environment(self) -> dict[str, str]: """The `environment` property. Returns: @@ -340,7 +337,7 @@ def environment(self) -> Dict[str, str]: return self.get_metadata().environment @property - def secrets(self) -> List[UUID]: + def secrets(self) -> list[UUID]: """The `secrets` property. Returns: @@ -349,7 +346,7 @@ def secrets(self) -> List[UUID]: return self.get_metadata().secrets @property - def labels(self) -> Optional[Dict[str, Any]]: + def labels(self) -> dict[str, Any] | None: """The `labels` property. Returns: @@ -358,7 +355,7 @@ def labels(self) -> Optional[Dict[str, Any]]: return self.get_metadata().labels @property - def connector_resource_id(self) -> Optional[str]: + def connector_resource_id(self) -> str | None: """The `connector_resource_id` property. Returns: @@ -391,37 +388,37 @@ def flavor(self) -> "FlavorResponse": class ComponentFilter(UserScopedFilter): """Model to enable advanced stack component filtering.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *UserScopedFilter.FILTER_EXCLUDE_FIELDS, "scope_type", "stack_id", ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *UserScopedFilter.CLI_EXCLUDE_FIELDS, "scope_type", ] - scope_type: Optional[str] = Field( + scope_type: str | None = Field( default=None, description="The type to scope this query to.", ) - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the stack component", ) - flavor: Optional[str] = Field( + flavor: str | None = Field( default=None, description="Flavor of the stack component", ) - type: Optional[str] = Field( + type: str | None = Field( default=None, description="Type of the stack component", ) - connector_id: Optional[Union[UUID, str]] = Field( + connector_id: UUID | str | None = Field( default=None, description="Connector linked to the stack component", union_mode="left_to_right", ) - stack_id: Optional[Union[UUID, str]] = Field( + stack_id: UUID | str | None = Field( default=None, description="Stack of the stack component", union_mode="left_to_right", @@ -436,7 +433,7 @@ def set_scope_type(self, component_type: str) -> None: self.scope_type = component_type def generate_filter( - self, table: Type["AnySchema"] + self, table: type["AnySchema"] ) -> Union["ColumnElement[bool]"]: """Generate the filter for the query. diff --git a/src/zenml/models/v2/core/curated_visualization.py b/src/zenml/models/v2/core/curated_visualization.py index 6da41b4eaa2..217c9fc8b95 100644 --- a/src/zenml/models/v2/core/curated_visualization.py +++ b/src/zenml/models/v2/core/curated_visualization.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing curated visualizations.""" -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from uuid import UUID from pydantic import Field, NonNegativeInt @@ -58,11 +58,11 @@ class CuratedVisualizationRequest(ProjectScopedRequest): "for the target resource." ), ) - display_name: Optional[str] = Field( + display_name: str | None = Field( default=None, title="The display name of the visualization.", ) - display_order: Optional[NonNegativeInt] = Field( + display_order: NonNegativeInt | None = Field( default=None, title="The display order of the visualization.", description=( @@ -98,11 +98,11 @@ class CuratedVisualizationRequest(ProjectScopedRequest): class CuratedVisualizationUpdate(BaseUpdate): """Update model for curated visualizations.""" - display_name: Optional[str] = Field( + display_name: str | None = Field( default=None, title="The new display name of the visualization.", ) - display_order: Optional[NonNegativeInt] = Field( + display_order: NonNegativeInt | None = Field( default=None, title="The new display order of the visualization.", description=( @@ -110,7 +110,7 @@ class CuratedVisualizationUpdate(BaseUpdate): "the combination of resource type and resource ID." ), ) - layout_size: Optional[CuratedVisualizationSize] = Field( + layout_size: CuratedVisualizationSize | None = Field( default=None, title="The updated layout size of the visualization.", ) @@ -135,11 +135,11 @@ class CuratedVisualizationResponseBody(ProjectScopedResponseBody): "Provided for read-only context when available." ), ) - display_name: Optional[str] = Field( + display_name: str | None = Field( default=None, title="The display name of the visualization.", ) - display_order: Optional[NonNegativeInt] = Field( + display_order: NonNegativeInt | None = Field( default=None, title="The display order of the visualization.", description=( @@ -216,7 +216,7 @@ def artifact_version_id(self) -> UUID: return self.get_body().artifact_version_id @property - def display_name(self) -> Optional[str]: + def display_name(self) -> str | None: """The display name of the visualization. Returns: @@ -225,7 +225,7 @@ def display_name(self) -> Optional[str]: return self.get_body().display_name @property - def display_order(self) -> Optional[int]: + def display_order(self) -> int | None: """The display order of the visualization. Returns: diff --git a/src/zenml/models/v2/core/deployment.py b/src/zenml/models/v2/core/deployment.py index d08cd30ed58..126b453203e 100644 --- a/src/zenml/models/v2/core/deployment.py +++ b/src/zenml/models/v2/core/deployment.py @@ -17,12 +17,8 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - List, Optional, - Type, TypeVar, - Union, ) from uuid import UUID @@ -64,8 +60,8 @@ class DeploymentOperationalState(BaseModel): """Operational state of a deployment.""" status: DeploymentStatus = Field(default=DeploymentStatus.UNKNOWN) - url: Optional[str] = None - metadata: Optional[Dict[str, Any]] = None + url: str | None = None + metadata: dict[str, Any] | None = None # ------------------ Request Model ------------------ @@ -88,12 +84,12 @@ class DeploymentRequest(ProjectScopedRequest): title="The deployer ID.", description="The ID of the deployer component managing this deployment.", ) - auth_key: Optional[str] = Field( + auth_key: str | None = Field( default=None, title="The auth key of the deployment.", description="The auth key of the deployment.", ) - tags: Optional[List[Union[str, Tag]]] = Field( + tags: list[str | Tag] | None = Field( default=None, title="Tags of the deployment.", ) @@ -105,35 +101,35 @@ class DeploymentRequest(ProjectScopedRequest): class DeploymentUpdate(BaseUpdate): """Update model for deployments.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, title="The new name of the deployment.", max_length=STR_FIELD_MAX_LENGTH, ) - snapshot_id: Optional[UUID] = Field( + snapshot_id: UUID | None = Field( default=None, title="New pipeline snapshot ID.", ) - url: Optional[str] = Field( + url: str | None = Field( default=None, title="The new URL of the deployment.", ) - status: Optional[DeploymentStatus] = Field( + status: DeploymentStatus | None = Field( default=None, title="The new status of the deployment.", ) - deployment_metadata: Optional[Dict[str, Any]] = Field( + deployment_metadata: dict[str, Any] | None = Field( default=None, title="The new metadata of the deployment.", ) - auth_key: Optional[str] = Field( + auth_key: str | None = Field( default=None, title="The new auth key of the deployment.", ) - add_tags: Optional[List[str]] = Field( + add_tags: list[str] | None = Field( default=None, title="New tags to add to the deployment." ) - remove_tags: Optional[List[str]] = Field( + remove_tags: list[str] | None = Field( default=None, title="Tags to remove from the deployment." ) @@ -162,12 +158,12 @@ def from_operational_state( class DeploymentResponseBody(ProjectScopedResponseBody): """Response body for deployments.""" - url: Optional[str] = Field( + url: str | None = Field( default=None, title="The URL of the deployment.", description="The HTTP URL where the deployment can be accessed.", ) - status: Optional[DeploymentStatus] = Field( + status: DeploymentStatus | None = Field( default=None, title="The status of the deployment.", description="Current operational status of the deployment.", @@ -177,10 +173,10 @@ class DeploymentResponseBody(ProjectScopedResponseBody): class DeploymentResponseMetadata(ProjectScopedResponseMetadata): """Response metadata for deployments.""" - deployment_metadata: Dict[str, Any] = Field( + deployment_metadata: dict[str, Any] = Field( title="The metadata of the deployment.", ) - auth_key: Optional[str] = Field( + auth_key: str | None = Field( default=None, title="The auth key of the deployment.", description="The auth key of the deployment.", @@ -205,10 +201,10 @@ class DeploymentResponseResources(ProjectScopedResponseResources): title="The pipeline.", description="The pipeline being deployed.", ) - tags: List["TagResponse"] = Field( + tags: list["TagResponse"] = Field( title="Tags associated with the deployment.", ) - visualizations: List["CuratedVisualizationResponse"] = Field( + visualizations: list["CuratedVisualizationResponse"] = Field( default_factory=list, title="Curated deployment visualizations.", ) @@ -242,7 +238,7 @@ def get_hydrated_version(self) -> "DeploymentResponse": # Helper properties @property - def url(self) -> Optional[str]: + def url(self) -> str | None: """The URL of the deployment. Returns: @@ -251,7 +247,7 @@ def url(self) -> Optional[str]: return self.get_body().url @property - def status(self) -> Optional[DeploymentStatus]: + def status(self) -> DeploymentStatus | None: """The status of the deployment. Returns: @@ -260,7 +256,7 @@ def status(self) -> Optional[DeploymentStatus]: return self.get_body().status @property - def deployment_metadata(self) -> Dict[str, Any]: + def deployment_metadata(self) -> dict[str, Any]: """The metadata of the deployment. Returns: @@ -269,7 +265,7 @@ def deployment_metadata(self) -> Dict[str, Any]: return self.get_metadata().deployment_metadata @property - def auth_key(self) -> Optional[str]: + def auth_key(self) -> str | None: """The auth key of the deployment. Returns: @@ -304,7 +300,7 @@ def pipeline(self) -> Optional["PipelineResponse"]: """ return self.get_resources().pipeline - def tags(self) -> List["TagResponse"]: + def tags(self) -> list["TagResponse"]: """The tags of the deployment. Returns: @@ -313,7 +309,7 @@ def tags(self) -> List["TagResponse"]: return self.get_resources().tags @property - def visualizations(self) -> List["CuratedVisualizationResponse"]: + def visualizations(self) -> list["CuratedVisualizationResponse"]: """The visualizations of the deployment. Returns: @@ -322,7 +318,7 @@ def visualizations(self) -> List["CuratedVisualizationResponse"]: return self.get_resources().visualizations @property - def snapshot_id(self) -> Optional[UUID]: + def snapshot_id(self) -> UUID | None: """The pipeline snapshot ID. Returns: @@ -334,7 +330,7 @@ def snapshot_id(self) -> Optional[UUID]: return None @property - def deployer_id(self) -> Optional[UUID]: + def deployer_id(self) -> UUID | None: """The deployer ID. Returns: @@ -352,13 +348,13 @@ def deployer_id(self) -> Optional[UUID]: class DeploymentFilter(ProjectScopedFilter, TaggableFilter): """Model to enable advanced filtering of deployments.""" - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + CUSTOM_SORTING_OPTIONS: ClassVar[list[str]] = [ *ProjectScopedFilter.CUSTOM_SORTING_OPTIONS, *TaggableFilter.CUSTOM_SORTING_OPTIONS, "snapshot", "pipeline", ] - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS, *TaggableFilter.FILTER_EXCLUDE_FIELDS, "pipeline", @@ -368,37 +364,37 @@ class DeploymentFilter(ProjectScopedFilter, TaggableFilter): *TaggableFilter.CLI_EXCLUDE_FIELDS, ] - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the deployment.", ) - url: Optional[str] = Field( + url: str | None = Field( default=None, description="URL of the deployment.", ) - status: Optional[str] = Field( + status: str | None = Field( default=None, description="Status of the deployment.", ) - pipeline: Optional[Union[UUID, str]] = Field( + pipeline: UUID | str | None = Field( default=None, description="Pipeline associated with the deployment.", union_mode="left_to_right", ) - snapshot_id: Optional[Union[UUID, str]] = Field( + snapshot_id: UUID | str | None = Field( default=None, description="Pipeline snapshot ID associated with the deployment.", union_mode="left_to_right", ) - deployer_id: Optional[Union[UUID, str]] = Field( + deployer_id: UUID | str | None = Field( default=None, description="Deployer ID managing the deployment.", union_mode="left_to_right", ) def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List["ColumnElement[bool]"]: + self, table: type["AnySchema"] + ) -> list["ColumnElement[bool]"]: """Get custom filters. Args: @@ -432,7 +428,7 @@ def get_custom_filters( def apply_sorting( self, query: "AnyQuery", - table: Type["AnySchema"], + table: type["AnySchema"], ) -> "AnyQuery": """Apply sorting to the query. diff --git a/src/zenml/models/v2/core/device.py b/src/zenml/models/v2/core/device.py index 5292e36b9dc..1f567dcdf1a 100644 --- a/src/zenml/models/v2/core/device.py +++ b/src/zenml/models/v2/core/device.py @@ -14,7 +14,6 @@ """Models representing devices.""" from datetime import datetime -from typing import Optional, Union from uuid import UUID from pydantic import Field @@ -43,36 +42,36 @@ class OAuthDeviceInternalRequest(BaseRequest): description="The number of seconds after which the OAuth2 device " "expires and can no longer be used for authentication." ) - os: Optional[str] = Field( + os: str | None = Field( default=None, description="The operating system of the device used for " "authentication.", ) - ip_address: Optional[str] = Field( + ip_address: str | None = Field( default=None, description="The IP address of the device used for authentication.", ) - hostname: Optional[str] = Field( + hostname: str | None = Field( default=None, description="The hostname of the device used for authentication.", ) - python_version: Optional[str] = Field( + python_version: str | None = Field( default=None, description="The Python version of the device used for authentication.", ) - zenml_version: Optional[str] = Field( + zenml_version: str | None = Field( default=None, description="The ZenML version of the device used for authentication.", ) - city: Optional[str] = Field( + city: str | None = Field( default=None, description="The city where the device is located.", ) - region: Optional[str] = Field( + region: str | None = Field( default=None, description="The region where the device is located.", ) - country: Optional[str] = Field( + country: str | None = Field( default=None, description="The country where the device is located.", ) @@ -84,7 +83,7 @@ class OAuthDeviceInternalRequest(BaseRequest): class OAuthDeviceUpdate(BaseUpdate): """OAuth2 device update model.""" - locked: Optional[bool] = Field( + locked: bool | None = Field( default=None, description="Whether to lock or unlock the OAuth2 device. A locked " "device cannot be used for authentication.", @@ -94,22 +93,22 @@ class OAuthDeviceUpdate(BaseUpdate): class OAuthDeviceInternalUpdate(OAuthDeviceUpdate): """OAuth2 device update model used internally for authentication.""" - user_id: Optional[UUID] = Field( + user_id: UUID | None = Field( default=None, description="User that owns the OAuth2 device." ) - status: Optional[OAuthDeviceStatus] = Field( + status: OAuthDeviceStatus | None = Field( default=None, description="The new status of the OAuth2 device." ) - expires_in: Optional[int] = Field( + expires_in: int | None = Field( default=None, description="Set the device to expire in the given number of seconds. " "If the value is 0 or negative, the device is set to never expire.", ) - failed_auth_attempts: Optional[int] = Field( + failed_auth_attempts: int | None = Field( default=None, description="Set the number of failed authentication attempts.", ) - trusted_device: Optional[bool] = Field( + trusted_device: bool | None = Field( default=None, description="Whether to mark the OAuth2 device as trusted. A trusted " "device has a much longer validity time.", @@ -121,36 +120,36 @@ class OAuthDeviceInternalUpdate(OAuthDeviceUpdate): default=False, description="Whether to generate new user and device codes.", ) - os: Optional[str] = Field( + os: str | None = Field( default=None, description="The operating system of the device used for " "authentication.", ) - ip_address: Optional[str] = Field( + ip_address: str | None = Field( default=None, description="The IP address of the device used for authentication.", ) - hostname: Optional[str] = Field( + hostname: str | None = Field( default=None, description="The hostname of the device used for authentication.", ) - python_version: Optional[str] = Field( + python_version: str | None = Field( default=None, description="The Python version of the device used for authentication.", ) - zenml_version: Optional[str] = Field( + zenml_version: str | None = Field( default=None, description="The ZenML version of the device used for authentication.", ) - city: Optional[str] = Field( + city: str | None = Field( default=None, description="The city where the device is located.", ) - region: Optional[str] = Field( + region: str | None = Field( default=None, description="The region where the device is located.", ) - country: Optional[str] = Field( + country: str | None = Field( default=None, description="The country where the device is located.", ) @@ -163,7 +162,7 @@ class OAuthDeviceResponseBody(UserScopedResponseBody): """Response body for OAuth2 devices.""" client_id: UUID = Field(description="The client ID of the OAuth2 device.") - expires: Optional[datetime] = Field( + expires: datetime | None = Field( default=None, description="The expiration date of the OAuth2 device after which " "the device is no longer valid and cannot be used for " @@ -176,16 +175,16 @@ class OAuthDeviceResponseBody(UserScopedResponseBody): status: OAuthDeviceStatus = Field( description="The status of the OAuth2 device." ) - os: Optional[str] = Field( + os: str | None = Field( default=None, description="The operating system of the device used for " "authentication.", ) - ip_address: Optional[str] = Field( + ip_address: str | None = Field( default=None, description="The IP address of the device used for authentication.", ) - hostname: Optional[str] = Field( + hostname: str | None = Field( default=None, description="The hostname of the device used for authentication.", ) @@ -194,30 +193,30 @@ class OAuthDeviceResponseBody(UserScopedResponseBody): class OAuthDeviceResponseMetadata(UserScopedResponseMetadata): """Response metadata for OAuth2 devices.""" - python_version: Optional[str] = Field( + python_version: str | None = Field( default=None, description="The Python version of the device used for authentication.", ) - zenml_version: Optional[str] = Field( + zenml_version: str | None = Field( default=None, description="The ZenML version of the device used for authentication.", ) - city: Optional[str] = Field( + city: str | None = Field( default=None, description="The city where the device is located.", ) - region: Optional[str] = Field( + region: str | None = Field( default=None, description="The region where the device is located.", ) - country: Optional[str] = Field( + country: str | None = Field( default=None, description="The country where the device is located.", ) failed_auth_attempts: int = Field( description="The number of failed authentication attempts.", ) - last_login: Optional[datetime] = Field( + last_login: datetime | None = Field( description="The date of the last successful login." ) @@ -258,7 +257,7 @@ def client_id(self) -> UUID: return self.get_body().client_id @property - def expires(self) -> Optional[datetime]: + def expires(self) -> datetime | None: """The `expires` property. Returns: @@ -285,7 +284,7 @@ def status(self) -> OAuthDeviceStatus: return self.get_body().status @property - def os(self) -> Optional[str]: + def os(self) -> str | None: """The `os` property. Returns: @@ -294,7 +293,7 @@ def os(self) -> Optional[str]: return self.get_body().os @property - def ip_address(self) -> Optional[str]: + def ip_address(self) -> str | None: """The `ip_address` property. Returns: @@ -303,7 +302,7 @@ def ip_address(self) -> Optional[str]: return self.get_body().ip_address @property - def hostname(self) -> Optional[str]: + def hostname(self) -> str | None: """The `hostname` property. Returns: @@ -312,7 +311,7 @@ def hostname(self) -> Optional[str]: return self.get_body().hostname @property - def python_version(self) -> Optional[str]: + def python_version(self) -> str | None: """The `python_version` property. Returns: @@ -321,7 +320,7 @@ def python_version(self) -> Optional[str]: return self.get_metadata().python_version @property - def zenml_version(self) -> Optional[str]: + def zenml_version(self) -> str | None: """The `zenml_version` property. Returns: @@ -330,7 +329,7 @@ def zenml_version(self) -> Optional[str]: return self.get_metadata().zenml_version @property - def city(self) -> Optional[str]: + def city(self) -> str | None: """The `city` property. Returns: @@ -339,7 +338,7 @@ def city(self) -> Optional[str]: return self.get_metadata().city @property - def region(self) -> Optional[str]: + def region(self) -> str | None: """The `region` property. Returns: @@ -348,7 +347,7 @@ def region(self) -> Optional[str]: return self.get_metadata().region @property - def country(self) -> Optional[str]: + def country(self) -> str | None: """The `country` property. Returns: @@ -366,7 +365,7 @@ def failed_auth_attempts(self) -> int: return self.get_metadata().failed_auth_attempts @property - def last_login(self) -> Optional[datetime]: + def last_login(self) -> datetime | None: """The `last_login` property. Returns: @@ -388,7 +387,7 @@ class OAuthDeviceInternalResponse(OAuthDeviceResponse): def _verify_code( self, code: str, - code_hash: Optional[str], + code_hash: str | None, ) -> bool: """Verifies a given code against the stored (hashed) code. @@ -441,32 +440,32 @@ def verify_device_code( class OAuthDeviceFilter(UserScopedFilter): """Model to enable advanced filtering of OAuth2 devices.""" - expires: Optional[Union[datetime, str, None]] = Field( + expires: datetime | str | None | None = Field( default=None, description="The expiration date of the OAuth2 device.", union_mode="left_to_right", ) - client_id: Union[UUID, str, None] = Field( + client_id: UUID | str | None = Field( default=None, description="The client ID of the OAuth2 device.", union_mode="left_to_right", ) - status: Union[OAuthDeviceStatus, str, None] = Field( + status: OAuthDeviceStatus | str | None = Field( default=None, description="The status of the OAuth2 device.", union_mode="left_to_right", ) - trusted_device: Union[bool, str, None] = Field( + trusted_device: bool | str | None = Field( default=None, description="Whether the OAuth2 device was marked as trusted.", union_mode="left_to_right", ) - failed_auth_attempts: Union[int, str, None] = Field( + failed_auth_attempts: int | str | None = Field( default=None, description="The number of failed authentication attempts.", union_mode="left_to_right", ) - last_login: Optional[Union[datetime, str, None]] = Field( + last_login: datetime | str | None | None = Field( default=None, description="The date of the last successful login.", union_mode="left_to_right", diff --git a/src/zenml/models/v2/core/event_source.py b/src/zenml/models/v2/core/event_source.py index 91b21cecc6e..d4c429325f1 100644 --- a/src/zenml/models/v2/core/event_source.py +++ b/src/zenml/models/v2/core/event_source.py @@ -14,7 +14,7 @@ """Collection of all models concerning event configurations.""" import copy -from typing import Any, Dict, Optional +from typing import Any from pydantic import Field @@ -55,7 +55,7 @@ class EventSourceRequest(ProjectScopedRequest): max_length=STR_FIELD_MAX_LENGTH, ) - configuration: Dict[str, Any] = Field( + configuration: dict[str, Any] = Field( title="The event source configuration.", ) @@ -66,21 +66,21 @@ class EventSourceRequest(ProjectScopedRequest): class EventSourceUpdate(BaseUpdate): """Update model for event sources.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, title="The updated name of the event source.", max_length=STR_FIELD_MAX_LENGTH, ) - description: Optional[str] = Field( + description: str | None = Field( default=None, title="The updated description of the event source.", max_length=STR_FIELD_MAX_LENGTH, ) - configuration: Optional[Dict[str, Any]] = Field( + configuration: dict[str, Any] | None = Field( default=None, title="The updated event source configuration.", ) - is_active: Optional[bool] = Field( + is_active: bool | None = Field( default=None, title="The status of the event source.", ) @@ -131,7 +131,7 @@ class EventSourceResponseMetadata(ProjectScopedResponseMetadata): title="The description of the event source.", max_length=STR_FIELD_MAX_LENGTH, ) - configuration: Dict[str, Any] = Field( + configuration: dict[str, Any] = Field( title="The event source configuration.", ) @@ -206,7 +206,7 @@ def description(self) -> str: return self.get_metadata().description @property - def configuration(self) -> Dict[str, Any]: + def configuration(self) -> dict[str, Any]: """The `configuration` property. Returns: @@ -214,7 +214,7 @@ def configuration(self) -> Dict[str, Any]: """ return self.get_metadata().configuration - def set_configuration(self, configuration: Dict[str, Any]) -> None: + def set_configuration(self, configuration: dict[str, Any]) -> None: """Set the `configuration` property. Args: @@ -229,15 +229,15 @@ def set_configuration(self, configuration: Dict[str, Any]) -> None: class EventSourceFilter(ProjectScopedFilter): """Model to enable advanced filtering of all EventSourceModels.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the event source", ) - flavor: Optional[str] = Field( + flavor: str | None = Field( default=None, description="Flavor of the event source", ) - plugin_subtype: Optional[str] = Field( + plugin_subtype: str | None = Field( default=None, title="The plugin sub type of the event source.", max_length=STR_FIELD_MAX_LENGTH, diff --git a/src/zenml/models/v2/core/event_source_flavor.py b/src/zenml/models/v2/core/event_source_flavor.py index 11db12d74f7..01ffcd294b9 100644 --- a/src/zenml/models/v2/core/event_source_flavor.py +++ b/src/zenml/models/v2/core/event_source_flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing event source flavors..""" -from typing import Any, Dict +from typing import Any from zenml.models.v2.base.base_plugin_flavor import ( BasePluginFlavorResponse, @@ -30,8 +30,8 @@ class EventSourceFlavorResponseBody(BasePluginResponseBody): class EventSourceFlavorResponseMetadata(BasePluginResponseMetadata): """Response metadata for event flavors.""" - source_config_schema: Dict[str, Any] - filter_config_schema: Dict[str, Any] + source_config_schema: dict[str, Any] + filter_config_schema: dict[str, Any] class EventSourceFlavorResponseResources(BasePluginResponseResources): @@ -49,7 +49,7 @@ class EventSourceFlavorResponse( # Body and metadata properties @property - def source_config_schema(self) -> Dict[str, Any]: + def source_config_schema(self) -> dict[str, Any]: """The `source_config_schema` property. Returns: @@ -58,7 +58,7 @@ def source_config_schema(self) -> Dict[str, Any]: return self.get_metadata().source_config_schema @property - def filter_config_schema(self) -> Dict[str, Any]: + def filter_config_schema(self) -> dict[str, Any]: """The `filter_config_schema` property. Returns: diff --git a/src/zenml/models/v2/core/flavor.py b/src/zenml/models/v2/core/flavor.py index 331d4861ba5..58b83696fb5 100644 --- a/src/zenml/models/v2/core/flavor.py +++ b/src/zenml/models/v2/core/flavor.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing flavors.""" -from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Optional from pydantic import Field @@ -40,7 +40,7 @@ class FlavorRequest(UserScopedRequest): """Request model for stack component flavors.""" - ANALYTICS_FIELDS: ClassVar[List[str]] = [ + ANALYTICS_FIELDS: ClassVar[list[str]] = [ "type", "integration", ] @@ -50,20 +50,20 @@ class FlavorRequest(UserScopedRequest): max_length=STR_FIELD_MAX_LENGTH, ) type: StackComponentType = Field(title="The type of the Flavor.") - config_schema: Dict[str, Any] = Field( + config_schema: dict[str, Any] = Field( title="The JSON schema of this flavor's corresponding configuration.", ) - connector_type: Optional[str] = Field( + connector_type: str | None = Field( default=None, title="The type of the connector that this flavor uses.", max_length=STR_FIELD_MAX_LENGTH, ) - connector_resource_type: Optional[str] = Field( + connector_resource_type: str | None = Field( default=None, title="The resource type of the connector that this flavor uses.", max_length=STR_FIELD_MAX_LENGTH, ) - connector_resource_id_attr: Optional[str] = Field( + connector_resource_id_attr: str | None = Field( default=None, title="The name of an attribute in the stack component configuration " "that plays the role of resource ID when linked to a service " @@ -74,20 +74,20 @@ class FlavorRequest(UserScopedRequest): title="The path to the module which contains this Flavor.", max_length=STR_FIELD_MAX_LENGTH, ) - integration: Optional[str] = Field( + integration: str | None = Field( title="The name of the integration that the Flavor belongs to.", max_length=STR_FIELD_MAX_LENGTH, ) - logo_url: Optional[str] = Field( + logo_url: str | None = Field( default=None, title="Optionally, a url pointing to a png," "svg or jpg can be attached.", ) - docs_url: Optional[str] = Field( + docs_url: str | None = Field( default=None, title="Optionally, a url pointing to docs, within docs.zenml.io.", ) - sdk_docs_url: Optional[str] = Field( + sdk_docs_url: str | None = Field( default=None, title="Optionally, a url pointing to SDK docs," "within sdkdocs.zenml.io.", @@ -104,60 +104,60 @@ class FlavorRequest(UserScopedRequest): class FlavorUpdate(BaseUpdate): """Update model for stack component flavors.""" - name: Optional[str] = Field( + name: str | None = Field( title="The name of the Flavor.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - type: Optional[StackComponentType] = Field( + type: StackComponentType | None = Field( title="The type of the Flavor.", default=None ) - config_schema: Optional[Dict[str, Any]] = Field( + config_schema: dict[str, Any] | None = Field( title="The JSON schema of this flavor's corresponding configuration.", default=None, ) - connector_type: Optional[str] = Field( + connector_type: str | None = Field( title="The type of the connector that this flavor uses.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - connector_resource_type: Optional[str] = Field( + connector_resource_type: str | None = Field( title="The resource type of the connector that this flavor uses.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - connector_resource_id_attr: Optional[str] = Field( + connector_resource_id_attr: str | None = Field( title="The name of an attribute in the stack component configuration " "that plays the role of resource ID when linked to a service " "connector.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - source: Optional[str] = Field( + source: str | None = Field( title="The path to the module which contains this Flavor.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - integration: Optional[str] = Field( + integration: str | None = Field( title="The name of the integration that the Flavor belongs to.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - logo_url: Optional[str] = Field( + logo_url: str | None = Field( title="Optionally, a url pointing to a png," "svg or jpg can be attached.", default=None, ) - docs_url: Optional[str] = Field( + docs_url: str | None = Field( title="Optionally, a url pointing to docs, within docs.zenml.io.", default=None, ) - sdk_docs_url: Optional[str] = Field( + sdk_docs_url: str | None = Field( title="Optionally, a url pointing to SDK docs," "within sdkdocs.zenml.io.", default=None, ) - is_custom: Optional[bool] = Field( + is_custom: bool | None = Field( title="Whether or not this flavor is a custom, user created flavor.", default=None, ) @@ -170,7 +170,7 @@ class FlavorResponseBody(UserScopedResponseBody): """Response body for stack component flavors.""" type: StackComponentType = Field(title="The type of the Flavor.") - integration: Optional[str] = Field( + integration: str | None = Field( title="The name of the integration that the Flavor belongs to.", max_length=STR_FIELD_MAX_LENGTH, ) @@ -178,7 +178,7 @@ class FlavorResponseBody(UserScopedResponseBody): title="The path to the module which contains this Flavor.", max_length=STR_FIELD_MAX_LENGTH, ) - logo_url: Optional[str] = Field( + logo_url: str | None = Field( default=None, title="Optionally, a url pointing to a png," "svg or jpg can be attached.", @@ -192,31 +192,31 @@ class FlavorResponseBody(UserScopedResponseBody): class FlavorResponseMetadata(UserScopedResponseMetadata): """Response metadata for stack component flavors.""" - config_schema: Dict[str, Any] = Field( + config_schema: dict[str, Any] = Field( title="The JSON schema of this flavor's corresponding configuration.", ) - connector_type: Optional[str] = Field( + connector_type: str | None = Field( default=None, title="The type of the connector that this flavor uses.", max_length=STR_FIELD_MAX_LENGTH, ) - connector_resource_type: Optional[str] = Field( + connector_resource_type: str | None = Field( default=None, title="The resource type of the connector that this flavor uses.", max_length=STR_FIELD_MAX_LENGTH, ) - connector_resource_id_attr: Optional[str] = Field( + connector_resource_id_attr: str | None = Field( default=None, title="The name of an attribute in the stack component configuration " "that plays the role of resource ID when linked to a service " "connector.", max_length=STR_FIELD_MAX_LENGTH, ) - docs_url: Optional[str] = Field( + docs_url: str | None = Field( default=None, title="Optionally, a url pointing to docs, within docs.zenml.io.", ) - sdk_docs_url: Optional[str] = Field( + sdk_docs_url: str | None = Field( default=None, title="Optionally, a url pointing to SDK docs," "within sdkdocs.zenml.io.", @@ -237,7 +237,7 @@ class FlavorResponse( """Response model for stack component flavors.""" # Analytics - ANALYTICS_FIELDS: ClassVar[List[str]] = [ + ANALYTICS_FIELDS: ClassVar[list[str]] = [ "id", "type", "integration", @@ -292,7 +292,7 @@ def type(self) -> StackComponentType: return self.get_body().type @property - def integration(self) -> Optional[str]: + def integration(self) -> str | None: """The `integration` property. Returns: @@ -310,7 +310,7 @@ def source(self) -> str: return self.get_body().source @property - def logo_url(self) -> Optional[str]: + def logo_url(self) -> str | None: """The `logo_url` property. Returns: @@ -328,7 +328,7 @@ def is_custom(self) -> bool: return self.get_body().is_custom @property - def config_schema(self) -> Dict[str, Any]: + def config_schema(self) -> dict[str, Any]: """The `config_schema` property. Returns: @@ -337,7 +337,7 @@ def config_schema(self) -> Dict[str, Any]: return self.get_metadata().config_schema @property - def connector_type(self) -> Optional[str]: + def connector_type(self) -> str | None: """The `connector_type` property. Returns: @@ -346,7 +346,7 @@ def connector_type(self) -> Optional[str]: return self.get_metadata().connector_type @property - def connector_resource_type(self) -> Optional[str]: + def connector_resource_type(self) -> str | None: """The `connector_resource_type` property. Returns: @@ -355,7 +355,7 @@ def connector_resource_type(self) -> Optional[str]: return self.get_metadata().connector_resource_type @property - def connector_resource_id_attr(self) -> Optional[str]: + def connector_resource_id_attr(self) -> str | None: """The `connector_resource_id_attr` property. Returns: @@ -364,7 +364,7 @@ def connector_resource_id_attr(self) -> Optional[str]: return self.get_metadata().connector_resource_id_attr @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """The `docs_url` property. Returns: @@ -373,7 +373,7 @@ def docs_url(self) -> Optional[str]: return self.get_metadata().docs_url @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """The `sdk_docs_url` property. Returns: @@ -388,15 +388,15 @@ def sdk_docs_url(self) -> Optional[str]: class FlavorFilter(UserScopedFilter): """Model to enable advanced stack component flavor filtering.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the flavor", ) - type: Optional[str] = Field( + type: str | None = Field( default=None, description="Stack Component Type of the stack flavor", ) - integration: Optional[str] = Field( + integration: str | None = Field( default=None, description="Integration associated with the flavor", ) diff --git a/src/zenml/models/v2/core/logs.py b/src/zenml/models/v2/core/logs.py index ae67c8dc635..060224dba18 100644 --- a/src/zenml/models/v2/core/logs.py +++ b/src/zenml/models/v2/core/logs.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing logs.""" -from typing import Any, Optional +from typing import Any from uuid import UUID from pydantic import Field, field_validator @@ -85,12 +85,12 @@ class LogsResponseBody(BaseDatedResponseBody): class LogsResponseMetadata(BaseResponseMetadata): """Response metadata for logs.""" - step_run_id: Optional[UUID] = Field( + step_run_id: UUID | None = Field( title="Step ID to associate the logs with.", default=None, description="When this is set, pipeline_run_id should be set to None.", ) - pipeline_run_id: Optional[UUID] = Field( + pipeline_run_id: UUID | None = Field( title="Pipeline run ID to associate the logs with.", default=None, description="When this is set, step_run_id should be set to None.", @@ -141,7 +141,7 @@ def source(self) -> str: return self.get_body().source @property - def step_run_id(self) -> Optional[UUID]: + def step_run_id(self) -> UUID | None: """The `step_run_id` property. Returns: @@ -150,7 +150,7 @@ def step_run_id(self) -> Optional[UUID]: return self.get_metadata().step_run_id @property - def pipeline_run_id(self) -> Optional[UUID]: + def pipeline_run_id(self) -> UUID | None: """The `pipeline_run_id` property. Returns: diff --git a/src/zenml/models/v2/core/model.py b/src/zenml/models/v2/core/model.py index 68ff7f9582d..c5b0dce3ff2 100644 --- a/src/zenml/models/v2/core/model.py +++ b/src/zenml/models/v2/core/model.py @@ -17,9 +17,6 @@ TYPE_CHECKING, Any, ClassVar, - List, - Optional, - Type, TypeVar, ) from uuid import UUID @@ -65,42 +62,42 @@ class ModelRequest(ProjectScopedRequest): title="The name of the model", max_length=STR_FIELD_MAX_LENGTH, ) - license: Optional[str] = Field( + license: str | None = Field( title="The license model created under", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - description: Optional[str] = Field( + description: str | None = Field( title="The description of the model", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - audience: Optional[str] = Field( + audience: str | None = Field( title="The target audience of the model", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - use_cases: Optional[str] = Field( + use_cases: str | None = Field( title="The use cases of the model", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - limitations: Optional[str] = Field( + limitations: str | None = Field( title="The know limitations of the model", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - trade_offs: Optional[str] = Field( + trade_offs: str | None = Field( title="The trade offs of the model", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - ethics: Optional[str] = Field( + ethics: str | None = Field( title="The ethical implications of the model", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - tags: Optional[List[str]] = Field( + tags: list[str] | None = Field( title="Tags associated with the model", default=None, ) @@ -116,17 +113,17 @@ class ModelRequest(ProjectScopedRequest): class ModelUpdate(BaseUpdate): """Update model for models.""" - name: Optional[str] = None - license: Optional[str] = None - description: Optional[str] = None - audience: Optional[str] = None - use_cases: Optional[str] = None - limitations: Optional[str] = None - trade_offs: Optional[str] = None - ethics: Optional[str] = None - add_tags: Optional[List[str]] = None - remove_tags: Optional[List[str]] = None - save_models_to_registry: Optional[bool] = None + name: str | None = None + license: str | None = None + description: str | None = None + audience: str | None = None + use_cases: str | None = None + limitations: str | None = None + trade_offs: str | None = None + ethics: str | None = None + add_tags: list[str] | None = None + remove_tags: list[str] | None = None + save_models_to_registry: bool | None = None # ------------------ Response Model ------------------ @@ -139,37 +136,37 @@ class ModelResponseBody(ProjectScopedResponseBody): class ModelResponseMetadata(ProjectScopedResponseMetadata): """Response metadata for models.""" - license: Optional[str] = Field( + license: str | None = Field( title="The license model created under", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - description: Optional[str] = Field( + description: str | None = Field( title="The description of the model", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - audience: Optional[str] = Field( + audience: str | None = Field( title="The target audience of the model", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - use_cases: Optional[str] = Field( + use_cases: str | None = Field( title="The use cases of the model", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - limitations: Optional[str] = Field( + limitations: str | None = Field( title="The know limitations of the model", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - trade_offs: Optional[str] = Field( + trade_offs: str | None = Field( title="The trade offs of the model", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - ethics: Optional[str] = Field( + ethics: str | None = Field( title="The ethical implications of the model", max_length=TEXT_FIELD_MAX_LENGTH, default=None, @@ -183,12 +180,12 @@ class ModelResponseMetadata(ProjectScopedResponseMetadata): class ModelResponseResources(ProjectScopedResponseResources): """Class for all resource models associated with the model entity.""" - tags: List["TagResponse"] = Field( + tags: list["TagResponse"] = Field( title="Tags associated with the model", ) - latest_version_name: Optional[str] = None - latest_version_id: Optional[UUID] = None - visualizations: List["CuratedVisualizationResponse"] = Field( + latest_version_name: str | None = None + latest_version_id: UUID | None = None + visualizations: list["CuratedVisualizationResponse"] = Field( default_factory=list, title="Curated visualizations associated with the model.", ) @@ -218,7 +215,7 @@ def get_hydrated_version(self) -> "ModelResponse": # Body and metadata properties @property - def tags(self) -> List["TagResponse"]: + def tags(self) -> list["TagResponse"]: """The `tags` property. Returns: @@ -227,7 +224,7 @@ def tags(self) -> List["TagResponse"]: return self.get_resources().tags @property - def latest_version_name(self) -> Optional[str]: + def latest_version_name(self) -> str | None: """The `latest_version_name` property. Returns: @@ -236,7 +233,7 @@ def latest_version_name(self) -> Optional[str]: return self.get_resources().latest_version_name @property - def latest_version_id(self) -> Optional[UUID]: + def latest_version_id(self) -> UUID | None: """The `latest_version_id` property. Returns: @@ -245,7 +242,7 @@ def latest_version_id(self) -> Optional[UUID]: return self.get_resources().latest_version_id @property - def license(self) -> Optional[str]: + def license(self) -> str | None: """The `license` property. Returns: @@ -254,7 +251,7 @@ def license(self) -> Optional[str]: return self.get_metadata().license @property - def description(self) -> Optional[str]: + def description(self) -> str | None: """The `description` property. Returns: @@ -263,7 +260,7 @@ def description(self) -> Optional[str]: return self.get_metadata().description @property - def audience(self) -> Optional[str]: + def audience(self) -> str | None: """The `audience` property. Returns: @@ -272,7 +269,7 @@ def audience(self) -> Optional[str]: return self.get_metadata().audience @property - def use_cases(self) -> Optional[str]: + def use_cases(self) -> str | None: """The `use_cases` property. Returns: @@ -281,7 +278,7 @@ def use_cases(self) -> Optional[str]: return self.get_metadata().use_cases @property - def limitations(self) -> Optional[str]: + def limitations(self) -> str | None: """The `limitations` property. Returns: @@ -290,7 +287,7 @@ def limitations(self) -> Optional[str]: return self.get_metadata().limitations @property - def trade_offs(self) -> Optional[str]: + def trade_offs(self) -> str | None: """The `trade_offs` property. Returns: @@ -299,7 +296,7 @@ def trade_offs(self) -> Optional[str]: return self.get_metadata().trade_offs @property - def ethics(self) -> Optional[str]: + def ethics(self) -> str | None: """The `ethics` property. Returns: @@ -317,7 +314,7 @@ def save_models_to_registry(self) -> bool: return self.get_metadata().save_models_to_registry @property - def visualizations(self) -> List["CuratedVisualizationResponse"]: + def visualizations(self) -> list["CuratedVisualizationResponse"]: """The `visualizations` property. Returns: @@ -327,7 +324,7 @@ def visualizations(self) -> List["CuratedVisualizationResponse"]: # Helper functions @property - def versions(self) -> List["Model"]: + def versions(self) -> list["Model"]: """List all versions of the model. Returns: @@ -353,21 +350,21 @@ def versions(self) -> List["Model"]: class ModelFilter(ProjectScopedFilter, TaggableFilter): """Model to enable advanced filtering of all models.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the Model", ) - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS, *TaggableFilter.FILTER_EXCLUDE_FIELDS, ] - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + CUSTOM_SORTING_OPTIONS: ClassVar[list[str]] = [ *ProjectScopedFilter.CUSTOM_SORTING_OPTIONS, *TaggableFilter.CUSTOM_SORTING_OPTIONS, SORT_BY_LATEST_VERSION_KEY, ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.CLI_EXCLUDE_FIELDS, *TaggableFilter.CLI_EXCLUDE_FIELDS, ] @@ -375,7 +372,7 @@ class ModelFilter(ProjectScopedFilter, TaggableFilter): def apply_sorting( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Apply sorting to the query for Models. diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index b181a6d6c4f..8f373603f03 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -17,10 +17,7 @@ from typing import ( TYPE_CHECKING, ClassVar, - Dict, - List, Optional, - Type, TypeVar, Union, ) @@ -70,17 +67,17 @@ class ModelVersionRequest(ProjectScopedRequest): """Request model for model versions.""" - name: Optional[str] = Field( + name: str | None = Field( description="The name of the model version", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - description: Optional[str] = Field( + description: str | None = Field( description="The description of the model version", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - stage: Optional[str] = Field( + stage: str | None = Field( description="The stage of the model version", max_length=STR_FIELD_MAX_LENGTH, default=None, @@ -89,7 +86,7 @@ class ModelVersionRequest(ProjectScopedRequest): model: UUID = Field( description="The ID of the model containing version", ) - tags: Optional[List[str]] = Field( + tags: list[str] | None = Field( title="Tags associated with the model version", default=None, ) @@ -101,7 +98,7 @@ class ModelVersionRequest(ProjectScopedRequest): class ModelVersionUpdate(BaseUpdate): """Update model for model versions.""" - stage: Optional[Union[str, ModelStages]] = Field( + stage: str | ModelStages | None = Field( description="Target model version stage to be set", default=None, union_mode="left_to_right", @@ -111,19 +108,19 @@ class ModelVersionUpdate(BaseUpdate): "silently archived or an error should be raised.", default=False, ) - name: Optional[str] = Field( + name: str | None = Field( description="Target model version name to be set", default=None, ) - description: Optional[str] = Field( + description: str | None = Field( description="Target model version description to be set", default=None, ) - add_tags: Optional[List[str]] = Field( + add_tags: list[str] | None = Field( description="Tags to be added to the model version", default=None, ) - remove_tags: Optional[List[str]] = Field( + remove_tags: list[str] | None = Field( description="Tags to be removed from the model version", default=None, ) @@ -145,7 +142,7 @@ def _validate_stage(cls, stage: str) -> str: class ModelVersionResponseBody(ProjectScopedResponseBody): """Response body for model versions.""" - stage: Optional[str] = Field( + stage: str | None = Field( description="The stage of the model version", max_length=STR_FIELD_MAX_LENGTH, default=None, @@ -169,12 +166,12 @@ class ModelVersionResponseBody(ProjectScopedResponseBody): class ModelVersionResponseMetadata(ProjectScopedResponseMetadata): """Response metadata for model versions.""" - description: Optional[str] = Field( + description: str | None = Field( description="The description of the model version", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - run_metadata: Dict[str, MetadataType] = Field( + run_metadata: dict[str, MetadataType] = Field( description="Metadata linked to the model version", default={}, ) @@ -186,7 +183,7 @@ class ModelVersionResponseResources(ProjectScopedResponseResources): services: Page[ServiceResponse] = Field( description="Services linked to the model version", ) - tags: List[TagResponse] = Field( + tags: list[TagResponse] = Field( title="Tags associated with the model version", default=[] ) @@ -200,14 +197,14 @@ class ModelVersionResponse( ): """Response model for model versions.""" - name: Optional[str] = Field( + name: str | None = Field( description="The name of the model version", max_length=STR_FIELD_MAX_LENGTH, default=None, ) @property - def stage(self) -> Optional[str]: + def stage(self) -> str | None: """The `stage` property. Returns: @@ -234,7 +231,7 @@ def model(self) -> "ModelResponse": return self.get_body().model @property - def tags(self) -> List[TagResponse]: + def tags(self) -> list[TagResponse]: """The `tags` property. Returns: @@ -243,7 +240,7 @@ def tags(self) -> List[TagResponse]: return self.get_resources().tags @property - def description(self) -> Optional[str]: + def description(self) -> str | None: """The `description` property. Returns: @@ -252,7 +249,7 @@ def description(self) -> Optional[str]: return self.get_metadata().description @property - def run_metadata(self) -> Dict[str, MetadataType]: + def run_metadata(self) -> dict[str, MetadataType]: """The `run_metadata` property. Returns: @@ -306,7 +303,7 @@ def to_model_class( @property def model_artifacts( self, - ) -> Dict[str, Dict[str, "ArtifactVersionResponse"]]: + ) -> dict[str, dict[str, "ArtifactVersionResponse"]]: """Get all model artifacts linked to this model version. Returns: @@ -326,7 +323,7 @@ def model_artifacts( project=self.project_id, ) - result: Dict[str, Dict[str, "ArtifactVersionResponse"]] = {} + result: dict[str, dict[str, "ArtifactVersionResponse"]] = {} for artifact_version in artifact_versions: result.setdefault(artifact_version.name, {}) result[artifact_version.name][artifact_version.version] = ( @@ -336,7 +333,7 @@ def model_artifacts( return result @property - def data_artifact_ids(self) -> Dict[str, Dict[str, UUID]]: + def data_artifact_ids(self) -> dict[str, dict[str, UUID]]: """Data artifacts linked to this model version. Returns: @@ -356,7 +353,7 @@ def data_artifact_ids(self) -> Dict[str, Dict[str, UUID]]: } @property - def model_artifact_ids(self) -> Dict[str, Dict[str, UUID]]: + def model_artifact_ids(self) -> dict[str, dict[str, UUID]]: """Model artifacts linked to this model version. Returns: @@ -376,7 +373,7 @@ def model_artifact_ids(self) -> Dict[str, Dict[str, UUID]]: } @property - def deployment_artifact_ids(self) -> Dict[str, Dict[str, UUID]]: + def deployment_artifact_ids(self) -> dict[str, dict[str, UUID]]: """Deployment artifacts linked to this model version. Returns: @@ -398,7 +395,7 @@ def deployment_artifact_ids(self) -> Dict[str, Dict[str, UUID]]: @property def data_artifacts( self, - ) -> Dict[str, Dict[str, "ArtifactVersionResponse"]]: + ) -> dict[str, dict[str, "ArtifactVersionResponse"]]: """Get all data artifacts linked to this model version. Returns: @@ -426,7 +423,7 @@ def data_artifacts( project=self.project_id, ) - result: Dict[str, Dict[str, "ArtifactVersionResponse"]] = {} + result: dict[str, dict[str, "ArtifactVersionResponse"]] = {} for artifact_version in artifact_versions: result.setdefault(artifact_version.name, {}) result[artifact_version.name][artifact_version.version] = ( @@ -438,7 +435,7 @@ def data_artifacts( @property def deployment_artifacts( self, - ) -> Dict[str, Dict[str, "ArtifactVersionResponse"]]: + ) -> dict[str, dict[str, "ArtifactVersionResponse"]]: """Get all deployment artifacts linked to this model version. Returns: @@ -459,7 +456,7 @@ def deployment_artifacts( project=self.project_id, ) - result: Dict[str, Dict[str, "ArtifactVersionResponse"]] = {} + result: dict[str, dict[str, "ArtifactVersionResponse"]] = {} for artifact_version in artifact_versions: result.setdefault(artifact_version.name, {}) result[artifact_version.name][artifact_version.version] = ( @@ -469,7 +466,7 @@ def deployment_artifacts( return result @property - def pipeline_run_ids(self) -> Dict[str, UUID]: + def pipeline_run_ids(self) -> dict[str, UUID]: """Pipeline runs linked to this model version. Returns: @@ -491,7 +488,7 @@ def pipeline_run_ids(self) -> Dict[str, UUID]: } @property - def pipeline_runs(self) -> Dict[str, "PipelineRunResponse"]: + def pipeline_runs(self) -> dict[str, "PipelineRunResponse"]: """Get all pipeline runs linked to this version. Returns: @@ -515,8 +512,8 @@ def pipeline_runs(self) -> Dict[str, "PipelineRunResponse"]: def _get_linked_object( self, name: str, - version: Optional[str] = None, - type: Optional[ArtifactType] = None, + version: str | None = None, + type: ArtifactType | None = None, ) -> Optional["ArtifactVersionResponse"]: """Get the artifact linked to this model version given type. @@ -549,7 +546,7 @@ def _get_linked_object( def get_artifact( self, name: str, - version: Optional[str] = None, + version: str | None = None, ) -> Optional["ArtifactVersionResponse"]: """Get the artifact linked to this model version. @@ -566,7 +563,7 @@ def get_artifact( def get_model_artifact( self, name: str, - version: Optional[str] = None, + version: str | None = None, ) -> Optional["ArtifactVersionResponse"]: """Get the model artifact linked to this model version. @@ -583,7 +580,7 @@ def get_model_artifact( def get_data_artifact( self, name: str, - version: Optional[str] = None, + version: str | None = None, ) -> Optional["ArtifactVersionResponse"]: """Get the data artifact linked to this model version. @@ -600,7 +597,7 @@ def get_data_artifact( def get_deployment_artifact( self, name: str, - version: Optional[str] = None, + version: str | None = None, ) -> Optional["ArtifactVersionResponse"]: """Get the deployment artifact linked to this model version. @@ -615,7 +612,7 @@ def get_deployment_artifact( return self._get_linked_object(name, version, ArtifactType.SERVICE) def set_stage( - self, stage: Union[str, ModelStages], force: bool = False + self, stage: str | ModelStages, force: bool = False ) -> None: """Sets this Model Version to a desired stage. @@ -649,42 +646,42 @@ class ModelVersionFilter( ): """Filter model for model versions.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS, *TaggableFilter.FILTER_EXCLUDE_FIELDS, *RunMetadataFilterMixin.FILTER_EXCLUDE_FIELDS, "model", ] - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + CUSTOM_SORTING_OPTIONS: ClassVar[list[str]] = [ *ProjectScopedFilter.CUSTOM_SORTING_OPTIONS, *TaggableFilter.CUSTOM_SORTING_OPTIONS, *RunMetadataFilterMixin.CUSTOM_SORTING_OPTIONS, ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.CLI_EXCLUDE_FIELDS, *TaggableFilter.CLI_EXCLUDE_FIELDS, *RunMetadataFilterMixin.CLI_EXCLUDE_FIELDS, ] - API_MULTI_INPUT_PARAMS: ClassVar[List[str]] = [ + API_MULTI_INPUT_PARAMS: ClassVar[list[str]] = [ *ProjectScopedFilter.API_MULTI_INPUT_PARAMS, *TaggableFilter.API_MULTI_INPUT_PARAMS, *RunMetadataFilterMixin.API_MULTI_INPUT_PARAMS, ] - name: Optional[str] = Field( + name: str | None = Field( default=None, description="The name of the Model Version", ) - number: Optional[int] = Field( + number: int | None = Field( default=None, description="The number of the Model Version", ) - stage: Optional[Union[str, ModelStages]] = Field( + stage: str | ModelStages | None = Field( description="The model version stage", default=None, union_mode="left_to_right", ) - model: Optional[Union[str, UUID]] = Field( + model: str | UUID | None = Field( default=None, description="The name or ID of the model which the search is scoped " "to. This field must always be set and is always applied in addition " @@ -694,8 +691,8 @@ class ModelVersionFilter( ) def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List[Union["ColumnElement[bool]"]]: + self, table: type["AnySchema"] + ) -> list[Union["ColumnElement[bool]"]]: """Get custom filters. Args: diff --git a/src/zenml/models/v2/core/model_version_artifact.py b/src/zenml/models/v2/core/model_version_artifact.py index 6c9514b9735..d36593ca3f7 100644 --- a/src/zenml/models/v2/core/model_version_artifact.py +++ b/src/zenml/models/v2/core/model_version_artifact.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing the link between model versions and artifacts.""" -from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, TypeVar, Union from uuid import UUID from pydantic import ConfigDict, Field @@ -136,25 +136,25 @@ class ModelVersionArtifactFilter(BaseFilter): "id", ] - model_version_id: Optional[Union[UUID, str]] = Field( + model_version_id: UUID | str | None = Field( default=None, description="Filter by model version ID", union_mode="left_to_right", ) - artifact_version_id: Optional[Union[UUID, str]] = Field( + artifact_version_id: UUID | str | None = Field( default=None, description="Filter by artifact ID", union_mode="left_to_right", ) - artifact_name: Optional[str] = Field( + artifact_name: str | None = Field( default=None, description="Name of the artifact", ) - only_data_artifacts: Optional[bool] = False - only_model_artifacts: Optional[bool] = False - only_deployment_artifacts: Optional[bool] = False - has_custom_name: Optional[bool] = None - user: Optional[Union[UUID, str]] = Field( + only_data_artifacts: bool | None = False + only_model_artifacts: bool | None = False + only_deployment_artifacts: bool | None = False + has_custom_name: bool | None = None + user: UUID | str | None = Field( default=None, description="Name/ID of the user that created the artifact.", ) @@ -168,8 +168,8 @@ class ModelVersionArtifactFilter(BaseFilter): model_config = ConfigDict(protected_namespaces=()) def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List[Union["ColumnElement[bool]"]]: + self, table: type["AnySchema"] + ) -> list[Union["ColumnElement[bool]"]]: """Get custom filters. Args: diff --git a/src/zenml/models/v2/core/model_version_pipeline_run.py b/src/zenml/models/v2/core/model_version_pipeline_run.py index 994a7e9b5f0..68aa66ff8cb 100644 --- a/src/zenml/models/v2/core/model_version_pipeline_run.py +++ b/src/zenml/models/v2/core/model_version_pipeline_run.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing the link between model versions and pipeline runs.""" -from typing import TYPE_CHECKING, List, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, TypeVar from uuid import UUID from pydantic import ConfigDict, Field @@ -127,21 +127,21 @@ class ModelVersionPipelineRunFilter(BaseFilter): "id", ] - model_version_id: Optional[Union[UUID, str]] = Field( + model_version_id: UUID | str | None = Field( default=None, description="Filter by model version ID", union_mode="left_to_right", ) - pipeline_run_id: Optional[Union[UUID, str]] = Field( + pipeline_run_id: UUID | str | None = Field( default=None, description="Filter by pipeline run ID", union_mode="left_to_right", ) - pipeline_run_name: Optional[str] = Field( + pipeline_run_name: str | None = Field( default=None, description="Name of the pipeline run", ) - user: Optional[Union[UUID, str]] = Field( + user: UUID | str | None = Field( default=None, description="Name/ID of the user that created the pipeline run.", ) @@ -155,8 +155,8 @@ class ModelVersionPipelineRunFilter(BaseFilter): model_config = ConfigDict(protected_namespaces=()) def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List["ColumnElement[bool]"]: + self, table: type["AnySchema"] + ) -> list["ColumnElement[bool]"]: """Get custom filters. Args: diff --git a/src/zenml/models/v2/core/pipeline.py b/src/zenml/models/v2/core/pipeline.py index 6122a07c978..2171e344130 100644 --- a/src/zenml/models/v2/core/pipeline.py +++ b/src/zenml/models/v2/core/pipeline.py @@ -17,11 +17,8 @@ TYPE_CHECKING, Any, ClassVar, - List, Optional, - Type, TypeVar, - Union, ) from uuid import UUID @@ -67,12 +64,12 @@ class PipelineRequest(ProjectScopedRequest): title="The name of the pipeline.", max_length=STR_FIELD_MAX_LENGTH, ) - description: Optional[str] = Field( + description: str | None = Field( default=None, title="The description of the pipeline.", max_length=TEXT_FIELD_MAX_LENGTH, ) - tags: Optional[List[str]] = Field( + tags: list[str] | None = Field( default=None, title="Tags of the pipeline.", ) @@ -84,15 +81,15 @@ class PipelineRequest(ProjectScopedRequest): class PipelineUpdate(BaseUpdate): """Update model for pipelines.""" - description: Optional[str] = Field( + description: str | None = Field( default=None, title="The description of the pipeline.", max_length=TEXT_FIELD_MAX_LENGTH, ) - add_tags: Optional[List[str]] = Field( + add_tags: list[str] | None = Field( default=None, title="New tags to add to the pipeline." ) - remove_tags: Optional[List[str]] = Field( + remove_tags: list[str] | None = Field( default=None, title="Tags to remove from the pipeline." ) @@ -107,7 +104,7 @@ class PipelineResponseBody(ProjectScopedResponseBody): class PipelineResponseMetadata(ProjectScopedResponseMetadata): """Response metadata for pipelines.""" - description: Optional[str] = Field( + description: str | None = Field( default=None, title="The description of the pipeline.", ) @@ -120,18 +117,18 @@ class PipelineResponseResources(ProjectScopedResponseResources): default=None, title="The user that created the latest run of this pipeline.", ) - latest_run_id: Optional[UUID] = Field( + latest_run_id: UUID | None = Field( default=None, title="The ID of the latest run of the pipeline.", ) - latest_run_status: Optional[ExecutionStatus] = Field( + latest_run_status: ExecutionStatus | None = Field( default=None, title="The status of the latest run of the pipeline.", ) - tags: List[TagResponse] = Field( + tags: list[TagResponse] = Field( title="Tags associated with the pipeline.", ) - visualizations: List["CuratedVisualizationResponse"] = Field( + visualizations: list["CuratedVisualizationResponse"] = Field( default=[], title="Curated visualizations associated with the pipeline.", ) @@ -162,7 +159,7 @@ def get_hydrated_version(self) -> "PipelineResponse": return Client().zen_store.get_pipeline(self.id) # Helper methods - def get_runs(self, **kwargs: Any) -> List["PipelineRunResponse"]: + def get_runs(self, **kwargs: Any) -> list["PipelineRunResponse"]: """Get runs of this pipeline. Can be used to fetch runs other than `self.runs` and supports @@ -180,7 +177,7 @@ def get_runs(self, **kwargs: Any) -> List["PipelineRunResponse"]: return Client().list_pipeline_runs(pipeline_id=self.id, **kwargs).items @property - def runs(self) -> List["PipelineRunResponse"]: + def runs(self) -> list["PipelineRunResponse"]: """Returns the 20 most recent runs of this pipeline in descending order. Returns: @@ -235,7 +232,7 @@ def last_successful_run(self) -> "PipelineRunResponse": return runs[0] @property - def latest_run_id(self) -> Optional[UUID]: + def latest_run_id(self) -> UUID | None: """The `latest_run_id` property. Returns: @@ -244,7 +241,7 @@ def latest_run_id(self) -> Optional[UUID]: return self.get_resources().latest_run_id @property - def latest_run_status(self) -> Optional[ExecutionStatus]: + def latest_run_status(self) -> ExecutionStatus | None: """The `latest_run_status` property. Returns: @@ -253,7 +250,7 @@ def latest_run_status(self) -> Optional[ExecutionStatus]: return self.get_resources().latest_run_status @property - def tags(self) -> List[TagResponse]: + def tags(self) -> list[TagResponse]: """The `tags` property. Returns: @@ -268,33 +265,33 @@ def tags(self) -> List[TagResponse]: class PipelineFilter(ProjectScopedFilter, TaggableFilter): """Pipeline filter model.""" - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + CUSTOM_SORTING_OPTIONS: ClassVar[list[str]] = [ *ProjectScopedFilter.CUSTOM_SORTING_OPTIONS, *TaggableFilter.CUSTOM_SORTING_OPTIONS, SORT_PIPELINES_BY_LATEST_RUN_KEY, ] - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS, *TaggableFilter.FILTER_EXCLUDE_FIELDS, "latest_run_status", "latest_run_user", ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.CLI_EXCLUDE_FIELDS, *TaggableFilter.CLI_EXCLUDE_FIELDS, ] - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the Pipeline", ) - latest_run_status: Optional[str] = Field( + latest_run_status: str | None = Field( default=None, description="Filter by the status of the latest run of a pipeline. " "This will always be applied as an `AND` filter for now.", ) - latest_run_user: Optional[Union[UUID, str]] = Field( + latest_run_user: UUID | str | None = Field( default=None, description="Filter by the name or id of the last user that executed the pipeline. ", ) @@ -309,7 +306,7 @@ def filter_by_latest_execution(self) -> bool: return bool(self.latest_run_user) or bool(self.latest_run_status) def apply_filter( - self, query: AnyQuery, table: Type["AnySchema"] + self, query: AnyQuery, table: type["AnySchema"] ) -> AnyQuery: """Applies the filter to a query. @@ -380,7 +377,7 @@ def apply_filter( def apply_sorting( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Apply sorting to the query. diff --git a/src/zenml/models/v2/core/pipeline_build.py b/src/zenml/models/v2/core/pipeline_build.py index f3157f02e3c..1fb6aa403eb 100644 --- a/src/zenml/models/v2/core/pipeline_build.py +++ b/src/zenml/models/v2/core/pipeline_build.py @@ -18,12 +18,8 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - List, Optional, - Type, TypeVar, - Union, ) from uuid import UUID @@ -56,7 +52,7 @@ class PipelineBuildBase(BaseZenModel): """Base model for pipeline builds.""" - images: Dict[str, BuildItem] = Field( + images: dict[str, BuildItem] = Field( default={}, title="The images of this build." ) is_local: bool = Field( @@ -66,13 +62,13 @@ class PipelineBuildBase(BaseZenModel): contains_code: bool = Field( title="Whether any image of the build contains user code.", ) - zenml_version: Optional[str] = Field( + zenml_version: str | None = Field( title="The version of ZenML used for this build.", default=None ) - python_version: Optional[str] = Field( + python_version: str | None = Field( title="The Python version used for this build.", default=None ) - duration: Optional[int] = Field( + duration: int | None = Field( title="The duration of the build in seconds.", default=None ) @@ -89,7 +85,7 @@ def requires_code_download(self) -> bool: ) @staticmethod - def get_image_key(component_key: str, step: Optional[str] = None) -> str: + def get_image_key(component_key: str, step: str | None = None) -> str: """Get the image key. Args: @@ -104,7 +100,7 @@ def get_image_key(component_key: str, step: Optional[str] = None) -> str: else: return component_key - def get_image(self, component_key: str, step: Optional[str] = None) -> str: + def get_image(self, component_key: str, step: str | None = None) -> str: """Get the image built for a specific key. Args: @@ -119,8 +115,8 @@ def get_image(self, component_key: str, step: Optional[str] = None) -> str: return self._get_item(component_key=component_key, step=step).image def get_settings_checksum( - self, component_key: str, step: Optional[str] = None - ) -> Optional[str]: + self, component_key: str, step: str | None = None + ) -> str | None: """Get the settings checksum for a specific key. Args: @@ -137,7 +133,7 @@ def get_settings_checksum( ).settings_checksum def _get_item( - self, component_key: str, step: Optional[str] = None + self, component_key: str, step: str | None = None ) -> "BuildItem": """Get the item for a specific key. @@ -174,15 +170,15 @@ def _get_item( class PipelineBuildRequest(PipelineBuildBase, ProjectScopedRequest): """Request model for pipelines builds.""" - checksum: Optional[str] = Field(title="The build checksum.", default=None) - stack_checksum: Optional[str] = Field( + checksum: str | None = Field(title="The build checksum.", default=None) + stack_checksum: str | None = Field( title="The stack checksum.", default=None ) - stack: Optional[UUID] = Field( + stack: UUID | None = Field( title="The stack that was used for this build.", default=None ) - pipeline: Optional[UUID] = Field( + pipeline: UUID | None = Field( title="The pipeline that was used for this build.", default=None ) @@ -200,7 +196,7 @@ class PipelineBuildResponseBody(ProjectScopedResponseBody): class PipelineBuildResponseMetadata(ProjectScopedResponseMetadata): """Response metadata for pipeline builds.""" - __zenml_skip_dehydration__: ClassVar[List[str]] = [ + __zenml_skip_dehydration__: ClassVar[list[str]] = [ "images", ] @@ -210,17 +206,17 @@ class PipelineBuildResponseMetadata(ProjectScopedResponseMetadata): stack: Optional["StackResponse"] = Field( default=None, title="The stack that was used for this build." ) - images: Dict[str, "BuildItem"] = Field( + images: dict[str, "BuildItem"] = Field( default={}, title="The images of this build." ) - zenml_version: Optional[str] = Field( + zenml_version: str | None = Field( default=None, title="The version of ZenML used for this build." ) - python_version: Optional[str] = Field( + python_version: str | None = Field( default=None, title="The Python version used for this build." ) - checksum: Optional[str] = Field(default=None, title="The build checksum.") - stack_checksum: Optional[str] = Field( + checksum: str | None = Field(default=None, title="The build checksum.") + stack_checksum: str | None = Field( default=None, title="The stack checksum." ) is_local: bool = Field( @@ -230,7 +226,7 @@ class PipelineBuildResponseMetadata(ProjectScopedResponseMetadata): contains_code: bool = Field( title="Whether any image of the build contains user code.", ) - duration: Optional[int] = Field( + duration: int | None = Field( title="The duration of the build in seconds.", default=None ) @@ -259,7 +255,7 @@ def get_hydrated_version(self) -> "PipelineBuildResponse": return Client().zen_store.get_build(self.id) # Helper methods - def to_yaml(self) -> Dict[str, Any]: + def to_yaml(self) -> dict[str, Any]: """Create a yaml representation of the pipeline build. Create a yaml representation of the pipeline build that can be used @@ -269,7 +265,7 @@ def to_yaml(self) -> Dict[str, Any]: The yaml representation of the pipeline build. """ # Get the base attributes - yaml_dict: Dict[str, Any] = json.loads( + yaml_dict: dict[str, Any] = json.loads( self.model_dump_json( exclude={ "body", @@ -301,7 +297,7 @@ def requires_code_download(self) -> bool: ) @staticmethod - def get_image_key(component_key: str, step: Optional[str] = None) -> str: + def get_image_key(component_key: str, step: str | None = None) -> str: """Get the image key. Args: @@ -316,7 +312,7 @@ def get_image_key(component_key: str, step: Optional[str] = None) -> str: else: return component_key - def get_image(self, component_key: str, step: Optional[str] = None) -> str: + def get_image(self, component_key: str, step: str | None = None) -> str: """Get the image built for a specific key. Args: @@ -331,8 +327,8 @@ def get_image(self, component_key: str, step: Optional[str] = None) -> str: return self._get_item(component_key=component_key, step=step).image def get_settings_checksum( - self, component_key: str, step: Optional[str] = None - ) -> Optional[str]: + self, component_key: str, step: str | None = None + ) -> str | None: """Get the settings checksum for a specific key. Args: @@ -349,7 +345,7 @@ def get_settings_checksum( ).settings_checksum def _get_item( - self, component_key: str, step: Optional[str] = None + self, component_key: str, step: str | None = None ) -> "BuildItem": """Get the item for a specific key. @@ -402,7 +398,7 @@ def stack(self) -> Optional["StackResponse"]: return self.get_metadata().stack @property - def images(self) -> Dict[str, "BuildItem"]: + def images(self) -> dict[str, "BuildItem"]: """The `images` property. Returns: @@ -411,7 +407,7 @@ def images(self) -> Dict[str, "BuildItem"]: return self.get_metadata().images @property - def zenml_version(self) -> Optional[str]: + def zenml_version(self) -> str | None: """The `zenml_version` property. Returns: @@ -420,7 +416,7 @@ def zenml_version(self) -> Optional[str]: return self.get_metadata().zenml_version @property - def python_version(self) -> Optional[str]: + def python_version(self) -> str | None: """The `python_version` property. Returns: @@ -429,7 +425,7 @@ def python_version(self) -> Optional[str]: return self.get_metadata().python_version @property - def checksum(self) -> Optional[str]: + def checksum(self) -> str | None: """The `checksum` property. Returns: @@ -438,7 +434,7 @@ def checksum(self) -> Optional[str]: return self.get_metadata().checksum @property - def stack_checksum(self) -> Optional[str]: + def stack_checksum(self) -> str | None: """The `stack_checksum` property. Returns: @@ -465,7 +461,7 @@ def contains_code(self) -> bool: return self.get_metadata().contains_code @property - def duration(self) -> Optional[int]: + def duration(self) -> int | None: """The `duration` property. Returns: @@ -480,55 +476,55 @@ def duration(self) -> Optional[int]: class PipelineBuildFilter(ProjectScopedFilter): """Model to enable advanced filtering of all pipeline builds.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS, "container_registry_id", ] - pipeline_id: Optional[Union[UUID, str]] = Field( + pipeline_id: UUID | str | None = Field( description="Pipeline associated with the pipeline build.", default=None, union_mode="left_to_right", ) - stack_id: Optional[Union[UUID, str]] = Field( + stack_id: UUID | str | None = Field( description="Stack associated with the pipeline build.", default=None, union_mode="left_to_right", ) - container_registry_id: Optional[Union[UUID, str]] = Field( + container_registry_id: UUID | str | None = Field( description="Container registry associated with the pipeline build.", default=None, union_mode="left_to_right", ) - is_local: Optional[bool] = Field( + is_local: bool | None = Field( description="Whether the build images are stored in a container " "registry or locally.", default=None, ) - contains_code: Optional[bool] = Field( + contains_code: bool | None = Field( description="Whether any image of the build contains user code.", default=None, ) - zenml_version: Optional[str] = Field( + zenml_version: str | None = Field( description="The version of ZenML used for this build.", default=None ) - python_version: Optional[str] = Field( + python_version: str | None = Field( description="The Python version used for this build.", default=None ) - checksum: Optional[str] = Field( + checksum: str | None = Field( description="The build checksum.", default=None ) - stack_checksum: Optional[str] = Field( + stack_checksum: str | None = Field( description="The stack checksum.", default=None ) - duration: Optional[Union[int, str]] = Field( + duration: int | str | None = Field( description="The duration of the build in seconds.", default=None ) def get_custom_filters( self, - table: Type["AnySchema"], - ) -> List["ColumnElement[bool]"]: + table: type["AnySchema"], + ) -> list["ColumnElement[bool]"]: """Get custom filters. Args: diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 1f74a9e99e3..7a9ef443532 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -18,12 +18,8 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - List, Optional, - Type, TypeVar, - Union, ) from uuid import UUID @@ -81,11 +77,11 @@ class PipelineRunTriggerInfo(BaseZenModel): """Trigger information model.""" - step_run_id: Optional[UUID] = Field( + step_run_id: UUID | None = Field( default=None, title="The ID of the step run that triggered the pipeline run.", ) - deployment_id: Optional[UUID] = Field( + deployment_id: UUID | None = Field( default=None, title="The ID of the deployment that triggered the pipeline run.", ) @@ -101,51 +97,51 @@ class PipelineRunRequest(ProjectScopedRequest): snapshot: UUID = Field( title="The snapshot associated with the pipeline run." ) - pipeline: Optional[UUID] = Field( + pipeline: UUID | None = Field( title="The pipeline associated with the pipeline run.", default=None, ) - orchestrator_run_id: Optional[str] = Field( + orchestrator_run_id: str | None = Field( title="The orchestrator run ID.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - start_time: Optional[datetime] = Field( + start_time: datetime | None = Field( title="The start time of the pipeline run.", default=None, ) - end_time: Optional[datetime] = Field( + end_time: datetime | None = Field( title="The end time of the pipeline run.", default=None, ) status: ExecutionStatus = Field( title="The status of the pipeline run.", ) - status_reason: Optional[str] = Field( + status_reason: str | None = Field( title="The reason for the status of the pipeline run.", default=None, max_length=STR_FIELD_MAX_LENGTH, ) - orchestrator_environment: Dict[str, Any] = Field( + orchestrator_environment: dict[str, Any] = Field( default={}, title=( "Environment of the orchestrator that executed this pipeline run " "(OS, Python version, etc.)." ), ) - trigger_execution_id: Optional[UUID] = Field( + trigger_execution_id: UUID | None = Field( default=None, title="ID of the trigger execution that triggered this run.", ) - trigger_info: Optional[PipelineRunTriggerInfo] = Field( + trigger_info: PipelineRunTriggerInfo | None = Field( default=None, title="Trigger information for the pipeline run.", ) - tags: Optional[List[Union[str, Tag]]] = Field( + tags: list[str | Tag] | None = Field( default=None, title="Tags of the pipeline run.", ) - logs: Optional[LogsRequest] = Field( + logs: LogsRequest | None = Field( default=None, title="Logs of the pipeline run.", ) @@ -171,23 +167,23 @@ def is_placeholder_request(self) -> bool: class PipelineRunUpdate(BaseUpdate): """Pipeline run update model.""" - status: Optional[ExecutionStatus] = None - status_reason: Optional[str] = Field( + status: ExecutionStatus | None = None + status_reason: str | None = Field( default=None, title="The reason for the status of the pipeline run.", max_length=STR_FIELD_MAX_LENGTH, ) - end_time: Optional[datetime] = None - orchestrator_run_id: Optional[str] = None + end_time: datetime | None = None + orchestrator_run_id: str | None = None # TODO: we should maybe have a different update model here, the upper # three attributes should only be for internal use - add_tags: Optional[List[str]] = Field( + add_tags: list[str] | None = Field( default=None, title="New tags to add to the pipeline run." ) - remove_tags: Optional[List[str]] = Field( + remove_tags: list[str] | None = Field( default=None, title="Tags to remove from the pipeline run." ) - add_logs: Optional[List[LogsRequest]] = Field( + add_logs: list[LogsRequest] | None = Field( default=None, title="New logs to add to the pipeline run." ) @@ -206,7 +202,7 @@ class PipelineRunResponseBody(ProjectScopedResponseBody): in_progress: bool = Field( title="Whether the pipeline run is in progress.", ) - status_reason: Optional[str] = Field( + status_reason: str | None = Field( default=None, title="The reason for the status of the pipeline run.", ) @@ -217,52 +213,52 @@ class PipelineRunResponseBody(ProjectScopedResponseBody): class PipelineRunResponseMetadata(ProjectScopedResponseMetadata): """Response metadata for pipeline runs.""" - __zenml_skip_dehydration__: ClassVar[List[str]] = [ + __zenml_skip_dehydration__: ClassVar[list[str]] = [ "run_metadata", "config", "client_environment", "orchestrator_environment", ] - run_metadata: Dict[str, MetadataType] = Field( + run_metadata: dict[str, MetadataType] = Field( default={}, title="Metadata associated with this pipeline run.", ) config: PipelineConfiguration = Field( title="The pipeline configuration used for this pipeline run.", ) - start_time: Optional[datetime] = Field( + start_time: datetime | None = Field( title="The start time of the pipeline run.", default=None, ) - end_time: Optional[datetime] = Field( + end_time: datetime | None = Field( title="The end time of the pipeline run.", default=None, ) - client_environment: Dict[str, Any] = Field( + client_environment: dict[str, Any] = Field( default={}, title=( "Environment of the client that initiated this pipeline run " "(OS, Python version, etc.)." ), ) - orchestrator_environment: Dict[str, Any] = Field( + orchestrator_environment: dict[str, Any] = Field( default={}, title=( "Environment of the orchestrator that executed this pipeline run " "(OS, Python version, etc.)." ), ) - orchestrator_run_id: Optional[str] = Field( + orchestrator_run_id: str | None = Field( title="The orchestrator run ID.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - code_path: Optional[str] = Field( + code_path: str | None = Field( default=None, title="Optional path where the code is stored in the artifact store.", ) - template_id: Optional[UUID] = Field( + template_id: UUID | None = Field( default=None, description="DEPRECATED: Template used for the pipeline run.", deprecated=True, @@ -271,7 +267,7 @@ class PipelineRunResponseMetadata(ProjectScopedResponseMetadata): default=False, description="Whether a template can be created from this run.", ) - trigger_info: Optional[PipelineRunTriggerInfo] = Field( + trigger_info: PipelineRunTriggerInfo | None = Field( default=None, title="Trigger information for the pipeline run.", ) @@ -300,19 +296,19 @@ class PipelineRunResponseResources(ProjectScopedResponseResources): trigger_execution: Optional["TriggerExecutionResponse"] = Field( default=None, title="The trigger execution that triggered this run." ) - model_version: Optional[ModelVersionResponse] = None - tags: List[TagResponse] = Field( + model_version: ModelVersionResponse | None = None + tags: list[TagResponse] = Field( title="Tags associated with the pipeline run.", ) logs: Optional["LogsResponse"] = Field( title="Logs associated with this pipeline run.", default=None, ) - log_collection: Optional[List["LogsResponse"]] = Field( + log_collection: list["LogsResponse"] | None = Field( title="Logs associated with this pipeline run.", default=None, ) - visualizations: List["CuratedVisualizationResponse"] = Field( + visualizations: list["CuratedVisualizationResponse"] = Field( default=[], title="Curated visualizations associated with the pipeline run.", ) @@ -352,7 +348,7 @@ def get_hydrated_version(self) -> "PipelineRunResponse": # Helper methods @property - def artifact_versions(self) -> List["ArtifactVersionResponse"]: + def artifact_versions(self) -> list["ArtifactVersionResponse"]: """Get all artifact versions that are outputs of steps of this run. Returns: @@ -365,7 +361,7 @@ def artifact_versions(self) -> List["ArtifactVersionResponse"]: return get_artifacts_versions_of_pipeline_run(self) @property - def produced_artifact_versions(self) -> List["ArtifactVersionResponse"]: + def produced_artifact_versions(self) -> list["ArtifactVersionResponse"]: """Get all artifact versions produced during this pipeline run. Returns: @@ -388,7 +384,7 @@ def status(self) -> ExecutionStatus: return self.get_body().status @property - def run_metadata(self) -> Dict[str, MetadataType]: + def run_metadata(self) -> dict[str, MetadataType]: """The `run_metadata` property. Returns: @@ -397,7 +393,7 @@ def run_metadata(self) -> Dict[str, MetadataType]: return self.get_metadata().run_metadata @property - def steps(self) -> Dict[str, "StepRunResponse"]: + def steps(self) -> dict[str, "StepRunResponse"]: """The `steps` property. Returns: @@ -425,7 +421,7 @@ def config(self) -> PipelineConfiguration: return self.get_metadata().config @property - def start_time(self) -> Optional[datetime]: + def start_time(self) -> datetime | None: """The `start_time` property. Returns: @@ -434,7 +430,7 @@ def start_time(self) -> Optional[datetime]: return self.get_metadata().start_time @property - def end_time(self) -> Optional[datetime]: + def end_time(self) -> datetime | None: """The `end_time` property. Returns: @@ -452,7 +448,7 @@ def in_progress(self) -> bool: return self.get_body().in_progress @property - def client_environment(self) -> Dict[str, Any]: + def client_environment(self) -> dict[str, Any]: """The `client_environment` property. Returns: @@ -461,7 +457,7 @@ def client_environment(self) -> Dict[str, Any]: return self.get_metadata().client_environment @property - def orchestrator_environment(self) -> Dict[str, Any]: + def orchestrator_environment(self) -> dict[str, Any]: """The `orchestrator_environment` property. Returns: @@ -470,7 +466,7 @@ def orchestrator_environment(self) -> Dict[str, Any]: return self.get_metadata().orchestrator_environment @property - def orchestrator_run_id(self) -> Optional[str]: + def orchestrator_run_id(self) -> str | None: """The `orchestrator_run_id` property. Returns: @@ -479,7 +475,7 @@ def orchestrator_run_id(self) -> Optional[str]: return self.get_metadata().orchestrator_run_id @property - def code_path(self) -> Optional[str]: + def code_path(self) -> str | None: """The `code_path` property. Returns: @@ -488,7 +484,7 @@ def code_path(self) -> Optional[str]: return self.get_metadata().code_path @property - def template_id(self) -> Optional[UUID]: + def template_id(self) -> UUID | None: """The `template_id` property. Returns: @@ -578,7 +574,7 @@ def code_reference(self) -> Optional["CodeReferenceResponse"]: return self.get_resources().code_reference @property - def model_version(self) -> Optional[ModelVersionResponse]: + def model_version(self) -> ModelVersionResponse | None: """The `model_version` property. Returns: @@ -587,7 +583,7 @@ def model_version(self) -> Optional[ModelVersionResponse]: return self.get_resources().model_version @property - def tags(self) -> List[TagResponse]: + def tags(self) -> list[TagResponse]: """The `tags` property. Returns: @@ -605,7 +601,7 @@ def logs(self) -> Optional["LogsResponse"]: return self.get_resources().logs @property - def log_collection(self) -> Optional[List["LogsResponse"]]: + def log_collection(self) -> list["LogsResponse"] | None: """The `log_collection` property. Returns: @@ -622,7 +618,7 @@ class PipelineRunFilter( ): """Model to enable advanced filtering of all pipeline runs.""" - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + CUSTOM_SORTING_OPTIONS: ClassVar[list[str]] = [ *ProjectScopedFilter.CUSTOM_SORTING_OPTIONS, *TaggableFilter.CUSTOM_SORTING_OPTIONS, *RunMetadataFilterMixin.CUSTOM_SORTING_OPTIONS, @@ -631,7 +627,7 @@ class PipelineRunFilter( "model", "model_version", ] - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS, *TaggableFilter.FILTER_EXCLUDE_FIELDS, *RunMetadataFilterMixin.FILTER_EXCLUDE_FIELDS, @@ -658,67 +654,67 @@ class PipelineRunFilter( *TaggableFilter.CLI_EXCLUDE_FIELDS, *RunMetadataFilterMixin.CLI_EXCLUDE_FIELDS, ] - API_MULTI_INPUT_PARAMS: ClassVar[List[str]] = [ + API_MULTI_INPUT_PARAMS: ClassVar[list[str]] = [ *ProjectScopedFilter.API_MULTI_INPUT_PARAMS, *TaggableFilter.API_MULTI_INPUT_PARAMS, *RunMetadataFilterMixin.API_MULTI_INPUT_PARAMS, ] - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the Pipeline Run", ) - orchestrator_run_id: Optional[str] = Field( + orchestrator_run_id: str | None = Field( default=None, description="Name of the Pipeline Run within the orchestrator", ) - pipeline_id: Optional[Union[UUID, str]] = Field( + pipeline_id: UUID | str | None = Field( default=None, description="Pipeline associated with the Pipeline Run", union_mode="left_to_right", ) - stack_id: Optional[Union[UUID, str]] = Field( + stack_id: UUID | str | None = Field( default=None, description="Stack used for the Pipeline Run", union_mode="left_to_right", ) - schedule_id: Optional[Union[UUID, str]] = Field( + schedule_id: UUID | str | None = Field( default=None, description="Schedule that triggered the Pipeline Run", union_mode="left_to_right", ) - build_id: Optional[Union[UUID, str]] = Field( + build_id: UUID | str | None = Field( default=None, description="Build used for the Pipeline Run", union_mode="left_to_right", ) - snapshot_id: Optional[Union[UUID, str]] = Field( + snapshot_id: UUID | str | None = Field( default=None, description="Snapshot used for the Pipeline Run", union_mode="left_to_right", ) - code_repository_id: Optional[Union[UUID, str]] = Field( + code_repository_id: UUID | str | None = Field( default=None, description="Code repository used for the Pipeline Run", union_mode="left_to_right", ) - template_id: Optional[Union[UUID, str]] = Field( + template_id: UUID | str | None = Field( default=None, description="DEPRECATED: Template used for the pipeline run.", union_mode="left_to_right", deprecated=True, ) - source_snapshot_id: Optional[Union[UUID, str]] = Field( + source_snapshot_id: UUID | str | None = Field( default=None, description="Source snapshot used for the pipeline run.", union_mode="left_to_right", ) - model_version_id: Optional[Union[UUID, str]] = Field( + model_version_id: UUID | str | None = Field( default=None, description="Model version associated with the pipeline run.", union_mode="left_to_right", ) - linked_to_model_version_id: Optional[Union[UUID, str]] = Field( + linked_to_model_version_id: UUID | str | None = Field( default=None, description="Filter by model version linked to the pipeline run. " "The difference to `model_version_id` is that this filter will " @@ -726,60 +722,60 @@ class PipelineRunFilter( "version, but also if any step run is linked to the model version.", union_mode="left_to_right", ) - status: Optional[str] = Field( + status: str | None = Field( default=None, description="Name of the Pipeline Run", ) - in_progress: Optional[bool] = Field( + in_progress: bool | None = Field( default=None, description="Whether the pipeline run is in progress.", ) - start_time: Optional[Union[datetime, str]] = Field( + start_time: datetime | str | None = Field( default=None, description="Start time for this run", union_mode="left_to_right", ) - end_time: Optional[Union[datetime, str]] = Field( + end_time: datetime | str | None = Field( default=None, description="End time for this run", union_mode="left_to_right", ) - unlisted: Optional[bool] = None + unlisted: bool | None = None # TODO: Remove once frontend is ready for it. This is replaced by the more # generic `pipeline` filter below. - pipeline_name: Optional[str] = Field( + pipeline_name: str | None = Field( default=None, description="Name of the pipeline associated with the run", ) - pipeline: Optional[Union[UUID, str]] = Field( + pipeline: UUID | str | None = Field( default=None, description="Name/ID of the pipeline associated with the run.", ) - stack: Optional[Union[UUID, str]] = Field( + stack: UUID | str | None = Field( default=None, description="Name/ID of the stack associated with the run.", ) - code_repository: Optional[Union[UUID, str]] = Field( + code_repository: UUID | str | None = Field( default=None, description="Name/ID of the code repository associated with the run.", ) - model: Optional[Union[UUID, str]] = Field( + model: UUID | str | None = Field( default=None, description="Name/ID of the model associated with the run.", ) - stack_component: Optional[Union[UUID, str]] = Field( + stack_component: UUID | str | None = Field( default=None, description="Name/ID of the stack component associated with the run.", ) - templatable: Optional[bool] = Field( + templatable: bool | None = Field( default=None, description="Whether the run is templatable." ) - triggered_by_step_run_id: Optional[Union[UUID, str]] = Field( + triggered_by_step_run_id: UUID | str | None = Field( default=None, description="The ID of the step run that triggered this pipeline run.", union_mode="left_to_right", ) - triggered_by_deployment_id: Optional[Union[UUID, str]] = Field( + triggered_by_deployment_id: UUID | str | None = Field( default=None, description="The ID of the deployment that triggered this pipeline run.", union_mode="left_to_right", @@ -788,8 +784,8 @@ class PipelineRunFilter( def get_custom_filters( self, - table: Type["AnySchema"], - ) -> List["ColumnElement[bool]"]: + table: type["AnySchema"], + ) -> list["ColumnElement[bool]"]: """Get custom filters. Args: @@ -1023,7 +1019,7 @@ def get_custom_filters( def apply_sorting( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Apply sorting to the query. diff --git a/src/zenml/models/v2/core/pipeline_snapshot.py b/src/zenml/models/v2/core/pipeline_snapshot.py index c56cea969d9..847d73382ec 100644 --- a/src/zenml/models/v2/core/pipeline_snapshot.py +++ b/src/zenml/models/v2/core/pipeline_snapshot.py @@ -17,12 +17,8 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - List, Optional, - Type, TypeVar, - Union, ) from uuid import UUID @@ -82,25 +78,25 @@ class PipelineSnapshotBase(BaseZenModel): pipeline_configuration: PipelineConfiguration = Field( title="The pipeline configuration for this snapshot." ) - step_configurations: Dict[str, Step] = Field( + step_configurations: dict[str, Step] = Field( default={}, title="The step configurations for this snapshot." ) - client_environment: Dict[str, Any] = Field( + client_environment: dict[str, Any] = Field( default={}, title="The client environment for this snapshot." ) - client_version: Optional[str] = Field( + client_version: str | None = Field( default=None, title="The version of the ZenML installation on the client side.", ) - server_version: Optional[str] = Field( + server_version: str | None = Field( default=None, title="The version of the ZenML installation on the server side.", ) - pipeline_version_hash: Optional[str] = Field( + pipeline_version_hash: str | None = Field( default=None, title="The pipeline version hash of the snapshot.", ) - pipeline_spec: Optional[PipelineSpec] = Field( + pipeline_spec: PipelineSpec | None = Field( default=None, title="The pipeline spec of the snapshot.", ) @@ -121,45 +117,45 @@ def should_prevent_build_reuse(self) -> bool: class PipelineSnapshotRequest(PipelineSnapshotBase, ProjectScopedRequest): """Request model for pipeline snapshots.""" - name: Optional[Union[str, bool]] = Field( + name: str | bool | None = Field( default=None, title="The name of the snapshot.", ) - description: Optional[str] = Field( + description: str | None = Field( default=None, title="The description of the snapshot.", max_length=TEXT_FIELD_MAX_LENGTH, ) - replace: Optional[bool] = Field( + replace: bool | None = Field( default=None, title="Whether to replace the existing snapshot with the same name.", ) - tags: Optional[List[str]] = Field( + tags: list[str] | None = Field( default=None, title="Tags of the snapshot.", ) stack: UUID = Field(title="The stack associated with the snapshot.") pipeline: UUID = Field(title="The pipeline associated with the snapshot.") - build: Optional[UUID] = Field( + build: UUID | None = Field( default=None, title="The build associated with the snapshot." ) - schedule: Optional[UUID] = Field( + schedule: UUID | None = Field( default=None, title="The schedule associated with the snapshot." ) code_reference: Optional["CodeReferenceRequest"] = Field( default=None, title="The code reference associated with the snapshot.", ) - code_path: Optional[str] = Field( + code_path: str | None = Field( default=None, title="Optional path where the code is stored in the artifact store.", ) - template: Optional[UUID] = Field( + template: UUID | None = Field( default=None, description="DEPRECATED: Template used for the snapshot.", ) - source_snapshot: Optional[UUID] = Field( + source_snapshot: UUID | None = Field( default=None, description="Snapshot that is the source of this snapshot.", ) @@ -187,24 +183,24 @@ def _validate_name(cls, v: Any) -> Any: class PipelineSnapshotUpdate(BaseUpdate): """Pipeline snapshot update model.""" - name: Optional[Union[str, bool]] = Field( + name: str | bool | None = Field( default=None, title="The name of the snapshot. If set to " "False, the name will be removed.", ) - description: Optional[str] = Field( + description: str | None = Field( default=None, title="The description of the snapshot.", max_length=TEXT_FIELD_MAX_LENGTH, ) - replace: Optional[bool] = Field( + replace: bool | None = Field( default=None, title="Whether to replace the existing snapshot with the same name.", ) - add_tags: Optional[List[str]] = Field( + add_tags: list[str] | None = Field( default=None, title="New tags to add to the snapshot." ) - remove_tags: Optional[List[str]] = Field( + remove_tags: list[str] | None = Field( default=None, title="Tags to remove from the snapshot." ) @@ -242,14 +238,14 @@ class PipelineSnapshotResponseBody(ProjectScopedResponseBody): class PipelineSnapshotResponseMetadata(ProjectScopedResponseMetadata): """Response metadata for pipeline snapshots.""" - __zenml_skip_dehydration__: ClassVar[List[str]] = [ + __zenml_skip_dehydration__: ClassVar[list[str]] = [ "pipeline_configuration", "step_configurations", "client_environment", "pipeline_spec", ] - description: Optional[str] = Field( + description: str | None = Field( default=None, title="The description of the snapshot.", ) @@ -259,41 +255,41 @@ class PipelineSnapshotResponseMetadata(ProjectScopedResponseMetadata): pipeline_configuration: PipelineConfiguration = Field( title="The pipeline configuration for this snapshot." ) - step_configurations: Dict[str, Step] = Field( + step_configurations: dict[str, Step] = Field( default={}, title="The step configurations for this snapshot." ) - client_environment: Dict[str, Any] = Field( + client_environment: dict[str, Any] = Field( default={}, title="The client environment for this snapshot." ) - client_version: Optional[str] = Field( + client_version: str | None = Field( title="The version of the ZenML installation on the client side." ) - server_version: Optional[str] = Field( + server_version: str | None = Field( title="The version of the ZenML installation on the server side." ) - pipeline_version_hash: Optional[str] = Field( + pipeline_version_hash: str | None = Field( default=None, title="The pipeline version hash of the snapshot." ) - pipeline_spec: Optional[PipelineSpec] = Field( + pipeline_spec: PipelineSpec | None = Field( default=None, title="The pipeline spec of the snapshot." ) - code_path: Optional[str] = Field( + code_path: str | None = Field( default=None, title="Optional path where the code is stored in the artifact store.", ) - template_id: Optional[UUID] = Field( + template_id: UUID | None = Field( default=None, description="Template from which this snapshot was created.", deprecated=True, ) - source_snapshot_id: Optional[UUID] = Field( + source_snapshot_id: UUID | None = Field( default=None, description="Snapshot that is the source of this snapshot.", ) - config_template: Optional[Dict[str, Any]] = Field( + config_template: dict[str, Any] | None = Field( default=None, title="Run configuration template." ) - config_schema: Optional[Dict[str, Any]] = Field( + config_schema: dict[str, Any] | None = Field( default=None, title="Run configuration schema." ) @@ -304,41 +300,41 @@ class PipelineSnapshotResponseResources(ProjectScopedResponseResources): pipeline: PipelineResponse = Field( title="The pipeline associated with the snapshot." ) - stack: Optional[StackResponse] = Field( + stack: StackResponse | None = Field( default=None, title="The stack associated with the snapshot." ) - build: Optional[PipelineBuildResponse] = Field( + build: PipelineBuildResponse | None = Field( default=None, title="The pipeline build associated with the snapshot.", ) - schedule: Optional[ScheduleResponse] = Field( + schedule: ScheduleResponse | None = Field( default=None, title="The schedule associated with the snapshot." ) - code_reference: Optional[CodeReferenceResponse] = Field( + code_reference: CodeReferenceResponse | None = Field( default=None, title="The code reference associated with the snapshot.", ) - deployment: Optional[DeploymentResponse] = Field( + deployment: DeploymentResponse | None = Field( default=None, title="The deployment associated with the snapshot.", ) - tags: List[TagResponse] = Field( + tags: list[TagResponse] = Field( default=[], title="Tags associated with the snapshot.", ) - latest_run_id: Optional[UUID] = Field( + latest_run_id: UUID | None = Field( default=None, title="The ID of the latest run of the snapshot.", ) - latest_run_status: Optional[ExecutionStatus] = Field( + latest_run_status: ExecutionStatus | None = Field( default=None, title="The status of the latest run of the snapshot.", ) - latest_run_user: Optional[UserResponse] = Field( + latest_run_user: UserResponse | None = Field( default=None, title="The user that created the latest run of the snapshot.", ) - visualizations: List["CuratedVisualizationResponse"] = Field( + visualizations: list["CuratedVisualizationResponse"] = Field( default=[], title="Curated visualizations associated with the pipeline snapshot.", ) @@ -353,7 +349,7 @@ class PipelineSnapshotResponse( ): """Response model for pipeline snapshots.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, title="The name of the snapshot.", max_length=STR_FIELD_MAX_LENGTH, @@ -390,7 +386,7 @@ def deployable(self) -> bool: return self.get_body().deployable @property - def description(self) -> Optional[str]: + def description(self) -> str | None: """The `description` property. Returns: @@ -417,7 +413,7 @@ def pipeline_configuration(self) -> PipelineConfiguration: return self.get_metadata().pipeline_configuration @property - def step_configurations(self) -> Dict[str, Step]: + def step_configurations(self) -> dict[str, Step]: """The `step_configurations` property. Returns: @@ -426,7 +422,7 @@ def step_configurations(self) -> Dict[str, Step]: return self.get_metadata().step_configurations @property - def client_environment(self) -> Dict[str, Any]: + def client_environment(self) -> dict[str, Any]: """The `client_environment` property. Returns: @@ -435,7 +431,7 @@ def client_environment(self) -> Dict[str, Any]: return self.get_metadata().client_environment @property - def client_version(self) -> Optional[str]: + def client_version(self) -> str | None: """The `client_version` property. Returns: @@ -444,7 +440,7 @@ def client_version(self) -> Optional[str]: return self.get_metadata().client_version @property - def server_version(self) -> Optional[str]: + def server_version(self) -> str | None: """The `server_version` property. Returns: @@ -453,7 +449,7 @@ def server_version(self) -> Optional[str]: return self.get_metadata().server_version @property - def pipeline_version_hash(self) -> Optional[str]: + def pipeline_version_hash(self) -> str | None: """The `pipeline_version_hash` property. Returns: @@ -462,7 +458,7 @@ def pipeline_version_hash(self) -> Optional[str]: return self.get_metadata().pipeline_version_hash @property - def pipeline_spec(self) -> Optional[PipelineSpec]: + def pipeline_spec(self) -> PipelineSpec | None: """The `pipeline_spec` property. Returns: @@ -471,7 +467,7 @@ def pipeline_spec(self) -> Optional[PipelineSpec]: return self.get_metadata().pipeline_spec @property - def code_path(self) -> Optional[str]: + def code_path(self) -> str | None: """The `code_path` property. Returns: @@ -480,7 +476,7 @@ def code_path(self) -> Optional[str]: return self.get_metadata().code_path @property - def template_id(self) -> Optional[UUID]: + def template_id(self) -> UUID | None: """The `template_id` property. Returns: @@ -489,7 +485,7 @@ def template_id(self) -> Optional[UUID]: return self.get_metadata().template_id @property - def source_snapshot_id(self) -> Optional[UUID]: + def source_snapshot_id(self) -> UUID | None: """The `source_snapshot_id` property. Returns: @@ -498,7 +494,7 @@ def source_snapshot_id(self) -> Optional[UUID]: return self.get_metadata().source_snapshot_id @property - def config_schema(self) -> Optional[Dict[str, Any]]: + def config_schema(self) -> dict[str, Any] | None: """The `config_schema` property. Returns: @@ -507,7 +503,7 @@ def config_schema(self) -> Optional[Dict[str, Any]]: return self.get_metadata().config_schema @property - def config_template(self) -> Optional[Dict[str, Any]]: + def config_template(self) -> dict[str, Any] | None: """The `config_template` property. Returns: @@ -525,7 +521,7 @@ def pipeline(self) -> PipelineResponse: return self.get_resources().pipeline @property - def stack(self) -> Optional[StackResponse]: + def stack(self) -> StackResponse | None: """The `stack` property. Returns: @@ -534,7 +530,7 @@ def stack(self) -> Optional[StackResponse]: return self.get_resources().stack @property - def build(self) -> Optional[PipelineBuildResponse]: + def build(self) -> PipelineBuildResponse | None: """The `build` property. Returns: @@ -543,7 +539,7 @@ def build(self) -> Optional[PipelineBuildResponse]: return self.get_resources().build @property - def schedule(self) -> Optional[ScheduleResponse]: + def schedule(self) -> ScheduleResponse | None: """The `schedule` property. Returns: @@ -552,7 +548,7 @@ def schedule(self) -> Optional[ScheduleResponse]: return self.get_resources().schedule @property - def code_reference(self) -> Optional[CodeReferenceResponse]: + def code_reference(self) -> CodeReferenceResponse | None: """The `code_reference` property. Returns: @@ -561,7 +557,7 @@ def code_reference(self) -> Optional[CodeReferenceResponse]: return self.get_resources().code_reference @property - def deployment(self) -> Optional[DeploymentResponse]: + def deployment(self) -> DeploymentResponse | None: """The `deployment` property. Returns: @@ -570,7 +566,7 @@ def deployment(self) -> Optional[DeploymentResponse]: return self.get_resources().deployment @property - def tags(self) -> List[TagResponse]: + def tags(self) -> list[TagResponse]: """The `tags` property. Returns: @@ -579,7 +575,7 @@ def tags(self) -> List[TagResponse]: return self.get_resources().tags @property - def latest_run_id(self) -> Optional[UUID]: + def latest_run_id(self) -> UUID | None: """The `latest_run_id` property. Returns: @@ -588,7 +584,7 @@ def latest_run_id(self) -> Optional[UUID]: return self.get_resources().latest_run_id @property - def latest_run_status(self) -> Optional[ExecutionStatus]: + def latest_run_status(self) -> ExecutionStatus | None: """The `latest_run_status` property. Returns: @@ -597,7 +593,7 @@ def latest_run_status(self) -> Optional[ExecutionStatus]: return self.get_resources().latest_run_status @property - def latest_run_user(self) -> Optional[UserResponse]: + def latest_run_user(self) -> UserResponse | None: """The `latest_run_user` property. Returns: @@ -612,7 +608,7 @@ def latest_run_user(self) -> Optional[UserResponse]: class PipelineSnapshotFilter(ProjectScopedFilter, TaggableFilter): """Model for filtering pipeline snapshots.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS, *TaggableFilter.FILTER_EXCLUDE_FIELDS, "named_only", @@ -629,60 +625,60 @@ class PipelineSnapshotFilter(ProjectScopedFilter, TaggableFilter): "stack", "deployment", ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.CLI_EXCLUDE_FIELDS, *TaggableFilter.CLI_EXCLUDE_FIELDS, ] - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the snapshot.", ) - named_only: Optional[bool] = Field( + named_only: bool | None = Field( default=None, description="Whether to only return snapshots with a name.", ) - pipeline: Optional[Union[UUID, str]] = Field( + pipeline: UUID | str | None = Field( default=None, description="Pipeline associated with the snapshot.", union_mode="left_to_right", ) - stack: Optional[Union[UUID, str]] = Field( + stack: UUID | str | None = Field( default=None, description="Stack associated with the snapshot.", union_mode="left_to_right", ) - build_id: Optional[Union[UUID, str]] = Field( + build_id: UUID | str | None = Field( default=None, description="Build associated with the snapshot.", union_mode="left_to_right", ) - schedule_id: Optional[Union[UUID, str]] = Field( + schedule_id: UUID | str | None = Field( default=None, description="Schedule associated with the snapshot.", union_mode="left_to_right", ) - source_snapshot_id: Optional[Union[UUID, str]] = Field( + source_snapshot_id: UUID | str | None = Field( default=None, description="Source snapshot used for the snapshot.", union_mode="left_to_right", ) - runnable: Optional[bool] = Field( + runnable: bool | None = Field( default=None, description="Whether the snapshot is runnable.", ) - deployable: Optional[bool] = Field( + deployable: bool | None = Field( default=None, description="Whether the snapshot is deployable.", ) - deployed: Optional[bool] = Field( + deployed: bool | None = Field( default=None, description="Whether the snapshot is deployed.", ) def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List["ColumnElement[bool]"]: + self, table: type["AnySchema"] + ) -> list["ColumnElement[bool]"]: """Get custom filters. Args: @@ -785,7 +781,7 @@ def get_custom_filters( def apply_sorting( self, query: "AnyQuery", - table: Type["AnySchema"], + table: type["AnySchema"], ) -> "AnyQuery": """Apply sorting to the query. @@ -845,11 +841,11 @@ def apply_sorting( class PipelineSnapshotRunRequest(BaseZenModel): """Request model for running a pipeline snapshot.""" - run_configuration: Optional[PipelineRunConfiguration] = Field( + run_configuration: PipelineRunConfiguration | None = Field( default=None, title="The run configuration for the snapshot.", ) - step_run: Optional[UUID] = Field( + step_run: UUID | None = Field( default=None, title="The ID of the step run that ran the snapshot.", ) diff --git a/src/zenml/models/v2/core/project.py b/src/zenml/models/v2/core/project.py index ec35850b7d8..d0765368047 100644 --- a/src/zenml/models/v2/core/project.py +++ b/src/zenml/models/v2/core/project.py @@ -14,7 +14,7 @@ """Models representing projects.""" import re -from typing import Any, Dict, Optional +from typing import Any from pydantic import Field, model_validator @@ -58,7 +58,7 @@ class ProjectRequest(BaseRequest): @model_validator(mode="before") @classmethod @before_validator_handler - def _validate_project_name(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def _validate_project_name(cls, data: dict[str, Any]) -> dict[str, Any]: """Validate the project name. Args: @@ -94,7 +94,7 @@ def _validate_project_name(cls, data: Dict[str, Any]) -> Dict[str, Any]: class ProjectUpdate(BaseUpdate): """Update model for projects.""" - name: Optional[str] = Field( + name: str | None = Field( title="The unique name of the project. The project name must only " "contain only lowercase letters, numbers, underscores, and hyphens and " "be at most 50 characters long.", @@ -103,12 +103,12 @@ class ProjectUpdate(BaseUpdate): pattern=r"^[a-z0-9_-]+$", default=None, ) - display_name: Optional[str] = Field( + display_name: str | None = Field( title="The display name of the project.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - description: Optional[str] = Field( + description: str | None = Field( title="The description of the project.", max_length=STR_FIELD_MAX_LENGTH, default=None, @@ -192,12 +192,12 @@ def description(self) -> str: class ProjectFilter(BaseFilter): """Model to enable advanced filtering of all projects.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the project", ) - display_name: Optional[str] = Field( + display_name: str | None = Field( default=None, description="Display name of the project", ) diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index dfe651f77be..3ed1029378e 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Models representing run metadata.""" -from typing import Dict, List, Optional from uuid import UUID from pydantic import Field, model_validator @@ -30,20 +29,20 @@ class RunMetadataRequest(ProjectScopedRequest): """Request model for run metadata.""" - resources: List[RunMetadataResource] = Field( + resources: list[RunMetadataResource] = Field( title="The list of resources that this metadata belongs to." ) - stack_component_id: Optional[UUID] = Field( + stack_component_id: UUID | None = Field( title="The ID of the stack component that this metadata belongs to.", default=None, ) - values: Dict[str, "MetadataType"] = Field( + values: dict[str, "MetadataType"] = Field( title="The metadata to be created.", ) - types: Dict[str, "MetadataTypeEnum"] = Field( + types: dict[str, "MetadataTypeEnum"] = Field( title="The types of the metadata to be created.", ) - publisher_step_id: Optional[UUID] = Field( + publisher_step_id: UUID | None = Field( title="The ID of the step execution that published this metadata.", default=None, ) diff --git a/src/zenml/models/v2/core/run_template.py b/src/zenml/models/v2/core/run_template.py index 19e853aeb2a..677e6527200 100644 --- a/src/zenml/models/v2/core/run_template.py +++ b/src/zenml/models/v2/core/run_template.py @@ -17,12 +17,7 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - List, - Optional, - Type, TypeVar, - Union, ) from uuid import UUID @@ -71,7 +66,7 @@ class RunTemplateRequest(ProjectScopedRequest): title="The name of the run template.", max_length=STR_FIELD_MAX_LENGTH, ) - description: Optional[str] = Field( + description: str | None = Field( default=None, title="The description of the run template.", max_length=TEXT_FIELD_MAX_LENGTH, @@ -83,7 +78,7 @@ class RunTemplateRequest(ProjectScopedRequest): default=False, title="Whether the run template is hidden.", ) - tags: Optional[List[str]] = Field( + tags: list[str] | None = Field( default=None, title="Tags of the run template.", ) @@ -95,24 +90,24 @@ class RunTemplateRequest(ProjectScopedRequest): class RunTemplateUpdate(BaseUpdate): """Run template update model.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, title="The name of the run template.", max_length=STR_FIELD_MAX_LENGTH, ) - description: Optional[str] = Field( + description: str | None = Field( default=None, title="The description of the run template.", max_length=TEXT_FIELD_MAX_LENGTH, ) - hidden: Optional[bool] = Field( + hidden: bool | None = Field( default=None, title="Whether the run template is hidden.", ) - add_tags: Optional[List[str]] = Field( + add_tags: list[str] | None = Field( default=None, title="New tags to add to the run template." ) - remove_tags: Optional[List[str]] = Field( + remove_tags: list[str] | None = Field( default=None, title="Tags to remove from the run template." ) @@ -135,17 +130,17 @@ class RunTemplateResponseBody(ProjectScopedResponseBody): class RunTemplateResponseMetadata(ProjectScopedResponseMetadata): """Response metadata for run templates.""" - description: Optional[str] = Field( + description: str | None = Field( default=None, title="The description of the run template.", ) - pipeline_spec: Optional[PipelineSpec] = Field( + pipeline_spec: PipelineSpec | None = Field( default=None, title="The spec of the pipeline." ) - config_template: Optional[Dict[str, Any]] = Field( + config_template: dict[str, Any] | None = Field( default=None, title="Run configuration template." ) - config_schema: Optional[Dict[str, Any]] = Field( + config_schema: dict[str, Any] | None = Field( default=None, title="Run configuration schema." ) @@ -153,29 +148,29 @@ class RunTemplateResponseMetadata(ProjectScopedResponseMetadata): class RunTemplateResponseResources(ProjectScopedResponseResources): """All resource models associated with the run template.""" - source_snapshot: Optional[PipelineSnapshotResponse] = Field( + source_snapshot: PipelineSnapshotResponse | None = Field( default=None, title="The snapshot that is the source of the template.", ) - pipeline: Optional[PipelineResponse] = Field( + pipeline: PipelineResponse | None = Field( default=None, title="The pipeline associated with the template." ) - build: Optional[PipelineBuildResponse] = Field( + build: PipelineBuildResponse | None = Field( default=None, title="The pipeline build associated with the template.", ) - code_reference: Optional[CodeReferenceResponse] = Field( + code_reference: CodeReferenceResponse | None = Field( default=None, title="The code reference associated with the template.", ) - tags: List[TagResponse] = Field( + tags: list[TagResponse] = Field( title="Tags associated with the run template.", ) - latest_run_id: Optional[UUID] = Field( + latest_run_id: UUID | None = Field( default=None, title="The ID of the latest run of the run template.", ) - latest_run_status: Optional[ExecutionStatus] = Field( + latest_run_status: ExecutionStatus | None = Field( default=None, title="The status of the latest run of the run template.", ) @@ -227,7 +222,7 @@ def hidden(self) -> bool: return self.get_body().hidden @property - def latest_run_id(self) -> Optional[UUID]: + def latest_run_id(self) -> UUID | None: """The `latest_run_id` property. Returns: @@ -236,7 +231,7 @@ def latest_run_id(self) -> Optional[UUID]: return self.get_resources().latest_run_id @property - def latest_run_status(self) -> Optional[ExecutionStatus]: + def latest_run_status(self) -> ExecutionStatus | None: """The `latest_run_status` property. Returns: @@ -245,7 +240,7 @@ def latest_run_status(self) -> Optional[ExecutionStatus]: return self.get_resources().latest_run_status @property - def description(self) -> Optional[str]: + def description(self) -> str | None: """The `description` property. Returns: @@ -254,7 +249,7 @@ def description(self) -> Optional[str]: return self.get_metadata().description @property - def pipeline_spec(self) -> Optional[PipelineSpec]: + def pipeline_spec(self) -> PipelineSpec | None: """The `pipeline_spec` property. Returns: @@ -263,7 +258,7 @@ def pipeline_spec(self) -> Optional[PipelineSpec]: return self.get_metadata().pipeline_spec @property - def config_template(self) -> Optional[Dict[str, Any]]: + def config_template(self) -> dict[str, Any] | None: """The `config_template` property. Returns: @@ -272,7 +267,7 @@ def config_template(self) -> Optional[Dict[str, Any]]: return self.get_metadata().config_template @property - def config_schema(self) -> Optional[Dict[str, Any]]: + def config_schema(self) -> dict[str, Any] | None: """The `config_schema` property. Returns: @@ -281,7 +276,7 @@ def config_schema(self) -> Optional[Dict[str, Any]]: return self.get_metadata().config_schema @property - def source_snapshot(self) -> Optional[PipelineSnapshotResponse]: + def source_snapshot(self) -> PipelineSnapshotResponse | None: """The `source_snapshot` property. Returns: @@ -290,7 +285,7 @@ def source_snapshot(self) -> Optional[PipelineSnapshotResponse]: return self.get_resources().source_snapshot @property - def pipeline(self) -> Optional[PipelineResponse]: + def pipeline(self) -> PipelineResponse | None: """The `pipeline` property. Returns: @@ -299,7 +294,7 @@ def pipeline(self) -> Optional[PipelineResponse]: return self.get_resources().pipeline @property - def build(self) -> Optional[PipelineBuildResponse]: + def build(self) -> PipelineBuildResponse | None: """The `build` property. Returns: @@ -308,7 +303,7 @@ def build(self) -> Optional[PipelineBuildResponse]: return self.get_resources().build @property - def code_reference(self) -> Optional[CodeReferenceResponse]: + def code_reference(self) -> CodeReferenceResponse | None: """The `code_reference` property. Returns: @@ -317,7 +312,7 @@ def code_reference(self) -> Optional[CodeReferenceResponse]: return self.get_resources().code_reference @property - def tags(self) -> List[TagResponse]: + def tags(self) -> list[TagResponse]: """The `tags` property. Returns: @@ -332,7 +327,7 @@ def tags(self) -> List[TagResponse]: class RunTemplateFilter(ProjectScopedFilter, TaggableFilter): """Model for filtering of run templates.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS, *TaggableFilter.FILTER_EXCLUDE_FIELDS, "code_repository_id", @@ -347,51 +342,51 @@ class RunTemplateFilter(ProjectScopedFilter, TaggableFilter): *ProjectScopedFilter.CUSTOM_SORTING_OPTIONS, *TaggableFilter.CUSTOM_SORTING_OPTIONS, ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.CLI_EXCLUDE_FIELDS, *TaggableFilter.CLI_EXCLUDE_FIELDS, ] - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the run template.", ) - hidden: Optional[bool] = Field( + hidden: bool | None = Field( default=None, description="Whether the run template is hidden.", ) - pipeline_id: Optional[Union[UUID, str]] = Field( + pipeline_id: UUID | str | None = Field( default=None, description="Pipeline associated with the template.", union_mode="left_to_right", ) - build_id: Optional[Union[UUID, str]] = Field( + build_id: UUID | str | None = Field( default=None, description="Build associated with the template.", union_mode="left_to_right", ) - stack_id: Optional[Union[UUID, str]] = Field( + stack_id: UUID | str | None = Field( default=None, description="Stack associated with the template.", union_mode="left_to_right", ) - code_repository_id: Optional[Union[UUID, str]] = Field( + code_repository_id: UUID | str | None = Field( default=None, description="Code repository associated with the template.", union_mode="left_to_right", ) - pipeline: Optional[Union[UUID, str]] = Field( + pipeline: UUID | str | None = Field( default=None, description="Name/ID of the pipeline associated with the template.", ) - stack: Optional[Union[UUID, str]] = Field( + stack: UUID | str | None = Field( default=None, description="Name/ID of the stack associated with the template.", ) def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List["ColumnElement[bool]"]: + self, table: type["AnySchema"] + ) -> list["ColumnElement[bool]"]: """Get custom filters. Args: diff --git a/src/zenml/models/v2/core/schedule.py b/src/zenml/models/v2/core/schedule.py index 4870938460a..d382cc1766e 100644 --- a/src/zenml/models/v2/core/schedule.py +++ b/src/zenml/models/v2/core/schedule.py @@ -14,7 +14,6 @@ """Models representing schedules.""" from datetime import datetime, timedelta, timezone -from typing import Dict, Optional, Union from uuid import UUID from pydantic import Field, field_validator, model_validator @@ -45,23 +44,23 @@ class ScheduleRequest(ProjectScopedRequest): name: str active: bool - cron_expression: Optional[str] = None - start_time: Optional[datetime] = None - end_time: Optional[datetime] = None - interval_second: Optional[timedelta] = None + cron_expression: str | None = None + start_time: datetime | None = None + end_time: datetime | None = None + interval_second: timedelta | None = None catchup: bool = False - run_once_start_time: Optional[datetime] = None + run_once_start_time: datetime | None = None - orchestrator_id: Optional[UUID] - pipeline_id: Optional[UUID] + orchestrator_id: UUID | None + pipeline_id: UUID | None @field_validator( "start_time", "end_time", "run_once_start_time", mode="after" ) @classmethod def _ensure_tzunaware_utc( - cls, value: Optional[datetime] - ) -> Optional[datetime]: + cls, value: datetime | None + ) -> datetime | None: """Ensures that all datetimes are timezone unaware and in UTC time. Args: @@ -126,8 +125,8 @@ def _ensure_cron_or_periodic_schedule_configured( class ScheduleUpdate(BaseUpdate): """Update model for schedules.""" - name: Optional[str] = None - cron_expression: Optional[str] = None + name: str | None = None + cron_expression: str | None = None # ------------------ Response Model ------------------ @@ -137,21 +136,21 @@ class ScheduleResponseBody(ProjectScopedResponseBody): """Response body for schedules.""" active: bool - cron_expression: Optional[str] = None - start_time: Optional[datetime] = None - end_time: Optional[datetime] = None - interval_second: Optional[timedelta] = None + cron_expression: str | None = None + start_time: datetime | None = None + end_time: datetime | None = None + interval_second: timedelta | None = None catchup: bool = False - run_once_start_time: Optional[datetime] = None + run_once_start_time: datetime | None = None class ScheduleResponseMetadata(ProjectScopedResponseMetadata): """Response metadata for schedules.""" - orchestrator_id: Optional[UUID] - pipeline_id: Optional[UUID] + orchestrator_id: UUID | None + pipeline_id: UUID | None - run_metadata: Dict[str, MetadataType] = Field( + run_metadata: dict[str, MetadataType] = Field( title="Metadata associated with this schedule.", default={}, ) @@ -187,7 +186,7 @@ def get_hydrated_version(self) -> "ScheduleResponse": # Helper methods @property - def utc_start_time(self) -> Optional[str]: + def utc_start_time(self) -> str | None: """Optional ISO-formatted string of the UTC start time. Returns: @@ -199,7 +198,7 @@ def utc_start_time(self) -> Optional[str]: return to_utc_timezone(self.start_time).isoformat() @property - def utc_end_time(self) -> Optional[str]: + def utc_end_time(self) -> str | None: """Optional ISO-formatted string of the UTC end time. Returns: @@ -221,7 +220,7 @@ def active(self) -> bool: return self.get_body().active @property - def cron_expression(self) -> Optional[str]: + def cron_expression(self) -> str | None: """The `cron_expression` property. Returns: @@ -230,7 +229,7 @@ def cron_expression(self) -> Optional[str]: return self.get_body().cron_expression @property - def start_time(self) -> Optional[datetime]: + def start_time(self) -> datetime | None: """The `start_time` property. Returns: @@ -239,7 +238,7 @@ def start_time(self) -> Optional[datetime]: return self.get_body().start_time @property - def end_time(self) -> Optional[datetime]: + def end_time(self) -> datetime | None: """The `end_time` property. Returns: @@ -248,7 +247,7 @@ def end_time(self) -> Optional[datetime]: return self.get_body().end_time @property - def run_once_start_time(self) -> Optional[datetime]: + def run_once_start_time(self) -> datetime | None: """The `run_once_start_time` property. Returns: @@ -257,7 +256,7 @@ def run_once_start_time(self) -> Optional[datetime]: return self.get_body().run_once_start_time @property - def interval_second(self) -> Optional[timedelta]: + def interval_second(self) -> timedelta | None: """The `interval_second` property. Returns: @@ -275,7 +274,7 @@ def catchup(self) -> bool: return self.get_body().catchup @property - def orchestrator_id(self) -> Optional[UUID]: + def orchestrator_id(self) -> UUID | None: """The `orchestrator_id` property. Returns: @@ -284,7 +283,7 @@ def orchestrator_id(self) -> Optional[UUID]: return self.get_metadata().orchestrator_id @property - def pipeline_id(self) -> Optional[UUID]: + def pipeline_id(self) -> UUID | None: """The `pipeline_id` property. Returns: @@ -293,7 +292,7 @@ def pipeline_id(self) -> Optional[UUID]: return self.get_metadata().pipeline_id @property - def run_metadata(self) -> Dict[str, MetadataType]: + def run_metadata(self) -> dict[str, MetadataType]: """The `run_metadata` property. Returns: @@ -308,44 +307,44 @@ def run_metadata(self) -> Dict[str, MetadataType]: class ScheduleFilter(ProjectScopedFilter): """Model to enable advanced filtering of all Users.""" - pipeline_id: Optional[Union[UUID, str]] = Field( + pipeline_id: UUID | str | None = Field( default=None, description="Pipeline that the schedule is attached to.", union_mode="left_to_right", ) - orchestrator_id: Optional[Union[UUID, str]] = Field( + orchestrator_id: UUID | str | None = Field( default=None, description="Orchestrator that the schedule is attached to.", union_mode="left_to_right", ) - active: Optional[bool] = Field( + active: bool | None = Field( default=None, description="If the schedule is active", ) - cron_expression: Optional[str] = Field( + cron_expression: str | None = Field( default=None, description="The cron expression, describing the schedule", ) - start_time: Optional[Union[datetime, str]] = Field( + start_time: datetime | str | None = Field( default=None, description="Start time", union_mode="left_to_right" ) - end_time: Optional[Union[datetime, str]] = Field( + end_time: datetime | str | None = Field( default=None, description="End time", union_mode="left_to_right" ) - interval_second: Optional[Optional[float]] = Field( + interval_second: float | None | None = Field( default=None, description="The repetition interval in seconds", ) - catchup: Optional[bool] = Field( + catchup: bool | None = Field( default=None, description="Whether or not the schedule is set to catchup past missed " "events", ) - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the schedule", ) - run_once_start_time: Optional[Union[datetime, str]] = Field( + run_once_start_time: datetime | str | None = Field( default=None, description="The time at which the schedule should run once", union_mode="left_to_right", diff --git a/src/zenml/models/v2/core/secret.py b/src/zenml/models/v2/core/secret.py index b592238d232..57c0f9fe2f7 100644 --- a/src/zenml/models/v2/core/secret.py +++ b/src/zenml/models/v2/core/secret.py @@ -16,10 +16,6 @@ from typing import ( TYPE_CHECKING, ClassVar, - Dict, - List, - Optional, - Type, TypeVar, ) @@ -49,7 +45,7 @@ class SecretRequest(UserScopedRequest): """Request model for secrets.""" - ANALYTICS_FIELDS: ClassVar[List[str]] = ["private"] + ANALYTICS_FIELDS: ClassVar[list[str]] = ["private"] name: str = Field( title="The name of the secret.", @@ -60,12 +56,12 @@ class SecretRequest(UserScopedRequest): title="Whether the secret is private. A private secret is only " "accessible to the user who created it.", ) - values: Dict[str, Optional[PlainSerializedSecretStr]] = Field( + values: dict[str, PlainSerializedSecretStr | None] = Field( default_factory=dict, title="The values stored in this secret." ) @property - def secret_values(self) -> Dict[str, str]: + def secret_values(self) -> dict[str, str]: """A dictionary with all un-obfuscated values stored in this secret. The values are returned as strings, not SecretStr. If a value is @@ -89,24 +85,24 @@ def secret_values(self) -> Dict[str, str]: class SecretUpdate(BaseUpdate): """Update model for secrets.""" - ANALYTICS_FIELDS: ClassVar[List[str]] = ["private"] + ANALYTICS_FIELDS: ClassVar[list[str]] = ["private"] - name: Optional[str] = Field( + name: str | None = Field( title="The name of the secret.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - private: Optional[bool] = Field( + private: bool | None = Field( default=None, title="Whether the secret is private. A private secret is only " "accessible to the user who created it.", ) - values: Optional[Dict[str, Optional[PlainSerializedSecretStr]]] = Field( + values: dict[str, PlainSerializedSecretStr | None] | None = Field( title="The values stored in this secret.", default=None, ) - def get_secret_values_update(self) -> Dict[str, Optional[str]]: + def get_secret_values_update(self) -> dict[str, str | None]: """Returns a dictionary with the secret values to update. Returns: @@ -132,7 +128,7 @@ class SecretResponseBody(UserScopedResponseBody): title="Whether the secret is private. A private secret is only " "accessible to the user who created it.", ) - values: Dict[str, Optional[PlainSerializedSecretStr]] = Field( + values: dict[str, PlainSerializedSecretStr | None] = Field( default_factory=dict, title="The values stored in this secret." ) @@ -154,7 +150,7 @@ class SecretResponse( ): """Response model for secrets.""" - ANALYTICS_FIELDS: ClassVar[List[str]] = ["private"] + ANALYTICS_FIELDS: ClassVar[list[str]] = ["private"] name: str = Field( title="The name of the secret.", @@ -183,7 +179,7 @@ def private(self) -> bool: return self.get_body().private @property - def values(self) -> Dict[str, Optional[SecretStr]]: + def values(self) -> dict[str, SecretStr | None]: """The `values` property. Returns: @@ -193,7 +189,7 @@ def values(self) -> Dict[str, Optional[SecretStr]]: # Helper methods @property - def secret_values(self) -> Dict[str, str]: + def secret_values(self) -> dict[str, str]: """A dictionary with all un-obfuscated values stored in this secret. The values are returned as strings, not SecretStr. If a value is @@ -243,7 +239,7 @@ def remove_secrets(self) -> None: """Removes all secret values from the secret but keep the keys.""" self.get_body().values = {k: None for k in self.values.keys()} - def set_secrets(self, values: Dict[str, str]) -> None: + def set_secrets(self, values: dict[str, str]) -> None: """Sets the secret values of the secret. Args: @@ -258,16 +254,16 @@ def set_secrets(self, values: Dict[str, str]) -> None: class SecretFilter(UserScopedFilter): """Model to enable advanced secret filtering.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *UserScopedFilter.FILTER_EXCLUDE_FIELDS, "values", ] - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the secret", ) - private: Optional[bool] = Field( + private: bool | None = Field( default=None, description="Whether to filter secrets by private status", ) @@ -275,7 +271,7 @@ class SecretFilter(UserScopedFilter): def apply_filter( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Applies the filter to a query. diff --git a/src/zenml/models/v2/core/server_settings.py b/src/zenml/models/v2/core/server_settings.py index 543130a0f9a..e988aeec72e 100644 --- a/src/zenml/models/v2/core/server_settings.py +++ b/src/zenml/models/v2/core/server_settings.py @@ -14,9 +14,6 @@ """Models representing server settings stored in the database.""" from datetime import datetime -from typing import ( - Optional, -) from uuid import UUID from pydantic import Field @@ -37,21 +34,21 @@ class ServerSettingsUpdate(BaseUpdate): """Model for updating server settings.""" - server_name: Optional[str] = Field( + server_name: str | None = Field( default=None, title="The name of the server." ) - logo_url: Optional[str] = Field( + logo_url: str | None = Field( default=None, title="The logo URL of the server." ) - enable_analytics: Optional[bool] = Field( + enable_analytics: bool | None = Field( default=None, title="Whether to enable analytics for the server.", ) - display_announcements: Optional[bool] = Field( + display_announcements: bool | None = Field( default=None, title="Whether to display announcements about ZenML in the dashboard.", ) - display_updates: Optional[bool] = Field( + display_updates: bool | None = Field( default=None, title="Whether to display notifications about ZenML updates in the dashboard.", ) @@ -67,7 +64,7 @@ class ServerSettingsResponseBody(BaseResponseBody): title="The unique server id.", ) server_name: str = Field(title="The name of the server.") - logo_url: Optional[str] = Field( + logo_url: str | None = Field( default=None, title="The logo URL of the server." ) active: bool = Field( @@ -76,10 +73,10 @@ class ServerSettingsResponseBody(BaseResponseBody): enable_analytics: bool = Field( title="Whether analytics are enabled for the server.", ) - display_announcements: Optional[bool] = Field( + display_announcements: bool | None = Field( title="Whether to display announcements about ZenML in the dashboard.", ) - display_updates: Optional[bool] = Field( + display_updates: bool | None = Field( title="Whether to display notifications about ZenML updates in the dashboard.", ) last_user_activity: datetime = Field( @@ -138,7 +135,7 @@ def server_name(self) -> str: return self.get_body().server_name @property - def logo_url(self) -> Optional[str]: + def logo_url(self) -> str | None: """The `logo_url` property. Returns: @@ -156,7 +153,7 @@ def enable_analytics(self) -> bool: return self.get_body().enable_analytics @property - def display_announcements(self) -> Optional[bool]: + def display_announcements(self) -> bool | None: """The `display_announcements` property. Returns: @@ -165,7 +162,7 @@ def display_announcements(self) -> Optional[bool]: return self.get_body().display_announcements @property - def display_updates(self) -> Optional[bool]: + def display_updates(self) -> bool | None: """The `display_updates` property. Returns: @@ -211,13 +208,13 @@ def updated(self) -> datetime: class ServerActivationRequest(ServerSettingsUpdate): """Model for activating the server.""" - admin_username: Optional[str] = Field( + admin_username: str | None = Field( default=None, title="The username of the default admin account to create. Leave " "empty to skip creating the default admin account.", ) - admin_password: Optional[str] = Field( + admin_password: str | None = Field( default=None, title="The password of the default admin account to create. Leave " "empty to skip creating the default admin account.", diff --git a/src/zenml/models/v2/core/service.py b/src/zenml/models/v2/core/service.py index dda4ca0ba75..89735067132 100644 --- a/src/zenml/models/v2/core/service.py +++ b/src/zenml/models/v2/core/service.py @@ -18,10 +18,7 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - List, Optional, - Type, TypeVar, Union, ) @@ -64,48 +61,48 @@ class ServiceRequest(ProjectScopedRequest): service_type: ServiceType = Field( title="The type of the service.", ) - service_source: Optional[str] = Field( + service_source: str | None = Field( title="The class of the service.", description="The fully qualified class name of the service " "implementation.", default=None, ) - admin_state: Optional[ServiceState] = Field( + admin_state: ServiceState | None = Field( title="The admin state of the service.", description="The administrative state of the service, e.g., ACTIVE, " "INACTIVE.", default=None, ) - config: Dict[str, Any] = Field( + config: dict[str, Any] = Field( title="The service config.", description="A dictionary containing configuration parameters for the " "service.", ) - labels: Optional[Dict[str, str]] = Field( + labels: dict[str, str] | None = Field( default=None, title="The service labels.", ) - status: Optional[Dict[str, Any]] = Field( + status: dict[str, Any] | None = Field( default=None, title="The status of the service.", ) - endpoint: Optional[Dict[str, Any]] = Field( + endpoint: dict[str, Any] | None = Field( default=None, title="The service endpoint.", ) - prediction_url: Optional[str] = Field( + prediction_url: str | None = Field( default=None, title="The service endpoint URL.", ) - health_check_url: Optional[str] = Field( + health_check_url: str | None = Field( default=None, title="The service health check URL.", ) - model_version_id: Optional[UUID] = Field( + model_version_id: UUID | None = Field( default=None, title="The model version id linked to the service.", ) - pipeline_run_id: Optional[UUID] = Field( + pipeline_run_id: UUID | None = Field( default=None, title="The pipeline run id linked to the service.", ) @@ -125,44 +122,44 @@ class ServiceRequest(ProjectScopedRequest): class ServiceUpdate(BaseUpdate): """Update model for stack components.""" - name: Optional[str] = Field( + name: str | None = Field( None, title="The name of the service.", max_length=STR_FIELD_MAX_LENGTH, ) - admin_state: Optional[ServiceState] = Field( + admin_state: ServiceState | None = Field( None, title="The admin state of the service.", description="The administrative state of the service, e.g., ACTIVE, " "INACTIVE.", ) - service_source: Optional[str] = Field( + service_source: str | None = Field( None, title="The class of the service.", description="The fully qualified class name of the service " "implementation.", ) - status: Optional[Dict[str, Any]] = Field( + status: dict[str, Any] | None = Field( None, title="The status of the service.", ) - endpoint: Optional[Dict[str, Any]] = Field( + endpoint: dict[str, Any] | None = Field( None, title="The service endpoint.", ) - prediction_url: Optional[str] = Field( + prediction_url: str | None = Field( None, title="The service endpoint URL.", ) - health_check_url: Optional[str] = Field( + health_check_url: str | None = Field( None, title="The service health check URL.", ) - labels: Optional[Dict[str, str]] = Field( + labels: dict[str, str] | None = Field( default=None, title="The service labels.", ) - model_version_id: Optional[UUID] = Field( + model_version_id: UUID | None = Field( default=None, title="The model version id linked to the service.", ) @@ -185,7 +182,7 @@ class ServiceResponseBody(ProjectScopedResponseBody): service_type: ServiceType = Field( title="The type of the service.", ) - labels: Optional[Dict[str, str]] = Field( + labels: dict[str, str] | None = Field( default=None, title="The service labels.", ) @@ -195,7 +192,7 @@ class ServiceResponseBody(ProjectScopedResponseBody): updated: datetime = Field( title="The timestamp when this component was last updated.", ) - state: Optional[ServiceState] = Field( + state: ServiceState | None = Field( default=None, title="The current state of the service.", ) @@ -204,27 +201,27 @@ class ServiceResponseBody(ProjectScopedResponseBody): class ServiceResponseMetadata(ProjectScopedResponseMetadata): """Response metadata for services.""" - service_source: Optional[str] = Field( + service_source: str | None = Field( title="The class of the service.", ) - admin_state: Optional[ServiceState] = Field( + admin_state: ServiceState | None = Field( title="The admin state of the service.", ) - config: Dict[str, Any] = Field( + config: dict[str, Any] = Field( title="The service config.", ) - status: Optional[Dict[str, Any]] = Field( + status: dict[str, Any] | None = Field( title="The status of the service.", ) - endpoint: Optional[Dict[str, Any]] = Field( + endpoint: dict[str, Any] | None = Field( default=None, title="The service endpoint.", ) - prediction_url: Optional[str] = Field( + prediction_url: str | None = Field( default=None, title="The service endpoint URL.", ) - health_check_url: Optional[str] = Field( + health_check_url: str | None = Field( default=None, title="The service health check URL.", ) @@ -285,7 +282,7 @@ def service_type(self) -> ServiceType: return self.get_body().service_type @property - def labels(self) -> Optional[Dict[str, str]]: + def labels(self) -> dict[str, str] | None: """The `labels` property. Returns: @@ -294,7 +291,7 @@ def labels(self) -> Optional[Dict[str, str]]: return self.get_body().labels @property - def service_source(self) -> Optional[str]: + def service_source(self) -> str | None: """The `service_source` property. Returns: @@ -303,7 +300,7 @@ def service_source(self) -> Optional[str]: return self.get_metadata().service_source @property - def config(self) -> Dict[str, Any]: + def config(self) -> dict[str, Any]: """The `config` property. Returns: @@ -312,7 +309,7 @@ def config(self) -> Dict[str, Any]: return self.get_metadata().config @property - def status(self) -> Optional[Dict[str, Any]]: + def status(self) -> dict[str, Any] | None: """The `status` property. Returns: @@ -321,7 +318,7 @@ def status(self) -> Optional[Dict[str, Any]]: return self.get_metadata().status @property - def endpoint(self) -> Optional[Dict[str, Any]]: + def endpoint(self) -> dict[str, Any] | None: """The `endpoint` property. Returns: @@ -348,7 +345,7 @@ def updated(self) -> datetime: return self.get_body().updated @property - def admin_state(self) -> Optional[ServiceState]: + def admin_state(self) -> ServiceState | None: """The `admin_state` property. Returns: @@ -357,7 +354,7 @@ def admin_state(self) -> Optional[ServiceState]: return self.get_metadata().admin_state @property - def prediction_url(self) -> Optional[str]: + def prediction_url(self) -> str | None: """The `prediction_url` property. Returns: @@ -366,7 +363,7 @@ def prediction_url(self) -> Optional[str]: return self.get_metadata().prediction_url @property - def health_check_url(self) -> Optional[str]: + def health_check_url(self) -> str | None: """The `health_check_url` property. Returns: @@ -375,7 +372,7 @@ def health_check_url(self) -> Optional[str]: return self.get_metadata().health_check_url @property - def state(self) -> Optional[ServiceState]: + def state(self) -> ServiceState | None: """The `state` property. Returns: @@ -408,42 +405,42 @@ def model_version(self) -> Optional["ModelVersionResponse"]: class ServiceFilter(ProjectScopedFilter): """Model to enable advanced filtering of services.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the service. Use this to filter services by " "their name.", ) - type: Optional[str] = Field( + type: str | None = Field( default=None, description="Type of the service. Filter services by their type.", ) - flavor: Optional[str] = Field( + flavor: str | None = Field( default=None, description="Flavor of the service. Use this to filter services by " "their flavor.", ) - config: Optional[bytes] = Field( + config: bytes | None = Field( default=None, description="Config of the service. Use this to filter services by " "their config.", ) - pipeline_name: Optional[str] = Field( + pipeline_name: str | None = Field( default=None, description="Pipeline name responsible for deploying the service", ) - pipeline_step_name: Optional[str] = Field( + pipeline_step_name: str | None = Field( default=None, description="Pipeline step name responsible for deploying the service", ) - running: Optional[bool] = Field( + running: bool | None = Field( default=None, description="Whether the service is running" ) - model_version_id: Optional[Union[UUID, str]] = Field( + model_version_id: UUID | str | None = Field( default=None, description="By the model version this service is attached to.", union_mode="left_to_right", ) - pipeline_run_id: Optional[Union[UUID, str]] = Field( + pipeline_run_id: UUID | str | None = Field( default=None, description="By the pipeline run this service is attached to.", union_mode="left_to_right", @@ -483,7 +480,7 @@ def set_flavor(self, flavor: str) -> None: "pipeline_name", "config", ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.CLI_EXCLUDE_FIELDS, "flavor", "type", @@ -493,7 +490,7 @@ def set_flavor(self, flavor: str) -> None: ] def generate_filter( - self, table: Type["AnySchema"] + self, table: type["AnySchema"] ) -> Union["ColumnElement[bool]"]: """Generate the filter for the query. diff --git a/src/zenml/models/v2/core/service_account.py b/src/zenml/models/v2/core/service_account.py index 1f9f39453d6..115e112978f 100644 --- a/src/zenml/models/v2/core/service_account.py +++ b/src/zenml/models/v2/core/service_account.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing service accounts.""" -from typing import TYPE_CHECKING, ClassVar, List, Optional, Type, Union +from typing import TYPE_CHECKING, ClassVar from uuid import UUID from pydantic import ConfigDict, Field @@ -38,7 +38,7 @@ class ServiceAccountRequest(BaseRequest): """Request model for service accounts.""" - ANALYTICS_FIELDS: ClassVar[List[str]] = [ + ANALYTICS_FIELDS: ClassVar[list[str]] = [ "name", "active", ] @@ -52,13 +52,13 @@ class ServiceAccountRequest(BaseRequest): max_length=STR_FIELD_MAX_LENGTH, title="The display name of the service account.", ) - description: Optional[str] = Field( + description: str | None = Field( default=None, title="A description of the service account.", max_length=TEXT_FIELD_MAX_LENGTH, ) active: bool = Field(title="Whether the service account is active or not.") - avatar_url: Optional[str] = Field( + avatar_url: str | None = Field( default=None, title="The avatar URL for the account.", ) @@ -69,7 +69,7 @@ class ServiceAccountRequest(BaseRequest): class ServiceAccountInternalRequest(ServiceAccountRequest): """Internal request model for service accounts.""" - external_user_id: Optional[UUID] = Field( + external_user_id: UUID | None = Field( default=None, title="The external user ID associated with the account.", ) @@ -81,29 +81,29 @@ class ServiceAccountInternalRequest(ServiceAccountRequest): class ServiceAccountUpdate(BaseUpdate): """Update model for service accounts.""" - ANALYTICS_FIELDS: ClassVar[List[str]] = ["name", "active"] + ANALYTICS_FIELDS: ClassVar[list[str]] = ["name", "active"] - name: Optional[str] = Field( + name: str | None = Field( title="The unique name for the service account.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - full_name: Optional[str] = Field( + full_name: str | None = Field( title="The display name of the service account.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - description: Optional[str] = Field( + description: str | None = Field( title="A description of the service account.", max_length=TEXT_FIELD_MAX_LENGTH, default=None, ) - active: Optional[bool] = Field( + active: bool | None = Field( title="Whether the service account is active or not.", default=None, ) - avatar_url: Optional[str] = Field( + avatar_url: str | None = Field( default=None, title="The avatar URL for the account.", ) @@ -114,7 +114,7 @@ class ServiceAccountUpdate(BaseUpdate): class ServiceAccountInternalUpdate(ServiceAccountUpdate): """Internal update model for service accounts.""" - external_user_id: Optional[UUID] = Field( + external_user_id: UUID | None = Field( default=None, title="The external user ID associated with the account.", ) @@ -131,7 +131,7 @@ class ServiceAccountResponseBody(BaseDatedResponseBody): title="The display name of the service account.", ) active: bool = Field(default=False, title="Whether the account is active.") - avatar_url: Optional[str] = Field( + avatar_url: str | None = Field( default=None, title="The avatar URL for the account.", ) @@ -146,7 +146,7 @@ class ServiceAccountResponseMetadata(BaseResponseMetadata): max_length=TEXT_FIELD_MAX_LENGTH, ) - external_user_id: Optional[UUID] = Field( + external_user_id: UUID | None = Field( default=None, title="The external user ID associated with the account.", ) @@ -165,7 +165,7 @@ class ServiceAccountResponse( ): """Response model for service accounts.""" - ANALYTICS_FIELDS: ClassVar[List[str]] = [ + ANALYTICS_FIELDS: ClassVar[list[str]] = [ "name", "active", ] @@ -249,7 +249,7 @@ def description(self) -> str: return self.get_metadata().description @property - def external_user_id(self) -> Optional[UUID]: + def external_user_id(self) -> UUID | None: """The `external_user_id` property. Returns: @@ -258,7 +258,7 @@ def external_user_id(self) -> Optional[UUID]: return self.get_metadata().external_user_id @property - def avatar_url(self) -> Optional[str]: + def avatar_url(self) -> str | None: """The `avatar_url` property. Returns: @@ -271,20 +271,20 @@ def avatar_url(self) -> Optional[str]: class ServiceAccountFilter(BaseFilter): """Model to enable advanced filtering of service accounts.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the user", ) - description: Optional[str] = Field( + description: str | None = Field( default=None, title="Filter by the service account description.", ) - active: Optional[Union[bool, str]] = Field( + active: bool | str | None = Field( default=None, description="Whether the user is active", union_mode="left_to_right", ) - external_user_id: Optional[Union[UUID, str]] = Field( + external_user_id: UUID | str | None = Field( default=None, title="The external user ID associated with the account.", union_mode="left_to_right", @@ -293,7 +293,7 @@ class ServiceAccountFilter(BaseFilter): def apply_filter( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Override to filter out user accounts from the query. diff --git a/src/zenml/models/v2/core/service_connector.py b/src/zenml/models/v2/core/service_connector.py index e81c9c77849..7f0a3f679f4 100644 --- a/src/zenml/models/v2/core/service_connector.py +++ b/src/zenml/models/v2/core/service_connector.py @@ -15,7 +15,7 @@ import json from datetime import datetime -from typing import Any, ClassVar, Dict, List, Optional, Union +from typing import Any, ClassVar, Union from pydantic import ( Field, @@ -47,12 +47,12 @@ # ------------------ Configuration Model ------------------ -class ServiceConnectorConfiguration(Dict[str, Any]): +class ServiceConnectorConfiguration(dict[str, Any]): """Model for service connector configuration.""" @classmethod def from_dict( - cls, data: Dict[str, Any] + cls, data: dict[str, Any] ) -> "ServiceConnectorConfiguration": """Create a configuration model from a dictionary. @@ -65,7 +65,7 @@ def from_dict( return cls(**data) @property - def secrets(self) -> Dict[str, PlainSerializedSecretStr]: + def secrets(self) -> dict[str, PlainSerializedSecretStr]: """Get the secrets from the configuration. Returns: @@ -74,7 +74,7 @@ def secrets(self) -> Dict[str, PlainSerializedSecretStr]: return {k: v for k, v in self.items() if isinstance(v, SecretStr)} @property - def plain_secrets(self) -> Dict[str, str]: + def plain_secrets(self) -> dict[str, str]: """Get the plain secrets from the configuration. Returns: @@ -87,7 +87,7 @@ def plain_secrets(self) -> Dict[str, str]: } @property - def non_secrets(self) -> Dict[str, Any]: + def non_secrets(self) -> dict[str, Any]: """Get the non-secrets from the configuration. Returns: @@ -96,7 +96,7 @@ def non_secrets(self) -> Dict[str, Any]: return {k: v for k, v in self.items() if not isinstance(v, SecretStr)} @property - def plain(self) -> Dict[str, Any]: + def plain(self) -> dict[str, Any]: """Get the configuration with secrets unpacked. Returns: @@ -122,7 +122,7 @@ def get_plain(self, key: str, default: Any = None) -> Any: return result.get_secret_value() return result - def add_secrets(self, secrets: Dict[str, str]) -> None: + def add_secrets(self, secrets: dict[str, str]) -> None: """Add the secrets to the configuration. Args: @@ -181,12 +181,12 @@ class ServiceConnectorRequest(UserScopedRequest): "access the resources.", max_length=STR_FIELD_MAX_LENGTH, ) - resource_types: List[str] = Field( + resource_types: list[str] = Field( default_factory=list, title="The type(s) of resource that the connector instance can be used " "to gain access to.", ) - resource_id: Optional[str] = Field( + resource_id: str | None = Field( default=None, title="Uniquely identifies a specific resource instance that the " "connector instance can be used to access. If omitted, the connector " @@ -199,18 +199,18 @@ class ServiceConnectorRequest(UserScopedRequest): title="Indicates whether the connector instance can be used to access " "multiple instances of the configured resource type.", ) - expires_at: Optional[datetime] = Field( + expires_at: datetime | None = Field( default=None, title="Time when the authentication credentials configured for the " "connector expire. If omitted, the credentials do not expire.", ) - expires_skew_tolerance: Optional[int] = Field( + expires_skew_tolerance: int | None = Field( default=None, title="The number of seconds of tolerance to apply when checking " "whether the authentication credentials configured for the connector " "have expired. If omitted, no tolerance is applied.", ) - expiration_seconds: Optional[int] = Field( + expiration_seconds: int | None = Field( default=None, title="The duration, in seconds, that the temporary credentials " "generated by this connector should remain valid. Only applicable for " @@ -221,19 +221,19 @@ class ServiceConnectorRequest(UserScopedRequest): default_factory=ServiceConnectorConfiguration, title="The service connector configuration.", ) - labels: Dict[str, str] = Field( + labels: dict[str, str] = Field( default_factory=dict, title="Service connector labels.", ) # Analytics - ANALYTICS_FIELDS: ClassVar[List[str]] = [ + ANALYTICS_FIELDS: ClassVar[list[str]] = [ "connector_type", "auth_method", "resource_types", ] - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """Format the resource types in the analytics metadata. Returns: @@ -272,7 +272,7 @@ def emojified_connector_type(self) -> str: return self.connector_type @property - def emojified_resource_types(self) -> List[str]: + def emojified_resource_types(self) -> list[str]: """Get the emojified connector type. Returns: @@ -291,9 +291,9 @@ def emojified_resource_types(self) -> List[str]: def validate_and_configure_resources( self, connector_type: "ServiceConnectorTypeModel", - resource_types: Optional[Union[str, List[str]]] = None, - resource_id: Optional[str] = None, - configuration: Optional[Dict[str, Any]] = None, + resource_types: str | list[str] | None = None, + resource_id: str | None = None, + configuration: dict[str, Any] | None = None, ) -> None: """Validate and configure the resources that the connector can be used to access. @@ -343,32 +343,32 @@ class ServiceConnectorUpdate(BaseUpdate): have a None default value. """ - name: Optional[str] = Field( + name: str | None = Field( title="The service connector name.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - connector_type: Optional[Union[str, "ServiceConnectorTypeModel"]] = Field( + connector_type: Union[str, "ServiceConnectorTypeModel"] | None = Field( title="The type of service connector.", default=None, union_mode="left_to_right", ) - description: Optional[str] = Field( + description: str | None = Field( title="The service connector instance description.", default=None, ) - auth_method: Optional[str] = Field( + auth_method: str | None = Field( title="The authentication method that the connector instance uses to " "access the resources.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - resource_types: Optional[List[str]] = Field( + resource_types: list[str] | None = Field( title="The type(s) of resource that the connector instance can be used " "to gain access to.", default=None, ) - resource_id: Optional[str] = Field( + resource_id: str | None = Field( title="Uniquely identifies a specific resource instance that the " "connector instance can be used to access. If omitted, the " "connector instance can be used to access any and all resource " @@ -377,23 +377,23 @@ class ServiceConnectorUpdate(BaseUpdate): max_length=STR_FIELD_MAX_LENGTH, default=None, ) - supports_instances: Optional[bool] = Field( + supports_instances: bool | None = Field( title="Indicates whether the connector instance can be used to access " "multiple instances of the configured resource type.", default=None, ) - expires_at: Optional[datetime] = Field( + expires_at: datetime | None = Field( title="Time when the authentication credentials configured for the " "connector expire. If omitted, the credentials do not expire.", default=None, ) - expires_skew_tolerance: Optional[int] = Field( + expires_skew_tolerance: int | None = Field( title="The number of seconds of tolerance to apply when checking " "whether the authentication credentials configured for the " "connector have expired. If omitted, no tolerance is applied.", default=None, ) - expiration_seconds: Optional[int] = Field( + expiration_seconds: int | None = Field( title="The duration, in seconds, that the temporary credentials " "generated by this connector should remain valid. Only " "applicable for connectors and authentication methods that " @@ -401,23 +401,23 @@ class ServiceConnectorUpdate(BaseUpdate): "configured in the connector.", default=None, ) - configuration: Optional[ServiceConnectorConfiguration] = Field( + configuration: ServiceConnectorConfiguration | None = Field( title="The service connector full configuration replacement.", default=None, ) - labels: Optional[Dict[str, str]] = Field( + labels: dict[str, str] | None = Field( title="Service connector labels.", default=None, ) # Analytics - ANALYTICS_FIELDS: ClassVar[List[str]] = [ + ANALYTICS_FIELDS: ClassVar[list[str]] = [ "connector_type", "auth_method", "resource_types", ] - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """Format the resource types in the analytics metadata. Returns: @@ -438,7 +438,7 @@ def get_analytics_metadata(self) -> Dict[str, Any]: # Helper methods @property - def type(self) -> Optional[str]: + def type(self) -> str | None: """Get the connector type. Returns: @@ -453,9 +453,9 @@ def type(self) -> Optional[str]: def validate_and_configure_resources( self, connector_type: "ServiceConnectorTypeModel", - resource_types: Optional[Union[str, List[str]]] = None, - resource_id: Optional[str] = None, - configuration: Optional[Dict[str, Any]] = None, + resource_types: str | list[str] | None = None, + resource_id: str | None = None, + configuration: dict[str, Any] | None = None, ) -> None: """Validate and configure the resources that the connector can be used to access. @@ -518,12 +518,12 @@ class ServiceConnectorResponseBody(UserScopedResponseBody): "access the resources.", max_length=STR_FIELD_MAX_LENGTH, ) - resource_types: List[str] = Field( + resource_types: list[str] = Field( default_factory=list, title="The type(s) of resource that the connector instance can be used " "to gain access to.", ) - resource_id: Optional[str] = Field( + resource_id: str | None = Field( default=None, title="Uniquely identifies a specific resource instance that the " "connector instance can be used to access. If omitted, the connector " @@ -536,12 +536,12 @@ class ServiceConnectorResponseBody(UserScopedResponseBody): title="Indicates whether the connector instance can be used to access " "multiple instances of the configured resource type.", ) - expires_at: Optional[datetime] = Field( + expires_at: datetime | None = Field( default=None, title="Time when the authentication credentials configured for the " "connector expire. If omitted, the credentials do not expire.", ) - expires_skew_tolerance: Optional[int] = Field( + expires_skew_tolerance: int | None = Field( default=None, title="The number of seconds of tolerance to apply when checking " "whether the authentication credentials configured for the connector " @@ -556,14 +556,14 @@ class ServiceConnectorResponseMetadata(UserScopedResponseMetadata): default_factory=ServiceConnectorConfiguration, title="The service connector configuration.", ) - expiration_seconds: Optional[int] = Field( + expiration_seconds: int | None = Field( default=None, title="The duration, in seconds, that the temporary credentials " "generated by this connector should remain valid. Only applicable for " "connectors and authentication methods that involve generating " "temporary credentials from the ones configured in the connector.", ) - labels: Dict[str, str] = Field( + labels: dict[str, str] = Field( default_factory=dict, title="Service connector labels.", ) @@ -591,7 +591,7 @@ class ServiceConnectorResponse( max_length=STR_FIELD_MAX_LENGTH, ) - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """Add the service connector labels to analytics metadata. Returns: @@ -643,7 +643,7 @@ def emojified_connector_type(self) -> str: return self.connector_type @property - def emojified_resource_types(self) -> List[str]: + def emojified_resource_types(self) -> list[str]: """Get the emojified connector type. Returns: @@ -723,9 +723,9 @@ def validate_configuration(self) -> None: def validate_and_configure_resources( self, connector_type: "ServiceConnectorTypeModel", - resource_types: Optional[Union[str, List[str]]] = None, - resource_id: Optional[str] = None, - configuration: Optional[Dict[str, Any]] = None, + resource_types: str | list[str] | None = None, + resource_id: str | None = None, + configuration: dict[str, Any] | None = None, ) -> None: """Validate and configure the resources that the connector can be used to access. @@ -776,7 +776,7 @@ def auth_method(self) -> str: return self.get_body().auth_method @property - def resource_types(self) -> List[str]: + def resource_types(self) -> list[str]: """The `resource_types` property. Returns: @@ -785,7 +785,7 @@ def resource_types(self) -> List[str]: return self.get_body().resource_types @property - def resource_id(self) -> Optional[str]: + def resource_id(self) -> str | None: """The `resource_id` property. Returns: @@ -803,7 +803,7 @@ def supports_instances(self) -> bool: return self.get_body().supports_instances @property - def expires_at(self) -> Optional[datetime]: + def expires_at(self) -> datetime | None: """The `expires_at` property. Returns: @@ -812,7 +812,7 @@ def expires_at(self) -> Optional[datetime]: return self.get_body().expires_at @property - def expires_skew_tolerance(self) -> Optional[int]: + def expires_skew_tolerance(self) -> int | None: """The `expires_skew_tolerance` property. Returns: @@ -836,7 +836,7 @@ def remove_secrets(self) -> None: **metadata.configuration.non_secrets ) - def add_secrets(self, secrets: Dict[str, str]) -> None: + def add_secrets(self, secrets: dict[str, str]) -> None: """Add the secrets to the configuration. Args: @@ -845,7 +845,7 @@ def add_secrets(self, secrets: Dict[str, str]) -> None: self.get_metadata().configuration.add_secrets(secrets) @property - def expiration_seconds(self) -> Optional[int]: + def expiration_seconds(self) -> int | None: """The `expiration_seconds` property. Returns: @@ -854,7 +854,7 @@ def expiration_seconds(self) -> Optional[int]: return self.get_metadata().expiration_seconds @property - def labels(self) -> Dict[str, str]: + def labels(self) -> dict[str, str]: """The `labels` property. Returns: @@ -869,41 +869,41 @@ def labels(self) -> Dict[str, str]: class ServiceConnectorFilter(UserScopedFilter): """Model to enable advanced filtering of service connectors.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *UserScopedFilter.FILTER_EXCLUDE_FIELDS, "resource_type", "labels_str", "labels", ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *UserScopedFilter.CLI_EXCLUDE_FIELDS, "labels_str", "labels", ] - name: Optional[str] = Field( + name: str | None = Field( default=None, description="The name to filter by", ) - connector_type: Optional[str] = Field( + connector_type: str | None = Field( default=None, description="The type of service connector to filter by", ) - auth_method: Optional[str] = Field( + auth_method: str | None = Field( default=None, title="Filter by the authentication method configured for the " "connector", ) - resource_type: Optional[str] = Field( + resource_type: str | None = Field( default=None, title="Filter by the type of resource that the connector can be used " "to access", ) - resource_id: Optional[str] = Field( + resource_id: str | None = Field( default=None, title="Filter by the ID of the resource instance that the connector " "is configured to access", ) - labels_str: Optional[str] = Field( + labels_str: str | None = Field( default=None, title="Filter by one or more labels. This field can be either a JSON " "formatted dictionary of label names and values, where the values are " @@ -915,7 +915,7 @@ class ServiceConnectorFilter(UserScopedFilter): ) # Use this internally to configure and access the labels as a dictionary - labels: Optional[Dict[str, Optional[str]]] = Field( + labels: dict[str, str | None] | None = Field( default=None, title="The labels to filter by, as a dictionary", exclude=True, @@ -949,15 +949,15 @@ def validate_labels(self) -> "ServiceConnectorFilter": def _validate_and_configure_resources( - connector: Union[ - ServiceConnectorRequest, - ServiceConnectorUpdate, - ServiceConnectorResponse, - ], + connector: ( + ServiceConnectorRequest | + ServiceConnectorUpdate | + ServiceConnectorResponse + ), connector_type: "ServiceConnectorTypeModel", - resource_types: Optional[Union[str, List[str]]] = None, - resource_id: Optional[str] = None, - configuration: Optional[Dict[str, Any]] = None, + resource_types: str | list[str] | None = None, + resource_id: str | None = None, + configuration: dict[str, Any] | None = None, ) -> None: """Validate and configure the resources that a connector can be used to access. @@ -980,16 +980,16 @@ def _validate_and_configure_resources( # and response models. For the request model, the fields are in the # connector model itself, while for the response model, they are in the # metadata field. - update_connector_metadata: Union[ - ServiceConnectorRequest, - ServiceConnectorUpdate, - ServiceConnectorResponseMetadata, - ] - update_connector_body: Union[ - ServiceConnectorRequest, - ServiceConnectorUpdate, - ServiceConnectorResponseBody, - ] + update_connector_metadata: ( + ServiceConnectorRequest | + ServiceConnectorUpdate | + ServiceConnectorResponseMetadata + ) + update_connector_body: ( + ServiceConnectorRequest | + ServiceConnectorUpdate | + ServiceConnectorResponseBody + ) if isinstance(connector, ServiceConnectorRequest): update_connector_metadata = connector update_connector_body = connector diff --git a/src/zenml/models/v2/core/stack.py b/src/zenml/models/v2/core/stack.py index 51297580552..94d94ada439 100644 --- a/src/zenml/models/v2/core/stack.py +++ b/src/zenml/models/v2/core/stack.py @@ -18,12 +18,7 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - List, - Optional, - Type, TypeVar, - Union, ) from uuid import UUID @@ -68,11 +63,11 @@ class StackRequest(UserScopedRequest): title="The description of the stack", max_length=STR_FIELD_MAX_LENGTH, ) - stack_spec_path: Optional[str] = Field( + stack_spec_path: str | None = Field( default=None, title="The path to the stack spec used for mlstacks deployments.", ) - components: Dict[StackComponentType, List[Union[UUID, ComponentInfo]]] = ( + components: dict[StackComponentType, list[UUID | ComponentInfo]] = ( Field( title="The mapping for the components of the full stack registration.", description="The mapping from component types to either UUIDs of " @@ -80,20 +75,20 @@ class StackRequest(UserScopedRequest): "components.", ) ) - environment: Optional[Dict[str, str]] = Field( + environment: dict[str, str] | None = Field( default=None, title="Environment variables to set when running on this stack.", ) - secrets: Optional[List[Union[UUID, str]]] = Field( + secrets: list[UUID | str] | None = Field( default=None, title="Secrets to set as environment variables when running on this " "stack.", ) - labels: Optional[Dict[str, Any]] = Field( + labels: dict[str, Any] | None = Field( default=None, title="The stack labels.", ) - service_connectors: List[Union[UUID, ServiceConnectorInfo]] = Field( + service_connectors: list[UUID | ServiceConnectorInfo] = Field( default=[], title="The service connectors dictionary for the full stack " "registration.", @@ -104,8 +99,8 @@ class StackRequest(UserScopedRequest): @field_validator("components") def _validate_components( - cls, value: Dict[StackComponentType, List[Union[UUID, ComponentInfo]]] - ) -> Dict[StackComponentType, List[Union[UUID, ComponentInfo]]]: + cls, value: dict[StackComponentType, list[UUID | ComponentInfo]] + ) -> dict[StackComponentType, list[UUID | ComponentInfo]]: """Validate the components of the stack. Args: @@ -160,38 +155,38 @@ class DefaultStackRequest(StackRequest): class StackUpdate(BaseUpdate): """Update model for stacks.""" - name: Optional[str] = Field( + name: str | None = Field( title="The name of the stack.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - description: Optional[str] = Field( + description: str | None = Field( title="The description of the stack", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - stack_spec_path: Optional[str] = Field( + stack_spec_path: str | None = Field( title="The path to the stack spec used for mlstacks deployments.", default=None, ) - components: Optional[Dict[StackComponentType, List[UUID]]] = Field( + components: dict[StackComponentType, list[UUID]] | None = Field( title="A mapping of stack component types to the actual" "instances of components of this type.", default=None, ) - environment: Optional[Dict[str, str]] = Field( + environment: dict[str, str] | None = Field( default=None, title="Environment variables to set when running on this stack.", ) - labels: Optional[Dict[str, Any]] = Field( + labels: dict[str, Any] | None = Field( default=None, title="The stack labels.", ) - add_secrets: Optional[List[Union[UUID, str]]] = Field( + add_secrets: list[UUID | str] | None = Field( default=None, title="New secrets to add to the stack.", ) - remove_secrets: Optional[List[Union[UUID, str]]] = Field( + remove_secrets: list[UUID | str] | None = Field( default=None, title="Secrets to remove from the stack.", ) @@ -199,10 +194,10 @@ class StackUpdate(BaseUpdate): @field_validator("components") def _validate_components( cls, - value: Optional[ - Dict[StackComponentType, List[Union[UUID, ComponentInfo]]] - ], - ) -> Optional[Dict[StackComponentType, List[Union[UUID, ComponentInfo]]]]: + value: None | ( + dict[StackComponentType, list[UUID | ComponentInfo]] + ), + ) -> dict[StackComponentType, list[UUID | ComponentInfo]] | None: """Validate the components of the stack. Args: @@ -240,29 +235,29 @@ class StackResponseBody(UserScopedResponseBody): class StackResponseMetadata(UserScopedResponseMetadata): """Response metadata for stacks.""" - components: Dict[StackComponentType, List["ComponentResponse"]] = Field( + components: dict[StackComponentType, list["ComponentResponse"]] = Field( title="A mapping of stack component types to the actual" "instances of components of this type." ) - description: Optional[str] = Field( + description: str | None = Field( default="", title="The description of the stack", max_length=STR_FIELD_MAX_LENGTH, ) - stack_spec_path: Optional[str] = Field( + stack_spec_path: str | None = Field( default=None, title="The path to the stack spec used for mlstacks deployments.", ) - environment: Dict[str, str] = Field( + environment: dict[str, str] = Field( default={}, title="Environment variables to set when running on this stack.", ) - secrets: List[UUID] = Field( + secrets: list[UUID] = Field( default=[], title="Secrets to set as environment variables when running on this " "stack.", ) - labels: Optional[Dict[str, Any]] = Field( + labels: dict[str, Any] | None = Field( default=None, title="The stack labels.", ) @@ -308,7 +303,7 @@ def is_valid(self) -> bool: and StackComponentType.ORCHESTRATOR in self.components ) - def to_yaml(self) -> Dict[str, Any]: + def to_yaml(self) -> dict[str, Any]: """Create yaml representation of the Stack Model. Returns: @@ -340,7 +335,7 @@ def to_yaml(self) -> Dict[str, Any]: return yaml_data # Analytics - def get_analytics_metadata(self) -> Dict[str, Any]: + def get_analytics_metadata(self) -> dict[str, Any]: """Add the stack components to the stack analytics metadata. Returns: @@ -362,7 +357,7 @@ def get_analytics_metadata(self) -> Dict[str, Any]: return metadata @property - def description(self) -> Optional[str]: + def description(self) -> str | None: """The `description` property. Returns: @@ -371,7 +366,7 @@ def description(self) -> Optional[str]: return self.get_metadata().description @property - def stack_spec_path(self) -> Optional[str]: + def stack_spec_path(self) -> str | None: """The `stack_spec_path` property. Returns: @@ -382,7 +377,7 @@ def stack_spec_path(self) -> Optional[str]: @property def components( self, - ) -> Dict[StackComponentType, List["ComponentResponse"]]: + ) -> dict[StackComponentType, list["ComponentResponse"]]: """The `components` property. Returns: @@ -391,7 +386,7 @@ def components( return self.get_metadata().components @property - def environment(self) -> Dict[str, str]: + def environment(self) -> dict[str, str]: """The `environment` property. Returns: @@ -400,7 +395,7 @@ def environment(self) -> Dict[str, str]: return self.get_metadata().environment @property - def secrets(self) -> List[UUID]: + def secrets(self) -> list[UUID]: """The `secrets` property. Returns: @@ -409,7 +404,7 @@ def secrets(self) -> List[UUID]: return self.get_metadata().secrets @property - def labels(self) -> Optional[Dict[str, Any]]: + def labels(self) -> dict[str, Any] | None: """The `labels` property. Returns: @@ -424,31 +419,31 @@ def labels(self) -> Optional[Dict[str, Any]]: class StackFilter(UserScopedFilter): """Model to enable advanced stack filtering.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *UserScopedFilter.FILTER_EXCLUDE_FIELDS, "component_id", "component", ] - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the stack", ) - description: Optional[str] = Field( + description: str | None = Field( default=None, description="Description of the stack" ) - component_id: Optional[Union[UUID, str]] = Field( + component_id: UUID | str | None = Field( default=None, description="Component in the stack", union_mode="left_to_right", ) - component: Optional[Union[UUID, str]] = Field( + component: UUID | str | None = Field( default=None, description="Name/ID of a component in the stack." ) def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List["ColumnElement[bool]"]: + self, table: type["AnySchema"] + ) -> list["ColumnElement[bool]"]: """Get custom filters. Args: diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index 8c831fca4a9..76826ea80e3 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -17,12 +17,8 @@ from typing import ( TYPE_CHECKING, ClassVar, - Dict, - List, Optional, - Type, TypeVar, - Union, ) from uuid import UUID @@ -95,32 +91,32 @@ class StepRunRequest(ProjectScopedRequest): start_time: datetime = Field( title="The start time of the step run.", ) - end_time: Optional[datetime] = Field( + end_time: datetime | None = Field( title="The end time of the step run.", default=None, ) status: ExecutionStatus = Field(title="The status of the step.") - cache_key: Optional[str] = Field( + cache_key: str | None = Field( title="The cache key of the step run.", default=None, max_length=STR_FIELD_MAX_LENGTH, ) - cache_expires_at: Optional[datetime] = Field( + cache_expires_at: datetime | None = Field( title="The time at which this step run should not be used for cached " "results anymore. If not set, the result will never expire.", default=None, ) - code_hash: Optional[str] = Field( + code_hash: str | None = Field( title="The code hash of the step run.", default=None, max_length=STR_FIELD_MAX_LENGTH, ) - docstring: Optional[str] = Field( + docstring: str | None = Field( title="The docstring of the step function or class.", default=None, max_length=TEXT_FIELD_MAX_LENGTH, ) - source_code: Optional[str] = Field( + source_code: str | None = Field( title="The source code of the step function or class.", default=None, max_length=TEXT_FIELD_MAX_LENGTH, @@ -128,20 +124,20 @@ class StepRunRequest(ProjectScopedRequest): pipeline_run_id: UUID = Field( title="The ID of the pipeline run that this step run belongs to.", ) - original_step_run_id: Optional[UUID] = Field( + original_step_run_id: UUID | None = Field( title="The ID of the original step run if this step was cached.", default=None, ) - parent_step_ids: List[UUID] = Field( + parent_step_ids: list[UUID] = Field( title="The IDs of the parent steps of this step run.", default_factory=list, deprecated=True, ) - inputs: Dict[str, List[UUID]] = Field( + inputs: dict[str, list[UUID]] = Field( title="The IDs of the input artifact versions of the step run.", default_factory=dict, ) - outputs: Dict[str, List[UUID]] = Field( + outputs: dict[str, list[UUID]] = Field( title="The IDs of the output artifact versions of the step run.", default_factory=dict, ) @@ -149,7 +145,7 @@ class StepRunRequest(ProjectScopedRequest): title="Logs associated with this step run.", default=None, ) - exception_info: Optional[ExceptionInfo] = Field( + exception_info: ExceptionInfo | None = Field( default=None, title="The exception information of the step run.", ) @@ -163,27 +159,27 @@ class StepRunRequest(ProjectScopedRequest): class StepRunUpdate(BaseUpdate): """Update model for step runs.""" - outputs: Dict[str, List[UUID]] = Field( + outputs: dict[str, list[UUID]] = Field( title="The IDs of the output artifact versions of the step run.", default={}, ) - loaded_artifact_versions: Dict[str, UUID] = Field( + loaded_artifact_versions: dict[str, UUID] = Field( title="The IDs of artifact versions that were loaded by this step run.", default={}, ) - status: Optional[ExecutionStatus] = Field( + status: ExecutionStatus | None = Field( title="The status of the step.", default=None, ) - end_time: Optional[datetime] = Field( + end_time: datetime | None = Field( title="The end time of the step run.", default=None, ) - exception_info: Optional[ExceptionInfo] = Field( + exception_info: ExceptionInfo | None = Field( default=None, title="The exception information of the step run.", ) - cache_expires_at: Optional[datetime] = Field( + cache_expires_at: datetime | None = Field( title="The time at which this step run should not be used for cached " "results anymore.", default=None, @@ -202,20 +198,20 @@ class StepRunResponseBody(ProjectScopedResponseBody): is_retriable: bool = Field( title="Whether the step run is retriable.", ) - start_time: Optional[datetime] = Field( + start_time: datetime | None = Field( title="The start time of the step run.", default=None, ) - end_time: Optional[datetime] = Field( + end_time: datetime | None = Field( title="The end time of the step run.", default=None, ) - model_version_id: Optional[UUID] = Field( + model_version_id: UUID | None = Field( title="The ID of the model version that was " "configured by this step run explicitly.", default=None, ) - substitutions: Dict[str, str] = Field( + substitutions: dict[str, str] = Field( title="The substitutions of the step run.", default={}, ) @@ -225,7 +221,7 @@ class StepRunResponseBody(ProjectScopedResponseBody): class StepRunResponseMetadata(ProjectScopedResponseMetadata): """Response metadata for step runs.""" - __zenml_skip_dehydration__: ClassVar[List[str]] = [ + __zenml_skip_dehydration__: ClassVar[list[str]] = [ "config", "spec", "metadata", @@ -236,32 +232,32 @@ class StepRunResponseMetadata(ProjectScopedResponseMetadata): spec: "StepSpec" = Field(title="The spec of the step.") # Code related fields - cache_key: Optional[str] = Field( + cache_key: str | None = Field( title="The cache key of the step run.", default=None, max_length=STR_FIELD_MAX_LENGTH, ) - cache_expires_at: Optional[datetime] = Field( + cache_expires_at: datetime | None = Field( title="The time at which this step run should not be used for cached " "results anymore. If not set, the result will never expire.", default=None, ) - code_hash: Optional[str] = Field( + code_hash: str | None = Field( title="The code hash of the step run.", default=None, max_length=STR_FIELD_MAX_LENGTH, ) - docstring: Optional[str] = Field( + docstring: str | None = Field( title="The docstring of the step function or class.", default=None, max_length=TEXT_FIELD_MAX_LENGTH, ) - source_code: Optional[str] = Field( + source_code: str | None = Field( title="The source code of the step function or class.", default=None, max_length=TEXT_FIELD_MAX_LENGTH, ) - exception_info: Optional[ExceptionInfo] = Field( + exception_info: ExceptionInfo | None = Field( default=None, title="The exception information of the step run.", ) @@ -277,15 +273,15 @@ class StepRunResponseMetadata(ProjectScopedResponseMetadata): pipeline_run_id: UUID = Field( title="The ID of the pipeline run that this step run belongs to.", ) - original_step_run_id: Optional[UUID] = Field( + original_step_run_id: UUID | None = Field( title="The ID of the original step run if this step was cached.", default=None, ) - parent_step_ids: List[UUID] = Field( + parent_step_ids: list[UUID] = Field( title="The IDs of the parent steps of this step run.", default_factory=list, ) - run_metadata: Dict[str, MetadataType] = Field( + run_metadata: dict[str, MetadataType] = Field( title="Metadata associated with this step run.", default={}, ) @@ -294,12 +290,12 @@ class StepRunResponseMetadata(ProjectScopedResponseMetadata): class StepRunResponseResources(ProjectScopedResponseResources): """Class for all resource models associated with the step run entity.""" - model_version: Optional[ModelVersionResponse] = None - inputs: Dict[str, List[StepRunInputResponse]] = Field( + model_version: ModelVersionResponse | None = None + inputs: dict[str, list[StepRunInputResponse]] = Field( title="The input artifact versions of the step run.", default_factory=dict, ) - outputs: Dict[str, List[ArtifactVersionResponse]] = Field( + outputs: dict[str, list[ArtifactVersionResponse]] = Field( title="The output artifact versions of the step run.", default_factory=dict, ) @@ -380,7 +376,7 @@ def output(self) -> ArtifactVersionResponse: return next(iter(self.outputs.values()))[0] @property - def regular_inputs(self) -> Dict[str, StepRunInputResponse]: + def regular_inputs(self) -> dict[str, StepRunInputResponse]: """Returns the regular step inputs of the step run. Regular step inputs are the inputs that are defined in the step function @@ -412,7 +408,7 @@ def regular_inputs(self) -> Dict[str, StepRunInputResponse]: return result @property - def regular_outputs(self) -> Dict[str, ArtifactVersionResponse]: + def regular_outputs(self) -> dict[str, ArtifactVersionResponse]: """Returns the regular step outputs of the step run. Regular step outputs are the outputs that are defined in the step @@ -473,7 +469,7 @@ def is_retriable(self) -> bool: return self.get_body().is_retriable @property - def inputs(self) -> Dict[str, List[StepRunInputResponse]]: + def inputs(self) -> dict[str, list[StepRunInputResponse]]: """The `inputs` property. Returns: @@ -482,7 +478,7 @@ def inputs(self) -> Dict[str, List[StepRunInputResponse]]: return self.get_resources().inputs @property - def outputs(self) -> Dict[str, List[ArtifactVersionResponse]]: + def outputs(self) -> dict[str, list[ArtifactVersionResponse]]: """The `outputs` property. Returns: @@ -491,7 +487,7 @@ def outputs(self) -> Dict[str, List[ArtifactVersionResponse]]: return self.get_resources().outputs @property - def model_version_id(self) -> Optional[UUID]: + def model_version_id(self) -> UUID | None: """The `model_version_id` property. Returns: @@ -500,7 +496,7 @@ def model_version_id(self) -> Optional[UUID]: return self.get_body().model_version_id @property - def substitutions(self) -> Dict[str, str]: + def substitutions(self) -> dict[str, str]: """The `substitutions` property. Returns: @@ -527,7 +523,7 @@ def spec(self) -> "StepSpec": return self.get_metadata().spec @property - def cache_key(self) -> Optional[str]: + def cache_key(self) -> str | None: """The `cache_key` property. Returns: @@ -536,7 +532,7 @@ def cache_key(self) -> Optional[str]: return self.get_metadata().cache_key @property - def cache_expires_at(self) -> Optional[datetime]: + def cache_expires_at(self) -> datetime | None: """The `cache_expires_at` property. Returns: @@ -545,7 +541,7 @@ def cache_expires_at(self) -> Optional[datetime]: return self.get_metadata().cache_expires_at @property - def code_hash(self) -> Optional[str]: + def code_hash(self) -> str | None: """The `code_hash` property. Returns: @@ -554,7 +550,7 @@ def code_hash(self) -> Optional[str]: return self.get_metadata().code_hash @property - def docstring(self) -> Optional[str]: + def docstring(self) -> str | None: """The `docstring` property. Returns: @@ -563,7 +559,7 @@ def docstring(self) -> Optional[str]: return self.get_metadata().docstring @property - def source_code(self) -> Optional[str]: + def source_code(self) -> str | None: """The `source_code` property. Returns: @@ -572,7 +568,7 @@ def source_code(self) -> Optional[str]: return self.get_metadata().source_code @property - def start_time(self) -> Optional[datetime]: + def start_time(self) -> datetime | None: """The `start_time` property. Returns: @@ -581,7 +577,7 @@ def start_time(self) -> Optional[datetime]: return self.get_body().start_time @property - def end_time(self) -> Optional[datetime]: + def end_time(self) -> datetime | None: """The `end_time` property. Returns: @@ -617,7 +613,7 @@ def pipeline_run_id(self) -> UUID: return self.get_metadata().pipeline_run_id @property - def original_step_run_id(self) -> Optional[UUID]: + def original_step_run_id(self) -> UUID | None: """The `original_step_run_id` property. Returns: @@ -626,7 +622,7 @@ def original_step_run_id(self) -> Optional[UUID]: return self.get_metadata().original_step_run_id @property - def parent_step_ids(self) -> List[UUID]: + def parent_step_ids(self) -> list[UUID]: """The `parent_step_ids` property. Returns: @@ -635,7 +631,7 @@ def parent_step_ids(self) -> List[UUID]: return self.get_metadata().parent_step_ids @property - def run_metadata(self) -> Dict[str, MetadataType]: + def run_metadata(self) -> dict[str, MetadataType]: """The `run_metadata` property. Returns: @@ -644,7 +640,7 @@ def run_metadata(self) -> Dict[str, MetadataType]: return self.get_metadata().run_metadata @property - def model_version(self) -> Optional[ModelVersionResponse]: + def model_version(self) -> ModelVersionResponse | None: """The `model_version` property. Returns: @@ -659,86 +655,86 @@ def model_version(self) -> Optional[ModelVersionResponse]: class StepRunFilter(ProjectScopedFilter, RunMetadataFilterMixin): """Model to enable advanced filtering of step runs.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS, *RunMetadataFilterMixin.FILTER_EXCLUDE_FIELDS, "model", "exclude_retried", "cache_expired", ] - CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + CLI_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.CLI_EXCLUDE_FIELDS, *RunMetadataFilterMixin.CLI_EXCLUDE_FIELDS, ] - CUSTOM_SORTING_OPTIONS: ClassVar[List[str]] = [ + CUSTOM_SORTING_OPTIONS: ClassVar[list[str]] = [ *ProjectScopedFilter.CUSTOM_SORTING_OPTIONS, *RunMetadataFilterMixin.CUSTOM_SORTING_OPTIONS, ] - API_MULTI_INPUT_PARAMS: ClassVar[List[str]] = [ + API_MULTI_INPUT_PARAMS: ClassVar[list[str]] = [ *ProjectScopedFilter.API_MULTI_INPUT_PARAMS, *RunMetadataFilterMixin.API_MULTI_INPUT_PARAMS, ] - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the step run", ) - code_hash: Optional[str] = Field( + code_hash: str | None = Field( default=None, description="Code hash for this step run", ) - cache_key: Optional[str] = Field( + cache_key: str | None = Field( default=None, description="Cache key for this step run", ) - status: Optional[str] = Field( + status: str | None = Field( default=None, description="Status of the Step Run", ) - start_time: Optional[Union[datetime, str]] = Field( + start_time: datetime | str | None = Field( default=None, description="Start time for this run", union_mode="left_to_right", ) - end_time: Optional[Union[datetime, str]] = Field( + end_time: datetime | str | None = Field( default=None, description="End time for this run", union_mode="left_to_right", ) - pipeline_run_id: Optional[Union[UUID, str]] = Field( + pipeline_run_id: UUID | str | None = Field( default=None, description="Pipeline run of this step run", union_mode="left_to_right", ) - snapshot_id: Optional[Union[UUID, str]] = Field( + snapshot_id: UUID | str | None = Field( default=None, description="Snapshot of this step run", union_mode="left_to_right", ) - original_step_run_id: Optional[Union[UUID, str]] = Field( + original_step_run_id: UUID | str | None = Field( default=None, description="Original id for this step run", union_mode="left_to_right", ) - model_version_id: Optional[Union[UUID, str]] = Field( + model_version_id: UUID | str | None = Field( default=None, description="Model version associated with the step run.", union_mode="left_to_right", ) - model: Optional[Union[UUID, str]] = Field( + model: UUID | str | None = Field( default=None, description="Name/ID of the model associated with the step run.", ) - exclude_retried: Optional[bool] = Field( + exclude_retried: bool | None = Field( default=None, description="Whether to exclude retried step runs.", ) - cache_expires_at: Optional[Union[datetime, str]] = Field( + cache_expires_at: datetime | str | None = Field( default=None, description="Cache expiration time of the step run.", union_mode="left_to_right", ) - cache_expired: Optional[bool] = Field( + cache_expired: bool | None = Field( default=None, description="Whether the cache expiration time of the step run has " "passed.", @@ -746,8 +742,8 @@ class StepRunFilter(ProjectScopedFilter, RunMetadataFilterMixin): model_config = ConfigDict(protected_namespaces=()) def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List["ColumnElement[bool]"]: + self, table: type["AnySchema"] + ) -> list["ColumnElement[bool]"]: """Get custom filters. Args: diff --git a/src/zenml/models/v2/core/tag.py b/src/zenml/models/v2/core/tag.py index 08134e31c04..dd32a93a5bd 100644 --- a/src/zenml/models/v2/core/tag.py +++ b/src/zenml/models/v2/core/tag.py @@ -14,7 +14,7 @@ """Models representing tags.""" import random -from typing import TYPE_CHECKING, ClassVar, List, Optional, Type, TypeVar +from typing import TYPE_CHECKING, ClassVar, TypeVar from pydantic import Field, field_validator @@ -86,13 +86,13 @@ def validate_name_not_uuid(cls, value: str) -> str: class TagUpdate(BaseUpdate): """Update model for tags.""" - name: Optional[str] = None - exclusive: Optional[bool] = None - color: Optional[ColorVariants] = None + name: str | None = None + exclusive: bool | None = None + color: ColorVariants | None = None @field_validator("name") @classmethod - def validate_name_not_uuid(cls, value: Optional[str]) -> Optional[str]: + def validate_name_not_uuid(cls, value: str | None) -> str | None: """Validates that the tag name is not a UUID. Args: @@ -195,29 +195,29 @@ def tagged_count(self) -> int: class TagFilter(UserScopedFilter): """Model to enable advanced filtering of all tags.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *UserScopedFilter.FILTER_EXCLUDE_FIELDS, "resource_type", ] - name: Optional[str] = Field( + name: str | None = Field( description="The unique title of the tag.", default=None ) - color: Optional[ColorVariants] = Field( + color: ColorVariants | None = Field( description="The color variant assigned to the tag.", default=None ) - exclusive: Optional[bool] = Field( + exclusive: bool | None = Field( description="The flag signifying whether the tag is an exclusive tag.", default=None, ) - resource_type: Optional[TaggableResourceTypes] = Field( + resource_type: TaggableResourceTypes | None = Field( description="Filter tags associated with a specific resource type.", default=None, ) def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List["ColumnElement[bool]"]: + self, table: type["AnySchema"] + ) -> list["ColumnElement[bool]"]: """Get custom filters. Args: diff --git a/src/zenml/models/v2/core/trigger.py b/src/zenml/models/v2/core/trigger.py index c87c28c7772..0f508652b26 100644 --- a/src/zenml/models/v2/core/trigger.py +++ b/src/zenml/models/v2/core/trigger.py @@ -17,12 +17,8 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - List, Optional, - Type, TypeVar, - Union, ) from uuid import UUID @@ -71,17 +67,17 @@ class TriggerRequest(ProjectScopedRequest): action_id: UUID = Field( title="The action that is executed by this trigger.", ) - schedule: Optional[Schedule] = Field( + schedule: Schedule | None = Field( default=None, title="The schedule for the trigger. Either a schedule or an event " "source is required.", ) - event_source_id: Optional[UUID] = Field( + event_source_id: UUID | None = Field( default=None, title="The event source that activates this trigger. Either a schedule " "or an event source is required.", ) - event_filter: Optional[Dict[str, Any]] = Field( + event_filter: dict[str, Any] | None = Field( default=None, title="Filter applied to events that activate this trigger. Only " "set if the trigger is activated by an event source.", @@ -115,28 +111,28 @@ def _validate_schedule_or_event_source(self) -> "TriggerRequest": class TriggerUpdate(BaseUpdate): """Update model for triggers.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, title="The new name for the trigger.", max_length=STR_FIELD_MAX_LENGTH, ) - description: Optional[str] = Field( + description: str | None = Field( default=None, title="The new description for the trigger.", max_length=STR_FIELD_MAX_LENGTH, ) - event_filter: Optional[Dict[str, Any]] = Field( + event_filter: dict[str, Any] | None = Field( default=None, title="New filter applied to events that activate this trigger. Only " "valid if the trigger is already configured to be activated by an " "event source.", ) - schedule: Optional[Schedule] = Field( + schedule: Schedule | None = Field( default=None, title="The updated schedule for the trigger. Only valid if the trigger " "is already configured to be activated by a schedule.", ) - is_active: Optional[bool] = Field( + is_active: bool | None = Field( default=None, title="The new status of the trigger.", ) @@ -155,13 +151,13 @@ class TriggerResponseBody(ProjectScopedResponseBody): action_subtype: str = Field( title="The subtype of the action that is executed by this trigger.", ) - event_source_flavor: Optional[str] = Field( + event_source_flavor: str | None = Field( default=None, title="The flavor of the event source that activates this trigger. Not " "set if the trigger is activated by a schedule.", max_length=STR_FIELD_MAX_LENGTH, ) - event_source_subtype: Optional[str] = Field( + event_source_subtype: str | None = Field( default=None, title="The subtype of the event source that activates this trigger. " "Not set if the trigger is activated by a schedule.", @@ -180,12 +176,12 @@ class TriggerResponseMetadata(ProjectScopedResponseMetadata): title="The description of the trigger.", max_length=STR_FIELD_MAX_LENGTH, ) - event_filter: Optional[Dict[str, Any]] = Field( + event_filter: dict[str, Any] | None = Field( default=None, title="The event that activates this trigger. Not set if the trigger " "is activated by a schedule.", ) - schedule: Optional[Schedule] = Field( + schedule: Schedule | None = Field( default=None, title="The schedule that activates this trigger. Not set if the " "trigger is activated by an event source.", @@ -249,7 +245,7 @@ def action_subtype(self) -> str: return self.get_body().action_subtype @property - def event_source_flavor(self) -> Optional[str]: + def event_source_flavor(self) -> str | None: """The `event_source_flavor` property. Returns: @@ -258,7 +254,7 @@ def event_source_flavor(self) -> Optional[str]: return self.get_body().event_source_flavor @property - def event_source_subtype(self) -> Optional[str]: + def event_source_subtype(self) -> str | None: """The `event_source_subtype` property. Returns: @@ -276,7 +272,7 @@ def is_active(self) -> bool: return self.get_body().is_active @property - def event_filter(self) -> Optional[Dict[str, Any]]: + def event_filter(self) -> dict[str, Any] | None: """The `event_filter` property. Returns: @@ -327,7 +323,7 @@ def executions(self) -> Page[TriggerExecutionResponse]: class TriggerFilter(ProjectScopedFilter): """Model to enable advanced filtering of all triggers.""" - FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + FILTER_EXCLUDE_FIELDS: ClassVar[list[str]] = [ *ProjectScopedFilter.FILTER_EXCLUDE_FIELDS, "action_flavor", "action_subtype", @@ -335,44 +331,44 @@ class TriggerFilter(ProjectScopedFilter): "event_source_subtype", ] - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the trigger.", ) - event_source_id: Optional[Union[UUID, str]] = Field( + event_source_id: UUID | str | None = Field( default=None, description="The event source this trigger is attached to.", union_mode="left_to_right", ) - action_id: Optional[Union[UUID, str]] = Field( + action_id: UUID | str | None = Field( default=None, description="The action this trigger is attached to.", union_mode="left_to_right", ) - is_active: Optional[bool] = Field( + is_active: bool | None = Field( default=None, description="Whether the trigger is active.", ) - action_flavor: Optional[str] = Field( + action_flavor: str | None = Field( default=None, title="The flavor of the action that is executed by this trigger.", ) - action_subtype: Optional[str] = Field( + action_subtype: str | None = Field( default=None, title="The subtype of the action that is executed by this trigger.", ) - event_source_flavor: Optional[str] = Field( + event_source_flavor: str | None = Field( default=None, title="The flavor of the event source that activates this trigger.", ) - event_source_subtype: Optional[str] = Field( + event_source_subtype: str | None = Field( default=None, title="The subtype of the event source that activates this trigger.", ) def get_custom_filters( - self, table: Type["AnySchema"] - ) -> List["ColumnElement[bool]"]: + self, table: type["AnySchema"] + ) -> list["ColumnElement[bool]"]: """Get custom filters. Args: diff --git a/src/zenml/models/v2/core/trigger_execution.py b/src/zenml/models/v2/core/trigger_execution.py index 98aa2d0e669..9993cf416e5 100644 --- a/src/zenml/models/v2/core/trigger_execution.py +++ b/src/zenml/models/v2/core/trigger_execution.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Collection of all models concerning trigger executions.""" -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any from uuid import UUID from pydantic import Field @@ -39,7 +39,7 @@ class TriggerExecutionRequest(BaseRequest): """Model for creating a new Trigger execution.""" trigger: UUID - event_metadata: Dict[str, Any] = {} + event_metadata: dict[str, Any] = {} # ------------------ Update Model ------------------ @@ -55,7 +55,7 @@ class TriggerExecutionResponseBody(BaseDatedResponseBody): class TriggerExecutionResponseMetadata(BaseResponseMetadata): """Response metadata for trigger executions.""" - event_metadata: Dict[str, Any] = {} + event_metadata: dict[str, Any] = {} class TriggerExecutionResponseResources(BaseResponseResources): @@ -97,7 +97,7 @@ def trigger(self) -> "TriggerResponse": return self.get_resources().trigger @property - def event_metadata(self) -> Dict[str, Any]: + def event_metadata(self) -> dict[str, Any]: """The `event_metadata` property. Returns: @@ -112,7 +112,7 @@ def event_metadata(self) -> Dict[str, Any]: class TriggerExecutionFilter(ProjectScopedFilter): """Model to enable advanced filtering of all trigger executions.""" - trigger_id: Optional[Union[UUID, str]] = Field( + trigger_id: UUID | str | None = Field( default=None, description="ID of the trigger of the execution.", union_mode="left_to_right", diff --git a/src/zenml/models/v2/core/user.py b/src/zenml/models/v2/core/user.py index 361ee6d4104..768657621a9 100644 --- a/src/zenml/models/v2/core/user.py +++ b/src/zenml/models/v2/core/user.py @@ -19,11 +19,6 @@ AbstractSet, Any, ClassVar, - Dict, - List, - Optional, - Type, - Union, ) from uuid import UUID @@ -53,35 +48,35 @@ class UserBase(BaseModel): # Fields - email: Optional[str] = Field( + email: str | None = Field( default=None, title="The email address associated with the account.", max_length=STR_FIELD_MAX_LENGTH, ) - email_opted_in: Optional[bool] = Field( + email_opted_in: bool | None = Field( default=None, title="Whether the user agreed to share their email. Only relevant for " "user accounts", description="`null` if not answered, `true` if agreed, " "`false` if skipped.", ) - password: Optional[str] = Field( + password: str | None = Field( default=None, title="A password for the user.", max_length=STR_FIELD_MAX_LENGTH, ) - activation_token: Optional[str] = Field( + activation_token: str | None = Field( default=None, max_length=STR_FIELD_MAX_LENGTH ) - external_user_id: Optional[UUID] = Field( + external_user_id: UUID | None = Field( default=None, title="The external user ID associated with the account.", ) - user_metadata: Optional[Dict[str, Any]] = Field( + user_metadata: dict[str, Any] | None = Field( default=None, title="The metadata associated with the user.", ) - avatar_url: Optional[str] = Field( + avatar_url: str | None = Field( default=None, title="The avatar URL for the account.", ) @@ -98,7 +93,7 @@ def _get_crypt_context(cls) -> "CryptContext": return CryptContext(schemes=["bcrypt"], deprecated="auto") @classmethod - def _create_hashed_secret(cls, secret: Optional[str]) -> Optional[str]: + def _create_hashed_secret(cls, secret: str | None) -> str | None: """Hashes the input secret and returns the hash value. Only applied if supplied and if not already hashed. @@ -114,7 +109,7 @@ def _create_hashed_secret(cls, secret: Optional[str]) -> Optional[str]: pwd_context = cls._get_crypt_context() return pwd_context.hash(secret) - def create_hashed_password(self) -> Optional[str]: + def create_hashed_password(self) -> str | None: """Hashes the password. Returns: @@ -122,7 +117,7 @@ def create_hashed_password(self) -> Optional[str]: """ return self._create_hashed_secret(self.password) - def create_hashed_activation_token(self) -> Optional[str]: + def create_hashed_activation_token(self) -> str | None: """Hashes the activation token. Returns: @@ -147,7 +142,7 @@ class UserRequest(UserBase, BaseRequest): """Request model for users.""" # Analytics fields for user request models - ANALYTICS_FIELDS: ClassVar[List[str]] = [ + ANALYTICS_FIELDS: ClassVar[list[str]] = [ "name", "full_name", "active", @@ -182,30 +177,30 @@ class UserRequest(UserBase, BaseRequest): class UserUpdate(UserBase, BaseUpdate): """Update model for users.""" - name: Optional[str] = Field( + name: str | None = Field( title="The unique username for the account.", max_length=STR_FIELD_MAX_LENGTH, default=None, ) - full_name: Optional[str] = Field( + full_name: str | None = Field( default=None, title="The display name for the account.", max_length=STR_FIELD_MAX_LENGTH, ) - is_admin: Optional[bool] = Field( + is_admin: bool | None = Field( default=None, title="Whether the account is an administrator.", ) - active: Optional[bool] = Field( + active: bool | None = Field( default=None, title="Whether the account is active." ) - old_password: Optional[str] = Field( + old_password: str | None = Field( default=None, title="The previous password for the user. Only relevant for user " "accounts. Required when updating the password.", max_length=STR_FIELD_MAX_LENGTH, ) - default_project_id: Optional[UUID] = Field( + default_project_id: UUID | None = Field( default=None, title="The default project ID for the user.", ) @@ -260,7 +255,7 @@ class UserResponseBody(BaseDatedResponseBody): """Response body for users.""" active: bool = Field(default=False, title="Whether the account is active.") - activation_token: Optional[str] = Field( + activation_token: str | None = Field( default=None, max_length=STR_FIELD_MAX_LENGTH, title="The activation token for the user. Only relevant for user " @@ -271,7 +266,7 @@ class UserResponseBody(BaseDatedResponseBody): title="The display name for the account.", max_length=STR_FIELD_MAX_LENGTH, ) - email_opted_in: Optional[bool] = Field( + email_opted_in: bool | None = Field( default=None, title="Whether the user agreed to share their email. Only relevant for " "user accounts", @@ -284,11 +279,11 @@ class UserResponseBody(BaseDatedResponseBody): is_admin: bool = Field( title="Whether the account is an administrator.", ) - default_project_id: Optional[UUID] = Field( + default_project_id: UUID | None = Field( default=None, title="The default project ID for the user.", ) - avatar_url: Optional[str] = Field( + avatar_url: str | None = Field( default=None, title="The avatar URL for the account.", ) @@ -297,18 +292,18 @@ class UserResponseBody(BaseDatedResponseBody): class UserResponseMetadata(BaseResponseMetadata): """Response metadata for users.""" - email: Optional[str] = Field( + email: str | None = Field( default="", title="The email address associated with the account. Only relevant " "for user accounts.", max_length=STR_FIELD_MAX_LENGTH, ) - external_user_id: Optional[UUID] = Field( + external_user_id: UUID | None = Field( default=None, title="The external user ID associated with the account. Only relevant " "for user accounts.", ) - user_metadata: Dict[str, Any] = Field( + user_metadata: dict[str, Any] = Field( default={}, title="The metadata associated with the user.", ) @@ -330,7 +325,7 @@ class UserResponse( well for use by the analytics on the client-side. """ - ANALYTICS_FIELDS: ClassVar[List[str]] = [ + ANALYTICS_FIELDS: ClassVar[list[str]] = [ "name", "full_name", "active", @@ -364,7 +359,7 @@ def active(self) -> bool: return self.get_body().active @property - def activation_token(self) -> Optional[str]: + def activation_token(self) -> str | None: """The `activation_token` property. Returns: @@ -382,7 +377,7 @@ def full_name(self) -> str: return self.get_body().full_name @property - def email_opted_in(self) -> Optional[bool]: + def email_opted_in(self) -> bool | None: """The `email_opted_in` property. Returns: @@ -409,7 +404,7 @@ def is_admin(self) -> bool: return self.get_body().is_admin @property - def email(self) -> Optional[str]: + def email(self) -> str | None: """The `email` property. Returns: @@ -418,7 +413,7 @@ def email(self) -> Optional[str]: return self.get_metadata().email @property - def external_user_id(self) -> Optional[UUID]: + def external_user_id(self) -> UUID | None: """The `external_user_id` property. Returns: @@ -427,7 +422,7 @@ def external_user_id(self) -> Optional[UUID]: return self.get_metadata().external_user_id @property - def user_metadata(self) -> Dict[str, Any]: + def user_metadata(self) -> dict[str, Any]: """The `user_metadata` property. Returns: @@ -436,7 +431,7 @@ def user_metadata(self) -> Dict[str, Any]: return self.get_metadata().user_metadata @property - def default_project_id(self) -> Optional[UUID]: + def default_project_id(self) -> UUID | None: """The `default_project_id` property. Returns: @@ -445,7 +440,7 @@ def default_project_id(self) -> Optional[UUID]: return self.get_body().default_project_id @property - def avatar_url(self) -> Optional[str]: + def avatar_url(self) -> str | None: """The `avatar_url` property. Returns: @@ -472,29 +467,29 @@ def _get_crypt_context(cls) -> "CryptContext": class UserFilter(BaseFilter): """Model to enable advanced filtering of all Users.""" - name: Optional[str] = Field( + name: str | None = Field( default=None, description="Name of the user", ) - full_name: Optional[str] = Field( + full_name: str | None = Field( default=None, description="Full Name of the user", ) - email: Optional[str] = Field( + email: str | None = Field( default=None, description="Email of the user", ) - active: Optional[Union[bool, str]] = Field( + active: bool | str | None = Field( default=None, description="Whether the user is active", union_mode="left_to_right", ) - email_opted_in: Optional[Union[bool, str]] = Field( + email_opted_in: bool | str | None = Field( default=None, description="Whether the user has opted in to emails", union_mode="left_to_right", ) - external_user_id: Optional[Union[UUID, str]] = Field( + external_user_id: UUID | str | None = Field( default=None, title="The external user ID associated with the account.", union_mode="left_to_right", @@ -503,7 +498,7 @@ class UserFilter(BaseFilter): def apply_filter( self, query: AnyQuery, - table: Type["AnySchema"], + table: type["AnySchema"], ) -> AnyQuery: """Override to filter out service accounts from the query. diff --git a/src/zenml/models/v2/misc/auth_models.py b/src/zenml/models/v2/misc/auth_models.py index ab6f3a78762..9f84de2e698 100644 --- a/src/zenml/models/v2/misc/auth_models.py +++ b/src/zenml/models/v2/misc/auth_models.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing OAuth2 requests and responses.""" -from typing import Any, Dict, Optional +from typing import Any from uuid import UUID from pydantic import BaseModel, ConfigDict @@ -29,7 +29,7 @@ class OAuthDeviceAuthorizationRequest(BaseModel): """OAuth2 device authorization grant request.""" client_id: UUID - device_id: Optional[UUID] = None + device_id: UUID | None = None class OAuthDeviceVerificationRequest(BaseModel): @@ -50,10 +50,10 @@ class OAuthDeviceTokenRequest(BaseModel): class OAuthDeviceUserAgentHeader(BaseModel): """OAuth2 device user agent header.""" - hostname: Optional[str] = None - os: Optional[str] = None - python_version: Optional[str] = None - zenml_version: Optional[str] = None + hostname: str | None = None + os: str | None = None + python_version: str | None = None + zenml_version: str | None = None @classmethod def decode(cls, header_str: str) -> "OAuthDeviceUserAgentHeader": @@ -107,7 +107,7 @@ class OAuthDeviceAuthorizationResponse(BaseModel): device_code: str user_code: str verification_uri: str - verification_uri_complete: Optional[str] = None + verification_uri_complete: str | None = None expires_in: int interval: int @@ -117,12 +117,12 @@ class OAuthTokenResponse(BaseModel): access_token: str token_type: str - expires_in: Optional[int] = None - refresh_token: Optional[str] = None - csrf_token: Optional[str] = None - scope: Optional[str] = None - device_id: Optional[UUID] = None - device_metadata: Optional[Dict[str, Any]] = None + expires_in: int | None = None + refresh_token: str | None = None + csrf_token: str | None = None + scope: str | None = None + device_id: UUID | None = None + device_metadata: dict[str, Any] | None = None model_config = ConfigDict( # Allow extra attributes to allow compatibility with different versions diff --git a/src/zenml/models/v2/misc/build_item.py b/src/zenml/models/v2/misc/build_item.py index 13d35ddefd9..2227aa30613 100644 --- a/src/zenml/models/v2/misc/build_item.py +++ b/src/zenml/models/v2/misc/build_item.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Model definition for pipeline build item.""" -from typing import Optional from pydantic import BaseModel, Field @@ -32,13 +31,13 @@ class BuildItem(BaseModel): """ image: str = Field(title="The image name or digest.") - dockerfile: Optional[str] = Field( + dockerfile: str | None = Field( default=None, title="The dockerfile used to build the image." ) - requirements: Optional[str] = Field( + requirements: str | None = Field( default=None, title="The pip requirements installed in the image." ) - settings_checksum: Optional[str] = Field( + settings_checksum: str | None = Field( default=None, title="The checksum of the build settings." ) contains_code: bool = Field( diff --git a/src/zenml/models/v2/misc/exception_info.py b/src/zenml/models/v2/misc/exception_info.py index 878880246b2..c056a7266d0 100644 --- a/src/zenml/models/v2/misc/exception_info.py +++ b/src/zenml/models/v2/misc/exception_info.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Exception information models.""" -from typing import Optional from pydantic import BaseModel, Field @@ -24,7 +23,7 @@ class ExceptionInfo(BaseModel): traceback: str = Field( title="The traceback of the exception.", ) - step_code_line: Optional[int] = Field( + step_code_line: int | None = Field( default=None, title="The line number of the step code that raised the exception.", ) diff --git a/src/zenml/models/v2/misc/external_user.py b/src/zenml/models/v2/misc/external_user.py index f6a7110d8ce..6859ccdea25 100644 --- a/src/zenml/models/v2/misc/external_user.py +++ b/src/zenml/models/v2/misc/external_user.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Models representing users.""" -from typing import Optional from uuid import UUID from pydantic import BaseModel, ConfigDict @@ -24,10 +23,10 @@ class ExternalUserModel(BaseModel): id: UUID username: str - email: Optional[str] = None - name: Optional[str] = None + email: str | None = None + name: str | None = None is_admin: bool = False is_service_account: bool = False - avatar_url: Optional[str] = None + avatar_url: str | None = None model_config = ConfigDict(extra="ignore") diff --git a/src/zenml/models/v2/misc/info_models.py b/src/zenml/models/v2/misc/info_models.py index 5fb60228db0..048c82f59d8 100644 --- a/src/zenml/models/v2/misc/info_models.py +++ b/src/zenml/models/v2/misc/info_models.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing full stack requests.""" -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, Field, model_validator @@ -28,22 +28,22 @@ class ServiceConnectorInfo(BaseModel): type: str auth_method: str - configuration: Dict[str, Any] = {} + configuration: dict[str, Any] = {} class ComponentInfo(BaseModel): """Information about each stack components when creating a full stack.""" flavor: str - service_connector_index: Optional[int] = Field( + service_connector_index: int | None = Field( default=None, title="The id of the service connector from the list " "`service_connectors`.", description="The id of the service connector from the list " "`service_connectors` from `FullStackRequest`.", ) - service_connector_resource_id: Optional[str] = None - configuration: Dict[str, Any] = {} + service_connector_resource_id: str | None = None + configuration: dict[str, Any] = {} class ResourcesInfo(BaseModel): @@ -51,11 +51,11 @@ class ResourcesInfo(BaseModel): flavor: str flavor_display_name: str - required_configuration: Dict[str, str] = {} + required_configuration: dict[str, str] = {} use_resource_value_as_fixed_config: bool = False - accessible_by_service_connector: List[str] - connected_through_service_connector: List["ComponentResponse"] + accessible_by_service_connector: list[str] + connected_through_service_connector: list["ComponentResponse"] @model_validator(mode="after") def _validate_resource_info(self) -> "ResourcesInfo": @@ -75,4 +75,4 @@ class ServiceConnectorResourcesInfo(BaseModel): connector_type: str - components_resources_info: Dict[StackComponentType, List[ResourcesInfo]] + components_resources_info: dict[StackComponentType, list[ResourcesInfo]] diff --git a/src/zenml/models/v2/misc/loaded_visualization.py b/src/zenml/models/v2/misc/loaded_visualization.py index 419dfa76f4e..5d7c077305c 100644 --- a/src/zenml/models/v2/misc/loaded_visualization.py +++ b/src/zenml/models/v2/misc/loaded_visualization.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Model representing loaded visualizations.""" -from typing import Union from pydantic import BaseModel, Field @@ -24,4 +23,4 @@ class LoadedVisualization(BaseModel): """Model for loaded visualizations.""" type: VisualizationType - value: Union[str, bytes] = Field(union_mode="left_to_right") + value: str | bytes = Field(union_mode="left_to_right") diff --git a/src/zenml/models/v2/misc/param_groups.py b/src/zenml/models/v2/misc/param_groups.py index 9df32db6c07..2ab331f6cd3 100644 --- a/src/zenml/models/v2/misc/param_groups.py +++ b/src/zenml/models/v2/misc/param_groups.py @@ -48,13 +48,11 @@ def _validate_options(self) -> "VersionedIdentifier": class ArtifactVersionIdentifier(VersionedIdentifier): """Class for artifact version identifier group.""" - pass class ModelVersionIdentifier(VersionedIdentifier): """Class for model version identifier group.""" - pass class PipelineRunIdentifier(BaseModel): diff --git a/src/zenml/models/v2/misc/pipeline_run_dag.py b/src/zenml/models/v2/misc/pipeline_run_dag.py index 4dcda9a347b..4e7fdfb41ed 100644 --- a/src/zenml/models/v2/misc/pipeline_run_dag.py +++ b/src/zenml/models/v2/misc/pipeline_run_dag.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Pipeline run DAG models.""" -from typing import Any, Dict, List, Optional +from typing import Any from uuid import UUID from pydantic import BaseModel @@ -26,21 +26,21 @@ class PipelineRunDAG(BaseModel): id: UUID status: ExecutionStatus - nodes: List["Node"] - edges: List["Edge"] + nodes: list["Node"] + edges: list["Edge"] class Node(BaseModel): """Node in the pipeline run DAG.""" node_id: str type: str - id: Optional[UUID] = None + id: UUID | None = None name: str - metadata: Dict[str, Any] = {} + metadata: dict[str, Any] = {} class Edge(BaseModel): """Edge in the pipeline run DAG.""" source: str target: str - metadata: Dict[str, Any] = {} + metadata: dict[str, Any] = {} diff --git a/src/zenml/models/v2/misc/server_models.py b/src/zenml/models/v2/misc/server_models.py index bdc62b3417b..02a85e1b349 100644 --- a/src/zenml/models/v2/misc/server_models.py +++ b/src/zenml/models/v2/misc/server_models.py @@ -14,7 +14,6 @@ """Model definitions for ZenML servers.""" from datetime import datetime -from typing import Dict, Optional from uuid import UUID, uuid4 from pydantic import BaseModel, Field @@ -52,7 +51,7 @@ class ServerModel(BaseModel): id: UUID = Field(default_factory=uuid4, title="The unique server id.") - name: Optional[str] = Field(None, title="The name of the ZenML server.") + name: str | None = Field(None, title="The name of the ZenML server.") version: str = Field( title="The ZenML version that the server is running.", @@ -97,47 +96,47 @@ class ServerModel(BaseModel): title="Enable server-side analytics.", ) - metadata: Dict[str, str] = Field( + metadata: dict[str, str] = Field( {}, title="The metadata associated with the server.", ) - last_user_activity: Optional[datetime] = Field( + last_user_activity: datetime | None = Field( None, title="Timestamp of latest user activity traced on the server.", ) - pro_dashboard_url: Optional[str] = Field( + pro_dashboard_url: str | None = Field( None, title="The base URL of the ZenML Pro dashboard to which the server " "is connected. Only set if the server is a ZenML Pro server.", ) - pro_api_url: Optional[str] = Field( + pro_api_url: str | None = Field( None, title="The base URL of the ZenML Pro API to which the server is " "connected. Only set if the server is a ZenML Pro server.", ) - pro_organization_id: Optional[UUID] = Field( + pro_organization_id: UUID | None = Field( None, title="The ID of the ZenML Pro organization to which the server is " "connected. Only set if the server is a ZenML Pro server.", ) - pro_organization_name: Optional[str] = Field( + pro_organization_name: str | None = Field( None, title="The name of the ZenML Pro organization to which the server is " "connected. Only set if the server is a ZenML Pro server.", ) - pro_workspace_id: Optional[UUID] = Field( + pro_workspace_id: UUID | None = Field( None, title="The ID of the ZenML Pro workspace to which the server is " "connected. Only set if the server is a ZenML Pro server.", ) - pro_workspace_name: Optional[str] = Field( + pro_workspace_name: str | None = Field( None, title="The name of the ZenML Pro workspace to which the server is " "connected. Only set if the server is a ZenML Pro server.", diff --git a/src/zenml/models/v2/misc/service_connector_type.py b/src/zenml/models/v2/misc/service_connector_type.py index 70195bb59d5..e4934ca6c93 100644 --- a/src/zenml/models/v2/misc/service_connector_type.py +++ b/src/zenml/models/v2/misc/service_connector_type.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Model definitions for ZenML service connectors.""" -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Union from uuid import UUID from pydantic import BaseModel, Field, field_validator @@ -49,7 +49,7 @@ class ResourceTypeModel(BaseModel): default="", title="A description of the resource type.", ) - auth_methods: List[str] = Field( + auth_methods: list[str] = Field( title="The list of authentication methods that can be used to access " "resources of this type.", ) @@ -64,12 +64,12 @@ class ResourceTypeModel(BaseModel): "access a single resource and a resource ID is not required to access " "the resource.", ) - logo_url: Optional[str] = Field( + logo_url: str | None = Field( default=None, title="Optionally, a URL pointing to a png," "svg or jpg file can be attached.", ) - emoji: Optional[str] = Field( + emoji: str | None = Field( default=None, title="Optionally, a python-rich emoji can be attached.", ) @@ -104,33 +104,33 @@ class AuthenticationMethodModel(BaseModel): default="", title="A description of the authentication method.", ) - config_schema: Dict[str, Any] = Field( + config_schema: dict[str, Any] = Field( default_factory=dict, title="The JSON schema of the configuration for this authentication " "method.", ) - min_expiration_seconds: Optional[int] = Field( + min_expiration_seconds: int | None = Field( default=None, title="The minimum number of seconds that the authentication " "session can be configured to be valid for. Set to None for " "authentication sessions and long-lived credentials that don't expire.", ) - max_expiration_seconds: Optional[int] = Field( + max_expiration_seconds: int | None = Field( default=None, title="The maximum number of seconds that the authentication " "session can be configured to be valid for. Set to None for " "authentication sessions and long-lived credentials that don't expire.", ) - default_expiration_seconds: Optional[int] = Field( + default_expiration_seconds: int | None = Field( default=None, title="The default number of seconds that the authentication " "session is valid for. Set to None for authentication sessions and " "long-lived credentials that don't expire.", ) - _config_class: Optional[Type[BaseModel]] = None + _config_class: type[BaseModel] | None = None def __init__( - self, config_class: Optional[Type[BaseModel]] = None, **values: Any + self, config_class: type[BaseModel] | None = None, **values: Any ): """Initialize the authentication method. @@ -146,7 +146,7 @@ def __init__( self._config_class = config_class @property - def config_class(self) -> Optional[Type[BaseModel]]: + def config_class(self) -> type[BaseModel] | None: """Get the configuration class for the authentication method. Returns: @@ -168,8 +168,8 @@ def supports_temporary_credentials(self) -> bool: ) def validate_expiration( - self, expiration_seconds: Optional[int] - ) -> Optional[int]: + self, expiration_seconds: int | None + ) -> int | None: """Validate the expiration time. Args: @@ -243,11 +243,11 @@ class ServiceConnectorTypeModel(BaseModel): default="", title="A description of the service connector.", ) - resource_types: List[ResourceTypeModel] = Field( + resource_types: list[ResourceTypeModel] = Field( title="A list of resource types that the connector can be used to " "access.", ) - auth_methods: List[AuthenticationMethodModel] = Field( + auth_methods: list[AuthenticationMethodModel] = Field( title="A list of specifications describing the authentication " "methods that are supported by the service connector, along with the " "configuration and secrets attributes that need to be configured for " @@ -258,20 +258,20 @@ class ServiceConnectorTypeModel(BaseModel): title="Models if the connector can be configured automatically based " "on information extracted from a local environment.", ) - logo_url: Optional[str] = Field( + logo_url: str | None = Field( default=None, title="Optionally, a URL pointing to a png," "svg or jpg can be attached.", ) - emoji: Optional[str] = Field( + emoji: str | None = Field( default=None, title="Optionally, a python-rich emoji can be attached.", ) - docs_url: Optional[str] = Field( + docs_url: str | None = Field( default=None, title="Optionally, a URL pointing to docs, within docs.zenml.io.", ) - sdk_docs_url: Optional[str] = Field( + sdk_docs_url: str | None = Field( default=None, title="Optionally, a URL pointing to SDK docs," "within sdkdocs.zenml.io.", @@ -284,10 +284,10 @@ class ServiceConnectorTypeModel(BaseModel): default=False, title="If True, the service connector is available remotely.", ) - _connector_class: Optional[Type["ServiceConnector"]] = None + _connector_class: type["ServiceConnector"] | None = None @property - def connector_class(self) -> Optional[Type["ServiceConnector"]]: + def connector_class(self) -> type["ServiceConnector"] | None: """Get the service connector class. Returns: @@ -307,7 +307,7 @@ def emojified_connector_type(self) -> str: return f"{self.emoji} {self.connector_type}" @property - def emojified_resource_types(self) -> List[str]: + def emojified_resource_types(self) -> list[str]: """Get the emojified connector types. Returns: @@ -319,7 +319,7 @@ def emojified_resource_types(self) -> List[str]: ] def set_connector_class( - self, connector_class: Type["ServiceConnector"] + self, connector_class: type["ServiceConnector"] ) -> None: """Set the service connector class. @@ -331,8 +331,8 @@ def set_connector_class( @field_validator("resource_types") @classmethod def validate_resource_types( - cls, values: List[ResourceTypeModel] - ) -> List[ResourceTypeModel]: + cls, values: list[ResourceTypeModel] + ) -> list[ResourceTypeModel]: """Validate that the resource types are unique. Args: @@ -359,8 +359,8 @@ def validate_resource_types( @field_validator("auth_methods") @classmethod def validate_auth_methods( - cls, values: List[AuthenticationMethodModel] - ) -> List[AuthenticationMethodModel]: + cls, values: list[AuthenticationMethodModel] + ) -> list[AuthenticationMethodModel]: """Validate that the authentication methods are unique. Args: @@ -387,7 +387,7 @@ def validate_auth_methods( @property def resource_type_dict( self, - ) -> Dict[str, ResourceTypeModel]: + ) -> dict[str, ResourceTypeModel]: """Returns a map of resource types to resource type specifications. Returns: @@ -398,7 +398,7 @@ def resource_type_dict( @property def auth_method_dict( self, - ) -> Dict[str, AuthenticationMethodModel]: + ) -> dict[str, AuthenticationMethodModel]: """Returns a map of authentication methods to authentication method specifications. Returns: @@ -410,8 +410,8 @@ def auth_method_dict( def find_resource_specifications( self, auth_method: str, - resource_type: Optional[str] = None, - ) -> Tuple[AuthenticationMethodModel, Optional[ResourceTypeModel]]: + resource_type: str | None = None, + ) -> tuple[AuthenticationMethodModel, ResourceTypeModel | None]: """Find the specifications for a configurable resource. Validate the supplied connector configuration parameters against the @@ -486,9 +486,9 @@ class ServiceConnectorRequirements(BaseModel): the service connector instance must be able to access. """ - connector_type: Optional[str] = None + connector_type: str | None = None resource_type: str - resource_id_attr: Optional[str] = None + resource_id_attr: str | None = None def is_satisfied_by( self, @@ -496,7 +496,7 @@ def is_satisfied_by( "ServiceConnectorResponse", "ServiceConnectorRequest" ], component: Union["ComponentResponse", "ComponentBase"], - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: """Check if the requirements are satisfied by a connector. Args: @@ -551,7 +551,7 @@ class ServiceConnectorTypedResourcesModel(BaseModel): max_length=STR_FIELD_MAX_LENGTH, ) - resource_ids: Optional[List[str]] = Field( + resource_ids: list[str] | None = Field( default=None, title="The resource IDs of all resource instances that the service " "connector instance can be used to access. Omitted (set to None) for " @@ -564,7 +564,7 @@ class ServiceConnectorTypedResourcesModel(BaseModel): "listed.", ) - error: Optional[str] = Field( + error: str | None = Field( default=None, title="An error message describing why the service connector instance " "could not list the resources that it is configured to access.", @@ -578,13 +578,13 @@ class ServiceConnectorResourcesModel(BaseModel): can provide access to. """ - id: Optional[UUID] = Field( + id: UUID | None = Field( default=None, title="The ID of the service connector instance providing this " "resource.", ) - name: Optional[str] = Field( + name: str | None = Field( default=None, title="The name of the service connector instance providing this " "resource.", @@ -595,21 +595,21 @@ class ServiceConnectorResourcesModel(BaseModel): title="The type of service connector.", union_mode="left_to_right" ) - resources: List[ServiceConnectorTypedResourcesModel] = Field( + resources: list[ServiceConnectorTypedResourcesModel] = Field( default_factory=list, title="The list of resources that the service connector instance can " "give access to. Contains one entry for every resource type " "that the connector is configured for.", ) - error: Optional[str] = Field( + error: str | None = Field( default=None, title="A global error message describing why the service connector " "instance could not authenticate to the remote service.", ) @property - def resources_dict(self) -> Dict[str, ServiceConnectorTypedResourcesModel]: + def resources_dict(self) -> dict[str, ServiceConnectorTypedResourcesModel]: """Get the resources as a dictionary indexed by resource type. Returns: @@ -620,7 +620,7 @@ def resources_dict(self) -> Dict[str, ServiceConnectorTypedResourcesModel]: } @property - def resource_types(self) -> List[str]: + def resource_types(self) -> list[str]: """Get the resource types. Returns: @@ -629,7 +629,7 @@ def resource_types(self) -> List[str]: return [resource.resource_type for resource in self.resources] def set_error( - self, error: str, resource_type: Optional[str] = None + self, error: str, resource_type: str | None = None ) -> None: """Set a global error message or an error for a single resource type. @@ -662,7 +662,7 @@ def set_error( resource.resource_ids = None def set_resource_ids( - self, resource_type: str, resource_ids: List[str] + self, resource_type: str, resource_ids: list[str] ) -> None: """Set the resource IDs for a resource type. @@ -706,8 +706,8 @@ def emojified_connector_type(self) -> str: return self.connector_type def get_emojified_resource_types( - self, resource_type: Optional[str] = None - ) -> List[str]: + self, resource_type: str | None = None + ) -> list[str]: """Get the emojified resource type. Args: @@ -736,7 +736,7 @@ def get_emojified_resource_types( return [resource_type] return list(self.resources_dict.keys()) - def get_default_resource_id(self) -> Optional[str]: + def get_default_resource_id(self) -> str | None: """Get the default resource ID, if included in the resource list. The default resource ID is a resource ID supplied by the connector @@ -773,7 +773,7 @@ def get_default_resource_id(self) -> Optional[str]: def from_connector_model( cls, connector_model: "ServiceConnectorResponse", - resource_type: Optional[str] = None, + resource_type: str | None = None, ) -> "ServiceConnectorResourcesModel": """Initialize a resource model from a connector model. diff --git a/src/zenml/models/v2/misc/stack_deployment.py b/src/zenml/models/v2/misc/stack_deployment.py index d8a7950e1ae..015d8fb07e7 100644 --- a/src/zenml/models/v2/misc/stack_deployment.py +++ b/src/zenml/models/v2/misc/stack_deployment.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Models related to cloud stack deployments.""" -from typing import Dict, List, Optional from pydantic import BaseModel, Field @@ -40,22 +39,22 @@ class StackDeploymentInfo(BaseModel): title="The instructions for post-deployment.", description="The instructions for post-deployment.", ) - integrations: List[str] = Field( + integrations: list[str] = Field( title="ZenML integrations required for the stack.", description="The list of ZenML integrations that need to be installed " "for the stack to be usable.", ) - permissions: Dict[str, List[str]] = Field( + permissions: dict[str, list[str]] = Field( title="The permissions granted to ZenML to access the cloud resources.", description="The permissions granted to ZenML to access the cloud " "resources, as a dictionary grouping permissions by resource.", ) - locations: Dict[str, str] = Field( + locations: dict[str, str] = Field( title="The locations where the stack can be deployed.", description="The locations where the stack can be deployed, as a " "dictionary mapping location names to descriptions.", ) - skypilot_default_regions: Dict[str, str] = Field( + skypilot_default_regions: dict[str, str] = Field( title="The locations where the Skypilot clusters can be deployed by default.", description="The locations where the Skypilot clusters can be deployed by default, as a " "dictionary mapping location names to descriptions.", @@ -71,12 +70,12 @@ class StackDeploymentConfig(BaseModel): deployment_url_text: str = Field( title="A textual description for the cloud provider console URL.", ) - configuration: Optional[str] = Field( + configuration: str | None = Field( default=None, title="Configuration for the stack deployment that the user must " "manually configure into the cloud provider console.", ) - instructions: Optional[str] = Field( + instructions: str | None = Field( default=None, title="Instructions for deploying the stack.", ) @@ -89,7 +88,7 @@ class DeployedStack(BaseModel): title="The stack that was deployed.", description="The stack that was deployed.", ) - service_connector: Optional[ServiceConnectorResponse] = Field( + service_connector: ServiceConnectorResponse | None = Field( default=None, title="The service connector for the deployed stack.", description="The service connector for the deployed stack.", diff --git a/src/zenml/models/v2/misc/user_auth.py b/src/zenml/models/v2/misc/user_auth.py index a98566dcf2b..c31c4699108 100644 --- a/src/zenml/models/v2/misc/user_auth.py +++ b/src/zenml/models/v2/misc/user_auth.py @@ -48,10 +48,10 @@ class UserAuthModel(BaseZenModel): "account." ) - activation_token: Optional[PlainSerializedSecretStr] = Field( + activation_token: PlainSerializedSecretStr | None = Field( default=None, exclude=True ) - password: Optional[PlainSerializedSecretStr] = Field( + password: PlainSerializedSecretStr | None = Field( default=None, exclude=True ) name: str = Field( @@ -65,7 +65,7 @@ class UserAuthModel(BaseZenModel): max_length=STR_FIELD_MAX_LENGTH, ) - email_opted_in: Optional[bool] = Field( + email_opted_in: bool | None = Field( default=None, title="Whether the user agreed to share their email. Only relevant for " "user accounts", @@ -100,7 +100,7 @@ def _is_hashed_secret(cls, secret: SecretStr) -> bool: ) @classmethod - def _get_hashed_secret(cls, secret: Optional[SecretStr]) -> Optional[str]: + def _get_hashed_secret(cls, secret: SecretStr | None) -> str | None: """Hashes the input secret and returns the hash value. Only applied if supplied and if not already hashed. @@ -118,7 +118,7 @@ def _get_hashed_secret(cls, secret: Optional[SecretStr]) -> Optional[str]: pwd_context = cls._get_crypt_context() return pwd_context.hash(secret.get_secret_value()) - def get_password(self) -> Optional[str]: + def get_password(self) -> str | None: """Get the password. Returns: @@ -128,7 +128,7 @@ def get_password(self) -> Optional[str]: return None return self.password.get_secret_value() - def get_hashed_password(self) -> Optional[str]: + def get_hashed_password(self) -> str | None: """Returns the hashed password, if configured. Returns: @@ -136,7 +136,7 @@ def get_hashed_password(self) -> Optional[str]: """ return self._get_hashed_secret(self.password) - def get_hashed_activation_token(self) -> Optional[str]: + def get_hashed_activation_token(self) -> str | None: """Returns the hashed activation token, if configured. Returns: @@ -160,7 +160,7 @@ def verify_password( # even when the user or password is not set, we still want to execute # the password hash verification to protect against response discrepancy # attacks (https://cwe.mitre.org/data/definitions/204.html) - password_hash: Optional[str] = None + password_hash: str | None = None if ( user is not None # Disable password verification for service accounts as an extra diff --git a/src/zenml/orchestrators/base_orchestrator.py b/src/zenml/orchestrators/base_orchestrator.py index 5e7422ff6ce..312a7c31b8a 100644 --- a/src/zenml/orchestrators/base_orchestrator.py +++ b/src/zenml/orchestrators/base_orchestrator.py @@ -18,15 +18,10 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - Iterator, - List, Optional, - Tuple, - Type, cast, ) +from collections.abc import Callable, Iterator from uuid import UUID from pydantic import model_validator @@ -74,8 +69,8 @@ class SubmissionResult: def __init__( self, - wait_for_completion: Optional[Callable[[], None]] = None, - metadata: Optional[Dict[str, MetadataType]] = None, + wait_for_completion: Callable[[], None] | None = None, + metadata: dict[str, MetadataType] | None = None, ): """Initialize a submission result. @@ -95,7 +90,7 @@ class BaseOrchestratorConfig(StackComponentConfig): @model_validator(mode="before") @classmethod @before_validator_handler - def _deprecations(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def _deprecations(cls, data: dict[str, Any]) -> dict[str, Any]: """Validate and/or remove deprecated fields. Args: @@ -182,10 +177,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to @@ -212,9 +207,9 @@ def prepare_or_run_pipeline( self, deployment: "PipelineSnapshotResponse", stack: "Stack", - environment: Dict[str, str], + environment: dict[str, str], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[Iterator[Dict[str, MetadataType]]]: + ) -> Iterator[dict[str, MetadataType]] | None: """DEPRECATED: Prepare or run a pipeline. Args: @@ -247,8 +242,8 @@ def run( """ self._prepare_run(snapshot=snapshot) - pipeline_run_id: Optional[UUID] = None - schedule_id: Optional[UUID] = None + pipeline_run_id: UUID | None = None + schedule_id: UUID | None = None if snapshot.schedule: schedule_id = snapshot.schedule.id if placeholder_run: @@ -490,7 +485,7 @@ def _cleanup_run(self) -> None: self._active_snapshot = None @property - def supported_execution_modes(self) -> List[ExecutionMode]: + def supported_execution_modes(self) -> list[ExecutionMode]: """Returns the supported execution modes for this flavor. Returns: @@ -609,8 +604,8 @@ def _validate_execution_mode( def fetch_status( self, run: "PipelineRunResponse", include_steps: bool = False - ) -> Tuple[ - Optional[ExecutionStatus], Optional[Dict[str, ExecutionStatus]] + ) -> tuple[ + ExecutionStatus | None, dict[str, ExecutionStatus] | None ]: """Refreshes the status of a specific pipeline run. @@ -762,7 +757,7 @@ def type(self) -> StackComponentType: return StackComponentType.ORCHESTRATOR @property - def config_class(self) -> Type[BaseOrchestratorConfig]: + def config_class(self) -> type[BaseOrchestratorConfig]: """Config class for the base orchestrator flavor. Returns: @@ -772,7 +767,7 @@ def config_class(self) -> Type[BaseOrchestratorConfig]: @property @abstractmethod - def implementation_class(self) -> Type["BaseOrchestrator"]: + def implementation_class(self) -> type["BaseOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/orchestrators/cache_utils.py b/src/zenml/orchestrators/cache_utils.py index 2592993b1b7..bafed291668 100644 --- a/src/zenml/orchestrators/cache_utils.py +++ b/src/zenml/orchestrators/cache_utils.py @@ -15,7 +15,8 @@ import hashlib import os -from typing import TYPE_CHECKING, Mapping, Optional +from typing import TYPE_CHECKING, Optional +from collections.abc import Mapping from uuid import UUID from zenml.client import Client diff --git a/src/zenml/orchestrators/containerized_orchestrator.py b/src/zenml/orchestrators/containerized_orchestrator.py index 80d50072463..cfb9d5ea079 100644 --- a/src/zenml/orchestrators/containerized_orchestrator.py +++ b/src/zenml/orchestrators/containerized_orchestrator.py @@ -14,7 +14,6 @@ """Containerized orchestrator class.""" from abc import ABC -from typing import List, Optional, Set import zenml from zenml.config.build_configuration import BuildConfiguration @@ -28,7 +27,7 @@ class ContainerizedOrchestrator(BaseOrchestrator, ABC): """Base class for containerized orchestrators.""" @property - def requirements(self) -> Set[str]: + def requirements(self) -> set[str]: """Set of PyPI requirements for the component. Returns: @@ -46,7 +45,7 @@ def requirements(self) -> Set[str]: @staticmethod def get_image( snapshot: "PipelineSnapshotResponse", - step_name: Optional[str] = None, + step_name: str | None = None, ) -> str: """Gets the Docker image for the pipeline/a step. @@ -86,7 +85,7 @@ def should_build_pipeline_image( def get_docker_builds( self, snapshot: "PipelineSnapshotBase" - ) -> List["BuildConfiguration"]: + ) -> list["BuildConfiguration"]: """Gets the Docker builds required for the component. Args: diff --git a/src/zenml/orchestrators/dag_runner.py b/src/zenml/orchestrators/dag_runner.py index 2f056b3f5dc..8ecfa5b3ed1 100644 --- a/src/zenml/orchestrators/dag_runner.py +++ b/src/zenml/orchestrators/dag_runner.py @@ -17,14 +17,15 @@ import time from collections import defaultdict from enum import Enum -from typing import Any, Callable, Dict, List, Optional +from typing import Any +from collections.abc import Callable from zenml.logger import get_logger logger = get_logger(__name__) -def reverse_dag(dag: Dict[str, List[str]]) -> Dict[str, List[str]]: +def reverse_dag(dag: dict[str, list[str]]) -> dict[str, list[str]]: """Reverse a DAG. Args: @@ -71,13 +72,13 @@ class ThreadedDagRunner: def __init__( self, - dag: Dict[str, List[str]], + dag: dict[str, list[str]], run_fn: Callable[[str], Any], - preparation_fn: Optional[Callable[[str], bool]] = None, - finalize_fn: Optional[Callable[[Dict[str, NodeStatus]], None]] = None, + preparation_fn: Callable[[str], bool] | None = None, + finalize_fn: Callable[[dict[str, NodeStatus]], None] | None = None, parallel_node_startup_waiting_period: float = 0.0, - max_parallelism: Optional[int] = None, - continue_fn: Optional[Callable[[], bool]] = None, + max_parallelism: int | None = None, + continue_fn: Callable[[], bool] | None = None, ) -> None: """Define attributes and initialize all nodes in waiting state. @@ -246,7 +247,7 @@ def _finish_node( return # Run downstream nodes. - threads: List[threading.Thread] = [] + threads: list[threading.Thread] = [] for downstream_node in self.reversed_dag[node]: if self._can_run(downstream_node): if threads and self.parallel_node_startup_waiting_period > 0: @@ -268,7 +269,7 @@ def run(self) -> None: # Run all nodes that can be started immediately. # These will, in turn, start other nodes once all of their respective # upstream nodes have completed. - threads: List[threading.Thread] = [] + threads: list[threading.Thread] = [] for node in self.nodes: if self._can_run(node): if threads and self.parallel_node_startup_waiting_period > 0: diff --git a/src/zenml/orchestrators/input_utils.py b/src/zenml/orchestrators/input_utils.py index 1578ffa7bee..f85165c0b06 100644 --- a/src/zenml/orchestrators/input_utils.py +++ b/src/zenml/orchestrators/input_utils.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Utilities for inputs.""" -from typing import TYPE_CHECKING, Dict, Optional +from typing import TYPE_CHECKING from zenml.client import Client from zenml.config.step_configurations import Step @@ -29,8 +29,8 @@ def resolve_step_inputs( step: "Step", pipeline_run: "PipelineRunResponse", - step_runs: Optional[Dict[str, "StepRunResponse"]] = None, -) -> Dict[str, "StepRunInputResponse"]: + step_runs: dict[str, "StepRunResponse"] | None = None, +) -> dict[str, "StepRunInputResponse"]: """Resolves inputs for the current step. Args: @@ -55,9 +55,9 @@ def resolve_step_inputs( step_runs = step_runs or {} - steps_to_fetch = set( + steps_to_fetch = { input_.step_name for input_ in step.spec.inputs.values() - ) + } # Remove all the step runs that we've already fetched. steps_to_fetch.difference_update(step_runs.keys()) @@ -68,7 +68,7 @@ def resolve_step_inputs( ) ) - input_artifacts: Dict[str, StepRunInputResponse] = {} + input_artifacts: dict[str, StepRunInputResponse] = {} for name, input_ in step.spec.inputs.items(): try: step_run = step_runs[input_.step_name] diff --git a/src/zenml/orchestrators/local/local_orchestrator.py b/src/zenml/orchestrators/local/local_orchestrator.py index e45c4c1918d..9c4a42664c5 100644 --- a/src/zenml/orchestrators/local/local_orchestrator.py +++ b/src/zenml/orchestrators/local/local_orchestrator.py @@ -14,7 +14,7 @@ """Implementation of the ZenML local orchestrator.""" import time -from typing import TYPE_CHECKING, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Optional from uuid import uuid4 from zenml.enums import ExecutionMode @@ -42,7 +42,7 @@ class LocalOrchestrator(BaseOrchestrator): does not support running on a schedule. """ - _orchestrator_run_id: Optional[str] = None + _orchestrator_run_id: str | None = None @property def run_init_cleanup_at_step_level(self) -> bool: @@ -67,10 +67,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to @@ -108,9 +108,9 @@ def submit_pipeline( execution_mode = snapshot.pipeline_configuration.execution_mode - failed_steps: List[str] = [] - step_exception: Optional[Exception] = None - skipped_steps: List[str] = [] + failed_steps: list[str] = [] + step_exception: Exception | None = None + skipped_steps: list[str] = [] self.run_init_hook(snapshot=snapshot) @@ -209,7 +209,7 @@ def get_orchestrator_run_id(self) -> str: return self._orchestrator_run_id @property - def supported_execution_modes(self) -> List[ExecutionMode]: + def supported_execution_modes(self) -> list[ExecutionMode]: """Returns the supported execution modes for this flavor. Returns: @@ -257,7 +257,7 @@ def name(self) -> str: return "local" @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A URL to point at docs explaining this flavor. Returns: @@ -266,7 +266,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A URL to point at SDK docs explaining this flavor. Returns: @@ -284,7 +284,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/local.png" @property - def config_class(self) -> Type[BaseOrchestratorConfig]: + def config_class(self) -> type[BaseOrchestratorConfig]: """Config class for the base orchestrator flavor. Returns: @@ -293,7 +293,7 @@ def config_class(self) -> Type[BaseOrchestratorConfig]: return LocalOrchestratorConfig @property - def implementation_class(self) -> Type[LocalOrchestrator]: + def implementation_class(self) -> type[LocalOrchestrator]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py b/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py index ac7540594eb..7c49da1dab3 100644 --- a/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py +++ b/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py @@ -17,7 +17,7 @@ import os import sys import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, cast +from typing import TYPE_CHECKING, Any, Optional, cast from uuid import uuid4 from docker.errors import ContainerError @@ -55,7 +55,7 @@ class LocalDockerOrchestrator(ContainerizedOrchestrator): """ @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Settings class for the Local Docker orchestrator. Returns: @@ -73,7 +73,7 @@ def config(self) -> "LocalDockerOrchestratorConfig": return cast(LocalDockerOrchestratorConfig, self._config) @property - def validator(self) -> Optional[StackValidator]: + def validator(self) -> StackValidator | None: """Ensures there is an image builder in the stack. Returns: @@ -105,10 +105,10 @@ def submit_pipeline( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - base_environment: Dict[str, str], - step_environments: Dict[str, Dict[str, str]], + base_environment: dict[str, str], + step_environments: dict[str, dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, - ) -> Optional[SubmissionResult]: + ) -> SubmissionResult | None: """Submits a pipeline to the orchestrator. This method should only submit the pipeline and not wait for it to @@ -158,8 +158,8 @@ def submit_pipeline( execution_mode = snapshot.pipeline_configuration.execution_mode - failed_steps: List[str] = [] - skipped_steps: List[str] = [] + failed_steps: list[str] = [] + skipped_steps: list[str] = [] # Run each step for step_name, step in snapshot.step_configurations.items(): @@ -278,7 +278,7 @@ def submit_pipeline( return None @property - def supported_execution_modes(self) -> List[ExecutionMode]: + def supported_execution_modes(self) -> list[ExecutionMode]: """Supported execution modes for this orchestrator. Returns: @@ -300,7 +300,7 @@ class LocalDockerOrchestratorSettings(BaseSettings): of what can be passed.) """ - run_args: Dict[str, Any] = {} + run_args: dict[str, Any] = {} class LocalDockerOrchestratorConfig( @@ -340,7 +340,7 @@ def name(self) -> str: return "local_docker" @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -349,7 +349,7 @@ def docs_url(self) -> Optional[str]: return self.generate_default_docs_url() @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -367,7 +367,7 @@ def logo_url(self) -> str: return "https://public-flavor-logos.s3.eu-central-1.amazonaws.com/orchestrator/docker.png" @property - def config_class(self) -> Type[BaseOrchestratorConfig]: + def config_class(self) -> type[BaseOrchestratorConfig]: """Config class for the base orchestrator flavor. Returns: @@ -376,7 +376,7 @@ def config_class(self) -> Type[BaseOrchestratorConfig]: return LocalDockerOrchestratorConfig @property - def implementation_class(self) -> Type["LocalDockerOrchestrator"]: + def implementation_class(self) -> type["LocalDockerOrchestrator"]: """Implementation class for this flavor. Returns: diff --git a/src/zenml/orchestrators/output_utils.py b/src/zenml/orchestrators/output_utils.py index 7e9ef2edb9f..6cdabe223e4 100644 --- a/src/zenml/orchestrators/output_utils.py +++ b/src/zenml/orchestrators/output_utils.py @@ -14,7 +14,8 @@ """Utilities for outputs.""" import os -from typing import TYPE_CHECKING, Dict, Sequence +from typing import TYPE_CHECKING +from collections.abc import Sequence from uuid import uuid4 from zenml.client import Client @@ -64,7 +65,7 @@ def prepare_output_artifact_uris( step: "Step", *, skip_artifact_materialization: bool = False, -) -> Dict[str, str]: +) -> dict[str, str]: """Prepares the output artifact URIs to run the current step. Args: @@ -80,7 +81,7 @@ def prepare_output_artifact_uris( A dictionary mapping output names to artifact URIs. """ artifact_store = stack.artifact_store - output_artifact_uris: Dict[str, str] = {} + output_artifact_uris: dict[str, str] = {} for output_name in step.config.outputs.keys(): substituted_output_name = string_utils.format_name_template( diff --git a/src/zenml/orchestrators/publish_utils.py b/src/zenml/orchestrators/publish_utils.py index 8956852f934..d40766d2fbf 100644 --- a/src/zenml/orchestrators/publish_utils.py +++ b/src/zenml/orchestrators/publish_utils.py @@ -15,7 +15,7 @@ from contextvars import ContextVar from datetime import datetime -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING from zenml.client import Client from zenml.enums import ExecutionStatus, MetadataResourceTypes @@ -34,13 +34,13 @@ from zenml.metadata.metadata_types import MetadataType -step_exception_info: ContextVar[Optional[ExceptionInfo]] = ContextVar( +step_exception_info: ContextVar[ExceptionInfo | None] = ContextVar( "step_exception_info", default=None ) def publish_successful_step_run( - step_run_id: "UUID", output_artifact_ids: Dict[str, List["UUID"]] + step_run_id: "UUID", output_artifact_ids: dict[str, list["UUID"]] ) -> "StepRunResponse": """Publishes a successful step run. @@ -64,8 +64,8 @@ def publish_successful_step_run( def publish_step_run_status_update( step_run_id: "UUID", status: "ExecutionStatus", - end_time: Optional[datetime] = None, - exception_info: Optional[ExceptionInfo] = None, + end_time: datetime | None = None, + exception_info: ExceptionInfo | None = None, ) -> "StepRunResponse": """Publishes a step run update. @@ -141,8 +141,8 @@ def publish_failed_pipeline_run( def publish_pipeline_run_status_update( pipeline_run_id: "UUID", status: ExecutionStatus, - status_reason: Optional[str] = None, - end_time: Optional[datetime] = None, + status_reason: str | None = None, + end_time: datetime | None = None, ) -> "PipelineRunResponse": """Publishes a pipeline run status update. @@ -177,7 +177,7 @@ def publish_pipeline_run_status_update( def get_pipeline_run_status( run_status: ExecutionStatus, - step_statuses: List[ExecutionStatus], + step_statuses: list[ExecutionStatus], num_steps: int, ) -> ExecutionStatus: """Gets the pipeline run status for the given step statuses. @@ -229,7 +229,7 @@ def get_pipeline_run_status( def publish_pipeline_run_metadata( pipeline_run_id: "UUID", - pipeline_run_metadata: Dict["UUID", Dict[str, "MetadataType"]], + pipeline_run_metadata: dict["UUID", dict[str, "MetadataType"]], ) -> None: """Publishes the given pipeline run metadata. @@ -253,7 +253,7 @@ def publish_pipeline_run_metadata( def publish_step_run_metadata( step_run_id: "UUID", - step_run_metadata: Dict["UUID", Dict[str, "MetadataType"]], + step_run_metadata: dict["UUID", dict[str, "MetadataType"]], ) -> None: """Publishes the given step run metadata. @@ -277,7 +277,7 @@ def publish_step_run_metadata( def publish_schedule_metadata( schedule_id: "UUID", - schedule_metadata: Dict["UUID", Dict[str, "MetadataType"]], + schedule_metadata: dict["UUID", dict[str, "MetadataType"]], ) -> None: """Publishes the given schedule metadata. diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index eacb6fd24d1..6d7aa4b7c09 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -16,7 +16,8 @@ import signal import time from contextlib import nullcontext -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any +from collections.abc import Callable from zenml.client import Client from zenml.config.step_configurations import Step @@ -53,7 +54,7 @@ def _get_step_operator( - stack: "Stack", step_operator_name: Optional[str] + stack: "Stack", step_operator_name: str | None ) -> "BaseStepOperator": """Fetches the step operator from the stack. @@ -132,7 +133,7 @@ def __init__( self._invocation_id = step.spec.invocation_id # Internal properties and methods - self._step_run: Optional[StepRunResponse] = None + self._step_run: StepRunResponse | None = None self._setup_signal_handlers() def _setup_signal_handlers(self) -> None: @@ -377,7 +378,7 @@ def _bypass() -> None: model_version=model_version, ) - def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]: + def _create_or_reuse_run(self) -> tuple[PipelineRunResponse, bool]: """Creates a pipeline run or reuses an existing one. Returns: @@ -475,7 +476,7 @@ def _run_step( def _run_step_with_step_operator( self, - step_operator_name: Optional[str], + step_operator_name: str | None, step_run_info: StepRunInfo, ) -> None: """Runs the current step with a step operator. @@ -527,8 +528,8 @@ def _run_step_without_step_operator( pipeline_run: PipelineRunResponse, step_run: StepRunResponse, step_run_info: StepRunInfo, - input_artifacts: Dict[str, StepRunInputResponse], - output_artifact_uris: Dict[str, str], + input_artifacts: dict[str, StepRunInputResponse], + output_artifact_uris: dict[str, str], ) -> None: """Runs the current step without a step operator. diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index ea892203a6f..d037091943b 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -15,7 +15,6 @@ import json from datetime import timedelta -from typing import Dict, List, Optional, Set, Tuple from zenml.client import Client from zenml.config.step_configurations import Step @@ -101,7 +100,7 @@ def create_request(self, invocation_id: str) -> StepRunRequest: def populate_request( self, request: StepRunRequest, - step_runs: Optional[Dict[str, "StepRunResponse"]] = None, + step_runs: dict[str, "StepRunResponse"] | None = None, ) -> None: """Populate a step run request with additional information. @@ -186,7 +185,7 @@ def populate_request( def _get_docstring_and_source_code( self, invocation_id: str - ) -> Tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """Get the docstring and source code for the step. Args: @@ -216,7 +215,7 @@ def _get_docstring_and_source_code( @staticmethod def _get_docstring_and_source_code_from_step_instance( step: "Step", - ) -> Tuple[Optional[str], str]: + ) -> tuple[str | None, str]: """Get the docstring and source code of a step. Args: @@ -241,7 +240,7 @@ def _get_docstring_and_source_code_from_step_instance( def _try_to_get_docstring_and_source_code_from_template( self, invocation_id: str - ) -> Tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """Try to get the docstring and source code via a potential template. Args: @@ -274,8 +273,8 @@ def _try_to_get_docstring_and_source_code_from_template( def find_cacheable_invocation_candidates( snapshot: "PipelineSnapshotResponse", - finished_invocations: Set[str], -) -> Set[str]: + finished_invocations: set[str], +) -> set[str]: """Find invocations that can potentially be cached. Args: @@ -310,7 +309,7 @@ def create_cached_step_runs( snapshot: "PipelineSnapshotResponse", pipeline_run: PipelineRunResponse, stack: "Stack", -) -> Set[str]: +) -> set[str]: """Create all cached step runs for a pipeline run. Args: @@ -321,14 +320,14 @@ def create_cached_step_runs( Returns: The invocation IDs of the created step runs. """ - cached_invocations: Set[str] = set() - visited_invocations: Set[str] = set() + cached_invocations: set[str] = set() + visited_invocations: set[str] = set() request_factory = StepRunRequestFactory( snapshot=snapshot, pipeline_run=pipeline_run, stack=stack ) # This is used to cache the step runs that we created to avoid unnecessary # server requests. - step_runs: Dict[str, "StepRunResponse"] = {} + step_runs: dict[str, "StepRunResponse"] = {} while ( cache_candidates := find_cacheable_invocation_candidates( @@ -409,7 +408,7 @@ def log_model_version_dashboard_url( def link_output_artifacts_to_model_version( - artifacts: Dict[str, List[ArtifactVersionResponse]], + artifacts: dict[str, list[ArtifactVersionResponse]], model_version: ModelVersionResponse, ) -> None: """Link the outputs of a step run to a model version. @@ -450,8 +449,8 @@ def publish_cached_step_run( def fetch_step_runs_by_names( - step_run_names: List[str], pipeline_run: "PipelineRunResponse" -) -> Dict[str, "StepRunResponse"]: + step_run_names: list[str], pipeline_run: "PipelineRunResponse" +) -> dict[str, "StepRunResponse"]: """Fetch step runs by names. Args: diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index eb67b6f2a63..310159fb800 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -21,10 +21,6 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, - Tuple, - Type, ) from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact @@ -120,8 +116,8 @@ def run( self, pipeline_run: "PipelineRunResponse", step_run: "StepRunResponse", - input_artifacts: Dict[str, StepRunInputResponse], - output_artifact_uris: Dict[str, str], + input_artifacts: dict[str, StepRunInputResponse], + output_artifact_uris: dict[str, str], step_run_info: StepRunInfo, ) -> None: """Runs the step. @@ -358,8 +354,8 @@ def run( def _evaluate_artifact_names_in_collections( self, step_run: "StepRunResponse", - output_annotations: Dict[str, OutputSignature], - collections: List[Dict[str, Any]], + output_annotations: dict[str, OutputSignature], + collections: list[dict[str, Any]], ) -> None: """Evaluates the artifact names in the collections. @@ -396,7 +392,7 @@ def _load_step(self) -> "BaseStep": def _load_output_materializers( self, - ) -> Dict[str, Tuple[Type[BaseMaterializer], ...]]: + ) -> dict[str, tuple[type[BaseMaterializer], ...]]: """Loads the output materializers for the step. Returns: @@ -407,7 +403,7 @@ def _load_output_materializers( output_materializers = [] for source in output.materializer_source: - materializer_class: Type[BaseMaterializer] = ( + materializer_class: type[BaseMaterializer] = ( source_utils.load_and_validate_class( source, expected_class=BaseMaterializer ) @@ -420,10 +416,10 @@ def _load_output_materializers( def _parse_inputs( self, - args: List[str], - annotations: Dict[str, Any], - input_artifacts: Dict[str, StepRunInputResponse], - ) -> Dict[str, Any]: + args: list[str], + annotations: dict[str, Any], + input_artifacts: dict[str, StepRunInputResponse], + ) -> dict[str, Any]: """Parses the inputs for a step entrypoint function. Args: @@ -437,7 +433,7 @@ def _parse_inputs( Raises: RuntimeError: If a function argument value is missing. """ - function_params: Dict[str, Any] = {} + function_params: dict[str, Any] = {} if args and args[0] == "self": args.pop(0) @@ -460,7 +456,7 @@ def _parse_inputs( return function_params def _load_input_artifact( - self, artifact: "ArtifactVersionResponse", data_type: Type[Any] + self, artifact: "ArtifactVersionResponse", data_type: type[Any] ) -> Any: """Loads an input artifact. @@ -486,7 +482,7 @@ def _load_input_artifact( register_artifact_store_filesystem, ) - materializer_class: Type[BaseMaterializer] = ( + materializer_class: type[BaseMaterializer] = ( source_utils.load_and_validate_class( artifact.materializer, expected_class=BaseMaterializer ) @@ -514,8 +510,8 @@ def _load_artifact(artifact_store: "BaseArtifactStore") -> Any: def _validate_outputs( self, return_values: Any, - output_annotations: Dict[str, OutputSignature], - ) -> Dict[str, Any]: + output_annotations: dict[str, OutputSignature], + ) -> dict[str, Any]: """Validates the step function outputs. Args: @@ -569,7 +565,7 @@ def _validate_outputs( from zenml.steps.utils import get_args - validated_outputs: Dict[str, Any] = {} + validated_outputs: dict[str, Any] = {} for return_value, (output_name, output_annotation) in zip( return_values, output_annotations.items() ): @@ -591,13 +587,13 @@ def _validate_outputs( def _store_output_artifacts( self, - output_data: Dict[str, Any], - output_materializers: Dict[str, Tuple[Type[BaseMaterializer], ...]], - output_artifact_uris: Dict[str, str], - output_annotations: Dict[str, OutputSignature], + output_data: dict[str, Any], + output_materializers: dict[str, tuple[type[BaseMaterializer], ...]], + output_artifact_uris: dict[str, str], + output_annotations: dict[str, OutputSignature], artifact_metadata_enabled: bool, artifact_visualization_enabled: bool, - ) -> Dict[str, "ArtifactVersionResponse"]: + ) -> dict[str, "ArtifactVersionResponse"]: """Stores the output artifacts of the step. Args: @@ -640,7 +636,7 @@ def _store_output_artifacts( ].default_materializer_source if default_materializer_source: - default_materializer_class: Type[BaseMaterializer] = ( + default_materializer_class: type[BaseMaterializer] = ( source_utils.load_and_validate_class( default_materializer_source, expected_class=BaseMaterializer, diff --git a/src/zenml/orchestrators/topsort.py b/src/zenml/orchestrators/topsort.py index f515f31826c..a7ac2aea36c 100644 --- a/src/zenml/orchestrators/topsort.py +++ b/src/zenml/orchestrators/topsort.py @@ -31,7 +31,8 @@ https://github.com/tensorflow/tfx/blob/master/tfx/utils/topsort.py """ -from typing import Callable, List, Sequence, TypeVar +from typing import TypeVar +from collections.abc import Callable, Sequence from zenml.logger import get_logger @@ -43,9 +44,9 @@ def topsorted_layers( nodes: Sequence[NodeT], get_node_id_fn: Callable[[NodeT], str], - get_parent_nodes: Callable[[NodeT], List[NodeT]], - get_child_nodes: Callable[[NodeT], List[NodeT]], -) -> List[List[NodeT]]: + get_parent_nodes: Callable[[NodeT], list[NodeT]], + get_child_nodes: Callable[[NodeT], list[NodeT]], +) -> list[list[NodeT]]: """Sorts the DAG of nodes in topological order. Args: @@ -67,15 +68,15 @@ def topsorted_layers( ValueError: If the nodes are not unique. """ # Make sure the nodes are unique. - node_ids = set(get_node_id_fn(n) for n in nodes) + node_ids = {get_node_id_fn(n) for n in nodes} if len(node_ids) != len(nodes): raise ValueError("Nodes must have unique ids.") # The outputs of get_(parent|child)_nodes should always be deduplicated, # and references to unknown nodes should be removed. def _apply_and_clean( - func: Callable[[NodeT], List[NodeT]], func_name: str, node: NodeT - ) -> List[NodeT]: + func: Callable[[NodeT], list[NodeT]], func_name: str, node: NodeT + ) -> list[NodeT]: seen_inner_node_ids = set() result = [] for inner_node in func(node): @@ -104,10 +105,10 @@ def _apply_and_clean( return result - def get_clean_parent_nodes(node: NodeT) -> List[NodeT]: + def get_clean_parent_nodes(node: NodeT) -> list[NodeT]: return _apply_and_clean(get_parent_nodes, "get_parent_nodes", node) - def get_clean_child_nodes(node: NodeT) -> List[NodeT]: + def get_clean_child_nodes(node: NodeT) -> list[NodeT]: return _apply_and_clean(get_child_nodes, "get_child_nodes", node) # The first layer contains nodes with no incoming edges. @@ -126,10 +127,10 @@ def get_clean_child_nodes(node: NodeT) -> List[NodeT]: # Include the child node if all its parents are visited. If the child # node is part of a cycle, it will never be included since it will have # at least one unvisited parent node which is also part of the cycle. - parent_node_ids = set( + parent_node_ids = { get_node_id_fn(p) for p in get_clean_parent_nodes(child_node) - ) + } if parent_node_ids.issubset(visited_node_ids): next_layer.append(child_node) layer = next_layer diff --git a/src/zenml/orchestrators/utils.py b/src/zenml/orchestrators/utils.py index 2e88bb43f91..e7a64f8d307 100644 --- a/src/zenml/orchestrators/utils.py +++ b/src/zenml/orchestrators/utils.py @@ -15,7 +15,7 @@ import os import random -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, cast from uuid import UUID from zenml.client import Client @@ -42,7 +42,7 @@ def get_orchestrator_run_name( - pipeline_name: str, max_length: Optional[int] = None + pipeline_name: str, max_length: int | None = None ) -> str: """Gets an orchestrator run name. @@ -79,8 +79,8 @@ def get_orchestrator_run_name( def is_setting_enabled( - is_enabled_on_step: Optional[bool], - is_enabled_on_pipeline: Optional[bool], + is_enabled_on_step: bool | None, + is_enabled_on_pipeline: bool | None, ) -> bool: """Checks if a certain setting is enabled within a step run. @@ -103,10 +103,10 @@ def is_setting_enabled( def get_config_environment_vars( - schedule_id: Optional[UUID] = None, - pipeline_run_id: Optional[UUID] = None, - deployment_id: Optional[UUID] = None, -) -> Tuple[Dict[str, str], Dict[str, str]]: + schedule_id: UUID | None = None, + pipeline_run_id: UUID | None = None, + deployment_id: UUID | None = None, +) -> tuple[dict[str, str], dict[str, str]]: """Gets environment variables to set for mirroring the active config. If a schedule ID, pipeline run ID or step run ID is given, and the current @@ -130,7 +130,7 @@ def get_config_environment_vars( global_config = GlobalConfiguration() environment_vars = global_config.get_config_environment_vars() - secrets: Dict[str, str] = {} + secrets: dict[str, str] = {} if ( global_config.store_configuration.type == StoreType.REST @@ -247,7 +247,7 @@ class register_artifact_store_filesystem: will be restored. """ - def __init__(self, target_artifact_store_id: Optional[UUID]) -> None: + def __init__(self, target_artifact_store_id: UUID | None) -> None: """Initialization of the context manager. Args: @@ -308,9 +308,9 @@ def __enter__(self) -> "BaseArtifactStore": def __exit__( self, - exc_type: Optional[Any], - exc_value: Optional[Any], - traceback: Optional[Any], + exc_type: Any | None, + exc_value: Any | None, + traceback: Any | None, ) -> None: """Set it back to the original state. diff --git a/src/zenml/pipelines/build_utils.py b/src/zenml/pipelines/build_utils.py index 4e117561b89..ea48ac68356 100644 --- a/src/zenml/pipelines/build_utils.py +++ b/src/zenml/pipelines/build_utils.py @@ -18,8 +18,6 @@ import time from typing import ( TYPE_CHECKING, - Dict, - List, Optional, Union, ) @@ -161,7 +159,7 @@ def code_download_possible( def reuse_or_create_pipeline_build( snapshot: "PipelineSnapshotBase", allow_build_reuse: bool, - pipeline_id: Optional[UUID] = None, + pipeline_id: UUID | None = None, build: Union["UUID", "PipelineBuildBase", None] = None, code_repository: Optional["BaseCodeRepository"] = None, ) -> Optional["PipelineBuildResponse"]: @@ -304,7 +302,7 @@ def find_existing_build( def create_pipeline_build( snapshot: "PipelineSnapshotBase", - pipeline_id: Optional[UUID] = None, + pipeline_id: UUID | None = None, code_repository: Optional["BaseCodeRepository"] = None, ) -> Optional["PipelineBuildResponse"]: """Builds images and registers the output in the server. @@ -338,8 +336,8 @@ def create_pipeline_build( start_time = time.time() docker_image_builder = PipelineDockerImageBuilder() - images: Dict[str, BuildItem] = {} - checksums: Dict[str, str] = {} + images: dict[str, BuildItem] = {} + checksums: dict[str, str] = {} for build_config in required_builds: combined_key = PipelineBuildBase.get_image_key( @@ -447,7 +445,7 @@ def create_pipeline_build( def compute_build_checksum( - items: List["BuildConfiguration"], + items: list["BuildConfiguration"], stack: "Stack", code_repository: Optional["BaseCodeRepository"] = None, ) -> str: @@ -485,7 +483,7 @@ def compute_build_checksum( def verify_local_repository_context( snapshot: "PipelineSnapshotBase", local_repo_context: Optional["LocalRepositoryContext"], -) -> Optional[BaseCodeRepository]: +) -> BaseCodeRepository | None: """Verifies the local repository. If the local repository exists and has no local changes, code download @@ -693,7 +691,7 @@ def compute_stack_checksum(stack: StackResponse) -> str: def should_upload_code( snapshot: PipelineSnapshotBase, - build: Optional[PipelineBuildResponse], + build: PipelineBuildResponse | None, can_download_from_code_repository: bool, ) -> bool: """Checks whether the current code should be uploaded for the snapshot. diff --git a/src/zenml/pipelines/pipeline_decorator.py b/src/zenml/pipelines/pipeline_decorator.py index 97bd5c07b93..6b05a23eb2b 100644 --- a/src/zenml/pipelines/pipeline_decorator.py +++ b/src/zenml/pipelines/pipeline_decorator.py @@ -16,14 +16,12 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - List, Optional, TypeVar, Union, overload, ) +from collections.abc import Callable from uuid import UUID from zenml.enums import ExecutionMode @@ -50,24 +48,24 @@ def pipeline(_func: "F") -> "Pipeline": ... @overload def pipeline( *, - name: Optional[str] = None, - enable_cache: Optional[bool] = None, - enable_artifact_metadata: Optional[bool] = None, - enable_step_logs: Optional[bool] = None, - environment: Optional[Dict[str, Any]] = None, - secrets: Optional[List[Union[UUID, str]]] = None, - enable_pipeline_logs: Optional[bool] = None, - settings: Optional[Dict[str, "SettingsOrDict"]] = None, - tags: Optional[List[Union[str, "Tag"]]] = None, - extra: Optional[Dict[str, Any]] = None, + name: str | None = None, + enable_cache: bool | None = None, + enable_artifact_metadata: bool | None = None, + enable_step_logs: bool | None = None, + environment: dict[str, Any] | None = None, + secrets: list[UUID | str] | None = None, + enable_pipeline_logs: bool | None = None, + settings: dict[str, "SettingsOrDict"] | None = None, + tags: list[Union[str, "Tag"]] | None = None, + extra: dict[str, Any] | None = None, on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, on_init: Optional["InitHookSpecification"] = None, - on_init_kwargs: Optional[Dict[str, Any]] = None, + on_init_kwargs: dict[str, Any] | None = None, on_cleanup: Optional["HookSpecification"] = None, model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, - substitutions: Optional[Dict[str, str]] = None, + substitutions: dict[str, str] | None = None, execution_mode: Optional["ExecutionMode"] = None, cache_policy: Optional["CachePolicyOrString"] = None, ) -> Callable[["F"], "Pipeline"]: ... @@ -76,24 +74,24 @@ def pipeline( def pipeline( _func: Optional["F"] = None, *, - name: Optional[str] = None, - enable_cache: Optional[bool] = None, - enable_artifact_metadata: Optional[bool] = None, - enable_step_logs: Optional[bool] = None, - environment: Optional[Dict[str, Any]] = None, - secrets: Optional[List[Union[UUID, str]]] = None, - enable_pipeline_logs: Optional[bool] = None, - settings: Optional[Dict[str, "SettingsOrDict"]] = None, - tags: Optional[List[Union[str, "Tag"]]] = None, - extra: Optional[Dict[str, Any]] = None, + name: str | None = None, + enable_cache: bool | None = None, + enable_artifact_metadata: bool | None = None, + enable_step_logs: bool | None = None, + environment: dict[str, Any] | None = None, + secrets: list[UUID | str] | None = None, + enable_pipeline_logs: bool | None = None, + settings: dict[str, "SettingsOrDict"] | None = None, + tags: list[Union[str, "Tag"]] | None = None, + extra: dict[str, Any] | None = None, on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, on_init: Optional["InitHookSpecification"] = None, - on_init_kwargs: Optional[Dict[str, Any]] = None, + on_init_kwargs: dict[str, Any] | None = None, on_cleanup: Optional["HookSpecification"] = None, model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, - substitutions: Optional[Dict[str, str]] = None, + substitutions: dict[str, str] | None = None, execution_mode: Optional["ExecutionMode"] = None, cache_policy: Optional["CachePolicyOrString"] = None, ) -> Union["Pipeline", Callable[["F"], "Pipeline"]]: diff --git a/src/zenml/pipelines/pipeline_definition.py b/src/zenml/pipelines/pipeline_definition.py index f388dc5f581..efbb88979e7 100644 --- a/src/zenml/pipelines/pipeline_definition.py +++ b/src/zenml/pipelines/pipeline_definition.py @@ -21,20 +21,12 @@ from typing import ( TYPE_CHECKING, Any, - Callable, ClassVar, - Dict, - Iterator, - List, - Mapping, Optional, - Sequence, - Set, - Tuple, - Type, TypeVar, Union, ) +from collections.abc import Callable, Iterator, Mapping, Sequence from uuid import UUID import yaml @@ -116,7 +108,7 @@ from zenml.types import HookSpecification, InitHookSpecification StepConfigurationUpdateOrDict = Union[ - Dict[str, Any], StepConfigurationUpdate + dict[str, Any], StepConfigurationUpdate ] logger = get_logger(__name__) @@ -136,24 +128,24 @@ def __init__( self, name: str, entrypoint: F, - enable_cache: Optional[bool] = None, - enable_artifact_metadata: Optional[bool] = None, - enable_artifact_visualization: Optional[bool] = None, - enable_step_logs: Optional[bool] = None, - environment: Optional[Dict[str, Any]] = None, - secrets: Optional[List[Union[UUID, str]]] = None, - enable_pipeline_logs: Optional[bool] = None, - settings: Optional[Mapping[str, "SettingsOrDict"]] = None, - tags: Optional[List[Union[str, "Tag"]]] = None, - extra: Optional[Dict[str, Any]] = None, + enable_cache: bool | None = None, + enable_artifact_metadata: bool | None = None, + enable_artifact_visualization: bool | None = None, + enable_step_logs: bool | None = None, + environment: dict[str, Any] | None = None, + secrets: list[UUID | str] | None = None, + enable_pipeline_logs: bool | None = None, + settings: Mapping[str, "SettingsOrDict"] | None = None, + tags: list[Union[str, "Tag"]] | None = None, + extra: dict[str, Any] | None = None, on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, on_init: Optional["InitHookSpecification"] = None, - on_init_kwargs: Optional[Dict[str, Any]] = None, + on_init_kwargs: dict[str, Any] | None = None, on_cleanup: Optional["HookSpecification"] = None, model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, - substitutions: Optional[Dict[str, str]] = None, + substitutions: dict[str, str] | None = None, execution_mode: Optional["ExecutionMode"] = None, cache_policy: Optional["CachePolicyOrString"] = None, ) -> None: @@ -196,13 +188,13 @@ def __init__( execution_mode: The execution mode of the pipeline. cache_policy: Cache policy for this pipeline. """ - self._invocations: Dict[str, StepInvocation] = {} - self._run_args: Dict[str, Any] = {} + self._invocations: dict[str, StepInvocation] = {} + self._run_args: dict[str, Any] = {} self._configuration = PipelineConfiguration( name=name, ) - self._from_config_file: Dict[str, Any] = {} + self._from_config_file: dict[str, Any] = {} with self.__suppress_configure_warnings__(): self.configure( enable_cache=enable_cache, @@ -227,8 +219,8 @@ def __init__( cache_policy=cache_policy, ) self.entrypoint = entrypoint - self._parameters: Dict[str, Any] = {} - self._output_artifacts: List[StepArtifact] = [] + self._parameters: dict[str, Any] = {} + self._output_artifacts: list[StepArtifact] = [] self.__suppress_warnings_flag__ = False @@ -242,7 +234,7 @@ def name(self) -> str: return self.configuration.name @property - def enable_cache(self) -> Optional[bool]: + def enable_cache(self) -> bool | None: """If caching is enabled for the pipeline. Returns: @@ -260,7 +252,7 @@ def configuration(self) -> PipelineConfiguration: return self._configuration @property - def invocations(self) -> Dict[str, StepInvocation]: + def invocations(self) -> dict[str, StepInvocation]: """Returns the step invocations of this pipeline. This dictionary will only be populated once the pipeline has been @@ -334,25 +326,25 @@ def __suppress_configure_warnings__(self) -> Iterator[Any]: def configure( self, - enable_cache: Optional[bool] = None, - enable_artifact_metadata: Optional[bool] = None, - enable_artifact_visualization: Optional[bool] = None, - enable_step_logs: Optional[bool] = None, - environment: Optional[Dict[str, Any]] = None, - secrets: Optional[Sequence[Union[UUID, str]]] = None, - enable_pipeline_logs: Optional[bool] = None, - settings: Optional[Mapping[str, "SettingsOrDict"]] = None, - tags: Optional[List[Union[str, "Tag"]]] = None, - extra: Optional[Dict[str, Any]] = None, + enable_cache: bool | None = None, + enable_artifact_metadata: bool | None = None, + enable_artifact_visualization: bool | None = None, + enable_step_logs: bool | None = None, + environment: dict[str, Any] | None = None, + secrets: Sequence[UUID | str] | None = None, + enable_pipeline_logs: bool | None = None, + settings: Mapping[str, "SettingsOrDict"] | None = None, + tags: list[Union[str, "Tag"]] | None = None, + extra: dict[str, Any] | None = None, on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, on_init: Optional["InitHookSpecification"] = None, - on_init_kwargs: Optional[Dict[str, Any]] = None, + on_init_kwargs: dict[str, Any] | None = None, on_cleanup: Optional["HookSpecification"] = None, model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, - parameters: Optional[Dict[str, Any]] = None, - substitutions: Optional[Dict[str, str]] = None, + parameters: dict[str, Any] | None = None, + substitutions: dict[str, str] | None = None, execution_mode: Optional["ExecutionMode"] = None, cache_policy: Optional["CachePolicyOrString"] = None, merge: bool = True, @@ -524,7 +516,7 @@ def configure( return self @property - def required_parameters(self) -> List[str]: + def required_parameters(self) -> list[str]: """List of required parameters for the pipeline entrypoint. Returns: @@ -538,7 +530,7 @@ def required_parameters(self) -> List[str]: ] @property - def missing_parameters(self) -> List[str]: + def missing_parameters(self) -> list[str]: """List of missing parameters for the pipeline entrypoint. Returns: @@ -648,11 +640,11 @@ def register(self) -> "PipelineResponse": def build( self, - settings: Optional[Mapping[str, "SettingsOrDict"]] = None, - step_configurations: Optional[ + settings: Mapping[str, "SettingsOrDict"] | None = None, + step_configurations: None | ( Mapping[str, "StepConfigurationUpdateOrDict"] - ] = None, - config_path: Optional[str] = None, + ) = None, + config_path: str | None = None, ) -> Optional["PipelineBuildResponse"]: """Builds Docker images for the pipeline. @@ -698,7 +690,7 @@ def build( def deploy( self, deployment_name: str, - timeout: Optional[int] = None, + timeout: int | None = None, *args: Any, **kwargs: Any, ) -> DeploymentResponse: @@ -729,20 +721,20 @@ def deploy( def _create_snapshot( self, *, - run_name: Optional[str] = None, - enable_cache: Optional[bool] = None, - enable_artifact_metadata: Optional[bool] = None, - enable_artifact_visualization: Optional[bool] = None, - enable_step_logs: Optional[bool] = None, - enable_pipeline_logs: Optional[bool] = None, - schedule: Optional[Schedule] = None, + run_name: str | None = None, + enable_cache: bool | None = None, + enable_artifact_metadata: bool | None = None, + enable_artifact_visualization: bool | None = None, + enable_step_logs: bool | None = None, + enable_pipeline_logs: bool | None = None, + schedule: Schedule | None = None, build: Union[str, "UUID", "PipelineBuildBase", None] = None, - settings: Optional[Mapping[str, "SettingsOrDict"]] = None, - step_configurations: Optional[ + settings: Mapping[str, "SettingsOrDict"] | None = None, + step_configurations: None | ( Mapping[str, "StepConfigurationUpdateOrDict"] - ] = None, - extra: Optional[Dict[str, Any]] = None, - config_path: Optional[str] = None, + ) = None, + extra: dict[str, Any] | None = None, + config_path: str | None = None, prevent_build_reuse: bool = False, skip_schedule_registration: bool = False, **snapshot_request_kwargs: Any, @@ -920,7 +912,7 @@ def _create_snapshot( def _run( self, - ) -> Optional[PipelineRunResponse]: + ) -> PipelineRunResponse | None: """Runs the pipeline on the active stack. Returns: @@ -1168,8 +1160,8 @@ def _get_pipeline_analytics_metadata( self, snapshot: "PipelineSnapshotResponse", stack: "Stack", - run_id: Optional[UUID] = None, - ) -> Dict[str, Any]: + run_id: UUID | None = None, + ) -> dict[str, Any]: """Compute analytics metadata for the pipeline snapshot. Args: @@ -1207,8 +1199,8 @@ def _get_pipeline_analytics_metadata( } def _compile( - self, config_path: Optional[str] = None, **run_configuration_args: Any - ) -> Tuple[ + self, config_path: str | None = None, **run_configuration_args: Any + ) -> tuple[ "PipelineSnapshotBase", Optional["Schedule"], Union["PipelineBuildBase", UUID, None], @@ -1321,16 +1313,16 @@ def _compute_unique_identifier(self, pipeline_spec: PipelineSpec) -> str: def add_step_invocation( self, step: "BaseStep", - input_artifacts: Dict[str, StepArtifact], - external_artifacts: Dict[ + input_artifacts: dict[str, StepArtifact], + external_artifacts: dict[ str, Union["ExternalArtifact", "ArtifactVersionResponse"] ], - model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"], - client_lazy_loaders: Dict[str, "ClientLazyLoader"], - parameters: Dict[str, Any], - default_parameters: Dict[str, Any], - upstream_steps: Set[str], - custom_id: Optional[str] = None, + model_artifacts_or_metadata: dict[str, "ModelVersionDataLazyLoader"], + client_lazy_loaders: dict[str, "ClientLazyLoader"], + parameters: dict[str, Any], + default_parameters: dict[str, Any], + upstream_steps: set[str], + custom_id: str | None = None, allow_id_suffix: bool = True, ) -> str: """Adds a step invocation to the pipeline. @@ -1391,7 +1383,7 @@ def add_step_invocation( def _compute_invocation_id( self, step: "BaseStep", - custom_id: Optional[str] = None, + custom_id: str | None = None, allow_suffix: bool = True, ) -> str: """Compute the invocation ID. @@ -1452,8 +1444,8 @@ def __exit__(self, *args: Any) -> None: Pipeline.ACTIVE_PIPELINE = None def _parse_config_file( - self, config_path: Optional[str], matcher: List[str] - ) -> Dict[str, Any]: + self, config_path: str | None, matcher: list[str] + ) -> dict[str, Any]: """Parses the given configuration file and sets `self._from_config_file`. Args: @@ -1463,9 +1455,9 @@ def _parse_config_file( Returns: Parsed config file according to matcher settings. """ - _from_config_file: Dict[str, Any] = {} + _from_config_file: dict[str, Any] = {} if config_path: - with open(config_path, "r") as f: + with open(config_path) as f: _from_config_file = yaml.load(f, Loader=yaml.SafeLoader) _from_config_file = dict_utils.remove_none_values( @@ -1487,14 +1479,14 @@ def _parse_config_file( def with_options( self, - run_name: Optional[str] = None, - schedule: Optional[Schedule] = None, + run_name: str | None = None, + schedule: Schedule | None = None, build: Union[str, "UUID", "PipelineBuildBase", None] = None, - step_configurations: Optional[ + step_configurations: None | ( Mapping[str, "StepConfigurationUpdateOrDict"] - ] = None, - steps: Optional[Mapping[str, "StepConfigurationUpdateOrDict"]] = None, - config_path: Optional[str] = None, + ) = None, + steps: Mapping[str, "StepConfigurationUpdateOrDict"] | None = None, + config_path: str | None = None, unlisted: bool = False, prevent_build_reuse: bool = False, **kwargs: Any, @@ -1567,7 +1559,7 @@ def copy(self) -> "Pipeline": def __call__( self, *args: Any, **kwargs: Any - ) -> Optional[PipelineRunResponse]: + ) -> PipelineRunResponse | None: """Handle a call of the pipeline. This method does one of two things: @@ -1691,9 +1683,9 @@ def create_run_template( def create_snapshot( self, name: str, - description: Optional[str] = None, - replace: Optional[bool] = None, - tags: Optional[List[str]] = None, + description: str | None = None, + replace: bool | None = None, + tags: list[str] | None = None, ) -> PipelineSnapshotResponse: """Create a snapshot of the pipeline. @@ -1719,7 +1711,7 @@ def create_snapshot( def _reconfigure_from_file_with_overrides( self, - config_path: Optional[str] = None, + config_path: str | None = None, **kwargs: Any, ) -> None: """Update the pipeline configuration from config file. @@ -1750,7 +1742,7 @@ def _reconfigure_from_file_with_overrides( with self.__suppress_configure_warnings__(): self.configure(**_from_config_file) - def _compute_output_schema(self) -> Optional[Dict[str, Any]]: + def _compute_output_schema(self) -> dict[str, Any] | None: """Computes the output schema for the pipeline. Returns: @@ -1765,14 +1757,14 @@ def _compute_output_schema(self) -> Optional[Dict[str, Any]]: } ) - fields: Dict[str, Any] = { + fields: dict[str, Any] = { entry[1]: ( entry[0].annotation.resolved_annotation, ..., ) for _, entry in unique_step_output_mapping.items() } - output_model_class: Type[BaseModel] = create_model( + output_model_class: type[BaseModel] = create_model( "PipelineOutput", __config__=ConfigDict(arbitrary_types_allowed=True), **fields, @@ -1789,7 +1781,7 @@ def _compute_output_schema(self) -> Optional[Dict[str, Any]]: return None - def _compute_input_model(self) -> Optional[Type[BaseModel]]: + def _compute_input_model(self) -> type[BaseModel] | None: """Create a Pydantic model that represents the pipeline input parameters. Returns: @@ -1805,8 +1797,8 @@ def _compute_input_model(self) -> Optional[Type[BaseModel]]: self.entrypoint ) - defaults: Dict[str, Any] = self._parameters - model_args: Dict[str, Any] = {} + defaults: dict[str, Any] = self._parameters + model_args: dict[str, Any] = {} for name, param in entrypoint_definition.inputs.items(): if name in defaults: default_value = defaults[name] @@ -1818,7 +1810,7 @@ def _compute_input_model(self) -> Optional[Type[BaseModel]]: model_args[name] = (param.annotation, default_value) model_args["__config__"] = ConfigDict(extra="forbid") - params_model: Type[BaseModel] = create_model( + params_model: type[BaseModel] = create_model( "PipelineInput", **model_args, ) @@ -1831,7 +1823,7 @@ def _compute_input_model(self) -> Optional[Type[BaseModel]]: ) return None - def _compute_input_schema(self) -> Optional[Dict[str, Any]]: + def _compute_input_schema(self) -> dict[str, Any] | None: """Create a JSON schema that represents the pipeline input parameters. Returns: diff --git a/src/zenml/pipelines/run_utils.py b/src/zenml/pipelines/run_utils.py index 0adedd83468..170423e35fb 100644 --- a/src/zenml/pipelines/run_utils.py +++ b/src/zenml/pipelines/run_utils.py @@ -1,7 +1,7 @@ """Utility functions for running pipelines.""" import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union +from typing import TYPE_CHECKING, Any, Optional, Union from uuid import UUID from pydantic import BaseModel @@ -32,7 +32,7 @@ if TYPE_CHECKING: StepConfigurationUpdateOrDict = Union[ - Dict[str, Any], StepConfigurationUpdate + dict[str, Any], StepConfigurationUpdate ] logger = get_logger(__name__) @@ -52,9 +52,9 @@ def get_default_run_name(pipeline_name: str) -> str: def create_placeholder_run( snapshot: "PipelineSnapshotResponse", - orchestrator_run_id: Optional[str] = None, + orchestrator_run_id: str | None = None, logs: Optional["LogsRequest"] = None, - trigger_info: Optional[PipelineRunTriggerInfo] = None, + trigger_info: PipelineRunTriggerInfo | None = None, ) -> "PipelineRunResponse": """Create a placeholder run for the snapshot. @@ -335,7 +335,7 @@ def upload_notebook_cell_code_if_necessary( logger.info("Upload finished.") -def get_all_sources_from_value(value: Any) -> List[Source]: +def get_all_sources_from_value(value: Any) -> list[Source]: """Get all source objects from a value. Args: @@ -350,10 +350,10 @@ def get_all_sources_from_value(value: Any) -> List[Source]: elif isinstance(value, BaseModel): for v in value.__dict__.values(): sources.extend(get_all_sources_from_value(v)) - elif isinstance(value, Dict): + elif isinstance(value, dict): for v in value.values(): sources.extend(get_all_sources_from_value(v)) - elif isinstance(value, (List, Set, tuple)): + elif isinstance(value, (list, set, tuple)): for v in value: sources.extend(get_all_sources_from_value(v)) diff --git a/src/zenml/plugins/base_plugin_flavor.py b/src/zenml/plugins/base_plugin_flavor.py index bbbbb5c8419..5e3339a015f 100644 --- a/src/zenml/plugins/base_plugin_flavor.py +++ b/src/zenml/plugins/base_plugin_flavor.py @@ -51,7 +51,7 @@ def zen_store(self) -> "BaseZenStore": @property @abstractmethod - def config_class(self) -> Type[BasePluginConfig]: + def config_class(self) -> type[BasePluginConfig]: """Returns the `BasePluginConfig` config. Returns: @@ -74,7 +74,7 @@ class BasePluginFlavor(ABC): TYPE: ClassVar[PluginType] SUBTYPE: ClassVar[PluginSubType] FLAVOR: ClassVar[str] - PLUGIN_CLASS: ClassVar[Type[BasePlugin]] + PLUGIN_CLASS: ClassVar[type[BasePlugin]] @classmethod @abstractmethod diff --git a/src/zenml/plugins/plugin_flavor_registry.py b/src/zenml/plugins/plugin_flavor_registry.py index 574de90c664..6417933eb37 100644 --- a/src/zenml/plugins/plugin_flavor_registry.py +++ b/src/zenml/plugins/plugin_flavor_registry.py @@ -14,7 +14,8 @@ """Registry for all plugins.""" import math -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type +from typing import TYPE_CHECKING, Any +from collections.abc import Sequence from pydantic import BaseModel, ConfigDict @@ -34,8 +35,8 @@ class RegistryEntry(BaseModel): """Registry Entry Class for the Plugin Registry.""" - flavor_class: Type[BasePluginFlavor] - plugin_instance: Optional[BasePlugin] = None + flavor_class: type[BasePluginFlavor] + plugin_instance: BasePlugin | None = None model_config = ConfigDict(arbitrary_types_allowed=True) @@ -44,13 +45,13 @@ class PluginFlavorRegistry: def __init__(self) -> None: """Initialize the event flavor registry.""" - self.plugin_flavors: Dict[ - PluginType, Dict[PluginSubType, Dict[str, RegistryEntry]] + self.plugin_flavors: dict[ + PluginType, dict[PluginSubType, dict[str, RegistryEntry]] ] = {} self.register_plugin_flavors() @property - def _types(self) -> List[PluginType]: + def _types(self) -> list[PluginType]: """Returns all available types. Returns: @@ -60,7 +61,7 @@ def _types(self) -> List[PluginType]: def list_subtypes_within_type( self, _type: PluginType - ) -> List[PluginSubType]: + ) -> list[PluginSubType]: """Returns all available subtypes for a given type. Args: @@ -73,7 +74,7 @@ def list_subtypes_within_type( def _flavor_entries( self, _type: PluginType, subtype: PluginSubType - ) -> Dict[str, RegistryEntry]: + ) -> dict[str, RegistryEntry]: """Get a list of all subtypes for a specific flavor and type. Args: @@ -95,7 +96,7 @@ def list_available_flavors_for_type_and_subtype( self, _type: PluginType, subtype: PluginSubType, - ) -> List[Type[BasePluginFlavor]]: + ) -> list[type[BasePluginFlavor]]: """Get a list of all subtypes for a specific flavor and type. Args: @@ -173,7 +174,7 @@ def list_available_flavor_responses_for_type_and_subtype( ) @property - def _builtin_flavors(self) -> Sequence[Type["BasePluginFlavor"]]: + def _builtin_flavors(self) -> Sequence[type["BasePluginFlavor"]]: """A list of all default in-built flavors. Returns: @@ -187,7 +188,7 @@ def _builtin_flavors(self) -> Sequence[Type["BasePluginFlavor"]]: return flavors @property - def _integration_flavors(self) -> Sequence[Type["BasePluginFlavor"]]: + def _integration_flavors(self) -> Sequence[type["BasePluginFlavor"]]: """A list of all integration event flavors. Returns: @@ -227,7 +228,7 @@ def register_plugin_flavors(self) -> None: self.register_plugin_flavor(flavor_class=flavor) def register_plugin_flavor( - self, flavor_class: Type[BasePluginFlavor] + self, flavor_class: type[BasePluginFlavor] ) -> None: """Registers a new event_source. @@ -267,7 +268,7 @@ def get_flavor_class( _type: PluginType, subtype: PluginSubType, name: str, - ) -> Type[BasePluginFlavor]: + ) -> type[BasePluginFlavor]: """Get a single event_source based on the key. Args: diff --git a/src/zenml/secret/base_secret.py b/src/zenml/secret/base_secret.py index b69bd829fe0..b9e00bac5fb 100644 --- a/src/zenml/secret/base_secret.py +++ b/src/zenml/secret/base_secret.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of the Base SecretSchema class.""" -from typing import Any, Dict, List +from typing import Any from pydantic import BaseModel, ConfigDict @@ -22,7 +22,7 @@ class BaseSecretSchema(BaseModel): """Base class for all Secret Schemas.""" @classmethod - def get_schema_keys(cls) -> List[str]: + def get_schema_keys(cls) -> list[str]: """Get all attributes that are part of the schema. These schema keys can be used to define all required key-value pairs of @@ -33,7 +33,7 @@ def get_schema_keys(cls) -> List[str]: """ return list(cls.model_fields.keys()) - def get_values(self) -> Dict[str, Any]: + def get_values(self) -> dict[str, Any]: """Get all values of the secret schema. Returns: diff --git a/src/zenml/secret/schemas/aws_secret_schema.py b/src/zenml/secret/schemas/aws_secret_schema.py index 7df719ee634..3e611f5a719 100644 --- a/src/zenml/secret/schemas/aws_secret_schema.py +++ b/src/zenml/secret/schemas/aws_secret_schema.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """AWS Authentication Secret Schema definition.""" -from typing import Optional from zenml.secret.base_secret import BaseSecretSchema @@ -23,4 +22,4 @@ class AWSSecretSchema(BaseSecretSchema): aws_access_key_id: str aws_secret_access_key: str - aws_session_token: Optional[str] = None + aws_session_token: str | None = None diff --git a/src/zenml/secret/schemas/azure_secret_schema.py b/src/zenml/secret/schemas/azure_secret_schema.py index 3f7a516a031..85ff8c18110 100644 --- a/src/zenml/secret/schemas/azure_secret_schema.py +++ b/src/zenml/secret/schemas/azure_secret_schema.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Azure Authentication Secret Schema definition.""" -from typing import Optional from zenml.secret.base_secret import BaseSecretSchema @@ -21,10 +20,10 @@ class AzureSecretSchema(BaseSecretSchema): """Azure Authentication Secret Schema definition.""" - account_name: Optional[str] = None - account_key: Optional[str] = None - sas_token: Optional[str] = None - connection_string: Optional[str] = None - client_id: Optional[str] = None - client_secret: Optional[str] = None - tenant_id: Optional[str] = None + account_name: str | None = None + account_key: str | None = None + sas_token: str | None = None + connection_string: str | None = None + client_id: str | None = None + client_secret: str | None = None + tenant_id: str | None = None diff --git a/src/zenml/secret/schemas/gcp_secret_schema.py b/src/zenml/secret/schemas/gcp_secret_schema.py index 34d214a8eec..193b2be4107 100644 --- a/src/zenml/secret/schemas/gcp_secret_schema.py +++ b/src/zenml/secret/schemas/gcp_secret_schema.py @@ -14,7 +14,7 @@ """GCP Authentication Secret Schema definition.""" import json -from typing import Any, Dict +from typing import Any from zenml.secret.base_secret import BaseSecretSchema @@ -24,7 +24,7 @@ class GCPSecretSchema(BaseSecretSchema): token: str - def get_credential_dict(self) -> Dict[str, Any]: + def get_credential_dict(self) -> dict[str, Any]: """Gets a dictionary of credentials for authenticating to GCP. Returns: @@ -41,7 +41,7 @@ def get_credential_dict(self) -> Dict[str, Any]: "valid JSON string." ) - if not isinstance(dict_, Dict): + if not isinstance(dict_, dict): raise ValueError( "Failed to parse GCP secret token. The token value does not " "represent a GCP credential dictionary." diff --git a/src/zenml/service_connectors/docker_service_connector.py b/src/zenml/service_connectors/docker_service_connector.py index 5b54e9c33ab..0beb399272e 100644 --- a/src/zenml/service_connectors/docker_service_connector.py +++ b/src/zenml/service_connectors/docker_service_connector.py @@ -19,7 +19,7 @@ import re import subprocess -from typing import Any, List, Optional +from typing import Any from docker.client import DockerClient from docker.errors import DockerException @@ -58,7 +58,7 @@ class DockerCredentials(AuthenticationConfig): class DockerConfiguration(DockerCredentials): """Docker client configuration.""" - registry: Optional[str] = Field( + registry: str | None = Field( default=None, title="Registry server URL. Omit to use DockerHub.", ) @@ -159,7 +159,7 @@ def _parse_resource_id( ValueError: If the provided resource ID is not a valid Docker registry. """ - registry: Optional[str] = None + registry: str | None = None if re.match( r"^(https?://)?[a-zA-Z0-9-]+(\.[a-zA-Z0-9-]+)*(:[0-9]+)?(/.+)*$", resource_id, @@ -321,9 +321,9 @@ def _configure_local_client( @classmethod def _auto_configure( cls, - auth_method: Optional[str] = None, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + auth_method: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, **kwargs: Any, ) -> "DockerServiceConnector": """Auto-configure the connector. @@ -351,9 +351,9 @@ def _auto_configure( def _verify( self, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> List[str]: + resource_type: str | None = None, + resource_id: str | None = None, + ) -> list[str]: """Verify that the connector can authenticate and access resources. Args: diff --git a/src/zenml/service_connectors/service_connector.py b/src/zenml/service_connectors/service_connector.py index bf26a7bf3ad..59a5f47b9d3 100644 --- a/src/zenml/service_connectors/service_connector.py +++ b/src/zenml/service_connectors/service_connector.py @@ -19,11 +19,7 @@ from typing import ( Any, ClassVar, - Dict, - List, Optional, - Tuple, - Type, Union, cast, ) @@ -62,7 +58,7 @@ class AuthenticationConfig(BaseModel): """Base authentication configuration.""" @property - def all_values(self) -> Dict[str, Any]: + def all_values(self) -> dict[str, Any]: """Get all values as a dictionary. Returns: @@ -75,7 +71,7 @@ class ServiceConnectorMeta(ModelMetaclass): """Metaclass responsible for automatically registering ServiceConnector classes.""" def __new__( - mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any] + mcs, name: str, bases: tuple[type[Any], ...], dct: dict[str, Any] ) -> "ServiceConnectorMeta": """Creates a new ServiceConnector class and registers it. @@ -88,7 +84,7 @@ def __new__( The ServiceConnectorMeta class. """ cls = cast( - Type["ServiceConnector"], super().__new__(mcs, name, bases, dct) + type["ServiceConnector"], super().__new__(mcs, name, bases, dct) ) # Skip the following validation and registration for the base class. @@ -146,18 +142,18 @@ class ServiceConnector(BaseModel, metaclass=ServiceConnectorMeta): types of resources that they need to access. """ - id: Optional[UUID] = None - name: Optional[str] = None + id: UUID | None = None + name: str | None = None auth_method: str - resource_type: Optional[str] = None - resource_id: Optional[str] = None - expires_at: Optional[datetime] = None - expires_skew_tolerance: Optional[int] = None - expiration_seconds: Optional[int] = None + resource_type: str | None = None + resource_id: str | None = None + expires_at: datetime | None = None + expires_skew_tolerance: int | None = None + expiration_seconds: int | None = None config: AuthenticationConfig allow_implicit_auth_methods: bool = False - _TYPE: ClassVar[Optional[ServiceConnectorTypeModel]] = None + _TYPE: ClassVar[ServiceConnectorTypeModel | None] = None def __init__(self, **kwargs: Any) -> None: """Initialize a new service connector instance. @@ -304,9 +300,9 @@ def _configure_local_client( @abstractmethod def _auto_configure( cls, - auth_method: Optional[str] = None, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + auth_method: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, **kwargs: Any, ) -> "ServiceConnector": """Auto-configure a connector instance. @@ -341,9 +337,9 @@ def _auto_configure( @abstractmethod def _verify( self, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, - ) -> List[str]: + resource_type: str | None = None, + resource_id: str | None = None, + ) -> list[str]: """Verify and list all the resources that the connector can access. This method uses the connector's configuration to verify that it can @@ -466,8 +462,8 @@ def _get_connector_client( return copy def _validate_resource_id( - self, resource_type: str, resource_id: Optional[str] - ) -> Optional[str]: + self, resource_type: str, resource_id: str | None + ) -> str | None: """Validate a resource ID value of a certain type against the connector configuration. Args: @@ -554,7 +550,7 @@ def type(self) -> ServiceConnectorTypeModel: return self.get_type() @property - def supported_resource_types(self) -> List[str]: + def supported_resource_types(self) -> list[str]: """The resource types supported by this connector instance. Returns: @@ -596,11 +592,11 @@ def from_model( # instance is configured to access any of the supported resource # types (a multi-type connector). We represent that here by setting the # resource type to None. - resource_type: Optional[str] = None + resource_type: str | None = None if len(model.resource_types) == 1: resource_type = model.resource_types[0] - expiration_seconds: Optional[int] = None + expiration_seconds: int | None = None try: method_spec, _ = spec.find_resource_specifications( model.auth_method, @@ -658,9 +654,9 @@ def from_model( def to_model( self, - name: Optional[str] = None, + name: str | None = None, description: str = "", - labels: Optional[Dict[str, str]] = None, + labels: dict[str, str] | None = None, ) -> "ServiceConnectorRequest": """Convert the connector instance to a service connector model. @@ -707,11 +703,11 @@ def to_model( def to_response_model( self, - user: Optional[UserResponse] = None, - name: Optional[str] = None, - id: Optional[UUID] = None, + user: UserResponse | None = None, + name: str | None = None, + id: UUID | None = None, description: str = "", - labels: Optional[Dict[str, str]] = None, + labels: dict[str, str] | None = None, ) -> "ServiceConnectorResponse": """Convert the connector instance to a service connector response model. @@ -808,12 +804,12 @@ def has_expired(self) -> bool: def validate_runtime_args( self, - resource_type: Optional[str], - resource_id: Optional[str] = None, + resource_type: str | None, + resource_id: str | None = None, require_resource_type: bool = False, require_resource_id: bool = False, **kwargs: Any, - ) -> Tuple[Optional[str], Optional[str]]: + ) -> tuple[str | None, str | None]: """Validate the runtime arguments against the connector configuration. Validate that the supplied runtime arguments are compatible with the @@ -970,9 +966,9 @@ def connect( @classmethod def auto_configure( cls, - auth_method: Optional[str] = None, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + auth_method: str | None = None, + resource_type: str | None = None, + resource_id: str | None = None, **kwargs: Any, ) -> Optional["ServiceConnector"]: """Auto-configure a connector instance. @@ -1092,8 +1088,8 @@ def configure_local_client( def verify( self, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + resource_id: str | None = None, list_resources: bool = True, ) -> ServiceConnectorResourcesModel: """Verify and optionally list all the resources that the connector can access. @@ -1161,7 +1157,7 @@ def verify( resources.set_error(error) return resources - verify_resource_types: List[Optional[str]] = [] + verify_resource_types: list[str | None] = [] verify_resource_id = None if not list_resources and not resource_id: # If we're not listing resources, we're only verifying that the @@ -1281,8 +1277,8 @@ def verify( def get_connector_client( self, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + resource_id: str | None = None, ) -> "ServiceConnector": """Get a connector client that can be used to connect to a resource. diff --git a/src/zenml/service_connectors/service_connector_registry.py b/src/zenml/service_connectors/service_connector_registry.py index 83d1f35543b..7bc7ff7b645 100644 --- a/src/zenml/service_connectors/service_connector_registry.py +++ b/src/zenml/service_connectors/service_connector_registry.py @@ -14,7 +14,7 @@ """Implementation of a service connector registry.""" import threading -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Union from zenml.logger import get_logger from zenml.models import ServiceConnectorTypeModel @@ -33,7 +33,7 @@ class ServiceConnectorRegistry: def __init__(self) -> None: """Initialize the service connector registry.""" - self.service_connector_types: Dict[str, ServiceConnectorTypeModel] = {} + self.service_connector_types: dict[str, ServiceConnectorTypeModel] = {} self.initialized = False self.lock = threading.RLock() @@ -122,10 +122,10 @@ def is_registered(self, connector_type: str) -> bool: def list_service_connector_types( self, - connector_type: Optional[str] = None, - resource_type: Optional[str] = None, - auth_method: Optional[str] = None, - ) -> List[ServiceConnectorTypeModel]: + connector_type: str | None = None, + resource_type: str | None = None, + auth_method: str | None = None, + ) -> list[ServiceConnectorTypeModel]: """Find one or more service connector types that match the given criteria. Args: @@ -141,7 +141,7 @@ def list_service_connector_types( """ self.register_builtin_service_connectors() - matches: List[ServiceConnectorTypeModel] = [] + matches: list[ServiceConnectorTypeModel] = [] for service_connector_type in self.service_connector_types.values(): if ( ( diff --git a/src/zenml/service_connectors/service_connector_utils.py b/src/zenml/service_connectors/service_connector_utils.py index ce803bf6295..30915c7efc7 100644 --- a/src/zenml/service_connectors/service_connector_utils.py +++ b/src/zenml/service_connectors/service_connector_utils.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Utility methods for Service Connectors.""" -from typing import Dict, List, Union from uuid import UUID from zenml.client import Client @@ -29,11 +28,11 @@ def _prepare_resource_info( - connector_details: Union[UUID, ServiceConnectorInfo], - resource_ids: List[str], + connector_details: UUID | ServiceConnectorInfo, + resource_ids: list[str], stack_component_type: StackComponentType, flavor: str, - required_configuration: Dict[str, str], + required_configuration: dict[str, str], flavor_display_name: str, use_resource_value_as_fixed_config: bool = False, ) -> ResourcesInfo: @@ -57,9 +56,9 @@ def _prepare_resource_info( def _raise_specific_cloud_exception_if_needed( cloud_provider: str, - artifact_stores: List[ResourcesInfo], - orchestrators: List[ResourcesInfo], - container_registries: List[ResourcesInfo], + artifact_stores: list[ResourcesInfo], + orchestrators: list[ResourcesInfo], + container_registries: list[ResourcesInfo], ) -> None: AWS_DOCS = "https://docs.zenml.io/stacks/service-connectors/connector-types/aws-service-connector" GCP_DOCS = "https://docs.zenml.io/stacks/service-connectors/connector-types/gcp-service-connector" @@ -164,7 +163,7 @@ def _raise_specific_cloud_exception_if_needed( def get_resources_options_from_resource_model_for_full_stack( - connector_details: Union[UUID, ServiceConnectorInfo], + connector_details: UUID | ServiceConnectorInfo, ) -> ServiceConnectorResourcesInfo: """Get the resource options from the resource model for the full stack. @@ -206,9 +205,9 @@ def get_resources_options_from_resource_model_for_full_stack( else: connector_type = resource_model.connector_type.connector_type - artifact_stores: List[ResourcesInfo] = [] - orchestrators: List[ResourcesInfo] = [] - container_registries: List[ResourcesInfo] = [] + artifact_stores: list[ResourcesInfo] = [] + orchestrators: list[ResourcesInfo] = [] + container_registries: list[ResourcesInfo] = [] if connector_type == "aws": for each in resources: diff --git a/src/zenml/services/__init__.py b/src/zenml/services/__init__.py index d6244539246..94231866efd 100644 --- a/src/zenml/services/__init__.py +++ b/src/zenml/services/__init__.py @@ -52,8 +52,6 @@ TCPEndpointHealthMonitorConfig, ) from zenml.services.service_status import ServiceStatus -from zenml.enums import ServiceState -from zenml.models.v2.misc.service import ServiceType __all__ = [ "ServiceConfig", diff --git a/src/zenml/services/container/container_service.py b/src/zenml/services/container/container_service.py index 2223cf58637..45380c11cdb 100644 --- a/src/zenml/services/container/container_service.py +++ b/src/zenml/services/container/container_service.py @@ -19,7 +19,8 @@ import tempfile import time from abc import abstractmethod -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import Any +from collections.abc import Generator import docker.errors as docker_errors from docker.client import DockerClient @@ -67,7 +68,7 @@ class ContainerServiceConfig(ServiceConfig): image: the container image to use for the service. """ - root_runtime_path: Optional[str] = None + root_runtime_path: str | None = None singleton: bool = False image: str = DOCKER_ZENML_SERVER_DEFAULT_IMAGE @@ -80,10 +81,10 @@ class ContainerServiceStatus(ServiceStatus): file used to start the service daemon and the logfile) are located """ - runtime_path: Optional[str] = None + runtime_path: str | None = None @property - def config_file(self) -> Optional[str]: + def config_file(self) -> str | None: """Get the path to the service configuration file. Returns: @@ -95,7 +96,7 @@ def config_file(self) -> Optional[str]: return os.path.join(self.runtime_path, SERVICE_CONFIG_FILE_NAME) @property - def log_file(self) -> Optional[str]: + def log_file(self) -> str | None: """Get the path to the log file where the service output is/has been logged. Returns: @@ -168,9 +169,9 @@ def run(self) -> None: default_factory=ContainerServiceStatus ) # TODO [ENG-705]: allow multiple endpoints per service - endpoint: Optional[ContainerServiceEndpoint] = None + endpoint: ContainerServiceEndpoint | None = None - _docker_client: Optional[DockerClient] = None + _docker_client: DockerClient | None = None @property def docker_client(self) -> DockerClient: @@ -210,7 +211,7 @@ def get_service_status_message(self) -> str: ) return msg - def check_status(self) -> Tuple[ServiceState, str]: + def check_status(self) -> tuple[ServiceState, str]: """Check the the current operational state of the docker container. Returns: @@ -222,7 +223,7 @@ def check_status(self) -> Tuple[ServiceState, str]: if not check_docker(): return (ServiceState.INACTIVE, "Docker daemon is not running") - container: Optional[Container] = None + container: Container | None = None try: container = self.docker_client.containers.get(self.container_id) except docker_errors.NotFound: @@ -268,7 +269,7 @@ def _setup_runtime_path(self) -> None: prefix="zenml-service-" ) - def _get_container_cmd(self) -> Tuple[List[str], Dict[str, str]]: + def _get_container_cmd(self) -> tuple[list[str], dict[str, str]]: """Get the command to run the service container. The default implementation provided by this class is the following: @@ -320,7 +321,7 @@ def _get_container_cmd(self) -> Tuple[List[str], Dict[str, str]]: return command, command_env - def _get_container_volumes(self) -> Dict[str, Dict[str, str]]: + def _get_container_volumes(self) -> dict[str, dict[str, str]]: """Get the volumes to mount into the service container. The default implementation provided by this class mounts the @@ -336,7 +337,7 @@ def _get_container_volumes(self) -> Dict[str, Dict[str, str]]: A dictionary mapping host paths to dictionaries containing the mount options for each volume. """ - volumes: Dict[str, Dict[str, str]] = {} + volumes: dict[str, dict[str, str]] = {} assert self.status.runtime_path is not None @@ -353,7 +354,7 @@ def _get_container_volumes(self) -> Dict[str, Dict[str, str]]: return volumes @property - def container(self) -> Optional[Container]: + def container(self) -> Container | None: """Get the docker container for the service. Returns: @@ -400,7 +401,7 @@ def _start_container(self) -> None: self._setup_runtime_path() - ports: Dict[int, Optional[int]] = {} + ports: dict[int, int | None] = {} if self.endpoint: self.endpoint.prepare_for_start() if self.endpoint.status.port: @@ -410,7 +411,7 @@ def _start_container(self) -> None: volumes = self._get_container_volumes() try: - uid_args: Dict[str, Any] = {} + uid_args: dict[str, Any] = {} if sys.platform == "win32": # File permissions are not checked on Windows. This if clause # prevents mypy from complaining about unused 'type: ignore' @@ -494,7 +495,7 @@ def deprovision(self, force: bool = False) -> None: self._stop_daemon(force) def get_logs( - self, follow: bool = False, tail: Optional[int] = None + self, follow: bool = False, tail: int | None = None ) -> Generator[str, bool, None]: """Retrieve the service logs. @@ -510,7 +511,7 @@ def get_logs( ): return - with open(self.status.log_file, "r") as f: + with open(self.status.log_file) as f: if tail: # TODO[ENG-864]: implement a more efficient tailing mechanism that # doesn't read the entire file diff --git a/src/zenml/services/container/container_service_endpoint.py b/src/zenml/services/container/container_service_endpoint.py index 7bcfdf34391..4d93868f39c 100644 --- a/src/zenml/services/container/container_service_endpoint.py +++ b/src/zenml/services/container/container_service_endpoint.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Implementation of a containerized service endpoint.""" -from typing import Optional, Union from pydantic import Field @@ -51,7 +50,7 @@ class ContainerServiceEndpointConfig(ServiceEndpointConfig): """ protocol: ServiceEndpointProtocol = ServiceEndpointProtocol.TCP - port: Optional[int] = None + port: int | None = None allocate_port: bool = True @@ -78,9 +77,9 @@ class ContainerServiceEndpoint(BaseServiceEndpoint): status: ContainerServiceEndpointStatus = Field( default_factory=ContainerServiceEndpointStatus ) - monitor: Optional[ - Union[HTTPEndpointHealthMonitor, TCPEndpointHealthMonitor] - ] = Field(..., discriminator="type") + monitor: None | ( + HTTPEndpointHealthMonitor | TCPEndpointHealthMonitor + ) = Field(..., discriminator="type") def _lookup_free_port(self) -> int: """Search for a free TCP port for the service endpoint. @@ -107,7 +106,7 @@ def _lookup_free_port(self) -> int: if port_available(self.config.port): return self.config.port if not self.config.allocate_port: - raise IOError(f"TCP port {self.config.port} is not available.") + raise OSError(f"TCP port {self.config.port} is not available.") # Attempt to reuse the port used when the services was last running if self.status.port and port_available(self.status.port): @@ -116,7 +115,7 @@ def _lookup_free_port(self) -> int: port = scan_for_available_port() if port: return port - raise IOError("No free TCP ports found") + raise OSError("No free TCP ports found") def prepare_for_start(self) -> None: """Prepare the service endpoint for starting. diff --git a/src/zenml/services/container/entrypoint.py b/src/zenml/services/container/entrypoint.py index b7476bc19b9..59828f3f1f0 100644 --- a/src/zenml/services/container/entrypoint.py +++ b/src/zenml/services/container/entrypoint.py @@ -56,7 +56,7 @@ def launch_service(service_config_file: str) -> None: logger = get_logger(__name__) logger.info("Loading service configuration from %s", service_config_file) - with open(service_config_file, "r") as f: + with open(service_config_file) as f: config = f.read() integration_registry.activate_integrations() diff --git a/src/zenml/services/local/local_daemon_entrypoint.py b/src/zenml/services/local/local_daemon_entrypoint.py index 255541c0418..c09b2cd6c98 100644 --- a/src/zenml/services/local/local_daemon_entrypoint.py +++ b/src/zenml/services/local/local_daemon_entrypoint.py @@ -76,7 +76,7 @@ def launch_service(service_config_file: str) -> None: logger.info( "Loading service daemon configuration from %s", service_config_file ) - with open(service_config_file, "r") as f: + with open(service_config_file) as f: config = f.read() integration_registry.activate_integrations() diff --git a/src/zenml/services/local/local_service.py b/src/zenml/services/local/local_service.py index 767e8075f01..3abf11818e5 100644 --- a/src/zenml/services/local/local_service.py +++ b/src/zenml/services/local/local_service.py @@ -20,7 +20,7 @@ import tempfile import time from abc import abstractmethod -from typing import Dict, Generator, List, Optional, Tuple +from collections.abc import Generator import psutil from psutil import NoSuchProcess @@ -63,7 +63,7 @@ class LocalDaemonServiceConfig(ServiceConfig): """ silent_daemon: bool = False - root_runtime_path: Optional[str] = None + root_runtime_path: str | None = None singleton: bool = False blocking: bool = False @@ -79,14 +79,14 @@ class LocalDaemonServiceStatus(ServiceStatus): is suppressed (redirected to /dev/null). """ - runtime_path: Optional[str] = None + runtime_path: str | None = None # TODO [ENG-704]: remove field duplication between XServiceStatus and # XServiceConfig (e.g. keep a private reference to the config in the # status) silent_daemon: bool = False @property - def config_file(self) -> Optional[str]: + def config_file(self) -> str | None: """Get the path to the configuration file used to start the service daemon. Returns: @@ -98,7 +98,7 @@ def config_file(self) -> Optional[str]: return os.path.join(self.runtime_path, SERVICE_DAEMON_CONFIG_FILE_NAME) @property - def log_file(self) -> Optional[str]: + def log_file(self) -> str | None: """Get the path to the log file where the service output is/has been logged. Returns: @@ -110,7 +110,7 @@ def log_file(self) -> Optional[str]: return os.path.join(self.runtime_path, SERVICE_DAEMON_LOG_FILE_NAME) @property - def pid_file(self) -> Optional[str]: + def pid_file(self) -> str | None: """Get the path to a daemon PID file. This is where the last known PID of the daemon process is stored. @@ -124,7 +124,7 @@ def pid_file(self) -> Optional[str]: return os.path.join(self.runtime_path, SERVICE_DAEMON_PID_FILE_NAME) @property - def pid(self) -> Optional[int]: + def pid(self) -> int | None: """Return the PID of the currently running daemon. Returns: @@ -250,7 +250,7 @@ def run(self) -> None: default_factory=LocalDaemonServiceStatus ) # TODO [ENG-705]: allow multiple endpoints per service - endpoint: Optional[LocalDaemonServiceEndpoint] = None + endpoint: LocalDaemonServiceEndpoint | None = None def get_service_status_message(self) -> str: """Get a message about the current operational state of the service. @@ -270,7 +270,7 @@ def get_service_status_message(self) -> str: ) return msg - def check_status(self) -> Tuple[ServiceState, str]: + def check_status(self) -> tuple[ServiceState, str]: """Check the the current operational state of the daemon process. Returns: @@ -284,7 +284,7 @@ def check_status(self) -> Tuple[ServiceState, str]: # the daemon is running return ServiceState.ACTIVE, "" - def _get_daemon_cmd(self) -> Tuple[List[str], Dict[str, str]]: + def _get_daemon_cmd(self) -> tuple[list[str], dict[str, str]]: """Get the command to run the service daemon. The default implementation provided by this class is the following: @@ -449,7 +449,7 @@ def start(self, timeout: int = 0) -> None: self.run() def get_logs( - self, follow: bool = False, tail: Optional[int] = None + self, follow: bool = False, tail: int | None = None ) -> Generator[str, bool, None]: """Retrieve the service logs. @@ -465,7 +465,7 @@ def get_logs( ): return - with open(self.status.log_file, "r") as f: + with open(self.status.log_file) as f: if tail: # TODO[ENG-864]: implement a more efficient tailing mechanism that # doesn't read the entire file diff --git a/src/zenml/services/local/local_service_endpoint.py b/src/zenml/services/local/local_service_endpoint.py index ecad675c464..bf657227c50 100644 --- a/src/zenml/services/local/local_service_endpoint.py +++ b/src/zenml/services/local/local_service_endpoint.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Implementation of a local service endpoint.""" -from typing import Optional, Union from pydantic import Field @@ -53,7 +52,7 @@ class LocalDaemonServiceEndpointConfig(ServiceEndpointConfig): """ protocol: ServiceEndpointProtocol = ServiceEndpointProtocol.TCP - port: Optional[int] = None + port: int | None = None ip_address: str = DEFAULT_LOCAL_SERVICE_IP_ADDRESS allocate_port: bool = True @@ -81,9 +80,9 @@ class LocalDaemonServiceEndpoint(BaseServiceEndpoint): status: LocalDaemonServiceEndpointStatus = Field( default_factory=LocalDaemonServiceEndpointStatus ) - monitor: Optional[ - Union[HTTPEndpointHealthMonitor, TCPEndpointHealthMonitor] - ] = Field(..., discriminator="type") + monitor: None | ( + HTTPEndpointHealthMonitor | TCPEndpointHealthMonitor + ) = Field(..., discriminator="type") def _lookup_free_port(self) -> int: """Search for a free TCP port for the service endpoint. @@ -110,7 +109,7 @@ def _lookup_free_port(self) -> int: if port_available(self.config.port, self.config.ip_address): return self.config.port if not self.config.allocate_port: - raise IOError(f"TCP port {self.config.port} is not available.") + raise OSError(f"TCP port {self.config.port} is not available.") # Attempt to reuse the port used when the services was last running if self.status.port and port_available(self.status.port): @@ -119,7 +118,7 @@ def _lookup_free_port(self) -> int: port = scan_for_available_port() if port: return port - raise IOError("No free TCP ports found") + raise OSError("No free TCP ports found") def prepare_for_start(self) -> None: """Prepare the service endpoint for starting. diff --git a/src/zenml/services/service.py b/src/zenml/services/service.py index 24f85d07f3a..fe960433cc0 100644 --- a/src/zenml/services/service.py +++ b/src/zenml/services/service.py @@ -20,15 +20,10 @@ from typing import ( TYPE_CHECKING, Any, - Callable, ClassVar, - Dict, - Generator, - Optional, - Tuple, - Type, TypeVar, ) +from collections.abc import Callable, Generator from uuid import UUID from pydantic import ConfigDict @@ -54,8 +49,8 @@ def update_service_status( - pre_status: Optional[ServiceState] = None, - post_status: Optional[ServiceState] = None, + pre_status: ServiceState | None = None, + post_status: ServiceState | None = None, error_status: ServiceState = ServiceState.ERROR, ) -> Callable[[T], T]: """A decorator to update the service status before and after a method call. @@ -146,7 +141,7 @@ def __init__(self, **data: Any): else: raise ValueError("Either 'name' or 'model_name' must be set.") - def get_service_labels(self) -> Dict[str, str]: + def get_service_labels(self) -> dict[str, str]: """Get the service labels. Returns: @@ -183,7 +178,7 @@ class BaseService(BaseTypedModel): config: ServiceConfig status: ServiceStatus # TODO [ENG-703]: allow multiple endpoints per service - endpoint: Optional[BaseServiceEndpoint] = None + endpoint: BaseServiceEndpoint | None = None def __init__( self, @@ -212,7 +207,7 @@ def from_model(cls, model: "ServiceResponse") -> "BaseService": """ if not model.service_source: raise ValueError("Service source not found in the model.") - class_: Type[BaseService] = source_utils.load_and_validate_class( + class_: type[BaseService] = source_utils.load_and_validate_class( source=model.service_source, expected_class=BaseService ) return class_( @@ -235,13 +230,13 @@ def from_json(cls, json_str: str) -> "BaseTypedModel": The loaded service object. """ service_dict = json.loads(json_str) - class_: Type[BaseService] = source_utils.load_and_validate_class( + class_: type[BaseService] = source_utils.load_and_validate_class( source=service_dict["type"], expected_class=BaseService ) return class_.from_dict(service_dict) @abstractmethod - def check_status(self) -> Tuple[ServiceState, str]: + def check_status(self) -> tuple[ServiceState, str]: """Check the the current operational state of the external service. This method should be overridden by subclasses that implement @@ -256,7 +251,7 @@ def check_status(self) -> Tuple[ServiceState, str]: @abstractmethod def get_logs( - self, follow: bool = False, tail: Optional[int] = None + self, follow: bool = False, tail: int | None = None ) -> Generator[str, bool, None]: """Retrieve the service logs. @@ -479,7 +474,7 @@ def stop(self, timeout: int = 0, force: bool = False) -> None: f"'{self.status.last_error}'" ) - def get_prediction_url(self) -> Optional[str]: + def get_prediction_url(self) -> str | None: """Gets the prediction URL for the endpoint. Returns: @@ -494,7 +489,7 @@ def get_prediction_url(self) -> Optional[str]: ) return prediction_url - def get_healthcheck_url(self) -> Optional[str]: + def get_healthcheck_url(self) -> str | None: """Gets the healthcheck URL for the endpoint. Returns: @@ -533,7 +528,7 @@ class BaseDeploymentService(BaseService): """Base class for deployment services.""" @property - def prediction_url(self) -> Optional[str]: + def prediction_url(self) -> str | None: """Gets the prediction URL for the endpoint. Returns: @@ -542,7 +537,7 @@ def prediction_url(self) -> Optional[str]: return None @property - def healthcheck_url(self) -> Optional[str]: + def healthcheck_url(self) -> str | None: """Gets the healthcheck URL for the endpoint. Returns: diff --git a/src/zenml/services/service_endpoint.py b/src/zenml/services/service_endpoint.py index 6064fb1614a..7577d2d9fa7 100644 --- a/src/zenml/services/service_endpoint.py +++ b/src/zenml/services/service_endpoint.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Implementation of a ZenML service endpoint.""" -from typing import Any, Optional, Tuple +from typing import Any from zenml.constants import DEFAULT_LOCAL_SERVICE_IP_ADDRESS from zenml.enums import ServiceState @@ -65,11 +65,11 @@ class ServiceEndpointStatus(ServiceStatus): """ protocol: ServiceEndpointProtocol = ServiceEndpointProtocol.TCP - hostname: Optional[str] = None - port: Optional[int] = None + hostname: str | None = None + port: int | None = None @property - def uri(self) -> Optional[str]: + def uri(self) -> str | None: """Get the URI of the service endpoint. Returns: @@ -106,7 +106,7 @@ class BaseServiceEndpoint(BaseTypedModel): config: ServiceEndpointConfig status: ServiceEndpointStatus # TODO [ENG-701]: allow multiple monitors per endpoint - monitor: Optional[BaseServiceEndpointHealthMonitor] = None + monitor: BaseServiceEndpointHealthMonitor | None = None def __init__( self, @@ -122,7 +122,7 @@ def __init__( super().__init__(*args, **kwargs) self.config.name = self.config.name or self.__class__.__name__ - def check_status(self) -> Tuple[ServiceState, str]: + def check_status(self) -> tuple[ServiceState, str]: """Check the the current operational state of the external service endpoint. Returns: diff --git a/src/zenml/services/service_monitor.py b/src/zenml/services/service_monitor.py index 424194b7bec..f33099e29b6 100644 --- a/src/zenml/services/service_monitor.py +++ b/src/zenml/services/service_monitor.py @@ -14,7 +14,7 @@ """Implementation of the service health monitor.""" from abc import abstractmethod -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING import requests from pydantic import Field @@ -57,7 +57,7 @@ class BaseServiceEndpointHealthMonitor(BaseTypedModel): @abstractmethod def check_endpoint_status( self, endpoint: "BaseServiceEndpoint" - ) -> Tuple[ServiceState, str]: + ) -> tuple[ServiceState, str]: """Check the the current operational state of the external service endpoint. Args: @@ -106,7 +106,7 @@ class HTTPEndpointHealthMonitor(BaseServiceEndpointHealthMonitor): def get_healthcheck_uri( self, endpoint: "BaseServiceEndpoint" - ) -> Optional[str]: + ) -> str | None: """Get the healthcheck URI for the given service endpoint. Args: @@ -127,7 +127,7 @@ def get_healthcheck_uri( def check_endpoint_status( self, endpoint: "BaseServiceEndpoint" - ) -> Tuple[ServiceState, str]: + ) -> tuple[ServiceState, str]: """Run a HTTP endpoint API healthcheck. Args: @@ -199,7 +199,7 @@ class TCPEndpointHealthMonitor(BaseServiceEndpointHealthMonitor): def check_endpoint_status( self, endpoint: "BaseServiceEndpoint" - ) -> Tuple[ServiceState, str]: + ) -> tuple[ServiceState, str]: """Run a TCP endpoint healthcheck. Args: diff --git a/src/zenml/services/service_status.py b/src/zenml/services/service_status.py index ae8b128da11..559721e37bf 100644 --- a/src/zenml/services/service_status.py +++ b/src/zenml/services/service_status.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Implementation of the ServiceStatus class.""" -from typing import Optional from zenml.enums import ServiceState from zenml.logger import get_logger @@ -44,7 +43,7 @@ class ServiceStatus(BaseTypedModel): def update_state( self, - new_state: Optional[ServiceState] = None, + new_state: ServiceState | None = None, error: str = "", ) -> None: """Update the current operational state to reflect a new state value and/or error. diff --git a/src/zenml/stack/authentication_mixin.py b/src/zenml/stack/authentication_mixin.py index 6e4024d78b3..8d70e1a3a39 100644 --- a/src/zenml/stack/authentication_mixin.py +++ b/src/zenml/stack/authentication_mixin.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Stack component mixin for authentication.""" -from typing import Optional, Type, TypeVar, cast +from typing import TypeVar, cast from pydantic import BaseModel, Field @@ -33,7 +33,7 @@ class AuthenticationConfigMixin(StackComponentConfig): Field descriptions are defined inline using Field() descriptors. """ - authentication_secret: Optional[str] = Field( + authentication_secret: str | None = Field( default=None, description="Name of the ZenML secret containing authentication credentials.", ) @@ -57,7 +57,7 @@ def config(self) -> AuthenticationConfigMixin: def get_authentication_secret( self, - ) -> Optional[SecretResponse]: + ) -> SecretResponse | None: """Gets the secret referred to by the authentication secret attribute. Returns: @@ -83,8 +83,8 @@ def get_authentication_secret( ) def get_typed_authentication_secret( - self, expected_schema_type: Type[T] - ) -> Optional[T]: + self, expected_schema_type: type[T] + ) -> T | None: """Gets a typed secret referred to by the authentication secret attribute. Args: diff --git a/src/zenml/stack/flavor.py b/src/zenml/stack/flavor.py index f80b7f52936..0b518e182b8 100644 --- a/src/zenml/stack/flavor.py +++ b/src/zenml/stack/flavor.py @@ -15,7 +15,7 @@ import os from abc import abstractmethod -from typing import Any, Dict, Optional, Type, cast +from typing import Any, cast from zenml.enums import StackComponentType from zenml.exceptions import CustomFlavorImportError @@ -42,7 +42,7 @@ def name(self) -> str: """ @property - def docs_url(self) -> Optional[str]: + def docs_url(self) -> str | None: """A url to point at docs explaining this flavor. Returns: @@ -51,7 +51,7 @@ def docs_url(self) -> Optional[str]: return None @property - def sdk_docs_url(self) -> Optional[str]: + def sdk_docs_url(self) -> str | None: """A url to point at SDK docs explaining this flavor. Returns: @@ -60,7 +60,7 @@ def sdk_docs_url(self) -> Optional[str]: return None @property - def logo_url(self) -> Optional[str]: + def logo_url(self) -> str | None: """A url to represent the flavor in the dashboard. Returns: @@ -79,7 +79,7 @@ def type(self) -> StackComponentType: @property @abstractmethod - def implementation_class(self) -> Type[StackComponent]: + def implementation_class(self) -> type[StackComponent]: """Implementation class for this flavor. Returns: @@ -88,7 +88,7 @@ def implementation_class(self) -> Type[StackComponent]: @property @abstractmethod - def config_class(self) -> Type[StackComponentConfig]: + def config_class(self) -> type[StackComponentConfig]: """Returns `StackComponentConfig` config class. Returns: @@ -96,7 +96,7 @@ def config_class(self) -> Type[StackComponentConfig]: """ @property - def config_schema(self) -> Dict[str, Any]: + def config_schema(self) -> dict[str, Any]: """The config schema for a flavor. Returns: @@ -107,7 +107,7 @@ def config_schema(self) -> Dict[str, Any]: @property def service_connector_requirements( self, - ) -> Optional[ServiceConnectorRequirements]: + ) -> ServiceConnectorRequirements | None: """Service connector resource requirements for service connectors. Specifies resource requirements that are used to filter the available @@ -158,7 +158,7 @@ def from_model(cls, flavor_model: FlavorResponse) -> "Flavor": def to_model( self, - integration: Optional[str] = None, + integration: str | None = None, is_custom: bool = True, ) -> FlavorRequest: """Converts a flavor to a model. @@ -266,7 +266,7 @@ def generate_default_sdk_docs_url(self) -> str: def validate_flavor_source( source: str, component_type: StackComponentType -) -> Type["Flavor"]: +) -> type["Flavor"]: """Import a StackComponent class from a given source and validate its type. Args: diff --git a/src/zenml/stack/flavor_registry.py b/src/zenml/stack/flavor_registry.py index e75520bee43..cd14635ee55 100644 --- a/src/zenml/stack/flavor_registry.py +++ b/src/zenml/stack/flavor_registry.py @@ -14,7 +14,7 @@ """Implementation of the ZenML flavor registry.""" from collections import defaultdict -from typing import DefaultDict, Dict, List, Type +from typing import DefaultDict from zenml.analytics.utils import analytics_disabler from zenml.enums import StackComponentType @@ -40,7 +40,7 @@ class FlavorRegistry: def __init__(self) -> None: """Initialization of the flavors.""" self._flavors: DefaultDict[ - StackComponentType, Dict[str, FlavorResponse] + StackComponentType, dict[str, FlavorResponse] ] = defaultdict(dict) def register_flavors(self, store: BaseZenStore) -> None: @@ -53,7 +53,7 @@ def register_flavors(self, store: BaseZenStore) -> None: self.register_integration_flavors(store=store) @property - def builtin_flavors(self) -> List[Type[Flavor]]: + def builtin_flavors(self) -> list[type[Flavor]]: """A list of all default in-built flavors. Returns: @@ -90,7 +90,7 @@ def builtin_flavors(self) -> List[Type[Flavor]]: return flavors @property - def integration_flavors(self) -> List[Type[Flavor]]: + def integration_flavors(self) -> list[type[Flavor]]: """A list of all default integration flavors. Returns: diff --git a/src/zenml/stack/stack.py b/src/zenml/stack/stack.py index 4651b50c9ca..f26c11cbf11 100644 --- a/src/zenml/stack/stack.py +++ b/src/zenml/stack/stack.py @@ -21,13 +21,8 @@ TYPE_CHECKING, AbstractSet, Any, - Dict, - List, NoReturn, Optional, - Set, - Tuple, - Type, ) from uuid import UUID @@ -79,7 +74,7 @@ logger = get_logger(__name__) -_STACK_CACHE: Dict[Tuple[UUID, Optional[datetime]], "Stack"] = {} +_STACK_CACHE: dict[tuple[UUID, datetime | None], "Stack"] = {} class Stack: @@ -97,8 +92,8 @@ def __init__( id: UUID, name: str, *, - environment: Optional[Dict[str, str]] = None, - secrets: Optional[List[UUID]] = None, + environment: dict[str, str] | None = None, + secrets: list[UUID] | None = None, orchestrator: "BaseOrchestrator", artifact_store: "BaseArtifactStore", container_registry: Optional["BaseContainerRegistry"] = None, @@ -206,9 +201,9 @@ def from_components( cls, id: UUID, name: str, - components: Dict[StackComponentType, "StackComponent"], - environment: Optional[Dict[str, str]] = None, - secrets: Optional[List[UUID]] = None, + components: dict[StackComponentType, "StackComponent"], + environment: dict[str, str] | None = None, + secrets: list[UUID] | None = None, ) -> "Stack": """Creates a stack instance from a dict of stack components. @@ -245,7 +240,7 @@ def from_components( from zenml.step_operators import BaseStepOperator def _raise_type_error( - component: Optional["StackComponent"], expected_class: Type[Any] + component: Optional["StackComponent"], expected_class: type[Any] ) -> NoReturn: """Raises a TypeError that the component has an unexpected type. @@ -355,7 +350,7 @@ def _raise_type_error( ) @property - def components(self) -> Dict[StackComponentType, "StackComponent"]: + def components(self) -> dict[StackComponentType, "StackComponent"]: """All components of the stack. Returns: @@ -517,7 +512,7 @@ def deployer(self) -> Optional["BaseDeployer"]: """ return self._deployer - def dict(self) -> Dict[str, str]: + def dict(self) -> dict[str, str]: """Converts the stack into a dictionary. Returns: @@ -534,8 +529,8 @@ def dict(self) -> Dict[str, str]: def requirements( self, - exclude_components: Optional[AbstractSet[StackComponentType]] = None, - ) -> Set[str]: + exclude_components: AbstractSet[StackComponentType] | None = None, + ) -> set[str]: """Set of PyPI requirements for the stack. This method combines the requirements of all stack components (except @@ -557,7 +552,7 @@ def requirements( return set.union(*requirements) if requirements else set() @property - def apt_packages(self) -> List[str]: + def apt_packages(self) -> list[str]: """List of APT package requirements for the stack. Returns: @@ -570,7 +565,7 @@ def apt_packages(self) -> List[str]: ] @property - def environment(self) -> Dict[str, str]: + def environment(self) -> dict[str, str]: """Environment variables to set when running on this stack. Returns: @@ -579,7 +574,7 @@ def environment(self) -> Dict[str, str]: return self._environment @property - def secrets(self) -> List[UUID]: + def secrets(self) -> list[UUID]: """Secrets to set as environment variables when running on this stack. Returns: @@ -623,7 +618,7 @@ def check_local_paths(self) -> bool: return has_local_paths @property - def required_secrets(self) -> Set["secret_utils.SecretReference"]: + def required_secrets(self) -> set["secret_utils.SecretReference"]: """All required secrets for this stack. Returns: @@ -636,7 +631,7 @@ def required_secrets(self) -> Set["secret_utils.SecretReference"]: return set.union(*secrets) if secrets else set() @property - def setting_classes(self) -> Dict[str, Type["BaseSettings"]]: + def setting_classes(self) -> dict[str, type["BaseSettings"]]: """Setting classes of all components of this stack. Returns: @@ -850,7 +845,7 @@ def prepare_pipeline_submission( def get_docker_builds( self, snapshot: "PipelineSnapshotBase" - ) -> List["BuildConfiguration"]: + ) -> list["BuildConfiguration"]: """Gets the Docker builds required for the stack. Args: @@ -885,7 +880,7 @@ def deploy_pipeline( self, snapshot: "PipelineSnapshotResponse", deployment_name: str, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> "DeploymentResponse": """Deploys a pipeline on this stack. @@ -916,7 +911,7 @@ def deploy_pipeline( def _get_active_components_for_step( self, step_config: "StepConfiguration" - ) -> Dict[StackComponentType, "StackComponent"]: + ) -> dict[StackComponentType, "StackComponent"]: """Gets all the active stack components for a stack. Args: @@ -963,7 +958,7 @@ def prepare_step_run(self, info: "StepRunInfo") -> None: def get_pipeline_run_metadata( self, run_id: UUID - ) -> Dict[UUID, Dict[str, MetadataType]]: + ) -> dict[UUID, dict[str, MetadataType]]: """Get general component-specific metadata for a pipeline run. Args: @@ -972,7 +967,7 @@ def get_pipeline_run_metadata( Returns: A dictionary mapping component IDs to the metadata they created. """ - pipeline_run_metadata: Dict[UUID, Dict[str, MetadataType]] = {} + pipeline_run_metadata: dict[UUID, dict[str, MetadataType]] = {} for component in self.components.values(): try: component_metadata = component.get_pipeline_run_metadata( @@ -989,7 +984,7 @@ def get_pipeline_run_metadata( def get_step_run_metadata( self, info: "StepRunInfo" - ) -> Dict[UUID, Dict[str, MetadataType]]: + ) -> dict[UUID, dict[str, MetadataType]]: """Get component-specific metadata for a step run. Args: @@ -998,7 +993,7 @@ def get_step_run_metadata( Returns: A dictionary mapping component IDs to the metadata they created. """ - step_run_metadata: Dict[UUID, Dict[str, MetadataType]] = {} + step_run_metadata: dict[UUID, dict[str, MetadataType]] = {} for component in self._get_active_components_for_step( info.config ).values(): diff --git a/src/zenml/stack/stack_component.py b/src/zenml/stack/stack_component.py index 284592c44f4..a0086cfd58c 100644 --- a/src/zenml/stack/stack_component.py +++ b/src/zenml/stack/stack_component.py @@ -18,7 +18,7 @@ from collections.abc import Mapping, Sequence from datetime import datetime from inspect import isclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union +from typing import TYPE_CHECKING, Any, Optional, Union from uuid import UUID from pydantic import BaseModel, ConfigDict, model_validator @@ -122,7 +122,7 @@ def __init__( super().__init__(**kwargs) @property - def required_secrets(self) -> Set[secret_utils.SecretReference]: + def required_secrets(self) -> set[secret_utils.SecretReference]: """All required secrets for this stack component. Returns: @@ -259,7 +259,7 @@ def _is_part_of_active_stack(self) -> bool: @model_validator(mode="before") @classmethod @pydantic_utils.before_validator_handler - def _convert_json_strings(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def _convert_json_strings(cls, data: dict[str, Any]) -> dict[str, Any]: """Converts potential JSON strings. Args: @@ -332,15 +332,15 @@ def __init__( config: StackComponentConfig, flavor: str, type: StackComponentType, - user: Optional[UUID], + user: UUID | None, created: datetime, updated: datetime, - environment: Optional[Dict[str, str]] = None, - secrets: Optional[List[UUID]] = None, - labels: Optional[Dict[str, Any]] = None, - connector_requirements: Optional[ServiceConnectorRequirements] = None, - connector: Optional[UUID] = None, - connector_resource_id: Optional[str] = None, + environment: dict[str, str] | None = None, + secrets: list[UUID] | None = None, + labels: dict[str, Any] | None = None, + connector_requirements: ServiceConnectorRequirements | None = None, + connector: UUID | None = None, + connector_resource_id: str | None = None, *args: Any, **kwargs: Any, ): @@ -390,7 +390,7 @@ def __init__( self.connector_requirements = connector_requirements self.connector = connector self.connector_resource_id = connector_resource_id - self._connector_instance: Optional[ServiceConnector] = None + self._connector_instance: ServiceConnector | None = None @classmethod def from_model( @@ -481,7 +481,7 @@ def config(self) -> StackComponentConfig: return self._config @property - def settings_class(self) -> Optional[Type["BaseSettings"]]: + def settings_class(self) -> type["BaseSettings"] | None: """Class specifying available settings for this component. Returns: @@ -631,7 +631,7 @@ def get_connector(self) -> Optional["ServiceConnector"]: return self._connector_instance @property - def log_file(self) -> Optional[str]: + def log_file(self) -> str | None: """Optional path to a log file for the stack component. Returns: @@ -643,7 +643,7 @@ def log_file(self) -> Optional[str]: return None @property - def requirements(self) -> Set[str]: + def requirements(self) -> set[str]: """Set of PyPI requirements for the component. Returns: @@ -654,7 +654,7 @@ def requirements(self) -> Set[str]: return set(get_requirements_for_module(self.__module__)) @property - def apt_packages(self) -> List[str]: + def apt_packages(self) -> list[str]: """List of APT package requirements for the component. Returns: @@ -666,7 +666,7 @@ def apt_packages(self) -> List[str]: return integration.APT_PACKAGES if integration else [] @property - def local_path(self) -> Optional[str]: + def local_path(self) -> str | None: """Path to a local directory to store persistent information. This property should only be implemented by components that need to @@ -701,7 +701,7 @@ def local_path(self) -> Optional[str]: def get_docker_builds( self, snapshot: "PipelineSnapshotBase" - ) -> List["BuildConfiguration"]: + ) -> list["BuildConfiguration"]: """Gets the Docker builds required for the component. Args: @@ -714,7 +714,7 @@ def get_docker_builds( def get_pipeline_run_metadata( self, run_id: UUID - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Get general component-specific metadata for a pipeline run. Args: @@ -734,7 +734,7 @@ def prepare_step_run(self, info: "StepRunInfo") -> None: def get_step_run_metadata( self, info: "StepRunInfo" - ) -> Dict[str, "MetadataType"]: + ) -> dict[str, "MetadataType"]: """Get component- and step-specific metadata after a step ran. Args: @@ -754,7 +754,7 @@ def cleanup_step_run(self, info: "StepRunInfo", step_failed: bool) -> None: """ @property - def post_registration_message(self) -> Optional[str]: + def post_registration_message(self) -> str | None: """Optional message printed after the stack component is registered. Returns: @@ -778,7 +778,6 @@ def validator(self) -> Optional["StackValidator"]: def cleanup(self) -> None: """Cleans up the component after it has been used.""" - pass def __repr__(self) -> str: """String representation of the stack component. diff --git a/src/zenml/stack/stack_validator.py b/src/zenml/stack/stack_validator.py index 8516c0cb9a7..2c13c8663a5 100644 --- a/src/zenml/stack/stack_validator.py +++ b/src/zenml/stack/stack_validator.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Implementation of the ZenML Stack Validator.""" -from typing import TYPE_CHECKING, AbstractSet, Callable, Optional, Tuple +from typing import TYPE_CHECKING, AbstractSet +from collections.abc import Callable from zenml.enums import StackComponentType from zenml.exceptions import StackValidationError @@ -37,10 +38,10 @@ class StackValidator: def __init__( self, - required_components: Optional[AbstractSet[StackComponentType]] = None, - custom_validation_function: Optional[ - Callable[["Stack"], Tuple[bool, str]] - ] = None, + required_components: AbstractSet[StackComponentType] | None = None, + custom_validation_function: None | ( + Callable[["Stack"], tuple[bool, str]] + ) = None, ): """Initializes a `StackValidator` instance. diff --git a/src/zenml/stack/utils.py b/src/zenml/stack/utils.py index 90d76493341..5129748d7c7 100644 --- a/src/zenml/stack/utils.py +++ b/src/zenml/stack/utils.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Util functions for handling stacks, components, and flavors.""" -from typing import Any, Dict, Optional, Type, Union +from typing import Any from zenml.client import Client from zenml.enums import StackComponentType, StoreType @@ -27,12 +27,12 @@ def validate_stack_component_config( - configuration_dict: Dict[str, Any], - flavor: Union[FlavorResponse, str], + configuration_dict: dict[str, Any], + flavor: FlavorResponse | str, component_type: StackComponentType, - zen_store: Optional[BaseZenStore] = None, + zen_store: BaseZenStore | None = None, validate_custom_flavors: bool = True, -) -> Optional[StackComponentConfig]: +) -> StackComponentConfig | None: """Validate the configuration of a stack component. Args: @@ -82,7 +82,7 @@ def validate_stack_component_config( config_class = flavor_class.config_class # Make sure extras are forbidden for the config class. Due to inheritance # order, some config classes allow extras by accident which we patch here. - validation_config_class: Type[StackComponentConfig] = type( + validation_config_class: type[StackComponentConfig] = type( config_class.__name__, (config_class,), {"model_config": {"extra": "ignore"}}, diff --git a/src/zenml/stack_deployments/aws_stack_deployment.py b/src/zenml/stack_deployments/aws_stack_deployment.py index db2645b0994..1d568f7b457 100644 --- a/src/zenml/stack_deployments/aws_stack_deployment.py +++ b/src/zenml/stack_deployments/aws_stack_deployment.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Functionality to deploy a ZenML stack to AWS.""" -from typing import ClassVar, Dict, List, Optional +from typing import ClassVar from zenml.enums import StackDeploymentProvider from zenml.models import StackDeploymentConfig @@ -119,7 +119,7 @@ def post_deploy_instructions(cls) -> str: """ @classmethod - def integrations(cls) -> List[str]: + def integrations(cls) -> list[str]: """Return the ZenML integrations required for the stack. Returns: @@ -132,7 +132,7 @@ def integrations(cls) -> List[str]: ] @classmethod - def permissions(cls) -> Dict[str, List[str]]: + def permissions(cls) -> dict[str, list[str]]: """Return the permissions granted to ZenML to access the cloud resources. Returns: @@ -195,7 +195,7 @@ def permissions(cls) -> Dict[str, List[str]]: } @classmethod - def locations(cls) -> Dict[str, str]: + def locations(cls) -> dict[str, str]: """Return the locations where the ZenML stack can be deployed. Returns: @@ -282,7 +282,7 @@ def get_deployment_config( f"{region}#/stacks/create/review?{query_params}" ) - config: Optional[str] = None + config: str | None = None if self.deployment_type == STACK_DEPLOYMENT_TERRAFORM: config = f"""terraform {{ required_providers {{ diff --git a/src/zenml/stack_deployments/azure_stack_deployment.py b/src/zenml/stack_deployments/azure_stack_deployment.py index 0b70fd84d76..af1a8f329c8 100644 --- a/src/zenml/stack_deployments/azure_stack_deployment.py +++ b/src/zenml/stack_deployments/azure_stack_deployment.py @@ -14,7 +14,7 @@ """Functionality to deploy a ZenML stack to Azure.""" import re -from typing import ClassVar, Dict, List +from typing import ClassVar from zenml.enums import StackDeploymentProvider from zenml.models import StackDeploymentConfig @@ -118,7 +118,7 @@ def post_deploy_instructions(cls) -> str: """ @classmethod - def integrations(cls) -> List[str]: + def integrations(cls) -> list[str]: """Return the ZenML integrations required for the stack. Returns: @@ -128,7 +128,7 @@ def integrations(cls) -> List[str]: return ["azure"] @classmethod - def permissions(cls) -> Dict[str, List[str]]: + def permissions(cls) -> dict[str, list[str]]: """Return the permissions granted to ZenML to access the cloud resources. Returns: @@ -151,7 +151,7 @@ def permissions(cls) -> Dict[str, List[str]]: } @classmethod - def locations(cls) -> Dict[str, str]: + def locations(cls) -> dict[str, str]: """Return the locations where the ZenML stack can be deployed. Returns: @@ -227,7 +227,7 @@ def locations(cls) -> Dict[str, str]: } @classmethod - def skypilot_default_regions(cls) -> Dict[str, str]: + def skypilot_default_regions(cls) -> dict[str, str]: """Returns the regions supported by default for the Skypilot. Returns: diff --git a/src/zenml/stack_deployments/gcp_stack_deployment.py b/src/zenml/stack_deployments/gcp_stack_deployment.py index 7a511ca2749..1c3a794fd3b 100644 --- a/src/zenml/stack_deployments/gcp_stack_deployment.py +++ b/src/zenml/stack_deployments/gcp_stack_deployment.py @@ -14,7 +14,7 @@ """Functionality to deploy a ZenML stack to GCP.""" import re -from typing import ClassVar, Dict, List +from typing import ClassVar from zenml.enums import StackDeploymentProvider from zenml.models import StackDeploymentConfig @@ -125,7 +125,7 @@ def post_deploy_instructions(cls) -> str: """ @classmethod - def integrations(cls) -> List[str]: + def integrations(cls) -> list[str]: """Return the ZenML integrations required for the stack. Returns: @@ -137,7 +137,7 @@ def integrations(cls) -> List[str]: ] @classmethod - def permissions(cls) -> Dict[str, List[str]]: + def permissions(cls) -> dict[str, list[str]]: """Return the permissions granted to ZenML to access the cloud resources. Returns: @@ -163,7 +163,7 @@ def permissions(cls) -> Dict[str, List[str]]: } @classmethod - def locations(cls) -> Dict[str, str]: + def locations(cls) -> dict[str, str]: """Return the locations where the ZenML stack can be deployed. Returns: @@ -218,7 +218,7 @@ def locations(cls) -> Dict[str, str]: } @classmethod - def skypilot_default_regions(cls) -> Dict[str, str]: + def skypilot_default_regions(cls) -> dict[str, str]: """Returns the regions supported by default for the Skypilot. Returns: diff --git a/src/zenml/stack_deployments/stack_deployment.py b/src/zenml/stack_deployments/stack_deployment.py index 99aaa473b9c..5a61b44f7f8 100644 --- a/src/zenml/stack_deployments/stack_deployment.py +++ b/src/zenml/stack_deployments/stack_deployment.py @@ -15,7 +15,7 @@ import datetime from abc import abstractmethod -from typing import ClassVar, Dict, List, Optional +from typing import ClassVar from pydantic import BaseModel @@ -40,7 +40,7 @@ class ZenMLCloudStackDeployment(BaseModel): stack_name: str zenml_server_url: str zenml_server_api_token: str - location: Optional[str] = None + location: str | None = None @classmethod @abstractmethod @@ -81,7 +81,7 @@ def post_deploy_instructions(cls) -> str: @classmethod @abstractmethod - def integrations(cls) -> List[str]: + def integrations(cls) -> list[str]: """Return the ZenML integrations required for the stack. Returns: @@ -91,7 +91,7 @@ def integrations(cls) -> List[str]: @classmethod @abstractmethod - def permissions(cls) -> Dict[str, List[str]]: + def permissions(cls) -> dict[str, list[str]]: """Return the permissions granted to ZenML to access the cloud resources. Returns: @@ -101,7 +101,7 @@ def permissions(cls) -> Dict[str, List[str]]: @classmethod @abstractmethod - def locations(cls) -> Dict[str, str]: + def locations(cls) -> dict[str, str]: """Return the locations where the ZenML stack can be deployed. Returns: @@ -121,7 +121,7 @@ def deployment_type(self) -> str: return self.deployment @classmethod - def skypilot_default_regions(cls) -> Dict[str, str]: + def skypilot_default_regions(cls) -> dict[str, str]: """Returns the regions supported by default for the Skypilot. Returns: @@ -174,8 +174,8 @@ def get_deployment_config( def get_stack( self, - date_start: Optional[datetime.datetime] = None, - ) -> Optional[DeployedStack]: + date_start: datetime.datetime | None = None, + ) -> DeployedStack | None: """Return the ZenML stack that was deployed and registered. This method is called to retrieve a ZenML stack matching the deployment diff --git a/src/zenml/stack_deployments/utils.py b/src/zenml/stack_deployments/utils.py index 6866f2ee1f2..2fff9fac067 100644 --- a/src/zenml/stack_deployments/utils.py +++ b/src/zenml/stack_deployments/utils.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Functionality to deploy a ZenML stack to a cloud provider.""" -from typing import Type from zenml.enums import StackDeploymentProvider from zenml.stack_deployments.aws_stack_deployment import ( @@ -36,7 +35,7 @@ def get_stack_deployment_class( provider: StackDeploymentProvider, -) -> Type[ZenMLCloudStackDeployment]: +) -> type[ZenMLCloudStackDeployment]: """Get the ZenML Cloud Stack Deployment class for the specified provider. Args: diff --git a/src/zenml/step_operators/base_step_operator.py b/src/zenml/step_operators/base_step_operator.py index 5c7bfbe8e7a..7b21eee0902 100644 --- a/src/zenml/step_operators/base_step_operator.py +++ b/src/zenml/step_operators/base_step_operator.py @@ -14,7 +14,7 @@ """Base class for ZenML step operators.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, List, Type, cast +from typing import TYPE_CHECKING, cast from zenml.enums import StackComponentType from zenml.logger import get_logger @@ -49,7 +49,7 @@ def config(self) -> BaseStepOperatorConfig: @property def entrypoint_config_class( self, - ) -> Type[StepOperatorEntrypointConfiguration]: + ) -> type[StepOperatorEntrypointConfiguration]: """Returns the entrypoint configuration class for this step operator. Concrete step operator implementations may override this property @@ -65,8 +65,8 @@ def entrypoint_config_class( def launch( self, info: "StepRunInfo", - entrypoint_command: List[str], - environment: Dict[str, str], + entrypoint_command: list[str], + environment: dict[str, str], ) -> None: """Abstract method to execute a step. @@ -94,7 +94,7 @@ def type(self) -> StackComponentType: return StackComponentType.STEP_OPERATOR @property - def config_class(self) -> Type[BaseStepOperatorConfig]: + def config_class(self) -> type[BaseStepOperatorConfig]: """Returns the config class for this flavor. Returns: @@ -104,7 +104,7 @@ def config_class(self) -> Type[BaseStepOperatorConfig]: @property @abstractmethod - def implementation_class(self) -> Type[BaseStepOperator]: + def implementation_class(self) -> type[BaseStepOperator]: """Returns the implementation class for this flavor. Returns: diff --git a/src/zenml/step_operators/step_operator_entrypoint_configuration.py b/src/zenml/step_operators/step_operator_entrypoint_configuration.py index cb3b71c9c74..c42e5fd8682 100644 --- a/src/zenml/step_operators/step_operator_entrypoint_configuration.py +++ b/src/zenml/step_operators/step_operator_entrypoint_configuration.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Abstract base class for entrypoint configurations that run a single step.""" -from typing import TYPE_CHECKING, Any, List, Set +from typing import TYPE_CHECKING, Any from uuid import UUID from zenml.client import Client @@ -36,7 +36,7 @@ class StepOperatorEntrypointConfiguration(StepEntrypointConfiguration): """Base class for step operator entrypoint configurations.""" @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> set[str]: """Gets all options required for running with this configuration. Returns: @@ -50,7 +50,7 @@ def get_entrypoint_options(cls) -> Set[str]: def get_entrypoint_arguments( cls, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Gets all arguments that the entrypoint command should be called with. Args: diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index ef18b99603e..735927a7ab1 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -21,16 +21,11 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, - Mapping, Optional, - Sequence, - Tuple, - Type, TypeVar, Union, ) +from collections.abc import Mapping, Sequence from uuid import UUID from pydantic import BaseModel, ConfigDict, ValidationError @@ -83,7 +78,7 @@ from zenml.models import ArtifactVersionResponse from zenml.types import HookSpecification - MaterializerClassOrSource = Union[str, Source, Type["BaseMaterializer"]] + MaterializerClassOrSource = Union[str, Source, type["BaseMaterializer"]] OutputMaterializersSpecification = Union[ "MaterializerClassOrSource", Sequence["MaterializerClassOrSource"], @@ -101,27 +96,27 @@ class BaseStep: def __init__( self, - name: Optional[str] = None, - enable_cache: Optional[bool] = None, - enable_artifact_metadata: Optional[bool] = None, - enable_artifact_visualization: Optional[bool] = None, - enable_step_logs: Optional[bool] = None, - experiment_tracker: Optional[Union[bool, str]] = None, - step_operator: Optional[Union[bool, str]] = None, - parameters: Optional[Dict[str, Any]] = None, + name: str | None = None, + enable_cache: bool | None = None, + enable_artifact_metadata: bool | None = None, + enable_artifact_visualization: bool | None = None, + enable_step_logs: bool | None = None, + experiment_tracker: bool | str | None = None, + step_operator: bool | str | None = None, + parameters: dict[str, Any] | None = None, output_materializers: Optional[ "OutputMaterializersSpecification" ] = None, - environment: Optional[Dict[str, Any]] = None, - secrets: Optional[List[Union[str, UUID]]] = None, - settings: Optional[Mapping[str, "SettingsOrDict"]] = None, - extra: Optional[Dict[str, Any]] = None, + environment: dict[str, Any] | None = None, + secrets: list[str | UUID] | None = None, + settings: Mapping[str, "SettingsOrDict"] | None = None, + extra: dict[str, Any] | None = None, on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, model: Optional["Model"] = None, - retry: Optional[StepRetryConfig] = None, - substitutions: Optional[Dict[str, str]] = None, - cache_policy: Optional[CachePolicyOrString] = None, + retry: StepRetryConfig | None = None, + substitutions: dict[str, str] | None = None, + cache_policy: CachePolicyOrString | None = None, ) -> None: """Initializes a step. @@ -236,7 +231,7 @@ def entrypoint(self, *args: Any, **kwargs: Any) -> Any: """ @classmethod - def load_from_source(cls, source: Union[Source, str]) -> "BaseStep": + def load_from_source(cls, source: Source | str) -> "BaseStep": """Loads a step from source. Args: @@ -293,7 +288,7 @@ def source_code_cache_value(self) -> str: return self.source_code @property - def docstring(self) -> Optional[str]: + def docstring(self) -> str | None: """The docstring of this step. Returns: @@ -302,7 +297,7 @@ def docstring(self) -> Optional[str]: return self.__doc__ @property - def caching_parameters(self) -> Dict[str, Any]: + def caching_parameters(self) -> dict[str, Any]: """Caching parameters for this step. Returns: @@ -331,13 +326,13 @@ def caching_parameters(self) -> Dict[str, Any]: def _parse_call_args( self, *args: Any, **kwargs: Any - ) -> Tuple[ - Dict[str, "StepArtifact"], - Dict[str, Union["ExternalArtifact", "ArtifactVersionResponse"]], - Dict[str, "ModelVersionDataLazyLoader"], - Dict[str, "ClientLazyLoader"], - Dict[str, Any], - Dict[str, Any], + ) -> tuple[ + dict[str, "StepArtifact"], + dict[str, Union["ExternalArtifact", "ArtifactVersionResponse"]], + dict[str, "ModelVersionDataLazyLoader"], + dict[str, "ClientLazyLoader"], + dict[str, Any], + dict[str, Any], ]: """Parses the call args for the step entrypoint. @@ -369,7 +364,7 @@ def _parse_call_args( ) from e artifacts = {} - external_artifacts: Dict[ + external_artifacts: dict[ str, Union["ExternalArtifact", "ArtifactVersionResponse"] ] = {} model_artifacts_or_metadata = {} @@ -452,10 +447,10 @@ def _parse_call_args( def __call__( self, *args: Any, - id: Optional[str] = None, - after: Union[ - str, StepArtifact, Sequence[Union[str, StepArtifact]], None - ] = None, + id: str | None = None, + after: ( + str | StepArtifact | Sequence[str | StepArtifact] | None + ) = None, **kwargs: Any, ) -> Any: """Handle a call of the step. @@ -592,7 +587,7 @@ def name(self) -> str: return self.configuration.name @property - def enable_cache(self) -> Optional[bool]: + def enable_cache(self) -> bool | None: """If caching is enabled for the step. Returns: @@ -611,26 +606,26 @@ def configuration(self) -> "PartialStepConfiguration": def configure( self: T, - enable_cache: Optional[bool] = None, - enable_artifact_metadata: Optional[bool] = None, - enable_artifact_visualization: Optional[bool] = None, - enable_step_logs: Optional[bool] = None, - experiment_tracker: Optional[Union[bool, str]] = None, - step_operator: Optional[Union[bool, str]] = None, - parameters: Optional[Dict[str, Any]] = None, + enable_cache: bool | None = None, + enable_artifact_metadata: bool | None = None, + enable_artifact_visualization: bool | None = None, + enable_step_logs: bool | None = None, + experiment_tracker: bool | str | None = None, + step_operator: bool | str | None = None, + parameters: dict[str, Any] | None = None, output_materializers: Optional[ "OutputMaterializersSpecification" ] = None, - environment: Optional[Dict[str, Any]] = None, - secrets: Optional[Sequence[Union[str, UUID]]] = None, - settings: Optional[Mapping[str, "SettingsOrDict"]] = None, - extra: Optional[Dict[str, Any]] = None, + environment: dict[str, Any] | None = None, + secrets: Sequence[str | UUID] | None = None, + settings: Mapping[str, "SettingsOrDict"] | None = None, + extra: dict[str, Any] | None = None, on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, model: Optional["Model"] = None, - retry: Optional[StepRetryConfig] = None, - substitutions: Optional[Dict[str, str]] = None, - cache_policy: Optional[CachePolicyOrString] = None, + retry: StepRetryConfig | None = None, + substitutions: dict[str, str] | None = None, + cache_policy: CachePolicyOrString | None = None, merge: bool = True, ) -> T: """Configures the step. @@ -687,7 +682,7 @@ def configure( from zenml.hooks.hook_validators import resolve_and_validate_hook def _resolve_if_necessary( - value: Union[str, Source, Type[Any]], + value: str | Source | type[Any], ) -> Source: if isinstance(value, str): return Source.from_import_path(value) @@ -696,13 +691,13 @@ def _resolve_if_necessary( else: return source_utils.resolve(value) - def _convert_to_tuple(value: Any) -> Tuple[Source, ...]: + def _convert_to_tuple(value: Any) -> tuple[Source, ...]: if isinstance(value, str) or not isinstance(value, Sequence): return (_resolve_if_necessary(value),) else: return tuple(_resolve_if_necessary(v) for v in value) - outputs: Dict[str, Dict[str, Tuple[Source, ...]]] = defaultdict(dict) + outputs: dict[str, dict[str, tuple[Source, ...]]] = defaultdict(dict) allowed_output_names = set(self.entrypoint_definition.outputs) if output_materializers: @@ -760,26 +755,26 @@ def _convert_to_tuple(value: Any) -> Tuple[Source, ...]: def with_options( self, - enable_cache: Optional[bool] = None, - enable_artifact_metadata: Optional[bool] = None, - enable_artifact_visualization: Optional[bool] = None, - enable_step_logs: Optional[bool] = None, - experiment_tracker: Optional[Union[bool, str]] = None, - step_operator: Optional[Union[bool, str]] = None, - parameters: Optional[Dict[str, Any]] = None, + enable_cache: bool | None = None, + enable_artifact_metadata: bool | None = None, + enable_artifact_visualization: bool | None = None, + enable_step_logs: bool | None = None, + experiment_tracker: bool | str | None = None, + step_operator: bool | str | None = None, + parameters: dict[str, Any] | None = None, output_materializers: Optional[ "OutputMaterializersSpecification" ] = None, - environment: Optional[Dict[str, Any]] = None, - secrets: Optional[List[Union[str, UUID]]] = None, - settings: Optional[Mapping[str, "SettingsOrDict"]] = None, - extra: Optional[Dict[str, Any]] = None, + environment: dict[str, Any] | None = None, + secrets: list[str | UUID] | None = None, + settings: Mapping[str, "SettingsOrDict"] | None = None, + extra: dict[str, Any] | None = None, on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, model: Optional["Model"] = None, - retry: Optional[StepRetryConfig] = None, - substitutions: Optional[Dict[str, str]] = None, - cache_policy: Optional[CachePolicyOrString] = None, + retry: StepRetryConfig | None = None, + substitutions: dict[str, str] | None = None, + cache_policy: CachePolicyOrString | None = None, merge: bool = True, ) -> "BaseStep": """Copies the step and applies the given configurations. @@ -858,7 +853,7 @@ def _apply_configuration( self, config: "StepConfigurationUpdate", merge: bool = True, - runtime_parameters: Dict[str, Any] = {}, + runtime_parameters: dict[str, Any] = {}, ) -> None: """Applies an update to the step configuration. @@ -881,7 +876,7 @@ def _apply_configuration( def _validate_configuration( self, config: "StepConfigurationUpdate", - runtime_parameters: Dict[str, Any], + runtime_parameters: dict[str, Any], ) -> None: """Validates a configuration update. @@ -898,8 +893,8 @@ def _validate_configuration( def _validate_function_parameters( self, - parameters: Optional[Dict[str, Any]], - runtime_parameters: Dict[str, Any], + parameters: dict[str, Any] | None, + runtime_parameters: dict[str, Any], ) -> None: """Validates step function parameters. @@ -993,10 +988,10 @@ def _validate_outputs( def _validate_inputs( self, - input_artifacts: Dict[str, "StepArtifact"], - external_artifacts: Dict[str, "ExternalArtifactConfiguration"], - model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"], - client_lazy_loaders: Dict[str, "ClientLazyLoader"], + input_artifacts: dict[str, "StepArtifact"], + external_artifacts: dict[str, "ExternalArtifactConfiguration"], + model_artifacts_or_metadata: dict[str, "ModelVersionDataLazyLoader"], + client_lazy_loaders: dict[str, "ClientLazyLoader"], ) -> None: """Validates the step inputs. @@ -1027,10 +1022,10 @@ def _validate_inputs( def _finalize_configuration( self, - input_artifacts: Dict[str, "StepArtifact"], - external_artifacts: Dict[str, "ExternalArtifactConfiguration"], - model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"], - client_lazy_loaders: Dict[str, "ClientLazyLoader"], + input_artifacts: dict[str, "StepArtifact"], + external_artifacts: dict[str, "ExternalArtifactConfiguration"], + model_artifacts_or_metadata: dict[str, "ModelVersionDataLazyLoader"], + client_lazy_loaders: dict[str, "ClientLazyLoader"], ) -> "StepConfiguration": """Finalizes the configuration after the step was called. @@ -1060,7 +1055,7 @@ def _finalize_configuration( StepConfigurationUpdate, ) - outputs: Dict[str, Dict[str, Any]] = defaultdict(dict) + outputs: dict[str, dict[str, Any]] = defaultdict(dict) for ( output_name, @@ -1083,7 +1078,7 @@ def _finalize_configuration( if output_annotation.resolved_annotation is Any: continue - materializer_classes: List[Type["BaseMaterializer"]] = [ + materializer_classes: list[type["BaseMaterializer"]] = [ source_utils.load(materializer_source) for materializer_source in output.materializer_source ] @@ -1158,7 +1153,7 @@ def _finalize_configuration( self._configuration.model_dump() ) - def _finalize_parameters(self) -> Dict[str, Any]: + def _finalize_parameters(self) -> dict[str, Any]: """Finalizes the config parameters for running this step. Returns: diff --git a/src/zenml/steps/decorated_step.py b/src/zenml/steps/decorated_step.py index 9476184ecd4..fe75ef7ff35 100644 --- a/src/zenml/steps/decorated_step.py +++ b/src/zenml/steps/decorated_step.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Internal BaseStep subclass used by the step decorator.""" -from typing import Any, Optional +from typing import Any from zenml.config.source import Source from zenml.steps import BaseStep, step @@ -84,7 +84,7 @@ def remove_decorator_from_source_code( class _DecoratedStep(BaseStep): """Internal BaseStep subclass used by the step decorator.""" - def _get_step_decorator_name(self) -> Optional[str]: + def _get_step_decorator_name(self) -> str | None: """The name of the step decorator. Returns: diff --git a/src/zenml/steps/entrypoint_function_utils.py b/src/zenml/steps/entrypoint_function_utils.py index 28121759497..73975ebbb89 100644 --- a/src/zenml/steps/entrypoint_function_utils.py +++ b/src/zenml/steps/entrypoint_function_utils.py @@ -17,14 +17,11 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, NamedTuple, NoReturn, - Sequence, - Type, Union, ) +from collections.abc import Callable, Sequence from pydantic import ConfigDict, ValidationError, create_model @@ -44,7 +41,7 @@ from zenml.config.source import Source from zenml.pipelines.pipeline_definition import Pipeline - MaterializerClassOrSource = Union[str, "Source", Type["BaseMaterializer"]] + MaterializerClassOrSource = Union[str, "Source", type["BaseMaterializer"]] logger = get_logger(__name__) @@ -113,8 +110,8 @@ class EntrypointFunctionDefinition(NamedTuple): names to output annotations. """ - inputs: Dict[str, inspect.Parameter] - outputs: Dict[str, OutputSignature] + inputs: dict[str, inspect.Parameter] + outputs: dict[str, OutputSignature] def validate_input(self, key: str, value: Any) -> None: """Validates an input to the step entrypoint function. diff --git a/src/zenml/steps/step_context.py b/src/zenml/steps/step_context.py index 196727fc647..106999443b9 100644 --- a/src/zenml/steps/step_context.py +++ b/src/zenml/steps/step_context.py @@ -16,13 +16,9 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, - Mapping, Optional, - Sequence, - Type, ) +from collections.abc import Mapping, Sequence from zenml.exceptions import StepContextError from zenml.logger import get_logger @@ -76,10 +72,10 @@ class RunContext(metaclass=SingletonMetaClass): def __init__(self) -> None: """Create the run context.""" self.initialized = False - self._state: Optional[Any] = None + self._state: Any | None = None @property - def state(self) -> Optional[Any]: + def state(self) -> Any | None: """Returns the pipeline state. Returns: @@ -95,7 +91,7 @@ def state(self) -> Optional[Any]: ) return self._state - def initialize(self, state: Optional[Any]) -> None: + def initialize(self, state: Any | None) -> None: """Initialize the run context. Args: @@ -142,7 +138,7 @@ def __init__( self, pipeline_run: "PipelineRunResponse", step_run: "StepRunResponse", - output_materializers: Mapping[str, Sequence[Type["BaseMaterializer"]]], + output_materializers: Mapping[str, Sequence[type["BaseMaterializer"]]], output_artifact_uris: Mapping[str, str], output_artifact_configs: Mapping[str, Optional["ArtifactConfig"]], ) -> None: @@ -217,7 +213,7 @@ def pipeline(self) -> "PipelineResponse": ) @property - def pipeline_state(self) -> Optional[Any]: + def pipeline_state(self) -> Any | None: """Returns the pipeline state. Returns: @@ -250,7 +246,7 @@ def model(self) -> "Model": return self.model_version.to_model_class() @property - def inputs(self) -> Dict[str, "StepRunInputResponse"]: + def inputs(self) -> dict[str, "StepRunInputResponse"]: """Returns the input artifacts of the current step. Returns: @@ -259,7 +255,7 @@ def inputs(self) -> Dict[str, "StepRunInputResponse"]: return self.step_run.regular_inputs def _get_output( - self, output_name: Optional[str] = None + self, output_name: str | None = None ) -> "StepContextOutput": """Returns the materializer and artifact URI for a given step output. @@ -304,9 +300,9 @@ def _get_output( def get_output_materializer( self, - output_name: Optional[str] = None, - custom_materializer_class: Optional[Type["BaseMaterializer"]] = None, - data_type: Optional[Type[Any]] = None, + output_name: str | None = None, + custom_materializer_class: type["BaseMaterializer"] | None = None, + data_type: type[Any] | None = None, ) -> "BaseMaterializer": """Returns a materializer for a given step output. @@ -347,7 +343,7 @@ def get_output_materializer( return materializer_class(artifact_uri) def get_output_artifact_uri( - self, output_name: Optional[str] = None + self, output_name: str | None = None ) -> str: """Returns the artifact URI for a given step output. @@ -363,8 +359,8 @@ def get_output_artifact_uri( return self._get_output(output_name).artifact_uri def get_output_metadata( - self, output_name: Optional[str] = None - ) -> Dict[str, "MetadataType"]: + self, output_name: str | None = None + ) -> dict[str, "MetadataType"]: """Returns the metadata for a given step output. Args: @@ -384,7 +380,7 @@ def get_output_metadata( ) return custom_metadata - def get_output_tags(self, output_name: Optional[str] = None) -> List[str]: + def get_output_tags(self, output_name: str | None = None) -> list[str]: """Returns the tags for a given step output. Args: @@ -406,8 +402,8 @@ def get_output_tags(self, output_name: Optional[str] = None) -> List[str]: def add_output_metadata( self, - metadata: Dict[str, "MetadataType"], - output_name: Optional[str] = None, + metadata: dict[str, "MetadataType"], + output_name: str | None = None, ) -> None: """Adds metadata for a given step output. @@ -425,8 +421,8 @@ def add_output_metadata( def add_output_tags( self, - tags: List[str], - output_name: Optional[str] = None, + tags: list[str], + output_name: str | None = None, ) -> None: """Adds tags for a given step output. @@ -444,8 +440,8 @@ def add_output_tags( def remove_output_tags( self, - tags: List[str], - output_name: Optional[str] = None, + tags: list[str], + output_name: str | None = None, ) -> None: """Removes tags for a given step output. @@ -465,15 +461,15 @@ def remove_output_tags( class StepContextOutput: """Represents a step output in the step context.""" - materializer_classes: Sequence[Type["BaseMaterializer"]] + materializer_classes: Sequence[type["BaseMaterializer"]] artifact_uri: str - run_metadata: Optional[Dict[str, "MetadataType"]] = None + run_metadata: dict[str, "MetadataType"] | None = None artifact_config: Optional["ArtifactConfig"] - tags: Optional[List[str]] = None + tags: list[str] | None = None def __init__( self, - materializer_classes: Sequence[Type["BaseMaterializer"]], + materializer_classes: Sequence[type["BaseMaterializer"]], artifact_uri: str, artifact_config: Optional["ArtifactConfig"], ): diff --git a/src/zenml/steps/step_decorator.py b/src/zenml/steps/step_decorator.py index 3ea538c57a7..e0e2acd8d1f 100644 --- a/src/zenml/steps/step_decorator.py +++ b/src/zenml/steps/step_decorator.py @@ -16,17 +16,12 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - List, - Mapping, Optional, - Sequence, - Type, TypeVar, Union, overload, ) +from collections.abc import Callable, Mapping, Sequence from uuid import UUID from zenml.logger import get_logger @@ -41,7 +36,7 @@ from zenml.steps import BaseStep from zenml.types import HookSpecification - MaterializerClassOrSource = Union[str, Source, Type[BaseMaterializer]] + MaterializerClassOrSource = Union[str, Source, type[BaseMaterializer]] OutputMaterializersSpecification = Union[ MaterializerClassOrSource, @@ -62,23 +57,23 @@ def step(_func: "F") -> "BaseStep": ... @overload def step( *, - name: Optional[str] = None, - enable_cache: Optional[bool] = None, - enable_artifact_metadata: Optional[bool] = None, - enable_artifact_visualization: Optional[bool] = None, - enable_step_logs: Optional[bool] = None, - experiment_tracker: Optional[Union[bool, str]] = None, - step_operator: Optional[Union[bool, str]] = None, + name: str | None = None, + enable_cache: bool | None = None, + enable_artifact_metadata: bool | None = None, + enable_artifact_visualization: bool | None = None, + enable_step_logs: bool | None = None, + experiment_tracker: bool | str | None = None, + step_operator: bool | str | None = None, output_materializers: Optional["OutputMaterializersSpecification"] = None, - environment: Optional[Dict[str, Any]] = None, - secrets: Optional[List[Union[UUID, str]]] = None, - settings: Optional[Dict[str, "SettingsOrDict"]] = None, - extra: Optional[Dict[str, Any]] = None, + environment: dict[str, Any] | None = None, + secrets: list[UUID | str] | None = None, + settings: dict[str, "SettingsOrDict"] | None = None, + extra: dict[str, Any] | None = None, on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, - substitutions: Optional[Dict[str, str]] = None, + substitutions: dict[str, str] | None = None, cache_policy: Optional["CachePolicyOrString"] = None, ) -> Callable[["F"], "BaseStep"]: ... @@ -86,23 +81,23 @@ def step( def step( _func: Optional["F"] = None, *, - name: Optional[str] = None, - enable_cache: Optional[bool] = None, - enable_artifact_metadata: Optional[bool] = None, - enable_artifact_visualization: Optional[bool] = None, - enable_step_logs: Optional[bool] = None, - experiment_tracker: Optional[Union[bool, str]] = None, - step_operator: Optional[Union[bool, str]] = None, + name: str | None = None, + enable_cache: bool | None = None, + enable_artifact_metadata: bool | None = None, + enable_artifact_visualization: bool | None = None, + enable_step_logs: bool | None = None, + experiment_tracker: bool | str | None = None, + step_operator: bool | str | None = None, output_materializers: Optional["OutputMaterializersSpecification"] = None, - environment: Optional[Dict[str, Any]] = None, - secrets: Optional[List[Union[UUID, str]]] = None, - settings: Optional[Dict[str, "SettingsOrDict"]] = None, - extra: Optional[Dict[str, Any]] = None, + environment: dict[str, Any] | None = None, + secrets: list[UUID | str] | None = None, + settings: dict[str, "SettingsOrDict"] | None = None, + extra: dict[str, Any] | None = None, on_failure: Optional["HookSpecification"] = None, on_success: Optional["HookSpecification"] = None, model: Optional["Model"] = None, retry: Optional["StepRetryConfig"] = None, - substitutions: Optional[Dict[str, str]] = None, + substitutions: dict[str, str] | None = None, cache_policy: Optional["CachePolicyOrString"] = None, ) -> Union["BaseStep", Callable[["F"], "BaseStep"]]: """Decorator to create a ZenML step. @@ -147,7 +142,7 @@ def step( def inner_decorator(func: "F") -> "BaseStep": from zenml.steps.decorated_step import _DecoratedStep - class_: Type["BaseStep"] = type( + class_: type["BaseStep"] = type( func.__name__, (_DecoratedStep,), { diff --git a/src/zenml/steps/step_invocation.py b/src/zenml/steps/step_invocation.py index 17341d40845..41b8a5a1c7c 100644 --- a/src/zenml/steps/step_invocation.py +++ b/src/zenml/steps/step_invocation.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Step invocation class definition.""" -from typing import TYPE_CHECKING, Any, Dict, Set, Union +from typing import TYPE_CHECKING, Any, Union from zenml.models import ArtifactVersionResponse @@ -34,15 +34,15 @@ def __init__( self, id: str, step: "BaseStep", - input_artifacts: Dict[str, "StepArtifact"], - external_artifacts: Dict[ + input_artifacts: dict[str, "StepArtifact"], + external_artifacts: dict[ str, Union["ExternalArtifact", "ArtifactVersionResponse"] ], - model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"], - client_lazy_loaders: Dict[str, "ClientLazyLoader"], - parameters: Dict[str, Any], - default_parameters: Dict[str, Any], - upstream_steps: Set[str], + model_artifacts_or_metadata: dict[str, "ModelVersionDataLazyLoader"], + client_lazy_loaders: dict[str, "ClientLazyLoader"], + parameters: dict[str, Any], + default_parameters: dict[str, Any], + upstream_steps: set[str], pipeline: "Pipeline", ) -> None: """Initialize a step invocation. @@ -71,7 +71,7 @@ def __init__( self.upstream_steps = upstream_steps self.pipeline = pipeline - def finalize(self, parameters_to_ignore: Set[str]) -> "StepConfiguration": + def finalize(self, parameters_to_ignore: set[str]) -> "StepConfiguration": """Finalizes a step invocation. It will validate the upstream steps and run final configurations on the @@ -103,7 +103,7 @@ def finalize(self, parameters_to_ignore: Set[str]) -> "StepConfiguration": ) self.step.configure(parameters=parameters_to_apply) - external_artifacts: Dict[str, ExternalArtifactConfiguration] = {} + external_artifacts: dict[str, ExternalArtifactConfiguration] = {} for key, artifact in self.external_artifacts.items(): if isinstance(artifact, ArtifactVersionResponse): external_artifacts[key] = ExternalArtifactConfiguration( diff --git a/src/zenml/steps/utils.py b/src/zenml/steps/utils.py index e6e058b6d84..730d42e99ce 100644 --- a/src/zenml/steps/utils.py +++ b/src/zenml/steps/utils.py @@ -22,17 +22,13 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - Optional, - Tuple, TypeVar, - Union, ) +from collections.abc import Callable from uuid import UUID from pydantic import BaseModel -from typing_extensions import Annotated +from typing import Annotated from zenml.artifacts.artifact_config import ArtifactConfig from zenml.client import Client @@ -61,10 +57,10 @@ class OutputSignature(BaseModel): """The signature of an output artifact.""" resolved_annotation: Any = None - artifact_config: Optional[ArtifactConfig] = None + artifact_config: ArtifactConfig | None = None has_custom_name: bool = False - def get_output_types(self) -> Tuple[Any, ...]: + def get_output_types(self) -> tuple[Any, ...]: """Get all output types that match the type annotation. Returns: @@ -87,7 +83,7 @@ def get_output_types(self) -> Tuple[Any, ...]: return (self.resolved_annotation,) -def get_args(obj: Any) -> Tuple[Any, ...]: +def get_args(obj: Any) -> tuple[Any, ...]: """Get arguments of a type annotation. Example: @@ -107,7 +103,7 @@ def get_args(obj: Any) -> Tuple[Any, ...]: def parse_return_type_annotations( func: Callable[..., Any], enforce_type_annotations: bool = False, -) -> Dict[str, OutputSignature]: +) -> dict[str, OutputSignature]: """Parse the return type annotation of a step function. Args: @@ -126,7 +122,7 @@ def parse_return_type_annotations( """ signature = inspect.signature(func, follow_wrapped=True) return_annotation = signature.return_annotation - output_name: Optional[str] + output_name: str | None # Return type annotated as `None` if return_annotation is None: @@ -147,7 +143,7 @@ def parse_return_type_annotations( if typing_utils.get_origin(return_annotation) is tuple: requires_multiple_artifacts = has_tuple_return(func) if requires_multiple_artifacts: - output_signature: Dict[str, Any] = {} + output_signature: dict[str, Any] = {} args = typing_utils.get_args(return_annotation) if args[-1] is Ellipsis: raise RuntimeError( @@ -213,7 +209,7 @@ def resolve_type_annotation(obj: Any) -> Any: def get_artifact_config_from_annotation_metadata( annotation: Any, -) -> Optional[ArtifactConfig]: +) -> ArtifactConfig | None: """Get the artifact config from the annotation metadata of a step output. Example: @@ -298,7 +294,7 @@ def __init__(self, ignore_nested_functions: bool = True) -> None: self._inside_function = False def _visit_function( - self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef] + self, node: ast.FunctionDef | ast.AsyncFunctionDef ) -> None: """Visit a (async) function definition node. @@ -436,10 +432,10 @@ def f3(condition): def log_step_metadata( - metadata: Dict[str, "MetadataType"], - step_name: Optional[str] = None, - pipeline_name_id_or_prefix: Optional[Union[str, UUID]] = None, - run_id: Optional[str] = None, + metadata: dict[str, "MetadataType"], + step_name: str | None = None, + pipeline_name_id_or_prefix: str | UUID | None = None, + run_id: str | None = None, ) -> None: """Logs step metadata. @@ -544,7 +540,7 @@ def run_as_single_step_pipeline( "error above for more details." ) from e - inputs: Dict[str, Any] = {} + inputs: dict[str, Any] = {} for key, value in validated_arguments.items(): try: __step.entrypoint_definition.validate_input(key=key, value=value) @@ -592,8 +588,8 @@ def single_step_pipeline() -> None: def get_unique_step_output_names( - step_outputs: Dict[Tuple[str, str], T], -) -> Dict[Tuple[str, str], Tuple[T, str]]: + step_outputs: dict[tuple[str, str], T], +) -> dict[tuple[str, str], tuple[T, str]]: """Get unique step output names. Given a dictionary of step outputs indexed by (invocation_id, output_name), diff --git a/src/zenml/types.py b/src/zenml/types.py index 65ea344221f..2377eead198 100644 --- a/src/zenml/types.py +++ b/src/zenml/types.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Custom ZenML types.""" -from typing import TYPE_CHECKING, Any, Callable, Union +from typing import TYPE_CHECKING, Any, Union +from collections.abc import Callable if TYPE_CHECKING: from types import FunctionType diff --git a/src/zenml/utils/archivable.py b/src/zenml/utils/archivable.py index 23804016cce..b6aafdda553 100644 --- a/src/zenml/utils/archivable.py +++ b/src/zenml/utils/archivable.py @@ -18,7 +18,7 @@ import zipfile from abc import ABC, abstractmethod from pathlib import Path -from typing import IO, Any, Dict, Optional +from typing import IO, Any from zenml.io import fileio from zenml.utils.enum_utils import StrEnum @@ -42,7 +42,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: *args: Unused args for subclasses. **kwargs: Unused keyword args for subclasses. """ - self._extra_files: Dict[str, str] = {} + self._extra_files: dict[str, str] = {} def add_file(self, source: str, destination: str) -> None: """Adds a file to the archive. @@ -103,7 +103,7 @@ def write_archive( """ files = self.get_files() extra_files = self.get_extra_files() - close_fileobj: Optional[Any] = None + close_fileobj: Any | None = None fileobj: Any = output_file if archive_type == ArchiveType.ZIP: @@ -160,7 +160,7 @@ def write_archive( output_file.seek(0) @abstractmethod - def get_files(self) -> Dict[str, str]: + def get_files(self) -> dict[str, str]: """Gets all regular files that should be included in the archive. Returns: @@ -168,7 +168,7 @@ def get_files(self) -> Dict[str, str]: in the archive. """ - def get_extra_files(self) -> Dict[str, str]: + def get_extra_files(self) -> dict[str, str]: """Gets all extra files that should be included in the archive. Returns: diff --git a/src/zenml/utils/callback_registry.py b/src/zenml/utils/callback_registry.py index fa2c1977731..a8611baf143 100644 --- a/src/zenml/utils/callback_registry.py +++ b/src/zenml/utils/callback_registry.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Callback registry implementation.""" -from typing import Any, Callable, Dict, List, Tuple +from typing import Any +from collections.abc import Callable from typing_extensions import ParamSpec @@ -29,8 +30,8 @@ class CallbackRegistry: def __init__(self) -> None: """Initializes the callback registry.""" - self._callbacks: List[ - Tuple[Callable[P, Any], Tuple[Any], Dict[str, Any]] + self._callbacks: list[ + tuple[Callable[P, Any], tuple[Any], dict[str, Any]] ] = [] def register_callback( diff --git a/src/zenml/utils/code_repository_utils.py b/src/zenml/utils/code_repository_utils.py index bb32b3b025f..29669640de5 100644 --- a/src/zenml/utils/code_repository_utils.py +++ b/src/zenml/utils/code_repository_utils.py @@ -15,7 +15,6 @@ import os from typing import ( - Dict, Optional, ) @@ -30,7 +29,7 @@ logger = get_logger(__name__) -_CODE_REPOSITORY_CACHE: Dict[str, Optional["LocalRepositoryContext"]] = {} +_CODE_REPOSITORY_CACHE: dict[str, Optional["LocalRepositoryContext"]] = {} def set_custom_local_repository( @@ -84,7 +83,7 @@ def set_custom_local_repository( def find_active_code_repository( - path: Optional[str] = None, + path: str | None = None, ) -> Optional["LocalRepositoryContext"]: """Find the active code repository for a given path. diff --git a/src/zenml/utils/code_utils.py b/src/zenml/utils/code_utils.py index 6bcd9c9582c..6540d6e0e1f 100644 --- a/src/zenml/utils/code_utils.py +++ b/src/zenml/utils/code_utils.py @@ -19,7 +19,7 @@ import sys import tempfile from pathlib import Path -from typing import IO, TYPE_CHECKING, Dict, Optional +from typing import IO, TYPE_CHECKING, Optional from zenml.client import Client from zenml.io import fileio @@ -44,7 +44,7 @@ class CodeArchive(Archivable): excluded by gitignores will be included in the archive. """ - def __init__(self, root: Optional[str] = None) -> None: + def __init__(self, root: str | None = None) -> None: """Initialize the object. Args: @@ -74,7 +74,7 @@ def git_repo(self) -> Optional["Repo"]: return git_repo - def _get_all_files(self, archive_root: str) -> Dict[str, str]: + def _get_all_files(self, archive_root: str) -> dict[str, str]: """Get all files inside the archive root. Args: @@ -92,7 +92,7 @@ def _get_all_files(self, archive_root: str) -> Dict[str, str]: return all_files - def get_files(self) -> Dict[str, str]: + def get_files(self) -> dict[str, str]: """Gets all regular files that should be included in the archive. Raises: diff --git a/src/zenml/utils/context_utils.py b/src/zenml/utils/context_utils.py index 2add3bc862c..2db4c1166e0 100644 --- a/src/zenml/utils/context_utils.py +++ b/src/zenml/utils/context_utils.py @@ -15,7 +15,7 @@ import threading from contextvars import ContextVar -from typing import Generic, List, Optional, TypeVar +from typing import Generic, TypeVar T = TypeVar("T") @@ -30,13 +30,13 @@ def __init__(self, name: str) -> None: name: The name for the underlying ContextVar. """ # Use None as default to avoid mutable default issues - self._context_var: ContextVar[Optional[List[T]]] = ContextVar( + self._context_var: ContextVar[list[T] | None] = ContextVar( name, default=None ) # Lock to ensure atomic operations self._lock = threading.Lock() - def get(self) -> List[T]: + def get(self) -> list[T]: """Get the current list value. Returns empty list if not set. Returns: diff --git a/src/zenml/utils/daemon.py b/src/zenml/utils/daemon.py index 2dc0a034327..aa123879844 100644 --- a/src/zenml/utils/daemon.py +++ b/src/zenml/utils/daemon.py @@ -24,7 +24,8 @@ import signal import sys import types -from typing import Any, Callable, Optional, TypeVar, cast +from typing import Any, TypeVar, cast +from collections.abc import Callable import psutil @@ -42,7 +43,7 @@ def daemonize( pid_file: str, - log_file: Optional[str] = None, + log_file: str | None = None, working_directory: str = "/", ) -> Callable[[F], F]: """Decorator that executes the decorated function as a daemon process. @@ -113,7 +114,7 @@ def daemon(*args: Any, **kwargs: Any) -> None: def setup_daemon( - pid_file: Optional[str] = None, log_file: Optional[str] = None + pid_file: str | None = None, log_file: str | None = None ) -> None: """Sets up a daemon process. @@ -179,7 +180,7 @@ def cleanup() -> None: os.remove(pid_file) sys.stderr.flush() - def sighndl(signum: int, frame: Optional[types.FrameType]) -> None: + def sighndl(signum: int, frame: types.FrameType | None) -> None: """Daemon signal handler. Args: @@ -229,16 +230,16 @@ def stop_daemon(pid_file: str) -> None: kill. """ try: - with open(pid_file, "r") as f: + with open(pid_file) as f: pid = int(f.read().strip()) - except (IOError, FileNotFoundError): + except (OSError, FileNotFoundError): logger.debug("Daemon PID file '%s' does not exist.", pid_file) return stop_process(pid) -def get_daemon_pid_if_running(pid_file: str) -> Optional[int]: +def get_daemon_pid_if_running(pid_file: str) -> int | None: """Read and return the PID value from a PID file. It does this if the daemon process tracked by the PID file is running. @@ -251,9 +252,9 @@ def get_daemon_pid_if_running(pid_file: str) -> Optional[int]: The PID of the daemon process if it is running, otherwise None. """ try: - with open(pid_file, "r") as f: + with open(pid_file) as f: pid = int(f.read().strip()) - except (IOError, FileNotFoundError): + except (OSError, FileNotFoundError): logger.debug( f"Daemon PID file '{pid_file}' does not exist or cannot be read.", ) @@ -313,7 +314,7 @@ def run_as_daemon( daemon_function: F, *args: Any, pid_file: str, - log_file: Optional[str] = None, + log_file: str | None = None, working_directory: str = "/", **kwargs: Any, ) -> None: diff --git a/src/zenml/utils/dashboard_utils.py b/src/zenml/utils/dashboard_utils.py index bde967095a3..f24659f9da3 100644 --- a/src/zenml/utils/dashboard_utils.py +++ b/src/zenml/utils/dashboard_utils.py @@ -14,7 +14,6 @@ """Utility class to help with interacting with the dashboard.""" import os -from typing import Optional from urllib.parse import urlparse from zenml import constants @@ -35,7 +34,7 @@ logger = get_logger(__name__) -def get_cloud_dashboard_url() -> Optional[str]: +def get_cloud_dashboard_url() -> str | None: """Get the base url of the cloud dashboard if the server is a ZenML Pro workspace. Returns: @@ -52,7 +51,7 @@ def get_cloud_dashboard_url() -> Optional[str]: return None -def get_server_dashboard_url() -> Optional[str]: +def get_server_dashboard_url() -> str | None: """Get the base url of the dashboard deployed by the server. Returns: @@ -73,7 +72,7 @@ def get_server_dashboard_url() -> Optional[str]: return None -def get_stack_url(stack: StackResponse) -> Optional[str]: +def get_stack_url(stack: StackResponse) -> str | None: """Function to get the dashboard URL of a given stack model. Args: @@ -97,7 +96,7 @@ def get_stack_url(stack: StackResponse) -> Optional[str]: return None -def get_component_url(component: ComponentResponse) -> Optional[str]: +def get_component_url(component: ComponentResponse) -> str | None: """Function to get the dashboard URL of a given component model. Args: @@ -118,7 +117,7 @@ def get_component_url(component: ComponentResponse) -> Optional[str]: return None -def get_run_url(run: PipelineRunResponse) -> Optional[str]: +def get_run_url(run: PipelineRunResponse) -> str | None: """Function to get the dashboard URL of a given pipeline run. Args: @@ -140,7 +139,7 @@ def get_run_url(run: PipelineRunResponse) -> Optional[str]: def get_model_version_url( model_version: ModelVersionResponse, -) -> Optional[str]: +) -> str | None: """Function to get the dashboard URL of a given model version. Args: @@ -222,7 +221,7 @@ def show_dashboard_with_url(url: str) -> None: def show_dashboard( local: bool = False, - ngrok_token: Optional[str] = None, + ngrok_token: str | None = None, ) -> None: """Show the ZenML dashboard. @@ -238,7 +237,7 @@ def show_dashboard( """ from zenml.utils.networking_utils import get_or_create_ngrok_tunnel - url: Optional[str] = None + url: str | None = None if not local: gc = GlobalConfiguration() if gc.store_configuration.type == StoreType.REST: diff --git a/src/zenml/utils/deprecation_utils.py b/src/zenml/utils/deprecation_utils.py index 4a9ceae7342..99f55815af8 100644 --- a/src/zenml/utils/deprecation_utils.py +++ b/src/zenml/utils/deprecation_utils.py @@ -14,7 +14,7 @@ """Deprecation utilities.""" import warnings -from typing import Any, Dict, Set, Tuple, Type, Union +from typing import Any from pydantic import BaseModel, model_validator @@ -27,7 +27,7 @@ def deprecate_pydantic_attributes( - *attributes: Union[str, Tuple[str, str]], + *attributes: str | tuple[str, str], ) -> Any: """Utility function for deprecating and migrating pydantic attributes. @@ -64,8 +64,8 @@ class MyModel(BaseModel): @classmethod @before_validator_handler def _deprecation_validator( - cls: Type[BaseModel], data: Dict[str, Any] - ) -> Dict[str, Any]: + cls: type[BaseModel], data: dict[str, Any] + ) -> dict[str, Any]: """Pydantic validator function for deprecating pydantic attributes. Args: @@ -82,7 +82,7 @@ def _deprecation_validator( Returns: Input values with potentially migrated values. """ - previous_deprecation_warnings: Set[str] = getattr( + previous_deprecation_warnings: set[str] = getattr( cls, PREVIOUS_DEPRECATION_WARNINGS_ATTRIBUTE, set() ) diff --git a/src/zenml/utils/dict_utils.py b/src/zenml/utils/dict_utils.py index 9c40c80d4a8..643ba30c597 100644 --- a/src/zenml/utils/dict_utils.py +++ b/src/zenml/utils/dict_utils.py @@ -15,14 +15,14 @@ import base64 import json -from typing import Any, Dict +from typing import Any from zenml.utils.json_utils import pydantic_encoder def recursive_update( - original: Dict[str, Any], update: Dict[str, Any] -) -> Dict[str, Any]: + original: dict[str, Any], update: dict[str, Any] +) -> dict[str, Any]: """Recursively updates a dictionary. Args: @@ -33,9 +33,9 @@ def recursive_update( The updated dictionary. """ for key, value in update.items(): - if isinstance(value, Dict): + if isinstance(value, dict): original_value = original.get(key, None) or {} - if isinstance(original_value, Dict): + if isinstance(original_value, dict): original[key] = recursive_update(original_value, value) else: original[key] = value @@ -45,8 +45,8 @@ def recursive_update( def remove_none_values( - dict_: Dict[str, Any], recursive: bool = False -) -> Dict[str, Any]: + dict_: dict[str, Any], recursive: bool = False +) -> dict[str, Any]: """Removes all key-value pairs with `None` value. Args: @@ -67,7 +67,7 @@ def _maybe_recurse(value: Any) -> Any: Returns: The updated dictionary value. """ - if recursive and isinstance(value, Dict): + if recursive and isinstance(value, dict): return remove_none_values(value, recursive=True) else: return value @@ -75,7 +75,7 @@ def _maybe_recurse(value: Any) -> Any: return {k: _maybe_recurse(v) for k, v in dict_.items() if v is not None} -def dict_to_bytes(dict_: Dict[str, Any]) -> bytes: +def dict_to_bytes(dict_: dict[str, Any]) -> bytes: """Converts a dictionary to bytes. Args: diff --git a/src/zenml/utils/docker_utils.py b/src/zenml/utils/docker_utils.py index 29e792852b1..76494209428 100644 --- a/src/zenml/utils/docker_utils.py +++ b/src/zenml/utils/docker_utils.py @@ -18,15 +18,9 @@ import re from typing import ( Any, - Dict, - Iterable, - List, - Optional, - Sequence, - Tuple, - Union, cast, ) +from collections.abc import Iterable, Sequence from docker.client import DockerClient from docker.errors import DockerException @@ -56,7 +50,7 @@ def check_docker() -> bool: return False -def _parse_dockerignore(dockerignore_path: str) -> List[str]: +def _parse_dockerignore(dockerignore_path: str) -> list[str]: """Parses a dockerignore file and returns a list of patterns to ignore. Args: @@ -84,9 +78,9 @@ def _parse_dockerignore(dockerignore_path: str) -> List[str]: def _create_custom_build_context( dockerfile_contents: str, - build_context_root: Optional[str] = None, - dockerignore: Optional[str] = None, - extra_files: Sequence[Tuple[str, str]] = (), + build_context_root: str | None = None, + dockerignore: str | None = None, + extra_files: Sequence[tuple[str, str]] = (), ) -> Any: """Creates a docker build context. @@ -175,10 +169,10 @@ def _create_custom_build_context( def build_image( image_name: str, - dockerfile: Union[str, List[str]], - build_context_root: Optional[str] = None, - dockerignore: Optional[str] = None, - extra_files: Sequence[Tuple[str, str]] = (), + dockerfile: str | list[str], + build_context_root: str | None = None, + dockerignore: str | None = None, + extra_files: Sequence[tuple[str, str]] = (), **custom_build_options: Any, ) -> None: """Builds a docker image. @@ -243,7 +237,7 @@ def build_image( def push_image( - image_name: str, docker_client: Optional[DockerClient] = None + image_name: str, docker_client: DockerClient | None = None ) -> str: """Pushes an image to a container registry. @@ -277,7 +271,7 @@ def push_image( prefix_candidates.append(f"{image_name_without_index}@") image = docker_client.images.get(image_name) - repo_digests: List[str] = image.attrs["RepoDigests"] + repo_digests: list[str] = image.attrs["RepoDigests"] for digest in repo_digests: if digest.startswith(tuple(prefix_candidates)): @@ -307,7 +301,7 @@ def tag_image(image_name: str, target: str) -> None: image.tag(target) -def get_image_digest(image_name: str) -> Optional[str]: +def get_image_digest(image_name: str) -> str | None: """Gets the digest of an image. Args: @@ -371,7 +365,7 @@ def _try_get_docker_client_from_env() -> DockerClient: ) from e -def _process_stream(stream: Iterable[bytes]) -> List[Dict[str, Any]]: +def _process_stream(stream: Iterable[bytes]) -> list[dict[str, Any]]: """Processes the output stream of a docker command call. Args: diff --git a/src/zenml/utils/enum_utils.py b/src/zenml/utils/enum_utils.py index 113f589efae..505de1945b9 100644 --- a/src/zenml/utils/enum_utils.py +++ b/src/zenml/utils/enum_utils.py @@ -14,7 +14,6 @@ """Util functions for enums.""" from enum import Enum -from typing import List class StrEnum(str, Enum): @@ -29,7 +28,7 @@ def __str__(self) -> str: return self.value # type: ignore @classmethod - def names(cls) -> List[str]: + def names(cls) -> list[str]: """Get all enum names as a list of strings. Returns: @@ -38,7 +37,7 @@ def names(cls) -> List[str]: return [c.name for c in cls] @classmethod - def values(cls) -> List[str]: + def values(cls) -> list[str]: """Get all enum values as a list of strings. Returns: diff --git a/src/zenml/utils/env_utils.py b/src/zenml/utils/env_utils.py index 250bd234e7e..587dfb78047 100644 --- a/src/zenml/utils/env_utils.py +++ b/src/zenml/utils/env_utils.py @@ -19,14 +19,11 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - Iterator, - List, - Match, - Optional, TypeVar, cast, ) +from collections.abc import Iterator +from re import Match from zenml.client import Client from zenml.logger import get_logger @@ -46,7 +43,7 @@ def split_environment_variables( size_limit: int, - env: Optional[Dict[str, str]] = None, + env: dict[str, str] | None = None, ) -> None: """Split long environment variables into chunks. @@ -64,7 +61,7 @@ def split_environment_variables( more than 10 chunks. """ if env is None: - env = cast(Dict[str, str], os.environ) + env = cast(dict[str, str], os.environ) for key, value in env.copy().items(): if len(value) <= size_limit: @@ -90,7 +87,7 @@ def split_environment_variables( def reconstruct_environment_variables( - env: Optional[Dict[str, str]] = None, + env: dict[str, str] | None = None, ) -> None: """Reconstruct environment variables that were split into chunks. @@ -103,9 +100,9 @@ def reconstruct_environment_variables( environment variables are used. """ if env is None: - env = cast(Dict[str, str], os.environ) + env = cast(dict[str, str], os.environ) - chunks: Dict[str, List[str]] = {} + chunks: dict[str, list[str]] = {} for key in env.keys(): if not key[:-1].endswith(ENV_VAR_CHUNK_SUFFIX): continue @@ -171,7 +168,7 @@ def _substitution_func(v: str) -> str: @contextlib.contextmanager -def temporary_environment(environment: Dict[str, str]) -> Iterator[None]: +def temporary_environment(environment: dict[str, str]) -> Iterator[None]: """Temporarily set environment variables. Args: @@ -205,7 +202,7 @@ def temporary_environment(environment: Dict[str, str]) -> Iterator[None]: def get_step_environment( step_config: "StepConfiguration", stack: "Stack" -) -> Dict[str, str]: +) -> dict[str, str]: """Get the environment variables for a step. Args: @@ -230,7 +227,7 @@ def get_step_environment( def get_step_secret_environment( step_config: "StepConfiguration", stack: "Stack" -) -> Dict[str, str]: +) -> dict[str, str]: """Get the environment variables for a step. Args: diff --git a/src/zenml/utils/filesync_model.py b/src/zenml/utils/filesync_model.py index 5de3934c287..64f15f72a2c 100644 --- a/src/zenml/utils/filesync_model.py +++ b/src/zenml/utils/filesync_model.py @@ -14,7 +14,7 @@ """Filesync utils for ZenML.""" import os -from typing import Any, Optional +from typing import Any from pydantic import ( BaseModel, @@ -43,7 +43,7 @@ class FileSyncModel(BaseModel): """ _config_file: str - _config_file_timestamp: Optional[float] = None + _config_file_timestamp: float | None = None @model_validator(mode="wrap") @classmethod @@ -116,7 +116,7 @@ def __setattr__(self, key: str, value: Any) -> None: key: attribute name. value: attribute value. """ - super(FileSyncModel, self).__setattr__(key, value) + super().__setattr__(key, value) if key.startswith("_"): return self.write_config() @@ -132,7 +132,7 @@ def __getattribute__(self, key: str) -> Any: """ if not key.startswith("_") and key in self.__dict__: self.load_config() - return super(FileSyncModel, self).__getattribute__(key) + return super().__getattribute__(key) def write_config(self) -> None: """Writes the model to the configuration file.""" @@ -156,6 +156,6 @@ def load_config(self) -> None: # refresh the model from the configuration file values config_dict = yaml_utils.read_yaml(self._config_file) for key, value in config_dict.items(): - super(FileSyncModel, self).__setattr__(key, value) + super().__setattr__(key, value) self._config_file_timestamp = file_timestamp diff --git a/src/zenml/utils/function_utils.py b/src/zenml/utils/function_utils.py index 89b51220f23..edfef03b0b6 100644 --- a/src/zenml/utils/function_utils.py +++ b/src/zenml/utils/function_utils.py @@ -17,7 +17,8 @@ import os from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Iterator, List, Tuple, TypeVar, Union +from typing import Any, TypeVar, Union +from collections.abc import Callable, Iterator import click @@ -125,7 +126,7 @@ def _cli_wrapped_function(func: F) -> F: Raises: ValueError: If the function arguments are not valid. """ - options: List[Any] = [] + options: list[Any] = [] fullargspec = inspect.getfullargspec(func) if fullargspec.defaults is not None: defaults = [None] * ( @@ -205,7 +206,7 @@ def wrapper(function: F) -> F: @contextmanager def create_cli_wrapped_script( func: F, flavor: str = "accelerate" -) -> Iterator[Tuple[Path, Path]]: +) -> Iterator[tuple[Path, Path]]: """Create a script with the CLI-wrapped function. Args: diff --git a/src/zenml/utils/git_utils.py b/src/zenml/utils/git_utils.py index 7cf4b99703a..160b99b0370 100644 --- a/src/zenml/utils/git_utils.py +++ b/src/zenml/utils/git_utils.py @@ -14,7 +14,6 @@ """Utility function to clone a Git repository.""" import os -from typing import Optional from git.exc import GitCommandError from git.repo import Repo @@ -23,8 +22,8 @@ def clone_git_repository( url: str, to_path: str, - branch: Optional[str] = None, - commit: Optional[str] = None, + branch: str | None = None, + commit: str | None = None, ) -> Repo: """Clone a Git repository. diff --git a/src/zenml/utils/io_utils.py b/src/zenml/utils/io_utils.py index 375193825a1..c00e916a3dd 100644 --- a/src/zenml/utils/io_utils.py +++ b/src/zenml/utils/io_utils.py @@ -16,7 +16,8 @@ import fnmatch import os from pathlib import Path -from typing import TYPE_CHECKING, Iterable +from typing import TYPE_CHECKING +from collections.abc import Iterable import click diff --git a/src/zenml/utils/json_utils.py b/src/zenml/utils/json_utils.py index f356651af5f..9dad46a8dd9 100644 --- a/src/zenml/utils/json_utils.py +++ b/src/zenml/utils/json_utils.py @@ -32,7 +32,8 @@ from pathlib import Path from re import Pattern from types import GeneratorType -from typing import Any, Callable, Dict, Type, Union +from typing import Any +from collections.abc import Callable from uuid import UUID from pydantic import NameEmail, SecretBytes, SecretStr @@ -41,7 +42,7 @@ __all__ = "pydantic_encoder" -def isoformat(obj: Union[datetime.date, datetime.time]) -> str: +def isoformat(obj: datetime.date | datetime.time) -> str: """Function to convert a datetime into iso format. Args: @@ -53,7 +54,7 @@ def isoformat(obj: Union[datetime.date, datetime.time]) -> str: return obj.isoformat() -def decimal_encoder(dec_value: Decimal) -> Union[int, float]: +def decimal_encoder(dec_value: Decimal) -> int | float: """Encodes a Decimal as int of there's no exponent, otherwise float. This is useful when we use ConstrainedDecimal to represent Numeric(x,0) @@ -79,7 +80,7 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]: return float(dec_value) -ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { +ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { bytes: lambda obj: obj.decode(), Color: str, datetime.date: isoformat, diff --git a/src/zenml/utils/materializer_utils.py b/src/zenml/utils/materializer_utils.py index dfe1ebb7577..3ab6c58f97a 100644 --- a/src/zenml/utils/materializer_utils.py +++ b/src/zenml/utils/materializer_utils.py @@ -13,16 +13,17 @@ # permissions and limitations under the License. """Util functions for materializers.""" -from typing import TYPE_CHECKING, Any, Optional, Sequence, Type +from typing import TYPE_CHECKING, Any +from collections.abc import Sequence if TYPE_CHECKING: from zenml.materializers.base_materializer import BaseMaterializer def select_materializer( - data_type: Type[Any], - materializer_classes: Sequence[Type["BaseMaterializer"]], -) -> Type["BaseMaterializer"]: + data_type: type[Any], + materializer_classes: Sequence[type["BaseMaterializer"]], +) -> type["BaseMaterializer"]: """Select a materializer for a given data type. Args: @@ -35,7 +36,7 @@ def select_materializer( Returns: The first materializer that can handle the given data type. """ - fallback: Optional[Type["BaseMaterializer"]] = None + fallback: type["BaseMaterializer"] | None = None for class_ in data_type.__mro__: for materializer_class in materializer_classes: diff --git a/src/zenml/utils/metadata_utils.py b/src/zenml/utils/metadata_utils.py index 2a5a4d8575d..09c423764f2 100644 --- a/src/zenml/utils/metadata_utils.py +++ b/src/zenml/utils/metadata_utils.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Utility functions to handle metadata for ZenML entities.""" -from typing import Dict, List, Optional, Set, Union, overload +from typing import overload from uuid import UUID from zenml.client import Client @@ -34,14 +34,14 @@ @overload def log_metadata( - metadata: Dict[str, MetadataType], + metadata: dict[str, MetadataType], ) -> None: ... @overload def log_metadata( *, - metadata: Dict[str, MetadataType], + metadata: dict[str, MetadataType], step_id: UUID, ) -> None: ... @@ -49,24 +49,24 @@ def log_metadata( @overload def log_metadata( *, - metadata: Dict[str, MetadataType], + metadata: dict[str, MetadataType], step_name: str, - run_id_name_or_prefix: Union[UUID, str], + run_id_name_or_prefix: UUID | str, ) -> None: ... @overload def log_metadata( *, - metadata: Dict[str, MetadataType], - run_id_name_or_prefix: Union[UUID, str], + metadata: dict[str, MetadataType], + run_id_name_or_prefix: UUID | str, ) -> None: ... @overload def log_metadata( *, - metadata: Dict[str, MetadataType], + metadata: dict[str, MetadataType], artifact_version_id: UUID, ) -> None: ... @@ -74,18 +74,18 @@ def log_metadata( @overload def log_metadata( *, - metadata: Dict[str, MetadataType], + metadata: dict[str, MetadataType], artifact_name: str, - artifact_version: Optional[str] = None, + artifact_version: str | None = None, ) -> None: ... @overload def log_metadata( *, - metadata: Dict[str, MetadataType], + metadata: dict[str, MetadataType], infer_artifact: bool = False, - artifact_name: Optional[str] = None, + artifact_name: str | None = None, ) -> None: ... @@ -93,7 +93,7 @@ def log_metadata( @overload def log_metadata( *, - metadata: Dict[str, MetadataType], + metadata: dict[str, MetadataType], model_version_id: UUID, ) -> None: ... @@ -101,35 +101,35 @@ def log_metadata( @overload def log_metadata( *, - metadata: Dict[str, MetadataType], + metadata: dict[str, MetadataType], model_name: str, - model_version: Union[ModelStages, int, str], + model_version: ModelStages | int | str, ) -> None: ... @overload def log_metadata( *, - metadata: Dict[str, MetadataType], + metadata: dict[str, MetadataType], infer_model: bool = False, ) -> None: ... def log_metadata( - metadata: Dict[str, MetadataType], + metadata: dict[str, MetadataType], # Steps and runs - step_id: Optional[UUID] = None, - step_name: Optional[str] = None, - run_id_name_or_prefix: Optional[Union[UUID, str]] = None, + step_id: UUID | None = None, + step_name: str | None = None, + run_id_name_or_prefix: UUID | str | None = None, # Artifacts - artifact_version_id: Optional[UUID] = None, - artifact_name: Optional[str] = None, - artifact_version: Optional[str] = None, + artifact_version_id: UUID | None = None, + artifact_name: str | None = None, + artifact_version: str | None = None, infer_artifact: bool = False, # Models - model_version_id: Optional[UUID] = None, - model_name: Optional[str] = None, - model_version: Optional[Union[ModelStages, int, str]] = None, + model_version_id: UUID | None = None, + model_name: str | None = None, + model_version: ModelStages | int | str | None = None, infer_model: bool = False, ) -> None: """Logs metadata for various resource types in a generalized way. @@ -156,7 +156,7 @@ def log_metadata( """ client = Client() - resources: List[RunMetadataResource] = [] + resources: list[RunMetadataResource] = [] publisher_step_id = None # Log metadata to a step by ID @@ -375,7 +375,7 @@ def log_metadata( def bulk_log_metadata( - metadata: Dict[str, MetadataType], + metadata: dict[str, MetadataType], pipeline_runs: list[PipelineRunIdentifier] | None = None, step_runs: list[StepRunIdentifier] | None = None, artifact_versions: list[ArtifactVersionIdentifier] | None = None, @@ -400,7 +400,7 @@ def bulk_log_metadata( """ client = Client() - resources: Set[RunMetadataResource] = set() + resources: set[RunMetadataResource] = set() if not metadata: raise ValueError("You must provide metadata to log.") diff --git a/src/zenml/utils/networking_utils.py b/src/zenml/utils/networking_utils.py index 7e9735f5049..9283ef2d9c3 100644 --- a/src/zenml/utils/networking_utils.py +++ b/src/zenml/utils/networking_utils.py @@ -14,7 +14,7 @@ """Utility functions for networking.""" import socket -from typing import List, Optional, Tuple, cast +from typing import cast from urllib.parse import urlparse from zenml.environment import Environment @@ -47,7 +47,7 @@ def port_available(port: int, address: str = "127.0.0.1") -> bool: # missing code paths. pass s.bind((address, port)) - except socket.error as e: + except OSError as e: logger.debug("Port %d unavailable on %s: %s", port, address, e) return False @@ -68,9 +68,9 @@ def find_available_port() -> int: def lookup_preferred_or_free_port( - preferred_ports: List[int] = [], + preferred_ports: list[int] = [], allocate_port_if_busy: bool = True, - range: Tuple[int, int] = SCAN_PORT_RANGE, + range: tuple[int, int] = SCAN_PORT_RANGE, address: str = "127.0.0.1", ) -> int: """Find a preferred TCP port that is available or search for a free TCP port. @@ -98,21 +98,21 @@ def lookup_preferred_or_free_port( if port_available(port, address): return port if not allocate_port_if_busy: - raise IOError(f"TCP port {preferred_ports} is not available.") + raise OSError(f"TCP port {preferred_ports} is not available.") available_port = scan_for_available_port( start=range[0], stop=range[1], address=address ) if available_port: return available_port - raise IOError(f"No free TCP ports found in range {range}") + raise OSError(f"No free TCP ports found in range {range}") def scan_for_available_port( start: int = SCAN_PORT_RANGE[0], stop: int = SCAN_PORT_RANGE[1], address: str = "127.0.0.1", -) -> Optional[int]: +) -> int | None: """Scan the local network for an available port in the given range. Args: @@ -149,7 +149,7 @@ def port_is_open(hostname: str, port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: result = sock.connect_ex((hostname, port)) return result == 0 - except socket.error as e: + except OSError as e: logger.debug( f"Error checking TCP port {port} on host {hostname}: {str(e)}" ) diff --git a/src/zenml/utils/notebook_utils.py b/src/zenml/utils/notebook_utils.py index 6c3df9239ad..457f5a77c86 100644 --- a/src/zenml/utils/notebook_utils.py +++ b/src/zenml/utils/notebook_utils.py @@ -14,7 +14,8 @@ """Notebook utilities.""" import hashlib -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Optional, TypeVar, Union +from collections.abc import Callable from zenml.environment import Environment from zenml.logger import get_logger @@ -64,7 +65,7 @@ def inner_decorator(obj: "AnyObject") -> "AnyObject": return inner_decorator(_obj) -def get_active_notebook_cell_code() -> Optional[str]: +def get_active_notebook_cell_code() -> str | None: """Get the code of the currently active notebook cell. Returns: @@ -95,7 +96,7 @@ def try_to_save_notebook_cell_code(obj: Any) -> None: ) -def load_notebook_cell_code(obj: Any) -> Optional[str]: +def load_notebook_cell_code(obj: Any) -> str | None: """Load the notebook cell code for an object. Args: diff --git a/src/zenml/utils/package_utils.py b/src/zenml/utils/package_utils.py index 1b8dc0d44ce..9423ef439f6 100644 --- a/src/zenml/utils/package_utils.py +++ b/src/zenml/utils/package_utils.py @@ -18,7 +18,7 @@ distribution, distributions, ) -from typing import Dict, List, Optional, Union, cast +from typing import cast import requests from packaging import version @@ -59,7 +59,7 @@ def is_latest_zenml_version() -> bool: return True -def clean_requirements(requirements: List[str]) -> List[str]: +def clean_requirements(requirements: list[str]) -> list[str]: """Clean requirements list from redundant requirements. Args: @@ -96,7 +96,7 @@ def clean_requirements(requirements: List[str]) -> List[str]: return sorted(cleaned.values()) -def requirement_installed(requirement: Union[str, Requirement]) -> bool: +def requirement_installed(requirement: str | Requirement) -> bool: """Check if a requirement is installed. Args: @@ -118,7 +118,7 @@ def requirement_installed(requirement: Union[str, Requirement]) -> bool: def get_dependencies( requirement: Requirement, recursive: bool = False -) -> List[Requirement]: +) -> list[Requirement]: """Get the dependencies of a requirement. Args: @@ -129,7 +129,7 @@ def get_dependencies( A list of requirements. """ dist = distribution(requirement.name) - marker_environment = cast(Dict[str, str], default_environment()) + marker_environment = cast(dict[str, str], default_environment()) dependencies = [] @@ -168,8 +168,8 @@ def get_dependencies( def get_package_information( - package_names: Optional[List[str]] = None, -) -> Dict[str, str]: + package_names: list[str] | None = None, +) -> dict[str, str]: """Get package information. Args: diff --git a/src/zenml/utils/pagination_utils.py b/src/zenml/utils/pagination_utils.py index 7d9d8d369fa..adc110762bf 100644 --- a/src/zenml/utils/pagination_utils.py +++ b/src/zenml/utils/pagination_utils.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """Pagination utilities.""" -from typing import Any, Callable, List, TypeVar +from typing import Any, TypeVar +from collections.abc import Callable from zenml.models import BaseIdentifiedResponse, Page @@ -22,7 +23,7 @@ def depaginate( list_method: Callable[..., Page[AnyResponse]], **kwargs: Any -) -> List[AnyResponse]: +) -> list[AnyResponse]: """Depaginate the results from a client or store method that returns pages. Args: diff --git a/src/zenml/utils/pipeline_docker_image_builder.py b/src/zenml/utils/pipeline_docker_image_builder.py index b8657f74b06..4736699a3e7 100644 --- a/src/zenml/utils/pipeline_docker_image_builder.py +++ b/src/zenml/utils/pipeline_docker_image_builder.py @@ -20,12 +20,9 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, Optional, - Sequence, - Tuple, ) +from collections.abc import Sequence import zenml from zenml.config import DockerSettings @@ -81,11 +78,11 @@ def build_docker_image( tag: str, stack: "Stack", include_files: bool, - entrypoint: Optional[str] = None, - extra_files: Optional[Dict[str, str]] = None, + entrypoint: str | None = None, + extra_files: dict[str, str] | None = None, code_repository: Optional["BaseCodeRepository"] = None, - extra_requirements_files: Dict[str, List[str]] = {}, - ) -> Tuple[str, Optional[str], Optional[str]]: + extra_requirements_files: dict[str, list[str]] = {}, + ) -> tuple[str, str | None, str | None]: """Builds (and optionally pushes) a Docker image to run a pipeline. Use the image name returned by this method whenever you need to uniquely @@ -123,8 +120,8 @@ def build_docker_image( image build. ValueError: If the specified Dockerfile does not exist. """ - requirements: Optional[str] = None - dockerfile: Optional[str] = None + requirements: str | None = None + dockerfile: str | None = None if docker_settings.skip_build: assert ( @@ -400,7 +397,7 @@ def _get_target_image_name( @classmethod def _add_requirements_files( cls, - requirements_files: List[Tuple[str, str, List[str]]], + requirements_files: list[tuple[str, str, list[str]]], build_context: "BuildContext", ) -> None: """Adds requirements files to the build context. @@ -419,8 +416,8 @@ def gather_requirements_files( stack: "Stack", code_repository: Optional["BaseCodeRepository"] = None, log: bool = True, - extra_requirements_files: Dict[str, List[str]] = {}, - ) -> List[Tuple[str, str, List[str]]]: + extra_requirements_files: dict[str, list[str]] = {}, + ) -> list[tuple[str, str, list[str]]]: """Gathers and/or generates pip requirements files. This method is called in `PipelineDockerImageBuilder.build_docker_image` @@ -504,7 +501,7 @@ def gather_requirements_files( requirements = None pyproject_path = None - requirements_files: List[Tuple[str, str, List[str]]] = [] + requirements_files: list[tuple[str, str, list[str]]] = [] # Generate requirements file for the local environment if configured if docker_settings.replicate_local_python_environment: @@ -680,7 +677,7 @@ def _run_command(command: str) -> str: else "Including", path, ) - elif isinstance(requirements, List): + elif isinstance(requirements, list): user_requirements = "\n".join(requirements) if log: logger.info( @@ -701,9 +698,9 @@ def _run_command(command: str) -> str: def _generate_zenml_pipeline_dockerfile( parent_image: str, docker_settings: DockerSettings, - requirements_files: Sequence[Tuple[str, str, List[str]]] = (), + requirements_files: Sequence[tuple[str, str, list[str]]] = (), apt_packages: Sequence[str] = (), - entrypoint: Optional[str] = None, + entrypoint: str | None = None, ) -> str: """Generates a Dockerfile. @@ -744,7 +741,7 @@ def _generate_zenml_pipeline_dockerfile( == PythonPackageInstaller.PIP ): install_command = "pip install" - default_installer_args: Dict[str, Any] = PIP_DEFAULT_ARGS + default_installer_args: dict[str, Any] = PIP_DEFAULT_ARGS elif ( docker_settings.python_package_installer == PythonPackageInstaller.UV diff --git a/src/zenml/utils/proxy_utils.py b/src/zenml/utils/proxy_utils.py index 29bf836d090..52c993804db 100644 --- a/src/zenml/utils/proxy_utils.py +++ b/src/zenml/utils/proxy_utils.py @@ -15,13 +15,14 @@ from abc import ABC from functools import wraps -from typing import Any, Callable, Type, TypeVar, cast +from typing import Any, TypeVar, cast +from collections.abc import Callable -C = TypeVar("C", bound=Type[ABC]) +C = TypeVar("C", bound=type[ABC]) F = TypeVar("F", bound=Callable[..., Any]) -def make_proxy_class(interface: Type[ABC], attribute: str) -> Callable[[C], C]: +def make_proxy_class(interface: type[ABC], attribute: str) -> Callable[[C], C]: """Proxy class decorator. Use this decorator to transform the decorated class into a proxy that diff --git a/src/zenml/utils/pydantic_utils.py b/src/zenml/utils/pydantic_utils.py index 8afb84ff079..968643af218 100644 --- a/src/zenml/utils/pydantic_utils.py +++ b/src/zenml/utils/pydantic_utils.py @@ -16,7 +16,8 @@ import inspect import json from json.decoder import JSONDecodeError -from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union, cast +from typing import Any, TypeVar, Union, cast +from collections.abc import Callable import yaml from pydantic import ( @@ -43,7 +44,7 @@ def update_model( original: M, - update: Union["BaseModel", Dict[str, Any]], + update: Union["BaseModel", dict[str, Any]], recursive: bool = True, exclude_none: bool = False, ) -> M: @@ -58,7 +59,7 @@ def update_model( Returns: The updated model. """ - if isinstance(update, Dict): + if isinstance(update, dict): if exclude_none: update_dict = dict_utils.remove_none_values( update, recursive=recursive @@ -83,7 +84,7 @@ class TemplateGenerator: """Class to generate templates for pydantic models or classes.""" def __init__( - self, instance_or_class: Union[BaseModel, Type[BaseModel]] + self, instance_or_class: BaseModel | type[BaseModel] ) -> None: """Initializes the template generator. @@ -93,7 +94,7 @@ def __init__( """ self.instance_or_class = instance_or_class - def run(self) -> Dict[str, Any]: + def run(self) -> dict[str, Any]: """Generates the template. Returns: @@ -111,9 +112,9 @@ def run(self) -> Dict[str, Any]: # Convert to json in an intermediate step, so we can leverage Pydantic's # encoder to support types like UUID and datetime json_string = json.dumps(template, default=pydantic_encoder) - return cast(Dict[str, Any], json.loads(json_string)) + return cast(dict[str, Any], json.loads(json_string)) - def _generate_template_for_model(self, model: BaseModel) -> Dict[str, Any]: + def _generate_template_for_model(self, model: BaseModel) -> dict[str, Any]: """Generates a template for a pydantic model. Args: @@ -132,8 +133,8 @@ def _generate_template_for_model(self, model: BaseModel) -> Dict[str, Any]: def _generate_template_for_model_class( self, - model_class: Type[BaseModel], - ) -> Dict[str, Any]: + model_class: type[BaseModel], + ) -> dict[str, Any]: """Generates a template for a pydantic model class. Args: @@ -142,7 +143,7 @@ def _generate_template_for_model_class( Returns: The model class template. """ - template: Dict[str, Any] = {} + template: dict[str, Any] = {} for name, field in model_class.model_fields.items(): annotation = field.annotation @@ -175,7 +176,7 @@ def _generate_template_for_value(self, value: Any) -> Any: Returns: The value template. """ - if isinstance(value, Dict): + if isinstance(value, dict): return { k: self._generate_template_for_value(v) for k, v in value.items() @@ -221,7 +222,7 @@ def yaml(self, sort_keys: bool = False, **kwargs: Any) -> str: return yaml.dump(dict_, sort_keys=sort_keys) @classmethod - def from_yaml(cls: Type[M], path: str) -> M: + def from_yaml(cls: type[M], path: str) -> M: """Creates an instance from a YAML file. Args: @@ -236,10 +237,10 @@ def from_yaml(cls: Type[M], path: str) -> M: def validate_function_args( __func: Callable[..., Any], - __config: Optional[ConfigDict], + __config: ConfigDict | None, *args: Any, **kwargs: Any, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Validates arguments passed to a function. This function validates that all arguments to call the function exist and @@ -262,7 +263,7 @@ def validate_function_args( validated_args = () validated_kwargs = {} - def f(*args: Any, **kwargs: Dict[Any, Any]) -> None: + def f(*args: Any, **kwargs: dict[Any, Any]) -> None: nonlocal validated_args nonlocal validated_kwargs @@ -286,9 +287,9 @@ def f(*args: Any, **kwargs: Dict[Any, Any]) -> None: def model_validator_data_handler( raw_data: Any, - base_class: Type[BaseModel], + base_class: type[BaseModel], validation_info: ValidationInfo, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Utility function to parse raw input data of varying types to a dict. With the change to pydantic v2, validators which operate with "before" @@ -412,7 +413,7 @@ def before_validator_handler( """ def before_validator( - cls: Type[BaseModel], data: Any, validation_info: ValidationInfo + cls: type[BaseModel], data: Any, validation_info: ValidationInfo ) -> Any: """Wrapper method to handle the raw data. @@ -433,8 +434,8 @@ def before_validator( def has_validators( - pydantic_class: Type[BaseModel], - field_name: Optional[str] = None, + pydantic_class: type[BaseModel], + field_name: str | None = None, ) -> bool: """Function to check if a Pydantic model or a pydantic field has validators. diff --git a/src/zenml/utils/requirements_utils.py b/src/zenml/utils/requirements_utils.py index b0facb4ad92..fc6a934dca7 100644 --- a/src/zenml/utils/requirements_utils.py +++ b/src/zenml/utils/requirements_utils.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Requirement utils.""" -from typing import TYPE_CHECKING, List, Optional, Set, Tuple +from typing import TYPE_CHECKING from zenml.integrations.utils import get_integration_for_module @@ -23,8 +23,8 @@ def get_requirements_for_stack( stack: "StackResponse", - python_version: Optional[str] = None, -) -> Tuple[List[str], List[str]]: + python_version: str | None = None, +) -> tuple[list[str], list[str]]: """Get requirements for a stack model. Args: @@ -34,8 +34,8 @@ def get_requirements_for_stack( Returns: Tuple of PyPI and APT requirements of the stack. """ - pypi_requirements: Set[str] = set() - apt_packages: Set[str] = set() + pypi_requirements: set[str] = set() + apt_packages: set[str] = set() for component_list in stack.components.values(): assert len(component_list) == 1 @@ -57,8 +57,8 @@ def get_requirements_for_stack( def get_requirements_for_component( component: "ComponentResponse", - python_version: Optional[str] = None, -) -> Tuple[List[str], List[str]]: + python_version: str | None = None, +) -> tuple[list[str], list[str]]: """Get requirements for a component model. Args: diff --git a/src/zenml/utils/run_utils.py b/src/zenml/utils/run_utils.py index 9f6894d9527..3a50e376e8d 100644 --- a/src/zenml/utils/run_utils.py +++ b/src/zenml/utils/run_utils.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Utility functions for runs.""" -from typing import TYPE_CHECKING, Dict, List, Optional, Set, cast +from typing import TYPE_CHECKING, Optional, cast from zenml.enums import ExecutionStatus, StackComponentType from zenml.exceptions import IllegalOperationError @@ -167,7 +167,7 @@ def refresh_run_status( return run -def build_dag(steps: Dict[str, List[str]]) -> Dict[str, Set[str]]: +def build_dag(steps: dict[str, list[str]]) -> dict[str, set[str]]: """Build DAG with downstream steps from a list of steps. Args: @@ -176,7 +176,7 @@ def build_dag(steps: Dict[str, List[str]]) -> Dict[str, Set[str]]: Returns: The DAG with downstream steps. """ - dag: Dict[str, Set[str]] = {step: set() for step in steps} + dag: dict[str, set[str]] = {step: set() for step in steps} for step_name, upstream_steps in steps.items(): for upstream_step in upstream_steps: @@ -186,8 +186,8 @@ def build_dag(steps: Dict[str, List[str]]) -> Dict[str, Set[str]]: def find_all_downstream_steps( - step_name: str, dag: Dict[str, Set[str]] -) -> Set[str]: + step_name: str, dag: dict[str, set[str]] +) -> set[str]: """Find all downstream steps of a given step. Args: diff --git a/src/zenml/utils/secret_utils.py b/src/zenml/utils/secret_utils.py index 6a6a8355c4a..f06487ce721 100644 --- a/src/zenml/utils/secret_utils.py +++ b/src/zenml/utils/secret_utils.py @@ -14,10 +14,10 @@ """Utility functions for secrets and secret references.""" import re -from typing import TYPE_CHECKING, Any, List, NamedTuple, Union +from typing import TYPE_CHECKING, Any, NamedTuple, Union from pydantic import Field, PlainSerializer, SecretStr -from typing_extensions import Annotated +from typing import Annotated from zenml.logger import get_logger @@ -187,8 +187,8 @@ def is_clear_text_field(field: "FieldInfo") -> bool: def resolve_and_verify_secrets( - secrets: List[Union[str, "UUID"]], -) -> List["UUID"]: + secrets: list[Union[str, "UUID"]], +) -> list["UUID"]: """Convert a list of secret names or IDs to a list of secret IDs. Args: diff --git a/src/zenml/utils/settings_utils.py b/src/zenml/utils/settings_utils.py index ece67b769e3..fab49615261 100644 --- a/src/zenml/utils/settings_utils.py +++ b/src/zenml/utils/settings_utils.py @@ -14,7 +14,8 @@ """Utility functions for ZenML settings.""" import re -from typing import TYPE_CHECKING, Dict, Sequence, Type +from typing import TYPE_CHECKING +from collections.abc import Sequence from zenml.config.constants import ( DEPLOYMENT_SETTINGS_KEY, @@ -124,7 +125,7 @@ def get_stack_component_for_settings_key( return stack_component -def get_general_settings() -> Dict[str, Type["BaseSettings"]]: +def get_general_settings() -> dict[str, type["BaseSettings"]]: """Returns all general settings. Returns: diff --git a/src/zenml/utils/source_code_utils.py b/src/zenml/utils/source_code_utils.py index 199109a5b71..ac8f58fc8c4 100644 --- a/src/zenml/utils/source_code_utils.py +++ b/src/zenml/utils/source_code_utils.py @@ -26,10 +26,8 @@ ) from typing import ( Any, - Callable, - Type, - Union, ) +from collections.abc import Callable from zenml.environment import Environment @@ -53,16 +51,16 @@ def _new_getfile( object: Any, _old_getfile: Callable[ [ - Union[ - ModuleType, - Type[Any], - MethodType, - FunctionType, - TracebackType, - FrameType, - CodeType, - Callable[..., Any], - ] + ( + ModuleType | + type[Any] | + MethodType | + FunctionType | + TracebackType | + FrameType | + CodeType | + Callable[..., Any] + ) ], str, ] = inspect.getfile, diff --git a/src/zenml/utils/source_utils.py b/src/zenml/utils/source_utils.py index 8d982aa72fa..998f334742b 100644 --- a/src/zenml/utils/source_utils.py +++ b/src/zenml/utils/source_utils.py @@ -24,13 +24,8 @@ from types import BuiltinFunctionType, FunctionType, ModuleType from typing import ( Any, - Callable, - Dict, - Iterator, - Optional, - Type, - Union, ) +from collections.abc import Callable, Iterator from uuid import UUID from zenml.config.source import ( @@ -65,16 +60,16 @@ ) -_CUSTOM_SOURCE_ROOT: Optional[str] = os.getenv( +_CUSTOM_SOURCE_ROOT: str | None = os.getenv( ENV_ZENML_CUSTOM_SOURCE_ROOT, None ) -_SHARED_TEMPDIR: Optional[str] = None -_resolved_notebook_sources: Dict[str, str] = {} -_notebook_modules: Dict[str, UUID] = {} +_SHARED_TEMPDIR: str | None = None +_resolved_notebook_sources: dict[str, str] = {} +_notebook_modules: dict[str, UUID] = {} -def load(source: Union[Source, str]) -> Any: +def load(source: Source | str) -> Any: """Load a source or import path. Args: @@ -150,14 +145,14 @@ def load(source: Union[Source, str]) -> Any: def resolve( - obj: Union[ - Type[Any], - Callable[..., Any], - ModuleType, - FunctionType, - BuiltinFunctionType, - NoneType, - ], + obj: ( + type[Any] | + Callable[..., Any] | + ModuleType | + FunctionType | + BuiltinFunctionType | + NoneType + ), skip_validation: bool = False, ) -> Source: """Resolve an object. @@ -334,7 +329,7 @@ def get_source_root() -> str: return implicit_source_root -def set_custom_source_root(source_root: Optional[str]) -> None: +def set_custom_source_root(source_root: str | None) -> None: """Sets a custom source root. If set this has the highest priority and will always be used as the source @@ -576,7 +571,7 @@ def _resolve_module(module: ModuleType) -> str: def _load_module( - module_name: str, import_root: Optional[str] = None + module_name: str, import_root: str | None = None ) -> ModuleType: """Load a module. @@ -703,7 +698,7 @@ def _try_to_load_notebook_source(source: NotebookSource) -> Any: return obj -def _get_package_for_module(module_name: str) -> Optional[str]: +def _get_package_for_module(module_name: str) -> str | None: """Get the package name for a module. Args: @@ -724,7 +719,7 @@ def _get_package_for_module(module_name: str) -> Optional[str]: return None -def _get_package_version(package_name: str) -> Optional[str]: +def _get_package_version(package_name: str) -> str | None: """Gets the version of a package. Args: @@ -746,8 +741,8 @@ def _get_package_version(package_name: str) -> Optional[str]: # currently doesn't support this for abstract classes: # https://github.com/python/mypy/issues/4717 def load_and_validate_class( - source: Union[str, Source], expected_class: Type[Any] -) -> Type[Any]: + source: str | Source, expected_class: type[Any] +) -> type[Any]: """Loads a source class and validates its class. Args: @@ -772,7 +767,7 @@ def load_and_validate_class( def validate_source_class( - source: Union[Source, str], expected_class: Type[Any] + source: Source | str, expected_class: type[Any] ) -> bool: """Validates that a source resolves to a certain class. @@ -794,7 +789,7 @@ def validate_source_class( return False -def get_resolved_notebook_sources() -> Dict[str, str]: +def get_resolved_notebook_sources() -> dict[str, str]: """Get all notebook sources that were resolved in this process. Returns: diff --git a/src/zenml/utils/string_utils.py b/src/zenml/utils/string_utils.py index a57d4dbd4f9..01a7f77034f 100644 --- a/src/zenml/utils/string_utils.py +++ b/src/zenml/utils/string_utils.py @@ -17,7 +17,8 @@ import functools import random import string -from typing import Any, Callable, Dict, Optional, TypeVar, cast +from typing import Any, TypeVar, cast +from collections.abc import Callable from pydantic import BaseModel @@ -148,7 +149,7 @@ def validate_name(model: BaseModel, name_field: str = "name") -> None: def format_name_template( name_template: str, - substitutions: Optional[Dict[str, str]] = None, + substitutions: dict[str, str] | None = None, ) -> str: """Formats a name template with the given arguments. @@ -233,7 +234,7 @@ def substitute_string(value: V, substitution_func: Callable[[str], str]) -> V: model_values[k] = new_value return cast(V, type(value).model_validate(model_values)) # type: ignore[redundant-cast] - elif isinstance(value, Dict): + elif isinstance(value, dict): return cast( V, {substitute_(k): substitute_(v) for k, v in value.items()} ) diff --git a/src/zenml/utils/tag_utils.py b/src/zenml/utils/tag_utils.py index 9a25619006f..acf14a27e43 100644 --- a/src/zenml/utils/tag_utils.py +++ b/src/zenml/utils/tag_utils.py @@ -15,10 +15,7 @@ from typing import ( TYPE_CHECKING, - List, - Optional, TypeVar, - Union, overload, ) from uuid import UUID @@ -104,9 +101,9 @@ class Tag(BaseModel): """A model representing a tag.""" name: str - color: Optional[ColorVariants] = None - exclusive: Optional[bool] = None - cascade: Optional[bool] = None + color: ColorVariants | None = None + exclusive: bool | None = None + cascade: bool | None = None def to_request(self) -> "TagRequest": """Convert the tag to a TagRequest. @@ -127,30 +124,30 @@ def to_request(self) -> "TagRequest": @overload def add_tags( - tags: Union[str, Tag, List[Union[str, Tag]]], + tags: str | Tag | list[str | Tag], ) -> None: ... @overload def add_tags( *, - tags: Union[str, Tag, List[Union[str, Tag]]], - run: Union[UUID, str], + tags: str | Tag | list[str | Tag], + run: UUID | str, ) -> None: ... @overload def add_tags( *, - tags: Union[str, Tag, List[Union[str, Tag]]], - artifact: Union[UUID, str], + tags: str | Tag | list[str | Tag], + artifact: UUID | str, ) -> None: ... @overload def add_tags( *, - tags: Union[str, Tag, List[Union[str, Tag]]], + tags: str | Tag | list[str | Tag], artifact_version_id: UUID, ) -> None: ... @@ -158,7 +155,7 @@ def add_tags( @overload def add_tags( *, - tags: Union[str, Tag, List[Union[str, Tag]]], + tags: str | Tag | list[str | Tag], artifact_name: str, artifact_version: str, ) -> None: ... @@ -167,63 +164,63 @@ def add_tags( @overload def add_tags( *, - tags: Union[str, Tag, List[Union[str, Tag]]], + tags: str | Tag | list[str | Tag], infer_artifact: bool = False, - artifact_name: Optional[str] = None, + artifact_name: str | None = None, ) -> None: ... @overload def add_tags( *, - tags: Union[str, Tag, List[Union[str, Tag]]], - pipeline: Union[UUID, str], + tags: str | Tag | list[str | Tag], + pipeline: UUID | str, ) -> None: ... @overload def add_tags( *, - tags: Union[str, Tag, List[Union[str, Tag]]], - run_template: Union[UUID, str], + tags: str | Tag | list[str | Tag], + run_template: UUID | str, ) -> None: ... @overload def add_tags( *, - tags: Union[str, Tag, List[Union[str, Tag]]], - snapshot: Union[UUID, str], + tags: str | Tag | list[str | Tag], + snapshot: UUID | str, ) -> None: ... @overload def add_tags( *, - tags: Union[str, Tag, List[Union[str, Tag]]], - deployment: Union[UUID, str], + tags: str | Tag | list[str | Tag], + deployment: UUID | str, ) -> None: ... def add_tags( - tags: Union[str, Tag, List[Union[str, Tag]]], + tags: str | Tag | list[str | Tag], # Pipelines - pipeline: Optional[Union[UUID, str]] = None, + pipeline: UUID | str | None = None, # Runs - run: Optional[Union[UUID, str]] = None, + run: UUID | str | None = None, # Run Templates - run_template: Optional[Union[UUID, str]] = None, + run_template: UUID | str | None = None, # Snapshots - snapshot: Optional[Union[UUID, str]] = None, + snapshot: UUID | str | None = None, # Deployments - deployment: Optional[Union[UUID, str]] = None, + deployment: UUID | str | None = None, # Artifacts - artifact: Optional[Union[UUID, str]] = None, + artifact: UUID | str | None = None, # Artifact Versions - artifact_version_id: Optional[UUID] = None, - artifact_name: Optional[str] = None, - artifact_version: Optional[str] = None, - infer_artifact: Optional[bool] = None, + artifact_version_id: UUID | None = None, + artifact_name: str | None = None, + artifact_version: str | None = None, + infer_artifact: bool | None = None, ) -> None: """Add tags to various resource types in a generalized way. @@ -522,62 +519,62 @@ def add_tags( @overload def remove_tags( - tags: Union[str, List[str]], + tags: str | list[str], ) -> None: ... @overload def remove_tags( *, - tags: Union[str, List[str]], - pipeline: Union[UUID, str], + tags: str | list[str], + pipeline: UUID | str, ) -> None: ... @overload def remove_tags( *, - tags: Union[str, List[str]], - run: Union[UUID, str], + tags: str | list[str], + run: UUID | str, ) -> None: ... @overload def remove_tags( *, - tags: Union[str, List[str]], - run_template: Union[UUID, str], + tags: str | list[str], + run_template: UUID | str, ) -> None: ... @overload def remove_tags( *, - tags: Union[str, List[str]], - snapshot: Union[UUID, str], + tags: str | list[str], + snapshot: UUID | str, ) -> None: ... @overload def remove_tags( *, - tags: Union[str, List[str]], - deployment: Union[UUID, str], + tags: str | list[str], + deployment: UUID | str, ) -> None: ... @overload def remove_tags( *, - tags: Union[str, List[str]], - artifact: Union[UUID, str], + tags: str | list[str], + artifact: UUID | str, ) -> None: ... @overload def remove_tags( *, - tags: Union[str, List[str]], + tags: str | list[str], artifact_version_id: UUID, ) -> None: ... @@ -585,7 +582,7 @@ def remove_tags( @overload def remove_tags( *, - tags: Union[str, List[str]], + tags: str | list[str], artifact_name: str, artifact_version: str, ) -> None: ... @@ -594,31 +591,31 @@ def remove_tags( @overload def remove_tags( *, - tags: Union[str, List[str]], + tags: str | list[str], infer_artifact: bool = False, - artifact_name: Optional[str] = None, + artifact_name: str | None = None, ) -> None: ... def remove_tags( - tags: Union[str, List[str]], + tags: str | list[str], # Pipelines - pipeline: Optional[Union[UUID, str]] = None, + pipeline: UUID | str | None = None, # Runs - run: Optional[Union[UUID, str]] = None, + run: UUID | str | None = None, # Run Templates - run_template: Optional[Union[UUID, str]] = None, + run_template: UUID | str | None = None, # Snapshots - snapshot: Optional[Union[UUID, str]] = None, + snapshot: UUID | str | None = None, # Deployments - deployment: Optional[Union[UUID, str]] = None, + deployment: UUID | str | None = None, # Artifacts - artifact: Optional[Union[UUID, str]] = None, + artifact: UUID | str | None = None, # Artifact Versions - artifact_version_id: Optional[UUID] = None, - artifact_name: Optional[str] = None, - artifact_version: Optional[str] = None, - infer_artifact: Optional[bool] = None, + artifact_version_id: UUID | None = None, + artifact_name: str | None = None, + artifact_version: str | None = None, + infer_artifact: bool | None = None, ) -> None: """Remove tags from various resource types in a generalized way. diff --git a/src/zenml/utils/time_utils.py b/src/zenml/utils/time_utils.py index a76baf087ac..0a66de3d020 100644 --- a/src/zenml/utils/time_utils.py +++ b/src/zenml/utils/time_utils.py @@ -14,10 +14,9 @@ """Time utils.""" from datetime import datetime, timedelta, timezone -from typing import Optional, Union -def utc_now(tz_aware: Union[bool, datetime] = False) -> datetime: +def utc_now(tz_aware: bool | datetime = False) -> datetime: """Get the current time in the UTC timezone. Args: @@ -116,7 +115,7 @@ def seconds_to_human_readable(time_seconds: int) -> str: def expires_in( expires_at: datetime, expired_str: str, - skew_tolerance: Optional[int] = None, + skew_tolerance: int | None = None, ) -> str: """Returns a human-readable string of the time until an expiration. diff --git a/src/zenml/utils/typed_model.py b/src/zenml/utils/typed_model.py index f37bc96d1e6..805b2e32b58 100644 --- a/src/zenml/utils/typed_model.py +++ b/src/zenml/utils/typed_model.py @@ -14,13 +14,13 @@ """Utility classes for adding type information to Pydantic models.""" import json -from typing import Any, Dict, Tuple, Type, cast +from typing import Any, cast from pydantic import BaseModel, Field # TODO: Investigate if we can solve this import a different way. from pydantic._internal._model_construction import ModelMetaclass -from typing_extensions import Literal +from typing import Literal from zenml.utils import source_utils @@ -29,7 +29,7 @@ class BaseTypedModelMeta(ModelMetaclass): """Metaclass responsible for adding type information to Pydantic models.""" def __new__( - mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any] + mcs, name: str, bases: tuple[type[Any], ...], dct: dict[str, Any] ) -> "BaseTypedModelMeta": """Creates a Pydantic BaseModel class. @@ -59,7 +59,7 @@ def __new__( dct.setdefault("__annotations__", dict())["type"] = type_ann dct["type"] = type cls = cast( - Type["BaseTypedModel"], super().__new__(mcs, name, bases, dct) + type["BaseTypedModel"], super().__new__(mcs, name, bases, dct) ) return cls @@ -104,7 +104,7 @@ class TheMatrix(BaseTypedModel): @classmethod def from_dict( cls, - model_dict: Dict[str, Any], + model_dict: dict[str, Any], ) -> "BaseTypedModel": """Instantiate a Pydantic model from a serialized JSON-able dict representation. diff --git a/src/zenml/utils/typing_utils.py b/src/zenml/utils/typing_utils.py index 09248bb898b..5adc58fd0d3 100644 --- a/src/zenml/utils/typing_utils.py +++ b/src/zenml/utils/typing_utils.py @@ -17,23 +17,22 @@ https://github.com/pydantic/pydantic/blob/v1.10.14/pydantic/typing.py """ -import sys import typing -from typing import Any, Optional, Set, Tuple, Type, Union, cast +from typing import Any, Union, cast from typing import get_args as _typing_get_args from typing import get_origin as _typing_get_origin -from typing_extensions import Annotated, Literal +from typing import Annotated, Literal # Annotated[...] is implemented by returning an instance of one of these # classes, depending on python/typing_extensions version. AnnotatedTypeNames = {"AnnotatedMeta", "_AnnotatedAlias"} # None types -NONE_TYPES: Tuple[Any, Any, Any] = (None, None.__class__, Literal[None]) +NONE_TYPES: tuple[Any, Any, Any] = (None, None.__class__, Literal[None]) # Literal types -LITERAL_TYPES: Set[Any] = {Literal} +LITERAL_TYPES: set[Any] = {Literal} if hasattr(typing, "Literal"): LITERAL_TYPES.add(typing.Literal) @@ -54,10 +53,9 @@ def is_none_type(type_: Any) -> bool: # ----- is_union ----- -if sys.version_info < (3, 10): - def is_union(type_: Optional[Type[Any]]) -> bool: - """Checks if the provided type is a union type. +def is_union(type_: type[Any] | None) -> bool: + """Checks if the provided type is a union type. Args: type_: type to check. @@ -65,29 +63,15 @@ def is_union(type_: Optional[Type[Any]]) -> bool: Returns: boolean indicating whether the type is union type. """ - return type_ is Union # type: ignore[comparison-overlap] + import types - -else: - - def is_union(type_: Optional[Type[Any]]) -> bool: - """Checks if the provided type is a union type. - - Args: - type_: type to check. - - Returns: - boolean indicating whether the type is union type. - """ - import types - - return type_ is Union or type_ is types.UnionType # type: ignore[comparison-overlap] + return type_ is Union or type_ is types.UnionType # type: ignore[comparison-overlap] # ----- literal ----- -def is_literal_type(type_: Type[Any]) -> bool: +def is_literal_type(type_: type[Any]) -> bool: """Checks if the provided type is a literal type. Args: @@ -99,7 +83,7 @@ def is_literal_type(type_: Type[Any]) -> bool: return Literal is not None and get_origin(type_) in LITERAL_TYPES -def literal_values(type_: Type[Any]) -> Tuple[Any, ...]: +def literal_values(type_: type[Any]) -> tuple[Any, ...]: """Fetches the literal values defined in a type. Args: @@ -111,7 +95,7 @@ def literal_values(type_: Type[Any]) -> Tuple[Any, ...]: return get_args(type_) -def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]: +def all_literal_values(type_: type[Any]) -> tuple[Any, ...]: """Fetches the literal values defined in a type in a recursive manner. This method is used to retrieve all Literal values as Literal can be @@ -134,7 +118,7 @@ def all_literal_values(type_: Type[Any]) -> Tuple[Any, ...]: # ----- get_origin ----- -def get_origin(tp: Type[Any]) -> Optional[Type[Any]]: +def get_origin(tp: type[Any]) -> type[Any] | None: """Fetches the origin of a given type. We can't directly use `typing.get_origin` since we need a fallback to @@ -149,14 +133,14 @@ def get_origin(tp: Type[Any]) -> Optional[Type[Any]]: the origin type of the provided type. """ if type(tp).__name__ in AnnotatedTypeNames: - return cast(Type[Any], Annotated) # mypy complains about _SpecialForm + return cast(type[Any], Annotated) # mypy complains about _SpecialForm return _typing_get_origin(tp) or getattr(tp, "__origin__", None) # ----- get_args ----- -def _generic_get_args(tp: Type[Any]) -> Tuple[Any, ...]: +def _generic_get_args(tp: type[Any]) -> tuple[Any, ...]: """Generic get args function. In python 3.9, `typing.Dict`, `typing.List`, ... @@ -176,7 +160,7 @@ def _generic_get_args(tp: Type[Any]) -> Tuple[Any, ...]: # Special case for `tuple[()]`, which used to return ((),) with # `typing.Tuple in python 3.10- but now returns () for `tuple` and `Tuple`. try: - if tp == Tuple[()] or tp == tuple[()]: # type: ignore[comparison-overlap] + if tp == tuple[()] or tp == tuple[()]: # type: ignore[comparison-overlap] return ((),) # there is a TypeError when compiled with cython except TypeError: # pragma: no cover @@ -184,7 +168,7 @@ def _generic_get_args(tp: Type[Any]) -> Tuple[Any, ...]: return () -def get_args(tp: Type[Any]) -> Tuple[Any, ...]: +def get_args(tp: type[Any]) -> tuple[Any, ...]: """Get type arguments with all substitutions performed. For unions, basic simplifications used by Union constructor are performed. @@ -211,7 +195,7 @@ def get_args(tp: Type[Any]) -> Tuple[Any, ...]: ) -def is_optional(tp: Type[Any]) -> bool: +def is_optional(tp: type[Any]) -> bool: """Checks whether a given annotation is typing.Optional. Args: diff --git a/src/zenml/utils/uuid_utils.py b/src/zenml/utils/uuid_utils.py index 5feaf519197..dfdc4e82173 100644 --- a/src/zenml/utils/uuid_utils.py +++ b/src/zenml/utils/uuid_utils.py @@ -14,7 +14,7 @@ """Utility functions for handling UUIDs.""" import hashlib -from typing import Any, Optional, Union +from typing import Any from uuid import UUID @@ -40,8 +40,8 @@ def is_valid_uuid(value: Any, version: int = 4) -> bool: def parse_name_or_uuid( - name_or_id: Optional[str], -) -> Optional[Union[str, UUID]]: + name_or_id: str | None, +) -> str | UUID | None: """Convert a "name or id" string value to a string or UUID. Args: diff --git a/src/zenml/utils/visualization_utils.py b/src/zenml/utils/visualization_utils.py index 699f1b53775..ec19d817039 100644 --- a/src/zenml/utils/visualization_utils.py +++ b/src/zenml/utils/visualization_utils.py @@ -14,7 +14,7 @@ """Utility functions for dashboard visualizations.""" import json -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from IPython.core.display_functions import display from IPython.display import HTML, JSON, Image, Markdown @@ -28,7 +28,7 @@ def visualize_artifact( - artifact: "ArtifactVersionResponse", title: Optional[str] = None + artifact: "ArtifactVersionResponse", title: str | None = None ) -> None: """Visualize an artifact in notebook environments. diff --git a/src/zenml/utils/yaml_utils.py b/src/zenml/utils/yaml_utils.py index b7e28b923ba..f9cbf6b7eac 100644 --- a/src/zenml/utils/yaml_utils.py +++ b/src/zenml/utils/yaml_utils.py @@ -16,7 +16,7 @@ import json import os from pathlib import Path -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any from uuid import UUID import yaml @@ -27,7 +27,7 @@ def write_yaml( file_path: str, - contents: Union[Dict[Any, Any], List[Any]], + contents: dict[Any, Any] | list[Any], sort_keys: bool = True, ) -> None: """Write contents as YAML format to file_path. @@ -51,7 +51,7 @@ def write_yaml( ) -def append_yaml(file_path: str, contents: Dict[Any, Any]) -> None: +def append_yaml(file_path: str, contents: dict[Any, Any]) -> None: """Append contents to a YAML file at file_path. Args: @@ -121,7 +121,7 @@ def comment_out_yaml(yaml_string: str) -> str: def write_json( file_path: str, contents: Any, - encoder: Optional[Type[json.JSONEncoder]] = None, + encoder: type[json.JSONEncoder] | None = None, **json_dump_args: Any, ) -> None: """Write contents as JSON format to file_path. diff --git a/src/zenml/zen_server/auth.py b/src/zenml/zen_server/auth.py index e58afe3bfc8..d5aa6cfcaea 100644 --- a/src/zenml/zen_server/auth.py +++ b/src/zenml/zen_server/auth.py @@ -16,7 +16,8 @@ import functools from datetime import datetime, timedelta from functools import wraps -from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union, cast +from typing import Any, cast +from collections.abc import Awaitable, Callable from urllib.parse import urlencode, urlparse from uuid import UUID, uuid4 @@ -92,14 +93,14 @@ class AuthContext(BaseModel): """The authentication context.""" user: UserResponse - access_token: Optional[JWTToken] = None - encoded_access_token: Optional[str] = None - device: Optional[OAuthDeviceInternalResponse] = None - api_key: Optional[APIKeyInternalResponse] = None + access_token: JWTToken | None = None + encoded_access_token: str | None = None + device: OAuthDeviceInternalResponse | None = None + api_key: APIKeyInternalResponse | None = None def _fetch_and_verify_api_key( - api_key_id: UUID, key_to_verify: Optional[str] = None + api_key_id: UUID, key_to_verify: str | None = None ) -> APIKeyInternalResponse: """Fetches an API key from the database and verifies it. @@ -170,11 +171,11 @@ def _fetch_and_verify_api_key( def authenticate_credentials( - user_name_or_id: Optional[Union[str, UUID]] = None, - password: Optional[str] = None, - access_token: Optional[str] = None, - csrf_token: Optional[str] = None, - activation_token: Optional[str] = None, + user_name_or_id: str | UUID | None = None, + password: str | None = None, + access_token: str | None = None, + csrf_token: str | None = None, + activation_token: str | None = None, ) -> AuthContext: """Verify if user authentication credentials are valid. @@ -201,8 +202,8 @@ def authenticate_credentials( Raises: CredentialsNotValid: If the credentials are invalid. """ - user: Optional[UserAuthModel] = None - auth_context: Optional[AuthContext] = None + user: UserAuthModel | None = None + auth_context: AuthContext | None = None if user_name_or_id: try: # NOTE: this method will not return a user if the user name or ID @@ -223,7 +224,6 @@ def authenticate_credentials( f"Authentication error: error retrieving account " f"{user_name_or_id}" ) - pass if password is not None: if not UserAuthModel.verify_password(password, user): @@ -290,14 +290,14 @@ def authenticate_credentials( logger.error(error) raise CredentialsNotValid(error) - api_key_model: Optional[APIKeyInternalResponse] = None + api_key_model: APIKeyInternalResponse | None = None if decoded_token.api_key_id: # The API token was generated from an API key. We still have to # verify if the API key hasn't been deactivated or deleted in the # meantime. api_key_model = _fetch_and_verify_api_key(decoded_token.api_key_id) - device_model: Optional[OAuthDeviceInternalResponse] = None + device_model: OAuthDeviceInternalResponse | None = None if decoded_token.device_id: if server_config().auth_scheme in [ AuthScheme.NO_AUTH, @@ -367,7 +367,7 @@ def authenticate_credentials( # queries. @cache_result(expiry=30) - def get_schedule_active(schedule_id: UUID) -> Optional[bool]: + def get_schedule_active(schedule_id: UUID) -> bool | None: """Get the active status of a schedule. Args: @@ -412,7 +412,7 @@ def get_schedule_active(schedule_id: UUID) -> Optional[bool]: @cache_result(expiry=30) def check_if_pipeline_run_in_progress( pipeline_run_id: UUID, - ) -> Tuple[Optional[bool], Optional[datetime]]: + ) -> tuple[bool | None, datetime | None]: """Get the status of a pipeline run. Args: @@ -685,7 +685,7 @@ def authenticate_external_user( "Error fetching user information from external authenticator." ) - external_user: Optional[ExternalUserModel] = None + external_user: ExternalUserModel | None = None if 200 <= auth_response.status_code < 300: try: @@ -706,7 +706,6 @@ def authenticate_external_user( f"Error parsing user information from external " f"authenticator: {e}" ) - pass elif auth_response.status_code in [401, 403]: raise AuthorizationException("Not authorized to access this server.") @@ -732,7 +731,7 @@ def authenticate_external_user( # Check if the external user already exists in the ZenML server database # If not, create a new user. If yes, update the existing user. - user: Optional[UserResponse] = None + user: UserResponse | None = None if not external_user.is_service_account: users = store.list_users( UserFilter( @@ -924,14 +923,14 @@ def authenticate_api_key( def generate_access_token( user_id: UUID, - response: Optional[Response] = None, - request: Optional[Request] = None, - device: Optional[OAuthDeviceInternalResponse] = None, - api_key: Optional[APIKeyInternalResponse] = None, - expires_in: Optional[int] = None, - schedule_id: Optional[UUID] = None, - pipeline_run_id: Optional[UUID] = None, - deployment_id: Optional[UUID] = None, + response: Response | None = None, + request: Request | None = None, + device: OAuthDeviceInternalResponse | None = None, + api_key: APIKeyInternalResponse | None = None, + expires_in: int | None = None, + schedule_id: UUID | None = None, + pipeline_run_id: UUID | None = None, + deployment_id: UUID | None = None, ) -> OAuthTokenResponse: """Generates an access token for the given user. @@ -960,7 +959,7 @@ def generate_access_token( # If the expiration time is not supplied, the JWT tokens are set to expire # according to the values configured in the server config. Device tokens are # handled separately from regular user tokens. - expires: Optional[datetime] = None + expires: datetime | None = None if expires_in == 0: expires_in = None elif expires_in is not None: @@ -985,7 +984,7 @@ def generate_access_token( if response and request: # Extract the origin domain from the request; use the referer as a # fallback - origin_domain: Optional[str] = None + origin_domain: str | None = None origin = request.headers.get("origin", request.headers.get("referer")) if origin: # If the request origin is known, we use it to determine whether @@ -993,7 +992,7 @@ def generate_access_token( # measures. origin_domain = urlparse(origin).netloc - server_domain: Optional[str] = config.auth_cookie_domain + server_domain: str | None = config.auth_cookie_domain # If the server's cookie domain is not explicitly set in the # server's configuration, we use other sources to determine it: # @@ -1012,8 +1011,8 @@ def generate_access_token( if origin_domain and server_domain: same_site = is_same_or_subdomain(origin_domain, server_domain) - csrf_token: Optional[str] = None - session_id: Optional[UUID] = None + csrf_token: str | None = None + session_id: UUID | None = None if not same_site: # If responding to a cross-site login request, we need to generate and # sign a CSRF token associated with the authentication session. @@ -1056,7 +1055,7 @@ def generate_access_token( def generate_download_token( download_type: DownloadType, resource_id: UUID, - extra_claims: Optional[Dict[str, Any]] = None, + extra_claims: dict[str, Any] | None = None, expires_in_seconds: int = 30, ) -> str: """Generate a JWT token for downloading content. @@ -1094,7 +1093,7 @@ def verify_download_token( token: str, download_type: DownloadType, resource_id: UUID, - extra_claims: Optional[Dict[str, Any]] = None, + extra_claims: dict[str, Any] | None = None, ) -> None: """Verify a JWT token for downloading content. @@ -1112,7 +1111,7 @@ def verify_download_token( config = server_config() try: claims = cast( - Dict[str, Any], + dict[str, Any], jwt.decode( token, config.jwt_secret_key, @@ -1163,7 +1162,7 @@ def http_authentication( class CookieOAuth2TokenBearer(OAuth2PasswordBearer): """OAuth2 token bearer authentication scheme that uses a cookie.""" - async def __call__(self, request: Request) -> Optional[str]: + async def __call__(self, request: Request) -> str | None: """Extract the bearer token from the request. Args: diff --git a/src/zenml/zen_server/cache.py b/src/zenml/zen_server/cache.py index f6463d0484e..6c22524f3a4 100644 --- a/src/zenml/zen_server/cache.py +++ b/src/zenml/zen_server/cache.py @@ -16,8 +16,9 @@ import time from collections import OrderedDict from threading import Lock -from typing import Any, Callable, Optional -from typing import OrderedDict as OrderedDictType +from typing import Any +from collections.abc import Callable +from collections import OrderedDict as OrderedDictType from uuid import UUID from zenml.logger import get_logger @@ -95,7 +96,7 @@ def __init__(self, max_capacity: int, default_expiry: int) -> None: self.default_expiry = default_expiry self._lock = Lock() - def set(self, key: UUID, value: Any, expiry: Optional[int] = None) -> None: + def set(self, key: UUID, value: Any, expiry: int | None = None) -> None: """Insert value into cache with optional custom expiry time in seconds. Args: @@ -109,7 +110,7 @@ def set(self, key: UUID, value: Any, expiry: Optional[int] = None) -> None: ) self._cleanup() - def get(self, key: UUID) -> Optional[Any]: + def get(self, key: UUID) -> Any | None: """Retrieve value if it's still valid; otherwise, return None. Args: @@ -121,7 +122,7 @@ def get(self, key: UUID) -> Optional[Any]: with self._lock: return self._get_internal(key) - def _get_internal(self, key: UUID) -> Optional[Any]: + def _get_internal(self, key: UUID) -> Any | None: """Helper to retrieve a value without lock (internal use only). Args: @@ -153,7 +154,7 @@ def _cleanup(self) -> None: def cache_result( - expiry: Optional[int] = None, + expiry: int | None = None, ) -> Callable[[F], F]: """A decorator to cache the result of a function based on a UUID key argument. diff --git a/src/zenml/zen_server/cloud_utils.py b/src/zenml/zen_server/cloud_utils.py index a9eb27ec0ed..ba7719acf9c 100644 --- a/src/zenml/zen_server/cloud_utils.py +++ b/src/zenml/zen_server/cloud_utils.py @@ -4,7 +4,7 @@ import time from datetime import datetime, timedelta from threading import RLock -from typing import Any, Dict, Optional +from typing import Any, Optional import requests from requests.adapters import HTTPAdapter, Retry @@ -33,16 +33,16 @@ class ZenMLCloudConnection: def __init__(self) -> None: """Initialize the RBAC component.""" self._config = ServerProConfiguration.get_server_config() - self._session: Optional[requests.Session] = None - self._token: Optional[str] = None - self._token_expires_at: Optional[datetime] = None + self._session: requests.Session | None = None + self._token: str | None = None + self._token_expires_at: datetime | None = None self._lock = RLock() def request( self, method: str, endpoint: str, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, data: Any = None, ) -> requests.Response: """Send a request using the active session. @@ -81,7 +81,7 @@ def request( ) start_time = time.time() - status_code: Optional[int] = None + status_code: int | None = None try: response = self.session.request( method=method, @@ -126,7 +126,7 @@ def request( return response def get( - self, endpoint: str, params: Optional[Dict[str, Any]] + self, endpoint: str, params: dict[str, Any] | None ) -> requests.Response: """Send a GET request using the active session. @@ -143,7 +143,7 @@ def get( def post( self, endpoint: str, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, data: Any = None, ) -> requests.Response: """Send a POST request using the active session. @@ -164,8 +164,8 @@ def post( def patch( self, endpoint: str, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, ) -> requests.Response: """Send a PATCH request using the active session. @@ -185,8 +185,8 @@ def patch( def delete( self, endpoint: str, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, + data: dict[str, Any] | None = None, ) -> requests.Response: """Send a DELETE request using the active session. diff --git a/src/zenml/zen_server/deploy/base_provider.py b/src/zenml/zen_server/deploy/base_provider.py index f7382c84cf3..8743f2446d2 100644 --- a/src/zenml/zen_server/deploy/base_provider.py +++ b/src/zenml/zen_server/deploy/base_provider.py @@ -14,7 +14,8 @@ """Base ZenML server provider class.""" from abc import ABC, abstractmethod -from typing import ClassVar, Generator, Optional, Tuple, Type +from typing import ClassVar +from collections.abc import Generator from pydantic import ValidationError @@ -48,7 +49,7 @@ class BaseServerProvider(ABC): """ TYPE: ClassVar[ServerProviderType] - CONFIG_TYPE: ClassVar[Type[LocalServerDeploymentConfig]] = ( + CONFIG_TYPE: ClassVar[type[LocalServerDeploymentConfig]] = ( LocalServerDeploymentConfig ) @@ -87,7 +88,7 @@ def _convert_config( def deploy_server( self, config: LocalServerDeploymentConfig, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> LocalServerDeployment: """Deploy a new ZenML server. @@ -121,7 +122,7 @@ def deploy_server( def update_server( self, config: LocalServerDeploymentConfig, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> LocalServerDeployment: """Update an existing ZenML server deployment. @@ -165,7 +166,7 @@ def update_server( def remove_server( self, config: LocalServerDeploymentConfig, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> None: """Tears down and removes all resources and files associated with a ZenML server deployment. @@ -218,7 +219,7 @@ def get_server_logs( self, config: LocalServerDeploymentConfig, follow: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> Generator[str, bool, None]: """Retrieve the logs of a ZenML server. @@ -255,7 +256,7 @@ def _get_deployment_status( The status of the server deployment. """ gc = GlobalConfiguration() - url: Optional[str] = None + url: str | None = None if service.is_running: # all services must have an endpoint assert service.endpoint is not None @@ -291,7 +292,7 @@ def _get_deployment(self, service: BaseService) -> LocalServerDeployment: def _get_service_configuration( cls, server_config: LocalServerDeploymentConfig, - ) -> Tuple[ + ) -> tuple[ ServiceConfig, ServiceEndpointConfig, ServiceEndpointHealthMonitorConfig, @@ -309,7 +310,7 @@ def _get_service_configuration( def _create_service( self, config: LocalServerDeploymentConfig, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> BaseService: """Create, start and return a service instance for a ZenML server deployment. @@ -328,7 +329,7 @@ def _update_service( self, service: BaseService, config: LocalServerDeploymentConfig, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> BaseService: """Update an existing service instance for a ZenML server deployment. @@ -347,7 +348,7 @@ def _update_service( def _start_service( self, service: BaseService, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> BaseService: """Start a service instance for a ZenML server deployment. @@ -365,7 +366,7 @@ def _start_service( def _stop_service( self, service: BaseService, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> BaseService: """Stop a service instance for a ZenML server deployment. @@ -383,7 +384,7 @@ def _stop_service( def _delete_service( self, service: BaseService, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> None: """Remove a service instance for a ZenML server deployment. diff --git a/src/zenml/zen_server/deploy/daemon/daemon_provider.py b/src/zenml/zen_server/deploy/daemon/daemon_provider.py index 657936dce18..aa21af7f7e7 100644 --- a/src/zenml/zen_server/deploy/daemon/daemon_provider.py +++ b/src/zenml/zen_server/deploy/daemon/daemon_provider.py @@ -14,7 +14,7 @@ """Zen Server daemon provider implementation.""" import shutil -from typing import ClassVar, Optional, Tuple, Type, cast +from typing import ClassVar, cast from uuid import uuid4 from zenml import __version__ @@ -48,7 +48,7 @@ class DaemonServerProvider(BaseServerProvider): """Daemon ZenML server provider.""" TYPE: ClassVar[ServerProviderType] = ServerProviderType.DAEMON - CONFIG_TYPE: ClassVar[Type[LocalServerDeploymentConfig]] = ( + CONFIG_TYPE: ClassVar[type[LocalServerDeploymentConfig]] = ( DaemonServerDeploymentConfig ) @@ -79,7 +79,7 @@ def check_local_server_dependencies() -> None: def _get_service_configuration( cls, server_config: LocalServerDeploymentConfig, - ) -> Tuple[ + ) -> tuple[ ServiceConfig, ServiceEndpointConfig, ServiceEndpointHealthMonitorConfig, @@ -116,7 +116,7 @@ def _get_service_configuration( def _create_service( self, config: LocalServerDeploymentConfig, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> BaseService: """Create, start and return the local daemon ZenML server deployment service. @@ -166,7 +166,7 @@ def _update_service( self, service: BaseService, config: LocalServerDeploymentConfig, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> BaseService: """Update the local daemon ZenML server deployment service. @@ -207,7 +207,7 @@ def _update_service( def _start_service( self, service: BaseService, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> BaseService: """Start the local daemon ZenML server deployment service. @@ -228,7 +228,7 @@ def _start_service( def _stop_service( self, service: BaseService, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> BaseService: """Stop the local daemon ZenML server deployment service. @@ -249,7 +249,7 @@ def _stop_service( def _delete_service( self, service: BaseService, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> None: """Remove the local daemon ZenML server deployment service. diff --git a/src/zenml/zen_server/deploy/daemon/daemon_zen_server.py b/src/zenml/zen_server/deploy/daemon/daemon_zen_server.py index 935bb8d5c88..75720faab8e 100644 --- a/src/zenml/zen_server/deploy/daemon/daemon_zen_server.py +++ b/src/zenml/zen_server/deploy/daemon/daemon_zen_server.py @@ -15,7 +15,7 @@ import ipaddress import os -from typing import Dict, List, Optional, Tuple, Union, cast +from typing import Optional, cast from pydantic import ConfigDict, Field @@ -65,15 +65,15 @@ class DaemonServerDeploymentConfig(LocalServerDeploymentConfig): """ port: int = 8237 - ip_address: Union[ipaddress.IPv4Address, ipaddress.IPv6Address] = Field( + ip_address: ipaddress.IPv4Address | ipaddress.IPv6Address = Field( default=ipaddress.IPv4Address(DEFAULT_LOCAL_SERVICE_IP_ADDRESS), union_mode="left_to_right", ) blocking: bool = False - store: Optional[StoreConfiguration] = None + store: StoreConfiguration | None = None @property - def url(self) -> Optional[str]: + def url(self) -> str | None: """Get the configured server URL. Returns: @@ -144,14 +144,14 @@ def get_service(cls) -> Optional["DaemonZenServer"]: """ config_filename = os.path.join(cls.config_path(), "service.json") try: - with open(config_filename, "r") as f: + with open(config_filename) as f: return cast( "DaemonZenServer", DaemonZenServer.from_json(f.read()) ) except FileNotFoundError: return None - def _get_daemon_cmd(self) -> Tuple[List[str], Dict[str, str]]: + def _get_daemon_cmd(self) -> tuple[list[str], dict[str, str]]: """Get the command to start the daemon. Overrides the base class implementation to add the environment variable diff --git a/src/zenml/zen_server/deploy/deployer.py b/src/zenml/zen_server/deploy/deployer.py index 8b6bb2e06fb..fd2934b1733 100644 --- a/src/zenml/zen_server/deploy/deployer.py +++ b/src/zenml/zen_server/deploy/deployer.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """ZenML server deployer singleton implementation.""" -from typing import ClassVar, Dict, Generator, Optional, Type +from typing import ClassVar +from collections.abc import Generator from zenml.config.global_config import GlobalConfiguration from zenml.enums import ServerProviderType, StoreType @@ -44,10 +45,10 @@ class LocalServerDeployer(metaclass=SingletonMetaClass): server providers. """ - _providers: ClassVar[Dict[ServerProviderType, BaseServerProvider]] = {} + _providers: ClassVar[dict[ServerProviderType, BaseServerProvider]] = {} @classmethod - def register_provider(cls, provider: Type[BaseServerProvider]) -> None: + def register_provider(cls, provider: type[BaseServerProvider]) -> None: """Register a server provider. Args: @@ -93,7 +94,7 @@ def initialize_local_database(self) -> None: def deploy_server( self, config: LocalServerDeploymentConfig, - timeout: Optional[int] = None, + timeout: int | None = None, restart: bool = False, ) -> LocalServerDeployment: """Deploy the local ZenML server or update the existing deployment. @@ -131,7 +132,7 @@ def deploy_server( def update_server( self, config: LocalServerDeploymentConfig, - timeout: Optional[int] = None, + timeout: int | None = None, restart: bool = False, ) -> LocalServerDeployment: """Update an existing local ZenML server deployment. @@ -169,7 +170,7 @@ def update_server( def remove_server( self, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> None: """Tears down and removes all resources and files associated with the local ZenML server deployment. @@ -312,7 +313,7 @@ def get_server( def get_server_logs( self, follow: bool = False, - tail: Optional[int] = None, + tail: int | None = None, ) -> Generator[str, bool, None]: """Retrieve the logs for the local ZenML server. diff --git a/src/zenml/zen_server/deploy/deployment.py b/src/zenml/zen_server/deploy/deployment.py index 8ce69223460..56fa33c4bb9 100644 --- a/src/zenml/zen_server/deploy/deployment.py +++ b/src/zenml/zen_server/deploy/deployment.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Zen Server deployment definitions.""" -from typing import Optional from pydantic import BaseModel, ConfigDict @@ -33,7 +32,7 @@ class LocalServerDeploymentConfig(BaseModel): provider: ServerProviderType @property - def url(self) -> Optional[str]: + def url(self) -> str | None: """Get the configured server URL. Returns: @@ -73,10 +72,10 @@ class LocalServerDeploymentStatus(BaseModel): """ status: ServiceState - status_message: Optional[str] = None + status_message: str | None = None connected: bool - url: Optional[str] = None - ca_crt: Optional[str] = None + url: str | None = None + ca_crt: str | None = None class LocalServerDeployment(BaseModel): @@ -88,7 +87,7 @@ class LocalServerDeployment(BaseModel): """ config: LocalServerDeploymentConfig - status: Optional[LocalServerDeploymentStatus] = None + status: LocalServerDeploymentStatus | None = None @property def is_running(self) -> bool: diff --git a/src/zenml/zen_server/deploy/docker/docker_provider.py b/src/zenml/zen_server/deploy/docker/docker_provider.py index 9fe6df50054..3ac0ec872d0 100644 --- a/src/zenml/zen_server/deploy/docker/docker_provider.py +++ b/src/zenml/zen_server/deploy/docker/docker_provider.py @@ -15,7 +15,7 @@ import os import shutil -from typing import ClassVar, Optional, Tuple, Type, cast +from typing import ClassVar, cast from uuid import uuid4 from zenml.enums import ServerProviderType @@ -49,7 +49,7 @@ class DockerServerProvider(BaseServerProvider): """Docker ZenML server provider.""" TYPE: ClassVar[ServerProviderType] = ServerProviderType.DOCKER - CONFIG_TYPE: ClassVar[Type[LocalServerDeploymentConfig]] = ( + CONFIG_TYPE: ClassVar[type[LocalServerDeploymentConfig]] = ( DockerServerDeploymentConfig ) @@ -57,7 +57,7 @@ class DockerServerProvider(BaseServerProvider): def _get_service_configuration( cls, server_config: LocalServerDeploymentConfig, - ) -> Tuple[ + ) -> tuple[ ServiceConfig, ServiceEndpointConfig, ServiceEndpointHealthMonitorConfig, @@ -94,7 +94,7 @@ def _get_service_configuration( def _create_service( self, config: LocalServerDeploymentConfig, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> BaseService: """Create, start and return the docker ZenML server deployment service. @@ -145,7 +145,7 @@ def _update_service( self, service: BaseService, config: LocalServerDeploymentConfig, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> BaseService: """Update the docker ZenML server deployment service. @@ -187,7 +187,7 @@ def _update_service( def _start_service( self, service: BaseService, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> BaseService: """Start the docker ZenML server deployment service. @@ -208,7 +208,7 @@ def _start_service( def _stop_service( self, service: BaseService, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> BaseService: """Stop the docker ZenML server deployment service. @@ -229,7 +229,7 @@ def _stop_service( def _delete_service( self, service: BaseService, - timeout: Optional[int] = None, + timeout: int | None = None, ) -> None: """Remove the docker ZenML server deployment service. diff --git a/src/zenml/zen_server/deploy/docker/docker_zen_server.py b/src/zenml/zen_server/deploy/docker/docker_zen_server.py index 7fa7c4ff7d5..c159982c4a2 100644 --- a/src/zenml/zen_server/deploy/docker/docker_zen_server.py +++ b/src/zenml/zen_server/deploy/docker/docker_zen_server.py @@ -14,7 +14,7 @@ """Service implementation for the ZenML docker server deployment.""" import os -from typing import Dict, List, Optional, Tuple, cast +from typing import Optional, cast from pydantic import ConfigDict @@ -72,10 +72,10 @@ class DockerServerDeploymentConfig(LocalServerDeploymentConfig): port: int = 8238 image: str = DOCKER_ZENML_SERVER_DEFAULT_IMAGE - store: Optional[StoreConfiguration] = None + store: StoreConfiguration | None = None @property - def url(self) -> Optional[str]: + def url(self) -> str | None: """Get the configured server URL. Returns: @@ -148,14 +148,14 @@ def get_service(cls) -> Optional["DockerZenServer"]: """ config_filename = os.path.join(cls.config_path(), "service.json") try: - with open(config_filename, "r") as f: + with open(config_filename) as f: return cast( "DockerZenServer", DockerZenServer.from_json(f.read()) ) except FileNotFoundError: return None - def _get_container_cmd(self) -> Tuple[List[str], Dict[str, str]]: + def _get_container_cmd(self) -> tuple[list[str], dict[str, str]]: """Get the command to run the service container. Override the inherited method to use a ZenML global config path inside diff --git a/src/zenml/zen_server/exceptions.py b/src/zenml/zen_server/exceptions.py index 0454e3e6507..bb8017af100 100644 --- a/src/zenml/zen_server/exceptions.py +++ b/src/zenml/zen_server/exceptions.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """REST API exception handling.""" -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Type +from typing import TYPE_CHECKING, Any import requests from pydantic import BaseModel @@ -39,7 +39,7 @@ class ErrorModel(BaseModel): """Base class for error responses.""" - detail: Optional[Any] = None + detail: Any | None = None error_response = dict(model=ErrorModel) @@ -67,7 +67,7 @@ class ErrorModel(BaseModel): # An exception may be associated with multiple status codes if the same # exception can be reconstructed from two or more HTTP error responses with # different status codes (e.g. `ValueError` and the 400 and 422 status codes). -REST_API_EXCEPTIONS: List[Tuple[Type[Exception], int]] = [ +REST_API_EXCEPTIONS: list[tuple[type[Exception], int]] = [ # 409 Conflict (EntityExistsError, 409), # 403 Forbidden @@ -99,8 +99,8 @@ class ErrorModel(BaseModel): def error_detail( - error: Exception, exception_type: Optional[Type[Exception]] = None -) -> List[str]: + error: Exception, exception_type: type[Exception] | None = None +) -> list[str]: """Convert an Exception to API representation. Args: @@ -141,7 +141,7 @@ def http_exception_from_error(error: Exception) -> "HTTPException": from fastapi import HTTPException status_code = 0 - matching_exception_type: Optional[Type[Exception]] = None + matching_exception_type: type[Exception] | None = None for exception_type, exc_status_code in REST_API_EXCEPTIONS: if error.__class__ is exception_type: @@ -179,7 +179,7 @@ def http_exception_from_error(error: Exception) -> "HTTPException": def exception_from_response( response: requests.Response, -) -> Optional[Exception]: +) -> Exception | None: """Convert an error HTTP response to an exception. Uses the REST_API_EXCEPTIONS list to determine the appropriate exception @@ -199,7 +199,7 @@ class to use based on the response status code and the exception class name into an exception. """ - def unpack_exc() -> Tuple[Optional[str], str]: + def unpack_exc() -> tuple[str | None, str]: """Unpack the response body into an exception name and message. Returns: @@ -233,7 +233,7 @@ def unpack_exc() -> Tuple[Optional[str], str]: return detail[0], message exc_name, exc_msg = unpack_exc() - default_exc: Optional[Type[Exception]] = None + default_exc: type[Exception] | None = None for exception, status_code in REST_API_EXCEPTIONS: if response.status_code != status_code: diff --git a/src/zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py b/src/zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py index f16a20af9c7..4b76d14727a 100644 --- a/src/zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py +++ b/src/zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """ZenML Pro implementation of the feature gate.""" -from typing import Any, Dict +from typing import Any from uuid import UUID from pydantic import BaseModel, Field @@ -52,7 +52,7 @@ class RawUsageEvent(BaseModel): total: int = Field( description="The total amount of entities of this type." ) - metadata: Dict[str, Any] = Field( + metadata: dict[str, Any] = Field( default={}, description="Allows attaching additional metadata to events.", ) diff --git a/src/zenml/zen_server/jwt.py b/src/zenml/zen_server/jwt.py index a95fb9673b2..287fd3c48ed 100644 --- a/src/zenml/zen_server/jwt.py +++ b/src/zenml/zen_server/jwt.py @@ -16,8 +16,6 @@ from datetime import datetime, timedelta from typing import ( Any, - Dict, - Optional, cast, ) from uuid import UUID @@ -52,13 +50,13 @@ class JWTToken(BaseModel): """ user_id: UUID - device_id: Optional[UUID] = None - api_key_id: Optional[UUID] = None - schedule_id: Optional[UUID] = None - pipeline_run_id: Optional[UUID] = None - deployment_id: Optional[UUID] = None - session_id: Optional[UUID] = None - claims: Dict[str, Any] = {} + device_id: UUID | None = None + api_key_id: UUID | None = None + schedule_id: UUID | None = None + pipeline_run_id: UUID | None = None + deployment_id: UUID | None = None + session_id: UUID | None = None + claims: dict[str, Any] = {} @classmethod def decode_token( @@ -93,7 +91,7 @@ def decode_token( verify=verify, leeway=timedelta(seconds=config.jwt_token_leeway_seconds), ) - claims = cast(Dict[str, Any], claims_data) + claims = cast(dict[str, Any], claims_data) except jwt.PyJWTError as e: raise CredentialsNotValid(f"Invalid JWT token: {e}") from e @@ -110,7 +108,7 @@ def decode_token( "Invalid JWT token: the subject claim is not a valid UUID" ) - device_id: Optional[UUID] = None + device_id: UUID | None = None if "device_id" in claims: try: device_id = UUID(claims.pop("device_id")) @@ -120,7 +118,7 @@ def decode_token( "UUID" ) - api_key_id: Optional[UUID] = None + api_key_id: UUID | None = None if "api_key_id" in claims: try: api_key_id = UUID(claims.pop("api_key_id")) @@ -130,7 +128,7 @@ def decode_token( "UUID" ) - schedule_id: Optional[UUID] = None + schedule_id: UUID | None = None if "schedule_id" in claims: try: schedule_id = UUID(claims.pop("schedule_id")) @@ -140,7 +138,7 @@ def decode_token( "UUID" ) - pipeline_run_id: Optional[UUID] = None + pipeline_run_id: UUID | None = None if "pipeline_run_id" in claims: try: pipeline_run_id = UUID(claims.pop("pipeline_run_id")) @@ -150,7 +148,7 @@ def decode_token( "UUID" ) - deployment_id: Optional[UUID] = None + deployment_id: UUID | None = None if "deployment_id" in claims: try: deployment_id = UUID(claims.pop("deployment_id")) @@ -160,7 +158,7 @@ def decode_token( "UUID" ) - session_id: Optional[UUID] = None + session_id: UUID | None = None if "session_id" in claims: try: session_id = UUID(claims.pop("session_id")) @@ -181,7 +179,7 @@ def decode_token( claims=claims, ) - def encode(self, expires: Optional[datetime] = None) -> str: + def encode(self, expires: datetime | None = None) -> str: """Creates a JWT access token. Encodes, signs and returns a JWT access token. @@ -195,7 +193,7 @@ def encode(self, expires: Optional[datetime] = None) -> str: """ config = server_config() - claims: Dict[str, Any] = self.claims.copy() + claims: dict[str, Any] = self.claims.copy() claims["sub"] = str(self.user_id) claims["iss"] = config.get_jwt_token_issuer() diff --git a/src/zenml/zen_server/middleware.py b/src/zenml/zen_server/middleware.py index 7e61059d7e6..f9a6022113b 100644 --- a/src/zenml/zen_server/middleware.py +++ b/src/zenml/zen_server/middleware.py @@ -17,7 +17,7 @@ from asyncio import Lock from asyncio.log import logger from datetime import datetime, timedelta -from typing import Any, Set +from typing import Any from anyio import CapacityLimiter, to_thread from fastapi import FastAPI, Request @@ -111,7 +111,7 @@ async def dispatch( class RestrictFileUploadsMiddleware(BaseHTTPMiddleware): """Restrict file uploads to certain paths.""" - def __init__(self, app: ASGIApp, allowed_paths: Set[str]): + def __init__(self, app: ASGIApp, allowed_paths: set[str]): """Restrict file uploads to certain paths. Args: @@ -156,7 +156,7 @@ async def dispatch( ) -ALLOWED_FOR_FILE_UPLOAD: Set[str] = set() +ALLOWED_FOR_FILE_UPLOAD: set[str] = set() async def track_last_user_activity(request: Request, call_next: Any) -> Any: diff --git a/src/zenml/zen_server/pipeline_execution/runner_entrypoint_configuration.py b/src/zenml/zen_server/pipeline_execution/runner_entrypoint_configuration.py index 635c8b13e97..cec5efe643f 100644 --- a/src/zenml/zen_server/pipeline_execution/runner_entrypoint_configuration.py +++ b/src/zenml/zen_server/pipeline_execution/runner_entrypoint_configuration.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Runner entrypoint configuration.""" -from typing import Any, List, Set +from typing import Any from uuid import UUID from zenml.client import Client @@ -29,7 +29,7 @@ class RunnerEntrypointConfiguration(BaseEntrypointConfiguration): """Runner entrypoint configuration.""" @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> set[str]: """Gets all options required for running with this configuration. Returns: @@ -42,7 +42,7 @@ def get_entrypoint_options(cls) -> Set[str]: def get_entrypoint_arguments( cls, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: """Gets all arguments that the entrypoint command should be called with. Args: diff --git a/src/zenml/zen_server/pipeline_execution/utils.py b/src/zenml/zen_server/pipeline_execution/utils.py index 12d613de954..c77b2812459 100644 --- a/src/zenml/zen_server/pipeline_execution/utils.py +++ b/src/zenml/zen_server/pipeline_execution/utils.py @@ -18,7 +18,8 @@ import sys import threading from concurrent.futures import Future, ThreadPoolExecutor -from typing import Any, Callable, Dict, List, Optional +from typing import Any +from collections.abc import Callable from uuid import UUID from packaging import version @@ -143,7 +144,7 @@ def run_snapshot( auth_context: AuthContext, request: PipelineSnapshotRunRequest, sync: bool = False, - template_id: Optional[UUID] = None, + template_id: UUID | None = None, ) -> PipelineRunResponse: """Run a pipeline from a snapshot. @@ -417,8 +418,8 @@ def generate_image_hash(dockerfile: str) -> str: def generate_dockerfile( - pypi_requirements: List[str], - apt_packages: List[str], + pypi_requirements: list[str], + apt_packages: list[str], zenml_version: str, python_version: str, ) -> str: @@ -470,7 +471,7 @@ def generate_dockerfile( def snapshot_request_from_source_snapshot( source_snapshot: PipelineSnapshotResponse, config: PipelineRunConfiguration, - template_id: Optional[UUID] = None, + template_id: UUID | None = None, ) -> "PipelineSnapshotRequest": """Generate a snapshot request from a source snapshot. @@ -606,7 +607,7 @@ def get_pipeline_run_analytics_metadata( stack: StackResponse, source_snapshot_id: UUID, run_id: UUID, -) -> Dict[str, Any]: +) -> dict[str, Any]: """Get metadata for the pipeline run analytics event. Args: diff --git a/src/zenml/zen_server/pipeline_execution/workload_manager_interface.py b/src/zenml/zen_server/pipeline_execution/workload_manager_interface.py index 2eab1e47763..e295f388c68 100644 --- a/src/zenml/zen_server/pipeline_execution/workload_manager_interface.py +++ b/src/zenml/zen_server/pipeline_execution/workload_manager_interface.py @@ -14,7 +14,6 @@ """Workload manager interface definition.""" from abc import ABC, abstractmethod -from typing import Dict, List, Optional from uuid import UUID @@ -26,9 +25,9 @@ def run( self, workload_id: UUID, image: str, - command: List[str], - arguments: List[str], - environment: Optional[Dict[str, str]] = None, + command: list[str], + arguments: list[str], + environment: dict[str, str] | None = None, sync: bool = True, timeout_in_seconds: int = 0, ) -> None: @@ -46,7 +45,6 @@ def run( the container. If set to 0 the container will run until it fails or finishes. """ - pass @abstractmethod def build_and_push_image( @@ -71,7 +69,6 @@ def build_and_push_image( Returns: The full image name including container registry. """ - pass @abstractmethod def delete_workload(self, workload_id: UUID) -> None: @@ -80,7 +77,6 @@ def delete_workload(self, workload_id: UUID) -> None: Args: workload_id: Workload ID. """ - pass @abstractmethod def get_logs(self, workload_id: UUID) -> str: @@ -92,7 +88,6 @@ def get_logs(self, workload_id: UUID) -> str: Returns: The stored logs. """ - pass @abstractmethod def log(self, workload_id: UUID, message: str) -> None: @@ -102,4 +97,3 @@ def log(self, workload_id: UUID, message: str) -> None: workload_id: Workload ID. message: The message to log. """ - pass diff --git a/src/zenml/zen_server/rate_limit.py b/src/zenml/zen_server/rate_limit.py index afef9530015..2f5e61dc0ed 100644 --- a/src/zenml/zen_server/rate_limit.py +++ b/src/zenml/zen_server/rate_limit.py @@ -20,14 +20,10 @@ from functools import wraps from typing import ( Any, - Callable, - Dict, - Generator, - List, - Optional, TypeVar, cast, ) +from collections.abc import Callable, Generator from starlette.requests import Request @@ -43,8 +39,8 @@ class RequestLimiter: def __init__( self, - day_limit: Optional[int] = None, - minute_limit: Optional[int] = None, + day_limit: int | None = None, + minute_limit: int | None = None, ): """Initializes the limiter. @@ -62,7 +58,7 @@ def __init__( raise ValueError("Pass either day or minuter limits, or both.") self.day_limit = day_limit self.minute_limit = minute_limit - self.limiter: Dict[str, List[float]] = defaultdict(list) + self.limiter: dict[str, list[float]] = defaultdict(list) def hit_limiter(self, request: Request) -> None: """Increase the number of hits in the limiter. @@ -156,8 +152,8 @@ def limit_failed_requests( def rate_limit_requests( - day_limit: Optional[int] = None, - minute_limit: Optional[int] = None, + day_limit: int | None = None, + minute_limit: int | None = None, ) -> Callable[..., Any]: """Decorator to handle exceptions in the API. diff --git a/src/zenml/zen_server/rbac/endpoint_utils.py b/src/zenml/zen_server/rbac/endpoint_utils.py index ebd7d9d80a1..66576fd435d 100644 --- a/src/zenml/zen_server/rbac/endpoint_utils.py +++ b/src/zenml/zen_server/rbac/endpoint_utils.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """High-level helper functions to write endpoints with RBAC.""" -from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union +from typing import Any, TypeVar, Union +from collections.abc import Callable from uuid import UUID from zenml.models import ( @@ -58,7 +59,7 @@ def verify_permissions_and_create_entity( request_model: AnyRequest, create_method: Callable[[AnyRequest], AnyResponse], - surrogate_models: Optional[List[AnyOtherResponse]] = None, + surrogate_models: list[AnyOtherResponse] | None = None, skip_entitlements: bool = False, ) -> AnyResponse: """Verify permissions and create the entity if authorized. @@ -110,9 +111,9 @@ def verify_permissions_and_create_entity( def verify_permissions_and_batch_create_entity( - batch: List[AnyRequest], - create_method: Callable[[List[AnyRequest]], List[AnyResponse]], -) -> List[AnyResponse]: + batch: list[AnyRequest], + create_method: Callable[[list[AnyRequest]], list[AnyResponse]], +) -> list[AnyResponse]: """Verify permissions and create a batch of entities if authorized. Args: @@ -156,9 +157,9 @@ def verify_permissions_and_batch_create_entity( def verify_permissions_and_get_or_create_entity( request_model: AnyRequest, get_or_create_method: Callable[ - [AnyRequest, Optional[Callable[[], None]]], Tuple[AnyResponse, bool] + [AnyRequest, Callable[[], None] | None], tuple[AnyResponse, bool] ], -) -> Tuple[AnyResponse, bool]: +) -> tuple[AnyResponse, bool]: """Verify permissions and create the entity if authorized. Args: @@ -240,7 +241,7 @@ def verify_permissions_and_list_entities( auth_context = get_auth_context() assert auth_context - project_id: Optional[UUID] = None + project_id: UUID | None = None if isinstance(filter_model, ProjectScopedFilter): # A project scoped filter must always be scoped to a specific # project. This is required for the RBAC check to work. @@ -320,7 +321,7 @@ def verify_permissions_and_delete_entity( def verify_permissions_and_prune_entities( resource_type: ResourceType, prune_method: Callable[..., None], - project_id: Optional[UUID] = None, + project_id: UUID | None = None, **kwargs: Any, ) -> None: """Verify permissions and prune entities of certain type. diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index bc8984e403b..909ea4e6d09 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """RBAC model classes.""" -from typing import Optional from uuid import UUID from pydantic import ( @@ -104,8 +103,8 @@ class Resource(BaseModel): """RBAC resource model.""" type: str - id: Optional[UUID] = None - project_id: Optional[UUID] = None + id: UUID | None = None + project_id: UUID | None = None def __str__(self) -> str: """Convert to a string. @@ -135,7 +134,7 @@ def parse(cls, resource: str) -> "Resource": Returns: The converted resource. """ - project_id: Optional[str] = None + project_id: str | None = None if ":" in resource: ( project_id, @@ -145,7 +144,7 @@ def parse(cls, resource: str) -> "Resource": project_id = None resource_type_and_id = resource - resource_id: Optional[str] = None + resource_id: str | None = None if "/" in resource_type_and_id: resource_type, resource_id = resource_type_and_id.split("/") else: diff --git a/src/zenml/zen_server/rbac/rbac_interface.py b/src/zenml/zen_server/rbac/rbac_interface.py index 9ae627ee90f..4953d5f55be 100644 --- a/src/zenml/zen_server/rbac/rbac_interface.py +++ b/src/zenml/zen_server/rbac/rbac_interface.py @@ -14,7 +14,7 @@ """RBAC interface definition.""" from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING from zenml.zen_server.rbac.models import Action, Resource @@ -27,8 +27,8 @@ class RBACInterface(ABC): @abstractmethod def check_permissions( - self, user: "UserResponse", resources: Set[Resource], action: Action - ) -> Dict[Resource, bool]: + self, user: "UserResponse", resources: set[Resource], action: Action + ) -> dict[Resource, bool]: """Checks if a user has permissions to perform an action on resources. Args: @@ -44,7 +44,7 @@ def check_permissions( @abstractmethod def list_allowed_resource_ids( self, user: "UserResponse", resource: Resource, action: Action - ) -> Tuple[bool, List[str]]: + ) -> tuple[bool, list[str]]: """Lists all resource IDs of a resource type that a user can access. Args: @@ -66,9 +66,9 @@ def update_resource_membership( self, sharing_user: "UserResponse", resource: Resource, - actions: List[Action], - user_id: Optional[str] = None, - team_id: Optional[str] = None, + actions: list[Action], + user_id: str | None = None, + team_id: str | None = None, ) -> None: """Update the resource membership of a user. @@ -82,7 +82,7 @@ def update_resource_membership( """ @abstractmethod - def delete_resources(self, resources: List[Resource]) -> None: + def delete_resources(self, resources: list[Resource]) -> None: """Delete resource membership information for a list of resources. Args: diff --git a/src/zenml/zen_server/rbac/rbac_sql_zen_store.py b/src/zenml/zen_server/rbac/rbac_sql_zen_store.py index 37e1cb3dfa9..a70687cc424 100644 --- a/src/zenml/zen_server/rbac/rbac_sql_zen_store.py +++ b/src/zenml/zen_server/rbac/rbac_sql_zen_store.py @@ -13,10 +13,6 @@ # permissions and limitations under the License. """RBAC SQL Zen Store implementation.""" -from typing import ( - Optional, - Tuple, -) from uuid import UUID from zenml.logger import get_logger @@ -45,7 +41,7 @@ class RBACSqlZenStore(SqlZenStore): def _get_or_create_model( self, model_request: ModelRequest - ) -> Tuple[bool, ModelResponse]: + ) -> tuple[bool, ModelResponse]: """Get or create a model. Args: @@ -102,8 +98,8 @@ def _get_or_create_model( def _get_model_version( self, model_id: UUID, - version_name: Optional[str] = None, - producer_run_id: Optional[UUID] = None, + version_name: str | None = None, + producer_run_id: UUID | None = None, ) -> ModelVersionResponse: """Get a model version. @@ -128,8 +124,8 @@ def _get_model_version( def _get_or_create_model_version( self, model_version_request: ModelVersionRequest, - producer_run_id: Optional[UUID] = None, - ) -> Tuple[bool, ModelVersionResponse]: + producer_run_id: UUID | None = None, + ) -> tuple[bool, ModelVersionResponse]: """Get or create a model version. Args: diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index d3db78fe409..471785f46f2 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -16,14 +16,9 @@ from typing import ( TYPE_CHECKING, Any, - Dict, - List, - Optional, - Sequence, - Set, - Type, TypeVar, ) +from collections.abc import Sequence from uuid import UUID from pydantic import BaseModel @@ -61,8 +56,8 @@ def dehydrate_page(page: Page[AnyResponse]) -> Page[AnyResponse]: def dehydrate_response_model_batch( - batch: List[AnyResponse], -) -> List[AnyResponse]: + batch: list[AnyResponse], +) -> list[AnyResponse]: """Dehydrate all items of a batch. Args: @@ -92,7 +87,7 @@ def dehydrate_response_model_batch( def dehydrate_response_model( - model: AnyModel, permissions: Optional[Dict[Resource, bool]] = None + model: AnyModel, permissions: dict[Resource, bool] | None = None ) -> AnyModel: """Dehydrate a model if necessary. @@ -134,7 +129,7 @@ def dehydrate_response_model( def _dehydrate_value( - value: Any, permissions: Optional[Dict[Resource, bool]] = None + value: Any, permissions: dict[Resource, bool] | None = None ) -> Any: """Helper function to recursive dehydrate any object. @@ -167,12 +162,12 @@ def _dehydrate_value( return dehydrate_page(page=value) elif isinstance(value, BaseModel): return dehydrate_response_model(value, permissions=permissions) - elif isinstance(value, Dict): + elif isinstance(value, dict): return { k: _dehydrate_value(v, permissions=permissions) for k, v in value.items() } - elif isinstance(value, (List, Set, tuple)): + elif isinstance(value, (list, set, tuple)): type_ = type(value) return type_( _dehydrate_value(v, permissions=permissions) for v in value @@ -260,7 +255,7 @@ def verify_permission_for_model(model: AnyModel, action: Action) -> None: def batch_verify_permissions( - resources: Set[Resource], + resources: set[Resource], action: Action, ) -> None: """Batch permission verification. @@ -302,8 +297,8 @@ def batch_verify_permissions( def verify_permission( resource_type: str, action: Action, - resource_id: Optional[UUID] = None, - project_id: Optional[UUID] = None, + resource_id: UUID | None = None, + project_id: UUID | None = None, ) -> None: """Verifies if a user has permission to perform an action on a resource. @@ -324,8 +319,8 @@ def verify_permission( def get_allowed_resource_ids( resource_type: str, action: Action = Action.READ, - project_id: Optional[UUID] = None, -) -> Optional[Set[UUID]]: + project_id: UUID | None = None, +) -> set[UUID] | None: """Get all resource IDs of a resource type that a user can access. Args: @@ -359,7 +354,7 @@ def get_allowed_resource_ids( return {UUID(id) for id in allowed_ids} -def get_resource_for_model(model: AnyModel) -> Optional[Resource]: +def get_resource_for_model(model: AnyModel) -> Resource | None: """Get the resource associated with a model object. Args: @@ -374,14 +369,14 @@ def get_resource_for_model(model: AnyModel) -> Optional[Resource]: # This model is not tied to any RBAC resource type return None - project_id: Optional[UUID] = None + project_id: UUID | None = None if isinstance(model, ProjectScopedResponse): project_id = model.project_id elif isinstance(model, ProjectScopedRequest): # A project scoped request is always scoped to a specific project project_id = model.project - resource_id: Optional[UUID] = None + resource_id: UUID | None = None if isinstance(model, BaseIdentifiedResponse): resource_id = model.id @@ -419,7 +414,7 @@ def get_surrogate_permission_model_for_model( def get_resource_type_for_model( model: AnyModel, -) -> Optional[ResourceType]: +) -> ResourceType | None: """Get the resource type associated with a model object. Args: @@ -483,7 +478,7 @@ def get_resource_type_for_model( TriggerResponse, ) - mapping: Dict[ + mapping: dict[ Any, ResourceType, ] = { @@ -569,7 +564,7 @@ def is_owned_by_authenticated_user(model: AnyModel) -> bool: def get_subresources_for_model( model: AnyModel, -) -> Set[Resource]: +) -> set[Resource]: """Get all sub-resources of a model which need permission verification. Args: @@ -598,7 +593,7 @@ def get_subresources_for_model( return resources -def _get_subresources_for_value(value: Any) -> Set[Resource]: +def _get_subresources_for_value(value: Any) -> set[Resource]: """Helper function to recursive retrieve resources of any object. Args: @@ -619,12 +614,12 @@ def _get_subresources_for_value(value: Any) -> Set[Resource]: return resources.union(get_subresources_for_model(value)) elif isinstance(value, BaseModel): return get_subresources_for_model(value) - elif isinstance(value, Dict): + elif isinstance(value, dict): resources_list = [ _get_subresources_for_value(v) for v in value.values() ] return set.union(*resources_list) if resources_list else set() - elif isinstance(value, (List, Set, tuple)): + elif isinstance(value, (list, set, tuple)): resources_list = [_get_subresources_for_value(v) for v in value] return set.union(*resources_list) if resources_list else set() else: @@ -633,7 +628,7 @@ def _get_subresources_for_value(value: Any) -> Set[Resource]: def get_schema_for_resource_type( resource_type: ResourceType, -) -> Type["BaseSchema"]: +) -> type["BaseSchema"]: """Get the database schema for a resource type. Args: @@ -670,7 +665,7 @@ def get_schema_for_resource_type( UserSchema, ) - mapping: Dict[ResourceType, Type["BaseSchema"]] = { + mapping: dict[ResourceType, type["BaseSchema"]] = { ResourceType.STACK: StackSchema, ResourceType.FLAVOR: FlavorSchema, ResourceType.STACK_COMPONENT: StackComponentSchema, @@ -706,9 +701,9 @@ def get_schema_for_resource_type( def update_resource_membership( sharing_user: "UserResponse", resource: Resource, - actions: List[Action], - user_id: Optional[str] = None, - team_id: Optional[str] = None, + actions: list[Action], + user_id: str | None = None, + team_id: str | None = None, ) -> None: """Update the resource membership of a user. @@ -741,7 +736,7 @@ def delete_model_resource(model: AnyModel) -> None: delete_model_resources(models=[model]) -def delete_model_resources(models: List[AnyModel]) -> None: +def delete_model_resources(models: list[AnyModel]) -> None: """Delete resource membership information for a list of models. Args: @@ -758,7 +753,7 @@ def delete_model_resources(models: List[AnyModel]) -> None: delete_resources(resources=list(resources)) -def delete_resources(resources: List[Resource]) -> None: +def delete_resources(resources: list[Resource]) -> None: """Delete resource membership information for a list of resources. Args: diff --git a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py index 5aaa6ac18fc..ff122cc483d 100644 --- a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py +++ b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Cloud RBAC implementation.""" -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING from zenml.zen_server.cloud_utils import cloud_connection from zenml.zen_server.rbac.models import Action, Resource @@ -37,8 +37,8 @@ def __init__(self) -> None: self._connection = cloud_connection() def check_permissions( - self, user: "UserResponse", resources: Set[Resource], action: Action - ) -> Dict[Resource, bool]: + self, user: "UserResponse", resources: set[Resource], action: Action + ) -> dict[Resource, bool]: """Checks if a user has permissions to perform an action on resources. Args: @@ -80,7 +80,7 @@ def check_permissions( def list_allowed_resource_ids( self, user: "UserResponse", resource: Resource, action: Action - ) -> Tuple[bool, List[str]]: + ) -> tuple[bool, list[str]]: """Lists all resource IDs of a resource type that a user can access. Args: @@ -117,7 +117,7 @@ def list_allowed_resource_ids( response_json = response.json() full_resource_access: bool = response_json["full_access"] - allowed_ids: List[str] = response_json["ids"] + allowed_ids: list[str] = response_json["ids"] return full_resource_access, allowed_ids @@ -125,9 +125,9 @@ def update_resource_membership( self, sharing_user: "UserResponse", resource: Resource, - actions: List[Action], - user_id: Optional[str] = None, - team_id: Optional[str] = None, + actions: list[Action], + user_id: str | None = None, + team_id: str | None = None, ) -> None: """Update the resource membership of a user. @@ -149,7 +149,7 @@ def update_resource_membership( } self._connection.post(endpoint=RESOURCE_MEMBERSHIP_ENDPOINT, data=data) - def delete_resources(self, resources: List[Resource]) -> None: + def delete_resources(self, resources: list[Resource]) -> None: """Delete resource membership information for a list of resources. Args: diff --git a/src/zenml/zen_server/request_management.py b/src/zenml/zen_server/request_management.py index 8a747f6935c..c2ecad2c4a7 100644 --- a/src/zenml/zen_server/request_management.py +++ b/src/zenml/zen_server/request_management.py @@ -17,7 +17,8 @@ import base64 import json from contextvars import ContextVar -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Callable from uuid import UUID, uuid4 from fastapi import Request, Response @@ -51,7 +52,7 @@ def __init__(self, request: Request) -> None: """ self.request = request self.request_id = request.headers.get("X-Request-ID", str(uuid4())[:8]) - self.transaction_id: Optional[UUID] = None + self.transaction_id: UUID | None = None transaction_id = request.headers.get("Idempotency-Key") if transaction_id: try: @@ -190,12 +191,12 @@ def __init__( deduplicate: Whether to deduplicate requests. """ self.deduplicate = deduplicate - self.transactions: Dict[UUID, RequestRecord] = dict() + self.transactions: dict[UUID, RequestRecord] = dict() self.lock = asyncio.Lock() self.transaction_ttl = transaction_ttl self.request_timeout = request_timeout - self.request_contexts: ContextVar[Optional[RequestContext]] = ( + self.request_contexts: ContextVar[RequestContext | None] = ( ContextVar("request_contexts", default=None) ) @@ -225,11 +226,9 @@ def current_request(self, request_context: RequestContext) -> None: async def startup(self) -> None: """Start the request manager.""" - pass async def shutdown(self) -> None: """Shutdown the request manager.""" - pass async def async_run_and_cache_result( self, @@ -361,7 +360,7 @@ def sync_run_and_cache_result(*args: Any, **kwargs: Any) -> Any: if deduplicate_request: assert transaction_id is not None cache_result = True - result_to_cache: Optional[bytes] = None + result_to_cache: bytes | None = None if result is not None: try: result_to_cache = base64.b64encode( @@ -443,7 +442,7 @@ def sync_run_and_cache_result(*args: Any, **kwargs: Any) -> Any: async def execute( self, func: Callable[..., Any], - deduplicate: Optional[bool], + deduplicate: bool | None, *args: Any, **kwargs: Any, ) -> Any: diff --git a/src/zenml/zen_server/routers/artifact_version_endpoints.py b/src/zenml/zen_server/routers/artifact_version_endpoints.py index 181f89408e3..df429309b9c 100644 --- a/src/zenml/zen_server/routers/artifact_version_endpoints.py +++ b/src/zenml/zen_server/routers/artifact_version_endpoints.py @@ -14,7 +14,6 @@ """Endpoint definitions for artifact versions.""" import os -from typing import List, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -153,9 +152,9 @@ def create_artifact_version( ) @async_fastapi_endpoint_wrapper def batch_create_artifact_version( - artifact_versions: List[ArtifactVersionRequest], + artifact_versions: list[ArtifactVersionRequest], _: AuthContext = Security(authorize), -) -> List[ArtifactVersionResponse]: +) -> list[ArtifactVersionResponse]: """Create a batch of artifact versions. Args: @@ -251,7 +250,7 @@ def delete_artifact_version( ) @async_fastapi_endpoint_wrapper def prune_artifact_versions( - project_name_or_id: Union[str, UUID], + project_name_or_id: str | UUID, only_versions: bool = True, _: AuthContext = Security(authorize), ) -> None: diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index fb3c60a2902..8a4483858b2 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Endpoint definitions for authentication (login).""" -from typing import Optional, Union +from typing import Union from urllib.parse import urlencode from uuid import UUID @@ -98,11 +98,11 @@ class OAuthLoginRequestForm: def __init__( self, - grant_type: Optional[str] = Form(None), - username: Optional[str] = Form(None), - password: Optional[str] = Form(None), - client_id: Optional[str] = Form(None), - device_code: Optional[str] = Form(None), + grant_type: str | None = Form(None), + username: str | None = Form(None), + password: str | None = Form(None), + client_id: str | None = Form(None), + device_code: str | None = Form(None), ): """Initializes the form. @@ -231,7 +231,7 @@ def token( request: Request, response: Response, auth_form_data: OAuthLoginRequestForm = Depends(), -) -> Union[OAuthTokenResponse, OAuthRedirectResponse]: +) -> OAuthTokenResponse | OAuthRedirectResponse: """OAuth2 token endpoint. Args: @@ -246,7 +246,7 @@ def token( ValueError: If the grant type is invalid. """ config = server_config() - cookie_response: Optional[Response] = response + cookie_response: Response | None = response if auth_form_data.grant_type == OAuthGrantTypes.OAUTH_PASSWORD: auth_context = authenticate_credentials( @@ -475,10 +475,10 @@ def device_authorization( @async_fastapi_endpoint_wrapper def api_token( token_type: APITokenType = APITokenType.GENERIC, - expires_in: Optional[int] = None, - schedule_id: Optional[UUID] = None, - pipeline_run_id: Optional[UUID] = None, - deployment_id: Optional[UUID] = None, + expires_in: int | None = None, + schedule_id: UUID | None = None, + pipeline_run_id: UUID | None = None, + deployment_id: UUID | None = None, auth_context: AuthContext = Security(authorize), ) -> str: """Generate an API token for the current user. @@ -588,7 +588,7 @@ def api_token( f"deployment {token.deployment_id}." ) - project_id: Optional[UUID] = None + project_id: UUID | None = None if schedule_id: # The schedule must exist diff --git a/src/zenml/zen_server/routers/code_repositories_endpoints.py b/src/zenml/zen_server/routers/code_repositories_endpoints.py index 97462a08c48..1fe0c5652d5 100644 --- a/src/zenml/zen_server/routers/code_repositories_endpoints.py +++ b/src/zenml/zen_server/routers/code_repositories_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for code repositories.""" -from typing import Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -65,7 +64,7 @@ @async_fastapi_endpoint_wrapper def create_code_repository( code_repository: CodeRepositoryRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, _: AuthContext = Security(authorize), ) -> CodeRepositoryResponse: """Creates a code repository. @@ -104,7 +103,7 @@ def list_code_repositories( filter_model: CodeRepositoryFilter = Depends( make_dependable(CodeRepositoryFilter) ), - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, hydrate: bool = False, _: AuthContext = Security(authorize), ) -> Page[CodeRepositoryResponse]: diff --git a/src/zenml/zen_server/routers/devices_endpoints.py b/src/zenml/zen_server/routers/devices_endpoints.py index 94d99bd36a2..cbb58598c74 100644 --- a/src/zenml/zen_server/routers/devices_endpoints.py +++ b/src/zenml/zen_server/routers/devices_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for code repositories.""" -from typing import Optional from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -89,7 +88,7 @@ def list_authorized_devices( @async_fastapi_endpoint_wrapper def get_authorization_device( device_id: UUID, - user_code: Optional[str] = None, + user_code: str | None = None, hydrate: bool = True, auth_context: AuthContext = Security(authorize), ) -> OAuthDeviceResponse: diff --git a/src/zenml/zen_server/routers/model_versions_endpoints.py b/src/zenml/zen_server/routers/model_versions_endpoints.py index 494156e911f..cb9d71390fb 100644 --- a/src/zenml/zen_server/routers/model_versions_endpoints.py +++ b/src/zenml/zen_server/routers/model_versions_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for models.""" -from typing import Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -92,8 +91,8 @@ @async_fastapi_endpoint_wrapper def create_model_version( model_version: ModelVersionRequest, - model_id: Optional[UUID] = None, - project_name_or_id: Optional[Union[str, UUID]] = None, + model_id: UUID | None = None, + project_name_or_id: str | UUID | None = None, _: AuthContext = Security(authorize), ) -> ModelVersionResponse: """Creates a model version. @@ -132,7 +131,7 @@ def list_model_versions( model_version_filter_model: ModelVersionFilter = Depends( make_dependable(ModelVersionFilter) ), - model_name_or_id: Optional[Union[str, UUID]] = None, + model_name_or_id: str | UUID | None = None, hydrate: bool = False, auth_context: AuthContext = Security(authorize), ) -> Page[ModelVersionResponse]: @@ -334,7 +333,7 @@ def list_model_version_artifact_links( @async_fastapi_endpoint_wrapper def delete_model_version_artifact_link( model_version_id: UUID, - model_version_artifact_link_name_or_id: Union[str, UUID], + model_version_artifact_link_name_or_id: str | UUID, _: AuthContext = Security(authorize), ) -> None: """Deletes a model version to artifact link. @@ -457,7 +456,7 @@ def list_model_version_pipeline_run_links( @async_fastapi_endpoint_wrapper def delete_model_version_pipeline_run_link( model_version_id: UUID, - model_version_pipeline_run_link_name_or_id: Union[str, UUID], + model_version_pipeline_run_link_name_or_id: str | UUID, _: AuthContext = Security(authorize), ) -> None: """Deletes a model version link. diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index 50fe7197186..5c870ba659e 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for models.""" -from typing import Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -75,7 +74,7 @@ @async_fastapi_endpoint_wrapper def create_model( model: ModelRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, _: AuthContext = Security(authorize), ) -> ModelResponse: """Creates a model. diff --git a/src/zenml/zen_server/routers/pipeline_builds_endpoints.py b/src/zenml/zen_server/routers/pipeline_builds_endpoints.py index 81fb6aeb1dd..a5616c005a4 100644 --- a/src/zenml/zen_server/routers/pipeline_builds_endpoints.py +++ b/src/zenml/zen_server/routers/pipeline_builds_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for builds.""" -from typing import Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -63,7 +62,7 @@ @async_fastapi_endpoint_wrapper def create_build( build: PipelineBuildRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, _: AuthContext = Security(authorize), ) -> PipelineBuildResponse: """Creates a build, optionally in a specific project. @@ -102,7 +101,7 @@ def list_builds( build_filter_model: PipelineBuildFilter = Depends( make_dependable(PipelineBuildFilter) ), - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, hydrate: bool = False, _: AuthContext = Security(authorize), ) -> Page[PipelineBuildResponse]: diff --git a/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py b/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py index d6333cbd559..1a91141239b 100644 --- a/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py +++ b/src/zenml/zen_server/routers/pipeline_deployments_endpoints.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Endpoint definitions for deployments.""" -from typing import Any, List, Optional, Union +from typing import Any from uuid import UUID from fastapi import APIRouter, Depends, Query, Request, Security @@ -98,7 +98,7 @@ def _should_remove_step_config_overrides( def create_deployment( request: Request, deployment: PipelineSnapshotRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, _: AuthContext = Security(authorize), ) -> Any: """Creates a deployment. @@ -149,7 +149,7 @@ def list_deployments( deployment_filter_model: PipelineSnapshotFilter = Depends( make_dependable(PipelineSnapshotFilter) ), - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, hydrate: bool = False, _: AuthContext = Security(authorize), ) -> Any: @@ -202,7 +202,7 @@ def get_deployment( request: Request, deployment_id: UUID, hydrate: bool = True, - step_configuration_filter: Optional[List[str]] = Query(None), + step_configuration_filter: list[str] | None = Query(None), _: AuthContext = Security(authorize), ) -> Any: """Gets a specific deployment using its unique id. diff --git a/src/zenml/zen_server/routers/pipeline_snapshot_endpoints.py b/src/zenml/zen_server/routers/pipeline_snapshot_endpoints.py index 95fb462ca2b..68cc65214b4 100644 --- a/src/zenml/zen_server/routers/pipeline_snapshot_endpoints.py +++ b/src/zenml/zen_server/routers/pipeline_snapshot_endpoints.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Endpoint definitions for pipeline snapshots.""" -from typing import Any, List, Optional, Union +from typing import Any from uuid import UUID from fastapi import APIRouter, Depends, Query, Security @@ -70,7 +70,7 @@ @async_fastapi_endpoint_wrapper def create_pipeline_snapshot( snapshot: PipelineSnapshotRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, _: AuthContext = Security(authorize), ) -> PipelineSnapshotResponse: """Creates a snapshot. @@ -101,7 +101,7 @@ def list_pipeline_snapshots( snapshot_filter_model: PipelineSnapshotFilter = Depends( make_dependable(PipelineSnapshotFilter) ), - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, hydrate: bool = False, _: AuthContext = Security(authorize), ) -> Page[PipelineSnapshotResponse]: @@ -136,8 +136,8 @@ def list_pipeline_snapshots( def get_pipeline_snapshot( snapshot_id: UUID, hydrate: bool = True, - step_configuration_filter: Optional[List[str]] = Query(None), - include_config_schema: Optional[bool] = None, + step_configuration_filter: list[str] | None = Query(None), + include_config_schema: bool | None = None, _: AuthContext = Security(authorize), ) -> PipelineSnapshotResponse: """Gets a specific snapshot using its unique id. diff --git a/src/zenml/zen_server/routers/pipelines_endpoints.py b/src/zenml/zen_server/routers/pipelines_endpoints.py index d0eba672a51..36ffe06371e 100644 --- a/src/zenml/zen_server/routers/pipelines_endpoints.py +++ b/src/zenml/zen_server/routers/pipelines_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for pipelines.""" -from typing import Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -76,7 +75,7 @@ @async_fastapi_endpoint_wrapper def create_pipeline( pipeline: PipelineRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, _: AuthContext = Security(authorize), ) -> PipelineResponse: """Creates a pipeline. @@ -124,7 +123,7 @@ def list_pipelines( pipeline_filter_model: PipelineFilter = Depends( make_dependable(PipelineFilter) ), - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, hydrate: bool = False, _: AuthContext = Security(authorize), ) -> Page[PipelineResponse]: diff --git a/src/zenml/zen_server/routers/projects_endpoints.py b/src/zenml/zen_server/routers/projects_endpoints.py index 738ff9f34f3..bc9aa56b7e3 100644 --- a/src/zenml/zen_server/routers/projects_endpoints.py +++ b/src/zenml/zen_server/routers/projects_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for projects.""" -from typing import Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -151,7 +150,7 @@ def create_project( ) @async_fastapi_endpoint_wrapper def get_project( - project_name_or_id: Union[str, UUID], + project_name_or_id: str | UUID, hydrate: bool = True, _: AuthContext = Security(authorize), ) -> ProjectResponse: @@ -221,7 +220,7 @@ def update_project( ) @async_fastapi_endpoint_wrapper def delete_project( - project_name_or_id: Union[str, UUID], + project_name_or_id: str | UUID, _: AuthContext = Security(authorize), ) -> None: """Deletes a project. @@ -251,7 +250,7 @@ def delete_project( ) @async_fastapi_endpoint_wrapper def get_project_statistics( - project_name_or_id: Union[str, UUID], + project_name_or_id: str | UUID, auth_context: AuthContext = Security(authorize), ) -> ProjectStatistics: """Gets statistics of a project. diff --git a/src/zenml/zen_server/routers/run_metadata_endpoints.py b/src/zenml/zen_server/routers/run_metadata_endpoints.py index 0744bca2ce6..3eee8a25997 100644 --- a/src/zenml/zen_server/routers/run_metadata_endpoints.py +++ b/src/zenml/zen_server/routers/run_metadata_endpoints.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Endpoint definitions for run metadata.""" -from typing import Any, List, Optional, Union +from typing import Any from uuid import UUID from fastapi import APIRouter, Security @@ -53,7 +53,7 @@ @async_fastapi_endpoint_wrapper def create_run_metadata( run_metadata: RunMetadataRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, auth_context: AuthContext = Security(authorize), ) -> None: """Creates run metadata. @@ -72,7 +72,7 @@ def create_run_metadata( run_metadata.user = auth_context.user.id - verify_models: List[Any] = [] + verify_models: list[Any] = [] for resource in run_metadata.resources: if resource.type == MetadataResourceTypes.PIPELINE_RUN: verify_models.append(zen_store().get_run(resource.id)) diff --git a/src/zenml/zen_server/routers/run_templates_endpoints.py b/src/zenml/zen_server/routers/run_templates_endpoints.py index 50eecbe26f1..59b09632851 100644 --- a/src/zenml/zen_server/routers/run_templates_endpoints.py +++ b/src/zenml/zen_server/routers/run_templates_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for run templates.""" -from typing import Optional, Union from uuid import UUID from fastapi import ( @@ -88,7 +87,7 @@ @async_fastapi_endpoint_wrapper def create_run_template( run_template: RunTemplateRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, _: AuthContext = Security(authorize), ) -> RunTemplateResponse: """Create a run template. @@ -127,7 +126,7 @@ def list_run_templates( filter_model: RunTemplateFilter = Depends( make_dependable(RunTemplateFilter) ), - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, hydrate: bool = False, _: AuthContext = Security(authorize), ) -> Page[RunTemplateResponse]: @@ -243,7 +242,7 @@ def delete_run_template( @async_fastapi_endpoint_wrapper def create_template_run( template_id: UUID, - config: Optional[PipelineRunConfiguration] = None, + config: PipelineRunConfiguration | None = None, auth_context: AuthContext = Security(authorize), ) -> PipelineRunResponse: """Run a pipeline from a template. diff --git a/src/zenml/zen_server/routers/runs_endpoints.py b/src/zenml/zen_server/routers/runs_endpoints.py index 1707d8bd87d..e438effb58e 100644 --- a/src/zenml/zen_server/routers/runs_endpoints.py +++ b/src/zenml/zen_server/routers/runs_endpoints.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Endpoint definitions for pipeline runs.""" -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -99,9 +99,9 @@ @async_fastapi_endpoint_wrapper def get_or_create_pipeline_run( pipeline_run: PipelineRunRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, _: AuthContext = Security(authorize), -) -> Tuple[PipelineRunResponse, bool]: +) -> tuple[PipelineRunResponse, bool]: """Get or create a pipeline run. Args: @@ -139,7 +139,7 @@ def list_runs( runs_filter_model: PipelineRunFilter = Depends( make_dependable(PipelineRunFilter) ), - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, hydrate: bool = False, include_full_metadata: bool = False, _: AuthContext = Security(authorize), @@ -310,7 +310,7 @@ def get_run_steps( def get_pipeline_configuration( run_id: UUID, _: AuthContext = Security(authorize), -) -> Dict[str, Any]: +) -> dict[str, Any]: """Get the pipeline configuration of a specific pipeline run using its ID. Args: @@ -440,7 +440,7 @@ def run_logs( run_id: UUID, source: str, _: AuthContext = Security(authorize), -) -> List[LogEntry]: +) -> list[LogEntry]: """Get log entries for efficient pagination. This endpoint returns the log entries. diff --git a/src/zenml/zen_server/routers/schedule_endpoints.py b/src/zenml/zen_server/routers/schedule_endpoints.py index 73458e83681..1aa894fdfa0 100644 --- a/src/zenml/zen_server/routers/schedule_endpoints.py +++ b/src/zenml/zen_server/routers/schedule_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for pipeline run schedules.""" -from typing import Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -65,7 +64,7 @@ @async_fastapi_endpoint_wrapper def create_schedule( schedule: ScheduleRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, auth_context: AuthContext = Security(authorize), ) -> ScheduleResponse: """Creates a schedule. @@ -105,7 +104,7 @@ def list_schedules( schedule_filter_model: ScheduleFilter = Depends( make_dependable(ScheduleFilter) ), - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, hydrate: bool = False, _: AuthContext = Security(authorize), ) -> Page[ScheduleResponse]: diff --git a/src/zenml/zen_server/routers/secrets_endpoints.py b/src/zenml/zen_server/routers/secrets_endpoints.py index 91f60095112..0ecc079f0c2 100644 --- a/src/zenml/zen_server/routers/secrets_endpoints.py +++ b/src/zenml/zen_server/routers/secrets_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for pipeline run secrets.""" -from typing import Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -84,7 +83,7 @@ @async_fastapi_endpoint_wrapper def create_secret( secret: SecretRequest, - workspace_name_or_id: Optional[Union[str, UUID]] = None, + workspace_name_or_id: str | UUID | None = None, _: AuthContext = Security(authorize), ) -> SecretResponse: """Creates a secret. @@ -189,7 +188,7 @@ def get_secret( def update_secret( secret_id: UUID, secret_update: SecretUpdate, - patch_values: Optional[bool] = False, + patch_values: bool | None = False, _: AuthContext = Security(authorize), ) -> SecretResponse: """Updates the attribute on a specific secret using its unique id. diff --git a/src/zenml/zen_server/routers/server_endpoints.py b/src/zenml/zen_server/routers/server_endpoints.py index 4e2c300f617..22dac2942bb 100644 --- a/src/zenml/zen_server/routers/server_endpoints.py +++ b/src/zenml/zen_server/routers/server_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for authentication (login).""" -from typing import List, Optional from fastapi import APIRouter, Security @@ -139,7 +138,7 @@ def server_load_info(_: AuthContext = Security(authorize)) -> ServerLoadInfo: @async_fastapi_endpoint_wrapper def get_onboarding_state( _: AuthContext = Security(authorize), -) -> List[str]: +) -> list[str]: """Get the onboarding state of the server. Returns: @@ -234,7 +233,7 @@ def update_server_settings( @async_fastapi_endpoint_wrapper def activate_server( activate_request: ServerActivationRequest, - ) -> Optional[UserResponse]: + ) -> UserResponse | None: """Updates a stack. Args: diff --git a/src/zenml/zen_server/routers/service_accounts_endpoints.py b/src/zenml/zen_server/routers/service_accounts_endpoints.py index c3ed545b27d..b45157c84e7 100644 --- a/src/zenml/zen_server/routers/service_accounts_endpoints.py +++ b/src/zenml/zen_server/routers/service_accounts_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for API keys.""" -from typing import Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -120,7 +119,7 @@ def create_service_account( ) @async_fastapi_endpoint_wrapper def get_service_account( - service_account_name_or_id: Union[str, UUID], + service_account_name_or_id: str | UUID, _: AuthContext = Security(authorize), hydrate: bool = True, ) -> ServiceAccountResponse: @@ -182,7 +181,7 @@ def list_service_accounts( ) @async_fastapi_endpoint_wrapper def update_service_account( - service_account_name_or_id: Union[str, UUID], + service_account_name_or_id: str | UUID, service_account_update: ServiceAccountUpdate, _: AuthContext = Security(authorize), ) -> ServiceAccountResponse: @@ -222,7 +221,7 @@ def update_service_account( ) @async_fastapi_endpoint_wrapper def delete_service_account( - service_account_name_or_id: Union[str, UUID], + service_account_name_or_id: str | UUID, _: AuthContext = Security(authorize), ) -> None: """Delete a specific service account. @@ -311,7 +310,7 @@ def create_api_key_wrapper( @async_fastapi_endpoint_wrapper def get_api_key( service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, hydrate: bool = True, _: AuthContext = Security(authorize), ) -> APIKeyResponse: @@ -378,7 +377,7 @@ def list_api_keys( @async_fastapi_endpoint_wrapper def update_api_key( service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, api_key_update: APIKeyUpdate, _: AuthContext = Security(authorize), ) -> APIKeyResponse: @@ -424,7 +423,7 @@ def update_api_key( @async_fastapi_endpoint_wrapper def rotate_api_key( service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, rotate_request: APIKeyRotateRequest, _: AuthContext = Security(authorize), ) -> APIKeyResponse: @@ -467,7 +466,7 @@ def rotate_api_key( @async_fastapi_endpoint_wrapper def delete_api_key( service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, _: AuthContext = Security(authorize), ) -> None: """Deletes an API key. diff --git a/src/zenml/zen_server/routers/service_connectors_endpoints.py b/src/zenml/zen_server/routers/service_connectors_endpoints.py index 9aba2168e7d..bbd0bd652e8 100644 --- a/src/zenml/zen_server/routers/service_connectors_endpoints.py +++ b/src/zenml/zen_server/routers/service_connectors_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for service connectors.""" -from typing import List, Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -94,7 +93,7 @@ @async_fastapi_endpoint_wrapper def create_service_connector( connector: ServiceConnectorRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, _: AuthContext = Security(authorize), ) -> ServiceConnectorResponse: """Creates a service connector. @@ -129,7 +128,7 @@ def list_service_connectors( connector_filter_model: ServiceConnectorFilter = Depends( make_dependable(ServiceConnectorFilter) ), - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, expand_secrets: bool = True, hydrate: bool = False, _: AuthContext = Security(authorize), @@ -195,9 +194,9 @@ def list_service_connector_resources( filter_model: ServiceConnectorFilter = Depends( make_dependable(ServiceConnectorFilter) ), - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, auth_context: AuthContext = Security(authorize), -) -> List[ServiceConnectorResourcesModel]: +) -> list[ServiceConnectorResourcesModel]: """List resources that can be accessed by service connectors. Args: @@ -345,8 +344,8 @@ def validate_and_verify_service_connector_config( @async_fastapi_endpoint_wrapper(deduplicate=True) def validate_and_verify_service_connector( connector_id: UUID, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + resource_id: str | None = None, list_resources: bool = True, _: AuthContext = Security(authorize), ) -> ServiceConnectorResourcesModel: @@ -386,8 +385,8 @@ def validate_and_verify_service_connector( @async_fastapi_endpoint_wrapper(deduplicate=True) def get_service_connector_client( connector_id: UUID, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + resource_id: str | None = None, _: AuthContext = Security(authorize), ) -> ServiceConnectorResponse: """Get a service connector client for a service connector and given resource. @@ -422,11 +421,11 @@ def get_service_connector_client( ) @async_fastapi_endpoint_wrapper def list_service_connector_types( - connector_type: Optional[str] = None, - resource_type: Optional[str] = None, - auth_method: Optional[str] = None, + connector_type: str | None = None, + resource_type: str | None = None, + auth_method: str | None = None, _: AuthContext = Security(authorize), -) -> List[ServiceConnectorTypeModel]: +) -> list[ServiceConnectorTypeModel]: """Get a list of service connector types. Args: @@ -472,8 +471,8 @@ def get_service_connector_type( ) @async_fastapi_endpoint_wrapper def get_resources_based_on_service_connector_info( - connector_info: Optional[ServiceConnectorInfo] = None, - connector_uuid: Optional[UUID] = None, + connector_info: ServiceConnectorInfo | None = None, + connector_uuid: UUID | None = None, _: AuthContext = Security(authorize), ) -> ServiceConnectorResourcesInfo: """Gets the list of resources that a service connector can access. diff --git a/src/zenml/zen_server/routers/service_endpoints.py b/src/zenml/zen_server/routers/service_endpoints.py index e5047e94277..7aa96a12ba7 100644 --- a/src/zenml/zen_server/routers/service_endpoints.py +++ b/src/zenml/zen_server/routers/service_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for services.""" -from typing import Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -65,7 +64,7 @@ @async_fastapi_endpoint_wrapper def create_service( service: ServiceRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, _: AuthContext = Security(authorize), ) -> ServiceResponse: """Creates a new service. diff --git a/src/zenml/zen_server/routers/stack_components_endpoints.py b/src/zenml/zen_server/routers/stack_components_endpoints.py index 50d760885a4..baabb27f840 100644 --- a/src/zenml/zen_server/routers/stack_components_endpoints.py +++ b/src/zenml/zen_server/routers/stack_components_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for stack components.""" -from typing import List, Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -74,7 +73,7 @@ @async_fastapi_endpoint_wrapper def create_stack_component( component: ComponentRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, _: AuthContext = Security(authorize), ) -> ComponentResponse: """Creates a stack component. @@ -86,7 +85,7 @@ def create_stack_component( Returns: The created stack component. """ - rbac_read_checks: List[BaseModel] = [] + rbac_read_checks: list[BaseModel] = [] if component.connector: service_connector = zen_store().get_service_connector( component.connector @@ -142,7 +141,7 @@ def list_stack_components( component_filter_model: ComponentFilter = Depends( make_dependable(ComponentFilter) ), - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, hydrate: bool = False, _: AuthContext = Security(authorize), ) -> Page[ComponentResponse]: @@ -229,7 +228,7 @@ def update_stack_component( mode="json", exclude_unset=True ) - rbac_read_checks: List[BaseModel] = [] + rbac_read_checks: list[BaseModel] = [] if component_update.connector: service_connector = zen_store().get_service_connector( component_update.connector @@ -282,7 +281,7 @@ def deregister_stack_component( @async_fastapi_endpoint_wrapper def get_stack_component_types( _: AuthContext = Security(authorize), -) -> List[str]: +) -> list[str]: """Get a list of all stack component types. Returns: diff --git a/src/zenml/zen_server/routers/stack_deployment_endpoints.py b/src/zenml/zen_server/routers/stack_deployment_endpoints.py index 84dd40952bc..48e4ce4e3c0 100644 --- a/src/zenml/zen_server/routers/stack_deployment_endpoints.py +++ b/src/zenml/zen_server/routers/stack_deployment_endpoints.py @@ -14,7 +14,6 @@ """Endpoint definitions for stack deployments.""" import datetime -from typing import Optional from fastapi import APIRouter, Request, Security @@ -78,7 +77,7 @@ def get_stack_deployment_config( request: Request, provider: StackDeploymentProvider, stack_name: str, - location: Optional[str] = None, + location: str | None = None, terraform: bool = False, auth_context: AuthContext = Security(authorize), ) -> StackDeploymentConfig: @@ -143,11 +142,11 @@ def get_stack_deployment_config( def get_deployed_stack( provider: StackDeploymentProvider, stack_name: str, - location: Optional[str] = None, - date_start: Optional[datetime.datetime] = None, + location: str | None = None, + date_start: datetime.datetime | None = None, terraform: bool = False, _: AuthContext = Security(authorize), -) -> Optional[DeployedStack]: +) -> DeployedStack | None: """Return a matching ZenML stack that was deployed and registered. Args: diff --git a/src/zenml/zen_server/routers/stacks_endpoints.py b/src/zenml/zen_server/routers/stacks_endpoints.py index 01505c383ee..86d56141f36 100644 --- a/src/zenml/zen_server/routers/stacks_endpoints.py +++ b/src/zenml/zen_server/routers/stacks_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for stacks.""" -from typing import List, Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -70,7 +69,7 @@ @async_fastapi_endpoint_wrapper def create_stack( stack: StackRequest, - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, auth_context: AuthContext = Security(authorize), ) -> StackResponse: """Creates a stack. @@ -83,7 +82,7 @@ def create_stack( Returns: The created stack. """ - rbac_read_checks: List[BaseModel] = [] + rbac_read_checks: list[BaseModel] = [] # Check the service connector creation is_connector_create_needed = False @@ -147,7 +146,7 @@ def create_stack( ) @async_fastapi_endpoint_wrapper def list_stacks( - project_name_or_id: Optional[Union[str, UUID]] = None, + project_name_or_id: str | UUID | None = None, stack_filter_model: StackFilter = Depends(make_dependable(StackFilter)), hydrate: bool = False, _: AuthContext = Security(authorize), @@ -216,7 +215,7 @@ def update_stack( Returns: The updated stack. """ - rbac_read_checks: List[BaseModel] = [] + rbac_read_checks: list[BaseModel] = [] if stack_update.components: rbac_read_checks.extend( [ diff --git a/src/zenml/zen_server/routers/steps_endpoints.py b/src/zenml/zen_server/routers/steps_endpoints.py index 772d480348c..aac98613193 100644 --- a/src/zenml/zen_server/routers/steps_endpoints.py +++ b/src/zenml/zen_server/routers/steps_endpoints.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Endpoint definitions for steps (and artifacts) of pipeline runs.""" -from typing import Any, Dict, List +from typing import Any from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -208,7 +208,7 @@ def update_step( def get_step_configuration( step_id: UUID, _: AuthContext = Security(authorize), -) -> Dict[str, Any]: +) -> dict[str, Any]: """Get the configuration of a specific step. Args: @@ -261,7 +261,7 @@ def get_step_status( def get_step_logs( step_id: UUID, _: AuthContext = Security(authorize), -) -> List[LogEntry]: +) -> list[LogEntry]: """Get log entries for a step. Args: diff --git a/src/zenml/zen_server/routers/tag_resource_endpoints.py b/src/zenml/zen_server/routers/tag_resource_endpoints.py index 197120e3011..5e3fe8f8f4d 100644 --- a/src/zenml/zen_server/routers/tag_resource_endpoints.py +++ b/src/zenml/zen_server/routers/tag_resource_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for the link between tags and resources.""" -from typing import List from fastapi import APIRouter, Security @@ -64,9 +63,9 @@ def create_tag_resource( ) @async_fastapi_endpoint_wrapper def batch_create_tag_resource( - tag_resources: List[TagResourceRequest], + tag_resources: list[TagResourceRequest], _: AuthContext = Security(authorize), -) -> List[TagResourceResponse]: +) -> list[TagResourceResponse]: """Attach different tags to different resources. Args: @@ -104,7 +103,7 @@ def delete_tag_resource( ) @async_fastapi_endpoint_wrapper def batch_delete_tag_resource( - tag_resources: List[TagResourceRequest], + tag_resources: list[TagResourceRequest], _: AuthContext = Security(authorize), ) -> None: """Detach different tags from different resources. diff --git a/src/zenml/zen_server/routers/users_endpoints.py b/src/zenml/zen_server/routers/users_endpoints.py index 9f84afea9d8..47596ac3047 100644 --- a/src/zenml/zen_server/routers/users_endpoints.py +++ b/src/zenml/zen_server/routers/users_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for users.""" -from typing import List, Optional, Union from uuid import UUID from fastapi import APIRouter, Depends, Security @@ -165,7 +164,7 @@ def create_user( # 2. Create a new user without a password and have it activated at a # later time with an activation token - token: Optional[str] = None + token: str | None = None if user.password is None: user.active = False token = user.generate_activation_token() @@ -195,7 +194,7 @@ def create_user( ) @async_fastapi_endpoint_wrapper def get_user( - user_name_or_id: Union[str, UUID], + user_name_or_id: str | UUID, hydrate: bool = True, auth_context: AuthContext = Security(authorize), ) -> UserResponse: @@ -239,7 +238,7 @@ def get_user( ) @async_fastapi_endpoint_wrapper def activate_user( - user_name_or_id: Union[str, UUID], + user_name_or_id: str | UUID, user_update: UserUpdate, ) -> UserResponse: """Activates a specific user. @@ -299,7 +298,7 @@ def activate_user( ) @async_fastapi_endpoint_wrapper def deactivate_user( - user_name_or_id: Union[str, UUID], + user_name_or_id: str | UUID, auth_context: AuthContext = Security(authorize), ) -> UserResponse: """Deactivates a user and generates a new activation token for it. @@ -347,7 +346,7 @@ def deactivate_user( ) @async_fastapi_endpoint_wrapper def delete_user( - user_name_or_id: Union[str, UUID], + user_name_or_id: str | UUID, auth_context: AuthContext = Security(authorize), ) -> None: """Deletes a specific user. @@ -389,7 +388,7 @@ def delete_user( ) @async_fastapi_endpoint_wrapper def email_opt_in_response( - user_name_or_id: Union[str, UUID], + user_name_or_id: str | UUID, user_response: UserUpdate, auth_context: AuthContext = Security(authorize), ) -> UserResponse: @@ -441,7 +440,7 @@ def email_opt_in_response( ) @async_fastapi_endpoint_wrapper def update_user( - user_name_or_id: Union[str, UUID], + user_name_or_id: str | UUID, user_update: UserUpdate, request: Request, auth_context: AuthContext = Security(authorize), @@ -709,9 +708,9 @@ def update_myself( def update_user_resource_membership( resource_type: str, resource_id: UUID, - actions: List[str], - user_id: Optional[str] = None, - team_id: Optional[str] = None, + actions: list[str], + user_id: str | None = None, + team_id: str | None = None, auth_context: AuthContext = Security(authorize), ) -> None: """Updates resource memberships of a user. diff --git a/src/zenml/zen_server/routers/webhook_endpoints.py b/src/zenml/zen_server/routers/webhook_endpoints.py index e6e0c94f195..ecb8543178a 100644 --- a/src/zenml/zen_server/routers/webhook_endpoints.py +++ b/src/zenml/zen_server/routers/webhook_endpoints.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Endpoint definitions for webhooks.""" -from typing import Dict from uuid import UUID from fastapi import APIRouter, BackgroundTasks, Depends, Request @@ -53,7 +52,7 @@ async def get_body(request: Request) -> bytes: @router.post( "/{event_source_id}", - response_model=Dict[str, str], + response_model=dict[str, str], ) @async_fastapi_endpoint_wrapper def webhook( @@ -61,7 +60,7 @@ def webhook( request: Request, background_tasks: BackgroundTasks, raw_body: bytes = Depends(get_body), -) -> Dict[str, str]: +) -> dict[str, str]: """Webhook to receive events from external event sources. Args: diff --git a/src/zenml/zen_server/secure_headers.py b/src/zenml/zen_server/secure_headers.py index 8cff2187182..39716f0c5ee 100644 --- a/src/zenml/zen_server/secure_headers.py +++ b/src/zenml/zen_server/secure_headers.py @@ -13,13 +13,12 @@ # permissions and limitations under the License. """Secure headers for the ZenML Server.""" -from typing import Optional import secure from zenml.zen_server.utils import server_config -_secure_headers: Optional[secure.Secure] = None +_secure_headers: secure.Secure | None = None def secure_headers() -> secure.Secure: @@ -51,7 +50,7 @@ def initialize_secure_headers() -> None: # - if set to a string, we use the string as the value for the header # - if set to `False`, we don't set the header - server: Optional[secure.Server] = None + server: secure.Server | None = None if config.secure_headers_server: server = secure.Server() if isinstance(config.secure_headers_server, str): @@ -59,43 +58,43 @@ def initialize_secure_headers() -> None: else: server.set(str(config.deployment_id)) - hsts: Optional[secure.StrictTransportSecurity] = None + hsts: secure.StrictTransportSecurity | None = None if config.secure_headers_hsts: hsts = secure.StrictTransportSecurity() if isinstance(config.secure_headers_hsts, str): hsts.set(config.secure_headers_hsts) - xfo: Optional[secure.XFrameOptions] = None + xfo: secure.XFrameOptions | None = None if config.secure_headers_xfo: xfo = secure.XFrameOptions() if isinstance(config.secure_headers_xfo, str): xfo.set(config.secure_headers_xfo) - csp: Optional[secure.ContentSecurityPolicy] = None + csp: secure.ContentSecurityPolicy | None = None if config.secure_headers_csp: csp = secure.ContentSecurityPolicy() if isinstance(config.secure_headers_csp, str): csp.set(config.secure_headers_csp) - xcto: Optional[secure.XContentTypeOptions] = None + xcto: secure.XContentTypeOptions | None = None if config.secure_headers_content: xcto = secure.XContentTypeOptions() if isinstance(config.secure_headers_content, str): xcto.set(config.secure_headers_content) - referrer: Optional[secure.ReferrerPolicy] = None + referrer: secure.ReferrerPolicy | None = None if config.secure_headers_referrer: referrer = secure.ReferrerPolicy() if isinstance(config.secure_headers_referrer, str): referrer.set(config.secure_headers_referrer) - cache: Optional[secure.CacheControl] = None + cache: secure.CacheControl | None = None if config.secure_headers_cache: cache = secure.CacheControl() if isinstance(config.secure_headers_cache, str): cache.set(config.secure_headers_cache) - permissions: Optional[secure.PermissionsPolicy] = None + permissions: secure.PermissionsPolicy | None = None if config.secure_headers_permissions: permissions = secure.PermissionsPolicy() if isinstance(config.secure_headers_permissions, str): diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index 2575a880602..1ab9ae3d3da 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -24,17 +24,11 @@ from typing import ( TYPE_CHECKING, Any, - Awaitable, - Callable, - Dict, - List, Optional, - Tuple, - Type, TypeVar, - Union, overload, ) +from collections.abc import Awaitable, Callable from uuid import UUID import psutil @@ -84,13 +78,13 @@ logger = get_logger(__name__) _zen_store: Optional["SqlZenStore"] = None -_rbac: Optional[RBACInterface] = None -_feature_gate: Optional[FeatureGateInterface] = None -_workload_manager: Optional[WorkloadManagerInterface] = None +_rbac: RBACInterface | None = None +_feature_gate: FeatureGateInterface | None = None +_workload_manager: WorkloadManagerInterface | None = None _snapshot_executor: Optional["BoundedThreadPoolExecutor"] = None -_plugin_flavor_registry: Optional[PluginFlavorRegistry] = None -_memcache: Optional[MemoryCache] = None -_request_manager: Optional[RequestManager] = None +_plugin_flavor_registry: PluginFlavorRegistry | None = None +_memcache: MemoryCache | None = None +_request_manager: RequestManager | None = None def zen_store() -> "SqlZenStore": @@ -207,7 +201,7 @@ def initialize_workload_manager() -> None: from zenml.utils import source_utils try: - workload_manager_class: Type[WorkloadManagerInterface] = ( + workload_manager_class: type[WorkloadManagerInterface] = ( source_utils.load_and_validate_class( source=source, expected_class=WorkloadManagerInterface ) @@ -301,7 +295,7 @@ def memcache() -> MemoryCache: return _memcache -_server_config: Optional[ServerConfiguration] = None +_server_config: ServerConfiguration | None = None def server_config() -> ServerConfiguration: @@ -358,18 +352,18 @@ def async_fastapi_endpoint_wrapper( @overload def async_fastapi_endpoint_wrapper( - *, deduplicate: Optional[bool] = None + *, deduplicate: bool | None = None ) -> Callable[[Callable[P, R]], Callable[P, Awaitable[Any]]]: ... def async_fastapi_endpoint_wrapper( - func: Optional[Callable[P, R]] = None, + func: Callable[P, R] | None = None, *, - deduplicate: Optional[bool] = None, -) -> Union[ - Callable[P, Awaitable[Any]], - Callable[[Callable[P, R]], Callable[P, Awaitable[Any]]], -]: + deduplicate: bool | None = None, +) -> ( + Callable[P, Awaitable[Any]] | + Callable[[Callable[P, R]], Callable[P, Awaitable[Any]]] +): """Decorator for FastAPI endpoints. This decorator for FastAPI endpoints does the following: @@ -429,7 +423,7 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> Any: # Code from https://github.com/tiangolo/fastapi/issues/1474#issuecomment-1160633178 # to send 422 response when receiving invalid query parameters -def make_dependable(cls: Type[BaseModel]) -> Callable[..., Any]: +def make_dependable(cls: type[BaseModel]) -> Callable[..., Any]: """This function makes a pydantic model usable for fastapi query parameters. Additionally, it converts `InternalServerError`s that would happen due to @@ -484,7 +478,7 @@ def init_cls_and_handle_errors(*args: Any, **kwargs: Any) -> BaseModel: return init_cls_and_handle_errors -def get_ip_location(ip_address: str) -> Tuple[str, str, str]: +def get_ip_location(ip_address: str) -> tuple[str, str, str]: """Get the location of the given IP address. Args: @@ -509,8 +503,8 @@ def get_ip_location(ip_address: str) -> Tuple[str, str, str]: def verify_admin_status_if_no_rbac( - admin_status: Optional[bool], - action: Optional[str] = None, + admin_status: bool | None, + action: str | None = None, ) -> None: """Validate the admin status for sensitive requests. @@ -552,7 +546,7 @@ def is_user_request(request: "Request") -> bool: True if it's a user request, False otherwise. """ # Define system paths that should be excluded - system_paths: List[str] = [ + system_paths: list[str] = [ "/health", "/ready", "/metrics", @@ -652,7 +646,7 @@ def is_same_or_subdomain(source_domain: str, target_domain: str) -> bool: return False -def get_zenml_headers() -> Dict[str, str]: +def get_zenml_headers() -> dict[str, str]: """Get the ZenML specific headers to be included in requests made by the server. Returns: @@ -671,7 +665,7 @@ def get_zenml_headers() -> Dict[str, str]: def set_filter_project_scope( filter_model: ProjectScopedFilter, - project_name_or_id: Optional[Union[UUID, str]] = None, + project_name_or_id: UUID | str | None = None, ) -> None: """Set the project scope of the filter model. @@ -688,7 +682,7 @@ def set_filter_project_scope( process = psutil.Process() -fd_limit: Union[int, str] = "N/A" +fd_limit: int | str = "N/A" if sys.platform != "win32": import resource @@ -698,7 +692,7 @@ def set_filter_project_scope( pass -def get_system_metrics() -> Dict[str, Any]: +def get_system_metrics() -> dict[str, Any]: """Get comprehensive system metrics. Returns: @@ -711,7 +705,7 @@ def get_system_metrics() -> Dict[str, Any]: memory = process.memory_info() # File descriptors - open_fds: Union[int, str] = "N/A" + open_fds: int | str = "N/A" try: open_fds = process.num_fds() if hasattr(process, "num_fds") else "N/A" except Exception: @@ -757,7 +751,7 @@ def get_system_metrics_log_str(request: Optional["Request"] = None) -> str: ) -event_loop_lag_monitor_task: Optional[asyncio.Task[None]] = None +event_loop_lag_monitor_task: asyncio.Task[None] | None = None def start_event_loop_lag_monitor(threshold_ms: int = 50) -> None: diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 46feb51c476..d4785e38bb6 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -24,7 +24,7 @@ import os from asyncio.log import logger from genericpath import isfile -from typing import Any, List +from typing import Any from anyio import to_thread from fastapi import FastAPI, HTTPException, Request @@ -310,7 +310,7 @@ async def dashboard(request: Request) -> Any: app.include_router(users_endpoints.activation_router) -def get_root_static_files() -> List[str]: +def get_root_static_files() -> list[str]: """Get the list of static files in the root dashboard directory. These files are static files that are not in the /static subdirectory diff --git a/src/zenml/zen_stores/base_zen_store.py b/src/zenml/zen_stores/base_zen_store.py index 1f369a6e817..74e7daddf9a 100644 --- a/src/zenml/zen_stores/base_zen_store.py +++ b/src/zenml/zen_stores/base_zen_store.py @@ -18,10 +18,6 @@ from typing import ( Any, ClassVar, - Dict, - Optional, - Tuple, - Type, ) from uuid import UUID @@ -73,12 +69,12 @@ class BaseZenStore( config: StoreConfiguration TYPE: ClassVar[StoreType] - CONFIG_TYPE: ClassVar[Type[StoreConfiguration]] + CONFIG_TYPE: ClassVar[type[StoreConfiguration]] @model_validator(mode="before") @classmethod @before_validator_handler - def convert_config(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def convert_config(cls, data: dict[str, Any]) -> dict[str, Any]: """Method to infer the correct type of the config and convert. Args: @@ -142,7 +138,7 @@ def __init__( logger.debug("Skipping database initialization") @staticmethod - def get_store_class(store_type: StoreType) -> Type["BaseZenStore"]: + def get_store_class(store_type: StoreType) -> type["BaseZenStore"]: """Returns the class of the given store type. Args: @@ -178,7 +174,7 @@ def get_store_class(store_type: StoreType) -> Type["BaseZenStore"]: @staticmethod def get_store_config_class( store_type: StoreType, - ) -> Type["StoreConfiguration"]: + ) -> type["StoreConfiguration"]: """Returns the store config class of the given store type. Args: @@ -294,10 +290,10 @@ def type(self) -> StoreType: def validate_active_config( self, - active_project_id: Optional[UUID] = None, - active_stack_id: Optional[UUID] = None, + active_project_id: UUID | None = None, + active_stack_id: UUID | None = None, config_name: str = "", - ) -> Tuple[Optional[ProjectResponse], StackResponse]: + ) -> tuple[ProjectResponse | None, StackResponse]: """Validate the active configuration. Call this method to validate the supplied active project and active @@ -317,7 +313,7 @@ def validate_active_config( Returns: A tuple containing the active project and active stack. """ - active_project: Optional[ProjectResponse] = None + active_project: ProjectResponse | None = None if active_project_id: try: diff --git a/src/zenml/zen_stores/dag_generator.py b/src/zenml/zen_stores/dag_generator.py index f939393b35c..dd4da4df88f 100644 --- a/src/zenml/zen_stores/dag_generator.py +++ b/src/zenml/zen_stores/dag_generator.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """DAG generator helper.""" -from typing import Any, Dict, List, Optional +from typing import Any from uuid import UUID from zenml.enums import ExecutionStatus @@ -25,10 +25,10 @@ class DAGGeneratorHelper: def __init__(self) -> None: """Initialize the DAG generator helper.""" - self.step_nodes: Dict[str, PipelineRunDAG.Node] = {} - self.artifact_nodes: Dict[str, PipelineRunDAG.Node] = {} - self.triggered_run_nodes: Dict[str, PipelineRunDAG.Node] = {} - self.edges: List[PipelineRunDAG.Edge] = [] + self.step_nodes: dict[str, PipelineRunDAG.Node] = {} + self.artifact_nodes: dict[str, PipelineRunDAG.Node] = {} + self.triggered_run_nodes: dict[str, PipelineRunDAG.Node] = {} + self.edges: list[PipelineRunDAG.Edge] = [] def get_step_node_id(self, name: str) -> str: """Get the ID of a step node. @@ -81,7 +81,7 @@ def add_step_node( self, node_id: str, name: str, - id: Optional[UUID] = None, + id: UUID | None = None, **metadata: Any, ) -> PipelineRunDAG.Node: """Add a step node to the DAG. @@ -109,7 +109,7 @@ def add_artifact_node( self, node_id: str, name: str, - id: Optional[UUID] = None, + id: UUID | None = None, **metadata: Any, ) -> PipelineRunDAG.Node: """Add an artifact node to the DAG. @@ -137,7 +137,7 @@ def add_triggered_run_node( self, node_id: str, name: str, - id: Optional[UUID] = None, + id: UUID | None = None, **metadata: Any, ) -> PipelineRunDAG.Node: """Add a triggered run node to the DAG. diff --git a/src/zenml/zen_stores/migrations/alembic.py b/src/zenml/zen_stores/migrations/alembic.py index 99cafa6da79..9b515ad777b 100644 --- a/src/zenml/zen_stores/migrations/alembic.py +++ b/src/zenml/zen_stores/migrations/alembic.py @@ -19,7 +19,8 @@ """ from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Union +from collections.abc import Callable, Sequence from alembic.config import Config from alembic.runtime.environment import EnvironmentContext @@ -39,7 +40,7 @@ def include_object( - object: Any, name: Optional[str], type_: str, *args: Any, **kwargs: Any + object: Any, name: str | None, type_: str, *args: Any, **kwargs: Any ) -> bool: """Function used to exclude tables from the migration scripts. @@ -79,7 +80,7 @@ def __init__( self, engine: Engine, metadata: MetaData = SQLModel.metadata, - context: Optional[EnvironmentContext] = None, + context: EnvironmentContext | None = None, **kwargs: Any, ) -> None: """Initialize the Alembic wrapper. @@ -122,7 +123,7 @@ def db_is_empty(self) -> bool: def run_migrations( self, - fn: Optional[Callable[[_RevIdType, MigrationContext], List[Any]]], + fn: Callable[[_RevIdType, MigrationContext], list[Any]] | None, ) -> None: """Run an online migration function in the current migration context. @@ -130,7 +131,7 @@ def run_migrations( fn: Migration function to run. If not set, the function configured externally by the Alembic CLI command is used. """ - fn_context_args: Dict[Any, Any] = {} + fn_context_args: dict[Any, Any] = {} if fn is not None: fn_context_args["fn"] = fn @@ -149,15 +150,15 @@ def run_migrations( with self.environment_context.begin_transaction(): self.environment_context.run_migrations() - def head_revisions(self) -> List[str]: + def head_revisions(self) -> list[str]: """Get the head database revisions. Returns: List of head revisions. """ - head_revisions: List[str] = [] + head_revisions: list[str] = [] - def do_get_head_rev(rev: _RevIdType, context: Any) -> List[Any]: + def do_get_head_rev(rev: _RevIdType, context: Any) -> list[Any]: nonlocal head_revisions for r in self.script_directory.get_heads(): @@ -170,15 +171,15 @@ def do_get_head_rev(rev: _RevIdType, context: Any) -> List[Any]: return head_revisions - def current_revisions(self) -> List[str]: + def current_revisions(self) -> list[str]: """Get the current database revisions. Returns: List of head revisions. """ - current_revisions: List[str] = [] + current_revisions: list[str] = [] - def do_get_current_rev(rev: _RevIdType, context: Any) -> List[Any]: + def do_get_current_rev(rev: _RevIdType, context: Any) -> list[Any]: nonlocal current_revisions # Handle rev parameter in a way that's compatible with different alembic versions @@ -206,7 +207,7 @@ def stamp(self, revision: str) -> None: revision: String revision target. """ - def do_stamp(rev: _RevIdType, context: Any) -> List[Any]: + def do_stamp(rev: _RevIdType, context: Any) -> list[Any]: # Handle rev parameter in a way that's compatible with different alembic versions if isinstance(rev, str): return self.script_directory._stamp_revs(revision, rev) @@ -224,7 +225,7 @@ def upgrade(self, revision: str = "heads") -> None: revision: String revision target. """ - def do_upgrade(rev: _RevIdType, context: Any) -> List[Any]: + def do_upgrade(rev: _RevIdType, context: Any) -> list[Any]: # Handle rev parameter in a way that's compatible with different alembic versions if isinstance(rev, str): return self.script_directory._upgrade_revs(revision, rev) @@ -245,7 +246,7 @@ def downgrade(self, revision: str) -> None: revision: String revision target. """ - def do_downgrade(rev: _RevIdType, context: Any) -> List[Any]: + def do_downgrade(rev: _RevIdType, context: Any) -> list[Any]: # Handle rev parameter in a way that's compatible with different alembic versions if isinstance(rev, str): return self.script_directory._downgrade_revs(revision, rev) diff --git a/src/zenml/zen_stores/migrations/utils.py b/src/zenml/zen_stores/migrations/utils.py index f1947383a5c..0b650705a53 100644 --- a/src/zenml/zen_stores/migrations/utils.py +++ b/src/zenml/zen_stores/migrations/utils.py @@ -19,13 +19,9 @@ import shutil from typing import ( Any, - Callable, - Dict, - Generator, - List, - Optional, cast, ) +from collections.abc import Callable, Generator import pymysql from pydantic import BaseModel, ConfigDict @@ -50,13 +46,13 @@ class MigrationUtils(BaseModel): """Utilities for database migration, backup and recovery.""" url: URL - connect_args: Dict[str, Any] - engine_args: Dict[str, Any] + connect_args: dict[str, Any] + engine_args: dict[str, Any] - _engine: Optional[Engine] = None - _master_engine: Optional[Engine] = None + _engine: Engine | None = None + _master_engine: Engine | None = None - def create_engine(self, database: Optional[str] = None) -> Engine: + def create_engine(self, database: str | None = None) -> Engine: """Get the SQLAlchemy engine for a database. Args: @@ -115,7 +111,7 @@ def is_mysql_missing_database_error(cls, error: OperationalError) -> bool: def database_exists( self, - database: Optional[str] = None, + database: str | None = None, ) -> bool: """Check if a database exists. @@ -147,7 +143,7 @@ def database_exists( def drop_database( self, - database: Optional[str] = None, + database: str | None = None, ) -> None: """Drops a mysql database. @@ -163,7 +159,7 @@ def drop_database( def create_database( self, - database: Optional[str] = None, + database: str | None = None, drop: bool = False, ) -> None: """Creates a mysql database. @@ -182,7 +178,7 @@ def create_database( conn.execute(text(f"CREATE DATABASE IF NOT EXISTS `{database}`")) def backup_database_to_storage( - self, store_db_info: Callable[[Dict[str, Any]], None] + self, store_db_info: Callable[[dict[str, Any]], None] ) -> None: """Backup the database to a storage location. @@ -358,7 +354,7 @@ def backup_database_to_storage( ) def restore_database_from_storage( - self, load_db_info: Callable[[], Generator[Dict[str, Any], None, None]] + self, load_db_info: Callable[[], Generator[dict[str, Any], None, None]] ) -> None: """Restore the database from a backup storage location. @@ -385,7 +381,7 @@ def restore_database_from_storage( with self.engine.begin() as connection: # read the DB information one JSON object at a time - self_references: Dict[str, bool] = {} + self_references: dict[str, bool] = {} for table_dump in load_db_info(): table_name = table_dump["table"] if "create_stmt" in table_dump: @@ -515,7 +511,7 @@ def backup_database_to_file(self, dump_file: str) -> None: with open(dump_file, "w") as f: - def json_dump(obj: Dict[str, Any]) -> None: + def json_dump(obj: dict[str, Any]) -> None: """Dump a JSON object to the dump file. Args: @@ -566,9 +562,9 @@ def restore_database_from_file(self, dump_file: str) -> None: return # read the DB dump file one JSON object at a time - with open(dump_file, "r") as f: + with open(dump_file) as f: - def json_load() -> Generator[Dict[str, Any], None, None]: + def json_load() -> Generator[dict[str, Any], None, None]: """Generator that loads the JSON objects in the dump file. Yields: @@ -590,7 +586,7 @@ def json_load() -> Generator[Dict[str, Any], None, None]: logger.info(f"Database successfully restored from '{dump_file}'") - def backup_database_to_memory(self) -> List[Dict[str, Any]]: + def backup_database_to_memory(self) -> list[dict[str, Any]]: """Backup the database in memory. Returns: @@ -605,9 +601,9 @@ def backup_database_to_memory(self) -> List[Dict[str, Any]]: "In-memory backup is not supported for sqlite databases." ) - db_dump: List[Dict[str, Any]] = [] + db_dump: list[dict[str, Any]] = [] - def store_in_mem(obj: Dict[str, Any]) -> None: + def store_in_mem(obj: dict[str, Any]) -> None: """Store a JSON object in the in-memory database backup. Args: @@ -624,7 +620,7 @@ def store_in_mem(obj: Dict[str, Any]) -> None: return db_dump def restore_database_from_memory( - self, db_dump: List[Dict[str, Any]] + self, db_dump: list[dict[str, Any]] ) -> None: """Restore the database from an in-memory backup. @@ -641,14 +637,13 @@ def restore_database_from_memory( "In-memory backup is not supported for sqlite databases." ) - def load_from_mem() -> Generator[Dict[str, Any], None, None]: + def load_from_mem() -> Generator[dict[str, Any], None, None]: """Generator that loads the JSON objects from the in-memory backup. Yields: The loaded JSON objects. """ - for obj in db_dump: - yield obj + yield from db_dump # Call the generic restore method with a function that loads the # JSON objects from the in-memory database backup diff --git a/src/zenml/zen_stores/migrations/versions/0.33.0_release.py b/src/zenml/zen_stores/migrations/versions/0.33.0_release.py index b380d4a4924..79687430301 100644 --- a/src/zenml/zen_stores/migrations/versions/0.33.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.33.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.34.0_release.py b/src/zenml/zen_stores/migrations/versions/0.34.0_release.py index 01aef0dd999..c9d5e352c38 100644 --- a/src/zenml/zen_stores/migrations/versions/0.34.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.34.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.35.0_release.py b/src/zenml/zen_stores/migrations/versions/0.35.0_release.py index 0e46e613fde..d449bae2af7 100644 --- a/src/zenml/zen_stores/migrations/versions/0.35.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.35.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.35.1_release.py b/src/zenml/zen_stores/migrations/versions/0.35.1_release.py index d243777bcf7..4d58f8f0d18 100644 --- a/src/zenml/zen_stores/migrations/versions/0.35.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.35.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.36.0_release.py b/src/zenml/zen_stores/migrations/versions/0.36.0_release.py index fe8c7edabd9..a579a67f6d6 100644 --- a/src/zenml/zen_stores/migrations/versions/0.36.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.36.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.36.1_release.py b/src/zenml/zen_stores/migrations/versions/0.36.1_release.py index 6db7ccf9d27..6ae4dcbf504 100644 --- a/src/zenml/zen_stores/migrations/versions/0.36.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.36.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.37.0_release.py b/src/zenml/zen_stores/migrations/versions/0.37.0_release.py index 44236f01412..b01a7722325 100644 --- a/src/zenml/zen_stores/migrations/versions/0.37.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.37.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.38.0_release.py b/src/zenml/zen_stores/migrations/versions/0.38.0_release.py index a27f0a27fb7..8d510432e01 100644 --- a/src/zenml/zen_stores/migrations/versions/0.38.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.38.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.39.0_release.py b/src/zenml/zen_stores/migrations/versions/0.39.0_release.py index ddf3b10a451..23ded6dce97 100644 --- a/src/zenml/zen_stores/migrations/versions/0.39.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.39.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.39.1_release.py b/src/zenml/zen_stores/migrations/versions/0.39.1_release.py index 0916bdc3c94..4ce84d9e5f9 100644 --- a/src/zenml/zen_stores/migrations/versions/0.39.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.39.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.40.0_release.py b/src/zenml/zen_stores/migrations/versions/0.40.0_release.py index 2b4a95fc8b6..b6d3a99d6e5 100644 --- a/src/zenml/zen_stores/migrations/versions/0.40.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.40.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.40.1_release.py b/src/zenml/zen_stores/migrations/versions/0.40.1_release.py index a5dfaab67a6..fe8dfa9f27f 100644 --- a/src/zenml/zen_stores/migrations/versions/0.40.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.40.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.40.2_release.py b/src/zenml/zen_stores/migrations/versions/0.40.2_release.py index 11ba04c19a1..4ef6e817a51 100644 --- a/src/zenml/zen_stores/migrations/versions/0.40.2_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.40.2_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.40.3_release.py b/src/zenml/zen_stores/migrations/versions/0.40.3_release.py index b718a83d1fa..d034cc47402 100644 --- a/src/zenml/zen_stores/migrations/versions/0.40.3_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.40.3_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.41.0_release.py b/src/zenml/zen_stores/migrations/versions/0.41.0_release.py index f9ab6290109..e0461d902ff 100644 --- a/src/zenml/zen_stores/migrations/versions/0.41.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.41.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.42.0_release.py b/src/zenml/zen_stores/migrations/versions/0.42.0_release.py index 1823dd37a16..6c6cb0278b4 100644 --- a/src/zenml/zen_stores/migrations/versions/0.42.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.42.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.42.1_release.py b/src/zenml/zen_stores/migrations/versions/0.42.1_release.py index e2bf939f1f8..582a2e15db8 100644 --- a/src/zenml/zen_stores/migrations/versions/0.42.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.42.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.43.0_release.py b/src/zenml/zen_stores/migrations/versions/0.43.0_release.py index 47f129c66a0..426a93a4ca9 100644 --- a/src/zenml/zen_stores/migrations/versions/0.43.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.43.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.44.0_release.py b/src/zenml/zen_stores/migrations/versions/0.44.0_release.py index 56e1ee89a9e..94314f88a37 100644 --- a/src/zenml/zen_stores/migrations/versions/0.44.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.44.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.44.1_release.py b/src/zenml/zen_stores/migrations/versions/0.44.1_release.py index 0dea85995c0..a89f1f53eda 100644 --- a/src/zenml/zen_stores/migrations/versions/0.44.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.44.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.44.2_release.py b/src/zenml/zen_stores/migrations/versions/0.44.2_release.py index a6be88aaba0..c0d9d07df10 100644 --- a/src/zenml/zen_stores/migrations/versions/0.44.2_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.44.2_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.44.3_release.py b/src/zenml/zen_stores/migrations/versions/0.44.3_release.py index f0f10076a81..1b6ab3867b0 100644 --- a/src/zenml/zen_stores/migrations/versions/0.44.3_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.44.3_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.45.0_release.py b/src/zenml/zen_stores/migrations/versions/0.45.0_release.py index b71b4516122..d61c4861190 100644 --- a/src/zenml/zen_stores/migrations/versions/0.45.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.45.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.45.1_release_0_45_1.py b/src/zenml/zen_stores/migrations/versions/0.45.1_release_0_45_1.py index e214e67d76c..a32298a9d78 100644 --- a/src/zenml/zen_stores/migrations/versions/0.45.1_release_0_45_1.py +++ b/src/zenml/zen_stores/migrations/versions/0.45.1_release_0_45_1.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.45.2_release.py b/src/zenml/zen_stores/migrations/versions/0.45.2_release.py index 93aaa24b7e3..c3a3836496e 100644 --- a/src/zenml/zen_stores/migrations/versions/0.45.2_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.45.2_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.45.3_release.py b/src/zenml/zen_stores/migrations/versions/0.45.3_release.py index b5068767a9a..9d6c8a03943 100644 --- a/src/zenml/zen_stores/migrations/versions/0.45.3_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.45.3_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.45.4_release.py b/src/zenml/zen_stores/migrations/versions/0.45.4_release.py index 94b572a735f..09d7a3317b3 100644 --- a/src/zenml/zen_stores/migrations/versions/0.45.4_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.45.4_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.45.5_release.py b/src/zenml/zen_stores/migrations/versions/0.45.5_release.py index 9442bbef904..94537751032 100644 --- a/src/zenml/zen_stores/migrations/versions/0.45.5_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.45.5_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.45.6_release.py b/src/zenml/zen_stores/migrations/versions/0.45.6_release.py index cd62d8c207f..c9cb6465d7c 100644 --- a/src/zenml/zen_stores/migrations/versions/0.45.6_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.45.6_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.46.0_release.py b/src/zenml/zen_stores/migrations/versions/0.46.0_release.py index 3cfffac6a38..4ea7e9a0fe1 100644 --- a/src/zenml/zen_stores/migrations/versions/0.46.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.46.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.46.1_release.py b/src/zenml/zen_stores/migrations/versions/0.46.1_release.py index 753945c46a5..0542df8c0d8 100644 --- a/src/zenml/zen_stores/migrations/versions/0.46.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.46.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.47.0_release.py b/src/zenml/zen_stores/migrations/versions/0.47.0_release.py index 99dc7c2834c..959c1754e21 100644 --- a/src/zenml/zen_stores/migrations/versions/0.47.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.47.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.50.0_release.py b/src/zenml/zen_stores/migrations/versions/0.50.0_release.py index 280d6c9adfd..1c10fe77cb6 100644 --- a/src/zenml/zen_stores/migrations/versions/0.50.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.50.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.51.0_release.py b/src/zenml/zen_stores/migrations/versions/0.51.0_release.py index 1d2e3deb2d1..6f5f365928e 100644 --- a/src/zenml/zen_stores/migrations/versions/0.51.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.51.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.52.0_release.py b/src/zenml/zen_stores/migrations/versions/0.52.0_release.py index a952d8f11e1..cbbeb0d45b5 100644 --- a/src/zenml/zen_stores/migrations/versions/0.52.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.52.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.53.0_release.py b/src/zenml/zen_stores/migrations/versions/0.53.0_release.py index c6b4e0de3a2..a827d23abda 100644 --- a/src/zenml/zen_stores/migrations/versions/0.53.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.53.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.53.1_release.py b/src/zenml/zen_stores/migrations/versions/0.53.1_release.py index 7e492c17676..573d0b21ced 100644 --- a/src/zenml/zen_stores/migrations/versions/0.53.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.53.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.54.0_release.py b/src/zenml/zen_stores/migrations/versions/0.54.0_release.py index 7fad2a6d5e6..ebb75e4e27d 100644 --- a/src/zenml/zen_stores/migrations/versions/0.54.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.54.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.54.1_release.py b/src/zenml/zen_stores/migrations/versions/0.54.1_release.py index 96df433f3a7..ef7823383a8 100644 --- a/src/zenml/zen_stores/migrations/versions/0.54.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.54.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.55.0_release.py b/src/zenml/zen_stores/migrations/versions/0.55.0_release.py index 929366179a8..c59647efaa6 100644 --- a/src/zenml/zen_stores/migrations/versions/0.55.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.55.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.55.1_release.py b/src/zenml/zen_stores/migrations/versions/0.55.1_release.py index f4f3b4135ee..5e4ecb329c8 100644 --- a/src/zenml/zen_stores/migrations/versions/0.55.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.55.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.55.2_release.py b/src/zenml/zen_stores/migrations/versions/0.55.2_release.py index fe61721c976..5d21511c9cf 100644 --- a/src/zenml/zen_stores/migrations/versions/0.55.2_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.55.2_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.55.3_release.py b/src/zenml/zen_stores/migrations/versions/0.55.3_release.py index 645586b5df4..7c910d8e012 100644 --- a/src/zenml/zen_stores/migrations/versions/0.55.3_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.55.3_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.55.4_release.py b/src/zenml/zen_stores/migrations/versions/0.55.4_release.py index 611a1ab30ad..05003e8182f 100644 --- a/src/zenml/zen_stores/migrations/versions/0.55.4_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.55.4_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.55.5_release.py b/src/zenml/zen_stores/migrations/versions/0.55.5_release.py index f1175c4884b..860dff4da35 100644 --- a/src/zenml/zen_stores/migrations/versions/0.55.5_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.55.5_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.56.0_release.py b/src/zenml/zen_stores/migrations/versions/0.56.0_release.py index 85dc2ccdf5e..e00dd88ad2f 100644 --- a/src/zenml/zen_stores/migrations/versions/0.56.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.56.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.56.1_release.py b/src/zenml/zen_stores/migrations/versions/0.56.1_release.py index d1eb6c0c982..f636ef94722 100644 --- a/src/zenml/zen_stores/migrations/versions/0.56.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.56.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.56.2_release.py b/src/zenml/zen_stores/migrations/versions/0.56.2_release.py index 47431e949fe..acd6edfa546 100644 --- a/src/zenml/zen_stores/migrations/versions/0.56.2_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.56.2_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.56.3_release.py b/src/zenml/zen_stores/migrations/versions/0.56.3_release.py index c3eb51e9db6..1a184053bac 100644 --- a/src/zenml/zen_stores/migrations/versions/0.56.3_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.56.3_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.56.4_release.py b/src/zenml/zen_stores/migrations/versions/0.56.4_release.py index e93afba0402..f3846775d4a 100644 --- a/src/zenml/zen_stores/migrations/versions/0.56.4_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.56.4_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.57.0.rc1_release.py b/src/zenml/zen_stores/migrations/versions/0.57.0.rc1_release.py index 1498a0e3ac9..8f241b4ade2 100644 --- a/src/zenml/zen_stores/migrations/versions/0.57.0.rc1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.57.0.rc1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.57.0.rc2_release.py b/src/zenml/zen_stores/migrations/versions/0.57.0.rc2_release.py index 2e8d1e87f34..065d3870aa4 100644 --- a/src/zenml/zen_stores/migrations/versions/0.57.0.rc2_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.57.0.rc2_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.57.0_release.py b/src/zenml/zen_stores/migrations/versions/0.57.0_release.py index 9ab354421d3..a0a22ce66f4 100644 --- a/src/zenml/zen_stores/migrations/versions/0.57.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.57.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.57.1_release.py b/src/zenml/zen_stores/migrations/versions/0.57.1_release.py index 4fe7e5d7b15..a7bf58c0947 100644 --- a/src/zenml/zen_stores/migrations/versions/0.57.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.57.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.58.0_release.py b/src/zenml/zen_stores/migrations/versions/0.58.0_release.py index fdb901c58dc..b48ef306e75 100644 --- a/src/zenml/zen_stores/migrations/versions/0.58.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.58.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.58.1_release.py b/src/zenml/zen_stores/migrations/versions/0.58.1_release.py index 87131e68b31..8c52d3e0afb 100644 --- a/src/zenml/zen_stores/migrations/versions/0.58.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.58.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.58.2_release.py b/src/zenml/zen_stores/migrations/versions/0.58.2_release.py index 5d3aa61b7cd..27ac0fd6787 100644 --- a/src/zenml/zen_stores/migrations/versions/0.58.2_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.58.2_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.60.0_release.py b/src/zenml/zen_stores/migrations/versions/0.60.0_release.py index 540e9d6771b..9077f083072 100644 --- a/src/zenml/zen_stores/migrations/versions/0.60.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.60.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.61.0_release.py b/src/zenml/zen_stores/migrations/versions/0.61.0_release.py index 8df6be4bd60..8b6821fc3d9 100644 --- a/src/zenml/zen_stores/migrations/versions/0.61.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.61.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.62.0_release.py b/src/zenml/zen_stores/migrations/versions/0.62.0_release.py index 5eaabfd328c..0e5e4d86a87 100644 --- a/src/zenml/zen_stores/migrations/versions/0.62.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.62.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.63.0_release.py b/src/zenml/zen_stores/migrations/versions/0.63.0_release.py index eb72d0880de..e240848bead 100644 --- a/src/zenml/zen_stores/migrations/versions/0.63.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.63.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.64.0_release.py b/src/zenml/zen_stores/migrations/versions/0.64.0_release.py index 717272cedf4..0696d274404 100644 --- a/src/zenml/zen_stores/migrations/versions/0.64.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.64.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.65.0_release.py b/src/zenml/zen_stores/migrations/versions/0.65.0_release.py index 349479d76f0..1fbeed5bd7e 100644 --- a/src/zenml/zen_stores/migrations/versions/0.65.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.65.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.66.0_release.py b/src/zenml/zen_stores/migrations/versions/0.66.0_release.py index 4b0d81cc095..38ccfc51151 100644 --- a/src/zenml/zen_stores/migrations/versions/0.66.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.66.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.67.0_release.py b/src/zenml/zen_stores/migrations/versions/0.67.0_release.py index cbebe847b30..13d79f1fc5b 100644 --- a/src/zenml/zen_stores/migrations/versions/0.67.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.67.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.68.0_release.py b/src/zenml/zen_stores/migrations/versions/0.68.0_release.py index fa9c8267f68..4f23b947225 100644 --- a/src/zenml/zen_stores/migrations/versions/0.68.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.68.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.68.1_release.py b/src/zenml/zen_stores/migrations/versions/0.68.1_release.py index 748fb0750b6..3d1122cad4f 100644 --- a/src/zenml/zen_stores/migrations/versions/0.68.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.68.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.70.0_release.py b/src/zenml/zen_stores/migrations/versions/0.70.0_release.py index b9f977e6569..51796c70e9b 100644 --- a/src/zenml/zen_stores/migrations/versions/0.70.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.70.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.71.0_release.py b/src/zenml/zen_stores/migrations/versions/0.71.0_release.py index 84623865544..4bca6065b0a 100644 --- a/src/zenml/zen_stores/migrations/versions/0.71.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.71.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.72.0_release.py b/src/zenml/zen_stores/migrations/versions/0.72.0_release.py index 3d2887e5504..ee2bb3b899b 100644 --- a/src/zenml/zen_stores/migrations/versions/0.72.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.72.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.73.0_release.py b/src/zenml/zen_stores/migrations/versions/0.73.0_release.py index b1c649bff39..013879ea963 100644 --- a/src/zenml/zen_stores/migrations/versions/0.73.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.73.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.74.0_release.py b/src/zenml/zen_stores/migrations/versions/0.74.0_release.py index ca951b82dbd..b56b2542e82 100644 --- a/src/zenml/zen_stores/migrations/versions/0.74.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.74.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.75.0_release.py b/src/zenml/zen_stores/migrations/versions/0.75.0_release.py index b181e84ec25..e464d118b8a 100644 --- a/src/zenml/zen_stores/migrations/versions/0.75.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.75.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.80.0_release.py b/src/zenml/zen_stores/migrations/versions/0.80.0_release.py index 00e844f4bf9..f7ab2b579f1 100644 --- a/src/zenml/zen_stores/migrations/versions/0.80.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.80.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.80.1_release.py b/src/zenml/zen_stores/migrations/versions/0.80.1_release.py index 393961a6e63..7b8f44813a5 100644 --- a/src/zenml/zen_stores/migrations/versions/0.80.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.80.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.80.2_release.py b/src/zenml/zen_stores/migrations/versions/0.80.2_release.py index c170789fcb2..87f97006dac 100644 --- a/src/zenml/zen_stores/migrations/versions/0.80.2_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.80.2_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.81.0_release.py b/src/zenml/zen_stores/migrations/versions/0.81.0_release.py index 2d1218e3161..c09c887fb92 100644 --- a/src/zenml/zen_stores/migrations/versions/0.81.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.81.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.82.0_release.py b/src/zenml/zen_stores/migrations/versions/0.82.0_release.py index 07e1a830ff9..0cc0b5ac91b 100644 --- a/src/zenml/zen_stores/migrations/versions/0.82.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.82.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.82.1_release.py b/src/zenml/zen_stores/migrations/versions/0.82.1_release.py index 7bde7d94eef..4cb0ae9f0dc 100644 --- a/src/zenml/zen_stores/migrations/versions/0.82.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.82.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.83.0_release.py b/src/zenml/zen_stores/migrations/versions/0.83.0_release.py index 031d4b6fcd0..1bdd4a49e72 100644 --- a/src/zenml/zen_stores/migrations/versions/0.83.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.83.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.83.1_release.py b/src/zenml/zen_stores/migrations/versions/0.83.1_release.py index 364710d1bbe..87af7111ebb 100644 --- a/src/zenml/zen_stores/migrations/versions/0.83.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.83.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.84.0_release.py b/src/zenml/zen_stores/migrations/versions/0.84.0_release.py index f16badb9f50..36c9d1e40bd 100644 --- a/src/zenml/zen_stores/migrations/versions/0.84.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.84.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.84.1_release.py b/src/zenml/zen_stores/migrations/versions/0.84.1_release.py index b662ed5f58a..0ce3931db47 100644 --- a/src/zenml/zen_stores/migrations/versions/0.84.1_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.84.1_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.84.2_release.py b/src/zenml/zen_stores/migrations/versions/0.84.2_release.py index bb6d7a55c08..18bc624acfc 100644 --- a/src/zenml/zen_stores/migrations/versions/0.84.2_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.84.2_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.84.3_release.py b/src/zenml/zen_stores/migrations/versions/0.84.3_release.py index b2f4b83a467..0a42ded875a 100644 --- a/src/zenml/zen_stores/migrations/versions/0.84.3_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.84.3_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.85.0_release.py b/src/zenml/zen_stores/migrations/versions/0.85.0_release.py index 60f31cf0e36..0e42e74b349 100644 --- a/src/zenml/zen_stores/migrations/versions/0.85.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.85.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.90.0_release.py b/src/zenml/zen_stores/migrations/versions/0.90.0_release.py index 6e5549a7ff9..223b55194b1 100644 --- a/src/zenml/zen_stores/migrations/versions/0.90.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.90.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.90.0rc0_release.py b/src/zenml/zen_stores/migrations/versions/0.90.0rc0_release.py index 36dcb2a645e..398429c52ca 100644 --- a/src/zenml/zen_stores/migrations/versions/0.90.0rc0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.90.0rc0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/0.91.0_release.py b/src/zenml/zen_stores/migrations/versions/0.91.0_release.py index b47df69e22f..37cb34ca092 100644 --- a/src/zenml/zen_stores/migrations/versions/0.91.0_release.py +++ b/src/zenml/zen_stores/migrations/versions/0.91.0_release.py @@ -15,9 +15,7 @@ def upgrade() -> None: """Upgrade database schema and/or data, creating a new revision.""" - pass def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/1041bc644e0d_remove_secrets_manager.py b/src/zenml/zen_stores/migrations/versions/1041bc644e0d_remove_secrets_manager.py index 93a9979644b..8c6391a46eb 100644 --- a/src/zenml/zen_stores/migrations/versions/1041bc644e0d_remove_secrets_manager.py +++ b/src/zenml/zen_stores/migrations/versions/1041bc644e0d_remove_secrets_manager.py @@ -50,4 +50,3 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/26b776ad583e_redesign_artifacts.py b/src/zenml/zen_stores/migrations/versions/26b776ad583e_redesign_artifacts.py index b536402d57a..242b283d1b6 100644 --- a/src/zenml/zen_stores/migrations/versions/26b776ad583e_redesign_artifacts.py +++ b/src/zenml/zen_stores/migrations/versions/26b776ad583e_redesign_artifacts.py @@ -6,7 +6,7 @@ """ -from typing import Any, Dict +from typing import Any import sqlalchemy as sa import sqlmodel @@ -108,7 +108,7 @@ def _find_produced_artifact(cached_artifact: Any) -> Any: # For each cached artifact, find the ID of the original produced artifact # and link all input artifact entries to the produced artifact. - cached_to_produced_mapping: Dict[str, str] = {} + cached_to_produced_mapping: dict[str, str] = {} for cached_artifact in cached_artifacts: produced_artifact = _find_produced_artifact(cached_artifact) cached_to_produced_mapping[cached_artifact.id] = produced_artifact.id diff --git a/src/zenml/zen_stores/migrations/versions/279d55228d28_add_default_deployer.py b/src/zenml/zen_stores/migrations/versions/279d55228d28_add_default_deployer.py index e13d0d7bb38..eea3678125c 100644 --- a/src/zenml/zen_stores/migrations/versions/279d55228d28_add_default_deployer.py +++ b/src/zenml/zen_stores/migrations/versions/279d55228d28_add_default_deployer.py @@ -49,7 +49,7 @@ def upgrade() -> None: return now = utc_now() - empty_config = base64.b64encode("{}".encode("utf-8")) + empty_config = base64.b64encode(b"{}") # If an existing default deployer exists, rename it to avoid conflicts existing_default = connection.execute( @@ -90,4 +90,3 @@ def upgrade() -> None: def downgrade() -> None: """No downgrade supported for this data migration.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/2d201872e23c_remove_db_dependency_loop.py b/src/zenml/zen_stores/migrations/versions/2d201872e23c_remove_db_dependency_loop.py index a2c1ea0e869..a379a2dff4c 100644 --- a/src/zenml/zen_stores/migrations/versions/2d201872e23c_remove_db_dependency_loop.py +++ b/src/zenml/zen_stores/migrations/versions/2d201872e23c_remove_db_dependency_loop.py @@ -26,4 +26,3 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/3944116bbd56_rename_project_to_workspace.py b/src/zenml/zen_stores/migrations/versions/3944116bbd56_rename_project_to_workspace.py index c2e60b13b3a..94432fb2255 100644 --- a/src/zenml/zen_stores/migrations/versions/3944116bbd56_rename_project_to_workspace.py +++ b/src/zenml/zen_stores/migrations/versions/3944116bbd56_rename_project_to_workspace.py @@ -6,7 +6,6 @@ """ -from typing import Set import sqlalchemy as sa from alembic import op @@ -31,7 +30,7 @@ def _fk_constraint_name(table: str, column: str) -> str: return f"fk_{table}_{column}_workspace" -def _get_changed_tables() -> Set[str]: +def _get_changed_tables() -> set[str]: return { "artifact", "flavor", diff --git a/src/zenml/zen_stores/migrations/versions/5330ba58bf20_rename_tables_and_foreign_keys.py b/src/zenml/zen_stores/migrations/versions/5330ba58bf20_rename_tables_and_foreign_keys.py index a05b88a03c5..01b2c300d25 100644 --- a/src/zenml/zen_stores/migrations/versions/5330ba58bf20_rename_tables_and_foreign_keys.py +++ b/src/zenml/zen_stores/migrations/versions/5330ba58bf20_rename_tables_and_foreign_keys.py @@ -7,7 +7,6 @@ """ from collections import defaultdict -from typing import Dict, List, Tuple import sqlalchemy as sa from alembic import op @@ -87,12 +86,12 @@ def _create_fk_constraint( ) -def _get_changes() -> Tuple[ - List[str], - List[str], - List[str], - List[Tuple[str, str, str, str, str]], - List[Tuple[str, str, str, str, str]], +def _get_changes() -> tuple[ + list[str], + list[str], + list[str], + list[tuple[str, str, str, str, str]], + list[tuple[str, str, str, str, str]], ]: """Define the data that should be changed in the schema. @@ -108,7 +107,7 @@ def _get_changes() -> Tuple[ (source, target, source_column, target_column, ondelete) """ # Define all the tables that should be renamed - table_name_mapping: Dict[str, str] = { + table_name_mapping: dict[str, str] = { "roleschema": "role", "stepinputartifactschema": "step_run_input_artifact", "userroleassignmentschema": "user_role_assignment", @@ -140,7 +139,7 @@ def _get_changes() -> Tuple[ "pipeline_run", "stack", ] - new_fk_constraints: List[Tuple[str, str, str, str, str]] = [ + new_fk_constraints: list[tuple[str, str, str, str, str]] = [ *[ (source, "workspace", "project_id", "id", "CASCADE") for source in project_user_fk_tables @@ -240,7 +239,7 @@ def upgrade() -> None: # foreign key number. if engine_name == "mysql": old_fk_constraints.sort(key=lambda x: (x[0], x[2])) - source_table_fk_constraint_counts: Dict[str, int] = defaultdict(int) + source_table_fk_constraint_counts: dict[str, int] = defaultdict(int) # Drop old foreign key constraints. for source, target, source_column, _, _ in old_fk_constraints: diff --git a/src/zenml/zen_stores/migrations/versions/6119cd9b93c2_tags_table.py b/src/zenml/zen_stores/migrations/versions/6119cd9b93c2_tags_table.py index 8453bec806b..ea9f58852f5 100644 --- a/src/zenml/zen_stores/migrations/versions/6119cd9b93c2_tags_table.py +++ b/src/zenml/zen_stores/migrations/versions/6119cd9b93c2_tags_table.py @@ -10,7 +10,6 @@ import random from collections import defaultdict from datetime import datetime -from typing import Set from uuid import uuid4 import sqlalchemy as sa @@ -53,7 +52,7 @@ def upgrade() -> None: ) ) # find unique tags and de-json tags - unique_tags: Set[str] = set() + unique_tags: set[str] = set() model_tags_prepared = [] for id_, tags in model_tags: try: diff --git a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py index 5f391ff957d..affcef7f11b 100644 --- a/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py +++ b/src/zenml/zen_stores/migrations/versions/7500f434b71c_remove_shared_columns.py @@ -8,7 +8,6 @@ import base64 from collections import defaultdict -from typing import Dict, Optional, Set from uuid import uuid4 import sqlalchemy as sa @@ -24,7 +23,7 @@ def _rename_duplicate_entities( - table: sa.Table, reserved_names: Optional[Set[str]] = None + table: sa.Table, reserved_names: set[str] | None = None ) -> None: """Include owner id in the name of duplicate entities. @@ -75,7 +74,7 @@ def _rename_duplicate_components(table: sa.Table) -> None: table.c.user_id, ) - names_per_type: Dict[str, Set[str]] = defaultdict(lambda: {"default"}) + names_per_type: dict[str, set[str]] = defaultdict(lambda: {"default"}) for id, type_, name, user_id in connection.execute(query).fetchall(): if user_id is None: @@ -138,7 +137,7 @@ def resolve_duplicate_names() -> None: "name": "default", "type": "artifact_store", "flavor": "local", - "configuration": base64.b64encode("{}".encode("utf-8")), + "configuration": base64.b64encode(b"{}"), "is_shared": True, "created": utcnow, "updated": utcnow, @@ -150,7 +149,7 @@ def resolve_duplicate_names() -> None: "name": "default", "type": "orchestrator", "flavor": "local", - "configuration": base64.b64encode("{}".encode("utf-8")), + "configuration": base64.b64encode(b"{}"), "is_shared": True, "created": utcnow, "updated": utcnow, diff --git a/src/zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py b/src/zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py index 7bb4f130e07..9cb862a6479 100644 --- a/src/zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py +++ b/src/zenml/zen_stores/migrations/versions/7b651bf6822e_track_secrets_in_db.py @@ -8,7 +8,7 @@ from abc import abstractmethod from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any from uuid import UUID import sqlalchemy as sa @@ -62,8 +62,8 @@ class BaseSecretsStoreBackend(BaseModel): def _get_secret_metadata( self, - secret_id: Optional[UUID] = None, - ) -> Dict[str, str]: + secret_id: UUID | None = None, + ) -> dict[str, str]: """Get a dictionary with metadata that can be used as tags/labels. This utility method can be used with Secrets Managers that can @@ -86,7 +86,7 @@ def _get_secret_metadata( # from other secrets that might be stored in the same backend and # to distinguish between different ZenML deployments using the same # backend. - metadata: Dict[str, str] = {ZENML_SECRET_LABEL: str(self.server_id)} + metadata: dict[str, str] = {ZENML_SECRET_LABEL: str(self.server_id)} # Include the secret ID if provided. if secret_id is not None: @@ -96,7 +96,7 @@ def _get_secret_metadata( def _create_secret_from_metadata( self, - metadata: Dict[str, str], + metadata: dict[str, str], created: datetime, updated: datetime, ) -> ZenMLSecretMetadata: @@ -147,7 +147,7 @@ def _create_secret_from_metadata( @abstractmethod def list_secrets( self, - ) -> List[ZenMLSecretMetadata]: + ) -> list[ZenMLSecretMetadata]: """List all ZenML secrets in the secrets store backend. Returns: @@ -160,8 +160,8 @@ class AWSSecretsStoreBackend(BaseSecretsStoreBackend): @staticmethod def _get_aws_secret_filters( - metadata: Dict[str, str], - ) -> List[Dict[str, str]]: + metadata: dict[str, str], + ) -> list[dict[str, str]]: """Convert ZenML secret metadata to AWS secret filters. Args: @@ -170,7 +170,7 @@ def _get_aws_secret_filters( Returns: The AWS secret filters. """ - aws_filters: List[Dict[str, Any]] = [] + aws_filters: list[dict[str, Any]] = [] for k, v in metadata.items(): aws_filters.append( { @@ -193,7 +193,7 @@ def _get_aws_secret_filters( def list_secrets( self, - ) -> List[ZenMLSecretMetadata]: + ) -> list[ZenMLSecretMetadata]: """List all ZenML secrets in the AWS secrets store backend. Returns: @@ -211,7 +211,7 @@ def list_secrets( metadata = self._get_secret_metadata() aws_filters = self._get_aws_secret_filters(metadata) - results: List[ZenMLSecretMetadata] = [] + results: list[ZenMLSecretMetadata] = [] try: # AWS Secrets Manager API pagination is wrapped around the @@ -231,7 +231,7 @@ def list_secrets( for secret in page["SecretList"]: try: # Convert the AWS secret tags to a metadata dictionary. - unpacked_metadata: Dict[str, str] = { + unpacked_metadata: dict[str, str] = { tag["Key"]: tag["Value"] for tag in secret["Tags"] } @@ -274,7 +274,7 @@ def parent_name(self) -> str: def _convert_gcp_secret( self, - labels: Dict[str, str], + labels: dict[str, str], ) -> ZenMLSecretMetadata: """Create a ZenML secret model from data stored in an GCP secret. @@ -317,7 +317,7 @@ def _convert_gcp_secret( updated=updated, ) - def list_secrets(self) -> List[ZenMLSecretMetadata]: + def list_secrets(self) -> list[ZenMLSecretMetadata]: """List all secrets. Returns: @@ -361,7 +361,7 @@ class AzureSecretsStoreBackend(BaseSecretsStoreBackend): def _convert_azure_secret( self, - tags: Dict[str, str], + tags: dict[str, str], ) -> ZenMLSecretMetadata: """Create a ZenML secret model from data stored in an Azure secret. @@ -395,7 +395,7 @@ def _convert_azure_secret( updated=updated, ) - def list_secrets(self) -> List[ZenMLSecretMetadata]: + def list_secrets(self) -> list[ZenMLSecretMetadata]: """List all secrets. Returns: @@ -410,7 +410,7 @@ def list_secrets(self) -> List[ZenMLSecretMetadata]: assert isinstance(self.client, SecretClient) - results: List[ZenMLSecretMetadata] = [] + results: list[ZenMLSecretMetadata] = [] try: all_secrets = self.client.list_properties_of_secrets() @@ -449,7 +449,7 @@ class HashiCorpVaultSecretsStoreBackend(BaseSecretsStoreBackend): def _convert_vault_secret( self, - vault_secret: Dict[str, Any], + vault_secret: dict[str, Any], ) -> ZenMLSecretMetadata: """Create a ZenML secret model from data stored in an HashiCorp Vault secret. @@ -484,7 +484,7 @@ def _convert_vault_secret( updated=updated, ) - def list_secrets(self) -> List[ZenMLSecretMetadata]: + def list_secrets(self) -> list[ZenMLSecretMetadata]: """List all secrets. Note that returned secrets do not include any secret values. To fetch @@ -505,7 +505,7 @@ def list_secrets(self) -> List[ZenMLSecretMetadata]: assert isinstance(self.client, Client) - results: List[ZenMLSecretMetadata] = [] + results: list[ZenMLSecretMetadata] = [] try: # List all ZenML secrets in the Vault @@ -724,4 +724,3 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/904464ea4041_add_pipeline_model_run_unique_constraints.py b/src/zenml/zen_stores/migrations/versions/904464ea4041_add_pipeline_model_run_unique_constraints.py index 8a9441a157e..35a892fbe5f 100644 --- a/src/zenml/zen_stores/migrations/versions/904464ea4041_add_pipeline_model_run_unique_constraints.py +++ b/src/zenml/zen_stores/migrations/versions/904464ea4041_add_pipeline_model_run_unique_constraints.py @@ -7,7 +7,7 @@ """ from collections import defaultdict -from typing import Any, Dict, Set +from typing import Any import sqlalchemy as sa from alembic import op @@ -38,7 +38,7 @@ def resolve_duplicate_entities() -> None: result = connection.execute( sa.select(table.c.id, table.c.name, table.c.workspace_id) ).all() - existing: Dict[str, Set[str]] = defaultdict(set) + existing: dict[str, set[str]] = defaultdict(set) for id_, name, workspace_id in result: names_in_workspace = existing[workspace_id] @@ -71,8 +71,8 @@ def resolve_duplicate_entities() -> None: ) ).all() - existing_names: Dict[str, Set[str]] = defaultdict(set) - existing_numbers: Dict[str, Set[int]] = defaultdict(set) + existing_names: dict[str, set[str]] = defaultdict(set) + existing_numbers: dict[str, set[int]] = defaultdict(set) needs_update = [] @@ -99,7 +99,7 @@ def resolve_duplicate_entities() -> None: needs_new_name, needs_new_number, ) in needs_update: - values: Dict[str, Any] = {} + values: dict[str, Any] = {} is_numeric_version = str(number) == name next_numeric_version = max(existing_numbers[model_id]) + 1 diff --git a/src/zenml/zen_stores/migrations/versions/a39c4184c8ce_remove_secrets_manager_flavors.py b/src/zenml/zen_stores/migrations/versions/a39c4184c8ce_remove_secrets_manager_flavors.py index 04f0296c01c..f2848cd3799 100644 --- a/src/zenml/zen_stores/migrations/versions/a39c4184c8ce_remove_secrets_manager_flavors.py +++ b/src/zenml/zen_stores/migrations/versions/a39c4184c8ce_remove_secrets_manager_flavors.py @@ -29,4 +29,3 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" - pass diff --git a/src/zenml/zen_stores/migrations/versions/b4fca5241eea_migrate_onboarding_state.py b/src/zenml/zen_stores/migrations/versions/b4fca5241eea_migrate_onboarding_state.py index d081aa49654..c529690611c 100644 --- a/src/zenml/zen_stores/migrations/versions/b4fca5241eea_migrate_onboarding_state.py +++ b/src/zenml/zen_stores/migrations/versions/b4fca5241eea_migrate_onboarding_state.py @@ -7,7 +7,6 @@ """ import json -from typing import Dict, List import sqlalchemy as sa from alembic import op @@ -48,11 +47,11 @@ def upgrade() -> None: # -> Migrate to the new server keys state = json.loads(existing_onboarding_state) - if isinstance(state, Dict): + if isinstance(state, dict): for key in state.keys(): if key in ONBOARDING_KEY_MAPPING: new_state.extend(ONBOARDING_KEY_MAPPING[key]) - elif isinstance(state, List): + elif isinstance(state, list): # Somehow the state is already converted, probably shouldn't happen return @@ -163,5 +162,4 @@ def upgrade() -> None: def downgrade() -> None: """Downgrade database schema and/or data back to the previous revision.""" # ### commands auto generated by Alembic - please adjust! ### - pass # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/migrations/versions/b59aa68fdb1f_simplify_pipelines.py b/src/zenml/zen_stores/migrations/versions/b59aa68fdb1f_simplify_pipelines.py index c9dcfb8a307..5ed6cc389ed 100644 --- a/src/zenml/zen_stores/migrations/versions/b59aa68fdb1f_simplify_pipelines.py +++ b/src/zenml/zen_stores/migrations/versions/b59aa68fdb1f_simplify_pipelines.py @@ -6,7 +6,6 @@ """ -from typing import Dict, Optional import sqlalchemy as sa import sqlmodel @@ -62,8 +61,8 @@ def upgrade() -> None: def _migrate_pipeline_columns( pipeline_id: str, - version_hash: Optional[str], - pipeline_spec: Optional[str], + version_hash: str | None, + pipeline_spec: str | None, ) -> None: connection.execute( sa.update(pipeline_deployment_table) @@ -87,7 +86,7 @@ def _update_pipeline_fks(pipeline_id: str, replacement_id: str) -> None: ) all_pipelines = connection.execute(sa.select(pipeline_table)).fetchall() - replacement_mapping: Dict[str, str] = {} + replacement_mapping: dict[str, str] = {} for pipeline in all_pipelines: _migrate_pipeline_columns( diff --git a/src/zenml/zen_stores/migrations/versions/c22561cbb3a9_add_artifact_unique_constraints.py b/src/zenml/zen_stores/migrations/versions/c22561cbb3a9_add_artifact_unique_constraints.py index 701b802a83a..9884879b722 100644 --- a/src/zenml/zen_stores/migrations/versions/c22561cbb3a9_add_artifact_unique_constraints.py +++ b/src/zenml/zen_stores/migrations/versions/c22561cbb3a9_add_artifact_unique_constraints.py @@ -7,7 +7,6 @@ """ from collections import defaultdict -from typing import Dict, Set import sqlalchemy as sa from alembic import op @@ -36,7 +35,7 @@ def resolve_duplicate_versions() -> None: artifact_version_table.c.version, ) - versions_per_artifact: Dict[str, Set[str]] = defaultdict(set) + versions_per_artifact: dict[str, set[str]] = defaultdict(set) for id, artifact_id, version in connection.execute(query).fetchall(): versions = versions_per_artifact[artifact_id] diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 4e4cc8b489f..d6d91f34325 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -22,15 +22,10 @@ from typing import ( Any, ClassVar, - Dict, - List, - Optional, - Sequence, - Tuple, - Type, TypeVar, Union, ) +from collections.abc import Sequence from urllib.parse import urlparse from uuid import UUID, uuid4 @@ -293,7 +288,7 @@ logger = get_logger(__name__) # type alias for possible json payloads (the Anys are recursive Json instances) -Json = Union[Dict[str, Any], List[Any], str, int, float, bool, None] +Json = Union[dict[str, Any], list[Any], str, int, float, bool, None] AnyRequest = TypeVar("AnyRequest", bound=BaseRequest) @@ -323,7 +318,7 @@ class RestZenStoreConfiguration(StoreConfiguration): type: StoreType = StoreType.REST - verify_ssl: Union[bool, str] = Field( + verify_ssl: bool | str = Field( default=True, union_mode="left_to_right" ) http_timeout: int = DEFAULT_HTTP_TIMEOUT @@ -361,8 +356,8 @@ def validate_url(cls, url: str) -> str: @field_validator("verify_ssl") @classmethod def validate_verify_ssl( - cls, verify_ssl: Union[bool, str] - ) -> Union[bool, str]: + cls, verify_ssl: bool | str + ) -> bool | str: """Validates that the verify_ssl either points to a file or is a bool. Args: @@ -381,7 +376,7 @@ def validate_verify_ssl( return verify_ssl if os.path.isfile(verify_ssl): - with open(verify_ssl, "r") as f: + with open(verify_ssl) as f: cert_content = f.read() fileio.makedirs(str(secret_folder)) @@ -408,7 +403,7 @@ def supports_url_scheme(cls, url: str) -> bool: @model_validator(mode="before") @classmethod @before_validator_handler - def _move_credentials(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def _move_credentials(cls, data: dict[str, Any]) -> dict[str, Any]: """Moves credentials (API keys, API tokens, passwords) from the config to the credentials store. Args: @@ -459,12 +454,12 @@ class RestZenStore(BaseZenStore): config: RestZenStoreConfiguration TYPE: ClassVar[StoreType] = StoreType.REST - CONFIG_TYPE: ClassVar[Type[StoreConfiguration]] = RestZenStoreConfiguration - _api_token: Optional[APIToken] = None - _session: Optional[requests.Session] = None - _server_info: Optional[ServerModel] = None + CONFIG_TYPE: ClassVar[type[StoreConfiguration]] = RestZenStoreConfiguration + _api_token: APIToken | None = None + _session: requests.Session | None = None + _server_info: ServerModel | None = None _session_lock: RLock = PrivateAttr(default_factory=RLock) - _last_authenticated: Optional[datetime] = None + _last_authenticated: datetime | None = None # ==================================== # ZenML Store interface implementation @@ -715,7 +710,7 @@ def create_api_key( def get_api_key( self, service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, hydrate: bool = True, ) -> APIKeyResponse: """Get an API key for a service account. @@ -766,7 +761,7 @@ def list_api_keys( def update_api_key( self, service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, api_key_update: APIKeyUpdate, ) -> APIKeyResponse: """Update an API key for a service account. @@ -790,7 +785,7 @@ def update_api_key( def rotate_api_key( self, service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, rotate_request: APIKeyRotateRequest, ) -> APIKeyResponse: """Rotate an API key for a service account. @@ -813,7 +808,7 @@ def rotate_api_key( def delete_api_key( self, service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, ) -> None: """Delete an API key for a service account. @@ -1019,8 +1014,8 @@ def create_artifact_version( ) def batch_create_artifact_versions( - self, artifact_versions: List[ArtifactVersionRequest] - ) -> List[ArtifactVersionResponse]: + self, artifact_versions: list[ArtifactVersionRequest] + ) -> list[ArtifactVersionResponse]: """Creates a batch of artifact versions. Args: @@ -1112,7 +1107,7 @@ def delete_artifact_version(self, artifact_version_id: UUID) -> None: def prune_artifact_versions( self, - project_name_or_id: Union[str, UUID], + project_name_or_id: str | UUID, only_versions: bool = True, ) -> None: """Prunes unused artifact versions and their artifacts. @@ -1659,8 +1654,8 @@ def get_snapshot( self, snapshot_id: UUID, hydrate: bool = True, - step_configuration_filter: Optional[List[str]] = None, - include_config_schema: Optional[bool] = None, + step_configuration_filter: list[str] | None = None, + include_config_schema: bool | None = None, ) -> PipelineSnapshotResponse: """Get a snapshot with a given ID. @@ -2032,7 +2027,7 @@ def delete_run_template(self, template_id: UUID) -> None: def run_template( self, template_id: UUID, - run_configuration: Optional[PipelineRunConfiguration] = None, + run_configuration: PipelineRunConfiguration | None = None, ) -> PipelineRunResponse: """Run a template. @@ -2160,7 +2155,7 @@ def delete_event_source(self, event_source_id: UUID) -> None: def get_or_create_run( self, pipeline_run: PipelineRunRequest - ) -> Tuple[PipelineRunResponse, bool]: + ) -> tuple[PipelineRunResponse, bool]: """Gets or creates a pipeline run. If a run with the same ID or name already exists, it is returned. @@ -2498,7 +2493,7 @@ def backup_secrets( this flag effectively moves all secrets from the primary secrets store to the backup secrets store. """ - params: Dict[str, Any] = { + params: dict[str, Any] = { "ignore_errors": ignore_errors, "delete_secrets": delete_secrets, } @@ -2520,7 +2515,7 @@ def restore_secrets( this flag effectively moves all secrets from the backup secrets store to the primary secrets store. """ - params: Dict[str, Any] = { + params: dict[str, Any] = { "ignore_errors": ignore_errors, "delete_secrets": delete_secrets, } @@ -2550,7 +2545,7 @@ def create_service_account( def get_service_account( self, - service_account_name_or_id: Union[str, UUID], + service_account_name_or_id: str | UUID, hydrate: bool = True, ) -> ServiceAccountResponse: """Gets a specific service account. @@ -2594,7 +2589,7 @@ def list_service_accounts( def update_service_account( self, - service_account_name_or_id: Union[str, UUID], + service_account_name_or_id: str | UUID, service_account_update: ServiceAccountUpdate, ) -> ServiceAccountResponse: """Updates an existing service account. @@ -2617,7 +2612,7 @@ def update_service_account( def delete_service_account( self, - service_account_name_or_id: Union[str, UUID], + service_account_name_or_id: str | UUID, ) -> None: """Delete a service account. @@ -2792,9 +2787,9 @@ def delete_service_connector(self, service_connector_id: UUID) -> None: def _populate_connector_type( self, - *connector_models: Union[ - ServiceConnectorResponse, ServiceConnectorResourcesModel - ], + *connector_models: ( + ServiceConnectorResponse | ServiceConnectorResourcesModel + ), ) -> None: """Populates or updates the connector type of the given connector or resource models. @@ -2877,8 +2872,8 @@ def verify_service_connector_config( def verify_service_connector( self, service_connector_id: UUID, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + resource_id: str | None = None, list_resources: bool = True, ) -> ServiceConnectorResourcesModel: """Verifies if a service connector instance has access to one or more resources. @@ -2895,7 +2890,7 @@ def verify_service_connector( The list of resources that the service connector has access to, scoped to the supplied resource type and ID, if provided. """ - params: Dict[str, Any] = {"list_resources": list_resources} + params: dict[str, Any] = {"list_resources": list_resources} if resource_type: params["resource_type"] = resource_type if resource_id: @@ -2918,8 +2913,8 @@ def verify_service_connector( def get_service_connector_client( self, service_connector_id: UUID, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + resource_id: str | None = None, ) -> ServiceConnectorResponse: """Get a service connector client for a service connector and given resource. @@ -2957,7 +2952,7 @@ def get_service_connector_client( def list_service_connector_resources( self, filter_model: ServiceConnectorFilter, - ) -> List[ServiceConnectorResourcesModel]: + ) -> list[ServiceConnectorResourcesModel]: """List resources that can be accessed by service connectors. Args: @@ -3026,10 +3021,10 @@ def list_service_connector_resources( def list_service_connector_types( self, - connector_type: Optional[str] = None, - resource_type: Optional[str] = None, - auth_method: Optional[str] = None, - ) -> List[ServiceConnectorTypeModel]: + connector_type: str | None = None, + resource_type: str | None = None, + auth_method: str | None = None, + ) -> list[ServiceConnectorTypeModel]: """Get a list of service connector types. Args: @@ -3100,7 +3095,7 @@ def get_service_connector_type( """ # Use the local registry to get the service connector type, if it # exists. - local_connector_type: Optional[ServiceConnectorTypeModel] = None + local_connector_type: ServiceConnectorTypeModel | None = None if service_connector_registry.is_registered(connector_type): local_connector_type = ( service_connector_registry.get_service_connector_type( @@ -3242,7 +3237,7 @@ def get_stack_deployment_config( self, provider: StackDeploymentProvider, stack_name: str, - location: Optional[str] = None, + location: str | None = None, ) -> StackDeploymentConfig: """Return the cloud provider console URL and configuration needed to deploy the ZenML stack. @@ -3268,9 +3263,9 @@ def get_stack_deployment_stack( self, provider: StackDeploymentProvider, stack_name: str, - location: Optional[str] = None, - date_start: Optional[datetime] = None, - ) -> Optional[DeployedStack]: + location: str | None = None, + date_start: datetime | None = None, + ) -> DeployedStack | None: """Return a matching ZenML stack that was deployed and registered. Args: @@ -3552,7 +3547,7 @@ def create_user(self, user: UserRequest) -> UserResponse: def get_user( self, - user_name_or_id: Optional[Union[str, UUID]] = None, + user_name_or_id: str | UUID | None = None, include_private: bool = False, hydrate: bool = True, ) -> UserResponse: @@ -3626,7 +3621,7 @@ def update_user( ) def deactivate_user( - self, user_name_or_id: Union[str, UUID] + self, user_name_or_id: str | UUID ) -> UserResponse: """Deactivates a user. @@ -3642,7 +3637,7 @@ def deactivate_user( return UserResponse.model_validate(response_body) - def delete_user(self, user_name_or_id: Union[str, UUID]) -> None: + def delete_user(self, user_name_or_id: str | UUID) -> None: """Deletes a user. Args: @@ -3671,7 +3666,7 @@ def create_project(self, project: ProjectRequest) -> ProjectResponse: ) def get_project( - self, project_name_or_id: Union[UUID, str], hydrate: bool = True + self, project_name_or_id: UUID | str, hydrate: bool = True ) -> ProjectResponse: """Get an existing project by name or ID. @@ -3732,7 +3727,7 @@ def update_project( response_model=ProjectResponse, ) - def delete_project(self, project_name_or_id: Union[str, UUID]) -> None: + def delete_project(self, project_name_or_id: str | UUID) -> None: """Deletes a project. Args: @@ -3975,7 +3970,7 @@ def list_model_version_artifact_links( def delete_model_version_artifact_link( self, model_version_id: UUID, - model_version_artifact_link_name_or_id: Union[str, UUID], + model_version_artifact_link_name_or_id: str | UUID, ) -> None: """Deletes a model version to artifact link. @@ -4055,7 +4050,7 @@ def list_model_version_pipeline_run_links( def delete_model_version_pipeline_run_link( self, model_version_id: UUID, - model_version_pipeline_run_link_name_or_id: Union[str, UUID], + model_version_pipeline_run_link_name_or_id: str | UUID, ) -> None: """Deletes a model version to pipeline run link. @@ -4145,10 +4140,10 @@ def delete_authorized_device(self, device_id: UUID) -> None: def get_api_token( self, token_type: APITokenType = APITokenType.WORKLOAD, - expires_in: Optional[int] = None, - schedule_id: Optional[UUID] = None, - pipeline_run_id: Optional[UUID] = None, - deployment_id: Optional[UUID] = None, + expires_in: int | None = None, + schedule_id: UUID | None = None, + pipeline_run_id: UUID | None = None, + deployment_id: UUID | None = None, ) -> str: """Get an API token. @@ -4165,7 +4160,7 @@ def get_api_token( Raises: ValueError: if the server response is not valid. """ - params: Dict[str, Any] = { + params: dict[str, Any] = { "token_type": token_type.value, } if expires_in: @@ -4303,8 +4298,8 @@ def create_tag_resource( ) def batch_create_tag_resource( - self, tag_resources: List[TagResourceRequest] - ) -> List[TagResourceResponse]: + self, tag_resources: list[TagResourceRequest] + ) -> list[TagResourceResponse]: """Create a batch of tag resource relationships. Args: @@ -4331,7 +4326,7 @@ def delete_tag_resource( self.delete(path=TAG_RESOURCES, body=tag_resource) def batch_delete_tag_resource( - self, tag_resources: List[TagResourceRequest] + self, tag_resources: list[TagResourceRequest] ) -> None: """Delete a batch of tag resources. @@ -4375,11 +4370,11 @@ def get_or_generate_api_token(self) -> str: f"Authentication token for {self.url} expired; refreshing..." ) - data: Optional[Dict[str, str]] = None + data: dict[str, str] | None = None # Use a custom user agent to identify the ZenML client in the server # logs. - headers: Dict[str, str] = { + headers: dict[str, str] = { "User-Agent": "zenml/" + zenml.__version__, } @@ -4685,8 +4680,8 @@ def _request( self, method: str, url: str, - params: Optional[Dict[str, Any]] = None, - timeout: Optional[int] = None, + params: dict[str, Any] | None = None, + timeout: int | None = None, **kwargs: Any, ) -> Json: """Make a request to the REST API. @@ -4831,8 +4826,8 @@ def _request( def get( self, path: str, - params: Optional[Dict[str, Any]] = None, - timeout: Optional[int] = None, + params: dict[str, Any] | None = None, + timeout: int | None = None, **kwargs: Any, ) -> Json: """Make a GET request to the given endpoint path. @@ -4857,9 +4852,9 @@ def get( def delete( self, path: str, - body: Optional[BaseModel] = None, - params: Optional[Dict[str, Any]] = None, - timeout: Optional[int] = None, + body: BaseModel | None = None, + params: dict[str, Any] | None = None, + timeout: int | None = None, **kwargs: Any, ) -> Json: """Make a DELETE request to the given endpoint path. @@ -4887,8 +4882,8 @@ def post( self, path: str, body: BaseModel, - params: Optional[Dict[str, Any]] = None, - timeout: Optional[int] = None, + params: dict[str, Any] | None = None, + timeout: int | None = None, **kwargs: Any, ) -> Json: """Make a POST request to the given endpoint path. @@ -4915,9 +4910,9 @@ def post( def put( self, path: str, - body: Optional[BaseModel] = None, - params: Optional[Dict[str, Any]] = None, - timeout: Optional[int] = None, + body: BaseModel | None = None, + params: dict[str, Any] | None = None, + timeout: int | None = None, **kwargs: Any, ) -> Json: """Make a PUT request to the given endpoint path. @@ -4947,9 +4942,9 @@ def put( def _create_resource( self, resource: AnyRequest, - response_model: Type[AnyResponse], + response_model: type[AnyResponse], route: str, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> AnyResponse: """Create a new resource. @@ -4969,11 +4964,11 @@ def _create_resource( def _batch_create_resources( self, - resources: List[AnyRequest], - response_model: Type[AnyResponse], + resources: list[AnyRequest], + response_model: type[AnyResponse], route: str, - params: Optional[Dict[str, Any]] = None, - ) -> List[AnyResponse]: + params: dict[str, Any] | None = None, + ) -> list[AnyResponse]: """Create a new batch of resources. Args: @@ -5004,10 +4999,10 @@ def _batch_create_resources( def _get_or_create_resource( self, resource: AnyRequest, - response_model: Type[AnyResponse], + response_model: type[AnyResponse], route: str, - params: Optional[Dict[str, Any]] = None, - ) -> Tuple[AnyResponse, bool]: + params: dict[str, Any] | None = None, + ) -> tuple[AnyResponse, bool]: """Get or create a resource. Args: @@ -5053,10 +5048,10 @@ def _get_or_create_resource( def _get_resource( self, - resource_id: Union[str, int, UUID], + resource_id: str | int | UUID, route: str, - response_model: Type[AnyResponse], - params: Optional[Dict[str, Any]] = None, + response_model: type[AnyResponse], + params: dict[str, Any] | None = None, ) -> AnyResponse: """Retrieve a single resource. @@ -5075,9 +5070,9 @@ def _get_resource( def _list_paginated_resources( self, route: str, - response_model: Type[AnyResponse], + response_model: type[AnyResponse], filter_model: BaseFilter, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> Page[AnyResponse]: """Retrieve a list of resources filtered by some criteria. @@ -5113,9 +5108,9 @@ def _list_paginated_resources( def _list_resources( self, route: str, - response_model: Type[AnyResponse], + response_model: type[AnyResponse], **filters: Any, - ) -> List[AnyResponse]: + ) -> list[AnyResponse]: """Retrieve a list of resources filtered by some criteria. Args: @@ -5140,11 +5135,11 @@ def _list_resources( def _update_resource( self, - resource_id: Union[str, int, UUID], + resource_id: str | int | UUID, resource_update: BaseModel, - response_model: Type[AnyResponse], + response_model: type[AnyResponse], route: str, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> AnyResponse: """Update an existing resource. @@ -5167,9 +5162,9 @@ def _update_resource( def _delete_resource( self, - resource_id: Union[str, UUID], + resource_id: str | UUID, route: str, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> None: """Delete a resource. @@ -5184,7 +5179,7 @@ def _batch_delete_resources( self, resources: Sequence[BaseModel], route: str, - params: Optional[Dict[str, Any]] = None, + params: dict[str, Any] | None = None, ) -> None: """Delete a batch of resources. diff --git a/src/zenml/zen_stores/schemas/action_schemas.py b/src/zenml/zen_stores/schemas/action_schemas.py index de15011450f..e9b74cdded5 100644 --- a/src/zenml/zen_stores/schemas/action_schemas.py +++ b/src/zenml/zen_stores/schemas/action_schemas.py @@ -15,7 +15,8 @@ import base64 import json -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Sequence from uuid import UUID from pydantic.json import pydantic_encoder @@ -65,7 +66,7 @@ class ActionSchema(NamedSchema, table=True): ) project: "ProjectSchema" = Relationship(back_populates="actions") - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -78,7 +79,7 @@ class ActionSchema(NamedSchema, table=True): sa_relationship_kwargs={"foreign_keys": "[ActionSchema.user_id]"}, ) - triggers: List["TriggerSchema"] = Relationship(back_populates="action") + triggers: list["TriggerSchema"] = Relationship(back_populates="action") service_account_id: UUID = build_foreign_key_field( source=__tablename__, @@ -98,7 +99,7 @@ class ActionSchema(NamedSchema, table=True): flavor: str = Field(nullable=False) plugin_subtype: str = Field(nullable=False) - description: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) + description: str | None = Field(sa_column=Column(TEXT, nullable=True)) configuration: bytes diff --git a/src/zenml/zen_stores/schemas/api_key_schemas.py b/src/zenml/zen_stores/schemas/api_key_schemas.py index 7dd0c50c7de..f2b6d1fe6bc 100644 --- a/src/zenml/zen_stores/schemas/api_key_schemas.py +++ b/src/zenml/zen_stores/schemas/api_key_schemas.py @@ -15,7 +15,8 @@ from datetime import datetime from secrets import token_hex -from typing import Any, Optional, Sequence, Tuple +from typing import Any +from collections.abc import Sequence from uuid import UUID from passlib.context import CryptContext @@ -55,11 +56,11 @@ class APIKeySchema(NamedSchema, table=True): description: str = Field(sa_column=Column(TEXT)) key: str - previous_key: Optional[str] = Field(default=None, nullable=True) + previous_key: str | None = Field(default=None, nullable=True) retain_period: int = Field(default=0) active: bool = Field(default=True) - last_login: Optional[datetime] = None - last_rotated: Optional[datetime] = None + last_login: datetime | None = None + last_rotated: datetime | None = None service_account_id: UUID = build_foreign_key_field( source=__tablename__, @@ -123,7 +124,7 @@ def from_request( cls, service_account_id: UUID, request: APIKeyRequest, - ) -> Tuple["APIKeySchema", str]: + ) -> tuple["APIKeySchema", str]: """Convert a `APIKeyRequest` to a `APIKeySchema`. Args: @@ -257,7 +258,7 @@ def internal_update(self, update: APIKeyInternalUpdate) -> "APIKeySchema": def rotate( self, rotate_request: APIKeyRotateRequest, - ) -> Tuple["APIKeySchema", str]: + ) -> tuple["APIKeySchema", str]: """Rotate the key for an `APIKeySchema`. Args: diff --git a/src/zenml/zen_stores/schemas/api_transaction_schemas.py b/src/zenml/zen_stores/schemas/api_transaction_schemas.py index 14a524505ed..8de5f116cc0 100644 --- a/src/zenml/zen_stores/schemas/api_transaction_schemas.py +++ b/src/zenml/zen_stores/schemas/api_transaction_schemas.py @@ -14,7 +14,7 @@ """SQLModel implementation of idempotent API transaction tables.""" from datetime import datetime, timedelta -from typing import Any, Optional +from typing import Any from uuid import UUID from sqlalchemy import TEXT, Column, String @@ -53,7 +53,7 @@ class ApiTransactionSchema(BaseSchema, table=True): method: str url: str = Field(sa_column=Column(TEXT, nullable=False)) completed: bool = Field(default=False) - result: Optional[str] = Field( + result: str | None = Field( default=None, sa_column=Column( String(length=MEDIUMTEXT_MAX_LENGTH).with_variant( @@ -62,7 +62,7 @@ class ApiTransactionSchema(BaseSchema, table=True): nullable=True, ), ) - expired: Optional[datetime] = Field(default=None, nullable=True) + expired: datetime | None = Field(default=None, nullable=True) user_id: UUID = build_foreign_key_field( source=__tablename__, diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index 17884129511..7f6e29973dc 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """SQLModel implementation of artifact table.""" -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Sequence from uuid import UUID from pydantic import ValidationError @@ -84,11 +85,11 @@ class ArtifactSchema(NamedSchema, table=True): # Fields has_custom_name: bool - versions: List["ArtifactVersionSchema"] = Relationship( + versions: list["ArtifactVersionSchema"] = Relationship( back_populates="artifact", sa_relationship_kwargs={"cascade": "delete"}, ) - tags: List["TagSchema"] = Relationship( + tags: list["TagSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.ARTIFACT.value}', foreign(TagResourceSchema.resource_id)==ArtifactSchema.id)", secondary="tag_resource", @@ -108,7 +109,7 @@ class ArtifactSchema(NamedSchema, table=True): ) project: "ProjectSchema" = Relationship() - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -284,12 +285,12 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True): # Fields version: str - version_number: Optional[int] + version_number: int | None type: str uri: str = Field(sa_column=Column(TEXT, nullable=False)) materializer: str = Field(sa_column=Column(TEXT, nullable=False)) data_type: str = Field(sa_column=Column(TEXT, nullable=False)) - tags: List["TagSchema"] = Relationship( + tags: list["TagSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.ARTIFACT_VERSION.value}', foreign(TagResourceSchema.resource_id)==ArtifactVersionSchema.id)", secondary="tag_resource", @@ -299,7 +300,7 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True): ), ) save_type: str = Field(sa_column=Column(TEXT, nullable=False)) - content_hash: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) + content_hash: str | None = Field(sa_column=Column(TEXT, nullable=True)) # Foreign keys artifact_id: UUID = build_foreign_key_field( @@ -310,7 +311,7 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True): ondelete="CASCADE", nullable=False, ) - artifact_store_id: Optional[UUID] = build_foreign_key_field( + artifact_store_id: UUID | None = build_foreign_key_field( source=__tablename__, target=StackComponentSchema.__tablename__, source_column="artifact_store_id", @@ -318,7 +319,7 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True): ondelete="SET NULL", nullable=True, ) - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -341,7 +342,7 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True): back_populates="artifact_versions" ) project: "ProjectSchema" = Relationship(back_populates="artifact_versions") - run_metadata: List["RunMetadataSchema"] = Relationship( + run_metadata: list["RunMetadataSchema"] = Relationship( sa_relationship_kwargs=dict( secondary="run_metadata_resource", primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.ARTIFACT_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ArtifactVersionSchema.id)", @@ -349,23 +350,23 @@ class ArtifactVersionSchema(BaseSchema, RunMetadataInterface, table=True): overlaps="run_metadata", ), ) - visualizations: List["ArtifactVisualizationSchema"] = Relationship( + visualizations: list["ArtifactVisualizationSchema"] = Relationship( back_populates="artifact_version", sa_relationship_kwargs={"cascade": "delete"}, ) # Needed for cascade deletion behavior - model_versions_artifacts_links: List["ModelVersionArtifactSchema"] = ( + model_versions_artifacts_links: list["ModelVersionArtifactSchema"] = ( Relationship( back_populates="artifact_version", sa_relationship_kwargs={"cascade": "delete"}, ) ) - output_of_step_runs: List["StepRunOutputArtifactSchema"] = Relationship( + output_of_step_runs: list["StepRunOutputArtifactSchema"] = Relationship( back_populates="artifact_version", sa_relationship_kwargs={"cascade": "delete"}, ) - input_of_step_runs: List["StepRunInputArtifactSchema"] = Relationship( + input_of_step_runs: list["StepRunInputArtifactSchema"] = Relationship( back_populates="artifact_version", sa_relationship_kwargs={"cascade": "delete"}, ) @@ -410,7 +411,7 @@ def get_query_options( return options @property - def producer_run_ids(self) -> Optional[Tuple[UUID, UUID]]: + def producer_run_ids(self) -> tuple[UUID, UUID] | None: """Fetch the producer run IDs for this artifact version. Raises: diff --git a/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py b/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py index d7e1576f71d..62978d1de09 100644 --- a/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """SQLModel implementation of artifact visualization table.""" -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any from uuid import UUID from sqlalchemy import TEXT, Column @@ -60,7 +60,7 @@ class ArtifactVisualizationSchema(BaseSchema, table=True): artifact_version: ArtifactVersionSchema = Relationship( back_populates="visualizations" ) - curated_visualizations: List["CuratedVisualizationSchema"] = Relationship( + curated_visualizations: list["CuratedVisualizationSchema"] = Relationship( back_populates="artifact_visualization", sa_relationship_kwargs=dict( order_by="CuratedVisualizationSchema.display_order", diff --git a/src/zenml/zen_stores/schemas/base_schemas.py b/src/zenml/zen_stores/schemas/base_schemas.py index 412cd806387..fd7a929896a 100644 --- a/src/zenml/zen_stores/schemas/base_schemas.py +++ b/src/zenml/zen_stores/schemas/base_schemas.py @@ -14,7 +14,8 @@ """Base classes for SQLModel schemas.""" from datetime import datetime -from typing import TYPE_CHECKING, Any, Sequence, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar +from collections.abc import Sequence from uuid import UUID, uuid4 from sqlalchemy.sql.base import ExecutableOption diff --git a/src/zenml/zen_stores/schemas/code_repository_schemas.py b/src/zenml/zen_stores/schemas/code_repository_schemas.py index fd616b80cb3..1126eb39be2 100644 --- a/src/zenml/zen_stores/schemas/code_repository_schemas.py +++ b/src/zenml/zen_stores/schemas/code_repository_schemas.py @@ -14,7 +14,8 @@ """SQL Model Implementations for code repositories.""" import json -from typing import Any, Optional, Sequence +from typing import Any, Optional +from collections.abc import Sequence from uuid import UUID from sqlalchemy import TEXT, Column, UniqueConstraint @@ -64,7 +65,7 @@ class CodeRepositorySchema(NamedSchema, table=True): ) project: "ProjectSchema" = Relationship(back_populates="code_repositories") - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -79,8 +80,8 @@ class CodeRepositorySchema(NamedSchema, table=True): config: str = Field(sa_column=Column(TEXT, nullable=False)) source: str = Field(sa_column=Column(TEXT, nullable=False)) - logo_url: Optional[str] = Field() - description: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) + logo_url: str | None = Field() + description: str | None = Field(sa_column=Column(TEXT, nullable=True)) @classmethod def get_query_options( diff --git a/src/zenml/zen_stores/schemas/component_schemas.py b/src/zenml/zen_stores/schemas/component_schemas.py index 8165f2bcc51..8f2668d7a94 100644 --- a/src/zenml/zen_stores/schemas/component_schemas.py +++ b/src/zenml/zen_stores/schemas/component_schemas.py @@ -15,7 +15,8 @@ import base64 import json -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Sequence from uuid import UUID from sqlalchemy import UniqueConstraint @@ -66,10 +67,10 @@ class StackComponentSchema(NamedSchema, table=True): type: str flavor: str configuration: bytes - labels: Optional[bytes] - environment: Optional[bytes] = Field(default=None) + labels: bytes | None + environment: bytes | None = Field(default=None) - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -79,14 +80,14 @@ class StackComponentSchema(NamedSchema, table=True): ) user: Optional["UserSchema"] = Relationship(back_populates="components") - stacks: List["StackSchema"] = Relationship( + stacks: list["StackSchema"] = Relationship( back_populates="components", link_model=StackCompositionSchema ) - schedules: List["ScheduleSchema"] = Relationship( + schedules: list["ScheduleSchema"] = Relationship( back_populates="orchestrator", ) - run_metadata: List["RunMetadataSchema"] = Relationship( + run_metadata: list["RunMetadataSchema"] = Relationship( back_populates="stack_component", ) flavor_schema: Optional["FlavorSchema"] = Relationship( @@ -97,12 +98,12 @@ class StackComponentSchema(NamedSchema, table=True): }, ) - run_or_step_logs: List["LogsSchema"] = Relationship( + run_or_step_logs: list["LogsSchema"] = Relationship( back_populates="artifact_store", sa_relationship_kwargs={"cascade": "delete", "uselist": True}, ) - connector_id: Optional[UUID] = build_foreign_key_field( + connector_id: UUID | None = build_foreign_key_field( source=__tablename__, target=ServiceConnectorSchema.__tablename__, source_column="connector_id", @@ -114,8 +115,8 @@ class StackComponentSchema(NamedSchema, table=True): back_populates="components" ) - connector_resource_id: Optional[str] - secrets: List["SecretSchema"] = Relationship( + connector_resource_id: str | None + secrets: list["SecretSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=f"and_(foreign(SecretResourceSchema.resource_type)=='{SecretResourceTypes.STACK_COMPONENT.value}', foreign(SecretResourceSchema.resource_id)==StackComponentSchema.id)", secondary="secret_resource", @@ -166,7 +167,7 @@ def get_query_options( def from_request( cls, request: "ComponentRequest", - service_connector: Optional[ServiceConnectorSchema] = None, + service_connector: ServiceConnectorSchema | None = None, ) -> "StackComponentSchema": """Create a component schema from a request. diff --git a/src/zenml/zen_stores/schemas/curated_visualization_schemas.py b/src/zenml/zen_stores/schemas/curated_visualization_schemas.py index e2490fc971a..92dc55ccdcf 100644 --- a/src/zenml/zen_stores/schemas/curated_visualization_schemas.py +++ b/src/zenml/zen_stores/schemas/curated_visualization_schemas.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """SQLModel implementation of curated visualization tables.""" -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, Any +from collections.abc import Sequence from uuid import UUID from sqlalchemy import UniqueConstraint @@ -74,8 +75,8 @@ class CuratedVisualizationSchema(BaseSchema, table=True): custom_constraint_name="fk_curated_visualization_artifact_visualization_id", ) - display_name: Optional[str] = Field(default=None) - display_order: Optional[int] = Field(default=None) + display_name: str | None = Field(default=None) + display_order: int | None = Field(default=None) layout_size: str = Field( default=CuratedVisualizationSize.FULL_WIDTH.value, nullable=False, @@ -106,7 +107,7 @@ def get_query_options( Returns: A list of query options. """ - options: List[ExecutableOption] = [] + options: list[ExecutableOption] = [] if include_resources: options.append(selectinload(jl_arg(cls.artifact_visualization))) diff --git a/src/zenml/zen_stores/schemas/deployment_schemas.py b/src/zenml/zen_stores/schemas/deployment_schemas.py index 98d7327e575..591be21ae42 100644 --- a/src/zenml/zen_stores/schemas/deployment_schemas.py +++ b/src/zenml/zen_stores/schemas/deployment_schemas.py @@ -14,7 +14,8 @@ """SQLModel implementation of pipeline deployments table.""" import json -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Sequence from uuid import UUID from sqlalchemy import TEXT, Column, UniqueConstraint @@ -80,7 +81,7 @@ class DeploymentSchema(NamedSchema, table=True): ) project: "ProjectSchema" = Relationship(back_populates="deployments") - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -91,11 +92,11 @@ class DeploymentSchema(NamedSchema, table=True): user: Optional["UserSchema"] = Relationship(back_populates="deployments") status: str - url: Optional[str] = Field( + url: str | None = Field( default=None, sa_column=Column(TEXT, nullable=True), ) - auth_key: Optional[str] = Field( + auth_key: str | None = Field( default=None, sa_column=Column(TEXT, nullable=True), ) @@ -108,7 +109,7 @@ class DeploymentSchema(NamedSchema, table=True): nullable=False, ), ) - snapshot_id: Optional[UUID] = build_foreign_key_field( + snapshot_id: UUID | None = build_foreign_key_field( source=__tablename__, target=PipelineSnapshotSchema.__tablename__, source_column="snapshot_id", @@ -120,7 +121,7 @@ class DeploymentSchema(NamedSchema, table=True): back_populates="deployment", ) - deployer_id: Optional[UUID] = build_foreign_key_field( + deployer_id: UUID | None = build_foreign_key_field( source=__tablename__, target=StackComponentSchema.__tablename__, source_column="deployer_id", @@ -130,7 +131,7 @@ class DeploymentSchema(NamedSchema, table=True): ) deployer: Optional["StackComponentSchema"] = Relationship() - tags: List["TagSchema"] = Relationship( + tags: list["TagSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.DEPLOYMENT.value}', foreign(TagResourceSchema.resource_id)==DeploymentSchema.id)", secondary="tag_resource", @@ -140,7 +141,7 @@ class DeploymentSchema(NamedSchema, table=True): ), ) - visualizations: List["CuratedVisualizationSchema"] = Relationship( + visualizations: list["CuratedVisualizationSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=( "and_(CuratedVisualizationSchema.resource_type" @@ -204,7 +205,7 @@ def to_model( Returns: The created `DeploymentResponse`. """ - status: Optional[DeploymentStatus] = None + status: DeploymentStatus | None = None if self.status in DeploymentStatus.values(): status = DeploymentStatus(self.status) elif self.status is not None: diff --git a/src/zenml/zen_stores/schemas/device_schemas.py b/src/zenml/zen_stores/schemas/device_schemas.py index 8f99e6fb2f8..1fcc40acbdd 100644 --- a/src/zenml/zen_stores/schemas/device_schemas.py +++ b/src/zenml/zen_stores/schemas/device_schemas.py @@ -15,7 +15,8 @@ from datetime import datetime, timedelta from secrets import token_hex -from typing import Any, Optional, Sequence, Tuple +from typing import Any, Optional +from collections.abc import Sequence from uuid import UUID from passlib.context import CryptContext @@ -51,19 +52,19 @@ class OAuthDeviceSchema(BaseSchema, table=True): device_code: str status: str failed_auth_attempts: int = 0 - expires: Optional[datetime] = None - last_login: Optional[datetime] = None + expires: datetime | None = None + last_login: datetime | None = None trusted_device: bool = False - os: Optional[str] = None - ip_address: Optional[str] = None - hostname: Optional[str] = None - python_version: Optional[str] = None - zenml_version: Optional[str] = None - city: Optional[str] = None - region: Optional[str] = None - country: Optional[str] = None - - user_id: Optional[UUID] = build_foreign_key_field( + os: str | None = None + ip_address: str | None = None + hostname: str | None = None + python_version: str | None = None + zenml_version: str | None = None + city: str | None = None + region: str | None = None + country: str | None = None + + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -137,7 +138,7 @@ def _get_hashed_code(cls, code: str) -> str: @classmethod def from_request( cls, request: OAuthDeviceInternalRequest - ) -> Tuple["OAuthDeviceSchema", str, str]: + ) -> tuple["OAuthDeviceSchema", str, str]: """Create an authorized device DB entry from a device authorization request. Args: @@ -199,7 +200,7 @@ def update(self, device_update: OAuthDeviceUpdate) -> "OAuthDeviceSchema": def internal_update( self, device_update: OAuthDeviceInternalUpdate - ) -> Tuple["OAuthDeviceSchema", Optional[str], Optional[str]]: + ) -> tuple["OAuthDeviceSchema", str | None, str | None]: """Update an authorized device from an internal device update model. Args: @@ -210,8 +211,8 @@ def internal_update( code, if they were generated. """ now = utc_now() - user_code: Optional[str] = None - device_code: Optional[str] = None + user_code: str | None = None + device_code: str | None = None # This call also takes care of setting fields that have the same # name in the internal model and the schema. diff --git a/src/zenml/zen_stores/schemas/event_source_schemas.py b/src/zenml/zen_stores/schemas/event_source_schemas.py index 96dcc354163..e7c14979ae0 100644 --- a/src/zenml/zen_stores/schemas/event_source_schemas.py +++ b/src/zenml/zen_stores/schemas/event_source_schemas.py @@ -15,7 +15,8 @@ import base64 import json -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, cast +from typing import TYPE_CHECKING, Any, Optional, cast +from collections.abc import Sequence from uuid import UUID from sqlalchemy import TEXT, Column, UniqueConstraint @@ -69,7 +70,7 @@ class EventSourceSchema(NamedSchema, table=True): ) project: "ProjectSchema" = Relationship(back_populates="event_sources") - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -79,7 +80,7 @@ class EventSourceSchema(NamedSchema, table=True): ) user: Optional["UserSchema"] = Relationship(back_populates="event_sources") - triggers: List["TriggerSchema"] = Relationship( + triggers: list["TriggerSchema"] = Relationship( back_populates="event_source" ) diff --git a/src/zenml/zen_stores/schemas/flavor_schemas.py b/src/zenml/zen_stores/schemas/flavor_schemas.py index 0bdbcc849d1..0bf2bbe5fd3 100644 --- a/src/zenml/zen_stores/schemas/flavor_schemas.py +++ b/src/zenml/zen_stores/schemas/flavor_schemas.py @@ -14,7 +14,8 @@ """SQL Model Implementations for Flavors.""" import json -from typing import Any, Optional, Sequence +from typing import Any, Optional +from collections.abc import Sequence from uuid import UUID from sqlalchemy import TEXT, Column, UniqueConstraint @@ -59,12 +60,12 @@ class FlavorSchema(NamedSchema, table=True): type: str source: str config_schema: str = Field(sa_column=Column(TEXT, nullable=False)) - integration: Optional[str] = Field(default="") - connector_type: Optional[str] - connector_resource_type: Optional[str] - connector_resource_id_attr: Optional[str] + integration: str | None = Field(default="") + connector_type: str | None + connector_resource_type: str | None + connector_resource_id_attr: str | None - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -74,11 +75,11 @@ class FlavorSchema(NamedSchema, table=True): ) user: Optional["UserSchema"] = Relationship(back_populates="flavors") - logo_url: Optional[str] = Field() + logo_url: str | None = Field() - docs_url: Optional[str] = Field() + docs_url: str | None = Field() - sdk_docs_url: Optional[str] = Field() + sdk_docs_url: str | None = Field() is_custom: bool = Field(default=True) diff --git a/src/zenml/zen_stores/schemas/logs_schemas.py b/src/zenml/zen_stores/schemas/logs_schemas.py index 1900b0bc488..7cf223ff051 100644 --- a/src/zenml/zen_stores/schemas/logs_schemas.py +++ b/src/zenml/zen_stores/schemas/logs_schemas.py @@ -49,7 +49,7 @@ class LogsSchema(BaseSchema, table=True): source: str = Field(sa_column=Column(VARCHAR(255), nullable=False)) # Foreign Keys - pipeline_run_id: Optional[UUID] = build_foreign_key_field( + pipeline_run_id: UUID | None = build_foreign_key_field( source=__tablename__, target=PipelineRunSchema.__tablename__, source_column="pipeline_run_id", @@ -57,7 +57,7 @@ class LogsSchema(BaseSchema, table=True): ondelete="CASCADE", nullable=True, ) - step_run_id: Optional[UUID] = build_foreign_key_field( + step_run_id: UUID | None = build_foreign_key_field( source=__tablename__, target=StepRunSchema.__tablename__, source_column="step_run_id", diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 96856f1cdbd..d63faabd05d 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """SQLModel implementation of model tables.""" -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, cast +from typing import TYPE_CHECKING, Any, Optional, cast +from collections.abc import Sequence from uuid import UUID, uuid4 from pydantic import ConfigDict @@ -99,7 +100,7 @@ class ModelSchema(NamedSchema, table=True): ) project: "ProjectSchema" = Relationship(back_populates="models") - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -119,7 +120,7 @@ class ModelSchema(NamedSchema, table=True): save_models_to_registry: bool = Field( sa_column=Column(BOOLEAN, nullable=False) ) - tags: List["TagSchema"] = Relationship( + tags: list["TagSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL.value}', foreign(TagResourceSchema.resource_id)==ModelSchema.id)", secondary="tag_resource", @@ -128,11 +129,11 @@ class ModelSchema(NamedSchema, table=True): overlaps="tags", ), ) - model_versions: List["ModelVersionSchema"] = Relationship( + model_versions: list["ModelVersionSchema"] = Relationship( back_populates="model", sa_relationship_kwargs={"cascade": "delete"}, ) - visualizations: List["CuratedVisualizationSchema"] = Relationship( + visualizations: list["CuratedVisualizationSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=( "and_(CuratedVisualizationSchema.resource_type" @@ -357,7 +358,7 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): ) project: "ProjectSchema" = Relationship(back_populates="model_versions") - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -378,7 +379,7 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): nullable=False, ) model: "ModelSchema" = Relationship(back_populates="model_versions") - tags: List["TagSchema"] = Relationship( + tags: list["TagSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.MODEL_VERSION.value}', foreign(TagResourceSchema.resource_id)==ModelVersionSchema.id)", secondary="tag_resource", @@ -388,7 +389,7 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): ), ) - services: List["ServiceSchema"] = Relationship( + services: list["ServiceSchema"] = Relationship( back_populates="model_version", ) @@ -396,7 +397,7 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): description: str = Field(sa_column=Column(TEXT, nullable=True)) stage: str = Field(sa_column=Column(TEXT, nullable=True)) - run_metadata: List["RunMetadataSchema"] = Relationship( + run_metadata: list["RunMetadataSchema"] = Relationship( sa_relationship_kwargs=dict( secondary="run_metadata_resource", primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.MODEL_VERSION.value}', foreign(RunMetadataResourceSchema.resource_id)==ModelVersionSchema.id)", @@ -404,10 +405,10 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): overlaps="run_metadata", ), ) - pipeline_runs: List["PipelineRunSchema"] = Relationship( + pipeline_runs: list["PipelineRunSchema"] = Relationship( back_populates="model_version", ) - step_runs: List["StepRunSchema"] = Relationship( + step_runs: list["StepRunSchema"] = Relationship( back_populates="model_version" ) @@ -423,11 +424,11 @@ class ModelVersionSchema(NamedSchema, RunMetadataInterface, table=True): producer_run_id_if_numeric: UUID # Needed for cascade deletion behavior - artifact_links: List["ModelVersionArtifactSchema"] = Relationship( + artifact_links: list["ModelVersionArtifactSchema"] = Relationship( back_populates="model_version", sa_relationship_kwargs={"cascade": "delete"}, ) - pipeline_run_links: List["ModelVersionPipelineRunSchema"] = Relationship( + pipeline_run_links: list["ModelVersionPipelineRunSchema"] = Relationship( back_populates="model_version", sa_relationship_kwargs={"cascade": "delete"}, ) @@ -486,7 +487,7 @@ def from_request( cls, model_version_request: ModelVersionRequest, model_version_number: int, - producer_run_id: Optional[UUID] = None, + producer_run_id: UUID | None = None, ) -> "ModelVersionSchema": """Convert an `ModelVersionRequest` to an `ModelVersionSchema`. @@ -578,9 +579,9 @@ def to_model( def update( self, - target_stage: Optional[str] = None, - target_name: Optional[str] = None, - target_description: Optional[str] = None, + target_stage: str | None = None, + target_name: str | None = None, + target_description: str | None = None, ) -> "ModelVersionSchema": """Updates a `ModelVersionSchema` to a target stage. diff --git a/src/zenml/zen_stores/schemas/pipeline_build_schemas.py b/src/zenml/zen_stores/schemas/pipeline_build_schemas.py index 75563baac1a..60bdcb31790 100644 --- a/src/zenml/zen_stores/schemas/pipeline_build_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_build_schemas.py @@ -14,7 +14,8 @@ """SQLModel implementation of pipeline build tables.""" import json -from typing import Any, Optional, Sequence +from typing import Any, Optional +from collections.abc import Sequence from uuid import UUID from sqlalchemy import Column, String @@ -46,7 +47,7 @@ class PipelineBuildSchema(BaseSchema, table=True): __tablename__ = "pipeline_build" - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -66,7 +67,7 @@ class PipelineBuildSchema(BaseSchema, table=True): ) project: "ProjectSchema" = Relationship(back_populates="builds") - stack_id: Optional[UUID] = build_foreign_key_field( + stack_id: UUID | None = build_foreign_key_field( source=__tablename__, target=StackSchema.__tablename__, source_column="stack_id", @@ -76,7 +77,7 @@ class PipelineBuildSchema(BaseSchema, table=True): ) stack: Optional["StackSchema"] = Relationship(back_populates="builds") - pipeline_id: Optional[UUID] = build_foreign_key_field( + pipeline_id: UUID | None = build_foreign_key_field( source=__tablename__, target=PipelineSchema.__tablename__, source_column="pipeline_id", @@ -100,12 +101,12 @@ class PipelineBuildSchema(BaseSchema, table=True): is_local: bool contains_code: bool - zenml_version: Optional[str] - python_version: Optional[str] - checksum: Optional[str] - stack_checksum: Optional[str] + zenml_version: str | None + python_version: str | None + checksum: str | None + stack_checksum: str | None # Build duration in seconds - duration: Optional[int] = None + duration: int | None = None @classmethod def get_query_options( diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index e26e1eac710..537becd2ab4 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -15,7 +15,8 @@ import json from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Sequence from uuid import UUID from pydantic import ConfigDict @@ -108,18 +109,18 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): ) # Fields - orchestrator_run_id: Optional[str] = Field(nullable=True) - start_time: Optional[datetime] = Field(nullable=True) - end_time: Optional[datetime] = Field(nullable=True, default=None) + orchestrator_run_id: str | None = Field(nullable=True) + start_time: datetime | None = Field(nullable=True) + end_time: datetime | None = Field(nullable=True, default=None) in_progress: bool = Field(nullable=False) status: str = Field(nullable=False) - status_reason: Optional[str] = Field(nullable=True) - orchestrator_environment: Optional[str] = Field( + status_reason: str | None = Field(nullable=True) + orchestrator_environment: str | None = Field( sa_column=Column(TEXT, nullable=True) ) # Foreign keys - snapshot_id: Optional[UUID] = build_foreign_key_field( + snapshot_id: UUID | None = build_foreign_key_field( source=__tablename__, target=PipelineSnapshotSchema.__tablename__, source_column="snapshot_id", @@ -127,7 +128,7 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): ondelete="CASCADE", nullable=True, ) - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -143,7 +144,7 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): ondelete="CASCADE", nullable=False, ) - pipeline_id: Optional[UUID] = build_foreign_key_field( + pipeline_id: UUID | None = build_foreign_key_field( source=__tablename__, target=PipelineSchema.__tablename__, source_column="pipeline_id", @@ -166,7 +167,7 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): ) project: "ProjectSchema" = Relationship(back_populates="runs") user: Optional["UserSchema"] = Relationship(back_populates="runs") - run_metadata: List["RunMetadataSchema"] = Relationship( + run_metadata: list["RunMetadataSchema"] = Relationship( sa_relationship_kwargs=dict( secondary="run_metadata_resource", primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.PIPELINE_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==PipelineRunSchema.id)", @@ -174,11 +175,11 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): overlaps="run_metadata", ), ) - logs: List["LogsSchema"] = Relationship( + logs: list["LogsSchema"] = Relationship( back_populates="pipeline_run", sa_relationship_kwargs={"cascade": "delete"}, ) - step_runs: List["StepRunSchema"] = Relationship( + step_runs: list["StepRunSchema"] = Relationship( sa_relationship_kwargs={"cascade": "delete"}, ) model_version: "ModelVersionSchema" = Relationship( @@ -186,14 +187,14 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): ) # Temporary fields and foreign keys to be deprecated - pipeline_configuration: Optional[str] = Field( + pipeline_configuration: str | None = Field( sa_column=Column(TEXT, nullable=True) ) - client_environment: Optional[str] = Field( + client_environment: str | None = Field( sa_column=Column(TEXT, nullable=True) ) - stack_id: Optional[UUID] = build_foreign_key_field( + stack_id: UUID | None = build_foreign_key_field( source=__tablename__, target=StackSchema.__tablename__, source_column="stack_id", @@ -201,7 +202,7 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): ondelete="SET NULL", nullable=True, ) - build_id: Optional[UUID] = build_foreign_key_field( + build_id: UUID | None = build_foreign_key_field( source=__tablename__, target=PipelineBuildSchema.__tablename__, source_column="build_id", @@ -209,7 +210,7 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): ondelete="SET NULL", nullable=True, ) - schedule_id: Optional[UUID] = build_foreign_key_field( + schedule_id: UUID | None = build_foreign_key_field( source=__tablename__, target=ScheduleSchema.__tablename__, source_column="schedule_id", @@ -217,7 +218,7 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): ondelete="SET NULL", nullable=True, ) - trigger_execution_id: Optional[UUID] = build_foreign_key_field( + trigger_execution_id: UUID | None = build_foreign_key_field( source=__tablename__, target=TriggerExecutionSchema.__tablename__, source_column="trigger_execution_id", @@ -231,13 +232,13 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): schedule: Optional["ScheduleSchema"] = Relationship() pipeline: Optional["PipelineSchema"] = Relationship() trigger_execution: Optional["TriggerExecutionSchema"] = Relationship() - triggered_by: Optional[UUID] = None - triggered_by_type: Optional[str] = None + triggered_by: UUID | None = None + triggered_by_type: str | None = None - services: List["ServiceSchema"] = Relationship( + services: list["ServiceSchema"] = Relationship( back_populates="pipeline_run", ) - tags: List["TagSchema"] = Relationship( + tags: list["TagSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.PIPELINE_RUN.value}', foreign(TagResourceSchema.resource_id)==PipelineRunSchema.id)", secondary="tag_resource", @@ -246,7 +247,7 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): overlaps="tags", ), ) - visualizations: List["CuratedVisualizationSchema"] = Relationship( + visualizations: list["CuratedVisualizationSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=( "and_(CuratedVisualizationSchema.resource_type" @@ -260,7 +261,7 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): ) # Needed for cascade deletion - model_versions_pipeline_runs_links: List[ + model_versions_pipeline_runs_links: list[ "ModelVersionPipelineRunSchema" ] = Relationship( back_populates="pipeline_run", @@ -436,7 +437,7 @@ def get_step_configuration(self, step_name: str) -> Step: else: raise RuntimeError("Pipeline run has no snapshot.") - def get_upstream_steps(self) -> Dict[str, List[str]]: + def get_upstream_steps(self) -> dict[str, list[str]]: """Get the list of all the upstream steps for each step. Returns: @@ -459,7 +460,7 @@ def get_upstream_steps(self) -> Dict[str, List[str]]: def fetch_metadata_collection( self, include_full_metadata: bool = False, **kwargs: Any - ) -> Dict[str, List[RunMetadataEntry]]: + ) -> dict[str, list[RunMetadataEntry]]: """Fetches all the metadata entries related to the pipeline run. Args: @@ -566,7 +567,7 @@ def to_model( client_environment.pop("python_packages", None) orchestrator_environment.pop("python_packages", None) - trigger_info: Optional[PipelineRunTriggerInfo] = None + trigger_info: PipelineRunTriggerInfo | None = None if self.triggered_by and self.triggered_by_type: if ( self.triggered_by_type diff --git a/src/zenml/zen_stores/schemas/pipeline_schemas.py b/src/zenml/zen_stores/schemas/pipeline_schemas.py index cf0923d641c..87e388fc088 100644 --- a/src/zenml/zen_stores/schemas/pipeline_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_schemas.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """SQL Model Implementations for Pipelines and Pipeline Runs.""" -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Sequence from uuid import UUID from sqlalchemy import TEXT, Column, UniqueConstraint @@ -64,7 +65,7 @@ class PipelineSchema(NamedSchema, table=True): ), ) # Fields - description: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) + description: str | None = Field(sa_column=Column(TEXT, nullable=True)) # Foreign keys project_id: UUID = build_foreign_key_field( @@ -75,7 +76,7 @@ class PipelineSchema(NamedSchema, table=True): ondelete="CASCADE", nullable=False, ) - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -88,17 +89,17 @@ class PipelineSchema(NamedSchema, table=True): user: Optional["UserSchema"] = Relationship(back_populates="pipelines") project: "ProjectSchema" = Relationship(back_populates="pipelines") - schedules: List["ScheduleSchema"] = Relationship( + schedules: list["ScheduleSchema"] = Relationship( back_populates="pipeline", ) - builds: List["PipelineBuildSchema"] = Relationship( + builds: list["PipelineBuildSchema"] = Relationship( back_populates="pipeline" ) - snapshots: List["PipelineSnapshotSchema"] = Relationship( + snapshots: list["PipelineSnapshotSchema"] = Relationship( back_populates="pipeline", sa_relationship_kwargs={"cascade": "delete"}, ) - tags: List["TagSchema"] = Relationship( + tags: list["TagSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.PIPELINE.value}', foreign(TagResourceSchema.resource_id)==PipelineSchema.id)", secondary="tag_resource", @@ -107,7 +108,7 @@ class PipelineSchema(NamedSchema, table=True): overlaps="tags", ), ) - visualizations: List["CuratedVisualizationSchema"] = Relationship( + visualizations: list["CuratedVisualizationSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=( "and_(CuratedVisualizationSchema.resource_type" diff --git a/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py b/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py index 668b0ef04e4..75e50e59144 100644 --- a/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py @@ -14,7 +14,8 @@ """Pipeline snapshot schemas.""" import json -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Sequence from uuid import UUID from sqlalchemy import TEXT, Column, String, UniqueConstraint @@ -78,8 +79,8 @@ class PipelineSnapshotSchema(BaseSchema, table=True): ) # Fields - name: Optional[str] = Field(nullable=True) - description: Optional[str] = Field( + name: str | None = Field(nullable=True) + description: str | None = Field( sa_column=Column( String(length=MEDIUMTEXT_MAX_LENGTH).with_variant( MEDIUMTEXT, "mysql" @@ -100,8 +101,8 @@ class PipelineSnapshotSchema(BaseSchema, table=True): run_name_template: str = Field(nullable=False) client_version: str = Field(nullable=True) server_version: str = Field(nullable=True) - pipeline_version_hash: Optional[str] = Field(nullable=True, default=None) - pipeline_spec: Optional[str] = Field( + pipeline_version_hash: str | None = Field(nullable=True, default=None) + pipeline_spec: str | None = Field( sa_column=Column( String(length=MEDIUMTEXT_MAX_LENGTH).with_variant( MEDIUMTEXT, "mysql" @@ -109,10 +110,10 @@ class PipelineSnapshotSchema(BaseSchema, table=True): nullable=True, ) ) - code_path: Optional[str] = Field(nullable=True) + code_path: str | None = Field(nullable=True) # Foreign keys - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -128,7 +129,7 @@ class PipelineSnapshotSchema(BaseSchema, table=True): ondelete="CASCADE", nullable=False, ) - stack_id: Optional[UUID] = build_foreign_key_field( + stack_id: UUID | None = build_foreign_key_field( source=__tablename__, target=StackSchema.__tablename__, source_column="stack_id", @@ -144,7 +145,7 @@ class PipelineSnapshotSchema(BaseSchema, table=True): ondelete="CASCADE", nullable=False, ) - schedule_id: Optional[UUID] = build_foreign_key_field( + schedule_id: UUID | None = build_foreign_key_field( source=__tablename__, target=ScheduleSchema.__tablename__, source_column="schedule_id", @@ -152,7 +153,7 @@ class PipelineSnapshotSchema(BaseSchema, table=True): ondelete="SET NULL", nullable=True, ) - build_id: Optional[UUID] = build_foreign_key_field( + build_id: UUID | None = build_foreign_key_field( source=__tablename__, target=PipelineBuildSchema.__tablename__, source_column="build_id", @@ -160,7 +161,7 @@ class PipelineSnapshotSchema(BaseSchema, table=True): ondelete="SET NULL", nullable=True, ) - code_reference_id: Optional[UUID] = build_foreign_key_field( + code_reference_id: UUID | None = build_foreign_key_field( source=__tablename__, target=CodeReferenceSchema.__tablename__, source_column="code_reference_id", @@ -170,10 +171,10 @@ class PipelineSnapshotSchema(BaseSchema, table=True): ) # This is not a foreign key to remove a cycle which messes with our DB # backup process - source_snapshot_id: Optional[UUID] = None + source_snapshot_id: UUID | None = None # Deprecated, remove once we remove run templates entirely - template_id: Optional[UUID] = None + template_id: UUID | None = None # SQLModel Relationships source_snapshot: Optional["PipelineSnapshotSchema"] = Relationship( @@ -196,13 +197,13 @@ class PipelineSnapshotSchema(BaseSchema, table=True): ) code_reference: Optional["CodeReferenceSchema"] = Relationship() - pipeline_runs: List["PipelineRunSchema"] = Relationship( + pipeline_runs: list["PipelineRunSchema"] = Relationship( sa_relationship_kwargs={"cascade": "delete"} ) - step_runs: List["StepRunSchema"] = Relationship( + step_runs: list["StepRunSchema"] = Relationship( sa_relationship_kwargs={"cascade": "delete"} ) - step_configurations: List["StepConfigurationSchema"] = Relationship( + step_configurations: list["StepConfigurationSchema"] = Relationship( sa_relationship_kwargs={ "cascade": "delete", "order_by": "asc(StepConfigurationSchema.index)", @@ -212,7 +213,7 @@ class PipelineSnapshotSchema(BaseSchema, table=True): back_populates="snapshot" ) step_count: int - tags: List["TagSchema"] = Relationship( + tags: list["TagSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.PIPELINE_SNAPSHOT.value}', foreign(TagResourceSchema.resource_id)==PipelineSnapshotSchema.id)", secondary="tag_resource", @@ -221,7 +222,7 @@ class PipelineSnapshotSchema(BaseSchema, table=True): overlaps="tags", ), ) - visualizations: List["CuratedVisualizationSchema"] = Relationship( + visualizations: list["CuratedVisualizationSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=( "and_(CuratedVisualizationSchema.resource_type" @@ -281,8 +282,8 @@ def latest_run(self) -> Optional["PipelineRunSchema"]: ) def get_step_configurations( - self, include: Optional[List[str]] = None - ) -> List["StepConfigurationSchema"]: + self, include: list[str] | None = None + ) -> list["StepConfigurationSchema"]: """Get step configurations for the snapshot. Args: @@ -382,7 +383,7 @@ def get_query_options( def from_request( cls, request: PipelineSnapshotRequest, - code_reference_id: Optional[UUID], + code_reference_id: UUID | None, ) -> "PipelineSnapshotSchema": """Create schema from request. @@ -460,8 +461,8 @@ def to_model( include_metadata: bool = False, include_resources: bool = False, include_python_packages: bool = False, - include_config_schema: Optional[bool] = None, - step_configuration_filter: Optional[List[str]] = None, + include_config_schema: bool | None = None, + step_configuration_filter: list[str] | None = None, **kwargs: Any, ) -> PipelineSnapshotResponse: """Convert schema to response. diff --git a/src/zenml/zen_stores/schemas/project_schemas.py b/src/zenml/zen_stores/schemas/project_schemas.py index a4fe796e1a3..fa36ee1e4cb 100644 --- a/src/zenml/zen_stores/schemas/project_schemas.py +++ b/src/zenml/zen_stores/schemas/project_schemas.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """SQL Model Implementations for projects.""" -from typing import TYPE_CHECKING, Any, List +from typing import TYPE_CHECKING, Any from sqlalchemy import UniqueConstraint from sqlmodel import Relationship @@ -65,71 +65,71 @@ class ProjectSchema(NamedSchema, table=True): display_name: str description: str - pipelines: List["PipelineSchema"] = Relationship( + pipelines: list["PipelineSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - schedules: List["ScheduleSchema"] = Relationship( + schedules: list["ScheduleSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - runs: List["PipelineRunSchema"] = Relationship( + runs: list["PipelineRunSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - step_runs: List["StepRunSchema"] = Relationship( + step_runs: list["StepRunSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - builds: List["PipelineBuildSchema"] = Relationship( + builds: list["PipelineBuildSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - artifact_versions: List["ArtifactVersionSchema"] = Relationship( + artifact_versions: list["ArtifactVersionSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - run_metadata: List["RunMetadataSchema"] = Relationship( + run_metadata: list["RunMetadataSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - actions: List["ActionSchema"] = Relationship( + actions: list["ActionSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - triggers: List["TriggerSchema"] = Relationship( + triggers: list["TriggerSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - event_sources: List["EventSourceSchema"] = Relationship( + event_sources: list["EventSourceSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - snapshots: List["PipelineSnapshotSchema"] = Relationship( + snapshots: list["PipelineSnapshotSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - code_repositories: List["CodeRepositorySchema"] = Relationship( + code_repositories: list["CodeRepositorySchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - services: List["ServiceSchema"] = Relationship( + services: list["ServiceSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - models: List["ModelSchema"] = Relationship( + models: list["ModelSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - model_versions: List["ModelVersionSchema"] = Relationship( + model_versions: list["ModelVersionSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - deployments: List["DeploymentSchema"] = Relationship( + deployments: list["DeploymentSchema"] = Relationship( back_populates="project", sa_relationship_kwargs={"cascade": "delete"}, ) - visualizations: List["CuratedVisualizationSchema"] = Relationship( + visualizations: list["CuratedVisualizationSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=( "and_(CuratedVisualizationSchema.resource_type" diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index 829b676ed2f..0e3cc8ca894 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -35,7 +35,7 @@ class RunMetadataSchema(BaseSchema, table=True): __tablename__ = "run_metadata" - stack_component_id: Optional[UUID] = build_foreign_key_field( + stack_component_id: UUID | None = build_foreign_key_field( source=__tablename__, target=StackComponentSchema.__tablename__, source_column="stack_component_id", @@ -47,7 +47,7 @@ class RunMetadataSchema(BaseSchema, table=True): back_populates="run_metadata" ) - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -71,7 +71,7 @@ class RunMetadataSchema(BaseSchema, table=True): value: str = Field(sa_column=Column(TEXT, nullable=False)) type: str - publisher_step_id: Optional[UUID] = build_foreign_key_field( + publisher_step_id: UUID | None = build_foreign_key_field( source=__tablename__, target=StepRunSchema.__tablename__, source_column="publisher_step_id", diff --git a/src/zenml/zen_stores/schemas/run_template_schemas.py b/src/zenml/zen_stores/schemas/run_template_schemas.py index 91c96543154..0922cc5a871 100644 --- a/src/zenml/zen_stores/schemas/run_template_schemas.py +++ b/src/zenml/zen_stores/schemas/run_template_schemas.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """SQLModel implementation of run template tables.""" -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Sequence from uuid import UUID from sqlalchemy import Column, String, UniqueConstraint @@ -64,7 +65,7 @@ class RunTemplateSchema(NamedSchema, table=True): title="Whether the run template is hidden.", ) - description: Optional[str] = Field( + description: str | None = Field( sa_column=Column( String(length=MEDIUMTEXT_MAX_LENGTH).with_variant( MEDIUMTEXT, "mysql" @@ -73,7 +74,7 @@ class RunTemplateSchema(NamedSchema, table=True): ) ) - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -89,7 +90,7 @@ class RunTemplateSchema(NamedSchema, table=True): ondelete="CASCADE", nullable=False, ) - source_snapshot_id: Optional[UUID] = build_foreign_key_field( + source_snapshot_id: UUID | None = build_foreign_key_field( source=__tablename__, target="pipeline_snapshot", source_column="source_snapshot_id", @@ -108,7 +109,7 @@ class RunTemplateSchema(NamedSchema, table=True): } ) - tags: List["TagSchema"] = Relationship( + tags: list["TagSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=f"and_(foreign(TagResourceSchema.resource_type)=='{TaggableResourceTypes.RUN_TEMPLATE.value}', foreign(TagResourceSchema.resource_id)==RunTemplateSchema.id)", secondary="tag_resource", diff --git a/src/zenml/zen_stores/schemas/schedule_schema.py b/src/zenml/zen_stores/schemas/schedule_schema.py index 5c1c8cbfedd..cccf24fc07e 100644 --- a/src/zenml/zen_stores/schemas/schedule_schema.py +++ b/src/zenml/zen_stores/schemas/schedule_schema.py @@ -14,7 +14,8 @@ """SQL Model Implementations for Pipeline Schedules.""" from datetime import datetime, timedelta -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Sequence from uuid import UUID from sqlalchemy import UniqueConstraint @@ -74,7 +75,7 @@ class ScheduleSchema(NamedSchema, RunMetadataInterface, table=True): ) project: "ProjectSchema" = Relationship(back_populates="schedules") - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -84,7 +85,7 @@ class ScheduleSchema(NamedSchema, RunMetadataInterface, table=True): ) user: Optional["UserSchema"] = Relationship(back_populates="schedules") - pipeline_id: Optional[UUID] = build_foreign_key_field( + pipeline_id: UUID | None = build_foreign_key_field( source=__tablename__, target=PipelineSchema.__tablename__, source_column="pipeline_id", @@ -97,7 +98,7 @@ class ScheduleSchema(NamedSchema, RunMetadataInterface, table=True): back_populates="schedule" ) - orchestrator_id: Optional[UUID] = build_foreign_key_field( + orchestrator_id: UUID | None = build_foreign_key_field( source=__tablename__, target=StackComponentSchema.__tablename__, source_column="orchestrator_id", @@ -109,7 +110,7 @@ class ScheduleSchema(NamedSchema, RunMetadataInterface, table=True): back_populates="schedules" ) - run_metadata: List["RunMetadataSchema"] = Relationship( + run_metadata: list["RunMetadataSchema"] = Relationship( sa_relationship_kwargs=dict( secondary="run_metadata_resource", primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.SCHEDULE.value}', foreign(RunMetadataResourceSchema.resource_id)==ScheduleSchema.id)", @@ -119,12 +120,12 @@ class ScheduleSchema(NamedSchema, RunMetadataInterface, table=True): ) active: bool - cron_expression: Optional[str] = Field(nullable=True) - start_time: Optional[datetime] = Field(nullable=True) - end_time: Optional[datetime] = Field(nullable=True) - interval_second: Optional[float] = Field(nullable=True) + cron_expression: str | None = Field(nullable=True) + start_time: datetime | None = Field(nullable=True) + end_time: datetime | None = Field(nullable=True) + interval_second: float | None = Field(nullable=True) catchup: bool - run_once_start_time: Optional[datetime] = Field(nullable=True) + run_once_start_time: datetime | None = Field(nullable=True) @classmethod def get_query_options( diff --git a/src/zenml/zen_stores/schemas/schema_utils.py b/src/zenml/zen_stores/schemas/schema_utils.py index c179f5543df..f1189bf8cd1 100644 --- a/src/zenml/zen_stores/schemas/schema_utils.py +++ b/src/zenml/zen_stores/schemas/schema_utils.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Utility functions for SQLModel schemas.""" -from typing import Any, List, Optional +from typing import Any from sqlalchemy import Column, ForeignKey, Index from sqlmodel import Field @@ -45,7 +45,7 @@ def build_foreign_key_field( target_column: str, ondelete: str, nullable: bool, - custom_constraint_name: Optional[str] = None, + custom_constraint_name: str | None = None, **sa_column_kwargs: Any, ) -> Any: """Build a SQLModel foreign key field. @@ -94,7 +94,7 @@ def build_foreign_key_field( ) -def get_index_name(table_name: str, column_names: List[str]) -> str: +def get_index_name(table_name: str, column_names: list[str]) -> str: """Get the name for an index. Args: @@ -110,7 +110,7 @@ def get_index_name(table_name: str, column_names: List[str]) -> str: def build_index( - table_name: str, column_names: List[str], **kwargs: Any + table_name: str, column_names: list[str], **kwargs: Any ) -> Index: """Build an index object. diff --git a/src/zenml/zen_stores/schemas/secret_schemas.py b/src/zenml/zen_stores/schemas/secret_schemas.py index 67c62d1efee..6721e4ad8ef 100644 --- a/src/zenml/zen_stores/schemas/secret_schemas.py +++ b/src/zenml/zen_stores/schemas/secret_schemas.py @@ -15,7 +15,8 @@ import base64 import json -from typing import Any, Dict, Optional, Sequence, cast +from typing import Any, cast +from collections.abc import Sequence from uuid import UUID from sqlalchemy import TEXT, VARCHAR, Column, UniqueConstraint @@ -72,7 +73,7 @@ class SecretSchema(NamedSchema, table=True): internal: bool = Field(default=False) - values: Optional[bytes] = Field(sa_column=Column(TEXT, nullable=True)) + values: bytes | None = Field(sa_column=Column(TEXT, nullable=True)) user_id: UUID = build_foreign_key_field( source=__tablename__, @@ -112,7 +113,7 @@ def get_query_options( @classmethod def _dump_secret_values( - cls, values: Dict[str, str], encryption_engine: Optional[AesGcmEngine] + cls, values: dict[str, str], encryption_engine: AesGcmEngine | None ) -> bytes: """Dump the secret values to a string. @@ -149,8 +150,8 @@ def _dump_secret_values( def _load_secret_values( cls, encrypted_values: bytes, - encryption_engine: Optional[AesGcmEngine] = None, - ) -> Dict[str, str]: + encryption_engine: AesGcmEngine | None = None, + ) -> dict[str, str]: """Load the secret values from a base64 encoded byte string. Args: @@ -183,7 +184,7 @@ def _load_secret_values( try: return cast( - Dict[str, str], + dict[str, str], json.loads(serialized_values), ) except json.JSONDecodeError as e: @@ -288,8 +289,8 @@ def to_model( def get_secret_values( self, - encryption_engine: Optional[AesGcmEngine] = None, - ) -> Dict[str, str]: + encryption_engine: AesGcmEngine | None = None, + ) -> dict[str, str]: """Get the secret values for this secret. This method is used by the SQL secrets store to load the secret values @@ -315,8 +316,8 @@ def get_secret_values( def set_secret_values( self, - secret_values: Dict[str, str], - encryption_engine: Optional[AesGcmEngine] = None, + secret_values: dict[str, str], + encryption_engine: AesGcmEngine | None = None, ) -> None: """Create a `SecretSchema` from a `SecretRequest`. diff --git a/src/zenml/zen_stores/schemas/server_settings_schemas.py b/src/zenml/zen_stores/schemas/server_settings_schemas.py index e201ff93ac1..73253b0f6bd 100644 --- a/src/zenml/zen_stores/schemas/server_settings_schemas.py +++ b/src/zenml/zen_stores/schemas/server_settings_schemas.py @@ -15,7 +15,7 @@ import json from datetime import datetime -from typing import Any, Optional, Set +from typing import Any from uuid import UUID from sqlalchemy import TEXT, Column @@ -38,12 +38,12 @@ class ServerSettingsSchema(SQLModel, table=True): id: UUID = Field(primary_key=True) server_name: str - logo_url: Optional[str] = Field(nullable=True) + logo_url: str | None = Field(nullable=True) active: bool = Field(default=False) enable_analytics: bool = Field(default=False) - display_announcements: Optional[bool] = Field(nullable=True) - display_updates: Optional[bool] = Field(nullable=True) - onboarding_state: Optional[str] = Field( + display_announcements: bool | None = Field(nullable=True) + display_updates: bool | None = Field(nullable=True) + onboarding_state: str | None = Field( sa_column=Column(TEXT, nullable=True) ) last_user_activity: datetime = Field(default_factory=utc_now) @@ -72,7 +72,7 @@ def update( return self def update_onboarding_state( - self, completed_steps: Set[str] + self, completed_steps: set[str] ) -> "ServerSettingsSchema": """Update the onboarding state. diff --git a/src/zenml/zen_stores/schemas/service_connector_schemas.py b/src/zenml/zen_stores/schemas/service_connector_schemas.py index e095491ef4b..365cc026d32 100644 --- a/src/zenml/zen_stores/schemas/service_connector_schemas.py +++ b/src/zenml/zen_stores/schemas/service_connector_schemas.py @@ -16,7 +16,8 @@ import base64 import json from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, cast +from typing import TYPE_CHECKING, Any, Optional, cast +from collections.abc import Sequence from uuid import UUID from sqlalchemy import TEXT, Column, UniqueConstraint @@ -58,16 +59,16 @@ class ServiceConnectorSchema(NamedSchema, table=True): description: str auth_method: str = Field(sa_column=Column(TEXT)) resource_types: bytes - resource_id: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) + resource_id: str | None = Field(sa_column=Column(TEXT, nullable=True)) supports_instances: bool - configuration: Optional[bytes] - secret_id: Optional[UUID] - expires_at: Optional[datetime] - expires_skew_tolerance: Optional[int] - expiration_seconds: Optional[int] - labels: Optional[bytes] - - user_id: Optional[UUID] = build_foreign_key_field( + configuration: bytes | None + secret_id: UUID | None + expires_at: datetime | None + expires_skew_tolerance: int | None + expiration_seconds: int | None + labels: bytes | None + + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -78,7 +79,7 @@ class ServiceConnectorSchema(NamedSchema, table=True): user: Optional["UserSchema"] = Relationship( back_populates="service_connectors" ) - components: List["StackComponentSchema"] = Relationship( + components: list["StackComponentSchema"] = Relationship( back_populates="connector", ) @@ -109,7 +110,7 @@ def get_query_options( return options @property - def resource_types_list(self) -> List[str]: + def resource_types_list(self) -> list[str]: """Returns the resource types as a list. Returns: @@ -122,7 +123,7 @@ def resource_types_list(self) -> List[str]: return resource_types @property - def labels_dict(self) -> Dict[str, str]: + def labels_dict(self) -> dict[str, str]: """Returns the labels as a dictionary. Returns: @@ -131,9 +132,9 @@ def labels_dict(self) -> Dict[str, str]: if self.labels is None: return {} labels_dict = json.loads(base64.b64decode(self.labels).decode()) - return cast(Dict[str, str], labels_dict) + return cast(dict[str, str], labels_dict) - def has_labels(self, labels: Dict[str, Optional[str]]) -> bool: + def has_labels(self, labels: dict[str, str | None]) -> bool: """Checks if the connector has the given labels. Args: @@ -156,7 +157,7 @@ def has_labels(self, labels: Dict[str, Optional[str]]) -> bool: def from_request( cls, connector_request: ServiceConnectorRequest, - secret_id: Optional[UUID] = None, + secret_id: UUID | None = None, ) -> "ServiceConnectorSchema": """Create a `ServiceConnectorSchema` from a `ServiceConnectorRequest`. @@ -200,7 +201,7 @@ def from_request( def update( self, connector_update: ServiceConnectorUpdate, - secret_id: Optional[UUID] = None, + secret_id: UUID | None = None, ) -> "ServiceConnectorSchema": """Updates a `ServiceConnectorSchema` from a `ServiceConnectorUpdate`. diff --git a/src/zenml/zen_stores/schemas/service_schemas.py b/src/zenml/zen_stores/schemas/service_schemas.py index 6ea7dbeddc0..16443b36cd8 100644 --- a/src/zenml/zen_stores/schemas/service_schemas.py +++ b/src/zenml/zen_stores/schemas/service_schemas.py @@ -15,7 +15,8 @@ import base64 import json -from typing import Any, Optional, Sequence +from typing import Any, Optional +from collections.abc import Sequence from uuid import UUID from pydantic import ConfigDict @@ -58,7 +59,7 @@ class ServiceSchema(NamedSchema, table=True): ) project: "ProjectSchema" = Relationship(back_populates="services") - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -67,29 +68,29 @@ class ServiceSchema(NamedSchema, table=True): nullable=True, ) user: Optional["UserSchema"] = Relationship(back_populates="services") - service_source: Optional[str] = Field( + service_source: str | None = Field( sa_column=Column(TEXT, nullable=True) ) service_type: str = Field(sa_column=Column(TEXT, nullable=False)) type: str = Field(sa_column=Column(TEXT, nullable=False)) flavor: str = Field(sa_column=Column(TEXT, nullable=False)) - admin_state: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) - state: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) - labels: Optional[bytes] + admin_state: str | None = Field(sa_column=Column(TEXT, nullable=True)) + state: str | None = Field(sa_column=Column(TEXT, nullable=True)) + labels: bytes | None config: bytes - status: Optional[bytes] - endpoint: Optional[bytes] - prediction_url: Optional[str] = Field( + status: bytes | None + endpoint: bytes | None + prediction_url: str | None = Field( sa_column=Column(TEXT, nullable=True) ) - health_check_url: Optional[str] = Field( + health_check_url: str | None = Field( sa_column=Column(TEXT, nullable=True) ) - pipeline_name: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) - pipeline_step_name: Optional[str] = Field( + pipeline_name: str | None = Field(sa_column=Column(TEXT, nullable=True)) + pipeline_step_name: str | None = Field( sa_column=Column(TEXT, nullable=True) ) - model_version_id: Optional[UUID] = build_foreign_key_field( + model_version_id: UUID | None = build_foreign_key_field( source=__tablename__, target=ModelVersionSchema.__tablename__, source_column="model_version_id", @@ -100,7 +101,7 @@ class ServiceSchema(NamedSchema, table=True): model_version: Optional["ModelVersionSchema"] = Relationship( back_populates="services", ) - pipeline_run_id: Optional[UUID] = build_foreign_key_field( + pipeline_run_id: UUID | None = build_foreign_key_field( source=__tablename__, target="pipeline_run", source_column="pipeline_run_id", diff --git a/src/zenml/zen_stores/schemas/stack_schemas.py b/src/zenml/zen_stores/schemas/stack_schemas.py index c74227c21b9..7f68443fd1f 100644 --- a/src/zenml/zen_stores/schemas/stack_schemas.py +++ b/src/zenml/zen_stores/schemas/stack_schemas.py @@ -15,7 +15,8 @@ import base64 import json -from typing import TYPE_CHECKING, Any, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Sequence from uuid import UUID from sqlalchemy import UniqueConstraint @@ -90,12 +91,12 @@ class StackSchema(NamedSchema, table=True): ), ) - description: Optional[str] = Field(default=None) - stack_spec_path: Optional[str] - labels: Optional[bytes] - environment: Optional[bytes] = Field(default=None) + description: str | None = Field(default=None) + stack_spec_path: str | None + labels: bytes | None + environment: bytes | None = Field(default=None) - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -105,15 +106,15 @@ class StackSchema(NamedSchema, table=True): ) user: Optional["UserSchema"] = Relationship(back_populates="stacks") - components: List["StackComponentSchema"] = Relationship( + components: list["StackComponentSchema"] = Relationship( back_populates="stacks", link_model=StackCompositionSchema, ) - builds: List["PipelineBuildSchema"] = Relationship(back_populates="stack") - snapshots: List["PipelineSnapshotSchema"] = Relationship( + builds: list["PipelineBuildSchema"] = Relationship(back_populates="stack") + snapshots: list["PipelineSnapshotSchema"] = Relationship( back_populates="stack", ) - secrets: List["SecretSchema"] = Relationship( + secrets: list["SecretSchema"] = Relationship( sa_relationship_kwargs=dict( primaryjoin=f"and_(foreign(SecretResourceSchema.resource_type)=='{SecretResourceTypes.STACK.value}', foreign(SecretResourceSchema.resource_id)==StackSchema.id)", secondary="secret_resource", @@ -225,7 +226,7 @@ def get_query_options( def update( self, stack_update: "StackUpdate", - components: List["StackComponentSchema"], + components: list["StackComponentSchema"], ) -> "StackSchema": """Updates a stack schema with a stack update model. diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 2dced06c2e5..0412c9e7c8e 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -15,7 +15,8 @@ import json from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional +from collections.abc import Sequence from uuid import UUID from pydantic import ConfigDict @@ -84,15 +85,15 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): ) # Fields - start_time: Optional[datetime] = Field(nullable=True) - end_time: Optional[datetime] = Field(nullable=True) + start_time: datetime | None = Field(nullable=True) + end_time: datetime | None = Field(nullable=True) status: str = Field(nullable=False) - docstring: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) - cache_key: Optional[str] = Field(nullable=True) - cache_expires_at: Optional[datetime] = Field(nullable=True) - source_code: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) - code_hash: Optional[str] = Field(nullable=True) + docstring: str | None = Field(sa_column=Column(TEXT, nullable=True)) + cache_key: str | None = Field(nullable=True) + cache_expires_at: datetime | None = Field(nullable=True) + source_code: str | None = Field(sa_column=Column(TEXT, nullable=True)) + code_hash: str | None = Field(nullable=True) version: int = Field(nullable=False) is_retriable: bool = Field(nullable=False) @@ -104,7 +105,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): nullable=True, ) ) - exception_info: Optional[str] = Field( + exception_info: str | None = Field( sa_column=Column( String(length=MEDIUMTEXT_MAX_LENGTH).with_variant( MEDIUMTEXT, "mysql" @@ -114,7 +115,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): ) # Foreign keys - original_step_run_id: Optional[UUID] = build_foreign_key_field( + original_step_run_id: UUID | None = build_foreign_key_field( source=__tablename__, target=__tablename__, source_column="original_step_run_id", @@ -122,7 +123,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): ondelete="SET NULL", nullable=True, ) - snapshot_id: Optional[UUID] = build_foreign_key_field( + snapshot_id: UUID | None = build_foreign_key_field( source=__tablename__, target=PipelineSnapshotSchema.__tablename__, source_column="snapshot_id", @@ -138,7 +139,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): ondelete="CASCADE", nullable=False, ) - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -154,7 +155,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): ondelete="CASCADE", nullable=False, ) - model_version_id: Optional[UUID] = build_foreign_key_field( + model_version_id: UUID | None = build_foreign_key_field( source=__tablename__, target=MODEL_VERSION_TABLENAME, source_column="model_version_id", @@ -169,7 +170,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): snapshot: Optional["PipelineSnapshotSchema"] = Relationship( back_populates="step_runs" ) - run_metadata: List["RunMetadataSchema"] = Relationship( + run_metadata: list["RunMetadataSchema"] = Relationship( sa_relationship_kwargs=dict( secondary="run_metadata_resource", primaryjoin=f"and_(foreign(RunMetadataResourceSchema.resource_type)=='{MetadataResourceTypes.STEP_RUN.value}', foreign(RunMetadataResourceSchema.resource_id)==StepRunSchema.id)", @@ -177,17 +178,17 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): overlaps="run_metadata", ), ) - input_artifacts: List["StepRunInputArtifactSchema"] = Relationship( + input_artifacts: list["StepRunInputArtifactSchema"] = Relationship( sa_relationship_kwargs={"cascade": "delete"} ) - output_artifacts: List["StepRunOutputArtifactSchema"] = Relationship( + output_artifacts: list["StepRunOutputArtifactSchema"] = Relationship( sa_relationship_kwargs={"cascade": "delete"} ) logs: Optional["LogsSchema"] = Relationship( back_populates="step_run", sa_relationship_kwargs={"cascade": "delete", "uselist": False}, ) - parents: List["StepRunParentsSchema"] = Relationship( + parents: list["StepRunParentsSchema"] = Relationship( sa_relationship_kwargs={ "cascade": "delete", "primaryjoin": "StepRunParentsSchema.child_id == StepRunSchema.id", @@ -199,7 +200,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): model_version: "ModelVersionSchema" = Relationship( back_populates="step_runs", ) - triggered_runs: List["PipelineRunSchema"] = Relationship( + triggered_runs: list["PipelineRunSchema"] = Relationship( sa_relationship_kwargs={ "viewonly": True, "primaryjoin": f"and_(foreign(PipelineRunSchema.triggered_by) == StepRunSchema.id, foreign(PipelineRunSchema.triggered_by_type) == '{PipelineRunTriggeredByType.STEP_RUN.value}')", @@ -297,7 +298,7 @@ def get_query_options( def from_request( cls, request: StepRunRequest, - snapshot_id: Optional[UUID], + snapshot_id: UUID | None, version: int, is_retriable: bool, ) -> "StepRunSchema": @@ -437,7 +438,7 @@ def to_model( if self.model_version: model_version = self.model_version.to_model() - input_artifacts: Dict[str, List[StepRunInputResponse]] = {} + input_artifacts: dict[str, list[StepRunInputResponse]] = {} for input_artifact in self.input_artifacts: if input_artifact.name not in input_artifacts: input_artifacts[input_artifact.name] = [] @@ -447,7 +448,7 @@ def to_model( ) input_artifacts[input_artifact.name].append(step_run_input) - output_artifacts: Dict[str, List["ArtifactVersionResponse"]] = {} + output_artifacts: dict[str, list["ArtifactVersionResponse"]] = {} for output_artifact in self.output_artifacts: if output_artifact.name not in output_artifacts: output_artifacts[output_artifact.name] = [] diff --git a/src/zenml/zen_stores/schemas/tag_schemas.py b/src/zenml/zen_stores/schemas/tag_schemas.py index b2fc5378c67..df72f92f1a1 100644 --- a/src/zenml/zen_stores/schemas/tag_schemas.py +++ b/src/zenml/zen_stores/schemas/tag_schemas.py @@ -13,7 +13,8 @@ # permissions and limitations under the License. """SQLModel implementation of tag tables.""" -from typing import Any, List, Optional, Sequence +from typing import Any, Optional +from collections.abc import Sequence from uuid import UUID from sqlalchemy import VARCHAR, Column, UniqueConstraint @@ -54,7 +55,7 @@ class TagSchema(NamedSchema, table=True): ), ) - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -67,7 +68,7 @@ class TagSchema(NamedSchema, table=True): color: str = Field(sa_column=Column(VARCHAR(255), nullable=False)) exclusive: bool = Field(default=False) - links: List["TagResourceSchema"] = Relationship( + links: list["TagResourceSchema"] = Relationship( back_populates="tag", sa_relationship_kwargs={"overlaps": "tags", "cascade": "delete"}, ) diff --git a/src/zenml/zen_stores/schemas/trigger_schemas.py b/src/zenml/zen_stores/schemas/trigger_schemas.py index 63f7538a9ac..2fea6e8d806 100644 --- a/src/zenml/zen_stores/schemas/trigger_schemas.py +++ b/src/zenml/zen_stores/schemas/trigger_schemas.py @@ -15,7 +15,8 @@ import base64 import json -from typing import Any, List, Optional, Sequence, cast +from typing import Any, Optional, cast +from collections.abc import Sequence from uuid import UUID from sqlalchemy import TEXT, Column, UniqueConstraint @@ -74,7 +75,7 @@ class TriggerSchema(NamedSchema, table=True): ) project: "ProjectSchema" = Relationship(back_populates="triggers") - user_id: Optional[UUID] = build_foreign_key_field( + user_id: UUID | None = build_foreign_key_field( source=__tablename__, target=UserSchema.__tablename__, source_column="user_id", @@ -87,7 +88,7 @@ class TriggerSchema(NamedSchema, table=True): sa_relationship_kwargs={"foreign_keys": "[TriggerSchema.user_id]"}, ) - event_source_id: Optional[UUID] = build_foreign_key_field( + event_source_id: UUID | None = build_foreign_key_field( source=__tablename__, target=EventSourceSchema.__tablename__, source_column="event_source_id", @@ -113,12 +114,12 @@ class TriggerSchema(NamedSchema, table=True): ) action: "ActionSchema" = Relationship(back_populates="triggers") - executions: List["TriggerExecutionSchema"] = Relationship( + executions: list["TriggerExecutionSchema"] = Relationship( back_populates="trigger" ) event_filter: bytes - schedule: Optional[bytes] = Field(nullable=True) + schedule: bytes | None = Field(nullable=True) description: str = Field(sa_column=Column(TEXT, nullable=True)) is_active: bool = Field(nullable=False) @@ -301,7 +302,7 @@ class TriggerExecutionSchema(BaseSchema, table=True): ) trigger: TriggerSchema = Relationship(back_populates="executions") - event_metadata: Optional[bytes] = None + event_metadata: bytes | None = None @classmethod def from_request( diff --git a/src/zenml/zen_stores/schemas/user_schemas.py b/src/zenml/zen_stores/schemas/user_schemas.py index 6e65d978e6d..20955b7f932 100644 --- a/src/zenml/zen_stores/schemas/user_schemas.py +++ b/src/zenml/zen_stores/schemas/user_schemas.py @@ -14,7 +14,7 @@ """SQLModel implementation of user tables.""" import json -from typing import TYPE_CHECKING, Any, List, Optional, Union +from typing import TYPE_CHECKING, Any from uuid import UUID from sqlalchemy import TEXT, Column, UniqueConstraint @@ -76,21 +76,21 @@ class UserSchema(NamedSchema, table=True): is_service_account: bool = Field(default=False) full_name: str - description: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) - email: Optional[str] = Field(nullable=True) - avatar_url: Optional[str] = Field( + description: str | None = Field(sa_column=Column(TEXT, nullable=True)) + email: str | None = Field(nullable=True) + avatar_url: str | None = Field( default=None, sa_column=Column(TEXT, nullable=True), ) active: bool - password: Optional[str] = Field(nullable=True) - activation_token: Optional[str] = Field(nullable=True) - email_opted_in: Optional[bool] = Field(nullable=True) - external_user_id: Optional[UUID] = Field(nullable=True) + password: str | None = Field(nullable=True) + activation_token: str | None = Field(nullable=True) + email_opted_in: bool | None = Field(nullable=True) + external_user_id: UUID | None = Field(nullable=True) is_admin: bool = Field(default=False) - user_metadata: Optional[str] = Field(nullable=True) + user_metadata: str | None = Field(nullable=True) - default_project_id: Optional[UUID] = build_foreign_key_field( + default_project_id: UUID | None = build_foreign_key_field( source=__tablename__, target="project", source_column="default_project_id", @@ -99,84 +99,84 @@ class UserSchema(NamedSchema, table=True): nullable=True, ) - stacks: List["StackSchema"] = Relationship(back_populates="user") - components: List["StackComponentSchema"] = Relationship( + stacks: list["StackSchema"] = Relationship(back_populates="user") + components: list["StackComponentSchema"] = Relationship( back_populates="user", ) - flavors: List["FlavorSchema"] = Relationship(back_populates="user") - actions: List["ActionSchema"] = Relationship( + flavors: list["FlavorSchema"] = Relationship(back_populates="user") + actions: list["ActionSchema"] = Relationship( back_populates="user", sa_relationship_kwargs={ "cascade": "delete", "primaryjoin": "UserSchema.id==ActionSchema.user_id", }, ) - event_sources: List["EventSourceSchema"] = Relationship( + event_sources: list["EventSourceSchema"] = Relationship( back_populates="user" ) - pipelines: List["PipelineSchema"] = Relationship(back_populates="user") - schedules: List["ScheduleSchema"] = Relationship( + pipelines: list["PipelineSchema"] = Relationship(back_populates="user") + schedules: list["ScheduleSchema"] = Relationship( back_populates="user", ) - runs: List["PipelineRunSchema"] = Relationship(back_populates="user") - run_templates: List["RunTemplateSchema"] = Relationship( + runs: list["PipelineRunSchema"] = Relationship(back_populates="user") + run_templates: list["RunTemplateSchema"] = Relationship( back_populates="user", ) - step_runs: List["StepRunSchema"] = Relationship(back_populates="user") - builds: List["PipelineBuildSchema"] = Relationship(back_populates="user") - artifacts: List["ArtifactSchema"] = Relationship(back_populates="user") - artifact_versions: List["ArtifactVersionSchema"] = Relationship( + step_runs: list["StepRunSchema"] = Relationship(back_populates="user") + builds: list["PipelineBuildSchema"] = Relationship(back_populates="user") + artifacts: list["ArtifactSchema"] = Relationship(back_populates="user") + artifact_versions: list["ArtifactVersionSchema"] = Relationship( back_populates="user" ) - run_metadata: List["RunMetadataSchema"] = Relationship( + run_metadata: list["RunMetadataSchema"] = Relationship( back_populates="user" ) - secrets: List["SecretSchema"] = Relationship( + secrets: list["SecretSchema"] = Relationship( back_populates="user", sa_relationship_kwargs={"cascade": "delete"}, ) - triggers: List["TriggerSchema"] = Relationship( + triggers: list["TriggerSchema"] = Relationship( back_populates="user", sa_relationship_kwargs={ "cascade": "delete", "primaryjoin": "UserSchema.id==TriggerSchema.user_id", }, ) - auth_actions: List["ActionSchema"] = Relationship( + auth_actions: list["ActionSchema"] = Relationship( back_populates="service_account", sa_relationship_kwargs={ "cascade": "delete", "primaryjoin": "UserSchema.id==ActionSchema.service_account_id", }, ) - snapshots: List["PipelineSnapshotSchema"] = Relationship( + snapshots: list["PipelineSnapshotSchema"] = Relationship( back_populates="user", ) - code_repositories: List["CodeRepositorySchema"] = Relationship( + code_repositories: list["CodeRepositorySchema"] = Relationship( back_populates="user", ) - services: List["ServiceSchema"] = Relationship(back_populates="user") - service_connectors: List["ServiceConnectorSchema"] = Relationship( + services: list["ServiceSchema"] = Relationship(back_populates="user") + service_connectors: list["ServiceConnectorSchema"] = Relationship( back_populates="user", ) - models: List["ModelSchema"] = Relationship( + models: list["ModelSchema"] = Relationship( back_populates="user", ) - model_versions: List["ModelVersionSchema"] = Relationship( + model_versions: list["ModelVersionSchema"] = Relationship( back_populates="user", ) - auth_devices: List["OAuthDeviceSchema"] = Relationship( + auth_devices: list["OAuthDeviceSchema"] = Relationship( back_populates="user", sa_relationship_kwargs={"cascade": "delete"}, ) - api_keys: List["APIKeySchema"] = Relationship( + api_keys: list["APIKeySchema"] = Relationship( back_populates="service_account", sa_relationship_kwargs={"cascade": "delete"}, ) - deployments: List["DeploymentSchema"] = Relationship( + deployments: list["DeploymentSchema"] = Relationship( back_populates="user", ) - tags: List["TagSchema"] = Relationship( + tags: list["TagSchema"] = Relationship( back_populates="user", ) @@ -209,7 +209,7 @@ def from_user_request(cls, model: UserRequest) -> "UserSchema": @classmethod def from_service_account_request( - cls, model: Union[ServiceAccountRequest, ServiceAccountInternalRequest] + cls, model: ServiceAccountRequest | ServiceAccountInternalRequest ) -> "UserSchema": """Create a `UserSchema` from a Service Account request. diff --git a/src/zenml/zen_stores/schemas/utils.py b/src/zenml/zen_stores/schemas/utils.py index 8c28addee16..aa420b28282 100644 --- a/src/zenml/zen_stores/schemas/utils.py +++ b/src/zenml/zen_stores/schemas/utils.py @@ -15,7 +15,7 @@ import json import math -from typing import Any, Dict, List, Type, TypeVar, cast +from typing import Any, TypeVar, cast from sqlalchemy.orm import InstrumentedAttribute from sqlmodel import Relationship @@ -40,8 +40,8 @@ def jl_arg(column: Any) -> InstrumentedAttribute[Any]: def get_page_from_list( - items_list: List[S], - response_model: Type[BaseResponse], # type: ignore[type-arg] + items_list: list[S], + response_model: type[BaseResponse], # type: ignore[type-arg] size: int = 5, page: int = 1, include_resources: bool = False, @@ -92,7 +92,7 @@ class RunMetadataInterface: def fetch_metadata_collection( self, **kwargs: Any - ) -> Dict[str, List[RunMetadataEntry]]: + ) -> dict[str, list[RunMetadataEntry]]: """Fetches all the metadata entries related to the entity. Args: @@ -102,7 +102,7 @@ def fetch_metadata_collection( A dictionary, where the key is the key of the metadata entry and the values represent the list of entries with this key. """ - metadata_collection: Dict[str, List[RunMetadataEntry]] = {} + metadata_collection: dict[str, list[RunMetadataEntry]] = {} for rm in self.run_metadata: if rm.key not in metadata_collection: @@ -116,7 +116,7 @@ def fetch_metadata_collection( return metadata_collection - def fetch_metadata(self, **kwargs: Any) -> Dict[str, MetadataType]: + def fetch_metadata(self, **kwargs: Any) -> dict[str, MetadataType]: """Fetches the latest metadata entry related to the entity. Args: @@ -133,7 +133,7 @@ def fetch_metadata(self, **kwargs: Any) -> Dict[str, MetadataType]: } -def get_resource_type_name(schema_class: Type[BaseSchema]) -> str: +def get_resource_type_name(schema_class: type[BaseSchema]) -> str: """Get the name of a resource from a schema class. Args: diff --git a/src/zenml/zen_stores/secrets_stores/aws_secrets_store.py b/src/zenml/zen_stores/secrets_stores/aws_secrets_store.py index c1c02ad4713..c28d09c900b 100644 --- a/src/zenml/zen_stores/secrets_stores/aws_secrets_store.py +++ b/src/zenml/zen_stores/secrets_stores/aws_secrets_store.py @@ -17,9 +17,6 @@ from typing import ( Any, ClassVar, - Dict, - List, - Type, ) from uuid import UUID @@ -78,7 +75,7 @@ def region(self) -> str: @model_validator(mode="before") @classmethod @before_validator_handler - def populate_config(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def populate_config(cls, data: dict[str, Any]) -> dict[str, Any]: """Populate the connector configuration from legacy attributes. Args: @@ -147,7 +144,7 @@ class AWSSecretsStore(ServiceConnectorSecretsStore): config: AWSSecretsStoreConfiguration TYPE: ClassVar[SecretsStoreType] = SecretsStoreType.AWS - CONFIG_TYPE: ClassVar[Type[ServiceConnectorSecretsStoreConfiguration]] = ( + CONFIG_TYPE: ClassVar[type[ServiceConnectorSecretsStoreConfiguration]] = ( AWSSecretsStoreConfiguration ) SERVICE_CONNECTOR_TYPE: ClassVar[str] = AWS_CONNECTOR_TYPE @@ -201,8 +198,8 @@ def _get_aws_secret_id( @staticmethod def _get_aws_secret_tags( - metadata: Dict[str, str], - ) -> List[Dict[str, str]]: + metadata: dict[str, str], + ) -> list[dict[str, str]]: """Convert ZenML secret metadata to AWS secret tags. Args: @@ -211,7 +208,7 @@ def _get_aws_secret_tags( Returns: The AWS secret tags. """ - aws_tags: List[Dict[str, str]] = [] + aws_tags: list[dict[str, str]] = [] for k, v in metadata.items(): aws_tags.append( { @@ -225,7 +222,7 @@ def _get_aws_secret_tags( def store_secret_values( self, secret_id: UUID, - secret_values: Dict[str, str], + secret_values: dict[str, str], ) -> None: """Store secret values for a new secret. @@ -255,7 +252,7 @@ def store_secret_values( logger.debug(f"Created AWS secret: {aws_secret_id}") - def get_secret_values(self, secret_id: UUID) -> Dict[str, str]: + def get_secret_values(self, secret_id: UUID) -> dict[str, str]: """Get the secret values for an existing secret. Args: @@ -297,7 +294,7 @@ def get_secret_values(self, secret_id: UUID) -> Dict[str, str]: ) # Convert the AWS secret tags to a metadata dictionary. - metadata: Dict[str, str] = { + metadata: dict[str, str] = { tag["Key"]: tag["Value"] for tag in describe_secret_response["Tags"] } @@ -328,7 +325,7 @@ def get_secret_values(self, secret_id: UUID) -> Dict[str, str]: def update_secret_values( self, secret_id: UUID, - secret_values: Dict[str, str], + secret_values: dict[str, str], ) -> None: """Updates secret values for an existing secret. diff --git a/src/zenml/zen_stores/secrets_stores/azure_secrets_store.py b/src/zenml/zen_stores/secrets_stores/azure_secrets_store.py index 81c8e7e7ae5..1ffe7296bec 100644 --- a/src/zenml/zen_stores/secrets_stores/azure_secrets_store.py +++ b/src/zenml/zen_stores/secrets_stores/azure_secrets_store.py @@ -18,8 +18,6 @@ from typing import ( Any, ClassVar, - Dict, - Type, cast, ) from uuid import UUID @@ -69,7 +67,7 @@ class AzureSecretsStoreConfiguration( @model_validator(mode="before") @classmethod @before_validator_handler - def populate_config(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def populate_config(cls, data: dict[str, Any]) -> dict[str, Any]: """Populate the connector configuration from legacy attributes. Args: @@ -126,7 +124,7 @@ class AzureSecretsStore(ServiceConnectorSecretsStore): config: AzureSecretsStoreConfiguration TYPE: ClassVar[SecretsStoreType] = SecretsStoreType.AZURE - CONFIG_TYPE: ClassVar[Type[ServiceConnectorSecretsStoreConfiguration]] = ( + CONFIG_TYPE: ClassVar[type[ServiceConnectorSecretsStoreConfiguration]] = ( AzureSecretsStoreConfiguration ) SERVICE_CONNECTOR_TYPE: ClassVar[str] = AZURE_CONNECTOR_TYPE @@ -197,7 +195,7 @@ def _get_azure_secret_id( def store_secret_values( self, secret_id: UUID, - secret_values: Dict[str, str], + secret_values: dict[str, str], ) -> None: """Store secret values for a new secret. @@ -227,7 +225,7 @@ def store_secret_values( logger.debug(f"Created Azure secret: {azure_secret_id}") - def get_secret_values(self, secret_id: UUID) -> Dict[str, str]: + def get_secret_values(self, secret_id: UUID) -> dict[str, str]: """Get the secret values for an existing secret. Args: @@ -283,7 +281,7 @@ def get_secret_values(self, secret_id: UUID) -> Dict[str, str]: def update_secret_values( self, secret_id: UUID, - secret_values: Dict[str, str], + secret_values: dict[str, str], ) -> None: """Updates secret values for an existing secret. diff --git a/src/zenml/zen_stores/secrets_stores/base_secrets_store.py b/src/zenml/zen_stores/secrets_stores/base_secrets_store.py index 779b0f09c0c..df674edcceb 100644 --- a/src/zenml/zen_stores/secrets_stores/base_secrets_store.py +++ b/src/zenml/zen_stores/secrets_stores/base_secrets_store.py @@ -18,9 +18,7 @@ TYPE_CHECKING, Any, ClassVar, - Dict, Optional, - Type, ) from uuid import UUID @@ -57,12 +55,12 @@ class BaseSecretsStore(BaseModel, SecretsStoreInterface, ABC): _zen_store: Optional["BaseZenStore"] = None TYPE: ClassVar[SecretsStoreType] - CONFIG_TYPE: ClassVar[Type[SecretsStoreConfiguration]] + CONFIG_TYPE: ClassVar[type[SecretsStoreConfiguration]] @model_validator(mode="before") @classmethod @before_validator_handler - def convert_config(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def convert_config(cls, data: dict[str, Any]) -> dict[str, Any]: """Method to infer the correct type of the config and convert. Args: @@ -164,7 +162,7 @@ def __init__( @staticmethod def _load_custom_store_class( store_config: SecretsStoreConfiguration, - ) -> Type["BaseSecretsStore"]: + ) -> type["BaseSecretsStore"]: """Loads the custom secrets store class from the given config. Args: @@ -196,7 +194,7 @@ def _load_custom_store_class( @staticmethod def get_store_class( store_config: SecretsStoreConfiguration, - ) -> Type["BaseSecretsStore"]: + ) -> type["BaseSecretsStore"]: """Returns the class of the given secrets store type. Args: @@ -301,8 +299,8 @@ def zen_store(self) -> "BaseZenStore": def _get_secret_metadata( self, - secret_id: Optional[UUID] = None, - ) -> Dict[str, str]: + secret_id: UUID | None = None, + ) -> dict[str, str]: """Get a dictionary with metadata that can be used as tags/labels. This utility method can be used with Secrets Managers that can @@ -325,7 +323,7 @@ def _get_secret_metadata( # from other secrets that might be stored in the same backend and # to distinguish between different ZenML deployments using the same # backend. - metadata: Dict[str, str] = { + metadata: dict[str, str] = { ZENML_SECRET_LABEL: str(self.zen_store.get_store_info().id) } @@ -338,7 +336,7 @@ def _get_secret_metadata( def _verify_secret_metadata( self, secret_id: UUID, - metadata: Dict[str, str], + metadata: dict[str, str], ) -> None: """Verify that the given metadata corresponds to a valid ZenML secret. diff --git a/src/zenml/zen_stores/secrets_stores/gcp_secrets_store.py b/src/zenml/zen_stores/secrets_stores/gcp_secrets_store.py index ab640e7a76b..01d59b8b437 100644 --- a/src/zenml/zen_stores/secrets_stores/gcp_secrets_store.py +++ b/src/zenml/zen_stores/secrets_stores/gcp_secrets_store.py @@ -18,9 +18,6 @@ from typing import ( Any, ClassVar, - Dict, - Optional, - Type, cast, ) from uuid import UUID @@ -82,7 +79,7 @@ def project_id(self) -> str: @model_validator(mode="before") @classmethod @before_validator_handler - def populate_config(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def populate_config(cls, data: dict[str, Any]) -> dict[str, Any]: """Populate the connector configuration from legacy attributes. Args: @@ -132,13 +129,13 @@ class GCPSecretsStore(ServiceConnectorSecretsStore): config: GCPSecretsStoreConfiguration TYPE: ClassVar[SecretsStoreType] = SecretsStoreType.GCP - CONFIG_TYPE: ClassVar[Type[ServiceConnectorSecretsStoreConfiguration]] = ( + CONFIG_TYPE: ClassVar[type[ServiceConnectorSecretsStoreConfiguration]] = ( GCPSecretsStoreConfiguration ) SERVICE_CONNECTOR_TYPE: ClassVar[str] = GCP_CONNECTOR_TYPE SERVICE_CONNECTOR_RESOURCE_TYPE: ClassVar[str] = GCP_RESOURCE_TYPE - _client: Optional[SecretManagerServiceClient] = None + _client: SecretManagerServiceClient | None = None @property def client(self) -> SecretManagerServiceClient: @@ -203,7 +200,7 @@ def _get_gcp_secret_name( def store_secret_values( self, secret_id: UUID, - secret_values: Dict[str, str], + secret_values: dict[str, str], ) -> None: """Store secret values for a new secret. @@ -246,7 +243,7 @@ def store_secret_values( logger.debug(f"Created GCP secret {gcp_secret.name}") - def get_secret_values(self, secret_id: UUID) -> Dict[str, str]: + def get_secret_values(self, secret_id: UUID) -> dict[str, str]: """Get the secret values for an existing secret. Args: @@ -312,7 +309,7 @@ def get_secret_values(self, secret_id: UUID) -> Dict[str, str]: def update_secret_values( self, secret_id: UUID, - secret_values: Dict[str, str], + secret_values: dict[str, str], ) -> None: """Updates secret values for an existing secret. diff --git a/src/zenml/zen_stores/secrets_stores/hashicorp_secrets_store.py b/src/zenml/zen_stores/secrets_stores/hashicorp_secrets_store.py index 7925eaf9aae..be24146d89a 100644 --- a/src/zenml/zen_stores/secrets_stores/hashicorp_secrets_store.py +++ b/src/zenml/zen_stores/secrets_stores/hashicorp_secrets_store.py @@ -17,9 +17,6 @@ from typing import ( Any, ClassVar, - Dict, - Optional, - Type, ) from uuid import UUID @@ -91,15 +88,15 @@ class HashiCorpVaultSecretsStoreConfiguration(SecretsStoreConfiguration): type: SecretsStoreType = SecretsStoreType.HASHICORP vault_addr: str - vault_namespace: Optional[str] = None - mount_point: Optional[str] = None + vault_namespace: str | None = None + mount_point: str | None = None auth_method: HashiCorpVaultAuthMethod = HashiCorpVaultAuthMethod.TOKEN - auth_mount_point: Optional[str] = None - vault_token: Optional[PlainSerializedSecretStr] = None - app_role_id: Optional[str] = None - app_secret_id: Optional[str] = None - aws_role: Optional[str] = None - aws_header_value: Optional[str] = None + auth_mount_point: str | None = None + vault_token: PlainSerializedSecretStr | None = None + app_role_id: str | None = None + app_secret_id: str | None = None + aws_role: str | None = None + aws_header_value: str | None = None max_versions: int = 1 model_config = ConfigDict(extra="ignore") @@ -131,13 +128,13 @@ class HashiCorpVaultSecretsStore(BaseSecretsStore): config: HashiCorpVaultSecretsStoreConfiguration TYPE: ClassVar[SecretsStoreType] = SecretsStoreType.HASHICORP - CONFIG_TYPE: ClassVar[Type[SecretsStoreConfiguration]] = ( + CONFIG_TYPE: ClassVar[type[SecretsStoreConfiguration]] = ( HashiCorpVaultSecretsStoreConfiguration ) - _client: Optional[hvac.Client] = None - _expires_at: Optional[datetime] = None - _renew_at: Optional[datetime] = None + _client: hvac.Client | None = None + _expires_at: datetime | None = None + _renew_at: datetime | None = None @property def client(self) -> hvac.Client: @@ -150,7 +147,7 @@ def client(self) -> hvac.Client: ValueError: If the configuration is invalid. """ - def update_ttls(response: Dict[str, Any]) -> None: + def update_ttls(response: dict[str, Any]) -> None: """Update the TTLs for the client. Args: @@ -286,7 +283,7 @@ def _get_vault_secret_id( def store_secret_values( self, secret_id: UUID, - secret_values: Dict[str, str], + secret_values: dict[str, str], ) -> None: """Store secret values for a new secret. @@ -319,7 +316,7 @@ def store_secret_values( logger.debug(f"Created HashiCorp Vault secret: {vault_secret_id}") - def get_secret_values(self, secret_id: UUID) -> Dict[str, str]: + def get_secret_values(self, secret_id: UUID) -> dict[str, str]: """Get the secret values for an existing secret. Args: @@ -385,7 +382,7 @@ def get_secret_values(self, secret_id: UUID) -> Dict[str, str]: def update_secret_values( self, secret_id: UUID, - secret_values: Dict[str, str], + secret_values: dict[str, str], ) -> None: """Updates secret values for an existing secret. diff --git a/src/zenml/zen_stores/secrets_stores/secrets_store_interface.py b/src/zenml/zen_stores/secrets_stores/secrets_store_interface.py index efe5b66f493..39ad9025382 100644 --- a/src/zenml/zen_stores/secrets_stores/secrets_store_interface.py +++ b/src/zenml/zen_stores/secrets_stores/secrets_store_interface.py @@ -14,7 +14,6 @@ """ZenML secrets store interface.""" from abc import ABC, abstractmethod -from typing import Dict from uuid import UUID @@ -44,7 +43,7 @@ def _initialize(self) -> None: def store_secret_values( self, secret_id: UUID, - secret_values: Dict[str, str], + secret_values: dict[str, str], ) -> None: """Store secret values for a new secret. @@ -54,7 +53,7 @@ def store_secret_values( """ @abstractmethod - def get_secret_values(self, secret_id: UUID) -> Dict[str, str]: + def get_secret_values(self, secret_id: UUID) -> dict[str, str]: """Get the secret values for an existing secret. Args: @@ -72,7 +71,7 @@ def get_secret_values(self, secret_id: UUID) -> Dict[str, str]: def update_secret_values( self, secret_id: UUID, - secret_values: Dict[str, str], + secret_values: dict[str, str], ) -> None: """Updates secret values for an existing secret. diff --git a/src/zenml/zen_stores/secrets_stores/service_connector_secrets_store.py b/src/zenml/zen_stores/secrets_stores/service_connector_secrets_store.py index 2899f94f590..f67682be7ca 100644 --- a/src/zenml/zen_stores/secrets_stores/service_connector_secrets_store.py +++ b/src/zenml/zen_stores/secrets_stores/service_connector_secrets_store.py @@ -19,9 +19,6 @@ from typing import ( Any, ClassVar, - Dict, - Optional, - Type, ) from pydantic import Field, model_validator @@ -53,12 +50,12 @@ class ServiceConnectorSecretsStoreConfiguration(SecretsStoreConfiguration): """ auth_method: str - auth_config: Dict[str, Any] = Field(default_factory=dict) + auth_config: dict[str, Any] = Field(default_factory=dict) @model_validator(mode="before") @classmethod @before_validator_handler - def validate_auth_config(cls, data: Dict[str, Any]) -> Dict[str, Any]: + def validate_auth_config(cls, data: dict[str, Any]) -> dict[str, Any]: """Convert the authentication configuration if given in JSON format. Args: @@ -99,13 +96,13 @@ class ServiceConnectorSecretsStore(BaseSecretsStore): """ config: ServiceConnectorSecretsStoreConfiguration - CONFIG_TYPE: ClassVar[Type[ServiceConnectorSecretsStoreConfiguration]] + CONFIG_TYPE: ClassVar[type[ServiceConnectorSecretsStoreConfiguration]] SERVICE_CONNECTOR_TYPE: ClassVar[str] SERVICE_CONNECTOR_RESOURCE_TYPE: ClassVar[str] - _connector: Optional[ServiceConnector] = None - _client: Optional[Any] = None - _lock: Optional[Lock] = None + _connector: ServiceConnector | None = None + _client: Any | None = None + _lock: Lock | None = None def _initialize(self) -> None: """Initialize the secrets store.""" diff --git a/src/zenml/zen_stores/secrets_stores/sql_secrets_store.py b/src/zenml/zen_stores/secrets_stores/sql_secrets_store.py index 54e9a1c12d8..68d9e140ad4 100644 --- a/src/zenml/zen_stores/secrets_stores/sql_secrets_store.py +++ b/src/zenml/zen_stores/secrets_stores/sql_secrets_store.py @@ -28,9 +28,6 @@ TYPE_CHECKING, Any, ClassVar, - Dict, - Optional, - Type, ) from uuid import UUID @@ -74,7 +71,7 @@ class SqlSecretsStoreConfiguration(SecretsStoreConfiguration): """ type: SecretsStoreType = SecretsStoreType.SQL - encryption_key: Optional[PlainSerializedSecretStr] = None + encryption_key: PlainSerializedSecretStr | None = None model_config = ConfigDict( # Don't validate attributes when assigning them. This is necessary # because the certificate attributes can be expanded to the contents @@ -99,11 +96,11 @@ class SqlSecretsStore(BaseSecretsStore): config: SqlSecretsStoreConfiguration TYPE: ClassVar[SecretsStoreType] = SecretsStoreType.SQL - CONFIG_TYPE: ClassVar[Type[SecretsStoreConfiguration]] = ( + CONFIG_TYPE: ClassVar[type[SecretsStoreConfiguration]] = ( SqlSecretsStoreConfiguration ) - _encryption_engine: Optional[AesGcmEngine] = None + _encryption_engine: AesGcmEngine | None = None def __init__( self, @@ -185,7 +182,7 @@ def _initialize(self) -> None: def store_secret_values( self, secret_id: UUID, - secret_values: Dict[str, str], + secret_values: dict[str, str], ) -> None: """Store secret values for a new secret. @@ -212,7 +209,7 @@ def store_secret_values( session.add(secret_in_db) session.commit() - def get_secret_values(self, secret_id: UUID) -> Dict[str, str]: + def get_secret_values(self, secret_id: UUID) -> dict[str, str]: """Get the secret values for an existing secret. Args: @@ -246,7 +243,7 @@ def get_secret_values(self, secret_id: UUID) -> Dict[str, str]: def update_secret_values( self, secret_id: UUID, - secret_values: Dict[str, str], + secret_values: dict[str, str], ) -> None: """Updates secret values for an existing secret. diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 451007e05eb..88f3b7a1fed 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -41,24 +41,17 @@ from typing import ( TYPE_CHECKING, Any, - Callable, ClassVar, - Dict, ForwardRef, - List, Literal, NoReturn, Optional, - Sequence, - Set, - Tuple, - Type, TypeVar, - Union, cast, get_origin, overload, ) +from collections.abc import Callable, Sequence from uuid import UUID from packaging import version @@ -463,7 +456,7 @@ def exponential_backoff_with_jitter( class Session(SqlModelSession): """Session subclass that automatically tracks duration and calling context.""" - def _get_metrics(self) -> Dict[str, Any]: + def _get_metrics(self) -> dict[str, Any]: """Get the metrics for the session. Returns: @@ -548,9 +541,9 @@ def __enter__(self) -> "Session": def __exit__( self, - exc_type: Optional[Any], - exc_val: Optional[Any], - exc_tb: Optional[Any], + exc_type: Any | None, + exc_val: Any | None, + exc_tb: Any | None, ) -> None: """Exit the context manager. @@ -621,19 +614,19 @@ class SqlZenStoreConfiguration(StoreConfiguration): type: StoreType = StoreType.SQL - secrets_store: Optional[SerializeAsAny[SecretsStoreConfiguration]] = None - backup_secrets_store: Optional[ + secrets_store: SerializeAsAny[SecretsStoreConfiguration] | None = None + backup_secrets_store: None | ( SerializeAsAny[SecretsStoreConfiguration] - ] = None + ) = None - driver: Optional[SQLDatabaseDriver] = None - database: Optional[str] = None - username: Optional[PlainSerializedSecretStr] = None - password: Optional[PlainSerializedSecretStr] = None + driver: SQLDatabaseDriver | None = None + database: str | None = None + username: PlainSerializedSecretStr | None = None + password: PlainSerializedSecretStr | None = None ssl: bool = False - ssl_ca: Optional[PlainSerializedSecretStr] = None - ssl_cert: Optional[PlainSerializedSecretStr] = None - ssl_key: Optional[PlainSerializedSecretStr] = None + ssl_ca: PlainSerializedSecretStr | None = None + ssl_cert: PlainSerializedSecretStr | None = None + ssl_key: PlainSerializedSecretStr | None = None ssl_verify_server_cert: bool = False pool_size: int = 20 max_overflow: int = 20 @@ -647,12 +640,12 @@ class SqlZenStoreConfiguration(StoreConfiguration): SQL_STORE_BACKUP_DIRECTORY_NAME, ) ) - backup_database: Optional[str] = None + backup_database: str | None = None @field_validator("secrets_store") @classmethod def validate_secrets_store( - cls, secrets_store: Optional[SecretsStoreConfiguration] + cls, secrets_store: SecretsStoreConfiguration | None ) -> SecretsStoreConfiguration: """Ensures that the secrets store is initialized with a default SQL secrets store. @@ -762,8 +755,8 @@ def _validate_url(self) -> "SqlZenStoreConfiguration": if sql_url.query: def _get_query_result( - result: Union[str, Tuple[str, ...]], - ) -> Optional[str]: + result: str | tuple[str, ...], + ) -> str | None: """Returns the only or the first result of a query. Args: @@ -875,8 +868,8 @@ def supports_url_scheme(cls, url: str) -> bool: def get_sqlalchemy_config( self, - database: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any], Dict[str, Any]]: + database: str | None = None, + ) -> tuple[URL, dict[str, Any], dict[str, Any]]: """Get the SQLAlchemy engine configuration for the SQL ZenML store. Args: @@ -890,7 +883,7 @@ def get_sqlalchemy_config( NotImplementedError: If the SQL driver is not supported. """ sql_url = make_url(self.url) - sqlalchemy_connect_args: Dict[str, Any] = {} + sqlalchemy_connect_args: dict[str, Any] = {} engine_args = {} if sql_url.drivername == SQLDatabaseDriver.SQLITE: assert self.database is not None @@ -922,7 +915,7 @@ def get_sqlalchemy_config( database=database, ) - sqlalchemy_ssl_args: Dict[str, Any] = {} + sqlalchemy_ssl_args: dict[str, Any] = {} # Handle SSL params if self.ssl: @@ -973,16 +966,16 @@ class SqlZenStore(BaseZenStore): config: SqlZenStoreConfiguration skip_migrations: bool = False TYPE: ClassVar[StoreType] = StoreType.SQL - CONFIG_TYPE: ClassVar[Type[StoreConfiguration]] = SqlZenStoreConfiguration + CONFIG_TYPE: ClassVar[type[StoreConfiguration]] = SqlZenStoreConfiguration - _engine: Optional[Engine] = None - _migration_utils: Optional[MigrationUtils] = None - _alembic: Optional[Alembic] = None - _secrets_store: Optional[BaseSecretsStore] = None - _backup_secrets_store: Optional[BaseSecretsStore] = None + _engine: Engine | None = None + _migration_utils: MigrationUtils | None = None + _alembic: Alembic | None = None + _secrets_store: BaseSecretsStore | None = None + _backup_secrets_store: BaseSecretsStore | None = None _should_send_user_enriched_events: bool = False - _cached_onboarding_state: Optional[Set[str]] = None - _default_user: Optional[UserResponse] = None + _cached_onboarding_state: set[str] | None = None + _default_user: UserResponse | None = None @property def secrets_store(self) -> "BaseSecretsStore": @@ -1108,22 +1101,22 @@ def _send_user_enriched_events_if_necessary(self) -> None: def filter_and_paginate( cls, session: Session, - query: Union[Select[Any], SelectOfScalar[Any]], - table: Type[AnySchema], + query: Select[Any] | SelectOfScalar[Any], + table: type[AnySchema], filter_model: BaseFilter, - custom_schema_to_model_conversion: Optional[ + custom_schema_to_model_conversion: None | ( Callable[..., AnyResponse] - ] = None, - custom_fetch: Optional[ + ) = None, + custom_fetch: None | ( Callable[ [ Session, - Union[Select[Any], SelectOfScalar[Any]], + Select[Any] | SelectOfScalar[Any], BaseFilter, ], Sequence[Any], ] - ] = None, + ) = None, hydrate: bool = False, apply_query_options_from_schema: bool = False, ) -> Page[AnyResponse]: @@ -1161,7 +1154,7 @@ def filter_and_paginate( query = query.distinct() # Get the total amount of items in the database for a given query - custom_fetch_result: Optional[Sequence[Any]] = None + custom_fetch_result: Sequence[Any] | None = None if custom_fetch: custom_fetch_result = custom_fetch(session, query, filter_model) total = len(custom_fetch_result) @@ -1213,7 +1206,7 @@ def filter_and_paginate( item_schemas = query_result.all() # Convert this page of items from schemas to models. - items: List[AnyResponse] = [] + items: list[AnyResponse] = [] for schema in item_schemas: # If a custom conversion function is provided, use it. if custom_schema_to_model_conversion: @@ -1377,10 +1370,10 @@ def _get_db_backup_file_path(self) -> str: def backup_database( self, - strategy: Optional[DatabaseBackupStrategy] = None, - location: Optional[str] = None, + strategy: DatabaseBackupStrategy | None = None, + location: str | None = None, overwrite: bool = False, - ) -> Tuple[str, Any]: + ) -> tuple[str, Any]: """Backup the database. Args: @@ -1450,8 +1443,8 @@ def backup_database( def restore_database( self, - strategy: Optional[DatabaseBackupStrategy] = None, - location: Optional[Any] = None, + strategy: DatabaseBackupStrategy | None = None, + location: Any | None = None, cleanup: bool = False, ) -> None: """Restore the database. @@ -1506,8 +1499,8 @@ def restore_database( def cleanup_database_backup( self, - strategy: Optional[DatabaseBackupStrategy] = None, - location: Optional[Any] = None, + strategy: DatabaseBackupStrategy | None = None, + location: Any | None = None, ) -> None: """Delete the database backup. @@ -1615,8 +1608,8 @@ def migrate_database(self) -> None: self.config.backup_strategy != DatabaseBackupStrategy.DISABLED and set(current_revisions) != set(head_revisions) ) - backup_location: Optional[Any] = None - backup_location_msg: Optional[str] = None + backup_location: Any | None = None + backup_location_msg: str | None = None if backup_enabled: try: @@ -1919,7 +1912,7 @@ def _update_last_user_activity_timestamp( session.commit() session.refresh(settings) - def get_onboarding_state(self) -> List[str]: + def get_onboarding_state(self) -> list[str]: """Get the server onboarding state. Returns: @@ -1936,7 +1929,7 @@ def get_onboarding_state(self) -> List[str]: return [] def _update_onboarding_state( - self, completed_steps: Set[str], session: Session + self, completed_steps: set[str], session: Session ) -> None: """Update the server onboarding state. @@ -1962,7 +1955,7 @@ def _update_onboarding_state( json.loads(settings.onboarding_state) ) - def update_onboarding_state(self, completed_steps: Set[str]) -> None: + def update_onboarding_state(self, completed_steps: set[str]) -> None: """Update the server onboarding state. Args: @@ -1975,7 +1968,7 @@ def update_onboarding_state(self, completed_steps: Set[str]) -> None: def activate_server( self, request: ServerActivationRequest - ) -> Optional[UserResponse]: + ) -> UserResponse | None: """Activate the server and optionally create the default admin user. Args: @@ -2212,7 +2205,7 @@ def delete_action(self, action_id: UUID) -> None: def _get_api_key( self, service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, session: Session, ) -> APIKeySchema: """Helper method to fetch an API key by name or ID. @@ -2311,7 +2304,7 @@ def create_api_key( def get_api_key( self, service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, hydrate: bool = True, ) -> APIKeyResponse: """Get an API key for a service account. @@ -2400,7 +2393,7 @@ def list_api_keys( def update_api_key( self, service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, api_key_update: APIKeyUpdate, ) -> APIKeyResponse: """Update an API key for a service account. @@ -2490,7 +2483,7 @@ def update_internal_api_key( def rotate_api_key( self, service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, rotate_request: APIKeyRotateRequest, ) -> APIKeyResponse: """Rotate an API key for a service account. @@ -2529,7 +2522,7 @@ def rotate_api_key( def delete_api_key( self, service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, ) -> None: """Delete an API key for a service account. @@ -2554,8 +2547,8 @@ def _get_api_transaction( self, api_transaction_id: UUID, session: Session, - method: Optional[str] = None, - url: Optional[str] = None, + method: str | None = None, + url: str | None = None, ) -> ApiTransactionSchema: """Retrieve or create a new API transaction. @@ -2627,7 +2620,7 @@ def _cleanup_expired_api_transactions(self, session: Session) -> None: def get_or_create_api_transaction( self, api_transaction: ApiTransactionRequest - ) -> Tuple[ApiTransactionResponse, bool]: + ) -> tuple[ApiTransactionResponse, bool]: """Retrieve or create a new API transaction. Args: @@ -3177,7 +3170,7 @@ def create_artifact_version( assert artifact_version.artifact_id - artifact_version_schema: Optional[ArtifactVersionSchema] = None + artifact_version_schema: ArtifactVersionSchema | None = None if artifact_version.version is None: # No explicit version in the request -> We will try to @@ -3274,8 +3267,8 @@ def create_artifact_version( # Save metadata of the artifact if artifact_version.metadata: - values: Dict[str, "MetadataType"] = {} - types: Dict[str, "MetadataTypeEnum"] = {} + values: dict[str, "MetadataType"] = {} + types: dict[str, "MetadataTypeEnum"] = {} for key, value in artifact_version.metadata.items(): # Skip metadata that is too large to be stored in the DB. if len(json.dumps(value)) > TEXT_FIELD_MAX_LENGTH: @@ -3317,8 +3310,8 @@ def create_artifact_version( ) def batch_create_artifact_versions( - self, artifact_versions: List[ArtifactVersionRequest] - ) -> List[ArtifactVersionResponse]: + self, artifact_versions: list[ArtifactVersionRequest] + ) -> list[ArtifactVersionResponse]: """Creates a batch of artifact versions. Args: @@ -3449,7 +3442,7 @@ def delete_artifact_version(self, artifact_version_id: UUID) -> None: def prune_artifact_versions( self, - project_name_or_id: Union[str, UUID], + project_name_or_id: str | UUID, only_versions: bool = True, ) -> None: """Prunes unused artifact versions and their artifacts. @@ -4012,7 +4005,7 @@ def delete_stack_component(self, component_id: UUID) -> None: session.commit() def count_stack_components( - self, filter_model: Optional[ComponentFilter] = None + self, filter_model: ComponentFilter | None = None ) -> int: """Count all components. @@ -4136,8 +4129,8 @@ def get_authorized_device( def get_internal_authorized_device( self, - device_id: Optional[UUID] = None, - client_id: Optional[UUID] = None, + device_id: UUID | None = None, + client_id: UUID | None = None, hydrate: bool = True, ) -> OAuthDeviceInternalResponse: """Gets a specific OAuth 2.0 authorized device for internal use. @@ -4842,7 +4835,7 @@ def _create_or_reuse_code_reference( session: Session, project_id: UUID, code_reference: Optional["CodeReferenceRequest"], - ) -> Optional[UUID]: + ) -> UUID | None: """Creates or reuses a code reference. Args: @@ -5063,8 +5056,8 @@ def get_snapshot( self, snapshot_id: UUID, hydrate: bool = True, - step_configuration_filter: Optional[List[str]] = None, - include_config_schema: Optional[bool] = None, + step_configuration_filter: list[str] | None = None, + include_config_schema: bool | None = None, ) -> PipelineSnapshotResponse: """Get a snapshot with a given ID. @@ -5464,8 +5457,8 @@ def _assert_curated_visualization_display_order_unique( *, resource_id: UUID, resource_type: VisualizationResourceTypes, - display_order: Optional[int], - exclude_visualization_id: Optional[UUID] = None, + display_order: int | None, + exclude_visualization_id: UUID | None = None, ) -> None: """Ensure curated visualizations per resource use unique display orders. @@ -5542,8 +5535,8 @@ def create_curated_visualization( ) project_id = visualization.project - resource_schema_map: Dict[ - VisualizationResourceTypes, Type[BaseSchema] + resource_schema_map: dict[ + VisualizationResourceTypes, type[BaseSchema] ] = { VisualizationResourceTypes.DEPLOYMENT: DeploymentSchema, VisualizationResourceTypes.MODEL: ModelSchema, @@ -5873,7 +5866,7 @@ def delete_run_template(self, template_id: UUID) -> None: def run_template( self, template_id: UUID, - run_configuration: Optional[PipelineRunConfiguration] = None, + run_configuration: PipelineRunConfiguration | None = None, ) -> NoReturn: """Run a template. @@ -6105,8 +6098,8 @@ def get_pipeline_run_dag(self, pipeline_run_id: UUID) -> PipelineRunDAG: ) for config_table in snapshot.step_configurations } - regular_output_artifact_nodes: Dict[ - str, Dict[str, PipelineRunDAG.Node] + regular_output_artifact_nodes: dict[ + str, dict[str, PipelineRunDAG.Node] ] = defaultdict(dict) def _get_regular_output_artifact_node( @@ -6124,7 +6117,7 @@ def _get_regular_output_artifact_node( upstream_steps = set(step.spec.upstream_steps) step_id = None - metadata: Dict[str, Any] = {} + metadata: dict[str, Any] = {} step_run = step_runs.get(step_name) if step_run: @@ -6297,7 +6290,7 @@ def _get_regular_output_artifact_node( ] = artifact_node for triggered_run in step_run.triggered_runs: - triggered_run_metadata: Dict[str, Any] = { + triggered_run_metadata: dict[str, Any] = { "status": triggered_run.status, } @@ -6637,7 +6630,7 @@ def get_run_status(self, run_id: UUID) -> ExecutionStatus: def _check_if_run_in_progress( self, run_id: UUID - ) -> Tuple[bool, Optional[datetime]]: + ) -> tuple[bool, datetime | None]: """Check if a pipeline run is in progress. Args: @@ -6659,7 +6652,7 @@ def _replace_placeholder_run( self, pipeline_run: PipelineRunRequest, session: Session, - pre_replacement_hook: Optional[Callable[[], None]] = None, + pre_replacement_hook: Callable[[], None] | None = None, ) -> PipelineRunResponse: """Replace a placeholder run with the requested pipeline run. @@ -6784,8 +6777,8 @@ def _get_run_by_orchestrator_run_id( def get_or_create_run( self, pipeline_run: PipelineRunRequest, - pre_creation_hook: Optional[Callable[[], None]] = None, - ) -> Tuple[PipelineRunResponse, bool]: + pre_creation_hook: Callable[[], None] | None = None, + ) -> tuple[PipelineRunResponse, bool]: """Gets or creates a pipeline run. If a run with the same ID or name already exists, it is returned. @@ -7116,7 +7109,7 @@ def create_run_metadata(self, run_metadata: RunMetadataRequest) -> None: ) for resource in run_metadata.resources: - reference_schema: Type[BaseSchema] + reference_schema: type[BaseSchema] if resource.type == MetadataResourceTypes.PIPELINE_RUN: reference_schema = PipelineRunSchema elif resource.type == MetadataResourceTypes.STEP_RUN: @@ -7324,8 +7317,8 @@ def _check_sql_secret_scope( secret_name: str, private: bool, user: UUID, - exclude_secret_id: Optional[UUID] = None, - ) -> Tuple[bool, str]: + exclude_secret_id: UUID | None = None, + ) -> tuple[bool, str]: """Checks if a secret with the given name already exists with the given private status. This method enforces the following private status rules: @@ -7381,7 +7374,7 @@ def _check_sql_secret_scope( return False, "" def _set_secret_values( - self, secret_id: UUID, values: Dict[str, str], backup: bool = True + self, secret_id: UUID, values: dict[str, str], backup: bool = True ) -> None: """Sets the values of a secret in the configured secrets store. @@ -7430,7 +7423,7 @@ def do_backup() -> bool: do_backup() def _backup_secret_values( - self, secret_id: UUID, values: Dict[str, str] + self, secret_id: UUID, values: dict[str, str] ) -> None: """Backs up the values of a secret in the configured backup secrets store. @@ -7458,7 +7451,7 @@ def _backup_secret_values( def _get_secret_values( self, secret_id: UUID, use_backup: bool = True - ) -> Dict[str, str]: + ) -> dict[str, str]: """Gets the values of a secret from the configured secrets store. Args: @@ -7511,7 +7504,7 @@ def _get_secret_values( ) raise - def _get_backup_secret_values(self, secret_id: UUID) -> Dict[str, str]: + def _get_backup_secret_values(self, secret_id: UUID) -> dict[str, str]: """Gets the backup values of a secret from the configured backup secrets store. Args: @@ -7535,10 +7528,10 @@ def _get_backup_secret_values(self, secret_id: UUID) -> Dict[str, str]: def _update_secret_values( self, secret_id: UUID, - values: Dict[str, Optional[str]], + values: dict[str, str | None], overwrite: bool = False, backup: bool = True, - ) -> Dict[str, str]: + ) -> dict[str, str]: """Updates the values of a secret in the configured secrets store. This method will update the existing values with the new values @@ -7574,7 +7567,7 @@ def _update_secret_values( # store backend or when the secrets store backend is reconfigured to # a different account, provider, region etc. without migrating # the actual existing secrets themselves. - new_values: Dict[str, str] = { + new_values: dict[str, str] = { k: v for k, v in values.items() if v is not None } self._set_secret_values( @@ -7705,7 +7698,7 @@ def _delete_backup_secret_values( def _link_secrets_to_resource( self, - secrets: Optional[Sequence[Union[str, UUID]]], + secrets: Sequence[str | UUID] | None, resource: BaseSchema, session: Session, ) -> None: @@ -7749,11 +7742,10 @@ def _link_secrets_to_resource( # The secret resource already exists, so we rollback the session # and do nothing. session.rollback() - pass def _unlink_secrets_from_resource( self, - secrets: Optional[Sequence[Union[str, UUID]]], + secrets: Sequence[str | UUID] | None, resource: BaseSchema, session: Session, ) -> None: @@ -7917,7 +7909,7 @@ def get_secret( def get_secret_by_name_or_id( self, - secret_name_or_id: Union[str, UUID], + secret_name_or_id: str | UUID, include_secret_values: bool = False, ) -> SecretResponse: """Get a secret by name or ID. @@ -8249,7 +8241,7 @@ def restore_secrets( try: self._update_secret_values( secret_id=secret.id, - values=cast(Dict[str, Optional[str]], values), + values=cast(dict[str, Optional[str]], values), overwrite=True, backup=False, ) @@ -8279,9 +8271,9 @@ def restore_secrets( @track_decorator(AnalyticsEvent.CREATED_SERVICE_ACCOUNT) def create_service_account( self, - service_account: Union[ - ServiceAccountRequest, ServiceAccountInternalRequest - ], + service_account: ( + ServiceAccountRequest | ServiceAccountInternalRequest + ), ) -> ServiceAccountResponse: """Creates a new service account. @@ -8325,7 +8317,7 @@ def create_service_account( def get_service_account( self, - service_account_name_or_id: Union[str, UUID], + service_account_name_or_id: str | UUID, hydrate: bool = True, ) -> ServiceAccountResponse: """Gets a specific service account. @@ -8386,7 +8378,7 @@ def list_service_accounts( def update_service_account( self, - service_account_name_or_id: Union[str, UUID], + service_account_name_or_id: str | UUID, service_account_update: ServiceAccountUpdate, ) -> ServiceAccountResponse: """Updates an existing service account. @@ -8444,7 +8436,7 @@ def update_service_account( def delete_service_account( self, - service_account_name_or_id: Union[str, UUID], + service_account_name_or_id: str | UUID, ) -> None: """Delete a service account. @@ -8621,10 +8613,10 @@ def list_service_connectors( def fetch_connectors( session: Session, - query: Union[ - Select[ServiceConnectorSchema], - SelectOfScalar[ServiceConnectorSchema], - ], + query: ( + Select[ServiceConnectorSchema] | + SelectOfScalar[ServiceConnectorSchema] + ), filter_model: BaseFilter, ) -> Sequence[ServiceConnectorSchema]: """Custom fetch function for connector filtering and pagination. @@ -8879,9 +8871,9 @@ def delete_service_connector(self, service_connector_id: UUID) -> None: def _create_connector_secret( self, connector_name: str, - secrets: Dict[str, PlainSerializedSecretStr], + secrets: dict[str, PlainSerializedSecretStr], session: Session, - ) -> Optional[UUID]: + ) -> UUID | None: """Creates a new secret to store the service connector secret credentials. Args: @@ -8954,10 +8946,10 @@ def _populate_connector_type( @staticmethod def _list_filtered_service_connectors( session: Session, - query: Union[ - Select[ServiceConnectorSchema], - SelectOfScalar[ServiceConnectorSchema], - ], + query: ( + Select[ServiceConnectorSchema] | + SelectOfScalar[ServiceConnectorSchema] + ), filter_model: ServiceConnectorFilter, ) -> Sequence[ServiceConnectorSchema]: """Refine a service connector query. @@ -8993,10 +8985,10 @@ def _list_filtered_service_connectors( def _update_connector_secret( self, connector_name: str, - existing_secret_id: Optional[UUID], - secrets: Dict[str, PlainSerializedSecretStr], + existing_secret_id: UUID | None, + secrets: dict[str, PlainSerializedSecretStr], session: Session, - ) -> Optional[UUID]: + ) -> UUID | None: """Updates the secret for a service connector. Args: @@ -9062,8 +9054,8 @@ def verify_service_connector_config( def verify_service_connector( self, service_connector_id: UUID, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + resource_id: str | None = None, list_resources: bool = True, ) -> ServiceConnectorResourcesModel: """Verifies if a service connector instance has access to one or more resources. @@ -9097,8 +9089,8 @@ def verify_service_connector( def get_service_connector_client( self, service_connector_id: UUID, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + resource_id: str | None = None, ) -> ServiceConnectorResponse: """Get a service connector client for a service connector and given resource. @@ -9139,7 +9131,7 @@ def get_service_connector_client( def list_service_connector_resources( self, filter_model: ServiceConnectorFilter, - ) -> List[ServiceConnectorResourcesModel]: + ) -> list[ServiceConnectorResourcesModel]: """List resources that can be accessed by service connectors. Args: @@ -9160,7 +9152,7 @@ def list_service_connector_resources( filter_model=filter_model ).items - resource_list: List[ServiceConnectorResourcesModel] = [] + resource_list: list[ServiceConnectorResourcesModel] = [] for connector in service_connectors: if not service_connector_registry.is_registered(connector.type): @@ -9223,10 +9215,10 @@ def list_service_connector_resources( def list_service_connector_types( self, - connector_type: Optional[str] = None, - resource_type: Optional[str] = None, - auth_method: Optional[str] = None, - ) -> List[ServiceConnectorTypeModel]: + connector_type: str | None = None, + resource_type: str | None = None, + auth_method: str | None = None, + ) -> list[ServiceConnectorTypeModel]: """Get a list of service connector types. Args: @@ -9285,8 +9277,8 @@ def create_stack(self, stack: StackRequest) -> StackResponse: self._set_request_user_id(request_model=stack, session=session) # For clean-up purposes, each created entity is tracked here - service_connectors_created_ids: List[UUID] = [] - components_created_ids: List[UUID] = [] + service_connectors_created_ids: list[UUID] = [] + components_created_ids: list[UUID] = [] try: # Validate the name of the new stack @@ -9296,7 +9288,7 @@ def create_stack(self, stack: StackRequest) -> StackResponse: stack.labels = {} # Service Connectors - service_connectors: List[ServiceConnectorResponse] = [] + service_connectors: list[ServiceConnectorResponse] = [] orchestrator_components = stack.components[ StackComponentType.ORCHESTRATOR @@ -9386,7 +9378,7 @@ def create_stack(self, stack: StackRequest) -> StackResponse: continue # Stack Components - components_mapping: Dict[StackComponentType, List[UUID]] = {} + components_mapping: dict[StackComponentType, list[UUID]] = {} for ( component_type, components, @@ -9672,7 +9664,7 @@ def update_stack( session=session, ) - components: List["StackComponentSchema"] = [] + components: list["StackComponentSchema"] = [] if stack_update.components: for ( component_type, @@ -9734,7 +9726,7 @@ def delete_stack(self, stack_id: UUID) -> None: session.delete(stack) session.commit() - def count_stacks(self, filter_model: Optional[StackFilter]) -> int: + def count_stacks(self, filter_model: StackFilter | None) -> int: """Count all stacks. Args: @@ -9853,7 +9845,7 @@ def get_stack_deployment_config( self, provider: StackDeploymentProvider, stack_name: str, - location: Optional[str] = None, + location: str | None = None, ) -> StackDeploymentConfig: """Return the cloud provider console URL and configuration needed to deploy the ZenML stack. @@ -9874,9 +9866,9 @@ def get_stack_deployment_stack( self, provider: StackDeploymentProvider, stack_name: str, - location: Optional[str] = None, - date_start: Optional[datetime] = None, - ) -> Optional[DeployedStack]: + location: str | None = None, + date_start: datetime | None = None, + ) -> DeployedStack | None: """Return a matching ZenML stack that was deployed and registered. Args: @@ -10688,7 +10680,7 @@ def _update_pipeline_run_status( **stack_metadata, } - completed_onboarding_steps: Set[str] = { + completed_onboarding_steps: set[str] = { OnboardingStep.PIPELINE_RUN, OnboardingStep.OSS_ONBOARDING_COMPLETED, } @@ -10991,7 +10983,7 @@ def delete_trigger_execution(self, trigger_execution_id: UUID) -> None: @lru_cache(maxsize=1) def _get_resource_references( cls, - ) -> List[Tuple[Type[SQLModel], str]]: + ) -> list[tuple[type[SQLModel], str]]: """Get a list of all other table columns that reference the user table. Given that this list doesn't change at runtime, we cache it. @@ -11024,7 +11016,7 @@ def _get_resource_references( # To create this query, we need a list of all tables and their foreign # keys that point to the user table. - foreign_keys: List[Tuple[Type[SQLModel], str]] = [] + foreign_keys: list[tuple[type[SQLModel], str]] = [] for resource_attr in resource_attrs: # Extract the target schema from the annotation annotation = UserSchema.__annotations__[resource_attr] @@ -11220,7 +11212,7 @@ def create_user(self, user: UserRequest) -> UserResponse: def get_user( self, - user_name_or_id: Optional[Union[str, UUID]] = None, + user_name_or_id: str | UUID | None = None, include_private: bool = False, hydrate: bool = True, ) -> UserResponse: @@ -11252,7 +11244,7 @@ def get_user( # If a UUID is passed, we also allow fetching service accounts # with that ID. - service_account: Optional[bool] = False + service_account: bool | None = False if uuid_utils.is_valid_uuid(user_name_or_id): service_account = None user = self._get_account_schema( @@ -11268,7 +11260,7 @@ def get_user( ) def get_auth_user( - self, user_name_or_id: Union[str, UUID] + self, user_name_or_id: str | UUID ) -> UserAuthModel: """Gets the auth model to a specific user. @@ -11425,7 +11417,7 @@ def update_user( return updated_user - def delete_user(self, user_name_or_id: Union[str, UUID]) -> None: + def delete_user(self, user_name_or_id: str | UUID) -> None: """Deletes a user. Args: @@ -11573,7 +11565,7 @@ def create_project(self, project: ProjectRequest) -> ProjectResponse: return project_model def get_project( - self, project_name_or_id: Union[str, UUID], hydrate: bool = True + self, project_name_or_id: str | UUID, hydrate: bool = True ) -> ProjectResponse: """Get an existing project by name or ID. @@ -11670,7 +11662,7 @@ def update_project( include_metadata=True, include_resources=True ) - def delete_project(self, project_name_or_id: Union[str, UUID]) -> None: + def delete_project(self, project_name_or_id: str | UUID) -> None: """Deletes a project. Args: @@ -11698,7 +11690,7 @@ def delete_project(self, project_name_or_id: Union[str, UUID]) -> None: session.commit() def count_projects( - self, filter_model: Optional[ProjectFilter] = None + self, filter_model: ProjectFilter | None = None ) -> int: """Count all projects. @@ -11715,7 +11707,7 @@ def count_projects( def set_filter_project_id( self, filter_model: ProjectScopedFilter, - project_name_or_id: Optional[Union[UUID, str]] = None, + project_name_or_id: UUID | str | None = None, ) -> None: """Set the project ID on a filter model. @@ -11777,8 +11769,8 @@ def _default_project_enabled(self) -> bool: def _count_entity( self, - schema: Type[BaseSchema], - filter_model: Optional[BaseFilter] = None, + schema: type[BaseSchema], + filter_model: BaseFilter | None = None, ) -> int: """Return count of a given entity. @@ -11806,7 +11798,7 @@ def _count_entity( return int(entity_count) if entity_count else 0 def entity_exists( - self, entity_id: UUID, schema_class: Type[AnySchema] + self, entity_id: UUID, schema_class: type[AnySchema] ) -> bool: """Check whether an entity exists in the database. @@ -11825,8 +11817,8 @@ def entity_exists( return False if schema is None else True def get_entity_by_id( - self, entity_id: UUID, schema_class: Type[AnySchema] - ) -> Optional[AnyIdentifiedResponse]: + self, entity_id: UUID, schema_class: type[AnySchema] + ) -> AnyIdentifiedResponse | None: """Get an entity by ID. Args: @@ -11859,11 +11851,11 @@ def get_entity_by_id( @staticmethod def _get_schema_by_id( resource_id: UUID, - schema_class: Type[AnySchema], + schema_class: type[AnySchema], session: Session, - resource_type: Optional[str] = None, - project_id: Optional[UUID] = None, - query_options: Optional[Sequence[ExecutableOption]] = None, + resource_type: str | None = None, + project_id: UUID | None = None, + query_options: Sequence[ExecutableOption] | None = None, ) -> AnySchema: """Query a schema by its 'id' field. @@ -11911,10 +11903,10 @@ def _get_schema_by_id( @staticmethod def _get_schema_by_name_or_id( - object_name_or_id: Union[str, UUID], - schema_class: Type[AnyNamedSchema], + object_name_or_id: str | UUID, + schema_class: type[AnyNamedSchema], session: Session, - project_name_or_id: Optional[Union[UUID, str]] = None, + project_name_or_id: UUID | str | None = None, ) -> AnyNamedSchema: """Query a schema by its 'name' or 'id' field. @@ -11972,30 +11964,30 @@ def _get_schema_by_name_or_id( def _get_reference_schema_by_id( self, session: Session, - resource: Union[BaseRequest, BaseSchema], - reference_schema: Type[AnySchema], + resource: BaseRequest | BaseSchema, + reference_schema: type[AnySchema], reference_id: UUID, - reference_type: Optional[str] = None, + reference_type: str | None = None, ) -> AnySchema: ... @overload def _get_reference_schema_by_id( self, session: Session, - resource: Union[BaseRequest, BaseSchema], - reference_schema: Type[AnySchema], + resource: BaseRequest | BaseSchema, + reference_schema: type[AnySchema], reference_id: None, - reference_type: Optional[str] = None, + reference_type: str | None = None, ) -> None: ... def _get_reference_schema_by_id( self, session: Session, - resource: Union[BaseRequest, BaseSchema], - reference_schema: Type[AnySchema], - reference_id: Optional[UUID] = None, - reference_type: Optional[str] = None, - ) -> Optional[AnySchema]: + resource: BaseRequest | BaseSchema, + reference_schema: type[AnySchema], + reference_id: UUID | None = None, + reference_type: str | None = None, + ) -> AnySchema | None: """Fetch a referenced resource and verify scope relationship rules. This helper function is used for two things: @@ -12051,8 +12043,8 @@ def _get_reference_schema_by_id( if isinstance(resource, BaseSchema): operation = "updated" - resource_project_id: Optional[UUID] = None - resource_project_name: Optional[str] = None + resource_project_id: UUID | None = None + resource_project_name: str | None = None if isinstance(resource, ProjectScopedRequest): resource_project_id = resource.project resource_project_name = str(resource.project) @@ -12121,7 +12113,7 @@ def _set_filter_project_id( self, filter_model: ProjectScopedFilter, session: Session, - project_name_or_id: Optional[Union[UUID, str]] = None, + project_name_or_id: UUID | str | None = None, ) -> None: """Set the project ID on a filter model. @@ -12167,8 +12159,8 @@ def _set_filter_project_id( def _verify_name_uniqueness( self, - resource: Union[BaseRequest, BaseUpdate], - schema: Union[Type[AnyNamedSchema], AnyNamedSchema], + resource: BaseRequest | BaseUpdate, + schema: type[AnyNamedSchema] | AnyNamedSchema, session: Session, ) -> None: """Check the name uniqueness constraint for a given entity. @@ -12216,7 +12208,7 @@ def _verify_name_uniqueness( raise RuntimeError(f"Schema {schema_class.__name__} has no name.") operation: Literal["create", "update"] = "create" - project_id: Optional[UUID] = None + project_id: UUID | None = None if isinstance(resource, BaseRequest): # Create operation if isinstance(resource, ProjectScopedRequest): @@ -12273,9 +12265,9 @@ def _verify_name_uniqueness( def _get_account_schema( self, - account_name_or_id: Union[str, UUID], + account_name_or_id: str | UUID, session: Session, - service_account: Optional[bool] = None, + service_account: bool | None = None, ) -> UserSchema: """Gets a user account or a service account schema by name or ID. @@ -12410,7 +12402,7 @@ def get_model( def get_model_by_name_or_id( self, - model_name_or_id: Union[str, UUID], + model_name_or_id: str | UUID, project: UUID, hydrate: bool = True, ) -> ModelResponse: @@ -12543,7 +12535,7 @@ def update_model( def _get_or_create_model( self, model_request: ModelRequest - ) -> Tuple[bool, ModelResponse]: + ) -> tuple[bool, ModelResponse]: """Get or create a model. Args: @@ -12588,8 +12580,8 @@ def _model_version_exists( self, session: Session, model_id: UUID, - version: Optional[str] = None, - producer_run_id: Optional[UUID] = None, + version: str | None = None, + producer_run_id: UUID | None = None, ) -> bool: """Check if a model version with a certain version exists. @@ -12622,8 +12614,8 @@ def _model_version_exists( def _get_model_version( self, model_id: UUID, - version_name: Optional[str] = None, - producer_run_id: Optional[UUID] = None, + version_name: str | None = None, + producer_run_id: UUID | None = None, ) -> ModelVersionResponse: """Get a model version. @@ -12700,8 +12692,8 @@ def _get_model_version( def _get_or_create_model_version( self, model_version_request: ModelVersionRequest, - producer_run_id: Optional[UUID] = None, - ) -> Tuple[bool, ModelVersionResponse]: + producer_run_id: UUID | None = None, + ) -> tuple[bool, ModelVersionResponse]: """Get or create a model version. Args: @@ -12741,8 +12733,8 @@ def _get_or_create_model_version( ) def _get_or_create_model_version_for_run( - self, pipeline_or_step_run: Union[PipelineRunSchema, StepRunSchema] - ) -> Optional[UUID]: + self, pipeline_or_step_run: PipelineRunSchema | StepRunSchema + ) -> UUID | None: """Get or create a model version for a pipeline or step run. Args: @@ -12821,7 +12813,7 @@ def _get_or_create_model_version_for_run( def _create_model_version( self, model_version: ModelVersionRequest, - producer_run_id: Optional[UUID] = None, + producer_run_id: UUID | None = None, ) -> ModelVersionResponse: """Creates a new model version. @@ -12870,7 +12862,7 @@ def _create_model_version( ) assert model is not None - model_version_schema: Optional[ModelVersionSchema] = None + model_version_schema: ModelVersionSchema | None = None remaining_tries = MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION while remaining_tries > 0: @@ -13215,7 +13207,7 @@ def list_model_version_artifact_links( def delete_model_version_artifact_link( self, model_version_id: UUID, - model_version_artifact_link_name_or_id: Union[str, UUID], + model_version_artifact_link_name_or_id: str | UUID, ) -> None: """Deletes a model version to artifact link. @@ -13380,7 +13372,7 @@ def list_model_version_pipeline_run_links( def delete_model_version_pipeline_run_link( self, model_version_id: UUID, - model_version_pipeline_run_link_name_or_id: Union[str, UUID], + model_version_pipeline_run_link_name_or_id: str | UUID, ) -> None: """Deletes a model version to pipeline run link. @@ -13444,7 +13436,7 @@ def _get_taggable_resource_type( Raises: ValueError: If the resource type is not taggable. """ - resource_types: Dict[Type[BaseSchema], TaggableResourceTypes] = { + resource_types: dict[type[BaseSchema], TaggableResourceTypes] = { ArtifactSchema: TaggableResourceTypes.ARTIFACT, ArtifactVersionSchema: TaggableResourceTypes.ARTIFACT_VERSION, ModelSchema: TaggableResourceTypes.MODEL, @@ -13485,8 +13477,8 @@ def _get_schema_from_resource_type( RunTemplateSchema, ) - resource_type_to_schema_mapping: Dict[ - TaggableResourceTypes, Type[BaseSchema] + resource_type_to_schema_mapping: dict[ + TaggableResourceTypes, type[BaseSchema] ] = { TaggableResourceTypes.ARTIFACT: ArtifactSchema, TaggableResourceTypes.ARTIFACT_VERSION: ArtifactVersionSchema, @@ -13503,7 +13495,7 @@ def _get_schema_from_resource_type( def _get_tag_schema( self, - tag_name_or_id: Union[str, UUID], + tag_name_or_id: str | UUID, session: Session, ) -> TagSchema: """Gets a tag schema by name or ID. @@ -13526,8 +13518,8 @@ def _get_tag_schema( def _attach_tags_to_resources( self, - tags: Optional[Sequence[Union[str, tag_utils.Tag]]], - resources: Union[BaseSchema, Sequence[BaseSchema]], + tags: Sequence[str | tag_utils.Tag] | None, + resources: BaseSchema | Sequence[BaseSchema], session: Session, ) -> None: """Attaches multiple tags to multiple resources. @@ -13592,8 +13584,8 @@ def _attach_tags_to_resources( [resources] if isinstance(resources, BaseSchema) else resources ) - tag_resources: List[ - Tuple[TagSchema, TaggableResourceTypes, BaseSchema] + tag_resources: list[ + tuple[TagSchema, TaggableResourceTypes, BaseSchema] ] = [] for resource in resources: @@ -13607,8 +13599,8 @@ def _attach_tags_to_resources( def _detach_tags_from_resources( self, - tags: Optional[Sequence[Union[str, UUID, tag_utils.Tag]]], - resources: Union[BaseSchema, List[BaseSchema]], + tags: Sequence[str | UUID | tag_utils.Tag] | None, + resources: BaseSchema | list[BaseSchema], session: Session, ) -> None: """Detaches multiple tags from multiple resources. @@ -13928,7 +13920,7 @@ def _exclusive_check_for_existing_tags( resource_type: TaggableResourceTypes, resource_id_column: Any, scope_id_column: Any, - ) -> Tuple[bool, str]: + ) -> tuple[bool, str]: query = ( select(scope_id_column, func.count().label("count")) .join( @@ -13994,11 +13986,11 @@ def _get_tag_resource_schema( def _create_tag_resource_schemas( self, - tag_resources: List[ - Tuple[TagSchema, TaggableResourceTypes, BaseSchema] + tag_resources: list[ + tuple[TagSchema, TaggableResourceTypes, BaseSchema] ], session: Session, - ) -> List[TagResourceSchema]: + ) -> list[TagResourceSchema]: """Creates a set of tag resource relationships. Args: @@ -14017,8 +14009,8 @@ def _create_tag_resource_schemas( for _ in range(max_retries): tag_resource_schemas = [] - tag_resources_to_create: Set[ - Tuple[UUID, UUID, TaggableResourceTypes] + tag_resources_to_create: set[ + tuple[UUID, UUID, TaggableResourceTypes] ] = set() for tag_schema, resource_type, resource in tag_resources: @@ -14059,10 +14051,10 @@ def _create_tag_resource_schemas( # If the tag is an exclusive tag, apply the check and attach/detach accordingly if tag_schema.exclusive: - scope_ids: Dict[ - TaggableResourceTypes, List[Union[UUID, int]] + scope_ids: dict[ + TaggableResourceTypes, list[UUID | int] ] = defaultdict(list) - detach_resources: List[TagResourceRequest] = [] + detach_resources: list[TagResourceRequest] = [] if isinstance(resource, PipelineRunSchema): if resource.pipeline_id: @@ -14249,8 +14241,8 @@ def create_tag_resource( return self.batch_create_tag_resource(tag_resources=[tag_resource])[0] def batch_create_tag_resource( - self, tag_resources: List[TagResourceRequest] - ) -> List[TagResourceResponse]: + self, tag_resources: list[TagResourceRequest] + ) -> list[TagResourceResponse]: """Create a batch of tag resource relationships. Args: @@ -14260,8 +14252,8 @@ def batch_create_tag_resource( The newly created tag resource relationships. """ with Session(self.engine) as session: - resources: List[ - Tuple[TagSchema, TaggableResourceTypes, BaseSchema] + resources: list[ + tuple[TagSchema, TaggableResourceTypes, BaseSchema] ] = [] for tag_resource in tag_resources: resource_schema = self._get_schema_from_resource_type( @@ -14292,7 +14284,7 @@ def batch_create_tag_resource( def _delete_tag_resource_schemas( self, - tag_resources: List[TagResourceRequest], + tag_resources: list[TagResourceRequest], session: Session, commit: bool = True, ) -> None: @@ -14335,7 +14327,7 @@ def delete_tag_resource( self.batch_delete_tag_resource(tag_resources=[tag_resource]) def batch_delete_tag_resource( - self, tag_resources: List[TagResourceRequest] + self, tag_resources: list[TagResourceRequest] ) -> None: """Delete a batch of tag resource relationships. diff --git a/src/zenml/zen_stores/template_utils.py b/src/zenml/zen_stores/template_utils.py index c0fe9bec3c2..f2de2807c2c 100644 --- a/src/zenml/zen_stores/template_utils.py +++ b/src/zenml/zen_stores/template_utils.py @@ -14,7 +14,7 @@ """Utilities for run templates.""" from enum import Enum -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pydantic import create_model from pydantic.fields import FieldInfo @@ -91,8 +91,8 @@ def validate_snapshot_is_templatable( def generate_config_template( snapshot: PipelineSnapshotSchema, pipeline_configuration: "PipelineConfiguration", - step_configurations: Dict[str, "Step"], -) -> Dict[str, Any]: + step_configurations: dict[str, "Step"], +) -> dict[str, Any]: """Generate a run configuration template for a snapshot. Args: @@ -135,8 +135,8 @@ def generate_config_template( def generate_config_schema( snapshot: PipelineSnapshotSchema, - step_configurations: Dict[str, "Step"], -) -> Dict[str, Any]: + step_configurations: dict[str, "Step"], +) -> dict[str, Any]: """Generate a run configuration schema for the snapshot. Args: @@ -155,7 +155,7 @@ def generate_config_schema( experiment_trackers = [] step_operators = [] - settings_fields: Dict[str, Any] = { + settings_fields: dict[str, Any] = { "resources": (Optional[ResourceSettings], None) } for component in stack.components: @@ -187,7 +187,7 @@ def generate_config_schema( settings_model = create_model("Settings", **settings_fields) - generic_step_fields: Dict[str, Any] = {} + generic_step_fields: dict[str, Any] = {} for key, field_info in StepConfigurationUpdate.model_fields.items(): if key in [ @@ -223,12 +223,12 @@ def generate_config_schema( generic_step_fields["settings"] = (Optional[settings_model], None) - all_steps: Dict[str, Any] = {} + all_steps: dict[str, Any] = {} all_steps_required = False for step_name, step in step_configurations.items(): step_fields = generic_step_fields.copy() if step.config.parameters: - parameter_fields: Dict[str, Any] = {} + parameter_fields: dict[str, Any] = {} for parameter_name in step.config.parameters: # Pydantic doesn't allow field names to start with an underscore @@ -273,7 +273,7 @@ def generate_config_schema( all_steps_model = create_model("Steps", **all_steps) - top_level_fields: Dict[str, Any] = {} + top_level_fields: dict[str, Any] = {} for key, field_info in PipelineRunConfiguration.model_fields.items(): if key in ["schedule", "build", "steps", "settings", "parameters"]: diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 350a74c0387..ea008e156f0 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -15,7 +15,6 @@ import datetime from abc import ABC, abstractmethod -from typing import List, Optional, Tuple, Union from uuid import UUID from zenml.config.pipeline_run_configuration import PipelineRunConfiguration @@ -386,7 +385,7 @@ def create_api_key( def get_api_key( self, service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, hydrate: bool = True, ) -> APIKeyResponse: """Get an API key for a service account. @@ -431,7 +430,7 @@ def list_api_keys( def update_api_key( self, service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, api_key_update: APIKeyUpdate, ) -> APIKeyResponse: """Update an API key for a service account. @@ -456,7 +455,7 @@ def update_api_key( def rotate_api_key( self, service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, rotate_request: APIKeyRotateRequest, ) -> APIKeyResponse: """Rotate an API key for a service account. @@ -479,7 +478,7 @@ def rotate_api_key( def delete_api_key( self, service_account_id: UUID, - api_key_name_or_id: Union[str, UUID], + api_key_name_or_id: str | UUID, ) -> None: """Delete an API key for a service account. @@ -669,8 +668,8 @@ def create_artifact_version( @abstractmethod def batch_create_artifact_versions( - self, artifact_versions: List[ArtifactVersionRequest] - ) -> List[ArtifactVersionResponse]: + self, artifact_versions: list[ArtifactVersionRequest] + ) -> list[ArtifactVersionResponse]: """Creates a batch of artifact versions. Args: @@ -750,7 +749,7 @@ def delete_artifact_version(self, artifact_version_id: UUID) -> None: @abstractmethod def prune_artifact_versions( self, - project_name_or_id: Union[str, UUID], + project_name_or_id: str | UUID, only_versions: bool = True, ) -> None: """Prunes unused artifact versions and their artifacts. @@ -1308,8 +1307,8 @@ def get_snapshot( self, snapshot_id: UUID, hydrate: bool = True, - step_configuration_filter: Optional[List[str]] = None, - include_config_schema: Optional[bool] = None, + step_configuration_filter: list[str] | None = None, + include_config_schema: bool | None = None, ) -> PipelineSnapshotResponse: """Get a snapshot with a given ID. @@ -1617,7 +1616,7 @@ def delete_run_template(self, template_id: UUID) -> None: def run_template( self, template_id: UUID, - run_configuration: Optional[PipelineRunConfiguration] = None, + run_configuration: PipelineRunConfiguration | None = None, ) -> PipelineRunResponse: """Run a template. @@ -1717,7 +1716,7 @@ def delete_event_source(self, event_source_id: UUID) -> None: @abstractmethod def get_or_create_run( self, pipeline_run: PipelineRunRequest - ) -> Tuple[PipelineRunResponse, bool]: + ) -> tuple[PipelineRunResponse, bool]: """Gets or creates a pipeline run. If a run with the same ID or name already exists, it is returned. @@ -2066,7 +2065,7 @@ def create_service_account( @abstractmethod def get_service_account( self, - service_account_name_or_id: Union[str, UUID], + service_account_name_or_id: str | UUID, hydrate: bool = True, ) -> ServiceAccountResponse: """Gets a specific service account. @@ -2105,7 +2104,7 @@ def list_service_accounts( @abstractmethod def update_service_account( self, - service_account_name_or_id: Union[str, UUID], + service_account_name_or_id: str | UUID, service_account_update: ServiceAccountUpdate, ) -> ServiceAccountResponse: """Updates an existing service account. @@ -2126,7 +2125,7 @@ def update_service_account( @abstractmethod def delete_service_account( self, - service_account_name_or_id: Union[str, UUID], + service_account_name_or_id: str | UUID, ) -> None: """Delete a service account. @@ -2275,8 +2274,8 @@ def verify_service_connector_config( def verify_service_connector( self, service_connector_id: UUID, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + resource_id: str | None = None, list_resources: bool = True, ) -> ServiceConnectorResourcesModel: """Verifies if a service connector instance has access to one or more resources. @@ -2303,8 +2302,8 @@ def verify_service_connector( def get_service_connector_client( self, service_connector_id: UUID, - resource_type: Optional[str] = None, - resource_id: Optional[str] = None, + resource_type: str | None = None, + resource_id: str | None = None, ) -> ServiceConnectorResponse: """Get a service connector client for a service connector and given resource. @@ -2327,7 +2326,7 @@ def get_service_connector_client( def list_service_connector_resources( self, filter_model: ServiceConnectorFilter, - ) -> List[ServiceConnectorResourcesModel]: + ) -> list[ServiceConnectorResourcesModel]: """List resources that can be accessed by service connectors. Args: @@ -2342,10 +2341,10 @@ def list_service_connector_resources( @abstractmethod def list_service_connector_types( self, - connector_type: Optional[str] = None, - resource_type: Optional[str] = None, - auth_method: Optional[str] = None, - ) -> List[ServiceConnectorTypeModel]: + connector_type: str | None = None, + resource_type: str | None = None, + auth_method: str | None = None, + ) -> list[ServiceConnectorTypeModel]: """Get a list of service connector types. Args: @@ -2474,7 +2473,7 @@ def get_stack_deployment_config( self, provider: StackDeploymentProvider, stack_name: str, - location: Optional[str] = None, + location: str | None = None, ) -> StackDeploymentConfig: """Return the cloud provider console URL and configuration needed to deploy the ZenML stack. @@ -2493,9 +2492,9 @@ def get_stack_deployment_stack( self, provider: StackDeploymentProvider, stack_name: str, - location: Optional[str] = None, - date_start: Optional[datetime.datetime] = None, - ) -> Optional[DeployedStack]: + location: str | None = None, + date_start: datetime.datetime | None = None, + ) -> DeployedStack | None: """Return a matching ZenML stack that was deployed and registered. Args: @@ -2732,7 +2731,7 @@ def create_user(self, user: UserRequest) -> UserResponse: @abstractmethod def get_user( self, - user_name_or_id: Optional[Union[str, UUID]] = None, + user_name_or_id: str | UUID | None = None, include_private: bool = False, hydrate: bool = True, ) -> UserResponse: @@ -2787,7 +2786,7 @@ def update_user( """ @abstractmethod - def delete_user(self, user_name_or_id: Union[str, UUID]) -> None: + def delete_user(self, user_name_or_id: str | UUID) -> None: """Deletes a user. Args: @@ -2815,7 +2814,7 @@ def create_project(self, project: ProjectRequest) -> ProjectResponse: @abstractmethod def get_project( - self, project_name_or_id: Union[UUID, str], hydrate: bool = True + self, project_name_or_id: UUID | str, hydrate: bool = True ) -> ProjectResponse: """Get an existing project by name or ID. @@ -2867,7 +2866,7 @@ def update_project( """ @abstractmethod - def delete_project(self, project_name_or_id: Union[str, UUID]) -> None: + def delete_project(self, project_name_or_id: str | UUID) -> None: """Deletes a project. Args: @@ -3088,7 +3087,7 @@ def list_model_version_artifact_links( def delete_model_version_artifact_link( self, model_version_id: UUID, - model_version_artifact_link_name_or_id: Union[str, UUID], + model_version_artifact_link_name_or_id: str | UUID, ) -> None: """Deletes a model version to artifact link. @@ -3156,7 +3155,7 @@ def list_model_version_pipeline_run_links( def delete_model_version_pipeline_run_link( self, model_version_id: UUID, - model_version_pipeline_run_link_name_or_id: Union[str, UUID], + model_version_pipeline_run_link_name_or_id: str | UUID, ) -> None: """Deletes a model version to pipeline run link. @@ -3272,8 +3271,8 @@ def create_tag_resource( @abstractmethod def batch_create_tag_resource( - self, tag_resources: List[TagResourceRequest] - ) -> List[TagResourceResponse]: + self, tag_resources: list[TagResourceRequest] + ) -> list[TagResourceResponse]: """Create a new tag resource relationship. Args: @@ -3296,7 +3295,7 @@ def delete_tag_resource( @abstractmethod def batch_delete_tag_resource( - self, tag_resources: List[TagResourceRequest] + self, tag_resources: list[TagResourceRequest] ) -> None: """Delete a batch of tag resource relationships.