Skip to content

Commit e4a8cc6

Browse files
committed
Cancellation types for Nexus operations invoked by workflows
1 parent b59c555 commit e4a8cc6

File tree

5 files changed

+408
-15
lines changed

5 files changed

+408
-15
lines changed

temporalio/worker/_interceptor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ class StartNexusOperationInput(Generic[InputT, OutputT]):
298298
operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]]
299299
input: InputT
300300
schedule_to_close_timeout: Optional[timedelta]
301+
cancellation_type: temporalio.workflow.NexusOperationCancellationType
301302
headers: Optional[Mapping[str, str]]
302303
output_type: Optional[Type[OutputT]] = None
303304

temporalio/worker/_workflow_instance.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@
5454
import temporalio.bridge.proto.activity_result
5555
import temporalio.bridge.proto.child_workflow
5656
import temporalio.bridge.proto.common
57+
import temporalio.bridge.proto.nexus
5758
import temporalio.bridge.proto.workflow_activation
5859
import temporalio.bridge.proto.workflow_commands
5960
import temporalio.bridge.proto.workflow_completion
6061
import temporalio.common
6162
import temporalio.converter
6263
import temporalio.exceptions
63-
import temporalio.nexus
6464
import temporalio.workflow
6565
from temporalio.service import __version__
6666

@@ -1502,9 +1502,10 @@ async def workflow_start_nexus_operation(
15021502
service: str,
15031503
operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]],
15041504
input: Any,
1505-
output_type: Optional[Type[OutputT]] = None,
1506-
schedule_to_close_timeout: Optional[timedelta] = None,
1507-
headers: Optional[Mapping[str, str]] = None,
1505+
output_type: Optional[Type[OutputT]],
1506+
schedule_to_close_timeout: Optional[timedelta],
1507+
cancellation_type: temporalio.workflow.NexusOperationCancellationType,
1508+
headers: Optional[Mapping[str, str]],
15081509
) -> temporalio.workflow.NexusOperationHandle[OutputT]:
15091510
# start_nexus_operation
15101511
return await self._outbound.start_nexus_operation(
@@ -1515,6 +1516,7 @@ async def workflow_start_nexus_operation(
15151516
input=input,
15161517
output_type=output_type,
15171518
schedule_to_close_timeout=schedule_to_close_timeout,
1519+
cancellation_type=cancellation_type,
15181520
headers=headers,
15191521
)
15201522
)
@@ -2757,7 +2759,7 @@ def _apply_schedule_command(
27572759
if self._input.retry_policy:
27582760
self._input.retry_policy.apply_to_proto(v.retry_policy)
27592761
v.cancellation_type = cast(
2760-
"temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType",
2762+
temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType,
27612763
int(self._input.cancellation_type),
27622764
)
27632765

@@ -2893,7 +2895,7 @@ def _apply_start_command(self) -> None:
28932895
if self._input.task_timeout:
28942896
v.workflow_task_timeout.FromTimedelta(self._input.task_timeout)
28952897
v.parent_close_policy = cast(
2896-
"temporalio.bridge.proto.child_workflow.ParentClosePolicy.ValueType",
2898+
temporalio.bridge.proto.child_workflow.ParentClosePolicy.ValueType,
28972899
int(self._input.parent_close_policy),
28982900
)
28992901
v.workflow_id_reuse_policy = cast(
@@ -2915,7 +2917,7 @@ def _apply_start_command(self) -> None:
29152917
self._input.search_attributes, v.search_attributes
29162918
)
29172919
v.cancellation_type = cast(
2918-
"temporalio.bridge.proto.child_workflow.ChildWorkflowCancellationType.ValueType",
2920+
temporalio.bridge.proto.child_workflow.ChildWorkflowCancellationType.ValueType,
29192921
int(self._input.cancellation_type),
29202922
)
29212923
if self._input.versioning_intent:
@@ -3011,11 +3013,6 @@ def __init__(
30113013

30123014
@property
30133015
def operation_token(self) -> Optional[str]:
3014-
# TODO(nexus-preview): How should this behave?
3015-
# Java has a separate class that only exists if the operation token exists:
3016-
# https://github.com/temporalio/sdk-java/blob/master/temporal-sdk/src/main/java/io/temporal/internal/sync/NexusOperationExecutionImpl.java#L26
3017-
# And Go similar:
3018-
# https://github.com/temporalio/sdk-go/blob/master/internal/workflow.go#L2770-L2771
30193016
try:
30203017
return self._start_fut.result()
30213018
except BaseException:
@@ -3064,6 +3061,11 @@ def _apply_schedule_command(self) -> None:
30643061
v.schedule_to_close_timeout.FromTimedelta(
30653062
self._input.schedule_to_close_timeout
30663063
)
3064+
v.cancellation_type = cast(
3065+
temporalio.bridge.proto.nexus.NexusOperationCancellationType.ValueType,
3066+
int(self._input.cancellation_type),
3067+
)
3068+
30673069
if self._input.headers:
30683070
for key, val in self._input.headers.items():
30693071
v.nexus_header[key] = val

temporalio/workflow.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -858,9 +858,10 @@ async def workflow_start_nexus_operation(
858858
service: str,
859859
operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]],
860860
input: Any,
861-
output_type: Optional[Type[OutputT]] = None,
862-
schedule_to_close_timeout: Optional[timedelta] = None,
863-
headers: Optional[Mapping[str, str]] = None,
861+
output_type: Optional[Type[OutputT]],
862+
schedule_to_close_timeout: Optional[timedelta],
863+
cancellation_type: temporalio.workflow.NexusOperationCancellationType,
864+
headers: Optional[Mapping[str, str]],
864865
) -> NexusOperationHandle[OutputT]: ...
865866

