Skip to content

Commit 5162911

Browse files
committed
Split futures
1 parent f63e3e4 commit 5162911

File tree

4 files changed

+125
-52
lines changed

4 files changed

+125
-52
lines changed

src/zenml/execution/pipeline/dynamic/outputs.py

Lines changed: 90 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,13 @@
1414
"""Dynamic pipeline execution outputs."""
1515

1616
from concurrent.futures import Future
17-
from typing import (
18-
TYPE_CHECKING,
19-
Any,
20-
Tuple,
21-
Union,
22-
)
17+
from typing import Any, List, Tuple, Union
2318

2419
from zenml.logger import get_logger
2520
from zenml.models import (
2621
ArtifactVersionResponse,
2722
)
2823

29-
3024
logger = get_logger(__name__)
3125

3226

@@ -37,18 +31,61 @@ class OutputArtifact(ArtifactVersionResponse):
3731
step_name: str
3832

3933

40-
StepRunOutputs = Union[
41-
None, OutputArtifact, Tuple[OutputArtifact, ...]
42-
]
34+
StepRunOutputs = Union[None, OutputArtifact, Tuple[OutputArtifact, ...]]
35+
36+
37+
class ArtifactFuture:
38+
"""Future for a step run output artifact."""
39+
40+
def __init__(
41+
self, wrapped: Future[StepRunOutputs], invocation_id: str, index: int
42+
) -> None:
43+
"""Initialize the future.
44+
45+
Args:
46+
wrapped: The wrapped future object.
47+
invocation_id: The invocation ID of the step run.
48+
"""
49+
self._wrapped = wrapped
50+
self._invocation_id = invocation_id
51+
self._index = index
52+
53+
def _wait(self) -> None:
54+
self._wrapped.result()
55+
56+
def result(self) -> OutputArtifact:
57+
"""Get the step run output artifact.
58+
59+
Returns:
60+
The step run output artifact.
61+
"""
62+
result = self._wrapped.result()
63+
if isinstance(result, OutputArtifact):
64+
return result
65+
elif isinstance(result, tuple):
66+
return result[self._index]
67+
else:
68+
raise RuntimeError(
69+
f"Step {self._invocation_id} returned an invalid output: {result}"
70+
)
71+
72+
def load(self) -> Any:
73+
"""Load the step run output artifact data.
74+
75+
Returns:
76+
The step run output artifact data.
77+
"""
78+
return self.result().load()
4379

4480

45-
# TODO: maybe one future per artifact? But for a step that doesn't return anything, the user wouldn't have a future to wait for.
46-
# Or that step returns a future that returns None? Would be similar to a python function.
4781
class StepRunOutputsFuture:
4882
"""Future for a step run output."""
4983

5084
def __init__(
51-
self, wrapped: Future[StepRunOutputs], invocation_id: str
85+
self,
86+
wrapped: Future[StepRunOutputs],
87+
invocation_id: str,
88+
output_keys: List[str],
5289
) -> None:
5390
"""Initialize the future.
5491
@@ -58,12 +95,12 @@ def __init__(
5895
"""
5996
self._wrapped = wrapped
6097
self._invocation_id = invocation_id
98+
self._output_keys = output_keys
6199

62-
def wait(self) -> None:
63-
"""Wait for the future to complete."""
100+
def _wait(self) -> None:
64101
self._wrapped.result()
65102

66-
def result(self) -> StepRunOutputs:
103+
def artifacts(self) -> StepRunOutputs:
67104
"""Get the step run output artifacts.
68105
69106
Returns:
@@ -80,7 +117,7 @@ def load(self) -> Any:
80117
Returns:
81118
The step run output artifact data.
82119
"""
83-
result = self.result()
120+
result = self.artifacts()
84121

