Skip to content

Commit 22e7a9d

Browse files
committed
Mypy and docstrings
1 parent 050970f commit 22e7a9d

File tree

14 files changed

+109
-30
lines changed

14 files changed

+109
-30
lines changed

src/zenml/execution/pipeline/dynamic/outputs.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,28 +34,61 @@ class OutputArtifact(ArtifactVersionResponse):
3434
StepRunOutputs = Union[None, OutputArtifact, Tuple[OutputArtifact, ...]]
3535

3636

37-
class ArtifactFuture:
38-
"""Future for a step run output artifact."""
37+
class _BaseStepRunFuture:
38+
"""Base step run future."""
3939

4040
def __init__(
41-
self, wrapped: Future[StepRunOutputs], invocation_id: str, index: int
41+
self,
42+
wrapped: Future[StepRunOutputs],
43+
invocation_id: str,
44+
**kwargs: Any,
4245
) -> None:
43-
"""Initialize the future.
46+
"""Initialize the dynamic step run future.
4447
4548
Args:
4649
wrapped: The wrapped future object.
4750
invocation_id: The invocation ID of the step run.
51+
**kwargs: Additional keyword arguments.
4852
"""
4953
self._wrapped = wrapped
5054
self._invocation_id = invocation_id
51-
self._index = index
55+
56+
@property
57+
def invocation_id(self) -> str:
58+
"""The step run invocation ID.
59+
60+
Returns:
61+
The step run invocation ID.
62+
"""
63+
return self._invocation_id
5264

5365
def _wait(self) -> None:
66+
"""Wait for the step run future to complete."""
5467
self._wrapped.result()
5568

69+
70+
class ArtifactFuture(_BaseStepRunFuture):
71+
"""Future for a step run output artifact."""
72+
73+
def __init__(
74+
self, wrapped: Future[StepRunOutputs], invocation_id: str, index: int
75+
) -> None:
76+
"""Initialize the future.
77+
78+
Args:
79+
wrapped: The wrapped future object.
80+
invocation_id: The invocation ID of the step run.
81+
index: The index of the output artifact.
82+
"""
83+
super().__init__(wrapped=wrapped, invocation_id=invocation_id)
84+
self._index = index
85+
5686
def result(self) -> OutputArtifact:
5787
"""Get the step run output artifact.
5888
89+
Raises:
90+
RuntimeError: If the future returned an invalid output.
91+
5992
Returns:
6093
The step run output artifact.
6194
"""
@@ -66,7 +99,8 @@ def result(self) -> OutputArtifact:
6699
return result[self._index]
67100
else:
68101
raise RuntimeError(
69-
f"Step {self._invocation_id} returned an invalid output: {result}"
102+
f"Step {self._invocation_id} returned an invalid output: "
103+
f"{result}."
70104
)
71105

72106
def load(self) -> Any:
@@ -78,7 +112,7 @@ def load(self) -> Any:
78112
return self.result().load()
79113

80114

81-
class StepRunOutputsFuture:
115+
class StepRunOutputsFuture(_BaseStepRunFuture):
82116
"""Future for a step run output."""
83117

84118
def __init__(
@@ -92,14 +126,11 @@ def __init__(
92126
Args:
93127
wrapped: The wrapped future object.
94128
invocation_id: The invocation ID of the step run.
129+
output_keys: The output keys of the step run.
95130
"""
96-
self._wrapped = wrapped
97-
self._invocation_id = invocation_id
131+
super().__init__(wrapped=wrapped, invocation_id=invocation_id)
98132
self._output_keys = output_keys
99133

100-
def _wait(self) -> None:
101-
self._wrapped.result()
102-
103134
def artifacts(self) -> StepRunOutputs:
104135
"""Get the step run output artifacts.
105136

src/zenml/execution/pipeline/dynamic/runner.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
StepRunFuture,
4242
StepRunOutputs,
4343
StepRunOutputsFuture,
44+
_BaseStepRunFuture,
4445
)
4546
from zenml.execution.pipeline.dynamic.run_context import (
4647
DynamicPipelineRunContext,
@@ -292,22 +293,22 @@ def compile_dynamic_step_invocation(
292293
pipeline: The dynamic pipeline.
293294
step: The step to compile.
294295
id: Custom invocation ID.
295-
upstream_steps: The upstream steps.
296-
inputs: Inputs to the step function.
297-
default_parameters: Default parameters of the step function.
296+
args: The arguments for the step function.
297+
kwargs: The keyword arguments for the step function.
298+
after: The step run output futures to wait for.
298299
299300
Returns:
300301
The compiled step.
301302
"""
302303
upstream_steps = set()
303304

