Skip to content

Commit 7956099

Browse files
committed
Dynamic config
1 parent 34aeb41 commit 7956099

File tree

4 files changed

+45
-18
lines changed

4 files changed

+45
-18
lines changed

src/zenml/pipelines/dynamic/pipeline_definition.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,19 @@
5050
class DynamicPipeline(Pipeline):
5151
"""ZenML pipeline class."""
5252

53+
def __init__(self, *args: Any, **kwargs: Any) -> None:
54+
self._depends_on = kwargs.pop("depends_on", None) or []
55+
if self._depends_on:
56+
# TODO: This doesn't really work, as `step.with_options`
57+
sources = [step.resolve().import_path for step in self._depends_on]
58+
if len(sources) != len(set(sources)):
59+
raise ValueError("Duplicate steps in depends_on.")
60+
61+
super().__init__(*args, **kwargs)
62+
5363
@property
5464
def depends_on(self) -> List["BaseStep"]:
55-
# TODO: Even with this, it will not be possible to define all potential
56-
# docker builds:
57-
# If a step will be called multiple times, once without step operator
58-
# and once with, it might require multiple docker images. Even if this
59-
# list had two copies of the same step, how would we map at runtime
60-
# which one to use?
61-
# TODO: maybe this needs to be a dict, and steps can select which
62-
# "template" config (including the docker image) to use?
63-
return getattr(self, "_depends_on", [])
65+
return self._depends_on
6466

6567
@property
6668
def is_dynamic(self) -> bool:
@@ -148,7 +150,7 @@ def pipeline_(param_name: str):
148150
self._parameters = validated_args
149151
self._invocations = {}
150152
with self:
151-
for step in self.depends_on:
153+
for step in self._depends_on:
152154
self.add_step_invocation(
153155
step,
154156
input_artifacts={},
@@ -181,7 +183,7 @@ def add_dynamic_invocation(
181183
external_artifacts=external_artifacts,
182184
model_artifacts_or_metadata={},
183185
client_lazy_loaders={},
184-
parameters={},
186+
parameters=step.configuration.parameters,
185187
default_parameters={},
186188
upstream_steps=upstream_steps or set(),
187189
pipeline=self,

src/zenml/pipelines/dynamic/runner.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from zenml import ExternalArtifact
1919
from zenml.client import Client
2020
from zenml.config.compiler import Compiler
21-
from zenml.config.step_configurations import Step, StepConfiguration
21+
from zenml.config.step_configurations import Step
2222
from zenml.exceptions import RunStoppedException
2323
from zenml.logger import get_logger
2424
from zenml.logging.step_logging import setup_pipeline_logging
@@ -281,14 +281,15 @@ def _compile_step(
281281
annotation=Any,
282282
pipeline=pipeline,
283283
)
284-
elif isinstance(value, ArtifactVersionResponse):
284+
elif isinstance(value, (ArtifactVersionResponse, ExternalArtifact)):
285285
external_artifacts[name] = value
286286
else:
287287
external_artifacts[name] = ExternalArtifact(value=value)
288288

289-
if config := get_dynamic_step_configuration(snapshot, step):
290-
step._configuration = config
289+
if template := get_static_step_template(snapshot, step):
290+
step._configuration = template.config
291291

292+
step._apply_dynamic_configuration()
292293
invocation_id = pipeline.add_dynamic_invocation(
293294
step=step,
294295
custom_id=id,
@@ -413,14 +414,14 @@ def _runs_in_process(step: "Step") -> bool:
413414
return True
414415

415416

416-
def get_dynamic_step_configuration(
417+
def get_static_step_template(
417418
snapshot: "PipelineSnapshotResponse",
418419
step: "BaseStep",
419-
) -> Optional["StepConfiguration"]:
420+
) -> Optional["Step"]:
420421
step_source = step.resolve().import_path
421422

422423
for compiled_step in snapshot.step_configurations.values():
423424
if compiled_step.spec.source.import_path == step_source:
424-
return compiled_step.config
425+
return compiled_step
425426

426427
return None

src/zenml/pipelines/pipeline_decorator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from zenml.config.retry_config import StepRetryConfig
3636
from zenml.model.model import Model
3737
from zenml.pipelines.pipeline_definition import Pipeline
38+
from zenml.steps.base_step import BaseStep
3839
from zenml.types import HookSpecification, InitHookSpecification
3940
from zenml.utils.tag_utils import Tag
4041

@@ -51,6 +52,7 @@ def pipeline(_func: "F") -> "Pipeline": ...
5152
def pipeline(
5253
*,
5354
name: Optional[str] = None,
55+
depends_on: Optional[List["BaseStep"]] = None,
5456
enable_cache: Optional[bool] = None,
5557
enable_artifact_metadata: Optional[bool] = None,
5658
enable_step_logs: Optional[bool] = None,
@@ -77,6 +79,7 @@ def pipeline(
7779
_func: Optional["F"] = None,
7880
*,
7981
name: Optional[str] = None,
82+
depends_on: Optional[List["BaseStep"]] = None,
8083
enable_cache: Optional[bool] = None,
8184
enable_artifact_metadata: Optional[bool] = None,
8285
enable_step_logs: Optional[bool] = None,
@@ -142,6 +145,7 @@ def inner_decorator(func: "F") -> "Pipeline":
142145

143146
p = DynamicPipeline(
144147
name=name or func.__name__,
148+
depends_on=depends_on,
145149
entrypoint=func,
146150
enable_cache=enable_cache,
147151
enable_artifact_metadata=enable_artifact_metadata,

src/zenml/steps/base_step.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def __init__(
201201
)
202202

203203
self._configuration = PartialStepConfiguration(name=name)
204+
self._dynamic_configuration: Optional["StepConfigurationUpdate"] = None
204205
self.configure(
205206
enable_cache=enable_cache,
206207
enable_artifact_metadata=enable_artifact_metadata,
@@ -897,15 +898,34 @@ def _apply_configuration(
897898
or not. See the `BaseStep.configure(...)` method for a detailed
898899
explanation.
899900
"""
901+
from zenml.pipelines.dynamic.context import DynamicPipelineRunContext
902+
900903
self._validate_configuration(config, runtime_parameters)
901904

905+
if DynamicPipelineRunContext.get():
906+
if self._dynamic_configuration is None:
907+
self._dynamic_configuration = config
908+
else:
909+
self._dynamic_configuration = pydantic_utils.update_model(
910+
self._dynamic_configuration, update=config, recursive=merge
911+
)
912+
return
913+
902914
self._configuration = pydantic_utils.update_model(
903915
self._configuration, update=config, recursive=merge
904916
)
905917

906918
logger.debug("Updated step configuration:")
907919
logger.debug(self._configuration)
908920

921+
def _apply_dynamic_configuration(self) -> None:
922+
if self._dynamic_configuration:
923+
self._configuration = pydantic_utils.update_model(
924+
self._configuration,
925+
update=self._dynamic_configuration,
926+
recursive=True,
927+
)
928+
909929
def _validate_configuration(
910930
self,
911931
config: "StepConfigurationUpdate",

0 commit comments

Comments
 (0)