Skip to content

Commit 0a27ff6

Browse files
committed
More fixes
1 parent f87f325 commit 0a27ff6

File tree

6 files changed

+167
-18
lines changed

6 files changed

+167
-18
lines changed

src/zenml/orchestrators/step_launcher.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -459,21 +459,26 @@ def _run_step(
459459
step_operator_name=step_operator_name,
460460
step_run_info=step_run_info,
461461
)
462+
elif not self._snapshot.is_dynamic:
463+
self._run_step_in_current_thread(
464+
pipeline_run=pipeline_run,
465+
step_run=step_run,
466+
step_run_info=step_run_info,
467+
input_artifacts=step_run.regular_inputs,
468+
output_artifact_uris=output_artifact_uris,
469+
)
462470
else:
463471
from zenml.pipelines.dynamic.runner import (
464472
should_run_in_process,
465473
)
466474

467-
should_run_out_of_process = (
468-
self._snapshot.is_dynamic
469-
and self._step.config.in_process is False
470-
)
471-
472475
if should_run_in_process(
473476
self._step,
474477
self._snapshot.pipeline_configuration.docker_settings,
475478
):
476-
if should_run_out_of_process:
479+
if self._step.config.in_process is False:
480+
# The step was configured to run out of process, but
481+
# the orchestrator doesn't support it.
477482
logger.warning(
478483
"The %s does not support running dynamic out of "
479484
"process steps. Running step `%s` in current "

src/zenml/pipelines/dynamic/context.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import contextvars
2-
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Self
2+
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Self, cast
33

44
if TYPE_CHECKING:
55
from zenml.models import PipelineRunResponse, PipelineSnapshotResponse
@@ -8,20 +8,38 @@
88

99

1010
class BaseContext:
11+
"""Base context class."""
12+
1113
__context_var__: ClassVar[contextvars.ContextVar[Self]]
1214

13-
def __init__(self, *args: Any, **kwargs: Any) -> None:
15+
def __init__(self) -> None:
16+
"""Initialize the context."""
1417
self._token: Optional[contextvars.Token[Any]] = None
1518

1619
@classmethod
1720
def get(cls: type[Self]) -> Optional[Self]:
18-
return cls.__context_var__.get(None)
21+
"""Get the active context for the current thread.
22+
23+
Returns:
24+
The active context for the current thread.
25+
"""
26+
return cast(Optional[Self], cls.__context_var__.get(None))
1927

2028
def __enter__(self) -> Self:
29+
"""Enter the context.
30+
31+
Returns:
32+
The context object.
33+
"""
2134
self._token = self.__context_var__.set(self)
2235
return self
2336

2437
def __exit__(self, *_: Any) -> None:
38+
"""Exit the context.
39+
40+
Raises:
41+
RuntimeError: If the context has not been entered.
42+
"""
2543
if not self._token:
2644
raise RuntimeError(
2745
f"Can't exit {self.__class__.__name__} because it has not been "
@@ -31,6 +49,8 @@ def __exit__(self, *_: Any) -> None:
3149

3250

3351
class DynamicPipelineRunContext(BaseContext):
52+
"""Dynamic pipeline run context."""
53+
3454
__context_var__ = contextvars.ContextVar("dynamic_pipeline_run_context")
3555

3656
def __init__(
@@ -40,6 +60,14 @@ def __init__(
4060
run: "PipelineRunResponse",
4161
runner: "DynamicPipelineRunner",
4262
) -> None:
63+
"""Initialize the dynamic pipeline run context.
64+
65+
Args:
66+
pipeline: The dynamic pipeline that is being executed.
67+
snapshot: The snapshot of the pipeline.
68+
run: The pipeline run.
69+
runner: The dynamic pipeline runner.
70+
"""
4371
super().__init__()
4472
self._pipeline = pipeline
4573
self._snapshot = snapshot
@@ -63,6 +91,15 @@ def runner(self) -> "DynamicPipelineRunner":
6391
return self._runner
6492

6593
def __enter__(self) -> Self:
94+
"""Enter the dynamic pipeline run context.
95+
96+
Raises:
97+
RuntimeError: If the dynamic pipeline run context has already been
98+
entered.
99+
100+
Returns:
101+
The dynamic pipeline run context object.
102+
"""
66103
if self._token is not None:
67104
raise RuntimeError(
68105
"Calling a pipeline within a dynamic pipeline is not allowed."

src/zenml/pipelines/dynamic/pipeline_definition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def __call__(
236236
"""
237237
if should_prevent_pipeline_execution():
238238
logger.info("Preventing execution of pipeline '%s'.", self.name)
239-
return
239+
return None
240240

241241
stack = Client().active_stack
242242
if not stack.orchestrator.supports_dynamic_pipelines:
@@ -286,7 +286,7 @@ def _compute_output_schema(self) -> Optional[Dict[str, Any]]:
286286
"""
287287
try:
288288
outputs = parse_return_type_annotations(self.entrypoint)
289-
model_fields = {
289+
model_fields: Dict[str, Any] = {
290290
name: (output.resolved_annotation, ...)
291291
for name, output in outputs.items()
292292
}

src/zenml/pipelines/dynamic/runner.py

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from zenml.pipelines.run_utils import create_placeholder_run
4040
from zenml.stack import Stack
4141
from zenml.steps.entrypoint_function_utils import StepArtifact
42+
from zenml.steps.utils import OutputSignature
4243
from zenml.utils import source_utils
4344

4445
if TYPE_CHECKING:
@@ -51,6 +52,8 @@
5152

5253

5354
class DynamicStepRunOutput(ArtifactVersionResponse):
55+
"""Dynamic step run output artifact."""
56+
5457
output_name: str
5558
step_name: str
5659

@@ -61,17 +64,41 @@ class DynamicStepRunOutput(ArtifactVersionResponse):
6164

6265

6366
class StepRunOutputsFuture:
64-
def __init__(self, wrapped: Future[StepRunOutputs], invocation_id: str):
67+
"""Future for a step run output."""
68+
69+
def __init__(
70+
self, wrapped: Future[StepRunOutputs], invocation_id: str
71+
) -> None:
72+
"""Initialize the future.
73+
74+
Args:
75+
wrapped: The wrapped future object.
76+
invocation_id: The invocation ID of the step run.
77+
"""
6578
self._wrapped = wrapped
6679
self._invocation_id = invocation_id
6780

6881
def wait(self) -> None:
82+
"""Wait for the future to complete."""
6983
self._wrapped.result()
7084

7185
def result(self) -> StepRunOutputs:
86+
"""Get the step run output artifacts.
87+
88+
Returns:
89+
The step run output artifacts.
90+
"""
7291
return self._wrapped.result()
7392

7493
def load(self) -> Any:
94+
"""Get the step run output artifact data.
95+
96+
Raises:
97+
ValueError: If the step run output is invalid.
98+
99+
Returns:
100+
The step run output artifact data.
101+
"""
75102
result = self.result()
76103

77104
if result is None:
@@ -153,6 +180,7 @@ def run_pipeline(self) -> None:
153180
):
154181
self._orchestrator.run_init_hook(snapshot=self._snapshot)
155182
try:
183+
# TODO: step logging isn't threadsafe
156184
self.pipeline._call_entrypoint(**pipeline_parameters)
157185
except:
158186
publish_failed_pipeline_run(run.id)
@@ -240,6 +268,20 @@ def _prepare_step_run(
240268
"StepRunOutputsFuture", Sequence["StepRunOutputsFuture"], None
241269
] = None,
242270
) -> Tuple[Dict[str, Any], Set[str]]:
271+
"""Prepare a step run.
272+
273+
Args:
274+
step: The step to prepare.
275+
args: The arguments for the step function.
276+
kwargs: The keyword arguments for the step function.
277+
after: The step run output futures to wait for.
278+
279+
Returns:
280+
A tuple containing the inputs and the upstream steps.
281+
282+
Raises:
283+
ValueError: If an invalid step function input was passed.
284+
"""
243285
upstream_steps = set()
244286

245287
if isinstance(after, StepRunOutputsFuture):
@@ -250,7 +292,7 @@ def _prepare_step_run(
250292
item.wait()
251293
upstream_steps.add(item._invocation_id)
252294

253-
def _await_and_validate_input(input: Any):
295+
def _await_and_validate_input(input: Any) -> Any:
254296
if isinstance(input, StepRunOutputsFuture):
255297
input = input.result()
256298

@@ -260,15 +302,16 @@ def _await_and_validate_input(input: Any):
260302
and isinstance(input[0], DynamicStepRunOutput)
261303
):
262304
raise ValueError(
263-
"Passing multiple step run outputs to another step is not allowed."
305+
"Passing multiple step run outputs to another step is not "
306+
"allowed."
264307
)
265308

266309
if isinstance(input, DynamicStepRunOutput):
267310
upstream_steps.add(input.step_name)
268311

269312
return input
270313

271-
args = [_await_and_validate_input(arg) for arg in args]
314+
args = tuple(_await_and_validate_input(arg) for arg in args)
272315
kwargs = {
273316
key: _await_and_validate_input(value) for key, value in kwargs.items()
274317
}
@@ -295,7 +338,7 @@ def _compile_step(
295338
input_artifacts[name] = StepArtifact(
296339
invocation_id=value.step_name,
297340
output_name=value.output_name,
298-
annotation=Any,
341+
annotation=OutputSignature(resolved_annotation=Any),
299342
pipeline=pipeline,
300343
)
301344
elif isinstance(value, (ArtifactVersionResponse, ExternalArtifact)):
@@ -335,6 +378,18 @@ def _run_step_sync(
335378
orchestrator_run_id: str,
336379
retry: bool = False,
337380
) -> StepRunResponse:
381+
"""Run a step in the active thread.
382+
383+
Args:
384+
snapshot: The snapshot.
385+
step: The step to run.
386+
orchestrator_run_id: The orchestrator run ID.
387+
retry: Whether to retry the step if it fails.
388+
389+
Returns:
390+
The step run response.
391+
"""
392+
338393
def _launch_step() -> StepRunResponse:
339394
launcher = StepLauncher(
340395
snapshot=snapshot,
@@ -383,6 +438,14 @@ def _launch_step() -> StepRunResponse:
383438

384439

385440
def _load_step_outputs(step_run_id: UUID) -> StepRunOutputs:
441+
"""Load the outputs of a step run.
442+
443+
Args:
444+
step_run_id: The ID of the step run.
445+
446+
Returns:
447+
The outputs of the step run.
448+
"""
386449
step_run = Client().zen_store.get_run_step(step_run_id)
387450

388451
def _convert_output_artifact(
@@ -410,6 +473,15 @@ def _convert_output_artifact(
410473
def _should_retry_locally(
411474
step: "Step", pipeline_docker_settings: "DockerSettings"
412475
) -> bool:
476+
"""Determine if a step should be retried locally.
477+
478+
Args:
479+
step: The step.
480+
pipeline_docker_settings: The Docker settings of the parent pipeline.
481+
482+
Returns:
483+
Whether the step should be retried locally.
484+
"""
413485
if step.config.step_operator:
414486
return True
415487

@@ -425,6 +497,15 @@ def _should_retry_locally(
425497
def should_run_in_process(
426498
step: "Step", pipeline_docker_settings: "DockerSettings"
427499
) -> bool:
500+
"""Determine if a step should be run in process.
501+
502+
Args:
503+
step: The step.
504+
pipeline_docker_settings: The Docker settings of the parent pipeline.
505+
506+
Returns:
507+
Whether the step should be run in process.
508+
"""
428509
if step.config.step_operator:
429510
return False
430511

@@ -448,6 +529,16 @@ def get_config_template(
448529
step: "BaseStep",
449530
pipeline: "DynamicPipeline",
450531
) -> Optional["Step"]:
532+
"""Get the config template for a step executed in a dynamic pipeline.
533+
534+
Args:
535+
snapshot: The snapshot of the pipeline.
536+
step: The step to get the config template for.
537+
pipeline: The dynamic pipeline that the step is being executed in.
538+
539+
Returns:
540+
The config template for the step.
541+
"""
451542
for index, step_ in enumerate(pipeline.depends_on):
452543
if step_._static_id == step._static_id:
453544
break

src/zenml/pipelines/run_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,16 @@
22

33
import time
44
from contextlib import contextmanager
5-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
Dict,
9+
Generator,
10+
List,
11+
Optional,
12+
Set,
13+
Union,
14+
)
615
from uuid import UUID
716

817
from pydantic import BaseModel
@@ -60,7 +69,7 @@ def should_prevent_pipeline_execution() -> bool:
6069

6170

6271
@contextmanager
63-
def prevent_pipeline_execution():
72+
def prevent_pipeline_execution() -> Generator[None, None, None]:
6473
"""Context manager to prevent pipeline execution."""
6574
with env_utils.temporary_environment(
6675
{ENV_ZENML_PREVENT_PIPELINE_EXECUTION: "True"}

src/zenml/steps/base_step.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
Type,
3131
TypeVar,
3232
Union,
33+
cast,
3334
)
3435
from uuid import UUID
3536

@@ -489,6 +490,12 @@ def __call__(
489490
from zenml.pipelines.pipeline_definition import Pipeline
490491

491492
if context := DynamicPipelineRunContext.get():
493+
after = cast(
494+
Union[
495+
"StepRunOutputsFuture", Sequence["StepRunOutputsFuture"], None
496+
],
497+
after,
498+
)
492499
return context.runner.run_step_sync(self, id, args, kwargs, after)
493500

494501
if not Pipeline.ACTIVE_PIPELINE:

0 commit comments

Comments
 (0)