866867
@abstractmethod
@@ -5137,6 +5138,46 @@ def _to_proto(self) -> temporalio.bridge.proto.common.VersioningIntent.ValueType
51375138
ServiceT = TypeVar("ServiceT")
51385139

51395140

5141+
class NexusOperationCancellationType(IntEnum):
5142+
"""Defines behavior of a Nexus operation when the caller workflow initiates cancellation.
5143+
5144+
Pass one of these values to :py:meth:`NexusClient.start_operation` to define cancellation
5145+
behavior.
5146+
5147+
To initiate cancellation, use :py:meth:`NexusOperationHandle.cancel` and then `await` the
5148+
operation handle. This will result in a :py:class:`exceptions.NexusOperationError`. The values
5149+
of this enum define what is guaranteed to have happened by that point.
5150+
"""
5151+
5152+
ABANDON = int(temporalio.bridge.proto.nexus.NexusOperationCancellationType.ABANDON)
5153+
"""Do not send any cancellation request to the operation handler; just report cancellation to the caller"""
5154+
5155+
TRY_CANCEL = int(
5156+
temporalio.bridge.proto.nexus.NexusOperationCancellationType.TRY_CANCEL
5157+
)
5158+
"""Send a cancellation request but immediately report cancellation to the caller. Note that this
5159+
does not guarantee that cancellation is delivered to the operation handler if the caller exits
5160+
before the delivery is done.
5161+
"""
5162+
5163+
# TODO(nexus-preview): core needs to be updated to handle
5164+
# NexusOperationCancelRequestCompleted and NexusOperationCancelRequestFailed
5165+
# see https://github.com/temporalio/sdk-core/issues/911
5166+
# WAIT_REQUESTED = int(
5167+
# temporalio.bridge.proto.nexus.NexusOperationCancellationType.WAIT_CANCELLATION_REQUESTED
5168+
# )
5169+
# """Send a cancellation request and wait for confirmation that the request was received.
5170+
# Does not wait for the operation to complete.
5171+
# """
5172+
5173+
WAIT_COMPLETED = int(
5174+
temporalio.bridge.proto.nexus.NexusOperationCancellationType.WAIT_CANCELLATION_COMPLETED
5175+
)
5176+
"""Send a cancellation request and wait for the operation to complete.
5177+
Note that the operation may not complete as cancelled (for example, if it catches the
5178+
:py:exc:`asyncio.CancelledError` resulting from the cancellation request)."""
5179+
5180+
51405181
class NexusClient(ABC, Generic[ServiceT]):
51415182
"""A client for invoking Nexus operations.
51425183
@@ -5167,6 +5208,7 @@ async def start_operation(
51675208
*,
51685209
output_type: Optional[Type[OutputT]] = None,
51695210
schedule_to_close_timeout: Optional[timedelta] = None,
5211+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
51705212
headers: Optional[Mapping[str, str]] = None,
51715213
) -> NexusOperationHandle[OutputT]: ...
51725214

@@ -5180,6 +5222,7 @@ async def start_operation(
51805222
*,
51815223
output_type: Optional[Type[OutputT]] = None,
51825224
schedule_to_close_timeout: Optional[timedelta] = None,
5225+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
51835226
headers: Optional[Mapping[str, str]] = None,
51845227
) -> NexusOperationHandle[OutputT]: ...
51855228

@@ -5196,6 +5239,7 @@ async def start_operation(
51965239
*,
51975240
output_type: Optional[Type[OutputT]] = None,
51985241
schedule_to_close_timeout: Optional[timedelta] = None,
5242+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
51995243
headers: Optional[Mapping[str, str]] = None,
52005244
) -> NexusOperationHandle[OutputT]: ...
52015245

@@ -5212,6 +5256,7 @@ async def start_operation(
52125256
*,
52135257
output_type: Optional[Type[OutputT]] = None,
52145258
schedule_to_close_timeout: Optional[timedelta] = None,
5259+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
52155260
headers: Optional[Mapping[str, str]] = None,
52165261
) -> NexusOperationHandle[OutputT]: ...
52175262

@@ -5228,6 +5273,7 @@ async def start_operation(
52285273
*,
52295274
output_type: Optional[Type[OutputT]] = None,
52305275
schedule_to_close_timeout: Optional[timedelta] = None,
5276+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
52315277
headers: Optional[Mapping[str, str]] = None,
52325278
) -> NexusOperationHandle[OutputT]: ...
52335279

@@ -5239,6 +5285,7 @@ async def start_operation(
52395285
*,
52405286
output_type: Optional[Type[OutputT]] = None,
52415287
schedule_to_close_timeout: Optional[timedelta] = None,
5288+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
52425289
headers: Optional[Mapping[str, str]] = None,
52435290
) -> Any:
52445291
"""Start a Nexus operation and return its handle.
@@ -5268,6 +5315,7 @@ async def execute_operation(
52685315
*,
52695316
output_type: Optional[Type[OutputT]] = None,
52705317
schedule_to_close_timeout: Optional[timedelta] = None,
5318+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
52715319
headers: Optional[Mapping[str, str]] = None,
52725320
) -> OutputT: ...
52735321

