Skip to content

Commit 123c1c4

Browse files
Feature:3963 Step HeartBeat components
- Backend heartbeat support (DB, API) - Heartbeat monitoring worker
1 parent 066ab58 commit 123c1c4

File tree

10 files changed

+334
-3
lines changed

10 files changed

+334
-3
lines changed

src/zenml/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
439439
STATUS = "/status"
440440
STEP_CONFIGURATION = "/step-configuration"
441441
STEPS = "/steps"
442+
HEARTBEAT = "heartbeat"
442443
STOP = "/stop"
443444
TAGS = "/tags"
444445
TAG_RESOURCES = "/tag_resources"

src/zenml/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,8 @@
331331
StepRunResponse,
332332
StepRunResponseBody,
333333
StepRunResponseMetadata,
334-
StepRunResponseResources
334+
StepRunResponseResources,
335+
StepHeartbeatResponse,
335336
)
336337
from zenml.models.v2.core.tag import (
337338
TagFilter,
@@ -874,4 +875,5 @@
874875
"ProjectStatistics",
875876
"PipelineRunDAG",
876877
"ExceptionInfo",
878+
"StepHeartbeatResponse",
877879
]

src/zenml/models/v2/core/step_run.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from uuid import UUID
2828

29-
from pydantic import ConfigDict, Field
29+
from pydantic import BaseModel, ConfigDict, Field
3030

