Skip to content

Commit a14cfbe

Browse files
committed
Add tests of asyncio.Lock and asyncio.Semaphore usage
1 parent 38d9eef commit a14cfbe

File tree

2 files changed

+322
-2
lines changed

2 files changed

+322
-2
lines changed

tests/helpers/__init__.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import uuid
55
from contextlib import closing
66
from datetime import timedelta
7-
from typing import Awaitable, Callable, Optional, Sequence, Type, TypeVar
7+
from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar
88

99
from temporalio.api.common.v1 import WorkflowExecution
1010
from temporalio.api.enums.v1 import IndexedValueType
@@ -14,11 +14,12 @@
1414
)
1515
from temporalio.api.update.v1 import UpdateRef
1616
from temporalio.api.workflowservice.v1 import PollWorkflowExecutionUpdateRequest
17-
from temporalio.client import BuildIdOpAddNewDefault, Client
17+
from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle
1818
from temporalio.common import SearchAttributeKey
1919
from temporalio.service import RPCError, RPCStatusCode
2020
from temporalio.worker import Worker, WorkflowRunner
2121
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
22+
from temporalio.workflow import UpdateMethodMultiParam
2223

2324

2425
def new_worker(
@@ -128,3 +129,24 @@ async def workflow_update_exists(
128129
if err.status != RPCStatusCode.NOT_FOUND:
129130
raise
130131
return False
132+
133+
134+
# TODO: type update return value
135+
async def admitted_update_task(
136+
client: Client,
137+
handle: WorkflowHandle,
138+
update_method: UpdateMethodMultiParam,
139+
id: str,
140+
**kwargs,
141+
) -> asyncio.Task:
142+
"""
143+
Return an asyncio.Task for an update after waiting for it to be admitted.
144+
"""
145+
update_task = asyncio.create_task(
146+
handle.execute_update(update_method, id=id, **kwargs)
147+
)
148+
await assert_eq_eventually(
149+
True,
150+
lambda: workflow_update_exists(client, handle.id, id),
151+
)
152+
return update_task

tests/worker/test_workflow.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
WorkflowRunner,
106106
)
107107
from tests.helpers import (
108+
admitted_update_task,
108109
assert_eq_eventually,
109110
ensure_search_attributes_present,
110111
find_free_port,
@@ -5505,3 +5506,300 @@ def _unfinished_handler_warning_cls(self) -> Type:
55055506
"update": workflow.UnfinishedUpdateHandlersWarning,
55065507
"signal": workflow.UnfinishedSignalHandlersWarning,
55075508
}[self.handler_type]
5509+
5510+
5511+
# The following Lock and Semaphore tests test that asyncio concurrency primitives work as expected
5512+
# in workflow code. There is nothing Temporal-specific about the way that asyncio.Lock and
5513+
# asyncio.Semaphore are used here.
5514+
5515+
5516+
@activity.defn
5517+
async def noop_activity_for_lock_or_semaphore_tests() -> None:
5518+
return None
5519+
5520+
5521+
@dataclass
5522+
class LockOrSemaphoreWorkflowConcurrencySummary:
5523+
ever_in_critical_section: int
5524+
peak_in_critical_section: int
5525+
5526+
5527+
@dataclass
5528+
class UseLockOrSemaphoreWorkflowParameters:
5529+
n_coroutines: int = 0
5530+
semaphore_initial_value: Optional[int] = None
5531+
sleep: Optional[float] = None
5532+
timeout: Optional[float] = None
5533+
5534+
5535+
@workflow.defn
5536+
class CoroutinesUseLockWorkflow:
5537+
def __init__(self) -> None:
5538+
self.params: UseLockOrSemaphoreWorkflowParameters
5539+
self.lock_or_semaphore: Union[asyncio.Lock, asyncio.Semaphore]
5540+
self._currently_in_critical_section: set[str] = set()
5541+
self._ever_in_critical_section: set[str] = set()
5542+
self._peak_in_critical_section = 0
5543+
5544+
def init(self, params: UseLockOrSemaphoreWorkflowParameters):
5545+
self.params = params
5546+
if self.params.semaphore_initial_value is not None:
5547+
self.lock_or_semaphore = asyncio.Semaphore(
5548+
self.params.semaphore_initial_value
5549+
)
5550+
else:
5551+
self.lock_or_semaphore = asyncio.Lock()
5552+
5553+
@workflow.run
5554+
async def run(
5555+
self,
5556+
params: UseLockOrSemaphoreWorkflowParameters,
5557+
) -> LockOrSemaphoreWorkflowConcurrencySummary:
5558+
# TODO: Use workflow init method when it exists.
5559+
self.init(params)
5560+
await asyncio.gather(
5561+
*(self.coroutine(f"{i}") for i in range(self.params.n_coroutines))
5562+
)
5563+
assert not any(self._currently_in_critical_section)
5564+
return LockOrSemaphoreWorkflowConcurrencySummary(
5565+
len(self._ever_in_critical_section),
5566+
self._peak_in_critical_section,
5567+
)
5568+
5569+
async def coroutine(self, id: str):
5570+
if self.params.timeout:
5571+
try:
5572+
await asyncio.wait_for(
5573+
self.lock_or_semaphore.acquire(), self.params.timeout
5574+
)
5575+
except asyncio.TimeoutError:
5576+
return
5577+
else:
5578+
await self.lock_or_semaphore.acquire()
5579+
self._enters_critical_section(id)
5580+
try:
5581+
if self.params.sleep:
5582+
await asyncio.sleep(self.params.sleep)
5583+
else:
5584+
await workflow.execute_activity(
5585+
noop_activity_for_lock_or_semaphore_tests,
5586+
schedule_to_close_timeout=timedelta(seconds=30),
5587+
)
5588+
finally:
5589+
self.lock_or_semaphore.release()
5590+
self._exits_critical_section(id)
5591+
5592+
def _enters_critical_section(self, id: str) -> None:
5593+
self._currently_in_critical_section.add(id)
5594+
self._ever_in_critical_section.add(id)
5595+
self._peak_in_critical_section = max(
5596+
self._peak_in_critical_section,
5597+
len(self._currently_in_critical_section),
5598+
)
5599+
5600+
def _exits_critical_section(self, id: str) -> None:
5601+
self._currently_in_critical_section.remove(id)
5602+
5603+
5604+
@workflow.defn
5605+
class HandlerCoroutinesUseLockWorkflow(CoroutinesUseLockWorkflow):
5606+
def __init__(self) -> None:
5607+
super().__init__()
5608+
self.workflow_may_exit = False
5609+
5610+
@workflow.run
5611+
async def run(
5612+
self,
5613+
) -> LockOrSemaphoreWorkflowConcurrencySummary:
5614+
await workflow.wait_condition(lambda: self.workflow_may_exit)
5615+
return LockOrSemaphoreWorkflowConcurrencySummary(
5616+
len(self._ever_in_critical_section),
5617+
self._peak_in_critical_section,
5618+
)
5619+
5620+
@workflow.update
5621+
async def my_update(self, params: UseLockOrSemaphoreWorkflowParameters):
5622+
# TODO: Use workflow init method when it exists.
5623+
if not hasattr(self, "params"):
5624+
self.init(params)
5625+
assert (update_info := workflow.current_update_info())
5626+
await self.coroutine(update_info.id)
5627+
5628+
@workflow.signal
5629+
async def finish(self):
5630+
self.workflow_may_exit = True
5631+
5632+
5633+
async def _do_workflow_coroutines_lock_or_semaphore_test(
5634+
client: Client,
5635+
params: UseLockOrSemaphoreWorkflowParameters,
5636+
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
5637+
):
5638+
async with new_worker(
5639+
client,
5640+
CoroutinesUseLockWorkflow,
5641+
activities=[noop_activity_for_lock_or_semaphore_tests],
5642+
) as worker:
5643+
summary = await client.execute_workflow(
5644+
CoroutinesUseLockWorkflow.run,
5645+
arg=params,
5646+
id=str(uuid.uuid4()),
5647+
task_queue=worker.task_queue,
5648+
)
5649+
assert summary == expectation
5650+
5651+
5652+
async def _do_update_handler_lock_or_semaphore_test(
5653+
client: Client,
5654+
env: WorkflowEnvironment,
5655+
params: UseLockOrSemaphoreWorkflowParameters,
5656+
n_updates: int,
5657+
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
5658+
):
5659+
if env.supports_time_skipping:
5660+
pytest.skip(
5661+
"Java test server: https://github.com/temporalio/sdk-java/issues/1903"
5662+
)
5663+
5664+
task_queue = "tq"
5665+
handle = await client.start_workflow(
5666+
HandlerCoroutinesUseLockWorkflow.run,
5667+
id=f"wf-{str(uuid.uuid4())}",
5668+
task_queue=task_queue,
5669+
)
5670+
# Create updates in Admitted state, before the worker starts polling.
5671+
admitted_updates = [
5672+
await admitted_update_task(
5673+
client,
5674+
handle,
5675+
HandlerCoroutinesUseLockWorkflow.my_update,
5676+
arg=params,
5677+
id=f"update-{i}",
5678+
)
5679+
for i in range(n_updates)
5680+
]
5681+
async with new_worker(
5682+
client,
5683+
HandlerCoroutinesUseLockWorkflow,
5684+
activities=[noop_activity_for_lock_or_semaphore_tests],
5685+
task_queue=task_queue,
5686+
):
5687+
for update_task in admitted_updates:
5688+
await update_task
5689+
await handle.signal(HandlerCoroutinesUseLockWorkflow.finish)
5690+
summary = await handle.result()
5691+
assert summary == expectation
5692+
5693+
5694+
async def test_workflow_coroutines_can_use_lock(client: Client):
5695+
await _do_workflow_coroutines_lock_or_semaphore_test(
5696+
client,
5697+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5),
5698+
# The lock limits concurrency to 1
5699+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5700+
ever_in_critical_section=5, peak_in_critical_section=1
5701+
),
5702+
)
5703+
5704+
5705+
async def test_update_handler_can_use_lock_to_serialize_handler_executions(
5706+
client: Client, env: WorkflowEnvironment
5707+
):
5708+
await _do_update_handler_lock_or_semaphore_test(
5709+
client,
5710+
env,
5711+
UseLockOrSemaphoreWorkflowParameters(),
5712+
n_updates=5,
5713+
# The lock limits concurrency to 1
5714+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5715+
ever_in_critical_section=5, peak_in_critical_section=1
5716+
),
5717+
)
5718+
5719+
5720+
async def test_workflow_coroutines_lock_acquisition_respects_timeout(client: Client):
5721+
await _do_workflow_coroutines_lock_or_semaphore_test(
5722+
client,
5723+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, sleep=0.5, timeout=0.1),
5724+
# Second and subsequent coroutines fail to acquire the lock due to the timeout.
5725+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5726+
ever_in_critical_section=1, peak_in_critical_section=1
5727+
),
5728+
)
5729+
5730+
5731+
async def test_update_handler_lock_acquisition_respects_timeout(
5732+
client: Client, env: WorkflowEnvironment
5733+
):
5734+
await _do_update_handler_lock_or_semaphore_test(
5735+
client,
5736+
env,
5737+
# Second and subsequent handler executions fail to acquire the lock due to the timeout.
5738+
UseLockOrSemaphoreWorkflowParameters(sleep=0.5, timeout=0.1),
5739+
n_updates=5,
5740+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5741+
ever_in_critical_section=1, peak_in_critical_section=1
5742+
),
5743+
)
5744+
5745+
5746+
async def test_workflow_coroutines_can_use_semaphore(client: Client):
5747+
await _do_workflow_coroutines_lock_or_semaphore_test(
5748+
client,
5749+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, semaphore_initial_value=3),
5750+
# The semaphore limits concurrency to 3
5751+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5752+
ever_in_critical_section=5, peak_in_critical_section=3
5753+
),
5754+
)
5755+
5756+
5757+
async def test_update_handler_can_use_semaphore_to_control_handler_execution_concurrency(
5758+
client: Client, env: WorkflowEnvironment
5759+
):
5760+
await _do_update_handler_lock_or_semaphore_test(
5761+
client,
5762+
env,
5763+
# The semaphore limits concurrency to 3
5764+
UseLockOrSemaphoreWorkflowParameters(semaphore_initial_value=3),
5765+
n_updates=5,
5766+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5767+
ever_in_critical_section=5, peak_in_critical_section=3
5768+
),
5769+
)
5770+
5771+
5772+
async def test_workflow_coroutine_semaphore_acquisition_respects_timeout(
5773+
client: Client,
5774+
):
5775+
await _do_workflow_coroutines_lock_or_semaphore_test(
5776+
client,
5777+
UseLockOrSemaphoreWorkflowParameters(
5778+
n_coroutines=5, semaphore_initial_value=3, sleep=0.5, timeout=0.1
5779+
),
5780+
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
5781+
# slot fail.
5782+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5783+
ever_in_critical_section=3, peak_in_critical_section=3
5784+
),
5785+
)
5786+
5787+
5788+
async def test_update_handler_semaphore_acquisition_respects_timeout(
5789+
client: Client, env: WorkflowEnvironment
5790+
):
5791+
await _do_update_handler_lock_or_semaphore_test(
5792+
client,
5793+
env,
5794+
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
5795+
# slot fail.
5796+
UseLockOrSemaphoreWorkflowParameters(
5797+
semaphore_initial_value=3,
5798+
sleep=0.5,
5799+
timeout=0.1,
5800+
),
5801+
n_updates=5,
5802+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
5803+
ever_in_critical_section=3, peak_in_critical_section=3
5804+
),
5805+
)

0 commit comments

Comments
 (0)