1111import temporalio .nexus ._operation_handlers
1212from temporalio import exceptions , nexus , workflow
1313from temporalio .api .enums .v1 import EventType
14- from temporalio .api .history .v1 import HistoryEvent
1514from temporalio .client import (
1615 WithStartWorkflowOperation ,
1716 WorkflowExecutionStatus ,
2120from temporalio .testing import WorkflowEnvironment
2221from temporalio .worker import Worker
2322from tests .helpers .nexus import create_nexus_endpoint , make_nexus_endpoint_name
23+ from tests .nexus .test_workflow_caller_cancellation_types import (
24+ assert_event_subsequence ,
25+ get_event_time ,
26+ has_event ,
27+ )
2428
2529
2630@dataclass
@@ -257,13 +261,13 @@ async def check_behavior_for_abandon(
257261 assert result .error_type == "NexusOperationError"
258262 assert result .error_cause_type == "CancelledError"
259263
260- await _assert_event_subsequence (
264+ await assert_event_subsequence (
261265 [
262266 (caller_wf , EventType .EVENT_TYPE_WORKFLOW_EXECUTION_STARTED ),
263267 (caller_wf , EventType .EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED ),
264268 ]
265269 )
266- assert not await _has_event (
270+ assert not await has_event (
267271 caller_wf ,
268272 EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED ,
269273 )
@@ -280,18 +284,18 @@ async def check_behavior_for_try_cancel(
280284 assert result .error_cause_type == "CancelledError"
281285
282286 caller_op_future_resolved = test_context .caller_op_future_resolved .result ()
283- await _assert_event_subsequence (
287+ await assert_event_subsequence (
284288 [
285289 (caller_wf , EventType .EVENT_TYPE_WORKFLOW_EXECUTION_STARTED ),
286290 (caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED ),
287291 (caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED ),
288292 ]
289293 )
290- op_cancel_requested_event = await _get_event_time (
294+ op_cancel_requested_event = await get_event_time (
291295 caller_wf ,
292296 EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED ,
293297 )
294- op_cancel_request_failed_event = await _get_event_time (
298+ op_cancel_request_failed_event = await get_event_time (
295299 caller_wf ,
296300 EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED ,
297301 )
@@ -311,7 +315,7 @@ async def check_behavior_for_wait_cancellation_requested(
311315 result = await caller_wf .result ()
312316 assert result .error_type == "NexusOperationError"
313317 assert result .error_cause_type == "HandlerError"
314- await _assert_event_subsequence (
318+ await assert_event_subsequence (
315319 [
316320 (caller_wf , EventType .EVENT_TYPE_WORKFLOW_EXECUTION_STARTED ),
317321 (caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED ),
@@ -320,7 +324,7 @@ async def check_behavior_for_wait_cancellation_requested(
320324 ]
321325 )
322326 caller_op_future_resolved = test_context .caller_op_future_resolved .result ()
323- op_cancel_request_failed = await _get_event_time (
327+ op_cancel_request_failed = await get_event_time (
324328 caller_wf ,
325329 EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED ,
326330 )
@@ -335,7 +339,7 @@ async def check_behavior_for_wait_cancellation_completed(
335339 await caller_wf .signal (CallerWorkflow .release )
336340 result = await caller_wf .result ()
337341 assert not result .error_type
338- await _assert_event_subsequence (
342+ await assert_event_subsequence (
339343 [
340344 (caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED ),
341345 (caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED ),
@@ -344,72 +348,8 @@ async def check_behavior_for_wait_cancellation_completed(
344348 ]
345349 )
346350 caller_op_future_resolved = test_context .caller_op_future_resolved .result ()
347- handler_wf_completed = await _get_event_time (
351+ handler_wf_completed = await get_event_time (
348352 handler_wf ,
349353 EventType .EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED ,
350354 )
351355 assert handler_wf_completed < caller_op_future_resolved
352-
353-
354- async def _has_event (wf_handle : WorkflowHandle , event_type : EventType .ValueType ):
355- async for e in wf_handle .fetch_history_events ():
356- if e .event_type == event_type :
357- return True
358- return False
359-
360-
361- async def _get_event_time (
362- wf_handle : WorkflowHandle ,
363- event_type : EventType .ValueType ,
364- ) -> datetime :
365- async for event in wf_handle .fetch_history_events ():
366- if event .event_type == event_type :
367- return event .event_time .ToDatetime ().replace (tzinfo = timezone .utc )
368- event_type_name = EventType .Name (event_type ).removeprefix ("EVENT_TYPE_" )
369- assert False , f"Event { event_type_name } not found in { wf_handle .id } "
370-
371-
372- async def _assert_event_subsequence (
373- expected_events : list [tuple [WorkflowHandle , EventType .ValueType ]],
374- ) -> None :
375- """
376- Given a sequence of (WorkflowHandle, EventType) pairs, assert that the sorted sequence of events
377- from both workflows contains that subsequence.
378- """
379-
380- def _event_time (
381- item : tuple [WorkflowHandle , HistoryEvent ],
382- ) -> datetime :
383- return item [1 ].event_time .ToDatetime ()
384-
385- all_events = []
386- handles = {h for h , _ in expected_events }
387- for h in handles :
388- async for e in h .fetch_history_events ():
389- all_events .append ((h , e ))
390- _all_events = iter (sorted (all_events , key = _event_time ))
391- _expected_events = iter (expected_events )
392-
393- previous_expected_handle , previous_expected_event_type_name = None , None
394- for expected_handle , expected_event_type in _expected_events :
395- expected_event_type_name = EventType .Name (expected_event_type ).removeprefix (
396- "EVENT_TYPE_"
397- )
398- has_expected = next (
399- (
400- (h , e )
401- for h , e in _all_events
402- if h == expected_handle and e .event_type == expected_event_type
403- ),
404- None ,
405- )
406- if not has_expected :
407- if previous_expected_handle is not None :
408- prefix = f"After { previous_expected_event_type_name } in { previous_expected_handle .id } , "
409- else :
410- prefix = ""
411- pytest .fail (
412- f"{ prefix } expected { expected_event_type_name } in { expected_handle .id } "
413- )
414- previous_expected_event_type_name = expected_event_type_name
415- previous_expected_handle = expected_handle
0 commit comments