3939from zenml .pipelines .run_utils import create_placeholder_run
4040from zenml .stack import Stack
4141from zenml .steps .entrypoint_function_utils import StepArtifact
42+ from zenml .steps .utils import OutputSignature
4243from zenml .utils import source_utils
4344
4445if TYPE_CHECKING :
5152
5253
5354class DynamicStepRunOutput (ArtifactVersionResponse ):
55+ """Dynamic step run output artifact."""
56+
5457 output_name : str
5558 step_name : str
5659
@@ -61,17 +64,41 @@ class DynamicStepRunOutput(ArtifactVersionResponse):
6164
6265
6366class StepRunOutputsFuture :
64- def __init__ (self , wrapped : Future [StepRunOutputs ], invocation_id : str ):
67+ """Future for a step run output."""
68+
69+ def __init__ (
70+ self , wrapped : Future [StepRunOutputs ], invocation_id : str
71+ ) -> None :
72+ """Initialize the future.
73+
74+ Args:
75+ wrapped: The wrapped future object.
76+ invocation_id: The invocation ID of the step run.
77+ """
6578 self ._wrapped = wrapped
6679 self ._invocation_id = invocation_id
6780
6881 def wait (self ) -> None :
82+ """Wait for the future to complete."""
6983 self ._wrapped .result ()
7084
7185 def result (self ) -> StepRunOutputs :
86+ """Get the step run output artifacts.
87+
88+ Returns:
89+ The step run output artifacts.
90+ """
7291 return self ._wrapped .result ()
7392
7493 def load (self ) -> Any :
94+ """Get the step run output artifact data.
95+
96+ Raises:
97+ ValueError: If the step run output is invalid.
98+
99+ Returns:
100+ The step run output artifact data.
101+ """
75102 result = self .result ()
76103
77104 if result is None :
@@ -153,6 +180,7 @@ def run_pipeline(self) -> None:
153180 ):
154181 self ._orchestrator .run_init_hook (snapshot = self ._snapshot )
155182 try :
183+ # TODO: step logging isn't threadsafe
156184 self .pipeline ._call_entrypoint (** pipeline_parameters )
157185 except :
158186 publish_failed_pipeline_run (run .id )
@@ -240,6 +268,20 @@ def _prepare_step_run(
240268 "StepRunOutputsFuture" , Sequence ["StepRunOutputsFuture" ], None
241269 ] = None ,
242270) -> Tuple [Dict [str , Any ], Set [str ]]:
271+ """Prepare a step run.
272+
273+ Args:
274+ step: The step to prepare.
275+ args: The arguments for the step function.
276+ kwargs: The keyword arguments for the step function.
277+ after: The step run output futures to wait for.
278+
279+ Returns:
280+ A tuple containing the inputs and the upstream steps.
281+
282+ Raises:
283+ ValueError: If an invalid step function input was passed.
284+ """
243285 upstream_steps = set ()
244286
245287 if isinstance (after , StepRunOutputsFuture ):
@@ -250,7 +292,7 @@ def _prepare_step_run(
250292 item .wait ()
251293 upstream_steps .add (item ._invocation_id )
252294
253- def _await_and_validate_input (input : Any ):
295+ def _await_and_validate_input (input : Any ) -> Any :
254296 if isinstance (input , StepRunOutputsFuture ):
255297 input = input .result ()
256298
@@ -260,15 +302,16 @@ def _await_and_validate_input(input: Any):
260302 and isinstance (input [0 ], DynamicStepRunOutput )
261303 ):
262304 raise ValueError (
263- "Passing multiple step run outputs to another step is not allowed."
305+ "Passing multiple step run outputs to another step is not "
306+ "allowed."
264307 )
265308
266309 if isinstance (input , DynamicStepRunOutput ):
267310 upstream_steps .add (input .step_name )
268311
269312 return input
270313
271- args = [ _await_and_validate_input (arg ) for arg in args ]
314+ args = tuple ( _await_and_validate_input (arg ) for arg in args )
272315 kwargs = {
273316 key : _await_and_validate_input (value ) for key , value in kwargs .items ()
274317 }
@@ -295,7 +338,7 @@ def _compile_step(
295338 input_artifacts [name ] = StepArtifact (
296339 invocation_id = value .step_name ,
297340 output_name = value .output_name ,
298- annotation = Any ,
341+ annotation = OutputSignature ( resolved_annotation = Any ) ,
299342 pipeline = pipeline ,
300343 )
301344 elif isinstance (value , (ArtifactVersionResponse , ExternalArtifact )):
@@ -335,6 +378,18 @@ def _run_step_sync(
335378 orchestrator_run_id : str ,
336379 retry : bool = False ,
337380) -> StepRunResponse :
381+ """Run a step in the active thread.
382+
383+ Args:
384+ snapshot: The snapshot.
385+ step: The step to run.
386+ orchestrator_run_id: The orchestrator run ID.
387+ retry: Whether to retry the step if it fails.
388+
389+ Returns:
390+ The step run response.
391+ """
392+
338393 def _launch_step () -> StepRunResponse :
339394 launcher = StepLauncher (
340395 snapshot = snapshot ,
@@ -383,6 +438,14 @@ def _launch_step() -> StepRunResponse:
383438
384439
385440def _load_step_outputs (step_run_id : UUID ) -> StepRunOutputs :
441+ """Load the outputs of a step run.
442+
443+ Args:
444+ step_run_id: The ID of the step run.
445+
446+ Returns:
447+ The outputs of the step run.
448+ """
386449 step_run = Client ().zen_store .get_run_step (step_run_id )
387450
388451 def _convert_output_artifact (
@@ -410,6 +473,15 @@ def _convert_output_artifact(
410473def _should_retry_locally (
411474 step : "Step" , pipeline_docker_settings : "DockerSettings"
412475) -> bool :
476+ """Determine if a step should be retried locally.
477+
478+ Args:
479+ step: The step.
480+ pipeline_docker_settings: The Docker settings of the parent pipeline.
481+
482+ Returns:
483+ Whether the step should be retried locally.
484+ """
413485 if step .config .step_operator :
414486 return True
415487
@@ -425,6 +497,15 @@ def _should_retry_locally(
425497def should_run_in_process (
426498 step : "Step" , pipeline_docker_settings : "DockerSettings"
427499) -> bool :
500+ """Determine if a step should be run in process.
501+
502+ Args:
503+ step: The step.
504+ pipeline_docker_settings: The Docker settings of the parent pipeline.
505+
506+ Returns:
507+ Whether the step should be run in process.
508+ """
428509 if step .config .step_operator :
429510 return False
430511
@@ -448,6 +529,16 @@ def get_config_template(
448529 step : "BaseStep" ,
449530 pipeline : "DynamicPipeline" ,
450531) -> Optional ["Step" ]:
532+ """Get the config template for a step executed in a dynamic pipeline.
533+
534+ Args:
535+ snapshot: The snapshot of the pipeline.
536+ step: The step to get the config template for.
537+ pipeline: The dynamic pipeline that the step is being executed in.
538+
539+ Returns:
540+ The config template for the step.
541+ """
451542 for index , step_ in enumerate (pipeline .depends_on ):
452543 if step_ ._static_id == step ._static_id :
453544 break
0 commit comments