Skip to content

Commit efedb12

Browse files
committed
Avoid using global future
1 parent 9617092 commit efedb12

File tree

2 files changed

+40
-28
lines changed

2 files changed

+40
-28
lines changed

tests/nexus/test_workflow_caller_cancellation_types.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,7 @@
2929
class TestContext:
3030
__test__ = False
3131
cancellation_type: workflow.NexusOperationCancellationType
32-
caller_op_future_resolved: asyncio.Future[datetime] = field(
33-
default_factory=asyncio.Future
34-
)
32+
caller_workflow_id: str
3533
cancel_handler_released: asyncio.Future[datetime] = field(
3634
default_factory=asyncio.Future
3735
)
@@ -97,7 +95,13 @@ async def cancel(
9795
# by the caller server. At that point, the caller server will write
9896
# NexusOperationCancelRequestCompleted. For TRY_CANCEL we want to prove that the nexus
9997
# op handle future can be resolved as cancelled before any of that.
100-
await test_context.caller_op_future_resolved
98+
caller_wf = nexus.client().get_workflow_handle_for(
99+
CallerWorkflow.run,
100+
workflow_id=test_context.caller_workflow_id,
101+
)
102+
await caller_wf.execute_update(
103+
CallerWorkflow.wait_caller_op_future_resolved
104+
)
101105
test_context.cancel_handler_released.set_result(datetime.now(timezone.utc))
102106
await super().cancel(ctx, token)
103107

@@ -118,6 +122,7 @@ class Input:
118122
@dataclass
119123
class CancellationResult:
120124
operation_token: str
125+
caller_op_future_resolved: datetime
121126

122127

123128
@workflow.defn(sandboxed=False)
@@ -130,6 +135,7 @@ def __init__(self, input: Input):
130135
)
131136
self.released = False
132137
self.operation_token: Optional[str] = None
138+
self.caller_op_future_resolved: asyncio.Future[datetime] = asyncio.Future()
133139

134140
@workflow.signal
135141
def release(self):
@@ -141,6 +147,10 @@ async def get_operation_token(self) -> str:
141147
assert self.operation_token
142148
return self.operation_token
143149

