Skip to content

Commit 1d19a9f

Browse files
fixup! Improvements and bug fixes
- Updates migration down revision refs - context-reraise exception - changes in the step-heartbeat logic - fix null heartbeat in list/get endpoints
1 parent 6a15b66 commit 1d19a9f

File tree

11 files changed

+290
-80
lines changed

11 files changed

+290
-80
lines changed

src/zenml/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
440440
STATUS = "/status"
441441
STEP_CONFIGURATION = "/step-configuration"
442442
STEPS = "/steps"
443-
HEARTBEAT = "heartbeat"
443+
HEARTBEAT = "/heartbeat"
444444
STOP = "/stop"
445445
TAGS = "/tags"
446446
TAG_RESOURCES = "/tag_resources"

src/zenml/models/v2/core/step_run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -813,9 +813,9 @@ def get_custom_filters(
813813
# ------------------ Heartbeat Model ---------------
814814

815815

816-
class StepHeartbeatResponse(BaseModel):
816+
class StepHeartbeatResponse(BaseModel, use_enum_values=True):
817817
"""Light-weight model for Step Heartbeat responses."""
818818

819819
id: UUID
820-
status: str
820+
status: ExecutionStatus
821821
latest_heartbeat: datetime

src/zenml/orchestrators/step_launcher.py

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@
4343
from zenml.orchestrators import utils as orchestrator_utils
4444
from zenml.orchestrators.step_runner import StepRunner
4545
from zenml.stack import Stack
46+
from zenml.steps import StepHeartBeatTerminationException, StepHeartbeatWorker
4647
from zenml.utils import env_utils, exception_utils, string_utils
48+
from zenml.utils.exception_utils import ContextReraise
4749
from zenml.utils.time_utils import utc_now
4850

4951
if TYPE_CHECKING:
@@ -167,7 +169,6 @@ def signal_handler(signum: int, frame: Any) -> None:
167169

168170
try:
169171
client = Client()
170-
pipeline_run = None
171172

172173
if self._step_run:
173174
pipeline_run = client.get_pipeline_run(
@@ -443,35 +444,66 @@ def _run_step(
443444
)
444445

445446
start_time = time.time()
446-
try:
447-
if self._step.config.step_operator:
448-
step_operator_name = None
449-
if isinstance(self._step.config.step_operator, str):
450-
step_operator_name = self._step.config.step_operator
451-
452-
self._run_step_with_step_operator(
453-
step_operator_name=step_operator_name,
454-
step_run_info=step_run_info,
447+
448+
# To have a cross-platform compatible handling of main thread termination
449+
# we use Python's interrupt_main instead of termination signals (not Windows supported).
450+
# Since interrupt_main raises KeyboardInterrupt we want in this context to capture it
451+
# and handle it as a custom exception.
452+
453+
with ContextReraise(
454+
source_exceptions=[KeyboardInterrupt],
455+
target_exception=StepHeartBeatTerminationException,
456+
message=f"Step {self._invocation_id} has been remotely stopped - terminating",
457+
propagate_traceback=False,
458+
) as ctx_reraise:
459+
logger.info(
460+
f"Initiating heartbeat for step: {self._invocation_id}"
461+
)
462+
463+
heartbeat_worker = StepHeartbeatWorker(step_id=step_run.id)
464+
heartbeat_worker.start()
465+
466+
try:
467+
if self._step.config.step_operator:
468+
step_operator_name = None
469+
if isinstance(self._step.config.step_operator, str):
470+
step_operator_name = self._step.config.step_operator
471+
472+
self._run_step_with_step_operator(
473+
step_operator_name=step_operator_name,
474+
step_run_info=step_run_info,
475+
)
476+
else:
477+
self._run_step_without_step_operator(
478+
pipeline_run=pipeline_run,
479+
step_run=step_run,
480+
step_run_info=step_run_info,
481+
input_artifacts=step_run.regular_inputs,
482+
output_artifact_uris=output_artifact_uris,
483+
)
484+
except StepHeartBeatTerminationException as exc:
485+
logger.info(ctx_reraise.message)
486+
output_utils.remove_artifact_dirs(
487+
artifact_uris=list(output_artifact_uris.values())
455488
)
456-
else:
457-
self._run_step_without_step_operator(
458-
pipeline_run=pipeline_run,
459-
step_run=step_run,
460-
step_run_info=step_run_info,
461-
input_artifacts=step_run.regular_inputs,
462-
output_artifact_uris=output_artifact_uris,
489+
raise (
490+
exc
491+
if heartbeat_worker.is_terminated
492+
else KeyboardInterrupt
463493
)
464-
except: # noqa: E722
465-
output_utils.remove_artifact_dirs(
466-
artifact_uris=list(output_artifact_uris.values())
467-
)
468-
raise
494+
except: # noqa: E722
495+
output_utils.remove_artifact_dirs(
496+
artifact_uris=list(output_artifact_uris.values())
497+
)
498+
raise
469499

470-
duration = time.time() - start_time
471-
logger.info(
472-
f"Step `{self._invocation_id}` has finished in "
473-
f"`{string_utils.get_human_readable_time(duration)}`."
474-
)
500+
heartbeat_worker.stop()
501+
502+
duration = time.time() - start_time
503+
logger.info(
504+
f"Step `{self._invocation_id}` has finished in "
505+
f"`{string_utils.get_human_readable_time(duration)}`."
506+
)
475507

476508
def _run_step_with_step_operator(
477509
self,

src/zenml/steps/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from zenml.steps.base_step import BaseStep
3030
from zenml.config.resource_settings import ResourceSettings
31+
from zenml.steps.heartbeat import StepHeartbeatWorker, StepHeartBeatTerminationException
3132
from zenml.steps.step_context import StepContext, get_step_context
3233
from zenml.steps.step_decorator import step
3334

@@ -36,5 +37,7 @@
3637
"ResourceSettings",
3738
"StepContext",
3839
"step",
39-
"get_step_context"
40+
"get_step_context",
41+
"StepHeartbeatWorker",
42+
"StepHeartBeatTerminationException",
4043
]

src/zenml/steps/heartbeat.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,8 @@
1717
import logging
1818
import threading
1919
import time
20-
from typing import Annotated
2120
from uuid import UUID
2221

23-
from pydantic import BaseModel, conint, model_validator
24-
2522
from zenml.enums import ExecutionStatus
2623

2724
logger = logging.getLogger(__name__)
@@ -33,36 +30,18 @@ class StepHeartBeatTerminationException(Exception):
3330
pass
3431

3532

36-
class StepHeartBeatOptions(BaseModel):
37-
"""Options group for step heartbeat execution."""
38-
39-
step_id: UUID
40-
interval: Annotated[int, conint(ge=10, le=60)]
41-
name: str | None = None
42-
43-
@model_validator(mode="after")
44-
def set_default_name(self) -> "StepHeartBeatOptions":
45-
"""Model validator - set name value if missing.
46-
47-
Returns:
48-
The validated step heartbeat options.
49-
"""
50-
if not self.name:
51-
self.name = f"HeartBeatWorker-{self.step_id}"
52-
53-
return self
54-
55-
56-
class HeartbeatWorker:
33+
class StepHeartbeatWorker:
5734
"""Worker class implementing heartbeat polling and remote termination."""
5835

59-
def __init__(self, options: StepHeartBeatOptions):
36+
STEP_HEARTBEAT_INTERVAL_SECONDS = 60
37+
38+
def __init__(self, step_id: UUID):
6039
"""Heartbeat worker constructor.
6140
6241
Args:
63-
options: Parameter group - polling interval, step id, etc.
42+
step_id: The step id heartbeat is running for.
6443
"""
65-
self.options = options
44+
self._step_id = step_id
6645

6746
self._thread: threading.Thread | None = None
6847
self._running: bool = False
@@ -72,14 +51,23 @@ def __init__(self, options: StepHeartBeatOptions):
7251

7352
# properties
7453

54+
@property
55+
def is_terminated(self) -> bool:
56+
"""Property function for termination signal.
57+
58+
Returns:
59+
True if the worker has been terminated.
60+
"""
61+
return self._terminated
62+
7563
@property
7664
def interval(self) -> int:
7765
"""Property function for heartbeat interval.
7866
7967
Returns:
8068
The heartbeat polling interval value.
8169
"""
82-
return self.options.interval
70+
return self.STEP_HEARTBEAT_INTERVAL_SECONDS
8371

8472
@property
8573
def name(self) -> str:
@@ -88,7 +76,7 @@ def name(self) -> str:
8876
Returns:
8977
The name of the heartbeat worker.
9078
"""
91-
return str(self.options.name)
79+
return f"HeartBeatWorker-{self.step_id}"
9280

9381
@property
9482
def step_id(self) -> UUID:
@@ -97,14 +85,13 @@ def step_id(self) -> UUID:
9785
Returns:
9886
The id of the step heartbeat is running for.
9987
"""
100-
return self.options.step_id
88+
return self._step_id
10189

10290
# public functions
10391

10492
def start(self) -> None:
10593
"""Start the heartbeat worker on a background thread."""
10694
if self._thread and self._thread.is_alive():
107-
logger.info("%s already running; start() is a no-op", self.name)
10895
return
10996

11097
self._running = True
@@ -113,7 +100,7 @@ def start(self) -> None:
113100
target=self._run, name=self.name, daemon=True
114101
)
115102
self._thread.start()
116-
logger.info(
103+
logger.debug(
117104
"Daemon thread %s started (interval=%s)", self.name, self.interval
118105
)
119106

@@ -122,7 +109,7 @@ def stop(self) -> None:
122109
if not self._running:
123110
return
124111
self._running = False
125-
logger.info("%s stop requested", self.name)
112+
logger.debug("%s stop requested", self.name)
126113

127114
def is_alive(self) -> bool:
128115
"""Liveness of the heartbeat worker thread.
@@ -134,7 +121,7 @@ def is_alive(self) -> bool:
134121
return bool(t and t.is_alive())
135122

136123
def _run(self) -> None:
137-
logger.info("%s run() loop entered", self.name)
124+
logger.debug("%s run() loop entered", self.name)
138125
try:
139126
while self._running:
140127
try:
@@ -151,22 +138,21 @@ def _run(self) -> None:
151138
_thread.interrupt_main() # raises KeyboardInterrupt in main thread
152139
# Ensure we stop our own loop as well.
153140
self._running = False
154-
except Exception:
141+
except Exception as exc:
155142
# Log-and-continue policy for all other errors.
156-
logger.exception(
157-
"%s heartbeat() failed; continuing", self.name
143+
logger.debug(
144+
"%s heartbeat() failed with %s", self.name, str(exc)
158145
)
159146
# Sleep after each attempt (even after errors, unless stopped).
160147
if self._running:
161148
time.sleep(self.interval)
162149
finally:
163-
logger.info("%s run() loop exiting", self.name)
150+
logger.debug("%s run() loop exiting", self.name)
164151

165152
def _heartbeat(self) -> None:
166153
from zenml.config.global_config import GlobalConfiguration
167154

168155
store = GlobalConfiguration().zen_store
169-
170156
response = store.update_step_heartbeat(step_run_id=self.step_id)
171157

172158
if response.status in {

0 commit comments

Comments
 (0)