Skip to content

Commit a114671

Browse files
Add workflow caller tests that confirm both async and non-async cancel methods are invoked correctly (#1281)
Co-authored-by: Thomas Hardy <thestaffofmoses@gmail.com>
1 parent 0481785 commit a114671

File tree

1 file changed

+154
-4
lines changed

1 file changed

+154
-4
lines changed

tests/nexus/test_workflow_caller.py

Lines changed: 154 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import concurrent.futures
55
import dataclasses
6+
import threading
67
import uuid
78
from collections.abc import Awaitable, Callable
89
from dataclasses import dataclass
@@ -39,7 +40,11 @@
3940
)
4041
from temporalio.common import WorkflowIDConflictPolicy
4142
from temporalio.converter import PayloadConverter
42-
from temporalio.exceptions import ApplicationError, CancelledError, NexusOperationError
43+
from temporalio.exceptions import (
44+
ApplicationError,
45+
CancelledError,
46+
NexusOperationError,
47+
)
4348
from temporalio.nexus import WorkflowRunOperationContext, workflow_run_operation
4449
from temporalio.runtime import (
4550
BUFFERED_METRIC_KIND_COUNTER,
@@ -1212,9 +1217,6 @@ class ServiceClassNameOutput:
12121217
name: str
12131218

12141219

1215-
# TODO(nexus-prerelease): async and non-async cancel methods
1216-
1217-
12181220
@nexusrpc.service
12191221
class ServiceInterfaceWithoutNameOverride:
12201222
op: nexusrpc.Operation[None, ServiceClassNameOutput]
@@ -2004,3 +2006,151 @@ async def test_workflow_caller_buffered_metrics(
20042006
and update.value == 30
20052007
for update in updates
20062008
)
2009+
2010+
2011+
@workflow.defn()
2012+
class CancelTestCallerWorkflow:
2013+
def __init__(self) -> None:
2014+
self.released = False
2015+
2016+
@workflow.run
2017+
async def run(self, use_async_cancel: bool, task_queue: str) -> str:
2018+
nexus_client = workflow.create_nexus_client(
2019+
service=TestAsyncAndNonAsyncCancel.CancelTestService,
2020+
endpoint=make_nexus_endpoint_name(task_queue),
2021+
)
2022+
2023+
op = (
2024+
TestAsyncAndNonAsyncCancel.CancelTestService.async_cancel_op
2025+
if use_async_cancel
2026+
else TestAsyncAndNonAsyncCancel.CancelTestService.non_async_cancel_op
2027+
)
2028+
2029+
# Start the operation and immediately request cancellation
2030+
# Use WAIT_REQUESTED since we just need to verify the cancel handler was called
2031+
handle = await nexus_client.start_operation(
2032+
op,
2033+
None,
2034+
cancellation_type=workflow.NexusOperationCancellationType.WAIT_REQUESTED,
2035+
)
2036+
2037+
# Cancel the handle to trigger the cancel method on the handler
2038+
handle.cancel()
2039+
2040+
try:
2041+
await handle
2042+
except NexusOperationError:
2043+
# Wait for release signal before completing
2044+
await workflow.wait_condition(lambda: self.released)
2045+
return "cancelled_successfully"
2046+
2047+
return "unexpected_completion"
2048+
2049+
@workflow.signal
2050+
def release(self) -> None:
2051+
self.released = True
2052+
2053+
2054+
@pytest.fixture(scope="class")
2055+
def cancel_test_events(request: pytest.FixtureRequest):
2056+
if request.cls:
2057+
request.cls.called_async = asyncio.Event()
2058+
request.cls.called_non_async = threading.Event()
2059+
yield
2060+
2061+
2062+
@pytest.mark.usefixtures("cancel_test_events")
2063+
class TestAsyncAndNonAsyncCancel:
2064+
called_async: asyncio.Event # pyright: ignore[reportUninitializedInstanceVariable]
2065+
called_non_async: threading.Event # pyright: ignore[reportUninitializedInstanceVariable]
2066+
2067+
class OpWithAsyncCancel(OperationHandler[None, str]):
2068+
def __init__(self, evt: asyncio.Event) -> None:
2069+
self.evt = evt
2070+
2071+
async def start(
2072+
self, ctx: StartOperationContext, input: None
2073+
) -> StartOperationResultAsync:
2074+
return StartOperationResultAsync("test-token")
2075+
2076+
async def cancel(self, ctx: CancelOperationContext, token: str) -> None:
2077+
self.evt.set()
2078+
2079+
class OpWithNonAsyncCancel(OperationHandler[None, str]):
2080+
def __init__(self, evt: threading.Event) -> None:
2081+
self.evt = evt
2082+
2083+
def start(
2084+
self, ctx: StartOperationContext, input: None
2085+
) -> StartOperationResultAsync:
2086+
return StartOperationResultAsync("test-token")
2087+
2088+
def cancel(self, ctx: CancelOperationContext, token: str) -> None:
2089+
self.evt.set()
2090+
2091+
@nexusrpc.service
2092+
class CancelTestService:
2093+
async_cancel_op: nexusrpc.Operation[None, str]
2094+
non_async_cancel_op: nexusrpc.Operation[None, str]
2095+
2096+
@service_handler(service=CancelTestService)
2097+
class CancelTestServiceHandler:
2098+
def __init__(
2099+
self, async_evt: asyncio.Event, non_async_evt: threading.Event
2100+
) -> None:
2101+
self.async_evt = async_evt
2102+
self.non_async_evt = non_async_evt
2103+
2104+
@operation_handler
2105+
def async_cancel_op(self) -> OperationHandler[None, str]:
2106+
return TestAsyncAndNonAsyncCancel.OpWithAsyncCancel(self.async_evt)
2107+
2108+
@operation_handler
2109+
def non_async_cancel_op(self) -> OperationHandler[None, str]:
2110+
return TestAsyncAndNonAsyncCancel.OpWithNonAsyncCancel(self.non_async_evt)
2111+
2112+
@pytest.mark.parametrize("use_async_cancel", [True, False])
2113+
async def test_task_executor_operation_cancel_method(
2114+
self, client: Client, env: WorkflowEnvironment, use_async_cancel: bool
2115+
):
2116+
"""Test that both async and non-async cancel methods work for TaskExecutor-based operations."""
2117+
if env.supports_time_skipping:
2118+
pytest.skip("Nexus tests don't work with time-skipping server")
2119+
2120+
task_queue = str(uuid.uuid4())
2121+
async with Worker(
2122+
client,
2123+
task_queue=task_queue,
2124+
workflows=[CancelTestCallerWorkflow],
2125+
nexus_service_handlers=[
2126+
TestAsyncAndNonAsyncCancel.CancelTestServiceHandler(
2127+
self.called_async, self.called_non_async
2128+
)
2129+
],
2130+
nexus_task_executor=concurrent.futures.ThreadPoolExecutor(),
2131+
):
2132+
await create_nexus_endpoint(task_queue, client)
2133+
2134+
caller_wf_handle = await client.start_workflow(
2135+
CancelTestCallerWorkflow.run,
2136+
args=[use_async_cancel, task_queue],
2137+
id=f"caller-wf-{uuid.uuid4()}",
2138+
task_queue=task_queue,
2139+
)
2140+
2141+
# Wait for the cancel method to be called
2142+
fut = (
2143+
self.called_async.wait()
2144+
if use_async_cancel
2145+
else asyncio.get_running_loop().run_in_executor(
2146+
None, self.called_non_async.wait
2147+
)
2148+
)
2149+
await asyncio.wait_for(fut, timeout=30)
2150+
2151+
# Release the workflow to complete
2152+
await caller_wf_handle.signal(CancelTestCallerWorkflow.release)
2153+
2154+
# Verify the workflow completed successfully
2155+
result = await caller_wf_handle.result()
2156+
assert result == "cancelled_successfully"

0 commit comments

Comments
 (0)