Skip to content

Commit 251084d

Browse files
authored
Fix nexus cancellation type test flake (#1057)
* Restrict CI and add debugging output * Run multiple times * AI output * Add future resolutuion to output * Use workflow.now() for op future resolution time * Fix test assertions * Avoid using global future * Revert "AI output" This reverts commit 3066ea0. * Revert "Run multiple times" This reverts commit 784bf51. * Revert "Restrict CI and add debugging output" This reverts commit 6abdbb6. * Rename variable * Add type annotation
1 parent 8bb0b80 commit 251084d

File tree

2 files changed

+47
-40
lines changed

2 files changed

+47
-40
lines changed

tests/nexus/test_workflow_caller_cancellation_types.py

Lines changed: 36 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@
2828
class TestContext:
2929
__test__ = False
3030
cancellation_type: workflow.NexusOperationCancellationType
31-
caller_op_future_resolved: asyncio.Future[datetime] = field(
32-
default_factory=asyncio.Future
33-
)
31+
caller_workflow_id: str
3432
cancel_handler_released: asyncio.Future[datetime] = field(
3533
default_factory=asyncio.Future
3634
)
@@ -96,7 +94,15 @@ async def cancel(
9694
# by the caller server. At that point, the caller server will write
9795
# NexusOperationCancelRequestCompleted. For TRY_CANCEL we want to prove that the nexus
9896
# op handle future can be resolved as cancelled before any of that.
99-
await test_context.caller_op_future_resolved
97+
caller_wf: WorkflowHandle[Any, CancellationResult] = (
98+
nexus.client().get_workflow_handle_for(
99+
CallerWorkflow.run,
100+
workflow_id=test_context.caller_workflow_id,
101+
)
102+
)
103+
await caller_wf.execute_update(
104+
CallerWorkflow.wait_caller_op_future_resolved
105+
)
100106
test_context.cancel_handler_released.set_result(datetime.now(timezone.utc))
101107
await super().cancel(ctx, token)
102108

@@ -117,6 +123,7 @@ class Input:
117123
@dataclass
118124
class CancellationResult:
119125
operation_token: str
126+
caller_op_future_resolved: datetime
120127

121128

122129
@workflow.defn(sandboxed=False)
@@ -129,6 +136,7 @@ def __init__(self, input: Input):
129136
)
130137
self.released = False
131138
self.operation_token: Optional[str] = None
139+
self.caller_op_future_resolved: asyncio.Future[datetime] = asyncio.Future()
132140

133141
@workflow.signal
134142
def release(self):
@@ -140,6 +148,10 @@ async def get_operation_token(self) -> str:
140148
assert self.operation_token
141149
return self.operation_token
142150

151+
@workflow.update
152+
async def wait_caller_op_future_resolved(self) -> None:
153+
await self.caller_op_future_resolved
154+
143155
@workflow.run
144156
async def run(self, input: Input) -> CancellationResult:
145157
op_handle = await (
@@ -188,9 +200,7 @@ async def run(self, input: Input) -> CancellationResult:
188200
try:
189201
await op_handle
190202
except exceptions.NexusOperationError:
191-
test_context.caller_op_future_resolved.set_result(
192-
datetime.now(timezone.utc)
193-
)
203+
self.caller_op_future_resolved.set_result(workflow.now())
194204
assert op_handle.operation_token
195205
if input.cancellation_type in [
196206
workflow.NexusOperationCancellationType.TRY_CANCEL,
@@ -210,6 +220,7 @@ async def run(self, input: Input) -> CancellationResult:
210220
await workflow.wait_condition(lambda: self.released)
211221
return CancellationResult(
212222
operation_token=op_handle.operation_token,
223+
caller_op_future_resolved=self.caller_op_future_resolved.result(),
213224
)
214225
else:
215226
pytest.fail("Expected NexusOperationError")
@@ -233,7 +244,10 @@ async def test_cancellation_type(
233244

234245
cancellation_type = workflow.NexusOperationCancellationType[cancellation_type_name]
235246
global test_context
236-
test_context = TestContext(cancellation_type=cancellation_type)
247+
test_context = TestContext(
248+
cancellation_type=cancellation_type,
249+
caller_workflow_id="caller-wf-" + str(uuid.uuid4()),
250+
)
237251

238252
client = env.client
239253

@@ -253,7 +267,7 @@ async def test_cancellation_type(
253267
endpoint=make_nexus_endpoint_name(worker.task_queue),
254268
cancellation_type=cancellation_type,
255269
),
256-
id="caller-wf-" + str(uuid.uuid4()),
270+
id=test_context.caller_workflow_id,
257271
task_queue=worker.task_queue,
258272
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
259273
)
@@ -314,8 +328,12 @@ async def check_behavior_for_try_cancel(
314328
) -> None:
315329
"""
316330
Check that a cancellation request is sent and the caller workflow nexus op future is unblocked
317-
as cancelled before the cancel handler returns (i.e. before the
318-
NexusOperationCancelRequestCompleted in the caller workflow history).
331+
as cancelled before the caller server writes CANCEL_REQUESTED.
332+
333+
There is a race between (a) the caller server writing CANCEL_REQUEST_COMPLETED in response to
334+
the cancel handler returning, and (b) the caller server writing CANCELED in response to the
335+
handler workflow exiting as canceled. If (b) happens first then (a) may never happen, therefore
336+
we do not make any assertions regarding CANCEL_REQUEST_COMPLETED.
319337
"""
320338
try:
321339
await handler_wf.result()
@@ -324,31 +342,21 @@ async def check_behavior_for_try_cancel(
324342
else:
325343
pytest.fail("Expected WorkflowFailureError")
326344
await caller_wf.signal(CallerWorkflow.release)
327-
await caller_wf.result()
345+
result = await caller_wf.result()
328346

329347
handler_status = (await handler_wf.describe()).status
330348
assert handler_status == WorkflowExecutionStatus.CANCELED
331-
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
332349
await assert_event_subsequence(
333350
[
334351
(caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED),
335-
(caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED),
336352
(caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCELED),
337353
]
338354
)
339355
op_cancel_requested_event = await get_event_time(
340356
caller_wf,
341357
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED,
342358
)
343-
op_cancel_request_completed_event = await get_event_time(
344-
caller_wf,
345-
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED,
346-
)
347-
assert (
348-
caller_op_future_resolved
349-
< op_cancel_requested_event
350-
< op_cancel_request_completed_event
351-
)
359+
assert result.caller_op_future_resolved < op_cancel_requested_event
352360

353361

354362
async def check_behavior_for_wait_cancellation_requested(
@@ -369,7 +377,7 @@ async def check_behavior_for_wait_cancellation_requested(
369377
pytest.fail("Expected WorkflowFailureError")
370378

371379
await caller_wf.signal(CallerWorkflow.release)
372-
await caller_wf.result()
380+
result = await caller_wf.result()
373381

374382
handler_status = (await handler_wf.describe()).status
375383
assert handler_status == WorkflowExecutionStatus.CANCELED
@@ -380,7 +388,6 @@ async def check_behavior_for_wait_cancellation_requested(
380388
(caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCELED),
381389
]
382390
)
383-
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
384391
op_cancel_request_completed = await get_event_time(
385392
caller_wf,
386393
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED,
@@ -389,7 +396,7 @@ async def check_behavior_for_wait_cancellation_requested(
389396
handler_wf,
390397
EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED,
391398
)
392-
assert op_cancel_request_completed < caller_op_future_resolved < op_canceled
399+
assert op_cancel_request_completed < result.caller_op_future_resolved < op_canceled
393400

394401

395402
async def check_behavior_for_wait_cancellation_completed(
@@ -411,7 +418,7 @@ async def check_behavior_for_wait_cancellation_completed(
411418
assert handler_status == WorkflowExecutionStatus.CANCELED
412419

413420
await caller_wf.signal(CallerWorkflow.release)
414-
await caller_wf.result()
421+
result = await caller_wf.result()
415422

416423
await assert_event_subsequence(
417424
[
@@ -426,12 +433,11 @@ async def check_behavior_for_wait_cancellation_completed(
426433
(caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED),
427434
]
428435
)
429-
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
430-
handler_wf_canceled_event_time = await get_event_time(
436+
handler_wf_canceled_event = await get_event_time(
431437
handler_wf,
432438
EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED,
433439
)
434-
assert caller_op_future_resolved > handler_wf_canceled_event_time
440+
assert handler_wf_canceled_event < result.caller_op_future_resolved
435441

436442

437443
async def has_event(wf_handle: WorkflowHandle, event_type: EventType.ValueType):

tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@
3535
class TestContext:
3636
__test__ = False
3737
cancellation_type: workflow.NexusOperationCancellationType
38-
caller_op_future_resolved: asyncio.Future[datetime] = field(
39-
default_factory=asyncio.Future
40-
)
4138
cancel_handler_released: asyncio.Future[datetime] = field(
4239
default_factory=asyncio.Future
4340
)
@@ -129,6 +126,7 @@ class Input:
129126
@dataclass
130127
class CancellationResult:
131128
operation_token: str
129+
caller_op_future_resolved: datetime
132130
error_type: Optional[str] = None
133131
error_cause_type: Optional[str] = None
134132

@@ -143,6 +141,7 @@ def __init__(self, input: Input):
143141
)
144142
self.released = False
145143
self.operation_token: Optional[str] = None
144+
self.caller_op_future_resolved: asyncio.Future[datetime] = asyncio.Future()
146145

147146
@workflow.signal
148147
def release(self):
@@ -184,13 +183,14 @@ async def run(self, input: Input) -> CancellationResult:
184183
error_type = err.__class__.__name__
185184
error_cause_type = err.__cause__.__class__.__name__
186185

187-
test_context.caller_op_future_resolved.set_result(datetime.now(timezone.utc))
186+
self.caller_op_future_resolved.set_result(workflow.now())
188187
assert op_handle.operation_token
189188
await workflow.wait_condition(lambda: self.released)
190189
return CancellationResult(
191190
operation_token=op_handle.operation_token,
192191
error_type=error_type,
193192
error_cause_type=error_cause_type,
193+
caller_op_future_resolved=self.caller_op_future_resolved.result(),
194194
)
195195

196196

@@ -300,7 +300,6 @@ async def check_behavior_for_try_cancel(
300300
assert result.error_type == "NexusOperationError"
301301
assert result.error_cause_type == "CancelledError"
302302

303-
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
304303
await assert_event_subsequence(
305304
[
306305
(caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED),
@@ -317,7 +316,7 @@ async def check_behavior_for_try_cancel(
317316
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED,
318317
)
319318
assert (
320-
caller_op_future_resolved
319+
result.caller_op_future_resolved
321320
< op_cancel_requested_event
322321
< op_cancel_request_failed_event
323322
)
@@ -341,7 +340,6 @@ async def check_behavior_for_wait_cancellation_requested(
341340
(caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED),
342341
]
343342
)
344-
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
345343
op_cancel_request_failed = await get_event_time(
346344
caller_wf,
347345
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED,
@@ -350,7 +348,11 @@ async def check_behavior_for_wait_cancellation_requested(
350348
handler_wf,
351349
EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED,
352350
)
353-
assert op_cancel_request_failed < caller_op_future_resolved < handler_wf_completed
351+
assert (
352+
op_cancel_request_failed
353+
< result.caller_op_future_resolved
354+
< handler_wf_completed
355+
)
354356

355357

356358
async def check_behavior_for_wait_cancellation_completed(
@@ -373,9 +375,8 @@ async def check_behavior_for_wait_cancellation_completed(
373375
(caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_COMPLETED),
374376
]
375377
)
376-
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
377378
handler_wf_completed = await get_event_time(
378379
handler_wf,
379380
EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED,
380381
)
381-
assert handler_wf_completed < caller_op_future_resolved
382+
assert handler_wf_completed < result.caller_op_future_resolved

0 commit comments

Comments
 (0)