150+
@workflow.update
151+
async def wait_caller_op_future_resolved(self) -> None:
152+
await self.caller_op_future_resolved
153+
144154
@workflow.run
145155
async def run(self, input: Input) -> CancellationResult:
146156
op_handle = await (
@@ -189,7 +199,7 @@ async def run(self, input: Input) -> CancellationResult:
189199
try:
190200
await op_handle
191201
except exceptions.NexusOperationError:
192-
test_context.caller_op_future_resolved.set_result(workflow.now())
202+
self.caller_op_future_resolved.set_result(workflow.now())
193203
assert op_handle.operation_token
194204
if input.cancellation_type in [
195205
workflow.NexusOperationCancellationType.TRY_CANCEL,
@@ -209,6 +219,7 @@ async def run(self, input: Input) -> CancellationResult:
209219
await workflow.wait_condition(lambda: self.released)
210220
return CancellationResult(
211221
operation_token=op_handle.operation_token,
222+
caller_op_future_resolved=self.caller_op_future_resolved.result(),
212223
)
213224
else:
214225
pytest.fail("Expected NexusOperationError")
@@ -232,7 +243,10 @@ async def test_cancellation_type(
232243

233244
cancellation_type = workflow.NexusOperationCancellationType[cancellation_type_name]
234245
global test_context
235-
test_context = TestContext(cancellation_type=cancellation_type)
246+
test_context = TestContext(
247+
cancellation_type=cancellation_type,
248+
caller_workflow_id="caller-wf-" + str(uuid.uuid4()),
249+
)
236250

237251
client = env.client
238252

@@ -252,7 +266,7 @@ async def test_cancellation_type(
252266
endpoint=make_nexus_endpoint_name(worker.task_queue),
253267
cancellation_type=cancellation_type,
254268
),
255-
id="caller-wf-" + str(uuid.uuid4()),
269+
id=test_context.caller_workflow_id,
256270
task_queue=worker.task_queue,
257271
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
258272
)
@@ -327,19 +341,18 @@ async def check_behavior_for_try_cancel(
327341
else:
328342
pytest.fail("Expected WorkflowFailureError")
329343
await caller_wf.signal(CallerWorkflow.release)
330-
await caller_wf.result()
344+
result = await caller_wf.result()
331345

332346
handler_status = (await handler_wf.describe()).status
333347
assert handler_status == WorkflowExecutionStatus.CANCELED
334-
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
335348

336349
await print_interleaved_histories(
337350
[caller_wf, handler_wf],
338351
extra_events=[
339352
(
340353
caller_wf,
341354
"Caller op future resolved",
342-
caller_op_future_resolved,
355+
result.caller_op_future_resolved,
343356
)
344357
],
345358
)
@@ -354,7 +367,7 @@ async def check_behavior_for_try_cancel(
354367
caller_wf,
355368
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED,
356369
)
357-
assert caller_op_future_resolved < op_cancel_requested_event
370+
assert result.caller_op_future_resolved < op_cancel_requested_event
358371

359372

360373
async def check_behavior_for_wait_cancellation_requested(
@@ -375,7 +388,7 @@ async def check_behavior_for_wait_cancellation_requested(
375388
pytest.fail("Expected WorkflowFailureError")
376389

377390
await caller_wf.signal(CallerWorkflow.release)
378-
await caller_wf.result()
391+
result = await caller_wf.result()
379392

380393
handler_status = (await handler_wf.describe()).status
381394
assert handler_status == WorkflowExecutionStatus.CANCELED
@@ -386,7 +399,6 @@ async def check_behavior_for_wait_cancellation_requested(
386399
(caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCELED),
387400
]
388401
)
389-
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
390402
op_cancel_request_completed = await get_event_time(
391403
caller_wf,
392404
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED,
@@ -395,7 +407,7 @@ async def check_behavior_for_wait_cancellation_requested(
395407
handler_wf,
396408
EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED,
397409
)
398-
assert op_cancel_request_completed < caller_op_future_resolved < op_canceled
410+
assert op_cancel_request_completed < result.caller_op_future_resolved < op_canceled
399411

400412

401413
async def check_behavior_for_wait_cancellation_completed(
@@ -417,7 +429,7 @@ async def check_behavior_for_wait_cancellation_completed(
417429
assert handler_status == WorkflowExecutionStatus.CANCELED
418430

419431
await caller_wf.signal(CallerWorkflow.release)
420-
await caller_wf.result()
432+
result = await caller_wf.result()
421433

422434
await assert_event_subsequence(
423435
[
@@ -432,12 +444,11 @@ async def check_behavior_for_wait_cancellation_completed(
432444
(caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED),
433445
]
434446
)
435-
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
436447
handler_wf_canceled_event_time = await get_event_time(
437448
handler_wf,
438449
EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED,
439450
)
440-
assert caller_op_future_resolved > handler_wf_canceled_event_time
451+
assert result.caller_op_future_resolved > handler_wf_canceled_event_time
441452

442453

443454
async def has_event(wf_handle: WorkflowHandle, event_type: EventType.ValueType):

tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@
3636
class TestContext:
3737
__test__ = False
3838
cancellation_type: workflow.NexusOperationCancellationType
39-
caller_op_future_resolved: asyncio.Future[datetime] = field(
40-
default_factory=asyncio.Future
41-
)
4239
cancel_handler_released: asyncio.Future[datetime] = field(
4340
default_factory=asyncio.Future
4441
)
@@ -130,6 +127,7 @@ class Input:
130127
@dataclass
131128
class CancellationResult:
132129
operation_token: str
130+
caller_op_future_resolved: datetime
133131
error_type: Optional[str] = None
134132
error_cause_type: Optional[str] = None
135133

@@ -144,6 +142,7 @@ def __init__(self, input: Input):
144142
)
145143
self.released = False
146144
self.operation_token: Optional[str] = None
145+
self.caller_op_future_resolved: asyncio.Future[datetime] = asyncio.Future()
147146

148147
@workflow.signal
149148
def release(self):
@@ -185,13 +184,14 @@ async def run(self, input: Input) -> CancellationResult:
185184
error_type = err.__class__.__name__
186185
error_cause_type = err.__cause__.__class__.__name__
187186

188-
test_context.caller_op_future_resolved.set_result(workflow.now())
187+
self.caller_op_future_resolved.set_result(workflow.now())
189188
assert op_handle.operation_token
190189
await workflow.wait_condition(lambda: self.released)
191190
return CancellationResult(
192191
operation_token=op_handle.operation_token,
193192
error_type=error_type,
194193
error_cause_type=error_cause_type,
194+
caller_op_future_resolved=self.caller_op_future_resolved.result(),
195195
)
196196

197197

@@ -301,14 +301,13 @@ async def check_behavior_for_try_cancel(
301301
assert result.error_type == "NexusOperationError"
302302
assert result.error_cause_type == "CancelledError"
303303

304-
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
305304
await print_interleaved_histories(
306305
[caller_wf, handler_wf],
307306
extra_events=[
308307
(
309308
caller_wf,
310309
"Caller op future resolved",
311-
caller_op_future_resolved,
310+
result.caller_op_future_resolved,
312311
)
313312
],
314313
)
@@ -329,7 +328,7 @@ async def check_behavior_for_try_cancel(
329328
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED,
330329
)
331330
assert (
332-
caller_op_future_resolved
331+
result.caller_op_future_resolved
333332
< op_cancel_requested_event
334333
< op_cancel_request_failed_event
335334
)
@@ -353,7 +352,6 @@ async def check_behavior_for_wait_cancellation_requested(
353352
(caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED),
354353
]
355354
)
356-
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
357355
op_cancel_request_failed = await get_event_time(
358356
caller_wf,
359357
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED,
@@ -362,7 +360,11 @@ async def check_behavior_for_wait_cancellation_requested(
362360
handler_wf,
363361
EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED,
364362
)
365-
assert op_cancel_request_failed < caller_op_future_resolved < handler_wf_completed
363+
assert (
364+
op_cancel_request_failed
365+
< result.caller_op_future_resolved
366+
< handler_wf_completed
367+
)
366368

367369

368370
async def check_behavior_for_wait_cancellation_completed(
@@ -385,9 +387,8 @@ async def check_behavior_for_wait_cancellation_completed(
385387
(caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_COMPLETED),
386388
]
387389
)
388-
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
389390
handler_wf_completed = await get_event_time(
390391
handler_wf,
391392
EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED,
392393
)
393-
assert handler_wf_completed < caller_op_future_resolved
394+
assert handler_wf_completed < result.caller_op_future_resolved

0 commit comments

Comments
 (0)