@@ -5281,6 +5329,7 @@ async def execute_operation(
52815329
*,
52825330
output_type: Optional[Type[OutputT]] = None,
52835331
schedule_to_close_timeout: Optional[timedelta] = None,
5332+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
52845333
headers: Optional[Mapping[str, str]] = None,
52855334
) -> OutputT: ...
52865335

@@ -5297,6 +5346,7 @@ async def execute_operation(
52975346
*,
52985347
output_type: Optional[Type[OutputT]] = None,
52995348
schedule_to_close_timeout: Optional[timedelta] = None,
5349+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
53005350
headers: Optional[Mapping[str, str]] = None,
53015351
) -> OutputT: ...
53025352

@@ -5316,6 +5366,7 @@ async def execute_operation(
53165366
*,
53175367
output_type: Optional[Type[OutputT]] = None,
53185368
schedule_to_close_timeout: Optional[timedelta] = None,
5369+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
53195370
headers: Optional[Mapping[str, str]] = None,
53205371
) -> OutputT: ...
53215372

@@ -5332,6 +5383,7 @@ async def execute_operation(
53325383
*,
53335384
output_type: Optional[Type[OutputT]] = None,
53345385
schedule_to_close_timeout: Optional[timedelta] = None,
5386+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
53355387
headers: Optional[Mapping[str, str]] = None,
53365388
) -> OutputT: ...
53375389

