Skip to content

Commit b13c693

Browse files
committed
Refactoring
1 parent 438b003 commit b13c693

File tree

10 files changed

+321
-337
lines changed

10 files changed

+321
-337
lines changed

src/zenml/config/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,15 +489,15 @@ def _compile_step_invocation(
489489
invocation.step = copy.deepcopy(invocation.step)
490490

491491
step = invocation.step
492-
with step._skip_dynamic_configuration():
492+
with step._suspend_dynamic_configuration():
493493
if step_config:
494494
step._apply_configuration(
495495
step_config, runtime_parameters=invocation.parameters
496496
)
497497

498498
# Apply the dynamic configuration (which happened while executing the
499499
# pipeline function) after all other step-specific configurations.
500-
step._apply_dynamic_configuration()
500+
step._merge_dynamic_configuration()
501501

502502
convert_component_shortcut_settings_keys(
503503
step.configuration.settings, stack=stack

src/zenml/orchestrators/base_orchestrator.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,18 @@ def submit_dynamic_pipeline(
213213
environment: Dict[str, str],
214214
placeholder_run: Optional["PipelineRunResponse"] = None,
215215
) -> Optional[SubmissionResult]:
216-
"""Submits a dynamic pipeline to the orchestrator."""
216+
"""Submits a dynamic pipeline to the orchestrator.
217+
218+
Args:
219+
snapshot: The pipeline snapshot to submit.
220+
stack: The stack the pipeline will run on.
221+
environment: Environment variables to set in the orchestration
222+
environment.
223+
placeholder_run: An optional placeholder run.
224+
225+
Returns:
226+
Optional submission result.
227+
"""
217228
return None
218229

219230
def prepare_or_run_pipeline(
@@ -429,6 +440,8 @@ def run_step(
429440
"""
430441
from zenml.pipelines.dynamic.runner import _run_step_sync
431442

443+
assert self._active_snapshot
444+
432445
_run_step_sync(
433446
snapshot=self._active_snapshot,
434447
step=step,
@@ -438,23 +451,41 @@ def run_step(
438451

439452
@property
440453
def supports_dynamic_pipelines(self) -> bool:
454+
"""Whether the orchestrator supports dynamic pipelines.
455+
456+
Returns:
457+
Whether the orchestrator supports dynamic pipelines.
458+
"""
441459
return (
442460
getattr(self.submit_dynamic_pipeline, "__func__", None)
443461
is not BaseOrchestrator.submit_dynamic_pipeline
444462
)
445463

446464
@property
447-
def supports_dynamic_out_of_process_steps(self) -> bool:
465+
def can_launch_dynamic_steps(self) -> bool:
466+
"""Whether the orchestrator can launch dynamic steps.
467+
468+
Returns:
469+
Whether the orchestrator can launch dynamic steps.
470+
"""
448471
return (
449-
getattr(self.run_dynamic_out_of_process_step, "__func__", None)
450-
is not BaseOrchestrator.run_dynamic_out_of_process_step
472+
getattr(self.launch_dynamic_step, "__func__", None)
473+
is not BaseOrchestrator.launch_dynamic_step
451474
)
452475

453-
def run_dynamic_out_of_process_step(
476+
def launch_dynamic_step(
454477
self, step_run_info: "StepRunInfo", environment: Dict[str, str]
455478
) -> None:
479+
"""Launch a dynamic step.
480+
481+
Args:
482+
step_run_info: The step run information.
483+
environment: The environment variables to set in the execution
484+
environment.
485+
"""
456486
raise NotImplementedError(
457-
"Running dynamic out of process steps is not implemented for the orchestrator."
487+
"Launching dynamic steps is not implemented for "
488+
f"the {self.__class__.__name__} orchestrator."
458489
)
459490

460491
@staticmethod

src/zenml/orchestrators/step_launcher.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -481,8 +481,7 @@ def _run_step(
481481
# the orchestrator doesn't support it.
482482
logger.warning(
483483
"The %s does not support running dynamic out of "
484-
"process steps. Running step `%s` in current "
485-
"thread instead.",
484+
"process steps. Running step `%s` locally instead.",
486485
self._stack.orchestrator.__class__.__name__,
487486
self._invocation_id,
488487
)
@@ -574,7 +573,7 @@ def _run_step_with_dynamic_orchestrator(
574573
stack=self._stack,
575574
)
576575
)
577-
self._stack.orchestrator.run_dynamic_out_of_process_step(
576+
self._stack.orchestrator.launch_dynamic_step(
578577
step_run_info=step_run_info,
579578
environment=environment,
580579
)

src/zenml/pipelines/dynamic/pipeline_definition.py

Lines changed: 44 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -19,54 +19,70 @@
1919
Dict,
2020
List,
2121
Optional,
22-
Set,
2322
Type,
24-
Union,
2523
)
2624

27-
from pydantic import BaseModel, ConfigDict, ValidationError, create_model
25+
from pydantic import BaseModel, ConfigDict, create_model
2826

29-
from zenml import ExternalArtifact
3027
from zenml.client import Client
3128
from zenml.logger import get_logger
32-
from zenml.models import ArtifactVersionResponse, PipelineRunResponse
29+
from zenml.models import PipelineRunResponse
3330
from zenml.pipelines.pipeline_definition import Pipeline
3431
from zenml.pipelines.run_utils import (
3532
should_prevent_pipeline_execution,
3633
)
37-
from zenml.steps.step_invocation import StepInvocation
3834
from zenml.steps.utils import (
3935
parse_return_type_annotations,
4036
)
41-
from zenml.utils import dict_utils, pydantic_utils
4237

4338
if TYPE_CHECKING:
4439
from zenml.steps import BaseStep
45-
from zenml.steps.entrypoint_function_utils import StepArtifact
4640

4741
logger = get_logger(__name__)
4842

4943

5044
class DynamicPipeline(Pipeline):
5145
"""Dynamic pipeline class."""
5246

53-
def __init__(self, *args: Any, **kwargs: Any) -> None:
47+
def __init__(
48+
self,
49+
*args: Any,
50+
depends_on: Optional[List["BaseStep"]] = None,
51+
**kwargs: Any,
52+
) -> None:
5453
"""Initialize the pipeline.
5554
5655
Args:
5756
*args: Pipeline constructor arguments.
57+
depends_on: The steps that the pipeline depends on.
5858
**kwargs: Pipeline constructor keyword arguments.
59+
"""
60+
super().__init__(*args, **kwargs)
61+
self._depends_on = depends_on or []
62+
self._validate_depends_on(self._depends_on)
63+
64+
def _validate_depends_on(self, depends_on: List["BaseStep"]) -> None:
65+
"""Validates the steps that the pipeline depends on.
66+
67+
Args:
68+
depends_on: The steps that the pipeline depends on.
5969
6070
Raises:
61-
ValueError: If some of the steps in `depends_on` are duplicated.
71+
RuntimeError: If some of the steps in `depends_on` are duplicated.
6272
"""
63-
self._depends_on = kwargs.pop("depends_on", None) or []
64-
if self._depends_on:
65-
static_ids = [step._static_id for step in self._depends_on]
66-
if len(static_ids) != len(set(static_ids)):
67-
raise ValueError("Duplicate steps in depends_on.")
73+
static_ids = set()
74+
for step in depends_on:
75+
static_id = step._static_id
76+
if static_id in static_ids:
77+
raise RuntimeError(
78+
f"The pipeline {self.name} depends on the same step "
79+
f"({step.name}) multiple times. To fix this, remove the "
80+
"duplicate from the `depends_on` list. You can pass the "
81+
"same step function with multiple configurations by using "
82+
"the `step.with_options(...)` method."
83+
)
6884

69-
super().__init__(*args, **kwargs)
85+
static_ids.add(static_id)
7086

7187
@property
7288
def depends_on(self) -> List["BaseStep"]:
@@ -86,137 +102,19 @@ def is_dynamic(self) -> bool:
86102
"""
87103
return True
88104

89-
@property
90-
def is_prepared(self) -> bool:
91-
"""If the pipeline is prepared.
92-
93-
Prepared means that the pipeline entrypoint has been called and the
94-
pipeline is fully defined.
95-
96-
Returns:
97-
If the pipeline is prepared.
98-
"""
99-
return False
100-
101-
def prepare(self, *args: Any, **kwargs: Any) -> None:
102-
"""Prepares the pipeline.
103-
104-
Args:
105-
*args: Pipeline entrypoint input arguments.
106-
**kwargs: Pipeline entrypoint input keyword arguments.
107-
108-
Raises:
109-
RuntimeError: If the pipeline has parameters configured differently in
110-
configuration file and code.
111-
"""
112-
conflicting_parameters = {}
113-
parameters_ = (self.configuration.parameters or {}).copy()
114-
if from_file_ := self._from_config_file.get("parameters", None):
115-
parameters_ = dict_utils.recursive_update(parameters_, from_file_)
116-
if parameters_:
117-
for k, v_runtime in kwargs.items():
118-
if k in parameters_:
119-
v_config = parameters_[k]
120-
if v_config != v_runtime:
121-
conflicting_parameters[k] = (v_config, v_runtime)
122-
if conflicting_parameters:
123-
is_plural = "s" if len(conflicting_parameters) > 1 else ""
124-
msg = f"Configured parameter{is_plural} for the pipeline `{self.name}` conflict{'' if not is_plural else 's'} with parameter{is_plural} passed in runtime:\n"
125-
for key, values in conflicting_parameters.items():
126-
msg += f"`{key}`: config=`{values[0]}` | runtime=`{values[1]}`\n"
127-
msg += """This happens, if you define values for pipeline parameters in configuration file and pass same parameters from the code. Example:
128-
```
129-
# config.yaml
130-
parameters:
131-
param_name: value1
132-
133-
134-
# pipeline.py
135-
@pipeline
136-
def pipeline_(param_name: str):
137-
step_name()
138-
139-
if __name__=="__main__":
140-
pipeline_.with_options(config_path="config.yaml")(param_name="value2")
141-
```
142-
To avoid this consider setting pipeline parameters only in one place (config or code).
143-
"""
144-
raise RuntimeError(msg)
145-
for k, v_config in parameters_.items():
146-
if k not in kwargs:
147-
kwargs[k] = v_config
148-
149-
try:
150-
validated_args = pydantic_utils.validate_function_args(
151-
self.entrypoint,
152-
ConfigDict(arbitrary_types_allowed=False),
153-
*args,
154-
**kwargs,
105+
def _prepare_invocations(self, **kwargs: Any) -> None:
106+
"""Prepares the invocations of the pipeline."""
107+
for step in self._depends_on:
108+
self.add_step_invocation(
109+
step,
110+
input_artifacts={},
111+
external_artifacts={},
112+
model_artifacts_or_metadata={},
113+
client_lazy_loaders={},
114+
parameters={},
115+
default_parameters={},
116+
upstream_steps=set(),
155117
)
156-
except ValidationError as e:
157-
raise ValueError(
158-
"Invalid or missing pipeline function entrypoint arguments. "
159-
"Only JSON serializable inputs are allowed as pipeline inputs. "
160-
"Check out the pydantic error above for more details."
161-
) from e
162-
163-
self._parameters = validated_args
164-
self._invocations = {}
165-
with self:
166-
for step in self._depends_on:
167-
self.add_step_invocation(
168-
step,
169-
input_artifacts={},
170-
external_artifacts={},
171-
model_artifacts_or_metadata={},
172-
client_lazy_loaders={},
173-
parameters={},
174-
default_parameters={},
175-
upstream_steps=set(),
176-
)
177-
178-
def add_dynamic_invocation(
179-
self,
180-
step: "BaseStep",
181-
custom_id: Optional[str] = None,
182-
allow_id_suffix: bool = True,
183-
upstream_steps: Optional[Set[str]] = None,
184-
input_artifacts: Dict[str, "StepArtifact"] = {},
185-
external_artifacts: Dict[
186-
str, Union[ExternalArtifact, "ArtifactVersionResponse"]
187-
] = {},
188-
) -> str:
189-
"""Adds a dynamic invocation to the pipeline.
190-
191-
Args:
192-
step: The step for which to add an invocation.
193-
custom_id: Custom ID to use for the invocation.
194-
allow_id_suffix: Whether a suffix can be appended to the invocation
195-
ID.
196-
upstream_steps: The upstream steps for the invocation.
197-
input_artifacts: The input artifacts for the invocation.
198-
external_artifacts: The external artifacts for the invocation.
199-
200-
Returns:
201-
The invocation ID.
202-
"""
203-
invocation_id = self._compute_invocation_id(
204-
step=step, custom_id=custom_id, allow_suffix=allow_id_suffix
205-
)
206-
invocation = StepInvocation(
207-
id=invocation_id,
208-
step=step,
209-
input_artifacts=input_artifacts,
210-
external_artifacts=external_artifacts,
211-
model_artifacts_or_metadata={},
212-
client_lazy_loaders={},
213-
parameters=step.configuration.parameters,
214-
default_parameters={},
215-
upstream_steps=upstream_steps or set(),
216-
pipeline=self,
217-
)
218-
self._invocations[invocation_id] = invocation
219-
return invocation_id
220118

221119
def __call__(
222120
self, *args: Any, **kwargs: Any
@@ -248,36 +146,6 @@ def __call__(
248146
self.prepare(*args, **kwargs)
249147
return self._run()
250148

251-
def _call_entrypoint(self, *args: Any, **kwargs: Any) -> None:
252-
"""Calls the pipeline entrypoint function with the given arguments.
253-
254-
Args:
255-
*args: Entrypoint function arguments.
256-
**kwargs: Entrypoint function keyword arguments.
257-
258-
Raises:
259-
ValueError: If an input argument is missing or not JSON
260-
serializable.
261-
"""
262-
try:
263-
validated_args = pydantic_utils.validate_function_args(
264-
self.entrypoint,
265-
ConfigDict(arbitrary_types_allowed=False),
266-
*args,
267-
**kwargs,
268-
)
269-
except ValidationError as e:
270-
raise ValueError(
271-
"Invalid or missing pipeline function entrypoint arguments. "
272-
"Only JSON serializable inputs are allowed as pipeline inputs. "
273-
"Check out the pydantic error above for more details."
274-
) from e
275-
276-
# Clear the invocations as they might still contain invocations from
277-
# the compilation phase.
278-
self._invocations = {}
279-
self.entrypoint(**validated_args)
280-
281149
def _compute_output_schema(self) -> Optional[Dict[str, Any]]:
282150
"""Computes the output schema for the pipeline.
283151

0 commit comments

Comments
 (0)