2828class TestContext :
2929 __test__ = False
3030 cancellation_type : workflow .NexusOperationCancellationType
31- caller_op_future_resolved : asyncio .Future [datetime ] = field (
32- default_factory = asyncio .Future
33- )
31+ caller_workflow_id : str
3432 cancel_handler_released : asyncio .Future [datetime ] = field (
3533 default_factory = asyncio .Future
3634 )
@@ -96,7 +94,15 @@ async def cancel(
9694 # by the caller server. At that point, the caller server will write
9795 # NexusOperationCancelRequestCompleted. For TRY_CANCEL we want to prove that the nexus
9896 # op handle future can be resolved as cancelled before any of that.
99- await test_context .caller_op_future_resolved
97+ caller_wf : WorkflowHandle [Any , CancellationResult ] = (
98+ nexus .client ().get_workflow_handle_for (
99+ CallerWorkflow .run ,
100+ workflow_id = test_context .caller_workflow_id ,
101+ )
102+ )
103+ await caller_wf .execute_update (
104+ CallerWorkflow .wait_caller_op_future_resolved
105+ )
100106 test_context .cancel_handler_released .set_result (datetime .now (timezone .utc ))
101107 await super ().cancel (ctx , token )
102108
@@ -117,6 +123,7 @@ class Input:
117123@dataclass
118124class CancellationResult :
119125 operation_token : str
126+ caller_op_future_resolved : datetime
120127
121128
122129@workflow .defn (sandboxed = False )
@@ -129,6 +136,7 @@ def __init__(self, input: Input):
129136 )
130137 self .released = False
131138 self .operation_token : Optional [str ] = None
139+ self .caller_op_future_resolved : asyncio .Future [datetime ] = asyncio .Future ()
132140
133141 @workflow .signal
134142 def release (self ):
@@ -140,6 +148,10 @@ async def get_operation_token(self) -> str:
140148 assert self .operation_token
141149 return self .operation_token
142150
151+ @workflow .update
152+ async def wait_caller_op_future_resolved (self ) -> None :
153+ await self .caller_op_future_resolved
154+
143155 @workflow .run
144156 async def run (self , input : Input ) -> CancellationResult :
145157 op_handle = await (
@@ -188,9 +200,7 @@ async def run(self, input: Input) -> CancellationResult:
188200 try :
189201 await op_handle
190202 except exceptions .NexusOperationError :
191- test_context .caller_op_future_resolved .set_result (
192- datetime .now (timezone .utc )
193- )
203+ self .caller_op_future_resolved .set_result (workflow .now ())
194204 assert op_handle .operation_token
195205 if input .cancellation_type in [
196206 workflow .NexusOperationCancellationType .TRY_CANCEL ,
@@ -210,6 +220,7 @@ async def run(self, input: Input) -> CancellationResult:
210220 await workflow .wait_condition (lambda : self .released )
211221 return CancellationResult (
212222 operation_token = op_handle .operation_token ,
223+ caller_op_future_resolved = self .caller_op_future_resolved .result (),
213224 )
214225 else :
215226 pytest .fail ("Expected NexusOperationError" )
@@ -233,7 +244,10 @@ async def test_cancellation_type(
233244
234245 cancellation_type = workflow .NexusOperationCancellationType [cancellation_type_name ]
235246 global test_context
236- test_context = TestContext (cancellation_type = cancellation_type )
247+ test_context = TestContext (
248+ cancellation_type = cancellation_type ,
249+ caller_workflow_id = "caller-wf-" + str (uuid .uuid4 ()),
250+ )
237251
238252 client = env .client
239253
@@ -253,7 +267,7 @@ async def test_cancellation_type(
253267 endpoint = make_nexus_endpoint_name (worker .task_queue ),
254268 cancellation_type = cancellation_type ,
255269 ),
256- id = "caller-wf-" + str ( uuid . uuid4 ()) ,
270+ id = test_context . caller_workflow_id ,
257271 task_queue = worker .task_queue ,
258272 id_conflict_policy = WorkflowIDConflictPolicy .FAIL ,
259273 )
@@ -314,8 +328,12 @@ async def check_behavior_for_try_cancel(
314328) -> None :
315329 """
316330 Check that a cancellation request is sent and the caller workflow nexus op future is unblocked
317- as cancelled before the cancel handler returns (i.e. before the
318- NexusOperationCancelRequestCompleted in the caller workflow history).
331+ as cancelled before the caller server writes CANCEL_REQUESTED.
332+
333+ There is a race between (a) the caller server writing CANCEL_REQUEST_COMPLETED in response to
334+ the cancel handler returning, and (b) the caller server writing CANCELED in response to the
335+ handler workflow exiting as canceled. If (b) happens first then (a) may never happen, therefore
336+ we do not make any assertions regarding CANCEL_REQUEST_COMPLETED.
319337 """
320338 try :
321339 await handler_wf .result ()
@@ -324,31 +342,21 @@ async def check_behavior_for_try_cancel(
324342 else :
325343 pytest .fail ("Expected WorkflowFailureError" )
326344 await caller_wf .signal (CallerWorkflow .release )
327- await caller_wf .result ()
345+ result = await caller_wf .result ()
328346
329347 handler_status = (await handler_wf .describe ()).status
330348 assert handler_status == WorkflowExecutionStatus .CANCELED
331- caller_op_future_resolved = test_context .caller_op_future_resolved .result ()
332349 await assert_event_subsequence (
333350 [
334351 (caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED ),
335- (caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED ),
336352 (caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCELED ),
337353 ]
338354 )
339355 op_cancel_requested_event = await get_event_time (
340356 caller_wf ,
341357 EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED ,
342358 )
343- op_cancel_request_completed_event = await get_event_time (
344- caller_wf ,
345- EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED ,
346- )
347- assert (
348- caller_op_future_resolved
349- < op_cancel_requested_event
350- < op_cancel_request_completed_event
351- )
359+ assert result .caller_op_future_resolved < op_cancel_requested_event
352360
353361
354362async def check_behavior_for_wait_cancellation_requested (
@@ -369,7 +377,7 @@ async def check_behavior_for_wait_cancellation_requested(
369377 pytest .fail ("Expected WorkflowFailureError" )
370378
371379 await caller_wf .signal (CallerWorkflow .release )
372- await caller_wf .result ()
380+ result = await caller_wf .result ()
373381
374382 handler_status = (await handler_wf .describe ()).status
375383 assert handler_status == WorkflowExecutionStatus .CANCELED
@@ -380,7 +388,6 @@ async def check_behavior_for_wait_cancellation_requested(
380388 (caller_wf , EventType .EVENT_TYPE_NEXUS_OPERATION_CANCELED ),
381389 ]
382390 )
383- caller_op_future_resolved = test_context .caller_op_future_resolved .result ()
384391 op_cancel_request_completed = await get_event_time (
385392 caller_wf ,
386393 EventType .EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED ,
@@ -389,7 +396,7 @@ async def check_behavior_for_wait_cancellation_requested(
389396 handler_wf ,
390397 EventType .EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED ,
391398 )
392- assert op_cancel_request_completed < caller_op_future_resolved < op_canceled
399+ assert op_cancel_request_completed < result . caller_op_future_resolved < op_canceled
393400
394401
395402async def check_behavior_for_wait_cancellation_completed (
@@ -411,7 +418,7 @@ async def check_behavior_for_wait_cancellation_completed(
411418 assert handler_status == WorkflowExecutionStatus .CANCELED
412419
413420 await caller_wf .signal (CallerWorkflow .release )
414- await caller_wf .result ()
421+ result = await caller_wf .result ()
415422
416423 await assert_event_subsequence (
417424 [
@@ -426,12 +433,11 @@ async def check_behavior_for_wait_cancellation_completed(
426433 (caller_wf , EventType .EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED ),
427434 ]
428435 )
429- caller_op_future_resolved = test_context .caller_op_future_resolved .result ()
430- handler_wf_canceled_event_time = await get_event_time (
436+ handler_wf_canceled_event = await get_event_time (
431437 handler_wf ,
432438 EventType .EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED ,
433439 )
434- assert caller_op_future_resolved > handler_wf_canceled_event_time
440+ assert handler_wf_canceled_event < result . caller_op_future_resolved
435441
436442
437443async def has_event (wf_handle : WorkflowHandle , event_type : EventType .ValueType ):
0 commit comments