85122
if result is None:
86123
return None
@@ -90,3 +127,39 @@ def load(self) -> Any:
90127
return tuple(item.load() for item in result)
91128
else:
92129
raise ValueError(f"Invalid step run output: {result}")
130+
131+
def __getitem__(self, key: Union[str, int]) -> ArtifactFuture:
132+
if isinstance(key, str):
133+
index = self._output_keys.index(key)
134+
elif isinstance(key, int):
135+
index = key
136+
else:
137+
raise ValueError(f"Invalid key type: {type(key)}")
138+
139+
if index > len(self._output_keys):
140+
raise IndexError(f"Index out of range: {index}")
141+
142+
return ArtifactFuture(
143+
wrapped=self._wrapped,
144+
invocation_id=self._invocation_id,
145+
index=index,
146+
)
147+
148+
def __iter__(self) -> Any:
149+
if not self._output_keys:
150+
raise ValueError(
151+
f"Step {self._invocation_id} does not return any outputs."
152+
)
153+
154+
for index in range(len(self._output_keys)):
155+
yield ArtifactFuture(
156+
wrapped=self._wrapped,
157+
invocation_id=self._invocation_id,
158+
index=index,
159+
)
160+
161+
def __len__(self) -> int:
162+
return len(self._output_keys)
163+
164+
165+
StepRunFuture = Union[ArtifactFuture, StepRunOutputsFuture]

src/zenml/execution/pipeline/dynamic/runner.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
from zenml.config.compiler import Compiler
3636
from zenml.config.step_configurations import Step
3737
from zenml.execution.pipeline.dynamic.outputs import (
38+
ArtifactFuture,
3839
OutputArtifact,
40+
StepRunFuture,
3941
StepRunOutputs,
4042
StepRunOutputsFuture,
4143
)
@@ -153,6 +155,9 @@ def run_pipeline(self) -> None:
153155
self._orchestrator.run_init_hook(snapshot=self._snapshot)
154156
try:
155157
# TODO: step logging isn't threadsafe
158+
# TODO: what should be allowed as pipeline returns?
159+
# (artifacts, json serializable, anything?)
160+
# how do we show it in the UI?
156161
self.pipeline._call_entrypoint(**pipeline_parameters)
157162
except:
158163
publish_failed_pipeline_run(run.id)
@@ -171,9 +176,7 @@ def launch_step(
171176
id: Optional[str],
172177
args: Tuple[Any],
173178
kwargs: Dict[str, Any],
174-
after: Union[
175-
"StepRunOutputsFuture", Sequence["StepRunOutputsFuture"], None
176-
] = None,
179+
after: Union["StepRunFuture", Sequence["StepRunFuture"], None] = None,
177180
concurrent: Literal[False] = False,
178181
) -> StepRunOutputs: ...
179182

@@ -184,9 +187,7 @@ def launch_step(
184187
id: Optional[str],
185188
args: Tuple[Any],
186189
kwargs: Dict[str, Any],
187-
after: Union[
188-
"StepRunOutputsFuture", Sequence["StepRunOutputsFuture"], None
189-
] = None,
190+
after: Union["StepRunFuture", Sequence["StepRunFuture"], None] = None,
190191
concurrent: Literal[True] = True,
191192
) -> "StepRunOutputsFuture": ...
192193

