diff --git a/src/zenml/artifacts/in_memory_cache.py b/src/zenml/artifacts/in_memory_cache.py new file mode 100644 index 00000000000..91a659d5ffb --- /dev/null +++ b/src/zenml/artifacts/in_memory_cache.py @@ -0,0 +1,55 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""In-memory artifact cache.""" + +import contextvars +from typing import Any, Dict +from uuid import UUID + +from zenml.utils import context_utils + + +class InMemoryArtifactCache(context_utils.BaseContext): + """In-memory artifact cache.""" + + __context_var__ = contextvars.ContextVar("in_memory_artifact_cache") + + def __init__(self) -> None: + """Initialize the artifact cache.""" + super().__init__() + self._cache: Dict[UUID, Any] = {} + + def clear(self) -> None: + """Clear the artifact cache.""" + self._cache = {} + + def get_artifact_data(self, id_: UUID) -> Any: + """Get the artifact data. + + Args: + id_: The ID of the artifact to get the data for. + + Returns: + The artifact data. + """ + return self._cache.get(id_) + + def set_artifact_data(self, id_: UUID, data: Any) -> None: + """Set the artifact data. + + Args: + id_: The ID of the artifact to set the data for. + data: The artifact data to set. + """ + self._cache[id_] = data diff --git a/src/zenml/config/compiler.py b/src/zenml/config/compiler.py index dc2f140ca68..21acb874d7a 100644 --- a/src/zenml/config/compiler.py +++ b/src/zenml/config/compiler.py @@ -126,12 +126,18 @@ def compile( merge=False, ) + # If we're compiling a dynamic pipeline, the steps are only templates + # and might not have all inputs defined, so we skip the input + # validation. + skip_input_validation = pipeline.is_dynamic + steps = { invocation_id: self._compile_step_invocation( invocation=invocation, stack=stack, step_config=(run_configuration.steps or {}).get(invocation_id), pipeline_configuration=pipeline.configuration, + skip_input_validation=skip_input_validation, ) for invocation_id, invocation in self._get_sorted_invocations( pipeline=pipeline @@ -153,6 +159,7 @@ def compile( snapshot = PipelineSnapshotBase( run_name_template=run_name, + is_dynamic=pipeline.is_dynamic, pipeline_configuration=pipeline.configuration, step_configurations=steps, client_environment=get_run_environment_dict(), @@ -463,6 +470,7 @@ def _compile_step_invocation( stack: "Stack", step_config: Optional["StepConfigurationUpdate"], pipeline_configuration: "PipelineConfiguration", + skip_input_validation: bool = False, ) -> Step: """Compiles a ZenML step. @@ -471,6 +479,7 @@ def _compile_step_invocation( stack: The stack on which the pipeline will be run. step_config: Run configuration for the step. pipeline_configuration: Configuration for the pipeline. + skip_input_validation: If True, will skip the input validation. Returns: The compiled step. @@ -480,35 +489,41 @@ def _compile_step_invocation( invocation.step = copy.deepcopy(invocation.step) step = invocation.step - if step_config: - step._apply_configuration( - step_config, runtime_parameters=invocation.parameters - ) + with step._suspend_dynamic_configuration(): + if step_config: + step._apply_configuration( + step_config, runtime_parameters=invocation.parameters + ) - convert_component_shortcut_settings_keys( - step.configuration.settings, stack=stack - ) - step_spec = self._get_step_spec(invocation=invocation) - step_secrets = secret_utils.resolve_and_verify_secrets( - step.configuration.secrets - ) - step_settings = self._filter_and_validate_settings( - settings=step.configuration.settings, - configuration_level=ConfigurationLevel.STEP, - stack=stack, - ) - step.configure( - secrets=step_secrets, - settings=step_settings, - merge=False, - ) + # Apply the dynamic configuration (which happened while executing the + # pipeline function) after all other step-specific configurations. + step._merge_dynamic_configuration() - parameters_to_ignore = ( - set(step_config.parameters or {}) if step_config else set() - ) - step_configuration_overrides = invocation.finalize( - parameters_to_ignore=parameters_to_ignore - ) + convert_component_shortcut_settings_keys( + step.configuration.settings, stack=stack + ) + step_spec = self._get_step_spec(invocation=invocation) + step_secrets = secret_utils.resolve_and_verify_secrets( + step.configuration.secrets + ) + step_settings = self._filter_and_validate_settings( + settings=step.configuration.settings, + configuration_level=ConfigurationLevel.STEP, + stack=stack, + ) + step.configure( + secrets=step_secrets, + settings=step_settings, + merge=False, + ) + + parameters_to_ignore = ( + set(step_config.parameters or {}) if step_config else set() + ) + step_configuration_overrides = invocation.finalize( + parameters_to_ignore=parameters_to_ignore, + skip_input_validation=skip_input_validation, + ) full_step_config = ( step_configuration_overrides.apply_pipeline_configuration( pipeline_configuration=pipeline_configuration @@ -533,8 +548,15 @@ def _get_sorted_invocations( pipeline: The pipeline of which to sort the invocations Returns: - The sorted steps. + The sorted step invocations. """ + if pipeline.is_dynamic: + # In dynamic pipelines, we require the static invocations to be + # sorted the same way they were passed in `pipeline.depends_on`, as + # we index this list later to figure out the correct template for + # each step invocation. + return list(pipeline.invocations.items()) + from zenml.orchestrators.dag_runner import reverse_dag from zenml.orchestrators.topsort import topsorted_layers @@ -634,7 +656,7 @@ def _compute_pipeline_spec( Raises: ValueError: If the pipeline has no steps. """ - if not step_specs: + if not step_specs and not pipeline.is_dynamic: raise ValueError( f"Pipeline '{pipeline.name}' cannot be compiled because it has " f"no steps. Please make sure that your steps are decorated " diff --git a/src/zenml/config/step_configurations.py b/src/zenml/config/step_configurations.py index 597074cc82b..4e72929f646 100644 --- a/src/zenml/config/step_configurations.py +++ b/src/zenml/config/step_configurations.py @@ -214,6 +214,13 @@ class StepConfigurationUpdate(FrozenBaseModel): default=None, description="The cache policy for the step.", ) + in_process: Optional[bool] = Field( + default=None, + description="Whether to run the step in process. This is only " + "applicable for dynamic pipelines. If not set, the step will by " + "default run in-process unless it requires a different Docker image " + "or has special resource requirements.", + ) outputs: Mapping[str, PartialArtifactConfiguration] = {} @@ -254,6 +261,8 @@ class PartialStepConfiguration(StepConfigurationUpdate): """Class representing a partial step configuration.""" name: str + # TODO: maybe move to spec? + template: Optional[str] = None parameters: Dict[str, Any] = {} settings: Dict[str, SerializeAsAny[BaseSettings]] = {} environment: Dict[str, str] = {} diff --git a/src/zenml/config/step_run_info.py b/src/zenml/config/step_run_info.py index 9e3c86723c4..a6c6819601e 100644 --- a/src/zenml/config/step_run_info.py +++ b/src/zenml/config/step_run_info.py @@ -18,7 +18,11 @@ from zenml.config.frozen_base_model import FrozenBaseModel from zenml.config.pipeline_configurations import PipelineConfiguration -from zenml.config.step_configurations import StepConfiguration +from zenml.config.step_configurations import StepConfiguration, StepSpec +from zenml.logger import get_logger +from zenml.models import PipelineSnapshotResponse + +logger = get_logger(__name__) class StepRunInfo(FrozenBaseModel): @@ -30,7 +34,9 @@ class StepRunInfo(FrozenBaseModel): pipeline_step_name: str config: StepConfiguration + spec: StepSpec pipeline: PipelineConfiguration + snapshot: PipelineSnapshotResponse force_write_logs: Callable[..., Any] @@ -46,15 +52,22 @@ def get_image(self, key: str) -> str: Returns: The image name or digest. """ - from zenml.client import Client - - run = Client().get_pipeline_run(self.run_id) - if not run.build: + if not self.snapshot.build: raise RuntimeError( - f"Missing build for run {run.id}. This is probably because " - "the build was manually deleted." + f"Missing build for snapshot {self.snapshot.id}. This is " + "probably because the build was manually deleted." ) - return run.build.get_image( - component_key=key, step=self.pipeline_step_name - ) + if self.snapshot.is_dynamic: + step_key = self.config.template + if not step_key: + logger.warning( + "Unable to find config template for step %s. Falling " + "back to the pipeline image.", + self.pipeline_step_name, + ) + step_key = None + else: + step_key = self.pipeline_step_name + + return self.snapshot.build.get_image(component_key=key, step=step_key) diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 47c4e728151..7dab7f457ae 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -140,7 +140,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ENV_ZENML_LOGGING_VERBOSITY = "ZENML_LOGGING_VERBOSITY" ENV_ZENML_LOGGING_FORMAT = "ZENML_LOGGING_FORMAT" ENV_ZENML_REPOSITORY_PATH = "ZENML_REPOSITORY_PATH" -ENV_ZENML_PREVENT_PIPELINE_EXECUTION = "ZENML_PREVENT_PIPELINE_EXECUTION" ENV_ZENML_ENABLE_RICH_TRACEBACK = "ZENML_ENABLE_RICH_TRACEBACK" ENV_ZENML_ACTIVE_STACK_ID = "ZENML_ACTIVE_STACK_ID" ENV_ZENML_ACTIVE_PROJECT_ID = "ZENML_ACTIVE_PROJECT_ID" @@ -246,11 +245,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int: # ZenML Analytics Server - URL ANALYTICS_SERVER_URL = "https://analytics.zenml.io/" -# Container utils -SHOULD_PREVENT_PIPELINE_EXECUTION = handle_bool_env_var( - ENV_ZENML_PREVENT_PIPELINE_EXECUTION -) - # Repository and local store directory paths: REPOSITORY_DIRECTORY_NAME = ".zen" LOCAL_STORES_DIRECTORY_NAME = "local_stores" diff --git a/src/zenml/deployers/server/entrypoint_configuration.py b/src/zenml/deployers/server/entrypoint_configuration.py index ef9f1998dea..f0d282ebd4a 100644 --- a/src/zenml/deployers/server/entrypoint_configuration.py +++ b/src/zenml/deployers/server/entrypoint_configuration.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """ZenML Pipeline Deployment Entrypoint Configuration.""" -from typing import Any, List, Set +from typing import Any, Dict, List from uuid import UUID from zenml.client import Client @@ -39,14 +39,14 @@ class DeploymentEntrypointConfiguration(BaseEntrypointConfiguration): """ @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> Dict[str, bool]: """Gets all options required for the deployment entrypoint. Returns: Set of required option names """ return { - DEPLOYMENT_ID_OPTION, + DEPLOYMENT_ID_OPTION: True, } @classmethod @@ -113,7 +113,7 @@ def run(self) -> None: raise RuntimeError(f"Deployment {deployment.id} has no snapshot") # Download code if necessary (for remote execution environments) - self.download_code_if_necessary(snapshot=deployment.snapshot) + self.download_code_if_necessary() app_runner = BaseDeploymentAppRunner.load_app_runner(deployment) app_runner.run() diff --git a/src/zenml/entrypoints/base_entrypoint_configuration.py b/src/zenml/entrypoints/base_entrypoint_configuration.py index c8eb1a2d396..f48eb3cb10b 100644 --- a/src/zenml/entrypoints/base_entrypoint_configuration.py +++ b/src/zenml/entrypoints/base_entrypoint_configuration.py @@ -17,7 +17,7 @@ import os import sys from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Dict, List, NoReturn, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, NoReturn, Optional from uuid import UUID from zenml.client import Client @@ -34,6 +34,7 @@ if TYPE_CHECKING: from zenml.artifact_stores import BaseArtifactStore + from zenml.config import DockerSettings from zenml.models import CodeReferenceResponse, PipelineSnapshotResponse logger = get_logger(__name__) @@ -64,6 +65,7 @@ def __init__(self, arguments: List[str]): arguments: Command line arguments to configure this object. """ self.entrypoint_args = self._parse_arguments(arguments) + self._snapshot: Optional["PipelineSnapshotResponse"] = None @classmethod def get_entrypoint_command(cls) -> List[str]: @@ -83,18 +85,18 @@ def get_entrypoint_command(cls) -> List[str]: return DEFAULT_ENTRYPOINT_COMMAND @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> Dict[str, bool]: """Gets all options required for running with this configuration. Returns: - A set of strings with all required options. + A dictionary of options and whether they are required. """ return { # Importable source pointing to the entrypoint configuration class # that should be used inside the entrypoint. - ENTRYPOINT_CONFIG_SOURCE_OPTION, + ENTRYPOINT_CONFIG_SOURCE_OPTION: True, # ID of the pipeline snapshot to use in this entrypoint - SNAPSHOT_ID_OPTION, + SNAPSHOT_ID_OPTION: True, } @classmethod @@ -178,18 +180,59 @@ def error(self, message: str) -> NoReturn: parser = _CustomParser() - for option_name in cls.get_entrypoint_options(): + for option_name, required in cls.get_entrypoint_options().items(): if option_name == ENTRYPOINT_CONFIG_SOURCE_OPTION: # This option is already used by # `zenml.entrypoints.entrypoint` to read which config # class to use continue - parser.add_argument(f"--{option_name}", required=True) + parser.add_argument(f"--{option_name}", required=required) result, _ = parser.parse_known_args(arguments) return vars(result) - def load_snapshot(self) -> "PipelineSnapshotResponse": + @property + def snapshot(self) -> "PipelineSnapshotResponse": + """The snapshot configured for this entrypoint configuration. + + Returns: + The snapshot. + """ + if self._snapshot is None: + self._snapshot = self._load_snapshot() + return self._snapshot + + @property + def docker_settings(self) -> "DockerSettings": + """The Docker settings configured for this entrypoint configuration. + + Returns: + The Docker settings. + """ + return self.snapshot.pipeline_configuration.docker_settings + + @property + def should_download_code(self) -> bool: + """Whether code should be downloaded. + + Returns: + Whether code should be downloaded. + """ + if ( + self.snapshot.code_reference + and self.docker_settings.allow_download_from_code_repository + ): + return True + + if ( + self.snapshot.code_path + and self.docker_settings.allow_download_from_artifact_store + ): + return True + + return False + + def _load_snapshot(self) -> "PipelineSnapshotResponse": """Loads the snapshot. Returns: @@ -198,34 +241,19 @@ def load_snapshot(self) -> "PipelineSnapshotResponse": snapshot_id = UUID(self.entrypoint_args[SNAPSHOT_ID_OPTION]) return Client().zen_store.get_snapshot(snapshot_id=snapshot_id) - def download_code_if_necessary( - self, - snapshot: "PipelineSnapshotResponse", - step_name: Optional[str] = None, - ) -> None: + def download_code_if_necessary(self) -> None: """Downloads user code if necessary. - Args: - snapshot: The snapshot for which to download the code. - step_name: Name of the step to be run. This will be used to - determine whether code download is necessary. If not given, - the DockerSettings of the pipeline will be used to make that - decision instead. - Raises: CustomFlavorImportError: If the artifact store flavor can't be imported. RuntimeError: If the current environment requires code download but the snapshot does not have a reference to any code. """ - should_download_code = self._should_download_code( - snapshot=snapshot, step_name=step_name - ) - - if not should_download_code: + if not self.should_download_code: return - if code_path := snapshot.code_path: + if code_path := self.snapshot.code_path: # Load the artifact store not from the active stack but separately. # This is required in case the stack has custom flavor components # (other than the artifact store) for which the flavor @@ -247,7 +275,7 @@ def download_code_if_necessary( code_utils.download_code_from_artifact_store( code_path=code_path, artifact_store=artifact_store ) - elif code_reference := snapshot.code_reference: + elif code_reference := self.snapshot.code_reference: # TODO: This might fail if the code repository had unpushed changes # at the time the pipeline run was started. self.download_code_from_code_repository( @@ -294,43 +322,6 @@ def download_code_from_code_repository( sys.path.insert(0, download_dir) os.chdir(download_dir) - def _should_download_code( - self, - snapshot: "PipelineSnapshotResponse", - step_name: Optional[str] = None, - ) -> bool: - """Checks whether code should be downloaded. - - Args: - snapshot: The snapshot to check. - step_name: Name of the step to be run. This will be used to - determine whether code download is necessary. If not given, - the DockerSettings of the pipeline will be used to make that - decision instead. - - Returns: - Whether code should be downloaded. - """ - docker_settings = ( - snapshot.step_configurations[step_name].config.docker_settings - if step_name - else snapshot.pipeline_configuration.docker_settings - ) - - if ( - snapshot.code_reference - and docker_settings.allow_download_from_code_repository - ): - return True - - if ( - snapshot.code_path - and docker_settings.allow_download_from_artifact_store - ): - return True - - return False - def _load_active_artifact_store(self) -> "BaseArtifactStore": """Load the active artifact store. diff --git a/src/zenml/entrypoints/entrypoint.py b/src/zenml/entrypoints/entrypoint.py index 9dad3d9e29e..db5ed0879ff 100644 --- a/src/zenml/entrypoints/entrypoint.py +++ b/src/zenml/entrypoints/entrypoint.py @@ -17,11 +17,11 @@ import logging import sys -from zenml import constants from zenml.entrypoints.base_entrypoint_configuration import ( ENTRYPOINT_CONFIG_SOURCE_OPTION, BaseEntrypointConfiguration, ) +from zenml.execution.pipeline.utils import prevent_pipeline_execution from zenml.utils import source_utils @@ -35,23 +35,22 @@ def main() -> None: _setup_logging() # Make sure this entrypoint does not run an entire pipeline when - # importing user modules. This could happen if the `pipeline.run()` call - # is not wrapped in a function or an `if __name__== "__main__":` check) - constants.SHOULD_PREVENT_PIPELINE_EXECUTION = True - - # Read the source for the entrypoint configuration class from the command - # line arguments - parser = argparse.ArgumentParser() - parser.add_argument(f"--{ENTRYPOINT_CONFIG_SOURCE_OPTION}", required=True) - args, remaining_args = parser.parse_known_args() - - entrypoint_config_class = source_utils.load_and_validate_class( - args.entrypoint_config_source, - expected_class=BaseEntrypointConfiguration, - ) - entrypoint_config = entrypoint_config_class(arguments=remaining_args) - - entrypoint_config.run() + # importing user modules. This could happen if a pipeline is called in a + # module without an `if __name__== "__main__":` check) + with prevent_pipeline_execution(): + parser = argparse.ArgumentParser() + parser.add_argument( + f"--{ENTRYPOINT_CONFIG_SOURCE_OPTION}", required=True + ) + args, remaining_args = parser.parse_known_args() + + entrypoint_config_class = source_utils.load_and_validate_class( + args.entrypoint_config_source, + expected_class=BaseEntrypointConfiguration, + ) + entrypoint_config = entrypoint_config_class(arguments=remaining_args) + + entrypoint_config.run() if __name__ == "__main__": diff --git a/src/zenml/entrypoints/pipeline_entrypoint_configuration.py b/src/zenml/entrypoints/pipeline_entrypoint_configuration.py index f89a1595d86..f0b6f6a4c86 100644 --- a/src/zenml/entrypoints/pipeline_entrypoint_configuration.py +++ b/src/zenml/entrypoints/pipeline_entrypoint_configuration.py @@ -25,13 +25,13 @@ class PipelineEntrypointConfiguration(BaseEntrypointConfiguration): def run(self) -> None: """Prepares the environment and runs the configured pipeline.""" - snapshot = self.load_snapshot() + snapshot = self.snapshot # Activate all the integrations. This makes sure that all materializers # and stack component flavors are registered. integration_registry.activate_integrations() - self.download_code_if_necessary(snapshot=snapshot) + self.download_code_if_necessary() orchestrator = Client().active_stack.orchestrator orchestrator._prepare_run(snapshot=snapshot) diff --git a/src/zenml/entrypoints/step_entrypoint_configuration.py b/src/zenml/entrypoints/step_entrypoint_configuration.py index 7ba49332df6..9a087d99d42 100644 --- a/src/zenml/entrypoints/step_entrypoint_configuration.py +++ b/src/zenml/entrypoints/step_entrypoint_configuration.py @@ -15,7 +15,7 @@ import os import sys -from typing import TYPE_CHECKING, Any, List, Set +from typing import TYPE_CHECKING, Any, Dict, List from uuid import UUID from zenml.client import Client @@ -27,6 +27,7 @@ from zenml.logger import get_logger if TYPE_CHECKING: + from zenml.config import DockerSettings from zenml.config.step_configurations import Step from zenml.models import PipelineSnapshotResponse @@ -115,14 +116,14 @@ def post_run( """ @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> Dict[str, bool]: """Gets all options required for running with this configuration. Returns: The superclass options as well as an option for the name of the step to run. """ - return super().get_entrypoint_options() | {STEP_NAME_OPTION} + return super().get_entrypoint_options() | {STEP_NAME_OPTION: True} @classmethod def get_entrypoint_arguments( @@ -149,7 +150,26 @@ def get_entrypoint_arguments( kwargs[STEP_NAME_OPTION], ] - def load_snapshot(self) -> "PipelineSnapshotResponse": + @property + def docker_settings(self) -> "DockerSettings": + """The Docker settings configured for this entrypoint configuration. + + Returns: + The Docker settings. + """ + return self.step.config.docker_settings + + @property + def step(self) -> "Step": + """The step configured for this entrypoint configuration. + + Returns: + The step. + """ + step_name = self.entrypoint_args[STEP_NAME_OPTION] + return self.snapshot.step_configurations[step_name] + + def _load_snapshot(self) -> "PipelineSnapshotResponse": """Loads the snapshot. Returns: @@ -163,7 +183,7 @@ def load_snapshot(self) -> "PipelineSnapshotResponse": def run(self) -> None: """Prepares the environment and runs the configured step.""" - snapshot = self.load_snapshot() + snapshot = self.snapshot # Activate all the integrations. This makes sure that all materializers # and stack component flavors are registered. @@ -178,7 +198,7 @@ def run(self) -> None: os.makedirs("/app", exist_ok=True) os.chdir("/app") - self.download_code_if_necessary(snapshot=snapshot, step_name=step_name) + self.download_code_if_necessary() # If the working directory is not in the sys.path, we include it to make # sure user code gets correctly imported. @@ -188,8 +208,7 @@ def run(self) -> None: pipeline_name = snapshot.pipeline_configuration.name - step = snapshot.step_configurations[step_name] - self._run_step(step, snapshot=snapshot) + self._run_step(step=self.step, snapshot=snapshot) self.post_run( pipeline_name=pipeline_name, diff --git a/src/zenml/execution/__init__.py b/src/zenml/execution/__init__.py new file mode 100644 index 00000000000..2927d4efc3e --- /dev/null +++ b/src/zenml/execution/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Step and pipeline execution.""" \ No newline at end of file diff --git a/src/zenml/execution/pipeline/__init__.py b/src/zenml/execution/pipeline/__init__.py new file mode 100644 index 00000000000..a07e0ef516f --- /dev/null +++ b/src/zenml/execution/pipeline/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Pipeline execution.""" \ No newline at end of file diff --git a/src/zenml/execution/pipeline/dynamic/__init__.py b/src/zenml/execution/pipeline/dynamic/__init__.py new file mode 100644 index 00000000000..d762de555a7 --- /dev/null +++ b/src/zenml/execution/pipeline/dynamic/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Dynamic pipeline execution.""" \ No newline at end of file diff --git a/src/zenml/execution/pipeline/dynamic/outputs.py b/src/zenml/execution/pipeline/dynamic/outputs.py new file mode 100644 index 00000000000..263aadf5348 --- /dev/null +++ b/src/zenml/execution/pipeline/dynamic/outputs.py @@ -0,0 +1,253 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Dynamic pipeline execution outputs.""" + +from concurrent.futures import Future +from typing import Any, List, Tuple, Union + +from zenml.logger import get_logger +from zenml.models import ( + ArtifactVersionResponse, +) + +logger = get_logger(__name__) + + +class OutputArtifact(ArtifactVersionResponse): + """Dynamic step run output artifact.""" + + output_name: str + step_name: str + + +StepRunOutputs = Union[None, OutputArtifact, Tuple[OutputArtifact, ...]] + + +class _BaseStepRunFuture: + """Base step run future.""" + + def __init__( + self, + wrapped: Future[StepRunOutputs], + invocation_id: str, + **kwargs: Any, + ) -> None: + """Initialize the dynamic step run future. + + Args: + wrapped: The wrapped future object. + invocation_id: The invocation ID of the step run. + **kwargs: Additional keyword arguments. + """ + self._wrapped = wrapped + self._invocation_id = invocation_id + + @property + def invocation_id(self) -> str: + """The step run invocation ID. + + Returns: + The step run invocation ID. + """ + return self._invocation_id + + def _wait(self) -> None: + """Wait for the step run future to complete.""" + self._wrapped.result() + + +class ArtifactFuture(_BaseStepRunFuture): + """Future for a step run output artifact.""" + + def __init__( + self, wrapped: Future[StepRunOutputs], invocation_id: str, index: int + ) -> None: + """Initialize the future. + + Args: + wrapped: The wrapped future object. + invocation_id: The invocation ID of the step run. + index: The index of the output artifact. + """ + super().__init__(wrapped=wrapped, invocation_id=invocation_id) + self._index = index + + def result(self) -> OutputArtifact: + """Get the step run output artifact. + + Raises: + RuntimeError: If the future returned an invalid output. + + Returns: + The step run output artifact. + """ + result = self._wrapped.result() + if isinstance(result, OutputArtifact): + return result + elif isinstance(result, tuple): + return result[self._index] + else: + raise RuntimeError( + f"Step {self._invocation_id} returned an invalid output: " + f"{result}." + ) + + def load(self, disable_cache: bool = False) -> Any: + """Load the step run output artifact data. + + Args: + disable_cache: Whether to disable the artifact cache. + + Returns: + The step run output artifact data. + """ + return self.result().load(disable_cache=disable_cache) + + +class StepRunOutputsFuture(_BaseStepRunFuture): + """Future for a step run output.""" + + def __init__( + self, + wrapped: Future[StepRunOutputs], + invocation_id: str, + output_keys: List[str], + ) -> None: + """Initialize the future. + + Args: + wrapped: The wrapped future object. + invocation_id: The invocation ID of the step run. + output_keys: The output keys of the step run. + """ + super().__init__(wrapped=wrapped, invocation_id=invocation_id) + self._output_keys = output_keys + + def get_artifact(self, key: str) -> ArtifactFuture: + """Get an artifact future by key. + + Args: + key: The key of the artifact future. + + Raises: + KeyError: If no artifact for the given name exists. + + Returns: + The artifact future. + """ + if key not in self._output_keys: + raise KeyError( + f"Step run {self._invocation_id} does not have an output with " + f"the name: {key}." + ) + + return ArtifactFuture( + wrapped=self._wrapped, + invocation_id=self._invocation_id, + index=self._output_keys.index(key), + ) + + def artifacts(self) -> StepRunOutputs: + """Get the step run output artifacts. + + Returns: + The step run output artifacts. + """ + return self._wrapped.result() + + def load(self, disable_cache: bool = False) -> Any: + """Get the step run output artifact data. + + Args: + disable_cache: Whether to disable the artifact cache. + + Raises: + ValueError: If the step run output is invalid. + + Returns: + The step run output artifact data. + """ + result = self.artifacts() + + if result is None: + return None + elif isinstance(result, ArtifactVersionResponse): + return result.load(disable_cache=disable_cache) + elif isinstance(result, tuple): + return tuple( + item.load(disable_cache=disable_cache) for item in result + ) + else: + raise ValueError(f"Invalid step run output: {result}") + + def __getitem__(self, key: Any) -> ArtifactFuture: + """Get an artifact future by key or index. + + Args: + key: The key or index of the artifact future. + + Raises: + TypeError: If the key is not an integer. + IndexError: If the index is out of range. + + Returns: + The artifact future. + """ + if not isinstance(key, int): + raise TypeError(f"Invalid key type: {type(key)}") + + # Convert to positive index if necessary + if key < 0: + key += len(self._output_keys) + + if key > len(self._output_keys): + raise IndexError(f"Index out of range: {key}") + + return ArtifactFuture( + wrapped=self._wrapped, + invocation_id=self._invocation_id, + index=key, + ) + + def __iter__(self) -> Any: + """Iterate over the artifact futures. + + Raises: + ValueError: If the step does not return any outputs. + + Yields: + The artifact futures. + """ + if not self._output_keys: + raise ValueError( + f"Step {self._invocation_id} does not return any outputs." + ) + + for index in range(len(self._output_keys)): + yield ArtifactFuture( + wrapped=self._wrapped, + invocation_id=self._invocation_id, + index=index, + ) + + def __len__(self) -> int: + """Get the number of artifact futures. + + Returns: + The number of artifact futures. + """ + return len(self._output_keys) + + +StepRunFuture = Union[ArtifactFuture, StepRunOutputsFuture] diff --git a/src/zenml/execution/pipeline/dynamic/run_context.py b/src/zenml/execution/pipeline/dynamic/run_context.py new file mode 100644 index 00000000000..38d515140a1 --- /dev/null +++ b/src/zenml/execution/pipeline/dynamic/run_context.py @@ -0,0 +1,105 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Dynamic pipeline run context.""" + +import contextvars +from typing import TYPE_CHECKING + +from typing_extensions import Self + +from zenml.utils import context_utils + +if TYPE_CHECKING: + from zenml.execution.pipeline.dynamic.runner import DynamicPipelineRunner + from zenml.models import PipelineRunResponse, PipelineSnapshotResponse + from zenml.pipelines.dynamic.pipeline_definition import DynamicPipeline + + +class DynamicPipelineRunContext(context_utils.BaseContext): + """Dynamic pipeline run context.""" + + __context_var__ = contextvars.ContextVar("dynamic_pipeline_run_context") + + def __init__( + self, + pipeline: "DynamicPipeline", + snapshot: "PipelineSnapshotResponse", + run: "PipelineRunResponse", + runner: "DynamicPipelineRunner", + ) -> None: + """Initialize the dynamic pipeline run context. + + Args: + pipeline: The dynamic pipeline that is being executed. + snapshot: The snapshot of the pipeline. + run: The pipeline run. + runner: The dynamic pipeline runner. + """ + super().__init__() + self._pipeline = pipeline + self._snapshot = snapshot + self._run = run + self._runner = runner + + @property + def pipeline(self) -> "DynamicPipeline": + """The pipeline that is being executed. + + Returns: + The pipeline that is being executed. + """ + return self._pipeline + + @property + def run(self) -> "PipelineRunResponse": + """The pipeline run. + + Returns: + The pipeline run. + """ + return self._run + + @property + def snapshot(self) -> "PipelineSnapshotResponse": + """The snapshot of the pipeline. + + Returns: + The snapshot of the pipeline. + """ + return self._snapshot + + @property + def runner(self) -> "DynamicPipelineRunner": + """The runner executing the pipeline. + + Returns: + The runner executing the pipeline. + """ + return self._runner + + def __enter__(self) -> Self: + """Enter the dynamic pipeline run context. + + Raises: + RuntimeError: If the dynamic pipeline run context has already been + entered. + + Returns: + The dynamic pipeline run context object. + """ + if self._token is not None: + raise RuntimeError( + "Calling a pipeline within a dynamic pipeline is not allowed." + ) + return super().__enter__() diff --git a/src/zenml/execution/pipeline/dynamic/runner.py b/src/zenml/execution/pipeline/dynamic/runner.py new file mode 100644 index 00000000000..6a13ea87d69 --- /dev/null +++ b/src/zenml/execution/pipeline/dynamic/runner.py @@ -0,0 +1,497 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Dynamic pipeline runner.""" + +import contextvars +import inspect +from concurrent.futures import ThreadPoolExecutor +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Literal, + Optional, + Sequence, + Tuple, + Union, + overload, +) +from uuid import UUID + +from zenml import ExternalArtifact +from zenml.artifacts.in_memory_cache import InMemoryArtifactCache +from zenml.client import Client +from zenml.config.compiler import Compiler +from zenml.config.step_configurations import Step +from zenml.enums import ExecutionMode +from zenml.execution.pipeline.dynamic.outputs import ( + ArtifactFuture, + OutputArtifact, + StepRunFuture, + StepRunOutputs, + StepRunOutputsFuture, + _BaseStepRunFuture, +) +from zenml.execution.pipeline.dynamic.run_context import ( + DynamicPipelineRunContext, +) +from zenml.execution.step.utils import launch_step +from zenml.logger import get_logger +from zenml.logging.step_logging import setup_pipeline_logging +from zenml.models import ( + ArtifactVersionResponse, + PipelineRunResponse, + PipelineSnapshotResponse, +) +from zenml.orchestrators.publish_utils import ( + publish_failed_pipeline_run, + publish_successful_pipeline_run, +) +from zenml.pipelines.dynamic.pipeline_definition import DynamicPipeline +from zenml.pipelines.run_utils import create_placeholder_run +from zenml.stack import Stack +from zenml.steps.entrypoint_function_utils import StepArtifact +from zenml.steps.utils import OutputSignature +from zenml.utils import source_utils + +if TYPE_CHECKING: + from zenml.config import DockerSettings + from zenml.config.step_configurations import Step + from zenml.steps import BaseStep + + +logger = get_logger(__name__) + + +class DynamicPipelineRunner: + """Dynamic pipeline runner.""" + + def __init__( + self, + snapshot: "PipelineSnapshotResponse", + run: Optional["PipelineRunResponse"], + ) -> None: + """Initialize the dynamic pipeline runner. + + Args: + snapshot: The snapshot of the pipeline. + run: The pipeline run. + + Raises: + RuntimeError: If the snapshot has no associated stack. + """ + if not snapshot.stack: + raise RuntimeError("Missing stack for snapshot.") + + if ( + snapshot.pipeline_configuration.execution_mode + != ExecutionMode.STOP_ON_FAILURE + ): + logger.warning( + "Only the STOP_ON_FAILURE execution mode is supported for " + "dynamic pipelines right now. " + "The execution mode `%s` will be ignored.", + snapshot.pipeline_configuration.execution_mode.value, + ) + + self._snapshot = snapshot + self._run = run + # TODO: make this configurable + self._executor = ThreadPoolExecutor(max_workers=10) + self._pipeline: Optional["DynamicPipeline"] = None + self._orchestrator = Stack.from_model(snapshot.stack).orchestrator + self._orchestrator_run_id = ( + self._orchestrator.get_orchestrator_run_id() + ) + self._futures: List[StepRunOutputsFuture] = [] + + @property + def pipeline(self) -> "DynamicPipeline": + """The pipeline that the runner is executing. + + Raises: + RuntimeError: If the pipeline can't be loaded. + + Returns: + The pipeline that the runner is executing. + """ + if self._pipeline is None: + if ( + not self._snapshot.pipeline_spec + or not self._snapshot.pipeline_spec.source + ): + raise RuntimeError("Missing pipeline source for snapshot.") + + pipeline = source_utils.load(self._snapshot.pipeline_spec.source) + if not isinstance(pipeline, DynamicPipeline): + raise RuntimeError( + "Invalid pipeline source: " + f"{self._snapshot.pipeline_spec.source.import_path}" + ) + self._pipeline = pipeline + + return self._pipeline + + def run_pipeline(self) -> None: + """Run the pipeline.""" + with setup_pipeline_logging( + source="orchestrator", + snapshot=self._snapshot, + run_id=self._run.id if self._run else None, + ) as logs_request: + with InMemoryArtifactCache(): + run = self._run or create_placeholder_run( + snapshot=self._snapshot, + orchestrator_run_id=self._orchestrator_run_id, + logs=logs_request, + ) + + assert ( + self._snapshot.pipeline_spec + ) # Always exists for new snapshots + pipeline_parameters = self._snapshot.pipeline_spec.parameters + + with DynamicPipelineRunContext( + pipeline=self.pipeline, + run=run, + snapshot=self._snapshot, + runner=self, + ): + self._orchestrator.run_init_hook(snapshot=self._snapshot) + try: + # TODO: step logging isn't threadsafe + # TODO: what should be allowed as pipeline returns? + # (artifacts, json serializable, anything?) + # how do we show it in the UI? + self.pipeline._call_entrypoint(**pipeline_parameters) + except: + publish_failed_pipeline_run(run.id) + logger.error( + "Pipeline run failed. All in-progress step runs " + "will still finish executing." + ) + raise + finally: + self._orchestrator.run_cleanup_hook( + snapshot=self._snapshot + ) + self._executor.shutdown(wait=True, cancel_futures=True) + # self.await_all_step_run_futures() + publish_successful_pipeline_run(run.id) + + @overload + def launch_step( + self, + step: "BaseStep", + id: Optional[str], + args: Tuple[Any], + kwargs: Dict[str, Any], + after: Union["StepRunFuture", Sequence["StepRunFuture"], None] = None, + concurrent: Literal[False] = False, + ) -> StepRunOutputs: ... + + @overload + def launch_step( + self, + step: "BaseStep", + id: Optional[str], + args: Tuple[Any], + kwargs: Dict[str, Any], + after: Union["StepRunFuture", Sequence["StepRunFuture"], None] = None, + concurrent: Literal[True] = True, + ) -> "StepRunOutputsFuture": ... + + def launch_step( + self, + step: "BaseStep", + id: Optional[str], + args: Tuple[Any], + kwargs: Dict[str, Any], + after: Union["StepRunFuture", Sequence["StepRunFuture"], None] = None, + concurrent: bool = False, + ) -> Union[StepRunOutputs, "StepRunOutputsFuture"]: + """Launch a step. + + Args: + step: The step to launch. + id: The invocation ID of the step. + args: The arguments for the step function. + kwargs: The keyword arguments for the step function. + after: The step run output futures to wait for. + concurrent: Whether to launch the step concurrently. + + Returns: + The step run outputs or a future for the step run outputs. + """ + step = step.copy() + compiled_step = compile_dynamic_step_invocation( + snapshot=self._snapshot, + pipeline=self.pipeline, + step=step, + id=id, + args=args, + kwargs=kwargs, + after=after, + ) + + def _launch() -> StepRunOutputs: + step_run = launch_step( + snapshot=self._snapshot, + step=compiled_step, + orchestrator_run_id=self._orchestrator_run_id, + retry=_should_retry_locally( + compiled_step, + self._snapshot.pipeline_configuration.docker_settings, + ), + ) + return _load_step_run_outputs(step_run.id) + + if concurrent: + ctx = contextvars.copy_context() + future = self._executor.submit(ctx.run, _launch) + compiled_step.config.outputs + step_run_future = StepRunOutputsFuture( + wrapped=future, + invocation_id=compiled_step.spec.invocation_id, + output_keys=list(compiled_step.config.outputs), + ) + self._futures.append(step_run_future) + return step_run_future + else: + return _launch() + + def await_all_step_run_futures(self) -> None: + """Await all step run output futures.""" + for future in self._futures: + future.artifacts() + self._futures = [] + + +def compile_dynamic_step_invocation( + snapshot: "PipelineSnapshotResponse", + pipeline: "DynamicPipeline", + step: "BaseStep", + id: Optional[str], + args: Tuple[Any], + kwargs: Dict[str, Any], + after: Union["StepRunFuture", Sequence["StepRunFuture"], None] = None, +) -> "Step": + """Compile a dynamic step invocation. + + Args: + snapshot: The snapshot. + pipeline: The dynamic pipeline. + step: The step to compile. + id: Custom invocation ID. + args: The arguments for the step function. + kwargs: The keyword arguments for the step function. + after: The step run output futures to wait for. + + Returns: + The compiled step. + """ + upstream_steps = set() + + if isinstance(after, _BaseStepRunFuture): + after._wait() + upstream_steps.add(after.invocation_id) + elif isinstance(after, Sequence): + for item in after: + item._wait() + upstream_steps.add(item.invocation_id) + + def _await_and_validate_input(input: Any) -> Any: + if isinstance(input, StepRunOutputsFuture): + if len(input._output_keys) != 1: + raise ValueError( + "Passing multiple step run outputs to another step is not " + "allowed." + ) + input = input.artifacts() + + if isinstance(input, ArtifactFuture): + input = input.result() + + if isinstance(input, OutputArtifact): + upstream_steps.add(input.step_name) + + return input + + args = tuple(_await_and_validate_input(arg) for arg in args) + kwargs = { + key: _await_and_validate_input(value) for key, value in kwargs.items() + } + + # TODO: we can validate the type of the inputs that are passed as raw data + signature = inspect.signature(step.entrypoint, follow_wrapped=True) + bound_args = signature.bind_partial(*args, **kwargs) + validated_args = bound_args.arguments + bound_args.apply_defaults() + default_parameters = { + key: value + for key, value in bound_args.arguments.items() + if key not in validated_args + } + + input_artifacts = {} + external_artifacts = {} + for name, value in validated_args.items(): + if isinstance(value, OutputArtifact): + input_artifacts[name] = StepArtifact( + invocation_id=value.step_name, + output_name=value.output_name, + annotation=OutputSignature(resolved_annotation=Any), + pipeline=pipeline, + ) + elif isinstance(value, (ArtifactVersionResponse, ExternalArtifact)): + external_artifacts[name] = value + else: + # TODO: should some of these be parameters? + external_artifacts[name] = ExternalArtifact(value=value) + + if template := get_config_template(snapshot, step, pipeline): + step._configuration = template.config.model_copy( + update={"template": template.spec.invocation_id} + ) + + invocation_id = pipeline.add_step_invocation( + step=step, + custom_id=id, + allow_id_suffix=not id, + input_artifacts=input_artifacts, + external_artifacts=external_artifacts, + upstream_steps=upstream_steps, + default_parameters=default_parameters, + parameters={}, + model_artifacts_or_metadata={}, + client_lazy_loaders={}, + ) + return Compiler()._compile_step_invocation( + invocation=pipeline.invocations[invocation_id], + stack=Client().active_stack, + step_config=None, + pipeline_configuration=pipeline.configuration, + ) + + +def _load_step_run_outputs(step_run_id: UUID) -> StepRunOutputs: + """Load the outputs of a step run. + + Args: + step_run_id: The ID of the step run. + + Returns: + The outputs of the step run. + """ + step_run = Client().zen_store.get_run_step(step_run_id) + + def _convert_output_artifact( + output_name: str, artifact: ArtifactVersionResponse + ) -> OutputArtifact: + return OutputArtifact( + output_name=output_name, + step_name=step_run.name, + **artifact.model_dump(), + ) + + output_artifacts = step_run.regular_outputs + if len(output_artifacts) == 0: + return None + elif len(output_artifacts) == 1: + name, artifact = next(iter(output_artifacts.items())) + return _convert_output_artifact(output_name=name, artifact=artifact) + else: + return tuple( + _convert_output_artifact(output_name=name, artifact=artifact) + for name, artifact in output_artifacts.items() + ) + + +def _should_retry_locally( + step: "Step", pipeline_docker_settings: "DockerSettings" +) -> bool: + """Determine if a step should be retried locally. + + Args: + step: The step. + pipeline_docker_settings: The Docker settings of the parent pipeline. + + Returns: + Whether the step should be retried locally. + """ + if step.config.step_operator: + return True + + if should_run_in_process(step, pipeline_docker_settings): + return True + else: + # Running out of process with the orchestrator + return ( + not Client().active_stack.orchestrator.config.handles_step_retries + ) + + +def should_run_in_process( + step: "Step", pipeline_docker_settings: "DockerSettings" +) -> bool: + """Determine if a step should be run in process. + + Args: + step: The step. + pipeline_docker_settings: The Docker settings of the parent pipeline. + + Returns: + Whether the step should be run in process. + """ + if step.config.step_operator: + return False + + if not Client().active_stack.orchestrator.can_launch_dynamic_steps: + return True + + if step.config.in_process is False: + return False + elif step.config.in_process is None: + if not step.config.resource_settings.empty: + return False + + if step.config.docker_settings != pipeline_docker_settings: + return False + + return True + + +def get_config_template( + snapshot: "PipelineSnapshotResponse", + step: "BaseStep", + pipeline: "DynamicPipeline", +) -> Optional["Step"]: + """Get the config template for a step executed in a dynamic pipeline. + + Args: + snapshot: The snapshot of the pipeline. + step: The step to get the config template for. + pipeline: The dynamic pipeline that the step is being executed in. + + Returns: + The config template for the step. + """ + for index, step_ in enumerate(pipeline.depends_on): + if step_._static_id == step._static_id: + break + else: + return None + + return list(snapshot.step_configurations.values())[index] diff --git a/src/zenml/execution/pipeline/utils.py b/src/zenml/execution/pipeline/utils.py new file mode 100644 index 00000000000..6ce3c064b6b --- /dev/null +++ b/src/zenml/execution/pipeline/utils.py @@ -0,0 +1,115 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Pipeline execution utilities.""" + +import contextvars +from contextlib import contextmanager +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + Optional, + Union, +) + +from zenml.client import Client +from zenml.config.step_configurations import StepConfigurationUpdate +from zenml.exceptions import RunMonitoringError +from zenml.logger import get_logger +from zenml.models import ( + PipelineRunResponse, + PipelineSnapshotResponse, +) +from zenml.orchestrators.publish_utils import publish_failed_pipeline_run +from zenml.stack import Stack + +if TYPE_CHECKING: + StepConfigurationUpdateOrDict = Union[ + Dict[str, Any], StepConfigurationUpdate + ] + +logger = get_logger(__name__) + + +_prevent_pipeline_execution = contextvars.ContextVar( + "prevent_pipeline_execution", default=False +) + + +def should_prevent_pipeline_execution() -> bool: + """Whether to prevent pipeline execution. + + Returns: + Whether to prevent pipeline execution. + """ + return _prevent_pipeline_execution.get() + + +@contextmanager +def prevent_pipeline_execution() -> Generator[None, None, None]: + """Context manager to prevent pipeline execution. + + Yields: + None. + """ + token = _prevent_pipeline_execution.set(True) + try: + yield + finally: + _prevent_pipeline_execution.reset(token) + + +def submit_pipeline( + snapshot: "PipelineSnapshotResponse", + stack: "Stack", + placeholder_run: Optional["PipelineRunResponse"] = None, +) -> None: + """Submit a snapshot for execution. + + Args: + snapshot: The snapshot to submit. + stack: The stack on which to submit the snapshot. + placeholder_run: An optional placeholder run for the snapshot. + + # noqa: DAR401 + Raises: + BaseException: Any exception that happened while submitting or running + (in case it happens synchronously) the pipeline. + """ + # Prevent execution of nested pipelines which might lead to + # unexpected behavior + with prevent_pipeline_execution(): + try: + stack.prepare_pipeline_submission(snapshot=snapshot) + stack.submit_pipeline( + snapshot=snapshot, + placeholder_run=placeholder_run, + ) + except RunMonitoringError as e: + # Don't mark the run as failed if the error happened during + # monitoring of the run. + raise e.original_exception from None + except BaseException as e: + if ( + placeholder_run + and not Client() + .get_pipeline_run(placeholder_run.id, hydrate=False) + .status.is_finished + ): + # We failed during/before the submission of the run, so we mark + # the run as failed if it's still in an unfinished state. + publish_failed_pipeline_run(placeholder_run.id) + + raise e diff --git a/src/zenml/execution/step/__init__.py b/src/zenml/execution/step/__init__.py new file mode 100644 index 00000000000..f716f05d49b --- /dev/null +++ b/src/zenml/execution/step/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Step execution.""" \ No newline at end of file diff --git a/src/zenml/execution/step/utils.py b/src/zenml/execution/step/utils.py new file mode 100644 index 00000000000..40c35630aac --- /dev/null +++ b/src/zenml/execution/step/utils.py @@ -0,0 +1,103 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Step execution utilities.""" + +import time +from typing import ( + TYPE_CHECKING, +) + +from zenml.config.step_configurations import Step +from zenml.exceptions import RunStoppedException +from zenml.logger import get_logger +from zenml.models import ( + PipelineSnapshotResponse, +) +from zenml.models.v2.core.step_run import StepRunResponse +from zenml.orchestrators.step_launcher import StepLauncher + +if TYPE_CHECKING: + from zenml.config.step_configurations import Step + + +logger = get_logger(__name__) + + +def launch_step( + snapshot: "PipelineSnapshotResponse", + step: "Step", + orchestrator_run_id: str, + retry: bool = False, +) -> StepRunResponse: + """Launch a step. + + Args: + snapshot: The snapshot. + step: The step to run. + orchestrator_run_id: The orchestrator run ID. + retry: Whether to retry the step if it fails. + + Raises: + RunStoppedException: If the run was stopped. + BaseException: If the step failed all retries. + + Returns: + The step run response. + """ + + def _launch_without_retry() -> StepRunResponse: + launcher = StepLauncher( + snapshot=snapshot, + step=step, + orchestrator_run_id=orchestrator_run_id, + ) + return launcher.launch() + + if not retry: + step_run = _launch_without_retry() + else: + retries = 0 + retry_config = step.config.retry + max_retries = retry_config.max_retries if retry_config else 0 + delay = retry_config.delay if retry_config else 0 + backoff = retry_config.backoff if retry_config else 1 + + while retries <= max_retries: + try: + step_run = _launch_without_retry() + except RunStoppedException: + # Don't retry if the run was stopped + raise + except BaseException: + retries += 1 + if retries <= max_retries: + logger.info( + "Sleeping for %d seconds before retrying step `%s`.", + delay, + step.config.name, + ) + time.sleep(delay) + delay *= backoff + else: + if max_retries > 0: + logger.error( + "Failed to run step `%s` after %d retries.", + step.config.name, + max_retries, + ) + raise + else: + break + + return step_run diff --git a/src/zenml/integrations/azure/orchestrators/azureml_orchestrator_entrypoint_config.py b/src/zenml/integrations/azure/orchestrators/azureml_orchestrator_entrypoint_config.py index a2006cef723..ffce732f7ab 100644 --- a/src/zenml/integrations/azure/orchestrators/azureml_orchestrator_entrypoint_config.py +++ b/src/zenml/integrations/azure/orchestrators/azureml_orchestrator_entrypoint_config.py @@ -15,7 +15,7 @@ import json import os -from typing import Any, List, Set +from typing import Any, Dict, List from zenml.entrypoints.step_entrypoint_configuration import ( StepEntrypointConfiguration, @@ -30,14 +30,14 @@ class AzureMLEntrypointConfiguration(StepEntrypointConfiguration): """Entrypoint configuration for ZenML AzureML pipeline steps.""" @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> Dict[str, bool]: """Gets all options required for running with this configuration. Returns: The superclass options as well as an option for the environmental variables. """ - return super().get_entrypoint_options() | {ZENML_ENV_VARIABLES} + return super().get_entrypoint_options() | {ZENML_ENV_VARIABLES: True} @classmethod def get_entrypoint_arguments(cls, **kwargs: Any) -> List[str]: diff --git a/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py b/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py index de2001e3fee..ec2523e2a6a 100644 --- a/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py +++ b/src/zenml/integrations/databricks/orchestrators/databricks_orchestrator_entrypoint_config.py @@ -16,7 +16,7 @@ import os import sys from importlib.metadata import distribution -from typing import Any, List, Set +from typing import Any, Dict, List from zenml.entrypoints.step_entrypoint_configuration import ( StepEntrypointConfiguration, @@ -38,17 +38,16 @@ class DatabricksEntrypointConfiguration(StepEntrypointConfiguration): """ @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> Dict[str, bool]: """Gets all options required for running with this configuration. Returns: The superclass options as well as an option for the wheel package. """ - return ( - super().get_entrypoint_options() - | {WHEEL_PACKAGE_OPTION} - | {DATABRICKS_JOB_ID_OPTION} - ) + return super().get_entrypoint_options() | { + WHEEL_PACKAGE_OPTION: True, + DATABRICKS_JOB_ID_OPTION: True, + } @classmethod def get_entrypoint_arguments( diff --git a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py index 384a80bece6..476fff5f273 100644 --- a/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +++ b/src/zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py @@ -32,9 +32,12 @@ import os import random +import socket +from contextlib import contextmanager from typing import ( TYPE_CHECKING, Dict, + Generator, List, Optional, Tuple, @@ -49,6 +52,7 @@ from zenml.config.base_settings import BaseSettings from zenml.constants import ( METADATA_ORCHESTRATOR_RUN_ID, + ORCHESTRATOR_DOCKER_IMAGE_KEY, ) from zenml.enums import ExecutionMode, ExecutionStatus, StackComponentType from zenml.integrations.kubernetes.constants import ( @@ -73,6 +77,7 @@ job_template_manifest_from_job, pod_template_manifest_from_pod, ) +from zenml.integrations.kubernetes.pod_settings import KubernetesPodSettings from zenml.logger import get_logger from zenml.metadata.metadata_types import MetadataType from zenml.models.v2.core.schedule import ScheduleUpdate @@ -80,6 +85,7 @@ from zenml.stack import StackValidator if TYPE_CHECKING: + from zenml.config.step_run_info import StepRunInfo from zenml.models import ( PipelineRunResponse, PipelineSnapshotBase, @@ -110,7 +116,10 @@ def should_build_pipeline_image( settings = cast( KubernetesOrchestratorSettings, self.get_settings(snapshot) ) - return settings.always_build_pipeline_image + if settings.always_build_pipeline_image: + return True + else: + return super().should_build_pipeline_image(snapshot) def get_kube_client( self, incluster: Optional[bool] = None @@ -401,12 +410,7 @@ def submit_pipeline( step_environments: Dict[str, Dict[str, str]], placeholder_run: Optional["PipelineRunResponse"] = None, ) -> Optional[SubmissionResult]: - """Submits a pipeline to the orchestrator. - - This method should only submit the pipeline and not wait for it to - complete. If the orchestrator is configured to wait for the pipeline run - to complete, a function that waits for the pipeline run to complete can - be passed as part of the submission result. + """Submit a static pipeline to the orchestrator. Args: snapshot: The pipeline snapshot to submit. @@ -418,10 +422,6 @@ def submit_pipeline( specific steps. placeholder_run: An optional placeholder run for the snapshot. - Raises: - RuntimeError: If a schedule without cron expression is given. - Exception: If the orchestrator pod fails to start. - Returns: Optional submission result. """ @@ -441,33 +441,86 @@ def submit_pipeline( "for the Kubernetes orchestrator." ) - pipeline_name = snapshot.pipeline_configuration.name - settings = cast( - KubernetesOrchestratorSettings, self.get_settings(snapshot) + command = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_command() + args = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_arguments( + snapshot_id=snapshot.id, + run_id=placeholder_run.id if placeholder_run else None, ) - assert stack.container_registry + return self._submit_orchestrator_job( + snapshot=snapshot, + command=command, + args=args, + environment=base_environment, + placeholder_run=placeholder_run, + ) - # Get Docker image for the orchestrator pod - try: - image = self.get_image(snapshot=snapshot) - except KeyError: - # If no generic pipeline image exists (which means all steps have - # custom builds) we use a random step image as all of them include - # dependencies for the active stack - pipeline_step_name = next(iter(snapshot.step_configurations)) - image = self.get_image( - snapshot=snapshot, step_name=pipeline_step_name - ) + def submit_dynamic_pipeline( + self, + snapshot: "PipelineSnapshotResponse", + stack: "Stack", + environment: Dict[str, str], + placeholder_run: Optional["PipelineRunResponse"] = None, + ) -> Optional[SubmissionResult]: + """Submits a dynamic pipeline to the orchestrator. - # Build entrypoint command and args for the orchestrator pod. - # This will internally also build the command/args for all step pods. - command = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_command() - args = KubernetesOrchestratorEntrypointConfiguration.get_entrypoint_arguments( + Args: + snapshot: The snapshot of the pipeline. + stack: The stack to use for the pipeline. + environment: The environment variables to set in the pipeline. + placeholder_run: The placeholder run for the pipeline. + + Returns: + Optional submission result. + """ + from zenml.pipelines.dynamic.entrypoint_configuration import ( + DynamicPipelineEntrypointConfiguration, + ) + + command = ( + DynamicPipelineEntrypointConfiguration.get_entrypoint_command() + ) + args = DynamicPipelineEntrypointConfiguration.get_entrypoint_arguments( snapshot_id=snapshot.id, run_id=placeholder_run.id if placeholder_run else None, ) + return self._submit_orchestrator_job( + snapshot=snapshot, + command=command, + args=args, + environment=environment, + placeholder_run=placeholder_run, + ) + + def _prepare_job_manifest( + self, + name: str, + command: List[str], + args: List[str], + image: str, + environment: Dict[str, str], + labels: Dict[str, str], + annotations: Dict[str, str], + settings: KubernetesOrchestratorSettings, + pod_settings: Optional[KubernetesPodSettings] = None, + ) -> k8s_client.V1Job: + """Prepares the job manifest for a Kubernetes job. + + Args: + name: The name of the job. + command: The command to run in the job. + args: The arguments to pass to the job. + image: The image to use for the job. + environment: The environment variables to set in the job. + labels: The labels to add to the job. + annotations: The annotations to add to the job. + settings: Component settings for the orchestrator. + pod_settings: Optional settings for the pod. + + Returns: + The job manifest. + """ # Authorize pod to run Kubernetes commands inside the cluster. service_account_name = self._get_service_account_name(settings) @@ -476,55 +529,22 @@ def submit_pipeline( # takes up some memory resources itself and, if not specified, the pod # will be scheduled on any node regardless of available memory and risk # negatively impacting or even crashing the node due to memory pressure. - orchestrator_pod_settings = kube_utils.apply_default_resource_requests( + pod_settings = kube_utils.apply_default_resource_requests( memory="400Mi", cpu="100m", - pod_settings=settings.orchestrator_pod_settings, + pod_settings=pod_settings, ) - if self.config.pass_zenml_token_as_secret: - secret_name = self.get_token_secret_name(snapshot.id) - token = base_environment.pop("ZENML_STORE_API_TOKEN") - kube_utils.create_or_update_secret( - core_api=self._k8s_core_api, - namespace=self.config.kubernetes_namespace, - secret_name=secret_name, - data={KUBERNETES_SECRET_TOKEN_KEY_NAME: token}, - ) - orchestrator_pod_settings.env.append( - { - "name": "ZENML_STORE_API_TOKEN", - "valueFrom": { - "secretKeyRef": { - "name": secret_name, - "key": KUBERNETES_SECRET_TOKEN_KEY_NAME, - } - }, - } - ) - - orchestrator_pod_labels = { - "pipeline": kube_utils.sanitize_label(pipeline_name), - } - - if placeholder_run: - orchestrator_pod_labels["run_id"] = kube_utils.sanitize_label( - str(placeholder_run.id) - ) - orchestrator_pod_labels["run_name"] = kube_utils.sanitize_label( - placeholder_run.name - ) - pod_manifest = build_pod_manifest( pod_name=None, image_name=image, command=command, args=args, privileged=False, - pod_settings=orchestrator_pod_settings, + pod_settings=pod_settings, service_account_name=service_account_name, - env=base_environment, - labels=orchestrator_pod_labels, + env=environment, + labels=labels, mount_local_stores=self.config.is_local, termination_grace_period_seconds=settings.pod_stop_grace_period, ) @@ -558,103 +578,325 @@ def submit_pipeline( ] } - job_name = settings.job_name_prefix or "" - random_prefix = "".join(random.choices("0123456789abcdef", k=8)) - job_name += f"-{random_prefix}-{snapshot.pipeline_configuration.name}" - # The job name will be used as a label on the pods, so we need to make - # sure it doesn't exceed the label length limit - job_name = kube_utils.sanitize_label(job_name) - - job_manifest = build_job_manifest( - job_name=job_name, + return build_job_manifest( + job_name=name, pod_template=pod_template_manifest_from_pod(pod_manifest), backoff_limit=settings.orchestrator_job_backoff_limit, ttl_seconds_after_finished=settings.ttl_seconds_after_finished, active_deadline_seconds=settings.active_deadline_seconds, pod_failure_policy=pod_failure_policy, - labels=orchestrator_pod_labels, - annotations={ - ORCHESTRATOR_ANNOTATION_KEY: str(self.id), - }, + labels=labels, + annotations=annotations, ) - if snapshot.schedule: - if not snapshot.schedule.cron_expression: - raise RuntimeError( - "The Kubernetes orchestrator only supports scheduling via " - "CRON jobs, but the run was configured with a manual " - "schedule. Use `Schedule(cron_expression=...)` instead." + def _get_job_name( + self, + settings: KubernetesOrchestratorSettings, + pipeline_name: str, + step_name: Optional[str] = None, + ) -> str: + """Gets a job name for a Kubernetes job. + + Args: + settings: The settings for the orchestrator. + pipeline_name: The name of the pipeline. + step_name: The name of the step. + + Returns: + The job name. + """ + job_name = settings.job_name_prefix or "" + random_prefix = "".join(random.choices("0123456789abcdef", k=8)) + job_name += f"-{random_prefix}-{pipeline_name}" + if step_name: + job_name += f"-{step_name}" + # The job name will be used as a label on the pods, so we need to make + # sure it doesn't exceed the label length limit + job_name = kube_utils.sanitize_label(job_name) + return job_name + + @contextmanager + def _create_auth_secret_if_necessary( + self, + snapshot: "PipelineSnapshotResponse", + environment: Dict[str, str], + pod_settings: KubernetesPodSettings, + ) -> Generator[None, None, None]: + """Creates an authentication secret if necessary. + + If the authentication secret is created and some exception is raised, + the secret will be deleted. + + Args: + snapshot: The pipeline snapshot. + environment: The environment variables to set. + pod_settings: The pod settings to update. + + Raises: + Exception: If an exception happens while the context manager is + active. + + Yields: + None. + """ + try: + if self.config.pass_zenml_token_as_secret: + secret_name = self.get_token_secret_name(snapshot.id) + token = environment.pop("ZENML_STORE_API_TOKEN") + kube_utils.create_or_update_secret( + core_api=self._k8s_core_api, + namespace=self.config.kubernetes_namespace, + secret_name=secret_name, + data={KUBERNETES_SECRET_TOKEN_KEY_NAME: token}, ) - cron_expression = snapshot.schedule.cron_expression - cron_job_manifest = build_cron_job_manifest( - cron_expression=cron_expression, - job_template=job_template_manifest_from_job(job_manifest), - successful_jobs_history_limit=settings.successful_jobs_history_limit, - failed_jobs_history_limit=settings.failed_jobs_history_limit, - ) + pod_settings.env.append( + { + "name": "ZENML_STORE_API_TOKEN", + "valueFrom": { + "secretKeyRef": { + "name": secret_name, + "key": KUBERNETES_SECRET_TOKEN_KEY_NAME, + } + }, + } + ) + yield + except Exception as e: + if self.config.pass_zenml_token_as_secret: + secret_name = self.get_token_secret_name(snapshot.id) + try: + kube_utils.delete_secret( + core_api=self._k8s_core_api, + namespace=self.config.kubernetes_namespace, + secret_name=secret_name, + ) + except Exception as cleanup_error: + logger.error( + "Error cleaning up secret %s: %s", + secret_name, + cleanup_error, + ) + raise e - cron_job = self._k8s_batch_api.create_namespaced_cron_job( - body=cron_job_manifest, - namespace=self.config.kubernetes_namespace, + def _submit_orchestrator_job( + self, + snapshot: "PipelineSnapshotResponse", + command: List[str], + args: List[str], + environment: Dict[str, str], + placeholder_run: Optional["PipelineRunResponse"] = None, + ) -> Optional[SubmissionResult]: + """Submits an orchestrator job to Kubernetes. + + Args: + snapshot: The pipeline snapshot. + command: The command to run in the job. + args: The arguments to pass to the job. + environment: The environment variables to set in the job. + placeholder_run: The placeholder run for the job. + + Raises: + RuntimeError: If a schedule without cron expression is given. + + Returns: + Optional submission result. + """ + pipeline_name = snapshot.pipeline_configuration.name + settings = cast( + KubernetesOrchestratorSettings, self.get_settings(snapshot) + ) + orchestrator_pod_settings = ( + settings.orchestrator_pod_settings or KubernetesPodSettings() + ) + + try: + image = self.get_image(snapshot=snapshot) + except KeyError: + # If no generic pipeline image exists (which means all steps of a + # static pipeline have custom builds) we use a random step image as + # all of them include dependencies for the active stack + invocation_id = next(iter(snapshot.step_configurations)) + image = self.get_image(snapshot=snapshot, step_name=invocation_id) + + labels = { + "pipeline": kube_utils.sanitize_label(pipeline_name), + } + + if placeholder_run: + labels["run_id"] = kube_utils.sanitize_label( + str(placeholder_run.id) ) - logger.info( - f"Created Kubernetes CronJob `{cron_job.metadata.name}` " - f"with CRON expression `{cron_expression}`." + labels["run_name"] = kube_utils.sanitize_label( + placeholder_run.name ) - return SubmissionResult( - metadata={ - KUBERNETES_CRON_JOB_METADATA_KEY: cron_job.metadata.name, - } + + annotations = { + ORCHESTRATOR_ANNOTATION_KEY: str(self.id), + } + + job_name = self._get_job_name( + settings, pipeline_name=snapshot.pipeline_configuration.name + ) + + with self._create_auth_secret_if_necessary( + snapshot, environment, orchestrator_pod_settings + ): + job_manifest = self._prepare_job_manifest( + name=job_name, + command=command, + args=args, + image=image, + environment=environment, + labels=labels, + annotations=annotations, + settings=settings, + pod_settings=orchestrator_pod_settings, ) - else: - try: + + if snapshot.schedule: + if not snapshot.schedule.cron_expression: + raise RuntimeError( + "The Kubernetes orchestrator only supports scheduling via " + "CRON jobs, but the run was configured with a manual " + "schedule. Use `Schedule(cron_expression=...)` instead." + ) + cron_expression = snapshot.schedule.cron_expression + cron_job_manifest = build_cron_job_manifest( + cron_expression=cron_expression, + job_template=job_template_manifest_from_job(job_manifest), + successful_jobs_history_limit=settings.successful_jobs_history_limit, + failed_jobs_history_limit=settings.failed_jobs_history_limit, + ) + + cron_job = self._k8s_batch_api.create_namespaced_cron_job( + body=cron_job_manifest, + namespace=self.config.kubernetes_namespace, + ) + logger.info( + f"Created Kubernetes CronJob `{cron_job.metadata.name}` " + f"with CRON expression `{cron_expression}`." + ) + return SubmissionResult( + metadata={ + KUBERNETES_CRON_JOB_METADATA_KEY: cron_job.metadata.name, + } + ) + else: kube_utils.create_job( batch_api=self._k8s_batch_api, namespace=self.config.kubernetes_namespace, job_manifest=job_manifest, ) - except Exception as e: - if self.config.pass_zenml_token_as_secret: - secret_name = self.get_token_secret_name(snapshot.id) - try: - kube_utils.delete_secret( + + if settings.synchronous: + + def _wait_for_run_to_finish() -> None: + logger.info( + "Waiting for orchestrator job to finish..." + ) + kube_utils.wait_for_job_to_finish( + batch_api=self._k8s_batch_api, core_api=self._k8s_core_api, namespace=self.config.kubernetes_namespace, - secret_name=secret_name, - ) - except Exception as cleanup_error: - logger.error( - "Error cleaning up secret %s: %s", - secret_name, - cleanup_error, + job_name=job_name, + backoff_interval=settings.job_monitoring_interval, + fail_on_container_waiting_reasons=settings.fail_on_container_waiting_reasons, + stream_logs=True, ) - raise e - if settings.synchronous: - - def _wait_for_run_to_finish() -> None: - logger.info("Waiting for orchestrator job to finish...") - kube_utils.wait_for_job_to_finish( - batch_api=self._k8s_batch_api, - core_api=self._k8s_core_api, - namespace=self.config.kubernetes_namespace, - job_name=job_name, - backoff_interval=settings.job_monitoring_interval, - fail_on_container_waiting_reasons=settings.fail_on_container_waiting_reasons, - stream_logs=True, + return SubmissionResult( + wait_for_completion=_wait_for_run_to_finish, ) + else: + logger.info( + f"Orchestrator job `{job_name}` started. " + f"Run the following command to inspect the logs: " + f"`kubectl -n {self.config.kubernetes_namespace} logs " + f"job/{job_name}`" + ) + return None - return SubmissionResult( - wait_for_completion=_wait_for_run_to_finish, - ) - else: - logger.info( - f"Orchestrator job `{job_name}` started. " - f"Run the following command to inspect the logs: " - f"`kubectl -n {self.config.kubernetes_namespace} logs " - f"job/{job_name}`" - ) - return None + def launch_dynamic_step( + self, step_run_info: "StepRunInfo", environment: Dict[str, str] + ) -> None: + """Launches a dynamic step on Kubernetes. + + Args: + step_run_info: The step run information. + environment: The environment variables to set. + """ + from zenml.step_operators.step_operator_entrypoint_configuration import ( + StepOperatorEntrypointConfiguration, + ) + + logger.info( + "Launching job for step `%s`.", + step_run_info.pipeline_step_name, + ) + + settings = cast( + KubernetesOrchestratorSettings, self.get_settings(step_run_info) + ) + image = step_run_info.get_image(key=ORCHESTRATOR_DOCKER_IMAGE_KEY) + command = StepOperatorEntrypointConfiguration.get_entrypoint_command() + args = StepOperatorEntrypointConfiguration.get_entrypoint_arguments( + step_name=step_run_info.pipeline_step_name, + snapshot_id=(step_run_info.snapshot.id), + step_run_id=str(step_run_info.step_run_id), + ) + + labels = { + "pipeline": kube_utils.sanitize_label(step_run_info.pipeline.name), + "run_id": kube_utils.sanitize_label(str(step_run_info.run_id)), + "run_name": kube_utils.sanitize_label(str(step_run_info.run_name)), + "step_run_id": kube_utils.sanitize_label( + str(step_run_info.step_run_id) + ), + "step_name": kube_utils.sanitize_label( + step_run_info.pipeline_step_name + ), + } + annotations = { + STEP_NAME_ANNOTATION_KEY: step_run_info.pipeline_step_name, + } + + job_name = self._get_job_name( + settings, + pipeline_name=step_run_info.pipeline.name, + step_name=step_run_info.pipeline_step_name, + ) + + job_manifest = self._prepare_job_manifest( + name=job_name, + command=command, + args=args, + image=image, + environment=environment, + labels=labels, + annotations=annotations, + settings=settings, + pod_settings=settings.pod_settings, + ) + + kube_utils.create_job( + batch_api=self._k8s_batch_api, + namespace=self.config.kubernetes_namespace, + job_manifest=job_manifest, + ) + + logger.info( + "Waiting for job `%s` to finish...", + job_name, + ) + kube_utils.wait_for_job_to_finish( + batch_api=self._k8s_batch_api, + core_api=self._k8s_core_api, + namespace=self.config.kubernetes_namespace, + job_name=job_name, + fail_on_container_waiting_reasons=settings.fail_on_container_waiting_reasons, + stream_logs=True, + ) + logger.info("Job completed.") def _get_service_account_name( self, settings: KubernetesOrchestratorSettings @@ -685,20 +927,15 @@ def _get_service_account_name( def get_orchestrator_run_id(self) -> str: """Returns the active orchestrator run id. - Raises: - RuntimeError: If the environment variable specifying the run id - is not set. - Returns: The orchestrator run id. """ try: return os.environ[ENV_ZENML_KUBERNETES_RUN_ID] except KeyError: - raise RuntimeError( - "Unable to read run id from environment variable " - f"{ENV_ZENML_KUBERNETES_RUN_ID}." - ) + # This means we're in a dynamic pipeline orchestration container, + # so we use the hostname (= pod name) as the run id + return socket.gethostname() def _stop_run( self, run: "PipelineRunResponse", graceful: bool = True diff --git a/src/zenml/logging/step_logging.py b/src/zenml/logging/step_logging.py index 52c27aeafec..473212b904d 100644 --- a/src/zenml/logging/step_logging.py +++ b/src/zenml/logging/step_logging.py @@ -20,12 +20,13 @@ import re import threading import time -from contextlib import nullcontext +from contextlib import contextmanager, nullcontext from contextvars import ContextVar from datetime import datetime from types import TracebackType from typing import ( Any, + Generator, Iterator, List, Optional, @@ -930,3 +931,71 @@ def setup_orchestrator_logging( f"Failed to setup orchestrator logging for run {run_id}: {e}" ) return nullcontext() + + +@contextmanager +def setup_pipeline_logging( + source: str, + snapshot: "PipelineSnapshotResponse", + run_id: Optional[UUID] = None, + logs_response: Optional[LogsResponse] = None, +) -> Generator[Optional[LogsRequest], None, None]: + """Set up logging for a pipeline run. + + Args: + source: The log source. + snapshot: The snapshot of the pipeline run. + run_id: The ID of the pipeline run. + logs_response: The logs response to continue from. + + Raises: + Exception: If updating the run with the logs request fails. + + Yields: + The logs request. + """ + logging_enabled = True + + if handle_bool_env_var(ENV_ZENML_DISABLE_PIPELINE_LOGS_STORAGE, False): + logging_enabled = False + elif snapshot.pipeline_configuration.enable_pipeline_logs is not None: + logging_enabled = snapshot.pipeline_configuration.enable_pipeline_logs + + if logging_enabled: + client = Client() + artifact_store = client.active_stack.artifact_store + logs_model = None + + if logs_response: + logs_uri = logs_response.uri + else: + logs_uri = prepare_logs_uri( + artifact_store=artifact_store, + ) + logs_model = LogsRequest( + uri=logs_uri, + source=source, + artifact_store_id=artifact_store.id, + ) + + if run_id: + try: + run_update = PipelineRunUpdate(add_logs=[logs_model]) + client.zen_store.update_run( + run_id=run_id, run_update=run_update + ) + except Exception as e: + logger.error( + f"Failed to add logs to the run {run_id}: {e}" + ) + raise e + + logging_context = PipelineLogsStorageContext( + logs_uri=logs_uri, + artifact_store=artifact_store, + prepend_step_name=False, + ) + with logging_context: + yield logs_model + else: + yield None diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index fcc5a0f753a..dfee646f5ab 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -440,15 +440,28 @@ def run(self) -> "PipelineRunResponse": return Client().get_pipeline_run(self.step.pipeline_run_id) - def load(self) -> Any: + def load(self, disable_cache: bool = False) -> Any: """Materializes (loads) the data stored in this artifact. + Args: + disable_cache: Whether to disable the artifact cache. + Returns: The materialized data. """ + from zenml.artifacts.in_memory_cache import InMemoryArtifactCache from zenml.artifacts.utils import load_artifact_from_response - return load_artifact_from_response(self) + cache = InMemoryArtifactCache.get() + + if cache and (data := cache.get_artifact_data(self.id)): + logger.debug("Returning artifact data (%s) from cache", self.id) + return data + + data = load_artifact_from_response(self) + if cache and not disable_cache: + cache.set_artifact_data(self.id, data) + return data def download_files(self, path: str, overwrite: bool = False) -> None: """Downloads data for an artifact with no materializing. diff --git a/src/zenml/models/v2/core/pipeline_run.py b/src/zenml/models/v2/core/pipeline_run.py index 1f74a9e99e3..f852c4a7a83 100644 --- a/src/zenml/models/v2/core/pipeline_run.py +++ b/src/zenml/models/v2/core/pipeline_run.py @@ -178,6 +178,10 @@ class PipelineRunUpdate(BaseUpdate): max_length=STR_FIELD_MAX_LENGTH, ) end_time: Optional[datetime] = None + is_finished: Optional[bool] = Field( + default=None, + title="Whether the pipeline run is finished.", + ) orchestrator_run_id: Optional[str] = None # TODO: we should maybe have a different update model here, the upper # three attributes should only be for internal use diff --git a/src/zenml/models/v2/core/pipeline_snapshot.py b/src/zenml/models/v2/core/pipeline_snapshot.py index c56cea969d9..1f1b513aa0a 100644 --- a/src/zenml/models/v2/core/pipeline_snapshot.py +++ b/src/zenml/models/v2/core/pipeline_snapshot.py @@ -104,6 +104,10 @@ class PipelineSnapshotBase(BaseZenModel): default=None, title="The pipeline spec of the snapshot.", ) + is_dynamic: bool = Field( + default=False, + title="Whether this is a snapshot of a dynamic pipeline.", + ) @property def should_prevent_build_reuse(self) -> bool: @@ -237,6 +241,9 @@ class PipelineSnapshotResponseBody(ProjectScopedResponseBody): deployable: bool = Field( title="If the snapshot can be deployed.", ) + is_dynamic: bool = Field( + title="Whether this is a snapshot of a dynamic pipeline.", + ) class PipelineSnapshotResponseMetadata(ProjectScopedResponseMetadata): @@ -389,6 +396,15 @@ def deployable(self) -> bool: """ return self.get_body().deployable + @property + def is_dynamic(self) -> bool: + """The `is_dynamic` property. + + Returns: + the value of the property. + """ + return self.get_body().is_dynamic + @property def description(self) -> Optional[str]: """The `description` property. diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index 8c831fca4a9..1205fc97784 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -28,7 +28,7 @@ from pydantic import ConfigDict, Field -from zenml.config.step_configurations import StepConfiguration, StepSpec +from zenml.config.step_configurations import Step, StepConfiguration, StepSpec from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH from zenml.enums import ( ArtifactSaveType, @@ -153,6 +153,10 @@ class StepRunRequest(ProjectScopedRequest): default=None, title="The exception information of the step run.", ) + dynamic_config: Optional["Step"] = Field( + title="The dynamic configuration of the step run.", + default=None, + ) model_config = ConfigDict(protected_namespaces=()) diff --git a/src/zenml/orchestrators/base_orchestrator.py b/src/zenml/orchestrators/base_orchestrator.py index 5e7422ff6ce..5f6ea2ff914 100644 --- a/src/zenml/orchestrators/base_orchestrator.py +++ b/src/zenml/orchestrators/base_orchestrator.py @@ -13,7 +13,6 @@ # permissions and limitations under the License. """Base orchestrator class.""" -import time from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, @@ -40,7 +39,6 @@ HookExecutionException, IllegalOperationError, RunMonitoringError, - RunStoppedException, ) from zenml.hooks.hook_validators import load_and_run_hook from zenml.logger import get_logger @@ -50,7 +48,6 @@ publish_pipeline_run_status_update, publish_schedule_metadata, ) -from zenml.orchestrators.step_launcher import StepLauncher from zenml.orchestrators.utils import get_config_environment_vars from zenml.stack import Flavor, Stack, StackComponent, StackComponentConfig from zenml.steps.step_context import RunContext, get_or_create_run_context @@ -59,6 +56,7 @@ if TYPE_CHECKING: from zenml.config.step_configurations import Step + from zenml.config.step_run_info import StepRunInfo from zenml.models import ( PipelineRunResponse, PipelineSnapshotResponse, @@ -208,6 +206,27 @@ def submit_pipeline( """ return None + def submit_dynamic_pipeline( + self, + snapshot: "PipelineSnapshotResponse", + stack: "Stack", + environment: Dict[str, str], + placeholder_run: Optional["PipelineRunResponse"] = None, + ) -> Optional[SubmissionResult]: + """Submits a dynamic pipeline to the orchestrator. + + Args: + snapshot: The pipeline snapshot to submit. + stack: The stack the pipeline will run on. + environment: Environment variables to set in the orchestration + environment. + placeholder_run: An optional placeholder run. + + Returns: + Optional submission result. + """ + return None + def prepare_or_run_pipeline( self, deployment: "PipelineSnapshotResponse", @@ -271,6 +290,7 @@ def run( placeholder_run and self.config.supports_client_side_caching and not snapshot.schedule + and not snapshot.is_dynamic and not prevent_client_side_caching ): from zenml.orchestrators import cache_utils @@ -289,22 +309,10 @@ def run( else: logger.debug("Skipping client-side caching.") - step_environments = {} - for invocation_id, step in snapshot.step_configurations.items(): - from zenml.utils.env_utils import get_step_environment - - step_environment = get_step_environment( - step_config=step.config, - stack=stack, - ) - - combined_environment = base_environment.copy() - combined_environment.update(step_environment) - step_environments[invocation_id] = combined_environment - try: if ( - getattr(self.submit_pipeline, "__func__", None) + not snapshot.is_dynamic + and getattr(self.submit_pipeline, "__func__", None) is BaseOrchestrator.submit_pipeline ): logger.warning( @@ -336,13 +344,37 @@ def run( f"run metadata: {e}" ) else: - submission_result = self.submit_pipeline( - snapshot=snapshot, - stack=stack, - base_environment=base_environment, - step_environments=step_environments, - placeholder_run=placeholder_run, - ) + if snapshot.is_dynamic: + submission_result = self.submit_dynamic_pipeline( + snapshot=snapshot, + stack=stack, + environment=base_environment, + placeholder_run=placeholder_run, + ) + else: + step_environments = {} + for ( + invocation_id, + step, + ) in snapshot.step_configurations.items(): + from zenml.utils.env_utils import get_step_environment + + step_environment = get_step_environment( + step_config=step.config, + stack=stack, + ) + + combined_environment = base_environment.copy() + combined_environment.update(step_environment) + step_environments[invocation_id] = combined_environment + + submission_result = self.submit_pipeline( + snapshot=snapshot, + stack=stack, + base_environment=base_environment, + step_environments=step_environments, + placeholder_run=placeholder_run, + ) if placeholder_run: publish_pipeline_run_status_update( pipeline_run_id=placeholder_run.id, @@ -401,59 +433,60 @@ def run_step( Args: step: The step to run. + """ + from zenml.execution.step.utils import launch_step - Raises: - RunStoppedException: If the run was stopped. - BaseException: If the step failed all retries. + assert self._active_snapshot + + launch_step( + snapshot=self._active_snapshot, + step=step, + orchestrator_run_id=self.get_orchestrator_run_id(), + retry=not self.config.handles_step_retries, + ) + + @property + def supports_dynamic_pipelines(self) -> bool: + """Whether the orchestrator supports dynamic pipelines. + + Returns: + Whether the orchestrator supports dynamic pipelines. """ + return ( + getattr(self.submit_dynamic_pipeline, "__func__", None) + is not BaseOrchestrator.submit_dynamic_pipeline + ) - def _launch_step() -> None: - assert self._active_snapshot + @property + def can_launch_dynamic_steps(self) -> bool: + """Whether the orchestrator can launch dynamic steps. - launcher = StepLauncher( - snapshot=self._active_snapshot, - step=step, - orchestrator_run_id=self.get_orchestrator_run_id(), - ) - launcher.launch() + Returns: + Whether the orchestrator can launch dynamic steps. + """ + return ( + getattr(self.launch_dynamic_step, "__func__", None) + is not BaseOrchestrator.launch_dynamic_step + ) - if self.config.handles_step_retries: - _launch_step() - else: - # The orchestrator subclass doesn't handle step retries, so we - # handle it in-process instead - retries = 0 - retry_config = step.config.retry - max_retries = retry_config.max_retries if retry_config else 0 - delay = retry_config.delay if retry_config else 0 - backoff = retry_config.backoff if retry_config else 1 - - while retries <= max_retries: - try: - _launch_step() - except RunStoppedException: - # Don't retry if the run was stopped - raise - except BaseException: - retries += 1 - if retries <= max_retries: - logger.info( - "Sleeping for %d seconds before retrying step `%s`.", - delay, - step.config.name, - ) - time.sleep(delay) - delay *= backoff - else: - if max_retries > 0: - logger.error( - "Failed to run step `%s` after %d retries.", - step.config.name, - max_retries, - ) - raise - else: - break + def launch_dynamic_step( + self, step_run_info: "StepRunInfo", environment: Dict[str, str] + ) -> None: + """Launch a dynamic step. + + Args: + step_run_info: The step run information. + environment: The environment variables to set in the execution + environment. + + Raises: + NotImplementedError: If the orchestrator does not implement this + method. + """ + raise NotImplementedError( + "Launching dynamic steps is not implemented for " + f"the {self.__class__.__name__} orchestrator." + ) @staticmethod def requires_resources_in_orchestration_environment( diff --git a/src/zenml/orchestrators/containerized_orchestrator.py b/src/zenml/orchestrators/containerized_orchestrator.py index 80d50072463..175a0702f14 100644 --- a/src/zenml/orchestrators/containerized_orchestrator.py +++ b/src/zenml/orchestrators/containerized_orchestrator.py @@ -82,7 +82,9 @@ def should_build_pipeline_image( Returns: Whether to build the pipeline image. """ - return False + # When running a dynamic pipeline, we need an image for the + # orchestration container. + return snapshot.is_dynamic def get_docker_builds( self, snapshot: "PipelineSnapshotBase" diff --git a/src/zenml/orchestrators/local/local_orchestrator.py b/src/zenml/orchestrators/local/local_orchestrator.py index e45c4c1918d..361b4dc0b46 100644 --- a/src/zenml/orchestrators/local/local_orchestrator.py +++ b/src/zenml/orchestrators/local/local_orchestrator.py @@ -96,13 +96,6 @@ def submit_pipeline( step. RuntimeError: If the pipeline run fails. """ - if snapshot.schedule: - logger.warning( - "Local Orchestrator currently does not support the " - "use of schedules. The `schedule` will be ignored " - "and the pipeline will be run immediately." - ) - self._orchestrator_run_id = str(uuid4()) start_time = time.time() @@ -193,6 +186,44 @@ def submit_pipeline( self._orchestrator_run_id = None return None + def submit_dynamic_pipeline( + self, + snapshot: "PipelineSnapshotResponse", + stack: "Stack", + environment: Dict[str, str], + placeholder_run: Optional["PipelineRunResponse"] = None, + ) -> Optional[SubmissionResult]: + """Submits a dynamic pipeline to the orchestrator. + + Args: + snapshot: The pipeline snapshot to submit. + stack: The stack the pipeline will run on. + environment: Environment variables to set in the orchestration + environment. + placeholder_run: An optional placeholder run. + + Returns: + Optional submission result. + """ + from zenml.execution.pipeline.dynamic.runner import ( + DynamicPipelineRunner, + ) + + self._orchestrator_run_id = str(uuid4()) + start_time = time.time() + + runner = DynamicPipelineRunner(snapshot=snapshot, run=placeholder_run) + with temporary_environment(environment): + runner.run_pipeline() + + run_duration = time.time() - start_time + logger.info( + "Pipeline run has finished in `%s`.", + string_utils.get_human_readable_time(run_duration), + ) + self._orchestrator_run_id = None + return None + def get_orchestrator_run_id(self) -> str: """Returns the active orchestrator run id. diff --git a/src/zenml/orchestrators/output_utils.py b/src/zenml/orchestrators/output_utils.py index 7e9ef2edb9f..9ffe0097368 100644 --- a/src/zenml/orchestrators/output_utils.py +++ b/src/zenml/orchestrators/output_utils.py @@ -111,6 +111,8 @@ def remove_artifact_dirs(artifact_uris: Sequence[str]) -> None: Args: artifact_uris: URIs of the artifacts to remove the directories for. """ + # TODO: maybe keep non-empty dirs here, if the step saved intermediate + # checkpoints? artifact_store = Client().active_stack.artifact_store for artifact_uri in artifact_uris: if artifact_store.isdir(artifact_uri): diff --git a/src/zenml/orchestrators/publish_utils.py b/src/zenml/orchestrators/publish_utils.py index 8956852f934..e2d8b7072fd 100644 --- a/src/zenml/orchestrators/publish_utils.py +++ b/src/zenml/orchestrators/publish_utils.py @@ -118,6 +118,27 @@ def publish_failed_step_run(step_run_id: "UUID") -> "StepRunResponse": ) +def publish_successful_pipeline_run( + pipeline_run_id: "UUID", +) -> "PipelineRunResponse": + """Publishes a successful pipeline run. + + Args: + pipeline_run_id: The ID of the pipeline run to update. + + Returns: + The updated pipeline run. + """ + return Client().zen_store.update_run( + run_id=pipeline_run_id, + run_update=PipelineRunUpdate( + status=ExecutionStatus.COMPLETED, + end_time=utc_now(), + is_finished=True, + ), + ) + + def publish_failed_pipeline_run( pipeline_run_id: "UUID", ) -> "PipelineRunResponse": @@ -134,6 +155,7 @@ def publish_failed_pipeline_run( run_update=PipelineRunUpdate( status=ExecutionStatus.FAILED, end_time=utc_now(), + is_finished=True, ), ) @@ -179,6 +201,7 @@ def get_pipeline_run_status( run_status: ExecutionStatus, step_statuses: List[ExecutionStatus], num_steps: int, + is_dynamic_pipeline: bool, ) -> ExecutionStatus: """Gets the pipeline run status for the given step statuses. @@ -186,6 +209,7 @@ def get_pipeline_run_status( run_status: The status of the run. step_statuses: The status of steps in this run. num_steps: The total amount of steps in this run. + is_dynamic_pipeline: If the pipeline is dynamic. Returns: The run status. @@ -217,11 +241,10 @@ def get_pipeline_run_status( or ExecutionStatus.RETRYING in step_statuses ): return ExecutionStatus.RUNNING - - # If there are less steps than the total number of steps, it is running + elif is_dynamic_pipeline: + return run_status elif len(step_statuses) < num_steps: return ExecutionStatus.RUNNING - # Any other state is completed else: return ExecutionStatus.COMPLETED diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index eacb6fd24d1..820d6795d1e 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -246,12 +246,15 @@ def signal_handler(signum: int, frame: Any) -> None: logger.debug(f"Cannot register signal handlers: {e}") # Continue without signal handling - the step will still run - def launch(self) -> None: + def launch(self) -> StepRunResponse: """Launches the step. Raises: RunStoppedException: If the pipeline run is stopped by the user. BaseException: If the step preparation or execution fails. + + Returns: + The step run response. """ publish_utils.step_exception_info.set(None) pipeline_run, run_was_created = self._create_or_reuse_run() @@ -304,8 +307,10 @@ def launch(self) -> None: pipeline_run=pipeline_run, stack=self._stack, ) + dynamic_config = self._step if self._snapshot.is_dynamic else None step_run_request = request_factory.create_request( - invocation_id=self._invocation_id + invocation_id=self._invocation_id, + dynamic_config=dynamic_config, ) step_run_request.logs = logs_model @@ -377,6 +382,8 @@ def _bypass() -> None: model_version=model_version, ) + return step_run + def _create_or_reuse_run(self) -> Tuple[PipelineRunResponse, bool]: """Creates a pipeline run or reuses an existing one. @@ -427,12 +434,14 @@ def _run_step( step_run_info = StepRunInfo( config=self._step.config, + spec=self._step.spec, pipeline=self._snapshot.pipeline_configuration, run_name=pipeline_run.name, pipeline_step_name=self._invocation_id, run_id=pipeline_run.id, step_run_id=step_run.id, force_write_logs=force_write_logs, + snapshot=self._snapshot, ) output_artifact_uris = output_utils.prepare_output_artifact_uris( @@ -453,14 +462,44 @@ def _run_step( step_operator_name=step_operator_name, step_run_info=step_run_info, ) - else: - self._run_step_without_step_operator( + elif not self._snapshot.is_dynamic: + self._run_step_in_current_thread( pipeline_run=pipeline_run, step_run=step_run, step_run_info=step_run_info, input_artifacts=step_run.regular_inputs, output_artifact_uris=output_artifact_uris, ) + else: + from zenml.execution.pipeline.dynamic.runner import ( + should_run_in_process, + ) + + if should_run_in_process( + self._step, + self._snapshot.pipeline_configuration.docker_settings, + ): + if self._step.config.in_process is False: + # The step was configured to run out of process, but + # the orchestrator doesn't support it. + logger.warning( + "The %s does not support running dynamic out of " + "process steps. Running step `%s` locally instead.", + self._stack.orchestrator.__class__.__name__, + self._invocation_id, + ) + + self._run_step_in_current_thread( + pipeline_run=pipeline_run, + step_run=step_run, + step_run_info=step_run_info, + input_artifacts=step_run.regular_inputs, + output_artifact_uris=output_artifact_uris, + ) + else: + self._run_step_with_dynamic_orchestrator( + step_run_info=step_run_info + ) except: # noqa: E722 output_utils.remove_artifact_dirs( artifact_uris=list(output_artifact_uris.values()) @@ -522,7 +561,27 @@ def _run_step_with_step_operator( environment=environment, ) - def _run_step_without_step_operator( + def _run_step_with_dynamic_orchestrator( + self, + step_run_info: StepRunInfo, + ) -> None: + environment, secrets = orchestrator_utils.get_config_environment_vars( + pipeline_run_id=step_run_info.run_id, + ) + environment.update(secrets) + + environment.update( + env_utils.get_step_environment( + step_config=step_run_info.config, + stack=self._stack, + ) + ) + self._stack.orchestrator.launch_dynamic_step( + step_run_info=step_run_info, + environment=environment, + ) + + def _run_step_in_current_thread( self, pipeline_run: PipelineRunResponse, step_run: StepRunResponse, diff --git a/src/zenml/orchestrators/step_run_utils.py b/src/zenml/orchestrators/step_run_utils.py index ea892203a6f..c3827f04e97 100644 --- a/src/zenml/orchestrators/step_run_utils.py +++ b/src/zenml/orchestrators/step_run_utils.py @@ -76,7 +76,9 @@ def has_caching_enabled(self, invocation_id: str) -> bool: is_enabled_on_pipeline=self.snapshot.pipeline_configuration.enable_cache, ) - def create_request(self, invocation_id: str) -> StepRunRequest: + def create_request( + self, invocation_id: str, dynamic_config: Optional[Step] = None + ) -> StepRunRequest: """Create a step run request. This will only create a request with basic information and will not yet @@ -86,6 +88,7 @@ def create_request(self, invocation_id: str) -> StepRunRequest: Args: invocation_id: The invocation ID for which to create the request. + dynamic_config: The dynamic configuration for the step. Returns: The step run request. @@ -96,6 +99,7 @@ def create_request(self, invocation_id: str) -> StepRunRequest: status=ExecutionStatus.RUNNING, start_time=utc_now(), project=Client().active_project.id, + dynamic_config=dynamic_config, ) def populate_request( @@ -111,7 +115,10 @@ def populate_request( input resolution. This will be updated in-place with newly fetched step runs. """ - step = self.snapshot.step_configurations[request.name] + step = ( + request.dynamic_config + or self.snapshot.step_configurations[request.name] + ) input_artifacts = input_utils.resolve_step_inputs( step=step, @@ -139,7 +146,7 @@ def populate_request( ( docstring, source_code, - ) = self._get_docstring_and_source_code(invocation_id=request.name) + ) = self._get_docstring_and_source_code(step=step) request.docstring = docstring request.source_code = source_code @@ -185,19 +192,16 @@ def populate_request( request.docstring = cached_step_run.docstring def _get_docstring_and_source_code( - self, invocation_id: str + self, step: "Step" ) -> Tuple[Optional[str], Optional[str]]: """Get the docstring and source code for the step. Args: - invocation_id: The step invocation ID for which to get the - docstring and source code. + step: The step for which to get the docstring and source code. Returns: The docstring and source code of the step. """ - step = self.snapshot.step_configurations[invocation_id] - try: return self._get_docstring_and_source_code_from_step_instance( step=step @@ -210,7 +214,7 @@ def _get_docstring_and_source_code( # We now try to fetch the docstring/source code from a step run of the # snapshot that was used to create the template return self._try_to_get_docstring_and_source_code_from_template( - invocation_id=invocation_id + invocation_id=step.spec.invocation_id ) @staticmethod diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index eb67b6f2a63..2ee2677a088 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -184,8 +184,6 @@ def run( self._stack.prepare_step_run(info=step_run_info) - # Initialize the step context singleton - StepContext._clear() step_context = StepContext( pipeline_run=pipeline_run, step_run=step_run, @@ -196,74 +194,78 @@ def run( }, ) - # Parse the inputs for the entrypoint function. - function_params = self._parse_inputs( - args=spec.args, - annotations=spec.annotations, - input_artifacts=input_artifacts, - ) + with step_context: + function_params = self._parse_inputs( + args=spec.args, + annotations=spec.annotations, + input_artifacts=input_artifacts, + ) - # Get all step environment variables. For most orchestrators, the - # non-secret environment variables have been set before by the - # orchestrator. But for some orchestrators, this is not possible and - # we therefore make sure to set them here so they're at least - # available for the user code. - step_environment = env_utils.get_step_environment( - step_config=step_run.config, stack=self._stack - ) - secret_environment = env_utils.get_step_secret_environment( - step_config=step_run.config, stack=self._stack - ) - step_environment.update(secret_environment) - - step_failed = False - try: - if ( - pipeline_run.snapshot - and self._stack.orchestrator.run_init_cleanup_at_step_level - ): - self._stack.orchestrator.run_init_hook( - snapshot=pipeline_run.snapshot - ) + # Get all step environment variables. For most orchestrators, the + # non-secret environment variables have been set before by the + # orchestrator. But for some orchestrators, this is not possible and + # we therefore make sure to set them here so they're at least + # available for the user code. + step_environment = env_utils.get_step_environment( + step_config=step_run.config, stack=self._stack + ) + secret_environment = env_utils.get_step_secret_environment( + step_config=step_run.config, stack=self._stack + ) + step_environment.update(secret_environment) - with env_utils.temporary_environment(step_environment): - return_values = step_instance.call_entrypoint( - **function_params - ) - except BaseException as step_exception: # noqa: E722 - step_failed = True + step_failed = False + try: + if ( + # TODO: do we need to disable this for dynamic pipelines? + pipeline_run.snapshot + and self._stack.orchestrator.run_init_cleanup_at_step_level + ): + self._stack.orchestrator.run_init_hook( + snapshot=pipeline_run.snapshot + ) - exception_info = exception_utils.collect_exception_information( - step_exception, step_instance - ) + with env_utils.temporary_environment(step_environment): + return_values = step_instance.call_entrypoint( + **function_params + ) + except BaseException as step_exception: # noqa: E722 + step_failed = True - if ENV_ZENML_STEP_OPERATOR in os.environ: - # We're running in a step operator environment, so we can't - # depend on the step launcher to publish the exception info - Client().zen_store.update_run_step( - step_run_id=step_run_info.step_run_id, - step_run_update=StepRunUpdate( - exception_info=exception_info, - ), + exception_info = ( + exception_utils.collect_exception_information( + step_exception, step_instance + ) ) - else: - # This will be published by the step launcher - step_exception_info.set(exception_info) - if not step_run.is_retriable: - if ( - failure_hook_source - := self.configuration.failure_hook_source - ): - logger.info("Detected failure hook. Running...") - with env_utils.temporary_environment(step_environment): - load_and_run_hook( - failure_hook_source, - step_exception=step_exception, - ) - raise - finally: - try: + if ENV_ZENML_STEP_OPERATOR in os.environ: + # We're running in a step operator environment, so we can't + # depend on the step launcher to publish the exception info + Client().zen_store.update_run_step( + step_run_id=step_run_info.step_run_id, + step_run_update=StepRunUpdate( + exception_info=exception_info, + ), + ) + else: + # This will be published by the step launcher + step_exception_info.set(exception_info) + + if not step_run.is_retriable: + if ( + failure_hook_source + := self.configuration.failure_hook_source + ): + logger.info("Detected failure hook. Running...") + with env_utils.temporary_environment( + step_environment + ): + load_and_run_hook( + failure_hook_source, + step_exception=step_exception, + ) + raise + finally: step_run_metadata = self._stack.get_step_run_metadata( info=step_run_info, ) @@ -337,12 +339,6 @@ def run( snapshot=pipeline_run.snapshot ) - finally: - step_context._cleanup_registry.execute_callbacks( - raise_on_exception=False - ) - StepContext._clear() # Remove the step context singleton - # Update the status and output artifacts of the step run. output_artifact_ids = { output_name: [ diff --git a/src/zenml/pipelines/compilation_context.py b/src/zenml/pipelines/compilation_context.py new file mode 100644 index 00000000000..eecaa0b6788 --- /dev/null +++ b/src/zenml/pipelines/compilation_context.py @@ -0,0 +1,68 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Pipeline compilation context.""" + +import contextvars +from typing import TYPE_CHECKING + +from typing_extensions import Self + +from zenml.utils import context_utils + +if TYPE_CHECKING: + from zenml.pipelines.pipeline_definition import Pipeline + + +class PipelineCompilationContext(context_utils.BaseContext): + """Pipeline compilation context.""" + + __context_var__ = contextvars.ContextVar("pipeline_compilation_context") + + def __init__( + self, + pipeline: "Pipeline", + ) -> None: + """Initialize the pipeline compilation context. + + Args: + pipeline: The pipeline that is being compiled. + """ + super().__init__() + self._pipeline = pipeline + + @property + def pipeline(self) -> "Pipeline": + """The pipeline that is being compiled. + + Returns: + The pipeline that is being compiled. + """ + return self._pipeline + + def __enter__(self) -> Self: + """Enter the pipeline compilation context. + + Raises: + RuntimeError: If the pipeline compilation context has already been + entered. + + Returns: + The pipeline compilation context object. + """ + if self._token is not None: + raise RuntimeError( + "Compiling a pipeline while another pipeline is being compiled " + "is not allowed." + ) + return super().__enter__() diff --git a/src/zenml/pipelines/dynamic/entrypoint_configuration.py b/src/zenml/pipelines/dynamic/entrypoint_configuration.py new file mode 100644 index 00000000000..9c039466488 --- /dev/null +++ b/src/zenml/pipelines/dynamic/entrypoint_configuration.py @@ -0,0 +1,74 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Entrypoint configuration to run a dynamic pipeline.""" + +from typing import Any, Dict, List +from uuid import UUID + +from zenml.client import Client +from zenml.entrypoints.base_entrypoint_configuration import ( + BaseEntrypointConfiguration, +) +from zenml.execution.pipeline.dynamic.runner import DynamicPipelineRunner +from zenml.integrations.registry import integration_registry + +RUN_ID_OPTION = "run_id" + + +class DynamicPipelineEntrypointConfiguration(BaseEntrypointConfiguration): + """Entrypoint configuration to run a dynamic pipeline.""" + + @classmethod + def get_entrypoint_options(cls) -> Dict[str, bool]: + """Gets all options required for running with this configuration. + + Returns: + All options required for running with this configuration. + """ + return super().get_entrypoint_options() | {RUN_ID_OPTION: False} + + @classmethod + def get_entrypoint_arguments( + cls, + **kwargs: Any, + ) -> List[str]: + """Gets all arguments that the entrypoint command should be called with. + + Args: + **kwargs: Keyword arguments. + + Returns: + All arguments that the entrypoint command should be called with. + """ + args = super().get_entrypoint_arguments(**kwargs) + if run_id := kwargs.get(RUN_ID_OPTION, None): + args.extend([f"--{RUN_ID_OPTION}", str(run_id)]) + return args + + def run(self) -> None: + """Prepares the environment and runs the configured dynamic pipeline.""" + snapshot = self.snapshot + + # Activate all the integrations. This makes sure that all materializers + # and stack component flavors are registered. + integration_registry.activate_integrations() + + self.download_code_if_necessary() + + run = None + if run_id := self.entrypoint_args.get(RUN_ID_OPTION, None): + run = Client().get_pipeline_run(UUID(run_id)) + + runner = DynamicPipelineRunner(snapshot=snapshot, run=run) + runner.run_pipeline() diff --git a/src/zenml/pipelines/dynamic/pipeline_definition.py b/src/zenml/pipelines/dynamic/pipeline_definition.py new file mode 100644 index 00000000000..1e00a7b10e6 --- /dev/null +++ b/src/zenml/pipelines/dynamic/pipeline_definition.py @@ -0,0 +1,184 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Dynamic pipeline definition.""" + +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Type, +) + +from pydantic import BaseModel, ConfigDict, create_model + +from zenml.client import Client +from zenml.execution.pipeline.utils import ( + should_prevent_pipeline_execution, +) +from zenml.logger import get_logger +from zenml.models import PipelineRunResponse +from zenml.pipelines.pipeline_definition import Pipeline +from zenml.steps.utils import ( + parse_return_type_annotations, +) + +if TYPE_CHECKING: + from zenml.steps import BaseStep + +logger = get_logger(__name__) + + +class DynamicPipeline(Pipeline): + """Dynamic pipeline class.""" + + def __init__( + self, + *args: Any, + depends_on: Optional[List["BaseStep"]] = None, + **kwargs: Any, + ) -> None: + """Initialize the pipeline. + + Args: + *args: Pipeline constructor arguments. + depends_on: The steps that the pipeline depends on. + **kwargs: Pipeline constructor keyword arguments. + """ + super().__init__(*args, **kwargs) + self._depends_on = depends_on or [] + self._validate_depends_on(self._depends_on) + + def _validate_depends_on(self, depends_on: List["BaseStep"]) -> None: + """Validates the steps that the pipeline depends on. + + Args: + depends_on: The steps that the pipeline depends on. + + Raises: + RuntimeError: If some of the steps in `depends_on` are duplicated. + """ + static_ids = set() + for step in depends_on: + static_id = step._static_id + if static_id in static_ids: + raise RuntimeError( + f"The pipeline {self.name} depends on the same step " + f"({step.name}) multiple times. To fix this, remove the " + "duplicate from the `depends_on` list. You can pass the " + "same step function with multiple configurations by using " + "the `step.with_options(...)` method." + ) + + static_ids.add(static_id) + + @property + def depends_on(self) -> List["BaseStep"]: + """The steps that the pipeline depends on. + + Returns: + The steps that the pipeline depends on. + """ + return self._depends_on + + @property + def is_dynamic(self) -> bool: + """If the pipeline is dynamic. + + Returns: + If the pipeline is dynamic. + """ + return True + + def _prepare_invocations(self, **kwargs: Any) -> None: + """Prepares the invocations of the pipeline. + + Args: + **kwargs: Keyword arguments. + """ + for step in self._depends_on: + self.add_step_invocation( + step, + input_artifacts={}, + external_artifacts={}, + model_artifacts_or_metadata={}, + client_lazy_loaders={}, + parameters={}, + default_parameters={}, + upstream_steps=set(), + ) + + def __call__( + self, *args: Any, **kwargs: Any + ) -> Optional[PipelineRunResponse]: + """Run the pipeline on the active stack. + + Args: + *args: Entrypoint function arguments. + **kwargs: Entrypoint function keyword arguments. + + Raises: + RuntimeError: If the active orchestrator does not support running + dynamic pipelines. + + Returns: + The pipeline run or `None` if running with a schedule. + """ + if should_prevent_pipeline_execution(): + logger.info("Preventing execution of pipeline '%s'.", self.name) + return None + + stack = Client().active_stack + if not stack.orchestrator.supports_dynamic_pipelines: + raise RuntimeError( + f"The {stack.orchestrator.__class__.__name__} does not " + "support dynamic pipelines. " + ) + + logger.warning( + "Dynamic pipelines are currently an experimental stage. There " + "might be missing features, bugs and the interface is subject to " + "change. If you encounter any issues or have feedback, please " + "let us know at https://github.com/zenml-io/zenml/issues." + ) + + self.prepare(*args, **kwargs) + return self._run() + + def _compute_output_schema(self) -> Optional[Dict[str, Any]]: + """Computes the output schema for the pipeline. + + Returns: + The output schema for the pipeline. + """ + try: + outputs = parse_return_type_annotations(self.entrypoint) + model_fields: Dict[str, Any] = { + name: (output.resolved_annotation, ...) + for name, output in outputs.items() + } + output_model: Type[BaseModel] = create_model( + "PipelineOutput", + __config__=ConfigDict(extra="forbid"), + **model_fields, + ) + return output_model.model_json_schema(mode="serialization") + except Exception as e: + logger.debug( + f"Failed to generate the output schema for pipeline " + f"`{self.name}: {e}. This means that the pipeline cannot be " + "deployed.", + ) + return None diff --git a/src/zenml/pipelines/pipeline_context.py b/src/zenml/pipelines/pipeline_context.py index ef92da92d46..4110e4fc635 100644 --- a/src/zenml/pipelines/pipeline_context.py +++ b/src/zenml/pipelines/pipeline_context.py @@ -33,9 +33,11 @@ def get_pipeline_context() -> "PipelineContext": RuntimeError: If no active pipeline is found. RuntimeError: If inside a running step. """ - from zenml.pipelines.pipeline_definition import Pipeline + from zenml.pipelines.compilation_context import PipelineCompilationContext - if Pipeline.ACTIVE_PIPELINE is None: + context = PipelineCompilationContext.get() + + if context is None: try: from zenml.steps.step_context import get_step_context @@ -49,7 +51,7 @@ def get_pipeline_context() -> "PipelineContext": ) return PipelineContext( - pipeline_configuration=Pipeline.ACTIVE_PIPELINE.configuration + pipeline_configuration=context.pipeline.configuration ) diff --git a/src/zenml/pipelines/pipeline_decorator.py b/src/zenml/pipelines/pipeline_decorator.py index 97bd5c07b93..a07813d924c 100644 --- a/src/zenml/pipelines/pipeline_decorator.py +++ b/src/zenml/pipelines/pipeline_decorator.py @@ -35,6 +35,7 @@ from zenml.config.retry_config import StepRetryConfig from zenml.model.model import Model from zenml.pipelines.pipeline_definition import Pipeline + from zenml.steps.base_step import BaseStep from zenml.types import HookSpecification, InitHookSpecification from zenml.utils.tag_utils import Tag @@ -51,6 +52,8 @@ def pipeline(_func: "F") -> "Pipeline": ... def pipeline( *, name: Optional[str] = None, + dynamic: Optional[bool] = None, + depends_on: Optional[List["BaseStep"]] = None, enable_cache: Optional[bool] = None, enable_artifact_metadata: Optional[bool] = None, enable_step_logs: Optional[bool] = None, @@ -77,6 +80,8 @@ def pipeline( _func: Optional["F"] = None, *, name: Optional[str] = None, + dynamic: Optional[bool] = None, + depends_on: Optional[List["BaseStep"]] = None, enable_cache: Optional[bool] = None, enable_artifact_metadata: Optional[bool] = None, enable_step_logs: Optional[bool] = None, @@ -103,6 +108,8 @@ def pipeline( _func: The decorated function. name: The name of the pipeline. If left empty, the name of the decorated function will be used as a fallback. + dynamic: Whether this is a dynamic pipeline or not. + depends_on: The steps that this pipeline depends on. enable_cache: Whether to use caching or not. enable_artifact_metadata: Whether to enable artifact metadata or not. enable_step_logs: If step logs should be enabled for this pipeline. @@ -140,7 +147,26 @@ def pipeline( def inner_decorator(func: "F") -> "Pipeline": from zenml.pipelines.pipeline_definition import Pipeline - p = Pipeline( + PipelineClass = Pipeline + pipeline_args: Dict[str, Any] = {} + + if dynamic: + from zenml.pipelines.dynamic.pipeline_definition import ( + DynamicPipeline, + ) + + PipelineClass = DynamicPipeline + + pipeline_args = { + "depends_on": depends_on, + } + elif depends_on: + logger.warning( + "The `depends_on` argument is not supported " + "for static pipelines and will be ignored." + ) + + p = PipelineClass( name=name or func.__name__, entrypoint=func, enable_cache=enable_cache, @@ -162,6 +188,7 @@ def inner_decorator(func: "F") -> "Pipeline": substitutions=substitutions, execution_mode=execution_mode, cache_policy=cache_policy, + **pipeline_args, ) p.__doc__ = func.__doc__ diff --git a/src/zenml/pipelines/pipeline_definition.py b/src/zenml/pipelines/pipeline_definition.py index f388dc5f581..b32fd8e216e 100644 --- a/src/zenml/pipelines/pipeline_definition.py +++ b/src/zenml/pipelines/pipeline_definition.py @@ -22,7 +22,6 @@ TYPE_CHECKING, Any, Callable, - ClassVar, Dict, Iterator, List, @@ -41,7 +40,6 @@ from pydantic import BaseModel, ConfigDict, ValidationError, create_model from typing_extensions import Self -from zenml import constants from zenml.analytics.enums import AnalyticsEvent from zenml.analytics.utils import track_handler from zenml.client import Client @@ -54,8 +52,16 @@ from zenml.config.pipeline_spec import PipelineSpec from zenml.config.schedule import Schedule from zenml.config.step_configurations import StepConfigurationUpdate +from zenml.constants import ( + ENV_ZENML_DISABLE_PIPELINE_LOGS_STORAGE, + handle_bool_env_var, +) from zenml.enums import StackComponentType from zenml.exceptions import EntityExistsError +from zenml.execution.pipeline.utils import ( + should_prevent_pipeline_execution, + submit_pipeline, +) from zenml.hooks.hook_validators import resolve_and_validate_hook from zenml.logger import get_logger from zenml.logging.step_logging import ( @@ -78,9 +84,9 @@ ScheduleRequest, ) from zenml.pipelines import build_utils +from zenml.pipelines.compilation_context import PipelineCompilationContext from zenml.pipelines.run_utils import ( create_placeholder_run, - submit_pipeline, upload_notebook_cell_code_if_necessary, ) from zenml.stack import Stack @@ -127,11 +133,6 @@ class Pipeline: """ZenML pipeline class.""" - # The active pipeline is the pipeline to which step invocations will be - # added when a step is called. It is set using a context manager when a - # pipeline is called (see Pipeline.__call__ for more context) - ACTIVE_PIPELINE: ClassVar[Optional["Pipeline"]] = None - def __init__( self, name: str, @@ -156,6 +157,7 @@ def __init__( substitutions: Optional[Dict[str, str]] = None, execution_mode: Optional["ExecutionMode"] = None, cache_policy: Optional["CachePolicyOrString"] = None, + **kwargs: Any, ) -> None: """Initializes a pipeline. @@ -195,6 +197,7 @@ def __init__( substitutions: Extra placeholders to use in the name templates. execution_mode: The execution mode of the pipeline. cache_policy: Cache policy for this pipeline. + **kwargs: Additional keyword arguments. """ self._invocations: Dict[str, StepInvocation] = {} self._run_args: Dict[str, Any] = {} @@ -241,6 +244,15 @@ def name(self) -> str: """ return self.configuration.name + @property + def is_dynamic(self) -> bool: + """If the pipeline is dynamic. + + Returns: + If the pipeline is dynamic. + """ + return False + @property def enable_cache(self) -> Optional[bool]: """If caching is enabled for the pipeline. @@ -277,7 +289,12 @@ def resolve(self) -> "Source": Returns: The pipeline source. """ - return source_utils.resolve(self.entrypoint, skip_validation=True) + # We need to validate that the source is loadable for dynamic pipelines, + # as the orchestration environment will need to load the source. + skip_validation = not self.is_dynamic + return source_utils.resolve( + self.entrypoint, skip_validation=skip_validation + ) @property def source_object(self) -> Any: @@ -307,8 +324,6 @@ def model(self) -> "PipelineResponse": Raises: RuntimeError: If the pipeline has not been registered yet. """ - self._prepare_if_possible() - pipelines = Client().list_pipelines(name=self.name) if len(pipelines) == 1: return pipelines.items[0] @@ -568,25 +583,76 @@ def prepare(self, *args: Any, **kwargs: Any) -> None: Args: *args: Pipeline entrypoint input arguments. **kwargs: Pipeline entrypoint input keyword arguments. + """ + self._clear_state() + + kwargs = self._apply_config_parameters(kwargs) + self._parameters = self._validate_entrypoint_args(*args, **kwargs) + + with PipelineCompilationContext(pipeline=self): + self._prepare_invocations(**self._parameters) + + def _validate_entrypoint_args( + self, *args: Any, **kwargs: Any + ) -> Dict[str, Any]: + """Validates the arguments for the pipeline entrypoint. + + Args: + *args: Entrypoint function arguments. + **kwargs: Entrypoint function keyword arguments. Raises: - RuntimeError: If the pipeline has parameters configured differently in - configuration file and code. + ValueError: If the arguments are invalid or missing. + + Returns: + The validated arguments. """ - self._parameters = {} - self._invocations = {} - self._output_artifacts = [] + try: + validated_args = pydantic_utils.validate_function_args( + self.entrypoint, + ConfigDict(arbitrary_types_allowed=False), + *args, + **kwargs, + ) + except ValidationError as e: + raise ValueError( + "Invalid or missing pipeline function entrypoint arguments. " + "Only JSON serializable inputs are allowed as pipeline inputs. " + "Check out the pydantic error above for more details." + ) from e + + return validated_args + + def _apply_config_parameters( + self, kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + """Applies the configuration parameters to the code arguments. + + Args: + kwargs: The code arguments to apply the configuration parameters to. + + Raises: + RuntimeError: If different values for the same key are passed in + code and configuration. + Returns: + The merged arguments. + """ + kwargs = kwargs.copy() conflicting_parameters = {} - parameters_ = (self.configuration.parameters or {}).copy() + config_parameters = self.configuration.parameters or {} if from_file_ := self._from_config_file.get("parameters", None): - parameters_ = dict_utils.recursive_update(parameters_, from_file_) - if parameters_: + config_parameters = dict_utils.recursive_update( + config_parameters, from_file_ + ) + + if config_parameters: for k, v_runtime in kwargs.items(): - if k in parameters_: - v_config = parameters_[k] + if k in config_parameters: + v_config = config_parameters[k] if v_config != v_runtime: conflicting_parameters[k] = (v_config, v_runtime) + if conflicting_parameters: is_plural = "s" if len(conflicting_parameters) > 1 else "" msg = f"Configured parameter{is_plural} for the pipeline `{self.name}` conflict{'' if not is_plural else 's'} with parameter{is_plural} passed in runtime:\n" @@ -610,15 +676,35 @@ def pipeline_(param_name: str): To avoid this consider setting pipeline parameters only in one place (config or code). """ raise RuntimeError(msg) - for k, v_config in parameters_.items(): + + for k, v_config in config_parameters.items(): if k not in kwargs: kwargs[k] = v_config - with self: - # Enter the context manager, so we become the active pipeline. This - # means that all steps that get called while the entrypoint function - # is executed will be added as invocation to this pipeline instance. - self._call_entrypoint(*args, **kwargs) + return kwargs + + def _prepare_invocations(self, **kwargs: Any) -> None: + """Prepares the invocations of the pipeline. + + Args: + **kwargs: Keyword arguments. + """ + outputs = self._call_entrypoint(**kwargs) + + output_artifacts = [] + if isinstance(outputs, StepArtifact): + output_artifacts = [outputs] + elif isinstance(outputs, tuple): + for v in outputs: + if isinstance(v, StepArtifact): + output_artifacts.append(v) + else: + logger.debug( + "Ignore pipeline output that is not a step artifact: %s", + v, + ) + + self._output_artifacts = output_artifacts def register(self) -> "PipelineResponse": """Register the pipeline in the server. @@ -626,25 +712,39 @@ def register(self) -> "PipelineResponse": Returns: The registered pipeline model. """ - # Activating the built-in integrations to load all materializers - from zenml.integrations.registry import integration_registry + client = Client() - self._prepare_if_possible() - integration_registry.activate_integrations() + def _get() -> PipelineResponse: + matching_pipelines = client.list_pipelines( + name=self.name, + size=1, + sort_by="desc:created", + ) - if self.configuration.model_dump( - exclude_defaults=True, exclude={"name"} - ): - logger.warning( - f"The pipeline `{self.name}` that you're registering has " - "custom configurations applied to it. These will not be " - "registered with the pipeline and won't be set when you build " - "images or run the pipeline from the CLI. To provide these " - "configurations, use the `--config` option of the `zenml " - "pipeline build/run` commands." + if matching_pipelines.total: + registered_pipeline = matching_pipelines.items[0] + return registered_pipeline + raise RuntimeError("No matching pipelines found.") + + try: + return _get() + except RuntimeError: + request = PipelineRequest( + project=client.active_project.id, + name=self.name, ) - return self._register() + try: + registered_pipeline = client.zen_store.create_pipeline( + pipeline=request + ) + logger.info( + "Registered new pipeline: `%s`.", + registered_pipeline.name, + ) + return registered_pipeline + except EntityExistsError: + return _get() def build( self, @@ -682,7 +782,7 @@ def build( compile_args["settings"] = settings snapshot, _, _ = self._compile(**compile_args) - pipeline_id = self._register().id + pipeline_id = self.register().id local_repo = code_repository_utils.find_active_code_repository() code_repository = build_utils.verify_local_repository_context( @@ -798,7 +898,7 @@ def _create_snapshot( extra=extra, ) - pipeline_id = self._register().id + pipeline_id = self.register().id stack = Client().active_stack stack.validate() @@ -926,17 +1026,8 @@ def _run( Returns: The pipeline run or `None` if running with a schedule. """ - if constants.SHOULD_PREVENT_PIPELINE_EXECUTION: - # An environment variable was set to stop the execution of - # pipelines. This is done to prevent execution of module-level - # pipeline.run() calls when importing modules needed to run a step. - logger.info( - "Preventing execution of pipeline '%s'. If this is not " - "intended behavior, make sure to unset the environment " - "variable '%s'.", - self.name, - constants.ENV_ZENML_PREVENT_PIPELINE_EXECUTION, - ) + if should_prevent_pipeline_execution(): + logger.info("Preventing execution of pipeline '%s'.", self.name) return None logger.info(f"Initiating a new run for the pipeline: `{self.name}`.") @@ -949,8 +1040,8 @@ def _run( # Pipeline runs scheduled to run in the future are not logged # via the client. logging_enabled = False - elif constants.handle_bool_env_var( - constants.ENV_ZENML_DISABLE_PIPELINE_LOGS_STORAGE, False + elif handle_bool_env_var( + ENV_ZENML_DISABLE_PIPELINE_LOGS_STORAGE, False ): logging_enabled = False else: @@ -1253,46 +1344,6 @@ def _compile( return snapshot, run_config.schedule, run_config.build - def _register(self) -> "PipelineResponse": - """Register the pipeline in the server. - - Returns: - The registered pipeline model. - """ - client = Client() - - def _get() -> PipelineResponse: - matching_pipelines = client.list_pipelines( - name=self.name, - size=1, - sort_by="desc:created", - ) - - if matching_pipelines.total: - registered_pipeline = matching_pipelines.items[0] - return registered_pipeline - raise RuntimeError("No matching pipelines found.") - - try: - return _get() - except RuntimeError: - request = PipelineRequest( - project=client.active_project.id, - name=self.name, - ) - - try: - registered_pipeline = client.zen_store.create_pipeline( - pipeline=request - ) - logger.info( - "Registered new pipeline: `%s`.", - registered_pipeline.name, - ) - return registered_pipeline - except EntityExistsError: - return _get() - def _compute_unique_identifier(self, pipeline_spec: PipelineSpec) -> str: """Computes a unique identifier from the pipeline spec and steps. @@ -1357,7 +1408,15 @@ def add_step_invocation( Returns: The step invocation ID. """ - if Pipeline.ACTIVE_PIPELINE != self: + from zenml.execution.pipeline.dynamic.run_context import ( + DynamicPipelineRunContext, + ) + + context = ( + PipelineCompilationContext.get() or DynamicPipelineRunContext.get() + ) + + if not context or context.pipeline != self: raise RuntimeError( "A step invocation can only be added to an active pipeline." ) @@ -1425,32 +1484,6 @@ def _compute_invocation_id( raise RuntimeError("Unable to find step ID") - def __enter__(self) -> Self: - """Activate the pipeline context. - - Raises: - RuntimeError: If a different pipeline is already active. - - Returns: - The pipeline instance. - """ - if Pipeline.ACTIVE_PIPELINE: - raise RuntimeError( - "Unable to enter pipeline context. A different pipeline " - f"{Pipeline.ACTIVE_PIPELINE.name} is already active." - ) - - Pipeline.ACTIVE_PIPELINE = self - return self - - def __exit__(self, *args: Any) -> None: - """Deactivates the pipeline context. - - Args: - *args: The arguments passed to the context exit handler. - """ - Pipeline.ACTIVE_PIPELINE = None - def _parse_config_file( self, config_path: Optional[str], matcher: List[str] ) -> Dict[str, Any]: @@ -1586,7 +1619,7 @@ def __call__( `entrypoint` method. Otherwise, returns the pipeline run or `None` if running with a schedule. """ - if Pipeline.ACTIVE_PIPELINE: + if PipelineCompilationContext.is_active(): # Calling a pipeline inside a pipeline, we return the potential # outputs of the entrypoint function @@ -1598,48 +1631,19 @@ def __call__( self.prepare(*args, **kwargs) return self._run() - def _call_entrypoint(self, *args: Any, **kwargs: Any) -> None: + def _call_entrypoint(self, *args: Any, **kwargs: Any) -> Any: """Calls the pipeline entrypoint function with the given arguments. Args: *args: Entrypoint function arguments. **kwargs: Entrypoint function keyword arguments. - Raises: - ValueError: If an input argument is missing or not JSON - serializable. + Returns: + The return value of the entrypoint function. """ - try: - validated_args = pydantic_utils.validate_function_args( - self.entrypoint, - ConfigDict(arbitrary_types_allowed=False), - *args, - **kwargs, - ) - except ValidationError as e: - raise ValueError( - "Invalid or missing pipeline function entrypoint arguments. " - "Only JSON serializable inputs are allowed as pipeline inputs. " - "Check out the pydantic error above for more details." - ) from e - - self._parameters = validated_args - return_value = self.entrypoint(**validated_args) - - output_artifacts = [] - if isinstance(return_value, StepArtifact): - output_artifacts = [return_value] - elif isinstance(return_value, tuple): - for v in return_value: - if isinstance(v, StepArtifact): - output_artifacts.append(v) - else: - logger.debug( - "Ignore pipeline output that is not a step artifact: %s", - v, - ) - - self._output_artifacts = output_artifacts + self._clear_state() + self._parameters = self._validate_entrypoint_args(*args, **kwargs) + return self.entrypoint(**self._parameters) def _prepare_if_possible(self) -> None: """Prepares the pipeline if possible. @@ -1853,3 +1857,9 @@ def _compute_input_schema(self) -> Optional[Dict[str, Any]]: ) return None + + def _clear_state(self) -> None: + """Clears the state of the pipeline.""" + self._invocations = {} + self._parameters = {} + self._output_artifacts = [] diff --git a/src/zenml/pipelines/run_utils.py b/src/zenml/pipelines/run_utils.py index 0adedd83468..4637b3faa46 100644 --- a/src/zenml/pipelines/run_utils.py +++ b/src/zenml/pipelines/run_utils.py @@ -1,18 +1,37 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. """Utility functions for running pipelines.""" import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Set, + Union, +) from uuid import UUID from pydantic import BaseModel -from zenml import constants from zenml.client import Client from zenml.config.pipeline_run_configuration import PipelineRunConfiguration from zenml.config.source import Source, SourceType from zenml.config.step_configurations import StepConfigurationUpdate from zenml.enums import ExecutionStatus -from zenml.exceptions import RunMonitoringError from zenml.logger import get_logger from zenml.models import ( FlavorFilter, @@ -24,9 +43,13 @@ PipelineSnapshotResponse, StackResponse, ) -from zenml.orchestrators.publish_utils import publish_failed_pipeline_run from zenml.stack import Flavor, Stack -from zenml.utils import code_utils, notebook_utils, source_utils, string_utils +from zenml.utils import ( + code_utils, + notebook_utils, + source_utils, + string_utils, +) from zenml.utils.time_utils import utc_now from zenml.zen_stores.base_zen_store import BaseZenStore @@ -95,53 +118,6 @@ def create_placeholder_run( return run -def submit_pipeline( - snapshot: "PipelineSnapshotResponse", - stack: "Stack", - placeholder_run: Optional["PipelineRunResponse"] = None, -) -> None: - """Submit a snapshot for execution. - - Args: - snapshot: The snapshot to submit. - stack: The stack on which to submit the snapshot. - placeholder_run: An optional placeholder run for the snapshot. - - # noqa: DAR401 - Raises: - BaseException: Any exception that happened while submitting or running - (in case it happens synchronously) the pipeline. - """ - # Prevent execution of nested pipelines which might lead to - # unexpected behavior - previous_value = constants.SHOULD_PREVENT_PIPELINE_EXECUTION - constants.SHOULD_PREVENT_PIPELINE_EXECUTION = True - try: - stack.prepare_pipeline_submission(snapshot=snapshot) - stack.submit_pipeline( - snapshot=snapshot, - placeholder_run=placeholder_run, - ) - except RunMonitoringError as e: - # Don't mark the run as failed if the error happened during monitoring - # of the run. - raise e.original_exception from None - except BaseException as e: - if ( - placeholder_run - and not Client() - .get_pipeline_run(placeholder_run.id, hydrate=False) - .status.is_finished - ): - # We failed during/before the submission of the run, so we mark the - # run as failed if it is still in an initializing/running state. - publish_failed_pipeline_run(placeholder_run.id) - - raise e - finally: - constants.SHOULD_PREVENT_PIPELINE_EXECUTION = previous_value - - def wait_for_pipeline_run_to_finish(run_id: UUID) -> "PipelineRunResponse": """Waits until a pipeline run is finished. diff --git a/src/zenml/step_operators/step_operator_entrypoint_configuration.py b/src/zenml/step_operators/step_operator_entrypoint_configuration.py index cb3b71c9c74..cfbb81057fc 100644 --- a/src/zenml/step_operators/step_operator_entrypoint_configuration.py +++ b/src/zenml/step_operators/step_operator_entrypoint_configuration.py @@ -13,10 +13,11 @@ # permissions and limitations under the License. """Abstract base class for entrypoint configurations that run a single step.""" -from typing import TYPE_CHECKING, Any, List, Set +from typing import TYPE_CHECKING, Any, Dict, List, Optional from uuid import UUID from zenml.client import Client +from zenml.config.step_configurations import Step from zenml.config.step_run_info import StepRunInfo from zenml.entrypoints.step_entrypoint_configuration import ( STEP_NAME_OPTION, @@ -26,8 +27,7 @@ from zenml.orchestrators.step_runner import StepRunner if TYPE_CHECKING: - from zenml.config.step_configurations import Step - from zenml.models import PipelineSnapshotResponse + from zenml.models import PipelineSnapshotResponse, StepRunResponse STEP_RUN_ID_OPTION = "step_run_id" @@ -35,15 +35,25 @@ class StepOperatorEntrypointConfiguration(StepEntrypointConfiguration): """Base class for step operator entrypoint configurations.""" + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize the step operator entrypoint configuration. + + Args: + *args: The arguments to pass to the superclass. + **kwargs: The keyword arguments to pass to the superclass. + """ + super().__init__(*args, **kwargs) + self._step_run: Optional["StepRunResponse"] = None + @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> Dict[str, bool]: """Gets all options required for running with this configuration. Returns: The superclass options as well as an option for the step run id. """ return super().get_entrypoint_options() | { - STEP_RUN_ID_OPTION, + STEP_RUN_ID_OPTION: True, } @classmethod @@ -64,6 +74,31 @@ def get_entrypoint_arguments( kwargs[STEP_RUN_ID_OPTION], ] + @property + def step_run(self) -> "StepRunResponse": + """The step run configured for this entrypoint configuration. + + Returns: + The step run. + """ + if self._step_run is None: + step_run_id = UUID(self.entrypoint_args[STEP_RUN_ID_OPTION]) + self._step_run = Client().zen_store.get_run_step(step_run_id) + return self._step_run + + @property + def step(self) -> "Step": + """The step configured for this entrypoint configuration. + + Returns: + The step. + """ + return Step( + spec=self.step_run.spec, + config=self.step_run.config, + step_config_overrides=self.step_run.config, + ) + def _run_step( self, step: "Step", @@ -75,17 +110,18 @@ def _run_step( step: The step to run. snapshot: The snapshot configuration. """ - step_run_id = UUID(self.entrypoint_args[STEP_RUN_ID_OPTION]) - step_run = Client().zen_store.get_run_step(step_run_id) + step_run = self.step_run pipeline_run = Client().get_pipeline_run(step_run.pipeline_run_id) step_run_info = StepRunInfo( config=step.config, + spec=step.spec, + snapshot=snapshot, pipeline=snapshot.pipeline_configuration, run_name=pipeline_run.name, pipeline_step_name=self.entrypoint_args[STEP_NAME_OPTION], run_id=pipeline_run.id, - step_run_id=step_run_id, + step_run_id=step_run.id, force_write_logs=lambda: None, ) diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index ef18b99603e..8c34f6eadd6 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -18,10 +18,12 @@ import inspect from abc import abstractmethod from collections import defaultdict +from contextlib import contextmanager from typing import ( TYPE_CHECKING, Any, Dict, + Generator, List, Mapping, Optional, @@ -30,6 +32,7 @@ Type, TypeVar, Union, + cast, ) from uuid import UUID @@ -91,6 +94,12 @@ Mapping[str, Sequence["MaterializerClassOrSource"]], ] + from zenml.execution.pipeline.dynamic.outputs import ( + StepRunFuture, + StepRunOutputsFuture, + ) + + logger = get_logger(__name__) T = TypeVar("T", bound="BaseStep") @@ -122,6 +131,7 @@ def __init__( retry: Optional[StepRetryConfig] = None, substitutions: Optional[Dict[str, str]] = None, cache_policy: Optional[CachePolicyOrString] = None, + in_process: Optional[bool] = None, ) -> None: """Initializes a step. @@ -155,6 +165,8 @@ def __init__( retry: Configuration for retrying the step in case of failure. substitutions: Extra placeholders to use in the name template. cache_policy: Cache policy for this step. + in_process: Whether to run the step in process. This is only + applicable for dynamic pipelines. """ from zenml.config.step_configurations import PartialStepConfiguration @@ -163,6 +175,7 @@ def __init__( reserved_arguments=["after", "id"], ) + self._static_id = id(self) name = name or self.__class__.__name__ logger.debug( @@ -197,14 +210,15 @@ def __init__( }, ) - self._configuration = PartialStepConfiguration( - name=name, + self._configuration = PartialStepConfiguration(name=name) + self._dynamic_configuration: Optional["StepConfigurationUpdate"] = None + self._capture_dynamic_configuration = True + + self.configure( enable_cache=enable_cache, enable_artifact_metadata=enable_artifact_metadata, enable_artifact_visualization=enable_artifact_visualization, enable_step_logs=enable_step_logs, - ) - self.configure( experiment_tracker=experiment_tracker, step_operator=step_operator, output_materializers=output_materializers, @@ -219,6 +233,7 @@ def __init__( retry=retry, substitutions=substitutions, cache_policy=cache_policy, + in_process=in_process, ) notebook_utils.try_to_save_notebook_cell_code(self.source_object) @@ -454,7 +469,11 @@ def __call__( *args: Any, id: Optional[str] = None, after: Union[ - str, StepArtifact, Sequence[Union[str, StepArtifact]], None + str, + StepArtifact, + "StepRunFuture", + Sequence[Union[str, StepArtifact, "StepRunFuture"]], + None, ] = None, **kwargs: Any, ) -> Any: @@ -474,10 +493,48 @@ def __call__( Returns: The outputs of the entrypoint function call. """ - from zenml.pipelines.pipeline_definition import Pipeline + from zenml import get_step_context + from zenml.execution.pipeline.dynamic.run_context import ( + DynamicPipelineRunContext, + ) + from zenml.pipelines.compilation_context import ( + PipelineCompilationContext, + ) - if not Pipeline.ACTIVE_PIPELINE: - from zenml import constants, get_step_context + try: + step_context = get_step_context() + except RuntimeError: + step_context = None + + if step_context: + # We're currently inside the execution of a different step + # -> We don't want to launch another single step pipeline here, + # but instead just call the step function + return self.call_entrypoint(*args, **kwargs) + + if run_context := DynamicPipelineRunContext.get(): + after = cast( + Union[ + "StepRunFuture", + Sequence["StepRunFuture"], + None, + ], + after, + ) + return run_context.runner.launch_step( + step=self, + id=id, + args=args, + kwargs=kwargs, + after=after, + concurrent=False, + ) + + compilation_context = PipelineCompilationContext.get() + if not compilation_context: + from zenml.execution.pipeline.utils import ( + should_prevent_pipeline_execution, + ) # If the environment variable was set to explicitly not run on the # stack, we do that. @@ -487,21 +544,8 @@ def __call__( if run_without_stack: return self.call_entrypoint(*args, **kwargs) - try: - get_step_context() - except RuntimeError: - pass - else: - # We're currently inside the execution of a different step - # -> We don't want to launch another single step pipeline here, - # but instead just call the step function - return self.call_entrypoint(*args, **kwargs) - - if constants.SHOULD_PREVENT_PIPELINE_EXECUTION: - logger.info( - "Preventing execution of step '%s'.", - self.name, - ) + if should_prevent_pipeline_execution(): + logger.info("Preventing execution of step '%s'.", self.name) return return run_as_single_step_pipeline(self, *args, **kwargs) @@ -529,7 +573,7 @@ def __call__( elif isinstance(item, StepArtifact): upstream_steps.add(item.invocation_id) - invocation_id = Pipeline.ACTIVE_PIPELINE.add_step_invocation( + invocation_id = compilation_context.pipeline.add_step_invocation( step=self, input_artifacts=input_artifacts, external_artifacts=external_artifacts, @@ -548,7 +592,7 @@ def __call__( invocation_id=invocation_id, output_name=key, annotation=annotation, - pipeline=Pipeline.ACTIVE_PIPELINE, + pipeline=compilation_context.pipeline, ) outputs.append(output) return outputs[0] if len(outputs) == 1 else outputs @@ -582,6 +626,48 @@ def call_entrypoint(self, *args: Any, **kwargs: Any) -> Any: return self.entrypoint(**validated_args) + def submit( + self, + *args: Any, + id: Optional[str] = None, + after: Union["StepRunFuture", Sequence["StepRunFuture"], None] = None, + **kwargs: Any, + ) -> "StepRunOutputsFuture": + """Submit the step to run concurrently in a separate thread. + + Args: + *args: The arguments to pass to the step function. + id: The invocation ID of the step. + after: The step run output futures to wait for before executing the + step. + **kwargs: The keyword arguments to pass to the step function. + + Raises: + RuntimeError: If this method is called outside of a dynamic + pipeline. + + Returns: + The step run output future. + """ + from zenml.execution.pipeline.dynamic.run_context import ( + DynamicPipelineRunContext, + ) + + context = DynamicPipelineRunContext.get() + if not context: + raise RuntimeError( + "Submitting a step is only possible within a dynamic pipeline." + ) + + return context.runner.launch_step( + step=self, + id=id, + args=args, + kwargs=kwargs, + after=after, + concurrent=True, + ) + @property def name(self) -> str: """The name of the step. @@ -631,6 +717,7 @@ def configure( retry: Optional[StepRetryConfig] = None, substitutions: Optional[Dict[str, str]] = None, cache_policy: Optional[CachePolicyOrString] = None, + in_process: Optional[bool] = None, merge: bool = True, ) -> T: """Configures the step. @@ -674,6 +761,8 @@ def configure( retry: Configuration for retrying the step in case of failure. substitutions: Extra placeholders to use in the name template. cache_policy: Cache policy for this step. + in_process: Whether to run the step in process. This is only + applicable for dynamic pipelines. merge: If `True`, will merge the given dictionary configurations like `parameters` and `settings` with existing configurations. If `False` the given configurations will @@ -752,6 +841,7 @@ def _convert_to_tuple(value: Any) -> Tuple[Source, ...]: "retry": retry, "substitutions": substitutions, "cache_policy": cache_policy, + "in_process": in_process, } ) config = StepConfigurationUpdate(**values) @@ -780,6 +870,7 @@ def with_options( retry: Optional[StepRetryConfig] = None, substitutions: Optional[Dict[str, str]] = None, cache_policy: Optional[CachePolicyOrString] = None, + in_process: Optional[bool] = None, merge: bool = True, ) -> "BaseStep": """Copies the step and applies the given configurations. @@ -813,6 +904,8 @@ def with_options( retry: Configuration for retrying the step in case of failure. substitutions: Extra placeholders for the step name. cache_policy: Cache policy for this step. + in_process: Whether to run the step in process. This is only + applicable for dynamic pipelines. merge: If `True`, will merge the given dictionary configurations like `parameters` and `settings` with existing configurations. If `False` the given configurations will @@ -842,6 +935,7 @@ def with_options( retry=retry, substitutions=substitutions, cache_policy=cache_policy, + in_process=in_process, merge=merge, ) return step_copy @@ -852,7 +946,32 @@ def copy(self) -> "BaseStep": Returns: The step copy. """ - return copy.deepcopy(self) + from zenml.execution.pipeline.dynamic.run_context import ( + DynamicPipelineRunContext, + ) + + step_copy = copy.deepcopy(self) + + if not DynamicPipelineRunContext.is_active(): + # If we're not in a dynamic pipeline, we generate a new static ID + # for the step copy + step_copy._static_id = id(step_copy) + + return step_copy + + @contextmanager + def _suspend_dynamic_configuration(self) -> Generator[None, None, None]: + """Context manager to suspend applying to the dynamic configuration. + + Yields: + None. + """ + previous_value = self._capture_dynamic_configuration + self._capture_dynamic_configuration = False + try: + yield + finally: + self._capture_dynamic_configuration = previous_value def _apply_configuration( self, @@ -869,8 +988,24 @@ def _apply_configuration( or not. See the `BaseStep.configure(...)` method for a detailed explanation. """ + from zenml.execution.pipeline.dynamic.run_context import ( + DynamicPipelineRunContext, + ) + self._validate_configuration(config, runtime_parameters) + if ( + self._capture_dynamic_configuration + and DynamicPipelineRunContext.is_active() + ): + if self._dynamic_configuration is None: + self._dynamic_configuration = config + else: + self._dynamic_configuration = pydantic_utils.update_model( + self._dynamic_configuration, update=config, recursive=merge + ) + return + self._configuration = pydantic_utils.update_model( self._configuration, update=config, recursive=merge ) @@ -878,6 +1013,15 @@ def _apply_configuration( logger.debug("Updated step configuration:") logger.debug(self._configuration) + def _merge_dynamic_configuration(self) -> None: + """Merges the dynamic configuration into the static configuration.""" + if self._dynamic_configuration: + with self._suspend_dynamic_configuration(): + self._apply_configuration( + config=self._dynamic_configuration, merge=True + ) + logger.debug("Merged dynamic configuration.") + def _validate_configuration( self, config: "StepConfigurationUpdate", @@ -1031,6 +1175,7 @@ def _finalize_configuration( external_artifacts: Dict[str, "ExternalArtifactConfiguration"], model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"], client_lazy_loaders: Dict[str, "ClientLazyLoader"], + skip_input_validation: bool = False, ) -> "StepConfiguration": """Finalizes the configuration after the step was called. @@ -1045,6 +1190,7 @@ def _finalize_configuration( model_artifacts_or_metadata: The model artifacts or metadata of this step. client_lazy_loaders: The client lazy loaders of this step. + skip_input_validation: If True, will skip the input validation. Raises: StepInterfaceError: If explicit materializers were specified for an @@ -1134,12 +1280,13 @@ def _finalize_configuration( parameters = self._finalize_parameters() self.configure(parameters=parameters, merge=False) - self._validate_inputs( - input_artifacts=input_artifacts, - external_artifacts=external_artifacts, - model_artifacts_or_metadata=model_artifacts_or_metadata, - client_lazy_loaders=client_lazy_loaders, - ) + if not skip_input_validation: + self._validate_inputs( + input_artifacts=input_artifacts, + external_artifacts=external_artifacts, + model_artifacts_or_metadata=model_artifacts_or_metadata, + client_lazy_loaders=client_lazy_loaders, + ) values = dict_utils.remove_none_values({"outputs": outputs or None}) config = StepConfigurationUpdate(**values) diff --git a/src/zenml/steps/step_context.py b/src/zenml/steps/step_context.py index 196727fc647..a60eeef5b76 100644 --- a/src/zenml/steps/step_context.py +++ b/src/zenml/steps/step_context.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """Step context class.""" +import contextvars from typing import ( TYPE_CHECKING, Any, @@ -24,10 +25,13 @@ Type, ) +from typing_extensions import Self + from zenml.exceptions import StepContextError from zenml.logger import get_logger +from zenml.utils import context_utils from zenml.utils.callback_registry import CallbackRegistry -from zenml.utils.singleton import SingletonMetaClass, ThreadLocalSingleton +from zenml.utils.singleton import SingletonMetaClass if TYPE_CHECKING: from zenml.artifacts.artifact_config import ArtifactConfig @@ -54,8 +58,9 @@ def get_step_context() -> "StepContext": Raises: RuntimeError: If no step is currently running. """ - if StepContext._exists(): - return StepContext() # type: ignore + if ctx := StepContext.get(): + return ctx + raise RuntimeError( "The step context is only available inside a step function." ) @@ -110,12 +115,9 @@ def initialize(self, state: Optional[Any]) -> None: self.initialized = True -class StepContext(metaclass=ThreadLocalSingleton): +class StepContext(context_utils.BaseContext): """Provides additional context inside a step function. - This singleton class is used to access information about the current run, - step run, or its outputs inside a step function. - Usage example: ```python @@ -138,6 +140,8 @@ def my_trainer_step() -> Any: ``` """ + __context_var__ = contextvars.ContextVar("step_context") + def __init__( self, pipeline_run: "PipelineRunResponse", @@ -164,6 +168,8 @@ def __init__( """ from zenml.client import Client + super().__init__() + try: pipeline_run = Client().get_pipeline_run(pipeline_run.id) except KeyError: @@ -461,6 +467,30 @@ def remove_output_tags( return output.tags = [tag for tag in output.tags if tag not in tags] + def __enter__(self) -> Self: + """Enter the step context. + + Raises: + RuntimeError: If the step context has already been entered. + + Returns: + The step context object. + """ + if self._token is not None: + raise RuntimeError( + "Running a step from within another step is not allowed." + ) + return super().__enter__() + + def __exit__(self, *_: Any) -> None: + """Exit the step context. + + Args: + *_: Unused keyword arguments. + """ + self._cleanup_registry.execute_callbacks(raise_on_exception=False) + super().__exit__(*_) + class StepContextOutput: """Represents a step output in the step context.""" diff --git a/src/zenml/steps/step_decorator.py b/src/zenml/steps/step_decorator.py index 3ea538c57a7..c682fb33011 100644 --- a/src/zenml/steps/step_decorator.py +++ b/src/zenml/steps/step_decorator.py @@ -80,6 +80,7 @@ def step( retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, cache_policy: Optional["CachePolicyOrString"] = None, + in_process: Optional[bool] = None, ) -> Callable[["F"], "BaseStep"]: ... @@ -104,6 +105,7 @@ def step( retry: Optional["StepRetryConfig"] = None, substitutions: Optional[Dict[str, str]] = None, cache_policy: Optional["CachePolicyOrString"] = None, + in_process: Optional[bool] = None, ) -> Union["BaseStep", Callable[["F"], "BaseStep"]]: """Decorator to create a ZenML step. @@ -139,6 +141,8 @@ def step( retry: configuration of step retry in case of step failure. substitutions: Extra placeholders for the step name. cache_policy: Cache policy for this step. + in_process: Whether to run the step in process. This is only + applicable for dynamic pipelines. Returns: The step instance. @@ -176,6 +180,7 @@ def inner_decorator(func: "F") -> "BaseStep": retry=retry, substitutions=substitutions, cache_policy=cache_policy, + in_process=in_process, ) return step_instance diff --git a/src/zenml/steps/step_invocation.py b/src/zenml/steps/step_invocation.py index 17341d40845..4ef3a8ba677 100644 --- a/src/zenml/steps/step_invocation.py +++ b/src/zenml/steps/step_invocation.py @@ -71,7 +71,11 @@ def __init__( self.upstream_steps = upstream_steps self.pipeline = pipeline - def finalize(self, parameters_to_ignore: Set[str]) -> "StepConfiguration": + def finalize( + self, + parameters_to_ignore: Set[str], + skip_input_validation: bool = False, + ) -> "StepConfiguration": """Finalizes a step invocation. It will validate the upstream steps and run final configurations on the @@ -80,6 +84,7 @@ def finalize(self, parameters_to_ignore: Set[str]) -> "StepConfiguration": Args: parameters_to_ignore: Set of parameters that should not be applied to the step instance. + skip_input_validation: If True, will skip the input validation. Returns: The finalized step configuration. @@ -119,4 +124,5 @@ def finalize(self, parameters_to_ignore: Set[str]) -> "StepConfiguration": external_artifacts=external_artifacts, model_artifacts_or_metadata=self.model_artifacts_or_metadata, client_lazy_loaders=self.client_lazy_loaders, + skip_input_validation=skip_input_validation, ) diff --git a/src/zenml/utils/context_utils.py b/src/zenml/utils/context_utils.py index 2add3bc862c..185bcdfdacb 100644 --- a/src/zenml/utils/context_utils.py +++ b/src/zenml/utils/context_utils.py @@ -13,13 +13,69 @@ # permissions and limitations under the License. """Context variable utilities.""" +import contextvars import threading from contextvars import ContextVar -from typing import Generic, List, Optional, TypeVar +from typing import Any, ClassVar, Generic, List, Optional, TypeVar + +from typing_extensions import Self T = TypeVar("T") +class BaseContext: + """Base context class.""" + + __context_var__: ClassVar[contextvars.ContextVar[Self]] + + def __init__(self) -> None: + """Initialize the context.""" + self._token: Optional[contextvars.Token[Any]] = None + + @classmethod + def get(cls: type[Self]) -> Optional[Self]: + """Get the active context. + + Returns: + The active context. + """ + return cls.__context_var__.get(None) + + @classmethod + def is_active(cls: type[Self]) -> bool: + """Check if the context is active. + + Returns: + True if the context is active, False otherwise. + """ + return cls.get() is not None + + def __enter__(self) -> Self: + """Enter the context. + + Returns: + The context object. + """ + self._token = self.__context_var__.set(self) + return self + + def __exit__(self, *_: Any) -> None: + """Exit the context. + + Args: + *_: Unused keyword arguments. + + Raises: + RuntimeError: If the context has not been entered. + """ + if not self._token: + raise RuntimeError( + f"Can't exit {self.__class__.__name__} because it has not been " + "entered." + ) + self.__context_var__.reset(self._token) + + class ContextVarList(Generic[T]): """Thread-safe wrapper around ContextVar[List] with atomic add/remove operations.""" diff --git a/src/zenml/zen_server/pipeline_execution/runner_entrypoint_configuration.py b/src/zenml/zen_server/pipeline_execution/runner_entrypoint_configuration.py index 635c8b13e97..db149161092 100644 --- a/src/zenml/zen_server/pipeline_execution/runner_entrypoint_configuration.py +++ b/src/zenml/zen_server/pipeline_execution/runner_entrypoint_configuration.py @@ -13,14 +13,14 @@ # permissions and limitations under the License. """Runner entrypoint configuration.""" -from typing import Any, List, Set +from typing import Any, Dict, List from uuid import UUID from zenml.client import Client from zenml.entrypoints.base_entrypoint_configuration import ( BaseEntrypointConfiguration, ) -from zenml.pipelines.run_utils import submit_pipeline +from zenml.execution.pipeline.utils import submit_pipeline PLACEHOLDER_RUN_ID_OPTION = "placeholder_run_id" @@ -29,14 +29,16 @@ class RunnerEntrypointConfiguration(BaseEntrypointConfiguration): """Runner entrypoint configuration.""" @classmethod - def get_entrypoint_options(cls) -> Set[str]: + def get_entrypoint_options(cls) -> Dict[str, bool]: """Gets all options required for running with this configuration. Returns: The superclass options as well as an option for the name of the step to run. """ - return super().get_entrypoint_options() | {PLACEHOLDER_RUN_ID_OPTION} + return super().get_entrypoint_options() | { + PLACEHOLDER_RUN_ID_OPTION: True + } @classmethod def get_entrypoint_arguments( @@ -63,7 +65,7 @@ def run(self) -> None: This method runs the pipeline defined by the snapshot given as input to the entrypoint configuration. """ - snapshot = self.load_snapshot() + snapshot = self.snapshot placeholder_run_id = UUID( self.entrypoint_args[PLACEHOLDER_RUN_ID_OPTION] ) diff --git a/src/zenml/zen_server/pipeline_execution/utils.py b/src/zenml/zen_server/pipeline_execution/utils.py index 12d613de954..9fea72bc09b 100644 --- a/src/zenml/zen_server/pipeline_execution/utils.py +++ b/src/zenml/zen_server/pipeline_execution/utils.py @@ -175,6 +175,12 @@ def run_snapshot( "not have an associated build. This is probably because the " "build has been deleted." ) + if snapshot.is_dynamic: + raise ValueError( + "Snapshots of dynamic pipelines can not be run via the server " + "yet." + ) + raise ValueError("This snapshot can not be run via the server.") # Guaranteed by the `runnable` check above diff --git a/src/zenml/zen_stores/migrations/versions/af27025fe19c_dynamic_pipelines.py b/src/zenml/zen_stores/migrations/versions/af27025fe19c_dynamic_pipelines.py new file mode 100644 index 00000000000..3ef8931d496 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/af27025fe19c_dynamic_pipelines.py @@ -0,0 +1,95 @@ +"""Dynamic pipelines [af27025fe19c]. + +Revision ID: af27025fe19c +Revises: 0.91.0 +Create Date: 2025-10-27 13:23:40.442485 + +""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "af27025fe19c" +down_revision = "0.91.0" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("pipeline_snapshot", schema=None) as batch_op: + batch_op.add_column( + sa.Column("is_dynamic", sa.Boolean(), nullable=True) + ) + + op.execute( + "UPDATE pipeline_snapshot SET is_dynamic = FALSE WHERE is_dynamic IS NULL" + ) + with op.batch_alter_table("pipeline_snapshot", schema=None) as batch_op: + batch_op.alter_column( + "is_dynamic", existing_type=sa.Boolean(), nullable=False + ) + + with op.batch_alter_table("step_configuration", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "step_run_id", sqlmodel.sql.sqltypes.GUID(), nullable=True + ) + ) + batch_op.alter_column( + "snapshot_id", existing_type=sa.CHAR(length=32), nullable=True + ) + batch_op.drop_constraint( + "unique_step_name_for_snapshot", type_="unique" + ) + batch_op.create_unique_constraint( + "unique_step_name_for_snapshot", + ["snapshot_id", "step_run_id", "name"], + ) + batch_op.create_foreign_key( + "fk_step_configuration_step_run_id_step_run", + "step_run", + ["step_run_id"], + ["id"], + ondelete="CASCADE", + ) + + batch_op.create_check_constraint( + "ck_step_configuration_snapshot_step_run_exclusivity", + "((snapshot_id IS NULL AND step_run_id IS NOT NULL) OR " + "(snapshot_id IS NOT NULL AND step_run_id IS NULL))", + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint( + "ck_step_configuration_snapshot_step_run_exclusivity", + "step_configuration", + type_="check", + ) + with op.batch_alter_table("step_configuration", schema=None) as batch_op: + batch_op.drop_constraint( + "fk_step_configuration_step_run_id_step_run", type_="foreignkey" + ) + batch_op.drop_constraint( + "unique_step_name_for_snapshot", type_="unique" + ) + batch_op.create_unique_constraint( + "unique_step_name_for_snapshot", ["snapshot_id", "name"] + ) + batch_op.alter_column( + "snapshot_id", existing_type=sa.CHAR(length=32), nullable=False + ) + batch_op.drop_column("step_run_id") + + with op.batch_alter_table("pipeline_snapshot", schema=None) as batch_op: + batch_op.drop_column("is_dynamic") + + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index e26e1eac710..4f7a869a4b5 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -179,7 +179,10 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True): sa_relationship_kwargs={"cascade": "delete"}, ) step_runs: List["StepRunSchema"] = Relationship( - sa_relationship_kwargs={"cascade": "delete"}, + sa_relationship_kwargs={ + "cascade": "delete", + "order_by": "asc(StepRunSchema.start_time)", + }, ) model_version: "ModelVersionSchema" = Relationship( back_populates="pipeline_runs", @@ -709,7 +712,15 @@ def update(self, run_update: "PipelineRunUpdate") -> "PipelineRunSchema": if run_update.status_reason: self.status_reason = run_update.status_reason - self.in_progress = self._check_if_run_in_progress() + if run_update.is_finished: + self.in_progress = False + elif self.snapshot and self.snapshot.is_dynamic: + # In dynamic pipelines, we can't actually check if the run is + # in progress by inspecting the DAG. Only once the orchestration + # container finishes we know for sure. + pass + else: + self.in_progress = self._check_if_run_in_progress() if run_update.orchestrator_run_id: self.orchestrator_run_id = run_update.orchestrator_run_id diff --git a/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py b/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py index 668b0ef04e4..caa406a4800 100644 --- a/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_snapshot_schemas.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Sequence from uuid import UUID -from sqlalchemy import TEXT, Column, String, UniqueConstraint +from sqlalchemy import TEXT, CheckConstraint, Column, String, UniqueConstraint from sqlalchemy.dialects.mysql import MEDIUMTEXT from sqlalchemy.orm import joinedload, object_session, selectinload from sqlalchemy.sql.base import ExecutableOption @@ -87,6 +87,7 @@ class PipelineSnapshotSchema(BaseSchema, table=True): nullable=True, ) ) + is_dynamic: bool = Field(nullable=False, default=False) pipeline_configuration: str = Field( sa_column=Column( @@ -409,6 +410,7 @@ def from_request( return cls( name=name, description=request.description, + is_dynamic=request.is_dynamic, stack_id=request.stack, project_id=request.project, pipeline_id=request.pipeline, @@ -480,11 +482,21 @@ def to_model( The response. """ runnable = False - if self.build and not self.build.is_local and self.build.stack_id: + if ( + not self.is_dynamic + and self.build + and not self.build.is_local + and self.build.stack_id + ): runnable = True deployable = False - if self.build and self.stack and self.stack.has_deployer: + if ( + not self.is_dynamic + and self.build + and self.stack + and self.stack.has_deployer + ): deployable = True body = PipelineSnapshotResponseBody( @@ -494,6 +506,7 @@ def to_model( updated=self.updated, runnable=runnable, deployable=deployable, + is_dynamic=self.is_dynamic, ) metadata = None if include_metadata: @@ -612,9 +625,15 @@ class StepConfigurationSchema(BaseSchema, table=True): __table_args__ = ( UniqueConstraint( "snapshot_id", + "step_run_id", "name", name="unique_step_name_for_snapshot", ), + CheckConstraint( + "(snapshot_id IS NULL AND step_run_id IS NOT NULL) OR " + "(snapshot_id IS NOT NULL AND step_run_id IS NULL)", + name="ck_step_configuration_snapshot_step_run_exclusivity", + ), ) index: int @@ -634,5 +653,13 @@ class StepConfigurationSchema(BaseSchema, table=True): source_column="snapshot_id", target_column="id", ondelete="CASCADE", - nullable=False, + nullable=True, + ) + step_run_id: UUID = build_foreign_key_field( + source=__tablename__, + target="step_run", + source_column="step_run_id", + target_column="id", + ondelete="CASCADE", + nullable=True, ) diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 2dced06c2e5..bb589a2feb4 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -96,14 +96,6 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): version: int = Field(nullable=False) is_retriable: bool = Field(nullable=False) - step_configuration: str = Field( - sa_column=Column( - String(length=MEDIUMTEXT_MAX_LENGTH).with_variant( - MEDIUMTEXT, "mysql" - ), - nullable=True, - ) - ) exception_info: Optional[str] = Field( sa_column=Column( String(length=MEDIUMTEXT_MAX_LENGTH).with_variant( @@ -209,12 +201,25 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): original_step_run: Optional["StepRunSchema"] = Relationship( sa_relationship_kwargs={"remote_side": "StepRunSchema.id"} ) - step_configuration_schema: Optional["StepConfigurationSchema"] = ( - Relationship( - sa_relationship_kwargs=dict( - viewonly=True, - primaryjoin="and_(foreign(StepConfigurationSchema.name) == StepRunSchema.name, foreign(StepConfigurationSchema.snapshot_id) == StepRunSchema.snapshot_id)", + # In static pipelines, we use the config that is compiled in the snapshot. + static_config: Optional["StepConfigurationSchema"] = Relationship( + sa_relationship_kwargs=dict( + viewonly=True, + primaryjoin="and_(foreign(StepConfigurationSchema.name) == StepRunSchema.name, foreign(StepConfigurationSchema.snapshot_id) == StepRunSchema.snapshot_id)", + ), + ) + # In dynamic pipelines, the config is dynamically generated and cannot be + # included in the compiled snapshot. In this case, we link it directly to + # the step run. + dynamic_config: Optional["StepConfigurationSchema"] = Relationship() + # In legacy pipelines (before snapshots, former deployments), the config + # is stored as a string in the step run. + step_configuration: str = Field( + sa_column=Column( + String(length=MEDIUMTEXT_MAX_LENGTH).with_variant( + MEDIUMTEXT, "mysql" ), + nullable=True, ) ) @@ -251,7 +256,8 @@ def get_query_options( selectinload(jl_arg(StepRunSchema.pipeline_run)).load_only( jl_arg(PipelineRunSchema.start_time) ), - joinedload(jl_arg(StepRunSchema.step_configuration_schema)), + joinedload(jl_arg(StepRunSchema.static_config)), + joinedload(jl_arg(StepRunSchema.dynamic_config)), ] if include_metadata: @@ -346,7 +352,7 @@ def get_step_configuration(self) -> Step: step = None if self.snapshot is not None: - if self.step_configuration_schema: + if config_schema := (self.dynamic_config or self.static_config): pipeline_configuration = ( PipelineConfiguration.model_validate_json( self.snapshot.pipeline_configuration @@ -357,7 +363,7 @@ def get_step_configuration(self) -> Step: inplace=True, ) step = Step.from_dict( - json.loads(self.step_configuration_schema.config), + json.loads(config_schema.config), pipeline_configuration=pipeline_configuration, ) if not step and self.step_configuration: diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 451007e05eb..c3e37bc6dc1 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -6053,6 +6053,7 @@ def get_pipeline_run_dag(self, pipeline_run_id: UUID) -> PipelineRunDAG: """ helper = DAGGeneratorHelper() with Session(self.engine) as session: + # TODO: better loads for dynamic/static pipelines run = self._get_schema_by_id( resource_id=pipeline_run_id, schema_class=PipelineRunSchema, @@ -6060,6 +6061,7 @@ def get_pipeline_run_dag(self, pipeline_run_id: UUID) -> PipelineRunDAG: query_options=[ selectinload(jl_arg(PipelineRunSchema.snapshot)).load_only( jl_arg(PipelineSnapshotSchema.pipeline_configuration), + jl_arg(PipelineSnapshotSchema.is_dynamic), ), selectinload( jl_arg(PipelineRunSchema.snapshot) @@ -6072,6 +6074,9 @@ def get_pipeline_run_dag(self, pipeline_run_id: UUID) -> PipelineRunDAG: selectinload( jl_arg(PipelineRunSchema.step_runs) ).selectinload(jl_arg(StepRunSchema.output_artifacts)), + selectinload( + jl_arg(PipelineRunSchema.step_runs) + ).selectinload(jl_arg(StepRunSchema.dynamic_config)), selectinload(jl_arg(PipelineRunSchema.step_runs)) .selectinload(jl_arg(StepRunSchema.triggered_runs)) .load_only( @@ -6098,13 +6103,23 @@ def get_pipeline_run_dag(self, pipeline_run_id: UUID) -> PipelineRunDAG: start_time=run.start_time, inplace=True ) - steps = { - config_table.name: Step.from_dict( - json.loads(config_table.config), - pipeline_configuration=pipeline_configuration, - ) - for config_table in snapshot.step_configurations - } + if snapshot.is_dynamic: + # Ignore static config templates for dynamic pipeline DAGs + steps = { + name: Step.from_dict( + json.loads(step_run.dynamic_config.config), # type: ignore[union-attr] + pipeline_configuration=pipeline_configuration, + ) + for name, step_run in step_runs.items() + } + else: + steps = { + config_table.name: Step.from_dict( + json.loads(config_table.config), + pipeline_configuration=pipeline_configuration, + ) + for config_table in snapshot.step_configurations + } regular_output_artifact_nodes: Dict[ str, Dict[str, PipelineRunDAG.Node] ] = defaultdict(dict) @@ -9959,7 +9974,10 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse: session=session, reference_type="original step run", ) - step_config = run.get_step_configuration(step_name=step_run.name) + step_config = ( + step_run.dynamic_config + or run.get_step_configuration(step_name=step_run.name) + ) # Release the read locks of the previous two queries before we # try to acquire more exclusive locks @@ -10211,6 +10229,26 @@ def create_run_step(self, step_run: StepRunRequest) -> StepRunResponse: pipeline_run_id=step_run.pipeline_run_id, session=session ) + if step_run.dynamic_config: + if not run.snapshot or not run.snapshot.is_dynamic: + raise IllegalOperationError( + "Dynamic step configurations are not allowed for " + "static pipelines." + ) + + step_configuration_schema = StepConfigurationSchema( + index=0, + name=step_run.name, + # Don't include the merged config in the step + # configurations, we reconstruct it in the `to_model` method + # using the pipeline configuration. + config=step_run.dynamic_config.model_dump_json( + exclude={"config"} + ), + step_run_id=step_schema.id, + ) + session.add(step_configuration_schema) + session.commit() session.refresh( step_schema, ["input_artifacts", "output_artifacts"] @@ -10622,12 +10660,14 @@ def _update_pipeline_run_status( # Snapshots always exists for pipeline runs of newer versions assert pipeline_run.snapshot num_steps = pipeline_run.snapshot.step_count + is_dynamic_pipeline = pipeline_run.snapshot.is_dynamic new_status = get_pipeline_run_status( run_status=ExecutionStatus(pipeline_run.status), step_statuses=[ ExecutionStatus(status) for status in step_run_statuses ], num_steps=num_steps, + is_dynamic_pipeline=is_dynamic_pipeline, ) if new_status == pipeline_run.status or ( diff --git a/tests/integration/functional/model/test_model_version.py b/tests/integration/functional/model/test_model_version.py index 1959bd2aa10..2f72266bd55 100644 --- a/tests/integration/functional/model/test_model_version.py +++ b/tests/integration/functional/model/test_model_version.py @@ -560,7 +560,7 @@ def test_link_artifact_via_function(self): @pipeline def _inner_pipeline( - model: Model = None, + model: Optional[Model] = None, artifact_type: Optional[ArtifactType] = None, ): artifact_linker( diff --git a/tests/integration/functional/test_client.py b/tests/integration/functional/test_client.py index d594991fefc..8648dd47f45 100644 --- a/tests/integration/functional/test_client.py +++ b/tests/integration/functional/test_client.py @@ -735,6 +735,7 @@ def test_listing_snapshots(clean_client): client_version="0.12.3", server_version="0.12.3", pipeline=pipeline.id, + is_dynamic=False, ) response = clean_client.zen_store.create_snapshot(request) @@ -765,6 +766,7 @@ def test_getting_snapshots(clean_client): client_version="0.12.3", server_version="0.12.3", pipeline=pipeline.id, + is_dynamic=False, ) response = clean_client.zen_store.create_snapshot(request) @@ -793,6 +795,7 @@ def test_deleting_snapshots(clean_client): client_version="0.12.3", server_version="0.12.3", pipeline=pipeline.id, + is_dynamic=False, ) response = clean_client.zen_store.create_snapshot(request) diff --git a/tests/integration/functional/zen_stores/test_zen_store.py b/tests/integration/functional/zen_stores/test_zen_store.py index 4c4b82f2265..a62101389c9 100644 --- a/tests/integration/functional/zen_stores/test_zen_store.py +++ b/tests/integration/functional/zen_stores/test_zen_store.py @@ -5597,6 +5597,7 @@ def test_metadata_full_cycle_with_cascade_deletion( config=StepConfiguration(name=step_name), ) }, + is_dynamic=False, ) ) pr, _ = client.zen_store.get_or_create_run( @@ -5657,6 +5658,7 @@ def test_metadata_full_cycle_with_cascade_deletion( ) }, schedule=resource.id, + is_dynamic=False, ) ) else: @@ -5890,6 +5892,7 @@ def create_artifact_version(): config=StepConfiguration(name=step_name), ) }, + is_dynamic=False, ) ) diff --git a/tests/integration/functional/zen_stores/utils.py b/tests/integration/functional/zen_stores/utils.py index 036838b7558..e6de109185f 100644 --- a/tests/integration/functional/zen_stores/utils.py +++ b/tests/integration/functional/zen_stores/utils.py @@ -755,6 +755,7 @@ def __enter__(self): pipeline_configuration={"name": "pipeline_name"}, client_version="0.12.3", server_version="0.12.3", + is_dynamic=False, ), ) self.snapshots.append(snapshot) @@ -1292,6 +1293,7 @@ def cleanup(self) -> None: server_version="0.12.3", pipeline_version_hash="random_hash", pipeline_spec=PipelineSpec(steps=[]), + is_dynamic=False, ), filter_model=PipelineSnapshotFilter, entity_name="snapshot", @@ -1371,6 +1373,7 @@ def cleanup(self) -> None: server_version="0.12.3", pipeline_version_hash="random_hash", pipeline_spec=PipelineSpec(steps=[]), + is_dynamic=False, ), filter_model=PipelineSnapshotFilter, entity_name="snapshot", diff --git a/tests/unit/config/test_compiler.py b/tests/unit/config/test_compiler.py index 6a3976ba5c2..8fc930ae637 100644 --- a/tests/unit/config/test_compiler.py +++ b/tests/unit/config/test_compiler.py @@ -34,8 +34,7 @@ def test_compiling_pipeline_with_invalid_run_name_fails( ): """Tests that compiling a pipeline with an invalid run name fails.""" pipeline_instance = empty_pipeline - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() with pytest.raises(ValueError): Compiler().compile( pipeline=pipeline_instance, @@ -54,8 +53,7 @@ def _no_step_pipeline(): def test_compiling_pipeline_without_steps_fails(local_stack): """Tests that compiling a pipeline without steps fails.""" pipeline_instance = _no_step_pipeline - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() with pytest.raises(ValueError): Compiler().compile( pipeline=pipeline_instance, @@ -71,8 +69,7 @@ def test_compiling_pipeline_with_missing_step_operator( pipeline_instance = one_step_pipeline( empty_step.configure(step_operator="s") ) - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() with pytest.raises(StackValidationError): Compiler().compile( pipeline=pipeline_instance, @@ -89,8 +86,7 @@ def test_compiling_pipeline_with_missing_experiment_tracker( pipeline_instance = one_step_pipeline( empty_step.configure(experiment_tracker="e") ) - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() with pytest.raises(StackValidationError): Compiler().compile( pipeline=pipeline_instance, @@ -113,8 +109,7 @@ def test_pipeline_and_steps_dont_get_modified_during_compilation( "_empty_step": StepConfigurationUpdate(extra={"key": "new_value"}) }, ) - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() Compiler().compile( pipeline=pipeline_instance, stack=local_stack, @@ -148,8 +143,7 @@ def pipeline_instance(): empty_step(id="step_1") empty_step(id="step_2", after="step_1") - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() snapshot = Compiler().compile( pipeline=pipeline_instance, stack=local_stack, @@ -204,8 +198,7 @@ class StubSettings(BaseSettings): ) }, ) - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() snapshot = Compiler().compile( pipeline=pipeline_instance, stack=local_stack, @@ -252,8 +245,7 @@ def test_general_settings_merging(one_step_pipeline, empty_step, local_stack): ) }, ) - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() snapshot = Compiler().compile( pipeline=pipeline_instance, stack=local_stack, @@ -299,8 +291,7 @@ def test_extra_merging(one_step_pipeline, empty_step, local_stack): steps={"_empty_step": StepConfigurationUpdate(extra=run_step_extra)}, ) - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() snapshot = Compiler().compile( pipeline=pipeline_instance, @@ -358,8 +349,7 @@ def test_success_hook_merging( }, ) - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() snapshot = Compiler().compile( pipeline=pipeline_instance, stack=local_stack, @@ -409,8 +399,7 @@ def test_failure_hook_merging( }, ) - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() snapshot = Compiler().compile( pipeline=pipeline_instance, stack=local_stack, @@ -453,8 +442,7 @@ def test_stack_component_settings_for_missing_component_are_ignored( steps={"_empty_step": StepConfigurationUpdate(settings=settings)}, ) - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() snapshot = Compiler().compile( pipeline=pipeline_instance, stack=local_stack, @@ -500,8 +488,7 @@ class StubSettings(BaseSettings): steps={"_empty_step": StepConfigurationUpdate(settings=settings)}, ) - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() snapshot = Compiler().compile( pipeline=pipeline_instance, stack=local_stack, @@ -551,8 +538,7 @@ class StubSettings(BaseSettings): steps={"_empty_step": StepConfigurationUpdate(settings=settings)}, ) - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() snapshot = Compiler().compile( pipeline=pipeline_instance, stack=local_stack, @@ -585,8 +571,7 @@ def test_spec_compilation(local_stack): def pipeline_instance(): s2(s1()) - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() spec = ( Compiler() .compile( @@ -652,8 +637,7 @@ class StubSettings(BaseSettings): settings={"orchestrator": shortcut_settings} ) - with pipeline_instance_with_shortcut_settings: - pipeline_instance_with_shortcut_settings.entrypoint() + pipeline_instance_with_shortcut_settings.prepare() with does_not_raise(): snapshot = Compiler().compile( @@ -677,8 +661,7 @@ class StubSettings(BaseSettings): } ) - with pipeline_instance_with_duplicate_settings: - pipeline_instance_with_duplicate_settings.entrypoint() + pipeline_instance_with_duplicate_settings.prepare() with pytest.raises(ValueError): snapshot = Compiler().compile( diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 663ab001114..0fdf040e825 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -294,7 +294,6 @@ def step_context_with_no_output( sample_pipeline_run: PipelineRunResponse, sample_step_run: StepRunResponse, ) -> StepContext: - StepContext._clear() return StepContext( pipeline_run=sample_pipeline_run, step_run=sample_step_run, @@ -312,7 +311,6 @@ def step_context_with_single_output( materializers = {"output_1": (BaseMaterializer,)} artifact_uris = {"output_1": ""} artifact_configs = {"output_1": None} - StepContext._clear() return StepContext( pipeline_run=sample_pipeline_run, step_run=sample_step_run, @@ -337,7 +335,6 @@ def step_context_with_two_outputs( } artifact_configs = {"output_1": None, "output_2": None} - StepContext._clear() return StepContext( pipeline_run=sample_pipeline_run, step_run=sample_step_run, @@ -456,6 +453,7 @@ def sample_pipeline_snapshot_request_model() -> PipelineSnapshotRequest: server_version="0.12.3", stack=uuid4(), pipeline=uuid4(), + is_dynamic=False, ) @@ -634,6 +632,7 @@ def sample_snapshot_response_model( updated=datetime.now(), runnable=True, deployable=True, + is_dynamic=False, ), metadata=PipelineSnapshotResponseMetadata( run_name_template="", diff --git a/tests/unit/entrypoints/test_base_entrypoint_configuration.py b/tests/unit/entrypoints/test_base_entrypoint_configuration.py index 1736dd32d70..23afe488338 100644 --- a/tests/unit/entrypoints/test_base_entrypoint_configuration.py +++ b/tests/unit/entrypoints/test_base_entrypoint_configuration.py @@ -69,4 +69,4 @@ def test_loading_the_snapshot(clean_client): arguments=["--snapshot_id", str(snapshot.id)] ) - assert entrypoint_config.load_snapshot() == snapshot + assert entrypoint_config.snapshot == snapshot diff --git a/tests/unit/orchestrators/test_publish_utils.py b/tests/unit/orchestrators/test_publish_utils.py index 688a4ba11a1..5ad34835684 100644 --- a/tests/unit/orchestrators/test_publish_utils.py +++ b/tests/unit/orchestrators/test_publish_utils.py @@ -100,6 +100,7 @@ def test_pipeline_run_status_computation( run_status=ExecutionStatus.RUNNING, step_statuses=step_statuses, num_steps=num_steps, + is_dynamic_pipeline=False, ) == expected_run_status ) diff --git a/tests/unit/orchestrators/test_step_runner.py b/tests/unit/orchestrators/test_step_runner.py index 2cce79e48c7..a43ad92cb9b 100644 --- a/tests/unit/orchestrators/test_step_runner.py +++ b/tests/unit/orchestrators/test_step_runner.py @@ -22,7 +22,11 @@ from zenml.config.step_configurations import Step from zenml.config.step_run_info import StepRunInfo from zenml.enums import ArtifactSaveType -from zenml.models import PipelineRunResponse, StepRunResponse +from zenml.models import ( + PipelineRunResponse, + PipelineSnapshotResponse, + StepRunResponse, +) from zenml.orchestrators.step_launcher import StepRunner from zenml.stack import Stack from zenml.steps import step @@ -43,6 +47,7 @@ def test_running_a_successful_step( local_stack, sample_pipeline_run: PipelineRunResponse, sample_step_run: StepRunResponse, + sample_snapshot_response_model: PipelineSnapshotResponse, ): """Tests that running a successful step runs the step entrypoint and correctly prepares/cleans up.""" @@ -74,7 +79,9 @@ def test_running_a_successful_step( run_name="run_name", pipeline_step_name="step_name", config=step.config, + spec=step.spec, pipeline=pipeline_config, + snapshot=sample_snapshot_response_model, force_write_logs=lambda: None, ) @@ -98,6 +105,7 @@ def test_running_a_failing_step( local_stack, sample_pipeline_run: PipelineRunResponse, sample_step_run: StepRunResponse, + sample_snapshot_response_model: PipelineSnapshotResponse, ): """Tests that running a failing step runs the step entrypoint and correctly prepares/cleans up.""" @@ -130,7 +138,9 @@ def test_running_a_failing_step( run_name="run_name", pipeline_step_name="step_name", config=step.config, + spec=step.spec, pipeline=pipeline_config, + snapshot=sample_snapshot_response_model, force_write_logs=lambda: None, ) diff --git a/tests/unit/pipelines/test_base_pipeline.py b/tests/unit/pipelines/test_base_pipeline.py index 281d8ff3dc9..f33abbc5531 100644 --- a/tests/unit/pipelines/test_base_pipeline.py +++ b/tests/unit/pipelines/test_base_pipeline.py @@ -429,8 +429,7 @@ def test_compiling_a_pipeline_merges_schedule( config_path.write_text(run_config.yaml()) pipeline_instance = empty_pipeline - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() _, schedule, _ = pipeline_instance._compile( config_path=str(config_path), @@ -462,8 +461,7 @@ def test_compiling_a_pipeline_merges_build( in_code_build_id = uuid4() pipeline_instance = empty_pipeline - with pipeline_instance: - pipeline_instance.entrypoint() + pipeline_instance.prepare() # Config with ID _, _, build = pipeline_instance._compile( diff --git a/tests/unit/stack/test_stack.py b/tests/unit/stack/test_stack.py index 31a465a24b6..25cadd47af0 100644 --- a/tests/unit/stack/test_stack.py +++ b/tests/unit/stack/test_stack.py @@ -149,8 +149,7 @@ def test_stack_submission( components.""" # Mock the pipeline run registering which tries (and fails) to serialize # our mock objects - with empty_pipeline: - empty_pipeline.entrypoint() + empty_pipeline.prepare() snapshot = Compiler().compile( pipeline=empty_pipeline, stack=stack_with_mock_components, diff --git a/tests/unit/steps/test_step_context.py b/tests/unit/steps/test_step_context.py index 0e2de8dbae0..3fb34a22a36 100644 --- a/tests/unit/steps/test_step_context.py +++ b/tests/unit/steps/test_step_context.py @@ -22,22 +22,15 @@ from zenml.steps import StepContext -def test_step_context_is_singleton(step_context_with_no_output): - """Tests that the step context is a singleton.""" - assert StepContext() is step_context_with_no_output - - def test_get_step_context(step_context_with_no_output): """Unit test for `get_step_context()`.""" - # step context exists -> returns the step context - assert get_step_context() is StepContext() - - # step context does not exist -> raises an exception - StepContext._clear() with pytest.raises(RuntimeError): get_step_context() + with step_context_with_no_output: + assert get_step_context() is step_context_with_no_output + def test_initialize_step_context_with_mismatched_keys( sample_pipeline_run, @@ -49,7 +42,6 @@ def test_initialize_step_context_with_mismatched_keys( artifact_configs = {"some_yet_another_output_name": None} with pytest.raises(StepContextError): - StepContext._clear() StepContext( pipeline_run=sample_pipeline_run, step_run=sample_step_run, @@ -69,7 +61,6 @@ def test_initialize_step_context_with_matching_keys( artifact_configs = {"some_output_name": None} with does_not_raise(): - StepContext._clear() StepContext( pipeline_run=sample_pipeline_run, step_run=sample_step_run,