Skip to content

Commit eccde5a

Browse files
committed
Reworked the init/cleanup hooks and the run context
1 parent 032dced commit eccde5a

File tree

10 files changed

+230
-145
lines changed

10 files changed

+230
-145
lines changed

src/zenml/deployers/server/service.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
PipelineRunTriggerInfo,
4444
PipelineSnapshotResponse,
4545
)
46+
from zenml.orchestrators.base_orchestrator import BaseOrchestrator
4647
from zenml.orchestrators.local.local_orchestrator import (
4748
LocalOrchestrator,
4849
LocalOrchestratorConfig,
@@ -56,10 +57,12 @@
5657

5758

5859
class SharedLocalOrchestrator(LocalOrchestrator):
59-
"""Local orchestrator that uses a separate run id for each request.
60+
"""Local orchestrator tweaked for deployments.
6061
61-
This is a slight modification of the LocalOrchestrator to allow for
62-
request-scoped orchestrator run ids by storing them in contextvars.
62+
This is a slight modification of the LocalOrchestrator:
63+
- uses request-scoped orchestrator run ids by storing them in contextvars
64+
- bypasses the init/cleanup hook execution because they are run globally by
65+
the deployment service
6366
"""
6467

6568
# Use contextvars for thread-safe, request-scoped state
@@ -79,6 +82,28 @@ def get_orchestrator_run_id(self) -> str:
7982
self._shared_orchestrator_run_id.set(run_id)
8083
return run_id
8184

85+
@classmethod
86+
def run_init_hook(cls, snapshot: "PipelineSnapshotResponse") -> None:
87+
"""Runs the init hook.
88+
89+
Args:
90+
snapshot: The snapshot to run the init hook for.
91+
"""
92+
# Bypass the init hook execution because it is run globally by
93+
# the deployment service
94+
pass
95+
96+
@classmethod
97+
def run_cleanup_hook(cls, snapshot: "PipelineSnapshotResponse") -> None:
98+
"""Runs the cleanup hook.
99+
100+
Args:
101+
snapshot: The snapshot to run the cleanup hook for.
102+
"""
103+
# Bypass the cleanup hook execution because it is run globally by
104+
# the deployment service
105+
pass
106+
82107

83108
class PipelineDeploymentService:
84109
"""Pipeline deployment service."""
@@ -97,7 +122,6 @@ def __init__(self, deployment_id: Union[str, UUID]) -> None:
97122
deployment_id = UUID(deployment_id)
98123

99124
self._client = Client()
100-
self.pipeline_state: Optional[Any] = None
101125

102126
# Execution tracking
103127
self.service_start_time = time.time()
@@ -137,9 +161,7 @@ def initialize(self) -> None:
137161
"""
138162
try:
139163
# Execute init hook
140-
self._execute_init_hook()
141-
142-
self._orchestrator.set_shared_run_state(self.pipeline_state)
164+
BaseOrchestrator.run_init_hook(self.snapshot)
143165

144166
# Log success
145167
self._log_initialization_success()
@@ -150,28 +172,8 @@ def initialize(self) -> None:
150172
raise
151173

152174
def cleanup(self) -> None:
153-
"""Execute cleanup hook if present.
154-
155-
Raises:
156-
Exception: If the cleanup hook cannot be executed.
157-
"""
158-
cleanup_hook_source = (
159-
self.snapshot
160-
and self.snapshot.pipeline_configuration.cleanup_hook_source
161-
)
162-
163-
if not cleanup_hook_source:
164-
return
165-
166-
logger.info("Executing pipeline's cleanup hook...")
167-
try:
168-
with env_utils.temporary_environment(
169-
self.snapshot.pipeline_configuration.environment
170-
):
171-
load_and_run_hook(cleanup_hook_source)
172-
except Exception as e:
173-
logger.exception(f"Failed to execute cleanup hook: {e}")
174-
raise
175+
"""Execute cleanup hook if present."""
176+
BaseOrchestrator.run_cleanup_hook(self.snapshot)
175177

176178
def execute_pipeline(
177179
self,

src/zenml/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,7 @@ def __init__(
248248

249249
class HookValidationException(ZenMLBaseException):
250250
"""Exception raised when hook validation fails."""
251+
252+
253+
class HookExecutionException(ZenMLBaseException):
254+
"""Exception raised when hook execution fails."""

src/zenml/hooks/hook_validators.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,11 @@ def _parse_hook_inputs(
430430
resolved_type = resolve_type_annotation(arg_type) if arg_type else None
431431

432432
# Handle BaseException parameters - inject step_exception
433-
if resolved_type and issubclass(resolved_type, BaseException):
433+
if (
434+
resolved_type
435+
and isinstance(resolved_type, type)
436+
and issubclass(resolved_type, BaseException)
437+
):
434438
function_params[arg] = step_exception
435439
continue
436440

src/zenml/orchestrators/base_orchestrator.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,12 @@
3737
)
3838
from zenml.enums import ExecutionMode, ExecutionStatus, StackComponentType
3939
from zenml.exceptions import (
40+
HookExecutionException,
4041
IllegalOperationError,
4142
RunMonitoringError,
4243
RunStoppedException,
4344
)
45+
from zenml.hooks.hook_validators import load_and_run_hook
4446
from zenml.logger import get_logger
4547
from zenml.metadata.metadata_types import MetadataType
4648
from zenml.orchestrators.publish_utils import (
@@ -51,7 +53,8 @@
5153
from zenml.orchestrators.step_launcher import StepLauncher
5254
from zenml.orchestrators.utils import get_config_environment_vars
5355
from zenml.stack import Flavor, Stack, StackComponent, StackComponentConfig
54-
from zenml.steps.step_context import RunContext
56+
from zenml.steps.step_context import RunContext, get_or_create_run_context
57+
from zenml.utils.env_utils import temporary_environment
5558
from zenml.utils.pydantic_utils import before_validator_handler
5659

5760
if TYPE_CHECKING:
@@ -393,13 +396,11 @@ def run(
393396
def run_step(
394397
self,
395398
step: "Step",
396-
run_context: Optional[RunContext] = None,
397399
) -> None:
398400
"""Runs the given step.
399401
400402
Args:
401403
step: The step to run.
402-
run_context: A shared run context.
403404
404405
Raises:
405406
RunStoppedException: If the run was stopped.
@@ -413,7 +414,6 @@ def _launch_step() -> None:
413414
snapshot=self._active_snapshot,
414415
step=step,
415416
orchestrator_run_id=self.get_orchestrator_run_id(),
416-
run_context=run_context,
417417
)
418418
launcher.launch()
419419

@@ -498,6 +498,96 @@ def supported_execution_modes(self) -> List[ExecutionMode]:
498498
"""
499499
return [ExecutionMode.CONTINUE_ON_FAILURE]
500500

501+
@property
502+
def run_init_cleanup_at_step_level(self) -> bool:
503+
"""Whether the orchestrator runs the init and cleanup hooks at step level.
504+
505+
For orchestrators that run their steps in isolated step environments,
506+
the run context cannot be shared between steps. In this case, the init
507+
and cleanup hooks need to be run at step level for each individual step.
508+
509+
For orchestrators that run their steps in a shared environment with a
510+
shared memory (e.g. the local orchestrator), the init and cleanup hooks
511+
can be run at run level and this property should be overridden to return
512+
True.
513+
514+
Returns:
515+
Whether the orchestrator runs the init and cleanup hooks at step
516+
level.
517+
"""
518+
return True
519+
520+
@classmethod
521+
def run_init_hook(cls, snapshot: "PipelineSnapshotResponse") -> None:
522+
"""Runs the init hook.
523+
524+
Args:
525+
snapshot: The snapshot to run the init hook for.
526+
527+
Raises:
528+
HookExecutionException: If the init hook fails.
529+
"""
530+
# The lifetime of the run context starts when the init hook is executed
531+
# and ends when the cleanup hook is executed
532+
run_context = get_or_create_run_context()
533+
init_hook_source = snapshot.pipeline_configuration.init_hook_source
534+
init_hook_kwargs = snapshot.pipeline_configuration.init_hook_kwargs
535+
536+
# We only run the init hook once, if the (thread-local) run context
537+
# associated with the current run has not been initialized yet. This
538+
# allows us to run the init hook only once per run per execution
539+
# environment (process, container, etc.).
540+
if not run_context.initialized:
541+
if not init_hook_source:
542+
run_context.initialize(None)
543+
return
544+
545+
logger.info("Executing the pipeline's init hook...")
546+
try:
547+
with temporary_environment(
548+
snapshot.pipeline_configuration.environment
549+
):
550+
run_state = load_and_run_hook(
551+
init_hook_source,
552+
hook_parameters=init_hook_kwargs,
553+
raise_on_error=True,
554+
)
555+
except Exception as e:
556+
raise HookExecutionException(
557+
f"Failed to execute init hook for pipeline "
558+
f"{snapshot.pipeline_configuration.name}"
559+
) from e
560+
561+
run_context.initialize(run_state)
562+
563+
@classmethod
564+
def run_cleanup_hook(cls, snapshot: "PipelineSnapshotResponse") -> None:
565+
"""Runs the cleanup hook.
566+
567+
Args:
568+
snapshot: The snapshot to run the cleanup hook for.
569+
"""
570+
# The lifetime of the run context starts when the init hook is executed
571+
# and ends when the cleanup hook is executed
572+
if not RunContext._exists():
573+
return
574+
575+
if (
576+
cleanup_hook_source
577+
:= snapshot.pipeline_configuration.cleanup_hook_source
578+
):
579+
logger.info("Executing the pipeline's cleanup hook...")
580+
with temporary_environment(
581+
snapshot.pipeline_configuration.environment
582+
):
583+
load_and_run_hook(
584+
cleanup_hook_source,
585+
raise_on_error=False,
586+
)
587+
588+
# Destroy the run context, so it's created anew for the next run
589+
RunContext._clear()
590+
501591
def _validate_execution_mode(
502592
self, snapshot: "PipelineSnapshotResponse"
503593
) -> None:

src/zenml/orchestrators/local/local_orchestrator.py

Lines changed: 29 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414
"""Implementation of the ZenML local orchestrator."""
1515

1616
import time
17-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
17+
from typing import TYPE_CHECKING, Dict, List, Optional, Type
1818
from uuid import uuid4
1919

2020
from zenml.enums import ExecutionMode
21-
from zenml.hooks.hook_validators import load_and_run_hook
2221
from zenml.logger import get_logger
2322
from zenml.orchestrators import (
2423
BaseOrchestrator,
@@ -27,7 +26,6 @@
2726
SubmissionResult,
2827
)
2928
from zenml.stack import Stack
30-
from zenml.steps.step_context import RunContext
3129
from zenml.utils import string_utils
3230
from zenml.utils.env_utils import temporary_environment
3331

@@ -45,15 +43,25 @@ class LocalOrchestrator(BaseOrchestrator):
4543
"""
4644

4745
_orchestrator_run_id: Optional[str] = None
48-
_run_context: Optional[RunContext] = None
4946

50-
def set_shared_run_state(self, state: Optional[Any]) -> None:
51-
"""Sets the state to be shared between all steps of all runs executed by this orchestrator.
47+
@property
48+
def run_init_cleanup_at_step_level(self) -> bool:
49+
"""Whether the orchestrator runs the init and cleanup hooks at step level.
5250
53-
Args:
54-
state: the state to be shared
51+
For orchestrators that run their steps in isolated step environments,
52+
the run context cannot be shared between steps. In this case, the init
53+
and cleanup hooks need to be run at step level for each individual step.
54+
55+
For orchestrators that run their steps in a shared environment with a
56+
shared memory (e.g. the local orchestrator), the init and cleanup hooks
57+
can be run at run level and this property should be overridden to return
58+
True.
59+
60+
Returns:
61+
Whether the orchestrator runs the init and cleanup hooks at step
62+
level.
5563
"""
56-
self._run_context = RunContext(state=state)
64+
return False
5765

5866
def submit_pipeline(
5967
self,
@@ -100,25 +108,10 @@ def submit_pipeline(
100108
execution_mode = snapshot.pipeline_configuration.execution_mode
101109

102110
failed_steps: List[str] = []
111+
step_exception: Optional[Exception] = None
103112
skipped_steps: List[str] = []
104113

105-
# If the run context is not set globally, we initialize it by running
106-
# the init hook
107-
if self._run_context:
108-
run_context = self._run_context
109-
else:
110-
state = None
111-
if (
112-
init_hook_source
113-
:= snapshot.pipeline_configuration.init_hook_source
114-
):
115-
logger.info("Executing the pipeline's init hook...")
116-
state = load_and_run_hook(
117-
init_hook_source,
118-
hook_parameters=snapshot.pipeline_configuration.init_hook_kwargs,
119-
raise_on_error=True,
120-
)
121-
run_context = RunContext(state=state)
114+
self.run_init_hook(snapshot=snapshot)
122115

123116
# Run each step
124117
for step_name, step in snapshot.step_configurations.items():
@@ -170,32 +163,21 @@ def submit_pipeline(
170163
step_environment = step_environments[step_name]
171164
try:
172165
with temporary_environment(step_environment):
173-
self.run_step(step=step, run_context=run_context)
174-
except Exception:
166+
self.run_step(step=step)
167+
except Exception as e:
175168
logger.exception("Failed to execute step %s.", step_name)
176169
failed_steps.append(step_name)
177170
logger.exception("Step %s failed.", step_name)
178171

179172
if execution_mode == ExecutionMode.FAIL_FAST:
180-
raise
181-
182-
finally:
183-
try:
184-
# If the run context is not set globally, we also run the
185-
# cleanup hook
186-
if not self._run_context:
187-
if (
188-
cleanup_hook_source
189-
:= snapshot.pipeline_configuration.cleanup_hook_source
190-
):
191-
logger.info(
192-
"Executing the pipeline's cleanup hook..."
193-
)
194-
load_and_run_hook(
195-
cleanup_hook_source,
196-
)
197-
except Exception:
198-
logger.exception("Failed to execute cleanup hook.")
173+
step_exception = e
174+
break
175+
176+
self.run_cleanup_hook(snapshot=snapshot)
177+
178+
if execution_mode == ExecutionMode.FAIL_FAST and failed_steps:
179+
assert step_exception is not None
180+
raise step_exception
199181

200182
if failed_steps:
201183
raise RuntimeError(

0 commit comments

Comments
 (0)