@@ -196,9 +197,7 @@ def launch_step(
196197
id: Optional[str],
197198
args: Tuple[Any],
198199
kwargs: Dict[str, Any],
199-
after: Union[
200-
"StepRunOutputsFuture", Sequence["StepRunOutputsFuture"], None
201-
] = None,
200+
after: Union["StepRunFuture", Sequence["StepRunFuture"], None] = None,
202201
concurrent: bool = False,
203202
) -> Union[StepRunOutputs, "StepRunOutputsFuture"]:
204203
"""Launch a step.
@@ -240,8 +239,11 @@ def _launch() -> StepRunOutputs:
240239
if concurrent:
241240
ctx = contextvars.copy_context()
242241
future = self._executor.submit(ctx.run, _launch)
242+
compiled_step.config.outputs
243243
step_run_future = StepRunOutputsFuture(
244-
wrapped=future, invocation_id=compiled_step.spec.invocation_id
244+
wrapped=future,
245+
invocation_id=compiled_step.spec.invocation_id,
246+
output_keys=list(compiled_step.config.outputs),
245247
)
246248
self._futures.append(step_run_future)
247249
return step_run_future
@@ -251,7 +253,7 @@ def _launch() -> StepRunOutputs:
251253
def await_all_step_run_futures(self) -> None:
252254
"""Await all step run output futures."""
253255
for future in self._futures:
254-
future.wait()
256+
future.artifacts()
255257
self._futures = []
256258

257259

@@ -262,9 +264,7 @@ def compile_dynamic_step_invocation(
262264
id: Optional[str],
263265
args: Tuple[Any],
264266
kwargs: Dict[str, Any],
265-
after: Union[
266-
"StepRunOutputsFuture", Sequence["StepRunOutputsFuture"], None
267-
] = None,
267+
after: Union["StepRunFuture", Sequence["StepRunFuture"], None] = None,
268268
) -> "Step":
269269
"""Compile a dynamic step invocation.
270270
@@ -282,27 +282,25 @@ def compile_dynamic_step_invocation(
282282
"""
283283
upstream_steps = set()
284284

285-
if isinstance(after, StepRunOutputsFuture):
286-
after.wait()
285+
if isinstance(after, StepRunFuture):
286+
after._wait()
287287
upstream_steps.add(after._invocation_id)
288288
elif isinstance(after, Sequence):
289289
for item in after:
290-
item.wait()
290+
item._wait()
291291
upstream_steps.add(item._invocation_id)
292292

293293
def _await_and_validate_input(input: Any) -> Any:
294294
if isinstance(input, StepRunOutputsFuture):
295-
input = input.result()
295+
if len(input._output_keys) != 1:
296+
raise ValueError(
297+
"Passing multiple step run outputs to another step is not "
298+
"allowed."
299+
)
300+
input = input.artifacts()
296301

297-
if (
298-
input
299-
and isinstance(input, tuple)
300-
and isinstance(input[0], OutputArtifact)
301-
):
302-
raise ValueError(
303-
"Passing multiple step run outputs to another step is not "
304-
"allowed."
305-
)
302+
if isinstance(input, ArtifactFuture):
303+
input = input.result()
306304

307305
if isinstance(input, OutputArtifact):
308306
upstream_steps.add(input.step_name)

src/zenml/orchestrators/step_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def run(
219219
step_failed = False
220220
try:
221221
if (
222+
# TODO: do we need to disable this for dynamic pipelines?
222223
pipeline_run.snapshot
223224
and self._stack.orchestrator.run_init_cleanup_at_step_level
224225
):

src/zenml/steps/base_step.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@
9494
Mapping[str, Sequence["MaterializerClassOrSource"]],
9595
]
9696

97-
from zenml.execution.pipeline.dynamic.outputs import StepRunOutputsFuture
97+
from zenml.execution.pipeline.dynamic.outputs import (
98+
StepRunFuture,
99+
StepRunOutputsFuture,
100+
)
98101

99102

100103
logger = get_logger(__name__)
@@ -468,8 +471,8 @@ def __call__(
468471
after: Union[
469472
str,
470473
StepArtifact,
471-
"StepRunOutputsFuture",
472-
Sequence[Union[str, StepArtifact, "StepRunOutputsFuture"]],
474+
"StepRunFuture",
475+
Sequence[Union[str, StepArtifact, "StepRunFuture"]],
473476
None,
474477
] = None,
475478
**kwargs: Any,
@@ -512,8 +515,8 @@ def __call__(
512515
if run_context := DynamicPipelineRunContext.get():
513516
after = cast(
514517
Union[
515-
"StepRunOutputsFuture",
516-
Sequence["StepRunOutputsFuture"],
518+
"StepRunFuture",
519+
Sequence["StepRunFuture"],
517520
None,
518521
],
519522
after,
@@ -627,9 +630,7 @@ def submit(
627630
self,
628631
*args: Any,
629632
id: Optional[str] = None,
630-
after: Union[
631-
"StepRunOutputsFuture", Sequence["StepRunOutputsFuture"], None
632-
] = None,
633+
after: Union["StepRunFuture", Sequence["StepRunFuture"], None] = None,
633634
**kwargs: Any,
634635
) -> "StepRunOutputsFuture":
635636
from zenml.execution.pipeline.dynamic.run_context import (

0 commit comments

Comments
 (0)