From 6a15b66c468e2574b3341f600a6b4e89365d3718 Mon Sep 17 00:00:00 2001 From: Iason Andriopoulos Date: Mon, 13 Oct 2025 14:44:22 +0300 Subject: [PATCH 1/2] Feature:3963 Step HeartBeat components - Backend heartbeat support (DB, API) - Heartbeat monitoring worker --- src/zenml/constants.py | 1 + src/zenml/models/__init__.py | 4 +- src/zenml/models/v2/core/step_run.py | 26 ++- src/zenml/steps/heartbeat.py | 178 ++++++++++++++++++ .../zen_server/routers/steps_endpoints.py | 26 +++ ...b681_add_heartbeat_column_for_step_runs.py | 30 +++ src/zenml/zen_stores/rest_zen_store.py | 19 ++ .../zen_stores/schemas/step_run_schemas.py | 4 + src/zenml/zen_stores/sql_zen_store.py | 35 +++- src/zenml/zen_stores/zen_store_interface.py | 14 ++ 10 files changed, 334 insertions(+), 3 deletions(-) create mode 100644 src/zenml/steps/heartbeat.py create mode 100644 src/zenml/zen_stores/migrations/versions/a5a17015b681_add_heartbeat_column_for_step_runs.py diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 47c4e728151..538cc197920 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -440,6 +440,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: STATUS = "/status" STEP_CONFIGURATION = "/step-configuration" STEPS = "/steps" +HEARTBEAT = "heartbeat" STOP = "/stop" TAGS = "/tags" TAG_RESOURCES = "/tag_resources" diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index ef0737dfa83..e5468dd779e 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -346,7 +346,8 @@ StepRunResponse, StepRunResponseBody, StepRunResponseMetadata, - StepRunResponseResources + StepRunResponseResources, + StepHeartbeatResponse, ) from zenml.models.v2.core.tag import ( TagFilter, @@ -908,4 +909,5 @@ "StepRunIdentifier", "ArtifactVersionIdentifier", "ModelVersionIdentifier", + "StepHeartbeatResponse", ] diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index 8c831fca4a9..026a9900ccd 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -26,7 +26,7 @@ ) from uuid import UUID -from pydantic import ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field from zenml.config.step_configurations import StepConfiguration, StepSpec from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH @@ -210,6 +210,10 @@ class StepRunResponseBody(ProjectScopedResponseBody): title="The end time of the step run.", default=None, ) + latest_heartbeat: Optional[datetime] = Field( + title="The latest heartbeat of the step run.", + default=None, + ) model_version_id: Optional[UUID] = Field( title="The ID of the model version that was " "configured by this step run explicitly.", @@ -589,6 +593,15 @@ def end_time(self) -> Optional[datetime]: """ return self.get_body().end_time + @property + def latest_heartbeat(self) -> Optional[datetime]: + """The `latest_heartbeat` property. + + Returns: + the value of the property. + """ + return self.get_body().latest_heartbeat + @property def logs(self) -> Optional["LogsResponse"]: """The `logs` property. @@ -795,3 +808,14 @@ def get_custom_filters( custom_filters.append(cache_expiration_filter) return custom_filters + + +# ------------------ Heartbeat Model --------------- + + +class StepHeartbeatResponse(BaseModel): + """Light-weight model for Step Heartbeat responses.""" + + id: UUID + status: str + latest_heartbeat: datetime diff --git a/src/zenml/steps/heartbeat.py b/src/zenml/steps/heartbeat.py new file mode 100644 index 00000000000..fb189edf2a1 --- /dev/null +++ b/src/zenml/steps/heartbeat.py @@ -0,0 +1,178 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""ZenML Step HeartBeat functionality.""" + +import _thread +import logging +import threading +import time +from typing import Annotated +from uuid import UUID + +from pydantic import BaseModel, conint, model_validator + +from zenml.enums import ExecutionStatus + +logger = logging.getLogger(__name__) + + +class StepHeartBeatTerminationException(Exception): + """Custom exception class for heartbeat termination.""" + + pass + + +class StepHeartBeatOptions(BaseModel): + """Options group for step heartbeat execution.""" + + step_id: UUID + interval: Annotated[int, conint(ge=10, le=60)] + name: str | None = None + + @model_validator(mode="after") + def set_default_name(self) -> "StepHeartBeatOptions": + """Model validator - set name value if missing. + + Returns: + The validated step heartbeat options. + """ + if not self.name: + self.name = f"HeartBeatWorker-{self.step_id}" + + return self + + +class HeartbeatWorker: + """Worker class implementing heartbeat polling and remote termination.""" + + def __init__(self, options: StepHeartBeatOptions): + """Heartbeat worker constructor. + + Args: + options: Parameter group - polling interval, step id, etc. + """ + self.options = options + + self._thread: threading.Thread | None = None + self._running: bool = False + self._terminated: bool = ( + False # one-shot guard to avoid repeated interrupts + ) + + # properties + + @property + def interval(self) -> int: + """Property function for heartbeat interval. + + Returns: + The heartbeat polling interval value. + """ + return self.options.interval + + @property + def name(self) -> str: + """Property function for heartbeat worker name. + + Returns: + The name of the heartbeat worker. + """ + return str(self.options.name) + + @property + def step_id(self) -> UUID: + """Property function for heartbeat worker step ID. + + Returns: + The id of the step heartbeat is running for. + """ + return self.options.step_id + + # public functions + + def start(self) -> None: + """Start the heartbeat worker on a background thread.""" + if self._thread and self._thread.is_alive(): + logger.info("%s already running; start() is a no-op", self.name) + return + + self._running = True + self._terminated = False + self._thread = threading.Thread( + target=self._run, name=self.name, daemon=True + ) + self._thread.start() + logger.info( + "Daemon thread %s started (interval=%s)", self.name, self.interval + ) + + def stop(self) -> None: + """Stops the heartbeat worker.""" + if not self._running: + return + self._running = False + logger.info("%s stop requested", self.name) + + def is_alive(self) -> bool: + """Liveness of the heartbeat worker thread. + + Returns: + True if the heartbeat worker thread is alive, False otherwise. + """ + t = self._thread + return bool(t and t.is_alive()) + + def _run(self) -> None: + logger.info("%s run() loop entered", self.name) + try: + while self._running: + try: + self._heartbeat() + except StepHeartBeatTerminationException: + # One-shot: signal the main thread and stop the loop. + if not self._terminated: + self._terminated = True + logger.info( + "%s received HeartBeatTerminationException; " + "interrupting main thread", + self.name, + ) + _thread.interrupt_main() # raises KeyboardInterrupt in main thread + # Ensure we stop our own loop as well. + self._running = False + except Exception: + # Log-and-continue policy for all other errors. + logger.exception( + "%s heartbeat() failed; continuing", self.name + ) + # Sleep after each attempt (even after errors, unless stopped). + if self._running: + time.sleep(self.interval) + finally: + logger.info("%s run() loop exiting", self.name) + + def _heartbeat(self) -> None: + from zenml.config.global_config import GlobalConfiguration + + store = GlobalConfiguration().zen_store + + response = store.update_step_heartbeat(step_run_id=self.step_id) + + if response.status in { + ExecutionStatus.STOPPED, + ExecutionStatus.STOPPING, + }: + raise StepHeartBeatTerminationException( + f"Step {self.step_id} remotely stopped with status {response.status}." + ) diff --git a/src/zenml/zen_server/routers/steps_endpoints.py b/src/zenml/zen_server/routers/steps_endpoints.py index 772d480348c..c7655025ea8 100644 --- a/src/zenml/zen_server/routers/steps_endpoints.py +++ b/src/zenml/zen_server/routers/steps_endpoints.py @@ -20,6 +20,7 @@ from zenml.constants import ( API, + HEARTBEAT, LOGS, STATUS, STEP_CONFIGURATION, @@ -38,6 +39,7 @@ StepRunResponse, StepRunUpdate, ) +from zenml.models.v2.core.step_run import StepHeartbeatResponse from zenml.zen_server.auth import ( AuthContext, authorize, @@ -200,6 +202,30 @@ def update_step( return dehydrate_response_model(updated_step) +@router.put( + "/{step_run_id}/" + HEARTBEAT, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@async_fastapi_endpoint_wrapper(deduplicate=True) +def update_heartbeat( + step_run_id: UUID, + _: AuthContext = Security(authorize), +) -> StepHeartbeatResponse: + """Updates a step. + + Args: + step_run_id: ID of the step. + + Returns: + The step heartbeat response (id, status, last_heartbeat). + """ + step = zen_store().get_run_step(step_run_id, hydrate=True) + pipeline_run = zen_store().get_run(step.pipeline_run_id) + verify_permission_for_model(pipeline_run, action=Action.UPDATE) + + return zen_store().update_step_heartbeat(step_run_id=step_run_id) + + @router.get( "/{step_id}" + STEP_CONFIGURATION, responses={401: error_response, 404: error_response, 422: error_response}, diff --git a/src/zenml/zen_stores/migrations/versions/a5a17015b681_add_heartbeat_column_for_step_runs.py b/src/zenml/zen_stores/migrations/versions/a5a17015b681_add_heartbeat_column_for_step_runs.py new file mode 100644 index 00000000000..63994a240e6 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/a5a17015b681_add_heartbeat_column_for_step_runs.py @@ -0,0 +1,30 @@ +"""Add heartbeat column for step runs [a5a17015b681]. + +Revision ID: a5a17015b681 +Revises: 0.90.0 +Create Date: 2025-10-13 12:24:12.470803 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a5a17015b681" +down_revision = "0.90.0" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + with op.batch_alter_table("step_run", schema=None) as batch_op: + batch_op.add_column( + sa.Column("latest_heartbeat", sa.DateTime(), nullable=True) + ) + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + with op.batch_alter_table("step_run", schema=None) as batch_op: + batch_op.drop_column("latest_heartbeat") diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 24a0e50b3c4..a31446a9741 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -75,6 +75,7 @@ ENV_ZENML_DISABLE_CLIENT_SERVER_MISMATCH_WARNING, EVENT_SOURCES, FLAVORS, + HEARTBEAT, INFO, LOGIN, LOGS, @@ -259,6 +260,7 @@ StackRequest, StackResponse, StackUpdate, + StepHeartbeatResponse, StepRunFilter, StepRunRequest, StepRunResponse, @@ -3382,6 +3384,23 @@ def update_run_step( route=STEPS, ) + def update_step_heartbeat( + self, step_run_id: UUID + ) -> StepHeartbeatResponse: + """Updates a step run heartbeat. + + Args: + step_run_id: The ID of the step to update. + + Returns: + The step heartbeat response. + """ + response_body = self.put( + f"{STEPS}/{str(step_run_id)}/{HEARTBEAT}", body=None, params=None + ) + + return StepHeartbeatResponse.model_validate(response_body) + # -------------------- Triggers -------------------- def create_trigger(self, trigger: TriggerRequest) -> TriggerResponse: diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 2dced06c2e5..91aa2d6fd11 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -86,6 +86,10 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): # Fields start_time: Optional[datetime] = Field(nullable=True) end_time: Optional[datetime] = Field(nullable=True) + latest_heartbeat: Optional[datetime] = Field( + nullable=True, + description="The latest execution heartbeat.", + ) status: str = Field(nullable=False) docstring: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 451007e05eb..c302a5dabb5 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -13,6 +13,8 @@ # permissions and limitations under the License. """SQL Zen Store implementation.""" +from zenml.models.v2.core.step_run import StepHeartbeatResponse + try: import sqlalchemy # noqa except ImportError: @@ -35,7 +37,7 @@ import sys import time from collections import defaultdict -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from functools import lru_cache from pathlib import Path from typing import ( @@ -10297,6 +10299,37 @@ def list_run_steps( apply_query_options_from_schema=True, ) + def update_step_heartbeat( + self, step_run_id: UUID + ) -> StepHeartbeatResponse: + """Updates a step run heartbeat value. + + Lightweight function for fast updates as heartbeats may be received at bulk. + + Args: + step_run_id: ID of the step run. + + Returns: + Step heartbeat response (minimal info, id, status & latest_heartbeat). + """ + with Session(self.engine) as session: + existing_step_run = self._get_schema_by_id( + resource_id=step_run_id, + schema_class=StepRunSchema, + session=session, + ) + + existing_step_run.latest_heartbeat = datetime.now(timezone.utc) + + session.commit() + session.refresh(existing_step_run) + + return StepHeartbeatResponse( + id=existing_step_run.id, + status=existing_step_run.status, + latest_heartbeat=existing_step_run.latest_heartbeat, + ) + def update_run_step( self, step_run_id: UUID, diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 350a74c0387..9a9400a0b97 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -139,6 +139,7 @@ StackRequest, StackResponse, StackUpdate, + StepHeartbeatResponse, StepRunFilter, StepRunRequest, StepRunResponse, @@ -2581,6 +2582,19 @@ def update_run_step( KeyError: if the step run doesn't exist. """ + @abstractmethod + def update_step_heartbeat( + self, step_run_id: UUID + ) -> StepHeartbeatResponse: + """Updates a step run heartbeat. + + Args: + step_run_id: The ID of the step to update. + + Returns: + The step heartbeat response. + """ + # -------------------- Triggers -------------------- @abstractmethod From c40b9bfec2eea8f88b41b9094ebf142eb0695be7 Mon Sep 17 00:00:00 2001 From: Iason Andriopoulos Date: Mon, 27 Oct 2025 09:21:56 +0200 Subject: [PATCH 2/2] fixup! Improvements and bug fixes - Updates migration down revision refs - context-reraise exception - changes in the step-heartbeat logic - fix null heartbeat in list/get endpoints --- src/zenml/constants.py | 2 +- src/zenml/models/v2/core/step_run.py | 4 +- src/zenml/orchestrators/step_launcher.py | 86 +++++++++---- src/zenml/steps/__init__.py | 5 +- src/zenml/steps/heartbeat.py | 64 ++++------ src/zenml/utils/exception_utils.py | 85 ++++++++++++- .../zen_server/routers/steps_endpoints.py | 50 +++++++- ...b681_add_heartbeat_column_for_step_runs.py | 2 +- src/zenml/zen_stores/rest_zen_store.py | 3 +- .../zen_stores/schemas/step_run_schemas.py | 1 + .../functional/steps/test_heartbeat.py | 118 ++++++++++++++++++ tests/unit/utils/test_exception_utils.py | 70 ++++++++++- 12 files changed, 410 insertions(+), 80 deletions(-) create mode 100644 tests/integration/functional/steps/test_heartbeat.py diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 538cc197920..30b7579fb2d 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -440,7 +440,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: STATUS = "/status" STEP_CONFIGURATION = "/step-configuration" STEPS = "/steps" -HEARTBEAT = "heartbeat" +HEARTBEAT = "/heartbeat" STOP = "/stop" TAGS = "/tags" TAG_RESOURCES = "/tag_resources" diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index 026a9900ccd..c51b408a3bb 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -813,9 +813,9 @@ def get_custom_filters( # ------------------ Heartbeat Model --------------- -class StepHeartbeatResponse(BaseModel): +class StepHeartbeatResponse(BaseModel, use_enum_values=True): """Light-weight model for Step Heartbeat responses.""" id: UUID - status: str + status: ExecutionStatus latest_heartbeat: datetime diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index eacb6fd24d1..06292002584 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -43,7 +43,9 @@ from zenml.orchestrators import utils as orchestrator_utils from zenml.orchestrators.step_runner import StepRunner from zenml.stack import Stack +from zenml.steps import StepHeartBeatTerminationException, StepHeartbeatWorker from zenml.utils import env_utils, exception_utils, string_utils +from zenml.utils.exception_utils import ContextReraise from zenml.utils.time_utils import utc_now if TYPE_CHECKING: @@ -167,7 +169,6 @@ def signal_handler(signum: int, frame: Any) -> None: try: client = Client() - pipeline_run = None if self._step_run: pipeline_run = client.get_pipeline_run( @@ -443,35 +444,66 @@ def _run_step( ) start_time = time.time() - try: - if self._step.config.step_operator: - step_operator_name = None - if isinstance(self._step.config.step_operator, str): - step_operator_name = self._step.config.step_operator - - self._run_step_with_step_operator( - step_operator_name=step_operator_name, - step_run_info=step_run_info, + + # To have a cross-platform compatible handling of main thread termination + # we use Python's interrupt_main instead of termination signals (not Windows supported). + # Since interrupt_main raises KeyboardInterrupt we want in this context to capture it + # and handle it as a custom exception. + + with ContextReraise( + source_exceptions=[KeyboardInterrupt], + target_exception=StepHeartBeatTerminationException, + message=f"Step {self._invocation_id} has been remotely stopped - terminating", + propagate_traceback=False, + ) as ctx_reraise: + logger.info( + f"Initiating heartbeat for step: {self._invocation_id}" + ) + + heartbeat_worker = StepHeartbeatWorker(step_id=step_run.id) + heartbeat_worker.start() + + try: + if self._step.config.step_operator: + step_operator_name = None + if isinstance(self._step.config.step_operator, str): + step_operator_name = self._step.config.step_operator + + self._run_step_with_step_operator( + step_operator_name=step_operator_name, + step_run_info=step_run_info, + ) + else: + self._run_step_without_step_operator( + pipeline_run=pipeline_run, + step_run=step_run, + step_run_info=step_run_info, + input_artifacts=step_run.regular_inputs, + output_artifact_uris=output_artifact_uris, + ) + except StepHeartBeatTerminationException as exc: + logger.info(ctx_reraise.message) + output_utils.remove_artifact_dirs( + artifact_uris=list(output_artifact_uris.values()) ) - else: - self._run_step_without_step_operator( - pipeline_run=pipeline_run, - step_run=step_run, - step_run_info=step_run_info, - input_artifacts=step_run.regular_inputs, - output_artifact_uris=output_artifact_uris, + raise ( + exc + if heartbeat_worker.is_terminated + else KeyboardInterrupt ) - except: # noqa: E722 - output_utils.remove_artifact_dirs( - artifact_uris=list(output_artifact_uris.values()) - ) - raise + except: # noqa: E722 + output_utils.remove_artifact_dirs( + artifact_uris=list(output_artifact_uris.values()) + ) + raise - duration = time.time() - start_time - logger.info( - f"Step `{self._invocation_id}` has finished in " - f"`{string_utils.get_human_readable_time(duration)}`." - ) + heartbeat_worker.stop() + + duration = time.time() - start_time + logger.info( + f"Step `{self._invocation_id}` has finished in " + f"`{string_utils.get_human_readable_time(duration)}`." + ) def _run_step_with_step_operator( self, diff --git a/src/zenml/steps/__init__.py b/src/zenml/steps/__init__.py index 72d6cab12fc..c63d8e9e773 100644 --- a/src/zenml/steps/__init__.py +++ b/src/zenml/steps/__init__.py @@ -28,6 +28,7 @@ from zenml.steps.base_step import BaseStep from zenml.config.resource_settings import ResourceSettings +from zenml.steps.heartbeat import StepHeartbeatWorker, StepHeartBeatTerminationException from zenml.steps.step_context import StepContext, get_step_context from zenml.steps.step_decorator import step @@ -36,5 +37,7 @@ "ResourceSettings", "StepContext", "step", - "get_step_context" + "get_step_context", + "StepHeartbeatWorker", + "StepHeartBeatTerminationException", ] diff --git a/src/zenml/steps/heartbeat.py b/src/zenml/steps/heartbeat.py index fb189edf2a1..58d293d0232 100644 --- a/src/zenml/steps/heartbeat.py +++ b/src/zenml/steps/heartbeat.py @@ -17,11 +17,8 @@ import logging import threading import time -from typing import Annotated from uuid import UUID -from pydantic import BaseModel, conint, model_validator - from zenml.enums import ExecutionStatus logger = logging.getLogger(__name__) @@ -33,36 +30,18 @@ class StepHeartBeatTerminationException(Exception): pass -class StepHeartBeatOptions(BaseModel): - """Options group for step heartbeat execution.""" - - step_id: UUID - interval: Annotated[int, conint(ge=10, le=60)] - name: str | None = None - - @model_validator(mode="after") - def set_default_name(self) -> "StepHeartBeatOptions": - """Model validator - set name value if missing. - - Returns: - The validated step heartbeat options. - """ - if not self.name: - self.name = f"HeartBeatWorker-{self.step_id}" - - return self - - -class HeartbeatWorker: +class StepHeartbeatWorker: """Worker class implementing heartbeat polling and remote termination.""" - def __init__(self, options: StepHeartBeatOptions): + STEP_HEARTBEAT_INTERVAL_SECONDS = 60 + + def __init__(self, step_id: UUID): """Heartbeat worker constructor. Args: - options: Parameter group - polling interval, step id, etc. + step_id: The step id heartbeat is running for. """ - self.options = options + self._step_id = step_id self._thread: threading.Thread | None = None self._running: bool = False @@ -72,6 +51,15 @@ def __init__(self, options: StepHeartBeatOptions): # properties + @property + def is_terminated(self) -> bool: + """Property function for termination signal. + + Returns: + True if the worker has been terminated. + """ + return self._terminated + @property def interval(self) -> int: """Property function for heartbeat interval. @@ -79,7 +67,7 @@ def interval(self) -> int: Returns: The heartbeat polling interval value. """ - return self.options.interval + return self.STEP_HEARTBEAT_INTERVAL_SECONDS @property def name(self) -> str: @@ -88,7 +76,7 @@ def name(self) -> str: Returns: The name of the heartbeat worker. """ - return str(self.options.name) + return f"HeartBeatWorker-{self.step_id}" @property def step_id(self) -> UUID: @@ -97,14 +85,13 @@ def step_id(self) -> UUID: Returns: The id of the step heartbeat is running for. """ - return self.options.step_id + return self._step_id # public functions def start(self) -> None: """Start the heartbeat worker on a background thread.""" if self._thread and self._thread.is_alive(): - logger.info("%s already running; start() is a no-op", self.name) return self._running = True @@ -113,7 +100,7 @@ def start(self) -> None: target=self._run, name=self.name, daemon=True ) self._thread.start() - logger.info( + logger.debug( "Daemon thread %s started (interval=%s)", self.name, self.interval ) @@ -122,7 +109,7 @@ def stop(self) -> None: if not self._running: return self._running = False - logger.info("%s stop requested", self.name) + logger.debug("%s stop requested", self.name) def is_alive(self) -> bool: """Liveness of the heartbeat worker thread. @@ -134,7 +121,7 @@ def is_alive(self) -> bool: return bool(t and t.is_alive()) def _run(self) -> None: - logger.info("%s run() loop entered", self.name) + logger.debug("%s run() loop entered", self.name) try: while self._running: try: @@ -151,22 +138,21 @@ def _run(self) -> None: _thread.interrupt_main() # raises KeyboardInterrupt in main thread # Ensure we stop our own loop as well. self._running = False - except Exception: + except Exception as exc: # Log-and-continue policy for all other errors. - logger.exception( - "%s heartbeat() failed; continuing", self.name + logger.debug( + "%s heartbeat() failed with %s", self.name, str(exc) ) # Sleep after each attempt (even after errors, unless stopped). if self._running: time.sleep(self.interval) finally: - logger.info("%s run() loop exiting", self.name) + logger.debug("%s run() loop exiting", self.name) def _heartbeat(self) -> None: from zenml.config.global_config import GlobalConfiguration store = GlobalConfiguration().zen_store - response = store.update_step_heartbeat(step_run_id=self.step_id) if response.status in { diff --git a/src/zenml/utils/exception_utils.py b/src/zenml/utils/exception_utils.py index d4af51d838c..f4b3a6820c7 100644 --- a/src/zenml/utils/exception_utils.py +++ b/src/zenml/utils/exception_utils.py @@ -18,7 +18,9 @@ import re import textwrap import traceback -from typing import TYPE_CHECKING, Optional +from contextlib import ContextDecorator +from types import TracebackType +from typing import TYPE_CHECKING, Literal, Optional, Type from zenml.constants import MEDIUMTEXT_MAX_LENGTH from zenml.logger import get_logger @@ -91,3 +93,84 @@ def collect_exception_information( traceback=tb_bytes.decode(errors="ignore"), step_code_line=line_number, ) + + +class ContextReraise(ContextDecorator): + """Utility class. Capture & reraise exceptions within a context.""" + + def __init__( + self, + source_exceptions: list[Type[BaseException]], + target_exception: Type[BaseException], + message: str | None = None, + propagate_traceback: bool = True, + ) -> None: + """ContextReraise constructor. + + Args: + source_exceptions: A list of exception types to capture. + target_exception: The exception to re-raise. + message: The target exception message. If None, will be inferred from source exception arguments. + propagate_traceback: Whether to propagate exception traceback when re-raising. + """ + self._source_exceptions = source_exceptions + self._target_exception = target_exception + self._propagate_traceback = propagate_traceback + self.message = message + + def __enter__(self) -> "ContextReraise": + """Context manager enter magic method. + + Returns: + The context manager. + """ + return self + + @staticmethod + def _get_exc_message(exc: BaseException | None) -> str: + """Helper function that attempts to extract a formatted exception message. + + Args: + exc: The exception to extract the message from. + + Returns: + A formatted exception message (or a fallback stringified version). + """ + if exc and len(exc.args) > 1 and isinstance(exc.args[0], str): + try: + return exc.args[0] % exc.args[1:] + except Exception: + pass + return str(exc) + + def __exit__( + self, + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + trace: TracebackType | None, + ) -> Literal[False]: + """Context manager exit magic method. + + Args: + exc_type: The exception type. + exc_value: The exception. + trace: The exception trace. + + Returns: + Return False if context exits normally or re-raises exception. + """ + if exc_type is None: + return False + + if any(isinstance(exc_value, exc) for exc in self._source_exceptions): + if self.message: + exc_ = self._target_exception(self.message) + else: + exc_ = self._target_exception(self._get_exc_message(exc_value)) + + if self._propagate_traceback: + exc_ = exc_.with_traceback(trace) + + raise exc_ + + return False diff --git a/src/zenml/zen_server/routers/steps_endpoints.py b/src/zenml/zen_server/routers/steps_endpoints.py index c7655025ea8..a814d343866 100644 --- a/src/zenml/zen_server/routers/steps_endpoints.py +++ b/src/zenml/zen_server/routers/steps_endpoints.py @@ -16,7 +16,7 @@ from typing import Any, Dict, List from uuid import UUID -from fastapi import APIRouter, Depends, Security +from fastapi import APIRouter, Depends, HTTPException, Security from zenml.constants import ( API, @@ -28,6 +28,7 @@ VERSION_1, ) from zenml.enums import ExecutionStatus +from zenml.exceptions import AuthorizationException from zenml.logging.step_logging import ( LogEntry, fetch_log_records, @@ -203,25 +204,62 @@ def update_step( @router.put( - "/{step_run_id}/" + HEARTBEAT, + "/{step_run_id}" + HEARTBEAT, responses={401: error_response, 404: error_response, 422: error_response}, ) @async_fastapi_endpoint_wrapper(deduplicate=True) def update_heartbeat( step_run_id: UUID, - _: AuthContext = Security(authorize), + auth_context: AuthContext = Security(authorize), ) -> StepHeartbeatResponse: """Updates a step. Args: step_run_id: ID of the step. + auth_context: Authorization/Authentication context. Returns: The step heartbeat response (id, status, last_heartbeat). + + Raises: + HTTPException: If the step is finished raises with 422 status code. """ - step = zen_store().get_run_step(step_run_id, hydrate=True) - pipeline_run = zen_store().get_run(step.pipeline_run_id) - verify_permission_for_model(pipeline_run, action=Action.UPDATE) + step = zen_store().get_run_step(step_run_id, hydrate=False) + + if step.status.is_finished: + raise HTTPException( + status_code=422, + detail=f"Step {step.id} is finished - can not update heartbeat.", + ) + + def validate_token_access( + ctx: AuthContext, step_: StepRunResponse + ) -> None: + token_run_id = ctx.access_token.pipeline_run_id # type: ignore[union-attr] + token_schedule_id = ctx.access_token.schedule_id # type: ignore[union-attr] + + if token_run_id: + if step_.pipeline_run_id != token_run_id: + raise AuthorizationException( + f"Authentication token provided is invalid for step: {step_.id}" + ) + elif token_schedule_id: + pipeline_run = zen_store().get_run( + step_.pipeline_run_id, hydrate=False + ) + + if not ( + pipeline_run.schedule + and pipeline_run.schedule.id == token_schedule_id + ): + raise AuthorizationException( + f"Authentication token provided is invalid for step: {step_.id}" + ) + else: + # un-scoped token. Soon to-be-deprecated, we will ignore validation temporarily. + pass + + validate_token_access(ctx=auth_context, step_=step) return zen_store().update_step_heartbeat(step_run_id=step_run_id) diff --git a/src/zenml/zen_stores/migrations/versions/a5a17015b681_add_heartbeat_column_for_step_runs.py b/src/zenml/zen_stores/migrations/versions/a5a17015b681_add_heartbeat_column_for_step_runs.py index 63994a240e6..fc7a6855931 100644 --- a/src/zenml/zen_stores/migrations/versions/a5a17015b681_add_heartbeat_column_for_step_runs.py +++ b/src/zenml/zen_stores/migrations/versions/a5a17015b681_add_heartbeat_column_for_step_runs.py @@ -11,7 +11,7 @@ # revision identifiers, used by Alembic. revision = "a5a17015b681" -down_revision = "0.90.0" +down_revision = "0.91.0" branch_labels = None depends_on = None diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index a31446a9741..a574570f530 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -3396,7 +3396,8 @@ def update_step_heartbeat( The step heartbeat response. """ response_body = self.put( - f"{STEPS}/{str(step_run_id)}/{HEARTBEAT}", body=None, params=None + path=f"{STEPS}/{str(step_run_id)}{HEARTBEAT}", + timeout=5, ) return StepHeartbeatResponse.model_validate(response_body) diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 91aa2d6fd11..475e988fa98 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -407,6 +407,7 @@ def to_model( is_retriable=self.is_retriable, start_time=self.start_time, end_time=self.end_time, + latest_heartbeat=self.latest_heartbeat, created=self.created, updated=self.updated, model_version_id=self.model_version_id, diff --git a/tests/integration/functional/steps/test_heartbeat.py b/tests/integration/functional/steps/test_heartbeat.py new file mode 100644 index 00000000000..8425c8e7412 --- /dev/null +++ b/tests/integration/functional/steps/test_heartbeat.py @@ -0,0 +1,118 @@ +from datetime import datetime + +import pytest +from tests.integration.functional.utils import sample_name + +from zenml import ( + PipelineRequest, + PipelineRunRequest, + PipelineSnapshotRequest, + StepRunRequest, + StepRunUpdate, + pipeline, + step, +) +from zenml.client import Client +from zenml.config.global_config import GlobalConfiguration +from zenml.config.pipeline_configurations import PipelineConfiguration +from zenml.config.source import Source, SourceType +from zenml.config.step_configurations import Step, StepConfiguration, StepSpec +from zenml.enums import ExecutionStatus, StoreType + + +@step() +def greet(): + import time + + time.sleep(5) + pass + + +@pipeline() +def test_heartbeat_pipeline(): + greet() + + +def test_heartbeat_rest_functionality(): + if GlobalConfiguration().zen_store.config.type != StoreType.REST: + pytest.skip("Heartbeat testing requires REST") + + client = Client() + + pipeline_model = client.zen_store.create_pipeline( + PipelineRequest( + name=sample_name("pipeline"), + project=client.active_project.id, + ) + ) + + step_name = sample_name("foo") + snapshot = client.zen_store.create_snapshot( + PipelineSnapshotRequest( + project=client.active_project.id, + run_name_template=sample_name("foo"), + pipeline_configuration=PipelineConfiguration( + name=sample_name("foo") + ), + stack=client.active_stack.id, + pipeline=pipeline_model.id, + client_version="0.1.0", + server_version="0.1.0", + step_configurations={ + step_name: Step( + spec=StepSpec( + source=Source( + module="acme.foo", + type=SourceType.INTERNAL, + ), + upstream_steps=[], + ), + config=StepConfiguration(name=step_name), + ) + }, + ) + ) + pr, _ = client.zen_store.get_or_create_run( + PipelineRunRequest( + project=client.active_project.id, + name=sample_name("foo"), + snapshot=snapshot.id, + status=ExecutionStatus.RUNNING, + ) + ) + step_run = client.zen_store.create_run_step( + StepRunRequest( + project=client.active_project.id, + name=step_name, + status=ExecutionStatus.RUNNING, + pipeline_run_id=pr.id, + start_time=datetime.now(), + ) + ) + + assert step_run.latest_heartbeat is None + + assert ( + client.zen_store.get_run_step(step_run_id=step_run.id).latest_heartbeat + is None + ) + + hb_response = client.zen_store.update_step_heartbeat( + step_run_id=step_run.id + ) + + assert hb_response.status == ExecutionStatus.RUNNING + assert hb_response.latest_heartbeat is not None + + assert ( + client.zen_store.get_run_step(step_run_id=step_run.id).latest_heartbeat + == hb_response.latest_heartbeat + ) + + client.zen_store.update_run_step( + step_run_id=step_run.id, + step_run_update=StepRunUpdate(status=ExecutionStatus.COMPLETED), + ) + + with pytest.raises(ValueError): + client.zen_store.update_step_heartbeat(step_run_id=step_run.id) diff --git a/tests/unit/utils/test_exception_utils.py b/tests/unit/utils/test_exception_utils.py index 9e932b9c70b..1f39abf7d8c 100644 --- a/tests/unit/utils/test_exception_utils.py +++ b/tests/unit/utils/test_exception_utils.py @@ -5,7 +5,10 @@ import pytest -from zenml.utils.exception_utils import collect_exception_information +from zenml.utils.exception_utils import ( + ContextReraise, + collect_exception_information, +) def test_regex_pattern_no_syntax_warning(): @@ -101,3 +104,68 @@ def test_regex_pattern_matches_windows_paths_and_special_chars(): r' File "C:\Other\path\file.py", line 123, in some_function' ) assert line_pattern_win.search(non_match_line) is None + + +def test_context_reraise(): + # test source errors are captured and re-raised + + with ContextReraise( + source_exceptions=[ValueError, TypeError], + target_exception=RuntimeError, + message="Oh no", + ): + print("Normal exit - should do nothing!") + + class CustomError(Exception): + pass + + with pytest.raises(CustomError): + with ContextReraise( + source_exceptions=[ValueError, TypeError], + target_exception=CustomError, + message="Oh no", + ): + raise ValueError("VALUE ERROR") + + # test other errors propagate normally + + with pytest.raises(ZeroDivisionError): + with ContextReraise( + source_exceptions=[ValueError, TypeError], + target_exception=RuntimeError, + message="Oh no", + ): + _ = 1 / 0 + + # test inheritance works + + class CustomValueError(ValueError): + pass + + class CustomTypeError(TypeError): + pass + + with pytest.raises(CustomTypeError): + with ContextReraise( + source_exceptions=[ValueError, TypeError], + target_exception=CustomTypeError, + message="Oh no", + ): + raise CustomValueError("VALUE ERROR") + + try: + with ContextReraise( + source_exceptions=[ValueError], target_exception=CustomTypeError + ): + raise CustomValueError("Oh no tester %s %s", "iason", "zenml") + except CustomTypeError as exc: + assert str(exc) == "Oh no tester iason zenml" + + try: + with ContextReraise( + source_exceptions=[ValueError, ZeroDivisionError], + target_exception=CustomTypeError, + ): + _ = 1 / 0 + except CustomTypeError as exc: + assert str(exc) == "division by zero"