3737)
3838from zenml .enums import ExecutionMode , ExecutionStatus , StackComponentType
3939from zenml .exceptions import (
40+ HookExecutionException ,
4041 IllegalOperationError ,
4142 RunMonitoringError ,
4243 RunStoppedException ,
4344)
45+ from zenml .hooks .hook_validators import load_and_run_hook
4446from zenml .logger import get_logger
4547from zenml .metadata .metadata_types import MetadataType
4648from zenml .orchestrators .publish_utils import (
5153from zenml .orchestrators .step_launcher import StepLauncher
5254from zenml .orchestrators .utils import get_config_environment_vars
5355from zenml .stack import Flavor , Stack , StackComponent , StackComponentConfig
54- from zenml .steps .step_context import RunContext
56+ from zenml .steps .step_context import RunContext , get_or_create_run_context
57+ from zenml .utils .env_utils import temporary_environment
5558from zenml .utils .pydantic_utils import before_validator_handler
5659
5760if TYPE_CHECKING :
@@ -393,13 +396,11 @@ def run(
393396 def run_step (
394397 self ,
395398 step : "Step" ,
396- run_context : Optional [RunContext ] = None ,
397399 ) -> None :
398400 """Runs the given step.
399401
400402 Args:
401403 step: The step to run.
402- run_context: A shared run context.
403404
404405 Raises:
405406 RunStoppedException: If the run was stopped.
@@ -413,7 +414,6 @@ def _launch_step() -> None:
413414 snapshot = self ._active_snapshot ,
414415 step = step ,
415416 orchestrator_run_id = self .get_orchestrator_run_id (),
416- run_context = run_context ,
417417 )
418418 launcher .launch ()
419419
@@ -498,6 +498,96 @@ def supported_execution_modes(self) -> List[ExecutionMode]:
498498 """
499499 return [ExecutionMode .CONTINUE_ON_FAILURE ]
500500
501+ @property
502+ def run_init_cleanup_at_step_level (self ) -> bool :
503+ """Whether the orchestrator runs the init and cleanup hooks at step level.
504+
505+ For orchestrators that run their steps in isolated step environments,
506+ the run context cannot be shared between steps. In this case, the init
507+ and cleanup hooks need to be run at step level for each individual step.
508+
509+ For orchestrators that run their steps in a shared environment with a
510+ shared memory (e.g. the local orchestrator), the init and cleanup hooks
511+ can be run at run level and this property should be overridden to return
512+ True.
513+
514+ Returns:
515+ Whether the orchestrator runs the init and cleanup hooks at step
516+ level.
517+ """
518+ return True
519+
520+ @classmethod
521+ def run_init_hook (cls , snapshot : "PipelineSnapshotResponse" ) -> None :
522+ """Runs the init hook.
523+
524+ Args:
525+ snapshot: The snapshot to run the init hook for.
526+
527+ Raises:
528+ HookExecutionException: If the init hook fails.
529+ """
530+ # The lifetime of the run context starts when the init hook is executed
531+ # and ends when the cleanup hook is executed
532+ run_context = get_or_create_run_context ()
533+ init_hook_source = snapshot .pipeline_configuration .init_hook_source
534+ init_hook_kwargs = snapshot .pipeline_configuration .init_hook_kwargs
535+
536+ # We only run the init hook once, if the (thread-local) run context
537+ # associated with the current run has not been initialized yet. This
538+ # allows us to run the init hook only once per run per execution
539+ # environment (process, container, etc.).
540+ if not run_context .initialized :
541+ if not init_hook_source :
542+ run_context .initialize (None )
543+ return
544+
545+ logger .info ("Executing the pipeline's init hook..." )
546+ try :
547+ with temporary_environment (
548+ snapshot .pipeline_configuration .environment
549+ ):
550+ run_state = load_and_run_hook (
551+ init_hook_source ,
552+ hook_parameters = init_hook_kwargs ,
553+ raise_on_error = True ,
554+ )
555+ except Exception as e :
556+ raise HookExecutionException (
557+ f"Failed to execute init hook for pipeline "
558+ f"{ snapshot .pipeline_configuration .name } "
559+ ) from e
560+
561+ run_context .initialize (run_state )
562+
563+ @classmethod
564+ def run_cleanup_hook (cls , snapshot : "PipelineSnapshotResponse" ) -> None :
565+ """Runs the cleanup hook.
566+
567+ Args:
568+ snapshot: The snapshot to run the cleanup hook for.
569+ """
570+ # The lifetime of the run context starts when the init hook is executed
571+ # and ends when the cleanup hook is executed
572+ if not RunContext ._exists ():
573+ return
574+
575+ if (
576+ cleanup_hook_source
577+ := snapshot .pipeline_configuration .cleanup_hook_source
578+ ):
579+ logger .info ("Executing the pipeline's cleanup hook..." )
580+ with temporary_environment (
581+ snapshot .pipeline_configuration .environment
582+ ):
583+ load_and_run_hook (
584+ cleanup_hook_source ,
585+ raise_on_error = False ,
586+ )
587+
588+ # Destroy the run context, so it's created anew for the next run
589+ RunContext ._clear ()
590+
501591 def _validate_execution_mode (
502592 self , snapshot : "PipelineSnapshotResponse"
503593 ) -> None :
0 commit comments