diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 47c4e728151..30b7579fb2d 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -440,6 +440,7 @@ def handle_int_env_var(var: str, default: int = 0) -> int: STATUS = "/status" STEP_CONFIGURATION = "/step-configuration" STEPS = "/steps" +HEARTBEAT = "/heartbeat" STOP = "/stop" TAGS = "/tags" TAG_RESOURCES = "/tag_resources" diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index ef0737dfa83..e5468dd779e 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -346,7 +346,8 @@ StepRunResponse, StepRunResponseBody, StepRunResponseMetadata, - StepRunResponseResources + StepRunResponseResources, + StepHeartbeatResponse, ) from zenml.models.v2.core.tag import ( TagFilter, @@ -908,4 +909,5 @@ "StepRunIdentifier", "ArtifactVersionIdentifier", "ModelVersionIdentifier", + "StepHeartbeatResponse", ] diff --git a/src/zenml/models/v2/core/step_run.py b/src/zenml/models/v2/core/step_run.py index 8c831fca4a9..c51b408a3bb 100644 --- a/src/zenml/models/v2/core/step_run.py +++ b/src/zenml/models/v2/core/step_run.py @@ -26,7 +26,7 @@ ) from uuid import UUID -from pydantic import ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field from zenml.config.step_configurations import StepConfiguration, StepSpec from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH @@ -210,6 +210,10 @@ class StepRunResponseBody(ProjectScopedResponseBody): title="The end time of the step run.", default=None, ) + latest_heartbeat: Optional[datetime] = Field( + title="The latest heartbeat of the step run.", + default=None, + ) model_version_id: Optional[UUID] = Field( title="The ID of the model version that was " "configured by this step run explicitly.", @@ -589,6 +593,15 @@ def end_time(self) -> Optional[datetime]: """ return self.get_body().end_time + @property + def latest_heartbeat(self) -> Optional[datetime]: + """The `latest_heartbeat` property. + + Returns: + the value of the property. + """ + return self.get_body().latest_heartbeat + @property def logs(self) -> Optional["LogsResponse"]: """The `logs` property. @@ -795,3 +808,14 @@ def get_custom_filters( custom_filters.append(cache_expiration_filter) return custom_filters + + +# ------------------ Heartbeat Model --------------- + + +class StepHeartbeatResponse(BaseModel, use_enum_values=True): + """Light-weight model for Step Heartbeat responses.""" + + id: UUID + status: ExecutionStatus + latest_heartbeat: datetime diff --git a/src/zenml/orchestrators/step_launcher.py b/src/zenml/orchestrators/step_launcher.py index eacb6fd24d1..3f923f6fb0a 100644 --- a/src/zenml/orchestrators/step_launcher.py +++ b/src/zenml/orchestrators/step_launcher.py @@ -43,7 +43,9 @@ from zenml.orchestrators import utils as orchestrator_utils from zenml.orchestrators.step_runner import StepRunner from zenml.stack import Stack +from zenml.steps import StepHeartBeatTerminationException, StepHeartbeatWorker from zenml.utils import env_utils, exception_utils, string_utils +from zenml.utils.exception_utils import ContextReraise from zenml.utils.time_utils import utc_now if TYPE_CHECKING: @@ -167,7 +169,6 @@ def signal_handler(signum: int, frame: Any) -> None: try: client = Client() - pipeline_run = None if self._step_run: pipeline_run = client.get_pipeline_run( @@ -443,35 +444,59 @@ def _run_step( ) start_time = time.time() - try: - if self._step.config.step_operator: - step_operator_name = None - if isinstance(self._step.config.step_operator, str): - step_operator_name = self._step.config.step_operator - - self._run_step_with_step_operator( - step_operator_name=step_operator_name, - step_run_info=step_run_info, + + # To have a cross-platform compatible handling of main thread termination + # we use Python's interrupt_main instead of termination signals (not Windows supported). + # Since interrupt_main raises KeyboardInterrupt we want in this context to capture it + # and handle it as a custom exception. + + with ContextReraise( + source_exceptions=[KeyboardInterrupt], + target_exception=StepHeartBeatTerminationException, + message=f"Step {self._invocation_id} has been remotely stopped - terminating", + propagate_traceback=False, + ) as ctx_reraise: + logger.info( + f"Initiating heartbeat for step: {self._invocation_id}" + ) + + StepHeartbeatWorker(step_id=step_run.id).start() + + try: + if self._step.config.step_operator: + step_operator_name = None + if isinstance(self._step.config.step_operator, str): + step_operator_name = self._step.config.step_operator + + self._run_step_with_step_operator( + step_operator_name=step_operator_name, + step_run_info=step_run_info, + ) + else: + self._run_step_without_step_operator( + pipeline_run=pipeline_run, + step_run=step_run, + step_run_info=step_run_info, + input_artifacts=step_run.regular_inputs, + output_artifact_uris=output_artifact_uris, + ) + except StepHeartBeatTerminationException: + logger.info(ctx_reraise.message) + output_utils.remove_artifact_dirs( + artifact_uris=list(output_artifact_uris.values()) ) - else: - self._run_step_without_step_operator( - pipeline_run=pipeline_run, - step_run=step_run, - step_run_info=step_run_info, - input_artifacts=step_run.regular_inputs, - output_artifact_uris=output_artifact_uris, + raise + except: # noqa: E722 + output_utils.remove_artifact_dirs( + artifact_uris=list(output_artifact_uris.values()) ) - except: # noqa: E722 - output_utils.remove_artifact_dirs( - artifact_uris=list(output_artifact_uris.values()) - ) - raise + raise - duration = time.time() - start_time - logger.info( - f"Step `{self._invocation_id}` has finished in " - f"`{string_utils.get_human_readable_time(duration)}`." - ) + duration = time.time() - start_time + logger.info( + f"Step `{self._invocation_id}` has finished in " + f"`{string_utils.get_human_readable_time(duration)}`." + ) def _run_step_with_step_operator( self, diff --git a/src/zenml/steps/__init__.py b/src/zenml/steps/__init__.py index 72d6cab12fc..c63d8e9e773 100644 --- a/src/zenml/steps/__init__.py +++ b/src/zenml/steps/__init__.py @@ -28,6 +28,7 @@ from zenml.steps.base_step import BaseStep from zenml.config.resource_settings import ResourceSettings +from zenml.steps.heartbeat import StepHeartbeatWorker, StepHeartBeatTerminationException from zenml.steps.step_context import StepContext, get_step_context from zenml.steps.step_decorator import step @@ -36,5 +37,7 @@ "ResourceSettings", "StepContext", "step", - "get_step_context" + "get_step_context", + "StepHeartbeatWorker", + "StepHeartBeatTerminationException", ] diff --git a/src/zenml/steps/heartbeat.py b/src/zenml/steps/heartbeat.py new file mode 100644 index 00000000000..cec2142b028 --- /dev/null +++ b/src/zenml/steps/heartbeat.py @@ -0,0 +1,154 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""ZenML Step HeartBeat functionality.""" + +import _thread +import logging +import threading +import time +from uuid import UUID + +from zenml.enums import ExecutionStatus + +logger = logging.getLogger(__name__) + + +class StepHeartBeatTerminationException(Exception): + """Custom exception class for heartbeat termination.""" + pass + + +class StepHeartbeatWorker: + """Worker class implementing heartbeat polling and remote termination.""" + + STEP_HEARTBEAT_INTERVAL_SECONDS = 60 + + def __init__(self, step_id: UUID): + """Heartbeat worker constructor. + + Args: + step_id: The step id heartbeat is running for. + """ + self._step_id = step_id + + self._thread: threading.Thread | None = None + self._running: bool = False + self._terminated: bool = ( + False # one-shot guard to avoid repeated interrupts + ) + + # properties + + @property + def interval(self) -> int: + """Property function for heartbeat interval. + + Returns: + The heartbeat polling interval value. + """ + return self.STEP_HEARTBEAT_INTERVAL_SECONDS + + @property + def name(self) -> str: + """Property function for heartbeat worker name. + + Returns: + The name of the heartbeat worker. + """ + return f"HeartBeatWorker-{self.step_id}" + + @property + def step_id(self) -> UUID: + """Property function for heartbeat worker step ID. + + Returns: + The id of the step heartbeat is running for. + """ + return self._step_id + + # public functions + + def start(self) -> None: + """Start the heartbeat worker on a background thread.""" + if self._thread and self._thread.is_alive(): + return + + self._running = True + self._terminated = False + self._thread = threading.Thread( + target=self._run, name=self.name, daemon=True + ) + self._thread.start() + logger.debug( + "Daemon thread %s started (interval=%s)", self.name, self.interval + ) + + def stop(self) -> None: + """Stops the heartbeat worker.""" + if not self._running: + return + self._running = False + logger.debug("%s stop requested", self.name) + + def is_alive(self) -> bool: + """Liveness of the heartbeat worker thread. + + Returns: + True if the heartbeat worker thread is alive, False otherwise. + """ + t = self._thread + return bool(t and t.is_alive()) + + def _run(self) -> None: + logger.debug("%s run() loop entered", self.name) + try: + while self._running: + try: + self._heartbeat() + except StepHeartBeatTerminationException: + # One-shot: signal the main thread and stop the loop. + if not self._terminated: + self._terminated = True + logger.info( + "%s received HeartBeatTerminationException; " + "interrupting main thread", + self.name, + ) + _thread.interrupt_main() # raises KeyboardInterrupt in main thread + # Ensure we stop our own loop as well. + self._running = False + except Exception as exc: + # Log-and-continue policy for all other errors. + logger.debug( + "%s heartbeat() failed with %s", self.name, str(exc) + ) + # Sleep after each attempt (even after errors, unless stopped). + if self._running: + time.sleep(self.interval) + finally: + logger.debug("%s run() loop exiting", self.name) + + def _heartbeat(self) -> None: + from zenml.config.global_config import GlobalConfiguration + + store = GlobalConfiguration().zen_store + response = store.update_step_heartbeat(step_run_id=self.step_id) + + if response.status in { + ExecutionStatus.STOPPED, + ExecutionStatus.STOPPING, + }: + raise StepHeartBeatTerminationException( + f"Step {self.step_id} remotely stopped with status {response.status}." + ) diff --git a/src/zenml/utils/exception_utils.py b/src/zenml/utils/exception_utils.py index d4af51d838c..9121ddad056 100644 --- a/src/zenml/utils/exception_utils.py +++ b/src/zenml/utils/exception_utils.py @@ -18,7 +18,9 @@ import re import textwrap import traceback -from typing import TYPE_CHECKING, Optional +from contextlib import ContextDecorator +from types import TracebackType +from typing import TYPE_CHECKING, Optional, Type from zenml.constants import MEDIUMTEXT_MAX_LENGTH from zenml.logger import get_logger @@ -91,3 +93,84 @@ def collect_exception_information( traceback=tb_bytes.decode(errors="ignore"), step_code_line=line_number, ) + + +class ContextReraise(ContextDecorator): + """Utility class. Capture & reraise exceptions within a context.""" + def __init__( + self, + source_exceptions: list[Type[BaseException]], + target_exception: Type[BaseException], + message: str | None = None, + propagate_traceback: bool = True, + ) -> None: + """ContextReraise constructor. + + Args: + source_exceptions: A list of exception types to capture. + target_exception: The exception to re-raise. + message: The target exception message. If None, will be inferred from source exception arguments. + propagate_traceback: Whether to propagate exception traceback when re-raising. + """ + self._source_exceptions = source_exceptions + self._target_exception = target_exception + self._propagate_traceback = propagate_traceback + self._message = message + + def __enter__(self) -> "ContextReraise": + """Context manager enter magic method. + + Returns: + The context manager. + """ + return self + + @staticmethod + def _get_exc_message(exc: BaseException) -> str: + """Helper function that attempts to extract a formatted exception message. + + Args: + exc: The exception to extract the message from. + + Returns: + A formatted exception message (or a fallback stringified version). + """ + if len(exc.args) > 1 and isinstance(exc.args[0], str): + try: + return exc.args[0] % exc.args[1:] + except Exception: + pass + return str(exc) + + def __exit__( + self, + exc_type: Type[BaseException] | None, + exc_value: BaseException | None, + trace: TracebackType | None, + ) -> bool: + """Context manager exit magic method. + + Args: + exc_type: The exception type. + exc_value: The exception. + trace: The exception trace. + + Returns: + Return False if context exits normally or re-raises exception. + """ + if exc_type is None: + return False + + if any(isinstance(exc_value, exc) for exc in self._source_exceptions): + if self._message: + exc_ = self._target_exception(self._message) + else: + exc_ = self._target_exception(self._get_exc_message(exc_value)) + + if self._propagate_traceback: + exc_ = exc_.with_traceback(trace) + + raise exc_ + else: + raise exc_value + diff --git a/src/zenml/zen_server/routers/steps_endpoints.py b/src/zenml/zen_server/routers/steps_endpoints.py index 772d480348c..b8b15a79d74 100644 --- a/src/zenml/zen_server/routers/steps_endpoints.py +++ b/src/zenml/zen_server/routers/steps_endpoints.py @@ -16,10 +16,11 @@ from typing import Any, Dict, List from uuid import UUID -from fastapi import APIRouter, Depends, Security +from fastapi import APIRouter, Depends, HTTPException, Security from zenml.constants import ( API, + HEARTBEAT, LOGS, STATUS, STEP_CONFIGURATION, @@ -27,6 +28,7 @@ VERSION_1, ) from zenml.enums import ExecutionStatus +from zenml.exceptions import AuthorizationException from zenml.logging.step_logging import ( LogEntry, fetch_log_records, @@ -38,6 +40,7 @@ StepRunResponse, StepRunUpdate, ) +from zenml.models.v2.core.step_run import StepHeartbeatResponse from zenml.zen_server.auth import ( AuthContext, authorize, @@ -200,6 +203,59 @@ def update_step( return dehydrate_response_model(updated_step) +@router.put( + "/{step_run_id}" + HEARTBEAT, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@async_fastapi_endpoint_wrapper(deduplicate=True) +def update_heartbeat( + step_run_id: UUID, + auth_context: AuthContext = Security(authorize), +) -> StepHeartbeatResponse: + """Updates a step. + + Args: + step_run_id: ID of the step. + auth_context: Authorization/Authentication context. + + Returns: + The step heartbeat response (id, status, last_heartbeat). + """ + step = zen_store().get_run_step(step_run_id, hydrate=False) + + if step.status.is_finished: + raise HTTPException( + status_code=422, + detail=f"Step {step.id} is finished - can not update heartbeat.", + ) + + def validate_token_access(ctx: AuthContext, step_: StepRunResponse): + if ctx.access_token.pipeline_run_id: + if step_.pipeline_run_id != ctx.access_token.pipeline_run_id: + raise AuthorizationException( + f"Authentication token provided is invalid for step: {step_.id}" + ) + elif ctx.access_token.schedule_id: + pipeline_run = zen_store().get_run( + step_.pipeline_run_id, hydrate=False + ) + schedule = zen_store().get_schedule( + ctx.access_token.schedule_id, hydrate=True + ) + + if pipeline_run.pipeline.id != schedule.pipeline_id: + raise AuthorizationException( + f"Authentication token provided is invalid for step: {step_.id}" + ) + else: + # un-scoped token. Soon to-be-deprecated, we will ignore validation temporarily. + pass + + validate_token_access(ctx=auth_context, step_=step) + + return zen_store().update_step_heartbeat(step_run_id=step_run_id) + + @router.get( "/{step_id}" + STEP_CONFIGURATION, responses={401: error_response, 404: error_response, 422: error_response}, diff --git a/src/zenml/zen_stores/migrations/versions/a5a17015b681_add_heartbeat_column_for_step_runs.py b/src/zenml/zen_stores/migrations/versions/a5a17015b681_add_heartbeat_column_for_step_runs.py new file mode 100644 index 00000000000..fc7a6855931 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/a5a17015b681_add_heartbeat_column_for_step_runs.py @@ -0,0 +1,30 @@ +"""Add heartbeat column for step runs [a5a17015b681]. + +Revision ID: a5a17015b681 +Revises: 0.90.0 +Create Date: 2025-10-13 12:24:12.470803 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a5a17015b681" +down_revision = "0.91.0" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + with op.batch_alter_table("step_run", schema=None) as batch_op: + batch_op.add_column( + sa.Column("latest_heartbeat", sa.DateTime(), nullable=True) + ) + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + with op.batch_alter_table("step_run", schema=None) as batch_op: + batch_op.drop_column("latest_heartbeat") diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 24a0e50b3c4..8dc0b7d6b77 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -75,6 +75,7 @@ ENV_ZENML_DISABLE_CLIENT_SERVER_MISMATCH_WARNING, EVENT_SOURCES, FLAVORS, + HEARTBEAT, INFO, LOGIN, LOGS, @@ -259,6 +260,7 @@ StackRequest, StackResponse, StackUpdate, + StepHeartbeatResponse, StepRunFilter, StepRunRequest, StepRunResponse, @@ -3382,6 +3384,23 @@ def update_run_step( route=STEPS, ) + def update_step_heartbeat( + self, step_run_id: UUID + ) -> StepHeartbeatResponse: + """Updates a step run heartbeat. + + Args: + step_run_id: The ID of the step to update. + + Returns: + The step heartbeat response. + """ + response_body = self.put( + path=f"{STEPS}/{str(step_run_id)}{HEARTBEAT}", timeout=5, + ) + + return StepHeartbeatResponse.model_validate(response_body) + # -------------------- Triggers -------------------- def create_trigger(self, trigger: TriggerRequest) -> TriggerResponse: diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 2dced06c2e5..475e988fa98 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -86,6 +86,10 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True): # Fields start_time: Optional[datetime] = Field(nullable=True) end_time: Optional[datetime] = Field(nullable=True) + latest_heartbeat: Optional[datetime] = Field( + nullable=True, + description="The latest execution heartbeat.", + ) status: str = Field(nullable=False) docstring: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) @@ -403,6 +407,7 @@ def to_model( is_retriable=self.is_retriable, start_time=self.start_time, end_time=self.end_time, + latest_heartbeat=self.latest_heartbeat, created=self.created, updated=self.updated, model_version_id=self.model_version_id, diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 451007e05eb..c302a5dabb5 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -13,6 +13,8 @@ # permissions and limitations under the License. """SQL Zen Store implementation.""" +from zenml.models.v2.core.step_run import StepHeartbeatResponse + try: import sqlalchemy # noqa except ImportError: @@ -35,7 +37,7 @@ import sys import time from collections import defaultdict -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from functools import lru_cache from pathlib import Path from typing import ( @@ -10297,6 +10299,37 @@ def list_run_steps( apply_query_options_from_schema=True, ) + def update_step_heartbeat( + self, step_run_id: UUID + ) -> StepHeartbeatResponse: + """Updates a step run heartbeat value. + + Lightweight function for fast updates as heartbeats may be received at bulk. + + Args: + step_run_id: ID of the step run. + + Returns: + Step heartbeat response (minimal info, id, status & latest_heartbeat). + """ + with Session(self.engine) as session: + existing_step_run = self._get_schema_by_id( + resource_id=step_run_id, + schema_class=StepRunSchema, + session=session, + ) + + existing_step_run.latest_heartbeat = datetime.now(timezone.utc) + + session.commit() + session.refresh(existing_step_run) + + return StepHeartbeatResponse( + id=existing_step_run.id, + status=existing_step_run.status, + latest_heartbeat=existing_step_run.latest_heartbeat, + ) + def update_run_step( self, step_run_id: UUID, diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 350a74c0387..9a9400a0b97 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -139,6 +139,7 @@ StackRequest, StackResponse, StackUpdate, + StepHeartbeatResponse, StepRunFilter, StepRunRequest, StepRunResponse, @@ -2581,6 +2582,19 @@ def update_run_step( KeyError: if the step run doesn't exist. """ + @abstractmethod + def update_step_heartbeat( + self, step_run_id: UUID + ) -> StepHeartbeatResponse: + """Updates a step run heartbeat. + + Args: + step_run_id: The ID of the step to update. + + Returns: + The step heartbeat response. + """ + # -------------------- Triggers -------------------- @abstractmethod diff --git a/tests/unit/utils/test_exception_utils.py b/tests/unit/utils/test_exception_utils.py index 9e932b9c70b..1f39abf7d8c 100644 --- a/tests/unit/utils/test_exception_utils.py +++ b/tests/unit/utils/test_exception_utils.py @@ -5,7 +5,10 @@ import pytest -from zenml.utils.exception_utils import collect_exception_information +from zenml.utils.exception_utils import ( + ContextReraise, + collect_exception_information, +) def test_regex_pattern_no_syntax_warning(): @@ -101,3 +104,68 @@ def test_regex_pattern_matches_windows_paths_and_special_chars(): r' File "C:\Other\path\file.py", line 123, in some_function' ) assert line_pattern_win.search(non_match_line) is None + + +def test_context_reraise(): + # test source errors are captured and re-raised + + with ContextReraise( + source_exceptions=[ValueError, TypeError], + target_exception=RuntimeError, + message="Oh no", + ): + print("Normal exit - should do nothing!") + + class CustomError(Exception): + pass + + with pytest.raises(CustomError): + with ContextReraise( + source_exceptions=[ValueError, TypeError], + target_exception=CustomError, + message="Oh no", + ): + raise ValueError("VALUE ERROR") + + # test other errors propagate normally + + with pytest.raises(ZeroDivisionError): + with ContextReraise( + source_exceptions=[ValueError, TypeError], + target_exception=RuntimeError, + message="Oh no", + ): + _ = 1 / 0 + + # test inheritance works + + class CustomValueError(ValueError): + pass + + class CustomTypeError(TypeError): + pass + + with pytest.raises(CustomTypeError): + with ContextReraise( + source_exceptions=[ValueError, TypeError], + target_exception=CustomTypeError, + message="Oh no", + ): + raise CustomValueError("VALUE ERROR") + + try: + with ContextReraise( + source_exceptions=[ValueError], target_exception=CustomTypeError + ): + raise CustomValueError("Oh no tester %s %s", "iason", "zenml") + except CustomTypeError as exc: + assert str(exc) == "Oh no tester iason zenml" + + try: + with ContextReraise( + source_exceptions=[ValueError, ZeroDivisionError], + target_exception=CustomTypeError, + ): + _ = 1 / 0 + except CustomTypeError as exc: + assert str(exc) == "division by zero"