Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/zenml/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int:
STATUS = "/status"
STEP_CONFIGURATION = "/step-configuration"
STEPS = "/steps"
HEARTBEAT = "/heartbeat"
STOP = "/stop"
TAGS = "/tags"
TAG_RESOURCES = "/tag_resources"
Expand Down
4 changes: 3 additions & 1 deletion src/zenml/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@
StepRunResponse,
StepRunResponseBody,
StepRunResponseMetadata,
StepRunResponseResources
StepRunResponseResources,
StepHeartbeatResponse,
)
from zenml.models.v2.core.tag import (
TagFilter,
Expand Down Expand Up @@ -908,4 +909,5 @@
"StepRunIdentifier",
"ArtifactVersionIdentifier",
"ModelVersionIdentifier",
"StepHeartbeatResponse",
]
26 changes: 25 additions & 1 deletion src/zenml/models/v2/core/step_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from uuid import UUID

from pydantic import ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field

from zenml.config.step_configurations import StepConfiguration, StepSpec
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
Expand Down Expand Up @@ -210,6 +210,10 @@ class StepRunResponseBody(ProjectScopedResponseBody):
title="The end time of the step run.",
default=None,
)
latest_heartbeat: Optional[datetime] = Field(
title="The latest heartbeat of the step run.",
default=None,
)
model_version_id: Optional[UUID] = Field(
title="The ID of the model version that was "
"configured by this step run explicitly.",
Expand Down Expand Up @@ -589,6 +593,15 @@ def end_time(self) -> Optional[datetime]:
"""
return self.get_body().end_time

@property
def latest_heartbeat(self) -> Optional[datetime]:
"""The `latest_heartbeat` property.
Returns:
the value of the property.
"""
return self.get_body().latest_heartbeat

@property
def logs(self) -> Optional["LogsResponse"]:
"""The `logs` property.
Expand Down Expand Up @@ -795,3 +808,14 @@ def get_custom_filters(
custom_filters.append(cache_expiration_filter)

return custom_filters


# ------------------ Heartbeat Model ---------------


class StepHeartbeatResponse(BaseModel):
"""Light-weight model for Step Heartbeat responses."""

id: UUID
status: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be of type ExecutionStatus?

latest_heartbeat: datetime
78 changes: 51 additions & 27 deletions src/zenml/orchestrators/step_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
from zenml.orchestrators import utils as orchestrator_utils
from zenml.orchestrators.step_runner import StepRunner
from zenml.stack import Stack
from zenml.steps import StepHeartBeatTerminationException, StepHeartbeatWorker
from zenml.utils import env_utils, exception_utils, string_utils
from zenml.utils.exception_utils import ContextReraise
from zenml.utils.time_utils import utc_now

if TYPE_CHECKING:
Expand Down Expand Up @@ -167,7 +169,6 @@ def signal_handler(signum: int, frame: Any) -> None:

try:
client = Client()
pipeline_run = None

if self._step_run:
pipeline_run = client.get_pipeline_run(
Expand Down Expand Up @@ -443,35 +444,58 @@ def _run_step(
)

start_time = time.time()
try:
if self._step.config.step_operator:
step_operator_name = None
if isinstance(self._step.config.step_operator, str):
step_operator_name = self._step.config.step_operator

self._run_step_with_step_operator(
step_operator_name=step_operator_name,
step_run_info=step_run_info,

# To have a cross-platform compatible handling of main thread termination
# we use Python's interrupt_main instead of termination signals (not Windows supported).
# Since interrupt_main raises KeyboardInterrupt we want in this context to capture it
# and handle it as a custom exception.

with ContextReraise(
source_exceptions=[KeyboardInterrupt],
target_exception=StepHeartBeatTerminationException,
message=f"Step {self._invocation_id} has been remotely stopped - terminating",
propagate_traceback=False,
) as ctx_reraise:

logger.info(f"Initiating heartbeat for step {self._invocation_id}")

StepHeartbeatWorker(step_id=step_run.id).start()

try:
if self._step.config.step_operator:
step_operator_name = None
if isinstance(self._step.config.step_operator, str):
step_operator_name = self._step.config.step_operator

self._run_step_with_step_operator(
step_operator_name=step_operator_name,
step_run_info=step_run_info,
)
else:
self._run_step_without_step_operator(
pipeline_run=pipeline_run,
step_run=step_run,
step_run_info=step_run_info,
input_artifacts=step_run.regular_inputs,
output_artifact_uris=output_artifact_uris,
)
except StepHeartBeatTerminationException:
logger.info(ctx_reraise.message)
output_utils.remove_artifact_dirs(
artifact_uris=list(output_artifact_uris.values())
)
else:
self._run_step_without_step_operator(
pipeline_run=pipeline_run,
step_run=step_run,
step_run_info=step_run_info,
input_artifacts=step_run.regular_inputs,
output_artifact_uris=output_artifact_uris,
raise
except: # noqa: E722
output_utils.remove_artifact_dirs(
artifact_uris=list(output_artifact_uris.values())
)
except: # noqa: E722
output_utils.remove_artifact_dirs(
artifact_uris=list(output_artifact_uris.values())
)
raise
raise

duration = time.time() - start_time
logger.info(
f"Step `{self._invocation_id}` has finished in "
f"`{string_utils.get_human_readable_time(duration)}`."
)
duration = time.time() - start_time
logger.info(
f"Step `{self._invocation_id}` has finished in "
f"`{string_utils.get_human_readable_time(duration)}`."
)

def _run_step_with_step_operator(
self,
Expand Down
5 changes: 4 additions & 1 deletion src/zenml/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from zenml.steps.base_step import BaseStep
from zenml.config.resource_settings import ResourceSettings
from zenml.steps.heartbeat import StepHeartbeatWorker, StepHeartBeatTerminationException
from zenml.steps.step_context import StepContext, get_step_context
from zenml.steps.step_decorator import step

Expand All @@ -36,5 +37,7 @@
"ResourceSettings",
"StepContext",
"step",
"get_step_context"
"get_step_context",
"StepHeartbeatWorker",
"StepHeartBeatTerminationException",
]
158 changes: 158 additions & 0 deletions src/zenml/steps/heartbeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) ZenML GmbH 2022. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""ZenML Step HeartBeat functionality."""

import _thread
import logging
import threading
import time
from uuid import UUID

from zenml.enums import ExecutionStatus

logger = logging.getLogger(__name__)


class StepHeartBeatTerminationException(Exception):
"""Custom exception class for heartbeat termination."""

pass


class StepHeartbeatWorker:
"""Worker class implementing heartbeat polling and remote termination."""

STEP_HEARTBEAT_INTERVAL_SECONDS = 30

def __init__(self, step_id: UUID):
"""Heartbeat worker constructor.
Args:
options: Parameter group - polling interval, step id, etc.
"""

self._step_id = step_id

self._thread: threading.Thread | None = None
self._running: bool = False
self._terminated: bool = (
False # one-shot guard to avoid repeated interrupts
)

# properties

@property
def interval(self) -> int:
"""Property function for heartbeat interval.
Returns:
The heartbeat polling interval value.
"""
return self.STEP_HEARTBEAT_INTERVAL_SECONDS

@property
def name(self) -> str:
"""Property function for heartbeat worker name.
Returns:
The name of the heartbeat worker.
"""
return f"HeartBeatWorker-{self.step_id}"

@property
def step_id(self) -> UUID:
"""Property function for heartbeat worker step ID.
Returns:
The id of the step heartbeat is running for.
"""
return self.step_id

# public functions

def start(self) -> None:
"""Start the heartbeat worker on a background thread."""
if self._thread and self._thread.is_alive():
logger.info("%s already running; start() is a no-op", self.name)
return

self._running = True
self._terminated = False
self._thread = threading.Thread(
target=self._run, name=self.name, daemon=True
)
self._thread.start()
logger.info(
"Daemon thread %s started (interval=%s)", self.name, self.interval
)

def stop(self) -> None:
"""Stops the heartbeat worker."""
if not self._running:
return
self._running = False
logger.info("%s stop requested", self.name)

def is_alive(self) -> bool:
"""Liveness of the heartbeat worker thread.
Returns:
True if the heartbeat worker thread is alive, False otherwise.
"""
t = self._thread
return bool(t and t.is_alive())

def _run(self) -> None:
logger.info("%s run() loop entered", self.name)
try:
while self._running:
try:
self._heartbeat()
except StepHeartBeatTerminationException:
# One-shot: signal the main thread and stop the loop.
if not self._terminated:
self._terminated = True
logger.info(
"%s received HeartBeatTerminationException; "
"interrupting main thread",
self.name,
)
_thread.interrupt_main() # raises KeyboardInterrupt in main thread
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My dynamic pipelines PR introduces running multiple steps in different threads, which doesn't work with this I think.

Can we somehow store the thread from which the heartbeat worker was started, and then interrupt that thread instead of the main one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that is an important change, good point. interrupt_main will not work here, we will need to change the pattern a bit. Should I work my changes from your branch?

# Ensure we stop our own loop as well.
self._running = False
except Exception:
Copy link
Contributor Author

@Json-Andriopoulos Json-Andriopoulos Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Improve this. For sure try to capture HTTP errors in more verbose logs to avoid excessive log generation if the error is for instance server raising 500 status code.

# Log-and-continue policy for all other errors.
logger.exception(
"%s heartbeat() failed; continuing", self.name
)
# Sleep after each attempt (even after errors, unless stopped).
if self._running:
time.sleep(self.interval)
finally:
logger.info("%s run() loop exiting", self.name)

def _heartbeat(self) -> None:
from zenml.config.global_config import GlobalConfiguration

store = GlobalConfiguration().zen_store

response = store.update_step_heartbeat(step_run_id=self.step_id)

if response.status in {
ExecutionStatus.STOPPED,
ExecutionStatus.STOPPING,
}:
raise StepHeartBeatTerminationException(
f"Step {self.step_id} remotely stopped with status {response.status}."
)
Loading