|
3 | 3 | import asyncio |
4 | 4 | import concurrent.futures |
5 | 5 | import dataclasses |
| 6 | +import threading |
6 | 7 | import uuid |
7 | 8 | from collections.abc import Awaitable, Callable |
8 | 9 | from dataclasses import dataclass |
|
39 | 40 | ) |
40 | 41 | from temporalio.common import WorkflowIDConflictPolicy |
41 | 42 | from temporalio.converter import PayloadConverter |
42 | | -from temporalio.exceptions import ApplicationError, CancelledError, NexusOperationError |
| 43 | +from temporalio.exceptions import ( |
| 44 | + ApplicationError, |
| 45 | + CancelledError, |
| 46 | + NexusOperationError, |
| 47 | +) |
43 | 48 | from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation |
44 | 49 | from temporalio.runtime import ( |
45 | 50 | BUFFERED_METRIC_KIND_COUNTER, |
@@ -1212,9 +1217,6 @@ class ServiceClassNameOutput: |
1212 | 1217 | name: str |
1213 | 1218 |
|
1214 | 1219 |
|
1215 | | -# TODO(nexus-prerelease): async and non-async cancel methods |
1216 | | - |
1217 | | - |
1218 | 1220 | @nexusrpc.service |
1219 | 1221 | class ServiceInterfaceWithoutNameOverride: |
1220 | 1222 | op: nexusrpc.Operation[None, ServiceClassNameOutput] |
@@ -2004,3 +2006,151 @@ async def test_workflow_caller_buffered_metrics( |
2004 | 2006 | and update.value == 30 |
2005 | 2007 | for update in updates |
2006 | 2008 | ) |
| 2009 | + |
| 2010 | + |
| 2011 | +@workflow.defn() |
| 2012 | +class CancelTestCallerWorkflow: |
| 2013 | + def __init__(self) -> None: |
| 2014 | + self.released = False |
| 2015 | + |
| 2016 | + @workflow.run |
| 2017 | + async def run(self, use_async_cancel: bool, task_queue: str) -> str: |
| 2018 | + nexus_client = workflow.create_nexus_client( |
| 2019 | + service=TestAsyncAndNonAsyncCancel.CancelTestService, |
| 2020 | + endpoint=make_nexus_endpoint_name(task_queue), |
| 2021 | + ) |
| 2022 | + |
| 2023 | + op = ( |
| 2024 | + TestAsyncAndNonAsyncCancel.CancelTestService.async_cancel_op |
| 2025 | + if use_async_cancel |
| 2026 | + else TestAsyncAndNonAsyncCancel.CancelTestService.non_async_cancel_op |
| 2027 | + ) |
| 2028 | + |
| 2029 | + # Start the operation and immediately request cancellation |
| 2030 | + # Use WAIT_REQUESTED since we just need to verify the cancel handler was called |
| 2031 | + handle = await nexus_client.start_operation( |
| 2032 | + op, |
| 2033 | + None, |
| 2034 | + cancellation_type=workflow.NexusOperationCancellationType.WAIT_REQUESTED, |
| 2035 | + ) |
| 2036 | + |
| 2037 | + # Cancel the handle to trigger the cancel method on the handler |
| 2038 | + handle.cancel() |
| 2039 | + |
| 2040 | + try: |
| 2041 | + await handle |
| 2042 | + except NexusOperationError: |
| 2043 | + # Wait for release signal before completing |
| 2044 | + await workflow.wait_condition(lambda: self.released) |
| 2045 | + return "cancelled_successfully" |
| 2046 | + |
| 2047 | + return "unexpected_completion" |
| 2048 | + |
| 2049 | + @workflow.signal |
| 2050 | + def release(self) -> None: |
| 2051 | + self.released = True |
| 2052 | + |
| 2053 | + |
| 2054 | +@pytest.fixture(scope="class") |
| 2055 | +def cancel_test_events(request: pytest.FixtureRequest): |
| 2056 | + if request.cls: |
| 2057 | + request.cls.called_async = asyncio.Event() |
| 2058 | + request.cls.called_non_async = threading.Event() |
| 2059 | + yield |
| 2060 | + |
| 2061 | + |
| 2062 | +@pytest.mark.usefixtures("cancel_test_events") |
| 2063 | +class TestAsyncAndNonAsyncCancel: |
| 2064 | + called_async: asyncio.Event # pyright: ignore[reportUninitializedInstanceVariable] |
| 2065 | + called_non_async: threading.Event # pyright: ignore[reportUninitializedInstanceVariable] |
| 2066 | + |
| 2067 | + class OpWithAsyncCancel(OperationHandler[None, str]): |
| 2068 | + def __init__(self, evt: asyncio.Event) -> None: |
| 2069 | + self.evt = evt |
| 2070 | + |
| 2071 | + async def start( |
| 2072 | + self, ctx: StartOperationContext, input: None |
| 2073 | + ) -> StartOperationResultAsync: |
| 2074 | + return StartOperationResultAsync("test-token") |
| 2075 | + |
| 2076 | + async def cancel(self, ctx: CancelOperationContext, token: str) -> None: |
| 2077 | + self.evt.set() |
| 2078 | + |
| 2079 | + class OpWithNonAsyncCancel(OperationHandler[None, str]): |
| 2080 | + def __init__(self, evt: threading.Event) -> None: |
| 2081 | + self.evt = evt |
| 2082 | + |
| 2083 | + def start( |
| 2084 | + self, ctx: StartOperationContext, input: None |
| 2085 | + ) -> StartOperationResultAsync: |
| 2086 | + return StartOperationResultAsync("test-token") |
| 2087 | + |
| 2088 | + def cancel(self, ctx: CancelOperationContext, token: str) -> None: |
| 2089 | + self.evt.set() |
| 2090 | + |
| 2091 | + @nexusrpc.service |
| 2092 | + class CancelTestService: |
| 2093 | + async_cancel_op: nexusrpc.Operation[None, str] |
| 2094 | + non_async_cancel_op: nexusrpc.Operation[None, str] |
| 2095 | + |
| 2096 | + @service_handler(service=CancelTestService) |
| 2097 | + class CancelTestServiceHandler: |
| 2098 | + def __init__( |
| 2099 | + self, async_evt: asyncio.Event, non_async_evt: threading.Event |
| 2100 | + ) -> None: |
| 2101 | + self.async_evt = async_evt |
| 2102 | + self.non_async_evt = non_async_evt |
| 2103 | + |
| 2104 | + @operation_handler |
| 2105 | + def async_cancel_op(self) -> OperationHandler[None, str]: |
| 2106 | + return TestAsyncAndNonAsyncCancel.OpWithAsyncCancel(self.async_evt) |
| 2107 | + |
| 2108 | + @operation_handler |
| 2109 | + def non_async_cancel_op(self) -> OperationHandler[None, str]: |
| 2110 | + return TestAsyncAndNonAsyncCancel.OpWithNonAsyncCancel(self.non_async_evt) |
| 2111 | + |
| 2112 | + @pytest.mark.parametrize("use_async_cancel", [True, False]) |
| 2113 | + async def test_task_executor_operation_cancel_method( |
| 2114 | + self, client: Client, env: WorkflowEnvironment, use_async_cancel: bool |
| 2115 | + ): |
| 2116 | + """Test that both async and non-async cancel methods work for TaskExecutor-based operations.""" |
| 2117 | + if env.supports_time_skipping: |
| 2118 | + pytest.skip("Nexus tests don't work with time-skipping server") |
| 2119 | + |
| 2120 | + task_queue = str(uuid.uuid4()) |
| 2121 | + async with Worker( |
| 2122 | + client, |
| 2123 | + task_queue=task_queue, |
| 2124 | + workflows=[CancelTestCallerWorkflow], |
| 2125 | + nexus_service_handlers=[ |
| 2126 | + TestAsyncAndNonAsyncCancel.CancelTestServiceHandler( |
| 2127 | + self.called_async, self.called_non_async |
| 2128 | + ) |
| 2129 | + ], |
| 2130 | + nexus_task_executor=concurrent.futures.ThreadPoolExecutor(), |
| 2131 | + ): |
| 2132 | + await create_nexus_endpoint(task_queue, client) |
| 2133 | + |
| 2134 | + caller_wf_handle = await client.start_workflow( |
| 2135 | + CancelTestCallerWorkflow.run, |
| 2136 | + args=[use_async_cancel, task_queue], |
| 2137 | + id=f"caller-wf-{uuid.uuid4()}", |
| 2138 | + task_queue=task_queue, |
| 2139 | + ) |
| 2140 | + |
| 2141 | + # Wait for the cancel method to be called |
| 2142 | + fut = ( |
| 2143 | + self.called_async.wait() |
| 2144 | + if use_async_cancel |
| 2145 | + else asyncio.get_running_loop().run_in_executor( |
| 2146 | + None, self.called_non_async.wait |
| 2147 | + ) |
| 2148 | + ) |
| 2149 | + await asyncio.wait_for(fut, timeout=30) |
| 2150 | + |
| 2151 | + # Release the workflow to complete |
| 2152 | + await caller_wf_handle.signal(CancelTestCallerWorkflow.release) |
| 2153 | + |
| 2154 | + # Verify the workflow completed successfully |
| 2155 | + result = await caller_wf_handle.result() |
| 2156 | + assert result == "cancelled_successfully" |
0 commit comments