Skip to content

Commit 6b51cc5

Browse files
committed
misc
1 parent 98d115e commit 6b51cc5

File tree

14 files changed

+271
-146
lines changed

14 files changed

+271
-146
lines changed

src/zenml/config/compiler.py

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from zenml.config.step_configurations import (
3636
InputSpec,
3737
Step,
38+
StepConfiguration,
3839
StepConfigurationUpdate,
3940
StepSpec,
4041
)
@@ -47,6 +48,7 @@
4748
if TYPE_CHECKING:
4849
from zenml.pipelines.pipeline_definition import Pipeline
4950
from zenml.stack import Stack, StackComponent
51+
from zenml.steps.base_step import BaseStep
5052
from zenml.steps.step_invocation import StepInvocation
5153

5254
from zenml.logger import get_logger
@@ -126,17 +128,29 @@ def compile(
126128
merge=False,
127129
)
128130

129-
steps = {
130-
invocation_id: self._compile_step_invocation(
131-
invocation=invocation,
132-
stack=stack,
133-
step_config=(run_configuration.steps or {}).get(invocation_id),
134-
pipeline_configuration=pipeline.configuration,
135-
)
136-
for invocation_id, invocation in self._get_sorted_invocations(
137-
pipeline=pipeline
138-
)
139-
}
131+
if pipeline.is_dynamic:
132+
step_templates = {
133+
step.name: self._compile_config_template(
134+
step=step, stack=stack
135+
)
136+
for step in pipeline.depends_on
137+
}
138+
steps = {}
139+
else:
140+
step_templates = None
141+
steps = {
142+
invocation_id: self._compile_step_invocation(
143+
invocation=invocation,
144+
stack=stack,
145+
step_config=(run_configuration.steps or {}).get(
146+
invocation_id
147+
),
148+
pipeline_configuration=pipeline.configuration,
149+
)
150+
for invocation_id, invocation in self._get_sorted_invocations(
151+
pipeline=pipeline
152+
)
153+
}
140154

141155
self._ensure_required_stack_components_exist(stack=stack, steps=steps)
142156