3131
from zenml.config.step_configurations import StepConfiguration, StepSpec
3232
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
@@ -200,6 +200,10 @@ class StepRunResponseBody(ProjectScopedResponseBody):
200200
title="The end time of the step run.",
201201
default=None,
202202
)
203+
latest_heartbeat: Optional[datetime] = Field(
204+
title="The latest heartbeat of the step run.",
205+
default=None,
206+
)
203207
model_version_id: Optional[UUID] = Field(
204208
title="The ID of the model version that was "
205209
"configured by this step run explicitly.",
@@ -565,6 +569,15 @@ def end_time(self) -> Optional[datetime]:
565569
"""
566570
return self.get_body().end_time
567571

572+
@property
573+
def latest_heartbeat(self) -> Optional[datetime]:
574+
"""The `latest_heartbeat` property.
575+
576+
Returns:
577+
the value of the property.
578+
"""
579+
return self.get_body().latest_heartbeat
580+
568581
@property
569582
def logs(self) -> Optional["LogsResponse"]:
570583
"""The `logs` property.
@@ -747,3 +760,14 @@ def get_custom_filters(
747760
)
748761

749762
return custom_filters
763+
764+
765+
# ------------------ Heartbeat Model ---------------
766+
767+
768+
class StepHeartbeatResponse(BaseModel):
769+
"""Light-weight model for Step Heartbeat responses."""
770+
771+
id: UUID
772+
status: str
773+
latest_heartbeat: datetime

src/zenml/steps/heartbeat.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright (c) ZenML GmbH 2022. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12+
# or implied. See the License for the specific language governing
13+
# permissions and limitations under the License.
14+
"""ZenML Step HeartBeat functionality."""
15+
16+
import _thread
17+
import logging
18+
import threading
19+
import time
20+
from typing import Annotated
21+
from uuid import UUID
22+
23+
from pydantic import BaseModel, conint, model_validator
24+
25+
from zenml.enums import ExecutionStatus
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
class StepHeartBeatTerminationException(Exception):
31+
"""Custom exception class for heartbeat termination."""
32+
33+
pass
34+
35+
36+
class StepHeartBeatOptions(BaseModel):
37+
"""Options group for step heartbeat execution."""
38+
39+
step_id: UUID
40+
interval: Annotated[int, conint(ge=10, le=60)]
41+
name: str | None = None
42+
43+
@model_validator(mode="after")
44+
def set_default_name(self) -> "StepHeartBeatOptions":
45+
"""Model validator - set name value if missing.
46+
47+
Returns:
48+
The validated step heartbeat options.
49+
"""
50+
if not self.name:
51+
self.name = f"HeartBeatWorker-{self.step_id}"
52+
53+
return self
54+
55+
56+
class HeartbeatWorker:
57+
"""Worker class implementing heartbeat polling and remote termination."""
58+
59+
def __init__(self, options: StepHeartBeatOptions):
60+
"""Heartbeat worker constructor.
61+
62+
Args:
63+
options: Parameter group - polling interval, step id, etc.
64+
"""
65+
self.options = options
66+
67+
self._thread: threading.Thread | None = None
68+
self._running: bool = False
69+
self._terminated: bool = (
70+
False # one-shot guard to avoid repeated interrupts
71+
)
72+
73+
# properties
74+
75+
@property
76+
def interval(self) -> int:
77+
"""Property function for heartbeat interval.
78+
79+
Returns:
80+
The heartbeat polling interval value.
81+
"""
82+
return self.options.interval
83+
84+
@property
85+
def name(self) -> str:
86+
"""Property function for heartbeat worker name.
87+
88+
Returns:
89+
The name of the heartbeat worker.
90+
"""
91+
return str(self.options.name)
92+
93+
@property
94+
def step_id(self) -> UUID:
95+
"""Property function for heartbeat worker step ID.
96+
97+
Returns:
98+
The id of the step heartbeat is running for.
99+
"""
100+
return self.options.step_id
101+
102+
# public functions
103+
104+
def start(self) -> None:
105+
"""Start the heartbeat worker on a background thread."""
106+
if self._thread and self._thread.is_alive():
107+
logger.info("%s already running; start() is a no-op", self.name)
108+
return
109+
110+
self._running = True
111+
self._terminated = False
112+
self._thread = threading.Thread(
113+
target=self._run, name=self.name, daemon=True
114+
)
115+
self._thread.start()
116+
logger.info(
117+
"Daemon thread %s started (interval=%s)", self.name, self.interval
118+
)
119+
120+
def stop(self) -> None:
121+
"""Stops the heartbeat worker."""
122+
if not self._running:
123+
return
124+
self._running = False
125+
logger.info("%s stop requested", self.name)
126+
127+
def is_alive(self) -> bool:
128+
"""Liveness of the heartbeat worker thread.
129+
130+
Returns:
131+
True if the heartbeat worker thread is alive, False otherwise.
132+
"""
133+
t = self._thread
134+
return bool(t and t.is_alive())
135+
136+
def _run(self) -> None:
137+
logger.info("%s run() loop entered", self.name)
138+
try:
139+
while self._running:
140+
try:
141+
self._heartbeat()
142+
except StepHeartBeatTerminationException:
143+
# One-shot: signal the main thread and stop the loop.
144+
if not self._terminated:
145+
self._terminated = True
146+
logger.info(
147+
"%s received HeartBeatTerminationException; "
148+
"interrupting main thread",
149+
self.name,
150+
)
151+
_thread.interrupt_main() # raises KeyboardInterrupt in main thread
152+
# Ensure we stop our own loop as well.
153+
self._running = False
154+
except Exception:
155+
# Log-and-continue policy for all other errors.
156+
logger.exception(
157+
"%s heartbeat() failed; continuing", self.name
158+
)
159+
# Sleep after each attempt (even after errors, unless stopped).
160+
if self._running:
161+
time.sleep(self.interval)
162+
finally:
163+
logger.info("%s run() loop exiting", self.name)
164+
165+
def _heartbeat(self) -> None:
166+
from zenml.config.global_config import GlobalConfiguration
167+
168+
store = GlobalConfiguration().zen_store
169+
170+
response = store.update_step_heartbeat(step_run_id=self.step_id)
171+
172+
if response.status in {
173+
ExecutionStatus.STOPPED,
174+
ExecutionStatus.STOPPING,
175+
}:
176+
raise StepHeartBeatTerminationException(
177+
f"Step {self.step_id} remotely stopped with status {response.status}."
178+
)

src/zenml/zen_server/routers/steps_endpoints.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from zenml.constants import (
2222
API,
23+
HEARTBEAT,
2324
LOGS,
2425
STATUS,
2526
STEP_CONFIGURATION,
@@ -38,6 +39,7 @@
3839
StepRunResponse,
3940
StepRunUpdate,
4041
)
42+
from zenml.models.v2.core.step_run import StepHeartbeatResponse
4143
from zenml.zen_server.auth import (
4244
AuthContext,
4345
authorize,
@@ -200,6 +202,30 @@ def update_step(
200202
return dehydrate_response_model(updated_step)
201203

202204

205+
@router.put(
206+
"/{step_run_id}/" + HEARTBEAT,
207+
responses={401: error_response, 404: error_response, 422: error_response},
208+
)
209+
@async_fastapi_endpoint_wrapper(deduplicate=True)
210+
def update_heartbeat(
211+
step_run_id: UUID,
212+
_: AuthContext = Security(authorize),
213+
) -> StepHeartbeatResponse:
214+
"""Updates a step.
215+
216+
Args:
217+
step_run_id: ID of the step.
218+
219+
Returns:
220+
The step heartbeat response (id, status, last_heartbeat).
221+
"""
222+
step = zen_store().get_run_step(step_run_id, hydrate=True)
223+
pipeline_run = zen_store().get_run(step.pipeline_run_id)
224+
verify_permission_for_model(pipeline_run, action=Action.UPDATE)
225+
226+
return zen_store().update_step_heartbeat(step_run_id=step_run_id)
227+
228+
203229
@router.get(
204230
"/{step_id}" + STEP_CONFIGURATION,
205231
responses={401: error_response, 404: error_response, 422: error_response},
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Add heartbeat column for step runs [a5a17015b681].
2+
3+
Revision ID: a5a17015b681
4+
Revises: 0.90.0
5+
Create Date: 2025-10-13 12:24:12.470803
6+
7+
"""
8+
9+
import sqlalchemy as sa
10+
from alembic import op
11+
12+
# revision identifiers, used by Alembic.
13+
revision = "a5a17015b681"
14+
down_revision = "0.90.0"
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
"""Upgrade database schema and/or data, creating a new revision."""
21+
with op.batch_alter_table("step_run", schema=None) as batch_op:
22+
batch_op.add_column(
23+
sa.Column("latest_heartbeat", sa.DateTime(), nullable=True)
24+
)
25+
26+
27+
def downgrade() -> None:
28+
"""Downgrade database schema and/or data back to the previous revision."""
29+
with op.batch_alter_table("step_run", schema=None) as batch_op:
30+
batch_op.drop_column("latest_heartbeat")

src/zenml/zen_stores/rest_zen_store.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
ENV_ZENML_DISABLE_CLIENT_SERVER_MISMATCH_WARNING,
7575
EVENT_SOURCES,
7676
FLAVORS,
77+
HEARTBEAT,
7778
INFO,
7879
LOGIN,
7980
LOGS,
@@ -254,6 +255,7 @@
254255
StackRequest,
255256
StackResponse,
256257
StackUpdate,
258+
StepHeartbeatResponse,
257259
StepRunFilter,
258260
StepRunRequest,
259261
StepRunResponse,
@@ -3303,6 +3305,23 @@ def update_run_step(
33033305
route=STEPS,
33043306
)
33053307

3308+
def update_step_heartbeat(
3309+
self, step_run_id: UUID
3310+
) -> StepHeartbeatResponse:
3311+
"""Updates a step run heartbeat.
3312+
3313+
Args:
3314+
step_run_id: The ID of the step to update.
3315+
3316+
Returns:
3317+
The step heartbeat response.
3318+
"""
3319+
response_body = self.put(
3320+
f"{STEPS}/{str(step_run_id)}/{HEARTBEAT}", body=None, params=None
3321+
)
3322+
3323+
return StepHeartbeatResponse.model_validate(response_body)
3324+
33063325
# -------------------- Triggers --------------------
33073326

33083327
def create_trigger(self, trigger: TriggerRequest) -> TriggerResponse:

src/zenml/zen_stores/schemas/step_run_schemas.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
8686
# Fields
8787
start_time: Optional[datetime] = Field(nullable=True)
8888
end_time: Optional[datetime] = Field(nullable=True)
89+
latest_heartbeat: Optional[datetime] = Field(
90+
nullable=True,
91+
description="The latest execution heartbeat.",
92+
)
8993
status: str = Field(nullable=False)
9094

9195
docstring: Optional[str] = Field(sa_column=Column(TEXT, nullable=True))

0 commit comments

Comments
 (0)