Skip to content

Commit d416056

Browse files
fixup! Improvements and bug fixes
- Updates migration down revision refs - context-reraise exception - changes in the step-heartbeat logic - fix null heartbeat in list/get endpoints
1 parent 6a15b66 commit d416056

File tree

11 files changed

+267
-81
lines changed

11 files changed

+267
-81
lines changed

src/zenml/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
440440
STATUS = "/status"
441441
STEP_CONFIGURATION = "/step-configuration"
442442
STEPS = "/steps"
443-
HEARTBEAT = "heartbeat"
443+
HEARTBEAT = "/heartbeat"
444444
STOP = "/stop"
445445
TAGS = "/tags"
446446
TAG_RESOURCES = "/tag_resources"

src/zenml/models/v2/core/step_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -813,9 +813,9 @@ def get_custom_filters(
813813
# ------------------ Heartbeat Model ---------------
814814

815815

816-
class StepHeartbeatResponse(BaseModel):
816+
class StepHeartbeatResponse(BaseModel, use_enum_values=True):
817817
"""Light-weight model for Step Heartbeat responses."""
818818

819819
id: UUID
820-
status: str
820+
status: ExecutionStatus
821821
latest_heartbeat: datetime

src/zenml/orchestrators/step_launcher.py

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@
4343
from zenml.orchestrators import utils as orchestrator_utils
4444
from zenml.orchestrators.step_runner import StepRunner
4545
from zenml.stack import Stack
46+
from zenml.steps import StepHeartBeatTerminationException, StepHeartbeatWorker
4647
from zenml.utils import env_utils, exception_utils, string_utils
48+
from zenml.utils.exception_utils import ContextReraise
4749
from zenml.utils.time_utils import utc_now
4850

4951
if TYPE_CHECKING:
@@ -167,7 +169,6 @@ def signal_handler(signum: int, frame: Any) -> None:
167169

168170
try:
169171
client = Client()
170-
pipeline_run = None
171172

172173
if self._step_run:
173174
pipeline_run = client.get_pipeline_run(
@@ -443,35 +444,59 @@ def _run_step(
443444
)
444445

445446
start_time = time.time()
446-
try:
447-
if self._step.config.step_operator:
448-
step_operator_name = None
449-
if isinstance(self._step.config.step_operator, str):
450-
step_operator_name = self._step.config.step_operator
451-
452-
self._run_step_with_step_operator(
453-
step_operator_name=step_operator_name,
454-
step_run_info=step_run_info,
447+
448+
# To have a cross-platform compatible handling of main thread termination
449+
# we use Python's interrupt_main instead of termination signals (not Windows supported).
450+
# Since interrupt_main raises KeyboardInterrupt we want in this context to capture it
451+
# and handle it as a custom exception.
452+
453+
with ContextReraise(
454+
source_exceptions=[KeyboardInterrupt],
455+
target_exception=StepHeartBeatTerminationException,
456+
message=f"Step {self._invocation_id} has been remotely stopped - terminating",
457+
propagate_traceback=False,
458+
) as ctx_reraise:
459+
logger.info(
460+
f"Initiating heartbeat for step: {self._invocation_id}"
461+
)
462+
463+
StepHeartbeatWorker(step_id=step_run.id).start()
464+
465+
try:
466+
if self._step.config.step_operator:
467+
step_operator_name = None
468+
if isinstance(self._step.config.step_operator, str):
469+
step_operator_name = self._step.config.step_operator
470+
471+
self._run_step_with_step_operator(
472+
step_operator_name=step_operator_name,
473+
step_run_info=step_run_info,
474+
)
475+
else:
476+
self._run_step_without_step_operator(
477+
pipeline_run=pipeline_run,
478+
step_run=step_run,
479+
step_run_info=step_run_info,
480+
input_artifacts=step_run.regular_inputs,
481+
output_artifact_uris=output_artifact_uris,
482+
)
483+
except StepHeartBeatTerminationException:
484+
logger.info(ctx_reraise.message)
485+
output_utils.remove_artifact_dirs(
486+
artifact_uris=list(output_artifact_uris.values())
455487
)
456-
else:
457-
self._run_step_without_step_operator(
458-
pipeline_run=pipeline_run,
459-
step_run=step_run,
460-
step_run_info=step_run_info,
461-
input_artifacts=step_run.regular_inputs,
462-
output_artifact_uris=output_artifact_uris,
488+
raise
489+
except: # noqa: E722
490+
output_utils.remove_artifact_dirs(
491+
artifact_uris=list(output_artifact_uris.values())
463492
)
464-
except: # noqa: E722
465-
output_utils.remove_artifact_dirs(
466-
artifact_uris=list(output_artifact_uris.values())
467-
)
468-
raise
493+
raise
469494

470-
duration = time.time() - start_time
471-
logger.info(
472-
f"Step `{self._invocation_id}` has finished in "
473-
f"`{string_utils.get_human_readable_time(duration)}`."
474-
)
495+
duration = time.time() - start_time
496+
logger.info(
497+
f"Step `{self._invocation_id}` has finished in "
498+
f"`{string_utils.get_human_readable_time(duration)}`."
499+
)
475500

476501
def _run_step_with_step_operator(
477502
self,

src/zenml/steps/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from zenml.steps.base_step import BaseStep
3030
from zenml.config.resource_settings import ResourceSettings
31+
from zenml.steps.heartbeat import StepHeartbeatWorker, StepHeartBeatTerminationException
3132
from zenml.steps.step_context import StepContext, get_step_context
3233
from zenml.steps.step_decorator import step
3334

@@ -36,5 +37,7 @@
3637
"ResourceSettings",
3738
"StepContext",
3839
"step",
39-
"get_step_context"
40+
"get_step_context",
41+
"StepHeartbeatWorker",
42+
"StepHeartBeatTerminationException",
4043
]

src/zenml/steps/heartbeat.py

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,52 +17,30 @@
1717
import logging
1818
import threading
1919
import time
20-
from typing import Annotated
2120
from uuid import UUID
2221

23-
from pydantic import BaseModel, conint, model_validator
24-
2522
from zenml.enums import ExecutionStatus
2623

2724
logger = logging.getLogger(__name__)
2825

2926

3027
class StepHeartBeatTerminationException(Exception):
3128
"""Custom exception class for heartbeat termination."""
32-
3329
pass
3430

3531

36-
class StepHeartBeatOptions(BaseModel):
37-
"""Options group for step heartbeat execution."""
38-
39-
step_id: UUID
40-
interval: Annotated[int, conint(ge=10, le=60)]
41-
name: str | None = None
42-
43-
@model_validator(mode="after")
44-
def set_default_name(self) -> "StepHeartBeatOptions":
45-
"""Model validator - set name value if missing.
46-
47-
Returns:
48-
The validated step heartbeat options.
49-
"""
50-
if not self.name:
51-
self.name = f"HeartBeatWorker-{self.step_id}"
52-
53-
return self
54-
55-
56-
class HeartbeatWorker:
32+
class StepHeartbeatWorker:
5733
"""Worker class implementing heartbeat polling and remote termination."""
5834

59-
def __init__(self, options: StepHeartBeatOptions):
35+
STEP_HEARTBEAT_INTERVAL_SECONDS = 60
36+
37+
def __init__(self, step_id: UUID):
6038
"""Heartbeat worker constructor.
6139
6240
Args:
63-
options: Parameter group - polling interval, step id, etc.
41+
step_id: The step id heartbeat is running for.
6442
"""
65-
self.options = options
43+
self._step_id = step_id
6644

6745
self._thread: threading.Thread | None = None
6846
self._running: bool = False
@@ -79,7 +57,7 @@ def interval(self) -> int:
7957
Returns:
8058
The heartbeat polling interval value.
8159
"""
82-
return self.options.interval
60+
return self.STEP_HEARTBEAT_INTERVAL_SECONDS
8361

8462
@property
8563
def name(self) -> str:
@@ -88,7 +66,7 @@ def name(self) -> str:
8866
Returns:
8967
The name of the heartbeat worker.
9068
"""
91-
return str(self.options.name)
69+
return f"HeartBeatWorker-{self.step_id}"
9270

9371
@property
9472
def step_id(self) -> UUID:
@@ -97,14 +75,13 @@ def step_id(self) -> UUID:
9775
Returns:
9876
The id of the step heartbeat is running for.
9977
"""
100-
return self.options.step_id
78+
return self._step_id
10179

10280
# public functions
10381

10482
def start(self) -> None:
10583
"""Start the heartbeat worker on a background thread."""
10684
if self._thread and self._thread.is_alive():
107-
logger.info("%s already running; start() is a no-op", self.name)
10885
return
10986

11087
self._running = True
@@ -113,7 +90,7 @@ def start(self) -> None:
11390
target=self._run, name=self.name, daemon=True
11491
)
11592
self._thread.start()
116-
logger.info(
93+
logger.debug(
11794
"Daemon thread %s started (interval=%s)", self.name, self.interval
11895
)
11996

@@ -122,7 +99,7 @@ def stop(self) -> None:
12299
if not self._running:
123100
return
124101
self._running = False
125-
logger.info("%s stop requested", self.name)
102+
logger.debug("%s stop requested", self.name)
126103

127104
def is_alive(self) -> bool:
128105
"""Liveness of the heartbeat worker thread.
@@ -134,7 +111,7 @@ def is_alive(self) -> bool:
134111
return bool(t and t.is_alive())
135112

136113
def _run(self) -> None:
137-
logger.info("%s run() loop entered", self.name)
114+
logger.debug("%s run() loop entered", self.name)
138115
try:
139116
while self._running:
140117
try:
@@ -151,22 +128,21 @@ def _run(self) -> None:
151128
_thread.interrupt_main() # raises KeyboardInterrupt in main thread
152129
# Ensure we stop our own loop as well.
153130
self._running = False
154-
except Exception:
131+
except Exception as exc:
155132
# Log-and-continue policy for all other errors.
156-
logger.exception(
157-
"%s heartbeat() failed; continuing", self.name
133+
logger.debug(
134+
"%s heartbeat() failed with %s", self.name, str(exc)
158135
)
159136
# Sleep after each attempt (even after errors, unless stopped).
160137
if self._running:
161138
time.sleep(self.interval)
162139
finally:
163-
logger.info("%s run() loop exiting", self.name)
140+
logger.debug("%s run() loop exiting", self.name)
164141

165142
def _heartbeat(self) -> None:
166143
from zenml.config.global_config import GlobalConfiguration
167144

168145
store = GlobalConfiguration().zen_store
169-
170146
response = store.update_step_heartbeat(step_run_id=self.step_id)
171147

172148
if response.status in {

src/zenml/utils/exception_utils.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import re
1919
import textwrap
2020
import traceback
21-
from typing import TYPE_CHECKING, Optional
21+
from contextlib import ContextDecorator
22+
from types import TracebackType
23+
from typing import TYPE_CHECKING, Optional, Type
2224

2325
from zenml.constants import MEDIUMTEXT_MAX_LENGTH
2426
from zenml.logger import get_logger
@@ -91,3 +93,84 @@ def collect_exception_information(
9193
traceback=tb_bytes.decode(errors="ignore"),
9294
step_code_line=line_number,
9395
)
96+
97+
98+
class ContextReraise(ContextDecorator):
99+
"""Utility class. Capture & reraise exceptions within a context."""
100+
def __init__(
101+
self,
102+
source_exceptions: list[Type[BaseException]],
103+
target_exception: Type[BaseException],
104+
message: str | None = None,
105+
propagate_traceback: bool = True,
106+
) -> None:
107+
"""ContextReraise constructor.
108+
109+
Args:
110+
source_exceptions: A list of exception types to capture.
111+
target_exception: The exception to re-raise.
112+
message: The target exception message. If None, will be inferred from source exception arguments.
113+
propagate_traceback: Whether to propagate exception traceback when re-raising.
114+
"""
115+
self._source_exceptions = source_exceptions
116+
self._target_exception = target_exception
117+
self._propagate_traceback = propagate_traceback
118+
self._message = message
119+
120+
def __enter__(self) -> "ContextReraise":
121+
"""Context manager enter magic method.
122+
123+
Returns:
124+
The context manager.
125+
"""
126+
return self
127+
128+
@staticmethod
129+
def _get_exc_message(exc: BaseException) -> str:
130+
"""Helper function that attempts to extract a formatted exception message.
131+
132+
Args:
133+
exc: The exception to extract the message from.
134+
135+
Returns:
136+
A formatted exception message (or a fallback stringified version).
137+
"""
138+
if len(exc.args) > 1 and isinstance(exc.args[0], str):
139+
try:
140+
return exc.args[0] % exc.args[1:]
141+
except Exception:
142+
pass
143+
return str(exc)
144+
145+
def __exit__(
146+
self,
147+
exc_type: Type[BaseException] | None,
148+
exc_value: BaseException | None,
149+
trace: TracebackType | None,
150+
) -> bool:
151+
"""Context manager exit magic method.
152+
153+
Args:
154+
exc_type: The exception type.
155+
exc_value: The exception.
156+
trace: The exception trace.
157+
158+
Returns:
159+
Return False if context exits normally or re-raises exception.
160+
"""
161+
if exc_type is None:
162+
return False
163+
164+
if any(isinstance(exc_value, exc) for exc in self._source_exceptions):
165+
if self._message:
166+
exc_ = self._target_exception(self._message)
167+
else:
168+
exc_ = self._target_exception(self._get_exc_message(exc_value))
169+
170+
if self._propagate_traceback:
171+
exc_ = exc_.with_traceback(trace)
172+
173+
raise exc_
174+
else:
175+
raise exc_value
176+

0 commit comments

Comments
 (0)