304-
if isinstance(after, StepRunFuture):
305+
if isinstance(after, _BaseStepRunFuture):
305306
after._wait()
306-
upstream_steps.add(after._invocation_id)
307+
upstream_steps.add(after.invocation_id)
307308
elif isinstance(after, Sequence):
308309
for item in after:
309310
item._wait()
310-
upstream_steps.add(item._invocation_id)
311+
upstream_steps.add(item.invocation_id)
311312

312313
def _await_and_validate_input(input: Any) -> Any:
313314
if isinstance(input, StepRunOutputsFuture):

src/zenml/execution/pipeline/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ def should_prevent_pipeline_execution() -> bool:
6262

6363
@contextmanager
6464
def prevent_pipeline_execution() -> Generator[None, None, None]:
65-
"""Context manager to prevent pipeline execution."""
65+
"""Context manager to prevent pipeline execution.
66+
67+
Yields:
68+
None.
69+
"""
6670
with env_utils.temporary_environment(
6771
{ENV_ZENML_PREVENT_PIPELINE_EXECUTION: "True"}
6872
):

src/zenml/execution/step/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ def launch_step(
4848
orchestrator_run_id: The orchestrator run ID.
4949
retry: Whether to retry the step if it fails.
5050
51+
Raises:
52+
RunStoppedException: If the run was stopped.
53+
BaseException: If the step failed all retries.
54+
5155
Returns:
5256
The step run response.
5357
"""

src/zenml/orchestrators/base_orchestrator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -433,10 +433,6 @@ def run_step(
433433
434434
Args:
435435
step: The step to run.
436-
437-
Raises:
438-
RunStoppedException: If the run was stopped.
439-
BaseException: If the step failed all retries.
440436
"""
441437
from zenml.execution.step.utils import launch_step
442438

@@ -482,6 +478,10 @@ def launch_dynamic_step(
482478
step_run_info: The step run information.
483479
environment: The environment variables to set in the execution
484480
environment.
481+
482+
Raises:
483+
NotImplementedError: If the orchestrator does not implement this
484+
method.
485485
"""
486486
raise NotImplementedError(
487487
"Launching dynamic steps is not implemented for "

src/zenml/orchestrators/local/local_orchestrator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,18 @@ def submit_dynamic_pipeline(
193193
environment: Dict[str, str],
194194
placeholder_run: Optional["PipelineRunResponse"] = None,
195195
) -> Optional[SubmissionResult]:
196-
"""Submits a dynamic pipeline to the orchestrator."""
196+
"""Submits a dynamic pipeline to the orchestrator.
197+
198+
Args:
199+
snapshot: The pipeline snapshot to submit.
200+
stack: The stack the pipeline will run on.
201+
environment: Environment variables to set in the orchestration
202+
environment.
203+
placeholder_run: An optional placeholder run.
204+
205+
Returns:
206+
Optional submission result.
207+
"""
197208
from zenml.execution.pipeline.dynamic.runner import (
198209
DynamicPipelineRunner,
199210
)

src/zenml/orchestrators/publish_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,9 @@ def publish_successful_pipeline_run(
125125
126126
Args:
127127
pipeline_run_id: The ID of the pipeline run to update.
128+
129+
Returns:
130+
The updated pipeline run.
128131
"""
129132
return Client().zen_store.update_run(
130133
run_id=pipeline_run_id,

src/zenml/orchestrators/step_launcher.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ def launch(self) -> StepRunResponse:
252252
Raises:
253253
RunStoppedException: If the pipeline run is stopped by the user.
254254
BaseException: If the step preparation or execution fails.
255+
256+
Returns:
257+
The step run response.
255258
"""
256259
publish_utils.step_exception_info.set(None)
257260
pipeline_run, run_was_created = self._create_or_reuse_run()

src/zenml/pipelines/dynamic/pipeline_definition.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,11 @@ def is_dynamic(self) -> bool:
103103
return True
104104

105105
def _prepare_invocations(self, **kwargs: Any) -> None:
106-
"""Prepares the invocations of the pipeline."""
106+
"""Prepares the invocations of the pipeline.
107+
108+
Args:
109+
**kwargs: Keyword arguments.
110+
"""
107111
for step in self._depends_on:
108112
self.add_step_invocation(
109113
step,

src/zenml/pipelines/pipeline_decorator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def pipeline(
106106
_func: The decorated function.
107107
name: The name of the pipeline. If left empty, the name of the
108108
decorated function will be used as a fallback.
109+
depends_on: The steps that this pipeline depends on.
109110
enable_cache: Whether to use caching or not.
110111
enable_artifact_metadata: Whether to enable artifact metadata or not.
111112
enable_step_logs: If step logs should be enabled for this pipeline.

0 commit comments

Comments
 (0)