@@ -5343,6 +5395,7 @@ async def execute_operation(
53435395
*,
53445396
output_type: Optional[Type[OutputT]] = None,
53455397
schedule_to_close_timeout: Optional[timedelta] = None,
5398+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
53465399
headers: Optional[Mapping[str, str]] = None,
53475400
) -> Any:
53485401
"""Execute a Nexus operation and return its result.
@@ -5394,6 +5447,7 @@ async def start_operation(
53945447
*,
53955448
output_type: Optional[Type] = None,
53965449
schedule_to_close_timeout: Optional[timedelta] = None,
5450+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
53975451
headers: Optional[Mapping[str, str]] = None,
53985452
) -> Any:
53995453
return (
@@ -5404,6 +5458,7 @@ async def start_operation(
54045458
input=input,
54055459
output_type=output_type,
54065460
schedule_to_close_timeout=schedule_to_close_timeout,
5461+
cancellation_type=cancellation_type,
54075462
headers=headers,
54085463
)
54095464
)
@@ -5415,13 +5470,15 @@ async def execute_operation(
54155470
*,
54165471
output_type: Optional[Type] = None,
54175472
schedule_to_close_timeout: Optional[timedelta] = None,
5473+
cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED,
54185474
headers: Optional[Mapping[str, str]] = None,
54195475
) -> Any:
54205476
handle = await self.start_operation(
54215477
operation,
54225478
input,
54235479
output_type=output_type,
54245480
schedule_to_close_timeout=schedule_to_close_timeout,
5481+
cancellation_type=cancellation_type,
54255482
headers=headers,
54265483
)
54275484
return await handle

tests/helpers/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar
88

99
from temporalio.api.common.v1 import WorkflowExecution
10+
from temporalio.api.enums.v1 import EventType as EventType
1011
from temporalio.api.enums.v1 import IndexedValueType
12+
from temporalio.api.history.v1 import HistoryEvent
1113
from temporalio.api.operatorservice.v1 import (
1214
AddSearchAttributesRequest,
1315
ListSearchAttributesRequest,
@@ -287,3 +289,40 @@ async def check_unpaused() -> bool:
287289
return not info.paused
288290

289291
await assert_eventually(check_unpaused)
292+
293+
294+
async def print_history(handle: WorkflowHandle):
295+
i = 1
296+
async for evt in handle.fetch_history_events():
297+
event = EventType.Name(evt.event_type).removeprefix("EVENT_TYPE_")
298+
print(f"{i:2}: {event}")
299+
i += 1
300+
301+
302+
async def print_interleaved_histories(*handles: WorkflowHandle) -> None:
303+
"""
304+
Print the interleaved history events from multiple workflow handles in columns.
305+
"""
306+
all_events: list[tuple[WorkflowHandle, HistoryEvent, int]] = []
307+
for handle in handles:
308+
event_num = 1
309+
async for event in handle.fetch_history_events():
310+
all_events.append((handle, event, event_num))
311+
event_num += 1
312+
all_events.sort(key=lambda item: item[1].event_time.ToDatetime())
313+
col_width = 40
314+
315+
def _format_row(items: list[str], truncate: bool = False) -> str:
316+
if truncate:
317+
items = [item[: col_width - 3] for item in items]
318+
return " | ".join(f"{item:<{col_width - 3}}" for item in items)
319+
320+
headers = [handle.id for handle in handles]
321+
print("\n" + _format_row(headers, truncate=True))
322+
print("-" * (col_width * len(handles) + len(handles) - 1))
323+
for handle, event, event_num in all_events:
324+
event_type = EventType.Name(event.event_type).removeprefix("EVENT_TYPE_")
325+
row = [""] * len(handles)
326+
col_idx = handles.index(handle)
327+
row[col_idx] = f"{event_num:2}: {event_type[: col_width - 5]}"
328+
print(_format_row(row))

0 commit comments

Comments
 (0)