@@ -156,6 +170,7 @@ def compile(
156170
is_dynamic=pipeline.is_dynamic,
157171
pipeline_configuration=pipeline.configuration,
158172
step_configurations=steps,
173+
step_configuration_templates=step_templates,
159174
client_environment=get_run_environment_dict(),
160175
client_version=client_version,
161176
server_version=server_version,
@@ -521,6 +536,48 @@ def _compile_step_invocation(
521536
step_config_overrides=step_configuration_overrides,
522537
)
523538

539+
def _compile_config_template(
540+
self,
541+
step: "BaseStep",
542+
stack: "Stack",
543+
step_config: Optional["StepConfigurationUpdate"],
544+
) -> StepConfiguration:
545+
"""Compiles a ZenML step.
546+
547+
Args:
548+
invocation: The step invocation to compile.
549+
stack: The stack on which the pipeline will be run.
550+
step_config: Run configuration for the step.
551+
pipeline_configuration: Configuration for the pipeline.
552+
553+
Returns:
554+
The compiled step.
555+
"""
556+
if step_config:
557+
step._apply_configuration(step_config)
558+
559+
convert_component_shortcut_settings_keys(
560+
step.configuration.settings, stack=stack
561+
)
562+
step_secrets = secret_utils.resolve_and_verify_secrets(
563+
step.configuration.secrets
564+
)
565+
step_settings = self._filter_and_validate_settings(
566+
settings=step.configuration.settings,
567+
configuration_level=ConfigurationLevel.STEP,
568+
stack=stack,
569+
)
570+
step.configure(
571+
secrets=step_secrets,
572+
settings=step_settings,
573+
merge=False,
574+
)
575+
576+
# TODO: apply pipeline config
577+
return StepConfiguration.model_validate(
578+
step.configuration.model_dump()
579+
)
580+
524581
def _get_sorted_invocations(
525582
self,
526583
pipeline: "Pipeline",

src/zenml/constants.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
246246
# ZenML Analytics Server - URL
247247
ANALYTICS_SERVER_URL = "https://analytics.zenml.io/"
248248

249-
# Container utils
250-
SHOULD_PREVENT_PIPELINE_EXECUTION = handle_bool_env_var(
251-
ENV_ZENML_PREVENT_PIPELINE_EXECUTION
252-
)
253-
254249
# Repository and local store directory paths:
255250
REPOSITORY_DIRECTORY_NAME = ".zen"
256251
LOCAL_STORES_DIRECTORY_NAME = "local_stores"

src/zenml/entrypoints/entrypoint.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717
import logging
1818
import sys
1919

20-
from zenml import constants
2120
from zenml.entrypoints.base_entrypoint_configuration import (
2221
ENTRYPOINT_CONFIG_SOURCE_OPTION,
2322
BaseEntrypointConfiguration,
2423
)
24+
from zenml.pipelines.run_utils import prevent_pipeline_execution
2525
from zenml.utils import source_utils
2626

2727

@@ -35,23 +35,22 @@ def main() -> None:
3535
_setup_logging()
3636

3737
# Make sure this entrypoint does not run an entire pipeline when
38-
# importing user modules. This could happen if the `pipeline.run()` call
39-
# is not wrapped in a function or an `if __name__== "__main__":` check)
40-
constants.SHOULD_PREVENT_PIPELINE_EXECUTION = True
41-
42-
# Read the source for the entrypoint configuration class from the command
43-
# line arguments
44-
parser = argparse.ArgumentParser()
45-
parser.add_argument(f"--{ENTRYPOINT_CONFIG_SOURCE_OPTION}", required=True)
46-
args, remaining_args = parser.parse_known_args()
47-
48-
entrypoint_config_class = source_utils.load_and_validate_class(
49-
args.entrypoint_config_source,
50-
expected_class=BaseEntrypointConfiguration,
51-
)
52-
entrypoint_config = entrypoint_config_class(arguments=remaining_args)
53-
54-
entrypoint_config.run()
38+
# importing user modules. This could happen if a pipeline is called in a
39+
# module without an `if __name__== "__main__":` check)
40+
with prevent_pipeline_execution():
41+
parser = argparse.ArgumentParser()
42+
parser.add_argument(
43+
f"--{ENTRYPOINT_CONFIG_SOURCE_OPTION}", required=True
44+
)
45+
args, remaining_args = parser.parse_known_args()
46+
47+
entrypoint_config_class = source_utils.load_and_validate_class(
48+
args.entrypoint_config_source,
49+
expected_class=BaseEntrypointConfiguration,
50+
)
51+
entrypoint_config = entrypoint_config_class(arguments=remaining_args)
52+
53+
entrypoint_config.run()
5554

5655

5756
if __name__ == "__main__":

src/zenml/models/v2/core/pipeline_snapshot.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from zenml.config.pipeline_configurations import PipelineConfiguration
3232
from zenml.config.pipeline_run_configuration import PipelineRunConfiguration
3333
from zenml.config.pipeline_spec import PipelineSpec
34-
from zenml.config.step_configurations import Step
34+
from zenml.config.step_configurations import Step, StepConfiguration
3535
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
3636
from zenml.enums import ExecutionStatus, StackComponentType
3737
from zenml.models.v2.base.base import BaseUpdate, BaseZenModel
@@ -82,6 +82,12 @@ class PipelineSnapshotBase(BaseZenModel):
8282
step_configurations: Dict[str, Step] = Field(
8383
default={}, title="The step configurations for this snapshot."
8484
)
85+
step_configuration_templates: Optional[Dict[str, StepConfiguration]] = (
86+
Field(
87+
default=None,
88+
title="The step configuration templates of the snapshot.",
89+
)
90+
)
8591
client_environment: Dict[str, Any] = Field(
8692
default={}, title="The client environment for this snapshot."
8793
)

src/zenml/orchestrators/containerized_orchestrator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,24 @@ def get_docker_builds(
120120
builds.append(pipeline_build)
121121
included_pipeline_build = True
122122

123+
for name, step_config in snapshot.step_configuration_templates.items():
124+
step_settings = step_config.docker_settings
125+
126+
if step_settings != pipeline_settings:
127+
build = BuildConfiguration(
128+
key=ORCHESTRATOR_DOCKER_IMAGE_KEY,
129+
settings=step_settings,
130+
step_name=name,
131+
)
132+
builds.append(build)
133+
elif not included_pipeline_build:
134+
pipeline_build = BuildConfiguration(
135+
key=ORCHESTRATOR_DOCKER_IMAGE_KEY,
136+
settings=pipeline_settings,
137+
)
138+
builds.append(pipeline_build)
139+
included_pipeline_build = True
140+
123141
if not included_pipeline_build and self.should_build_pipeline_image(
124142
snapshot
125143
):

src/zenml/orchestrators/local/local_orchestrator.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,6 @@ def submit_pipeline(
9696
step.
9797
RuntimeError: If the pipeline run fails.
9898
"""
99-
if snapshot.schedule:
100-
logger.warning(
101-
"Local Orchestrator currently does not support the "
102-
"use of schedules. The `schedule` will be ignored "
103-
"and the pipeline will be run immediately."
104-
)
105-
10699
self._orchestrator_run_id = str(uuid4())
107100
start_time = time.time()
108101

@@ -203,10 +196,19 @@ def submit_dynamic_pipeline(
203196
"""Submits a dynamic pipeline to the orchestrator."""
204197
from zenml.pipelines.dynamic.runner import DynamicPipelineRunner
205198

199+
self._orchestrator_run_id = str(uuid4())
200+
start_time = time.time()
201+
206202
runner = DynamicPipelineRunner(snapshot=snapshot, run=placeholder_run)
207203
with temporary_environment(environment):
208204
runner.run_pipeline()
209205

206+
run_duration = time.time() - start_time
207+
logger.info(
208+
"Pipeline run has finished in `%s`.",
209+
string_utils.get_human_readable_time(run_duration),
210+
)
211+
self._orchestrator_run_id = None
210212
return None
211213

212214
def get_orchestrator_run_id(self) -> str:

src/zenml/pipelines/dynamic/pipeline_definition.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,25 @@
1818
Any,
1919
Callable,
2020
Dict,
21+
List,
2122
Optional,
2223
Set,
2324
TypeVar,
2425
Union,
2526
)
26-
from uuid import uuid4
2727

2828
from pydantic import ConfigDict, ValidationError
2929

30-
from zenml import ExternalArtifact, constants
30+
from zenml import ExternalArtifact
3131
from zenml.client import Client
3232
from zenml.logger import get_logger
3333
from zenml.logging.step_logging import setup_pipeline_logging
3434
from zenml.models import ArtifactVersionResponse, PipelineRunResponse
3535
from zenml.pipelines.pipeline_definition import Pipeline
36-
from zenml.pipelines.run_utils import create_placeholder_run
36+
from zenml.pipelines.run_utils import (
37+
create_placeholder_run,
38+
should_prevent_pipeline_execution,
39+
)
3740
from zenml.steps.step_invocation import StepInvocation
3841
from zenml.utils import dict_utils, pydantic_utils
3942

@@ -49,6 +52,10 @@
4952
class DynamicPipeline(Pipeline):
5053
"""ZenML pipeline class."""
5154

55+
@property
56+
def depends_on(self) -> List["BaseStep"]:
57+
return []
58+
5259
@property
5360
def is_dynamic(self) -> bool:
5461
"""If the pipeline is dynamic.
@@ -167,15 +174,15 @@ def add_dynamic_invocation(
167174
def __call__(
168175
self, *args: Any, **kwargs: Any
169176
) -> Optional[PipelineRunResponse]:
170-
should_prevent_execution = constants.SHOULD_PREVENT_PIPELINE_EXECUTION
171-
172-
if should_prevent_execution:
173-
logger.warning("Preventing execution of pipeline '%s'.", self.name)
177+
if should_prevent_pipeline_execution():
178+
logger.info("Preventing execution of pipeline '%s'.", self.name)
174179
return
175180

176-
if not Client().active_stack.orchestrator.supports_dynamic_pipelines:
181+
stack = Client().active_stack
182+
183+
if not stack.orchestrator.supports_dynamic_pipelines:
177184
raise RuntimeError(
178-
f"The {Client().active_stack.orchestrator.__class__.__name__} does not support dynamic pipelines. "
185+
f"The {stack.orchestrator.__class__.__name__} does not support dynamic pipelines. "
179186
)
180187

181188
self.prepare(*args, **kwargs)
@@ -186,12 +193,11 @@ def __call__(
186193
) as logs_request:
187194
run = create_placeholder_run(
188195
snapshot=snapshot,
189-
orchestrator_run_id=str(uuid4()),
190196
logs=logs_request,
191197
)
192-
Client().active_stack.orchestrator.run(
198+
stack.orchestrator.run(
193199
snapshot=snapshot,
194-
stack=Client().active_stack,
200+
stack=stack,
195201
placeholder_run=run,
196202
)
197203
return Client().get_pipeline_run(run.id)

0 commit comments

Comments
 (0)