Skip to content

Commit c08a124

Browse files
committed
Add tests of asyncio.Lock and asyncio.Semaphore usage
1 parent 1a68b58 commit c08a124

File tree

2 files changed

+322
-2
lines changed

2 files changed

+322
-2
lines changed

tests/helpers/__init__.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import uuid
55
from contextlib import closing
66
from datetime import timedelta
7-
from typing import Awaitable, Callable, Optional, Sequence, Type, TypeVar
7+
from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar
88

99
from temporalio.api.common.v1 import WorkflowExecution
1010
from temporalio.api.enums.v1 import IndexedValueType
@@ -14,11 +14,12 @@
1414
)
1515
from temporalio.api.update.v1 import UpdateRef
1616
from temporalio.api.workflowservice.v1 import PollWorkflowExecutionUpdateRequest
17-
from temporalio.client import BuildIdOpAddNewDefault, Client
17+
from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle
1818
from temporalio.common import SearchAttributeKey
1919
from temporalio.service import RPCError, RPCStatusCode
2020
from temporalio.worker import Worker, WorkflowRunner
2121
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
22+
from temporalio.workflow import UpdateMethodMultiParam
2223

2324

2425
def new_worker(
@@ -128,3 +129,24 @@ async def workflow_update_exists(
128129
if err.status != RPCStatusCode.NOT_FOUND:
129130
raise
130131
return False
132+
133+
134+
# TODO: type update return value
135+
async def admitted_update_task(
136+
client: Client,
137+
handle: WorkflowHandle,
138+
update_method: UpdateMethodMultiParam,
139+
id: str,
140+
**kwargs,
141+
) -> asyncio.Task:
142+
"""
143+
Return an asyncio.Task for an update after waiting for it to be admitted.
144+
"""
145+
update_task = asyncio.create_task(
146+
handle.execute_update(update_method, id=id, **kwargs)
147+
)
148+
await assert_eq_eventually(
149+
True,
150+
lambda: workflow_update_exists(client, handle.id, id),
151+
)
152+
return update_task

tests/worker/test_workflow.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
WorkflowRunner,
108108
)
109109
from tests.helpers import (
110+
admitted_update_task,
110111
assert_eq_eventually,
111112
ensure_search_attributes_present,
112113
find_free_port,
@@ -6611,3 +6612,300 @@ async def test_alternate_async_loop_ordering(client: Client, env: WorkflowEnviro
66116612
task_queue=task_queue,
66126613
):
66136614
await handle.result()
6615+
6616+
6617+
# The following Lock and Semaphore tests test that asyncio concurrency primitives work as expected
6618+
# in workflow code. There is nothing Temporal-specific about the way that asyncio.Lock and
6619+
# asyncio.Semaphore are used here.
6620+
6621+
6622+
@activity.defn
6623+
async def noop_activity_for_lock_or_semaphore_tests() -> None:
6624+
return None
6625+
6626+
6627+
@dataclass
6628+
class LockOrSemaphoreWorkflowConcurrencySummary:
6629+
ever_in_critical_section: int
6630+
peak_in_critical_section: int
6631+
6632+
6633+
@dataclass
6634+
class UseLockOrSemaphoreWorkflowParameters:
6635+
n_coroutines: int = 0
6636+
semaphore_initial_value: Optional[int] = None
6637+
sleep: Optional[float] = None
6638+
timeout: Optional[float] = None
6639+
6640+
6641+
@workflow.defn
6642+
class CoroutinesUseLockWorkflow:
6643+
def __init__(self) -> None:
6644+
self.params: UseLockOrSemaphoreWorkflowParameters
6645+
self.lock_or_semaphore: Union[asyncio.Lock, asyncio.Semaphore]
6646+
self._currently_in_critical_section: set[str] = set()
6647+
self._ever_in_critical_section: set[str] = set()
6648+
self._peak_in_critical_section = 0
6649+
6650+
def init(self, params: UseLockOrSemaphoreWorkflowParameters):
6651+
self.params = params
6652+
if self.params.semaphore_initial_value is not None:
6653+
self.lock_or_semaphore = asyncio.Semaphore(
6654+
self.params.semaphore_initial_value
6655+
)
6656+
else:
6657+
self.lock_or_semaphore = asyncio.Lock()
6658+
6659+
@workflow.run
6660+
async def run(
6661+
self,
6662+
params: UseLockOrSemaphoreWorkflowParameters,
6663+
) -> LockOrSemaphoreWorkflowConcurrencySummary:
6664+
# TODO: Use workflow init method when it exists.
6665+
self.init(params)
6666+
await asyncio.gather(
6667+
*(self.coroutine(f"{i}") for i in range(self.params.n_coroutines))
6668+
)
6669+
assert not any(self._currently_in_critical_section)
6670+
return LockOrSemaphoreWorkflowConcurrencySummary(
6671+
len(self._ever_in_critical_section),
6672+
self._peak_in_critical_section,
6673+
)
6674+
6675+
async def coroutine(self, id: str):
6676+
if self.params.timeout:
6677+
try:
6678+
await asyncio.wait_for(
6679+
self.lock_or_semaphore.acquire(), self.params.timeout
6680+
)
6681+
except asyncio.TimeoutError:
6682+
return
6683+
else:
6684+
await self.lock_or_semaphore.acquire()
6685+
self._enters_critical_section(id)
6686+
try:
6687+
if self.params.sleep:
6688+
await asyncio.sleep(self.params.sleep)
6689+
else:
6690+
await workflow.execute_activity(
6691+
noop_activity_for_lock_or_semaphore_tests,
6692+
schedule_to_close_timeout=timedelta(seconds=30),
6693+
)
6694+
finally:
6695+
self.lock_or_semaphore.release()
6696+
self._exits_critical_section(id)
6697+
6698+
def _enters_critical_section(self, id: str) -> None:
6699+
self._currently_in_critical_section.add(id)
6700+
self._ever_in_critical_section.add(id)
6701+
self._peak_in_critical_section = max(
6702+
self._peak_in_critical_section,
6703+
len(self._currently_in_critical_section),
6704+
)
6705+
6706+
def _exits_critical_section(self, id: str) -> None:
6707+
self._currently_in_critical_section.remove(id)
6708+
6709+
6710+
@workflow.defn
6711+
class HandlerCoroutinesUseLockWorkflow(CoroutinesUseLockWorkflow):
6712+
def __init__(self) -> None:
6713+
super().__init__()
6714+
self.workflow_may_exit = False
6715+
6716+
@workflow.run
6717+
async def run(
6718+
self,
6719+
) -> LockOrSemaphoreWorkflowConcurrencySummary:
6720+
await workflow.wait_condition(lambda: self.workflow_may_exit)
6721+
return LockOrSemaphoreWorkflowConcurrencySummary(
6722+
len(self._ever_in_critical_section),
6723+
self._peak_in_critical_section,
6724+
)
6725+
6726+
@workflow.update
6727+
async def my_update(self, params: UseLockOrSemaphoreWorkflowParameters):
6728+
# TODO: Use workflow init method when it exists.
6729+
if not hasattr(self, "params"):
6730+
self.init(params)
6731+
assert (update_info := workflow.current_update_info())
6732+
await self.coroutine(update_info.id)
6733+
6734+
@workflow.signal
6735+
async def finish(self):
6736+
self.workflow_may_exit = True
6737+
6738+
6739+
async def _do_workflow_coroutines_lock_or_semaphore_test(
6740+
client: Client,
6741+
params: UseLockOrSemaphoreWorkflowParameters,
6742+
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
6743+
):
6744+
async with new_worker(
6745+
client,
6746+
CoroutinesUseLockWorkflow,
6747+
activities=[noop_activity_for_lock_or_semaphore_tests],
6748+
) as worker:
6749+
summary = await client.execute_workflow(
6750+
CoroutinesUseLockWorkflow.run,
6751+
arg=params,
6752+
id=str(uuid.uuid4()),
6753+
task_queue=worker.task_queue,
6754+
)
6755+
assert summary == expectation
6756+
6757+
6758+
async def _do_update_handler_lock_or_semaphore_test(
6759+
client: Client,
6760+
env: WorkflowEnvironment,
6761+
params: UseLockOrSemaphoreWorkflowParameters,
6762+
n_updates: int,
6763+
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
6764+
):
6765+
if env.supports_time_skipping:
6766+
pytest.skip(
6767+
"Java test server: https://github.com/temporalio/sdk-java/issues/1903"
6768+
)
6769+
6770+
task_queue = "tq"
6771+
handle = await client.start_workflow(
6772+
HandlerCoroutinesUseLockWorkflow.run,
6773+
id=f"wf-{str(uuid.uuid4())}",
6774+
task_queue=task_queue,
6775+
)
6776+
# Create updates in Admitted state, before the worker starts polling.
6777+
admitted_updates = [
6778+
await admitted_update_task(
6779+
client,
6780+
handle,
6781+
HandlerCoroutinesUseLockWorkflow.my_update,
6782+
arg=params,
6783+
id=f"update-{i}",
6784+
)
6785+
for i in range(n_updates)
6786+
]
6787+
async with new_worker(
6788+
client,
6789+
HandlerCoroutinesUseLockWorkflow,
6790+
activities=[noop_activity_for_lock_or_semaphore_tests],
6791+
task_queue=task_queue,
6792+
):
6793+
for update_task in admitted_updates:
6794+
await update_task
6795+
await handle.signal(HandlerCoroutinesUseLockWorkflow.finish)
6796+
summary = await handle.result()
6797+
assert summary == expectation
6798+
6799+
6800+
async def test_workflow_coroutines_can_use_lock(client: Client):
6801+
await _do_workflow_coroutines_lock_or_semaphore_test(
6802+
client,
6803+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5),
6804+
# The lock limits concurrency to 1
6805+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6806+
ever_in_critical_section=5, peak_in_critical_section=1
6807+
),
6808+
)
6809+
6810+
6811+
async def test_update_handler_can_use_lock_to_serialize_handler_executions(
6812+
client: Client, env: WorkflowEnvironment
6813+
):
6814+
await _do_update_handler_lock_or_semaphore_test(
6815+
client,
6816+
env,
6817+
UseLockOrSemaphoreWorkflowParameters(),
6818+
n_updates=5,
6819+
# The lock limits concurrency to 1
6820+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6821+
ever_in_critical_section=5, peak_in_critical_section=1
6822+
),
6823+
)
6824+
6825+
6826+
async def test_workflow_coroutines_lock_acquisition_respects_timeout(client: Client):
6827+
await _do_workflow_coroutines_lock_or_semaphore_test(
6828+
client,
6829+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, sleep=0.5, timeout=0.1),
6830+
# Second and subsequent coroutines fail to acquire the lock due to the timeout.
6831+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6832+
ever_in_critical_section=1, peak_in_critical_section=1
6833+
),
6834+
)
6835+
6836+
6837+
async def test_update_handler_lock_acquisition_respects_timeout(
6838+
client: Client, env: WorkflowEnvironment
6839+
):
6840+
await _do_update_handler_lock_or_semaphore_test(
6841+
client,
6842+
env,
6843+
# Second and subsequent handler executions fail to acquire the lock due to the timeout.
6844+
UseLockOrSemaphoreWorkflowParameters(sleep=0.5, timeout=0.1),
6845+
n_updates=5,
6846+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6847+
ever_in_critical_section=1, peak_in_critical_section=1
6848+
),
6849+
)
6850+
6851+
6852+
async def test_workflow_coroutines_can_use_semaphore(client: Client):
6853+
await _do_workflow_coroutines_lock_or_semaphore_test(
6854+
client,
6855+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, semaphore_initial_value=3),
6856+
# The semaphore limits concurrency to 3
6857+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6858+
ever_in_critical_section=5, peak_in_critical_section=3
6859+
),
6860+
)
6861+
6862+
6863+
async def test_update_handler_can_use_semaphore_to_control_handler_execution_concurrency(
6864+
client: Client, env: WorkflowEnvironment
6865+
):
6866+
await _do_update_handler_lock_or_semaphore_test(
6867+
client,
6868+
env,
6869+
# The semaphore limits concurrency to 3
6870+
UseLockOrSemaphoreWorkflowParameters(semaphore_initial_value=3),
6871+
n_updates=5,
6872+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6873+
ever_in_critical_section=5, peak_in_critical_section=3
6874+
),
6875+
)
6876+
6877+
6878+
async def test_workflow_coroutine_semaphore_acquisition_respects_timeout(
6879+
client: Client,
6880+
):
6881+
await _do_workflow_coroutines_lock_or_semaphore_test(
6882+
client,
6883+
UseLockOrSemaphoreWorkflowParameters(
6884+
n_coroutines=5, semaphore_initial_value=3, sleep=0.5, timeout=0.1
6885+
),
6886+
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
6887+
# slot fail.
6888+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6889+
ever_in_critical_section=3, peak_in_critical_section=3
6890+
),
6891+
)
6892+
6893+
6894+
async def test_update_handler_semaphore_acquisition_respects_timeout(
6895+
client: Client, env: WorkflowEnvironment
6896+
):
6897+
await _do_update_handler_lock_or_semaphore_test(
6898+
client,
6899+
env,
6900+
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
6901+
# slot fail.
6902+
UseLockOrSemaphoreWorkflowParameters(
6903+
semaphore_initial_value=3,
6904+
sleep=0.5,
6905+
timeout=0.1,
6906+
),
6907+
n_updates=5,
6908+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6909+
ever_in_critical_section=3, peak_in_critical_section=3
6910+
),
6911+
)

0 commit comments

Comments
 (0)