Skip to content

Commit bbb6964

Browse files
committed
Add tests of asyncio.Lock and asyncio.Semaphore usage
1 parent a5b9661 commit bbb6964

File tree

2 files changed

+322
-2
lines changed

2 files changed

+322
-2
lines changed

tests/helpers/__init__.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import uuid
55
from contextlib import closing
66
from datetime import timedelta
7-
from typing import Awaitable, Callable, Optional, Sequence, Type, TypeVar
7+
from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar
88

99
from temporalio.api.common.v1 import WorkflowExecution
1010
from temporalio.api.enums.v1 import IndexedValueType
@@ -14,11 +14,12 @@
1414
)
1515
from temporalio.api.update.v1 import UpdateRef
1616
from temporalio.api.workflowservice.v1 import PollWorkflowExecutionUpdateRequest
17-
from temporalio.client import BuildIdOpAddNewDefault, Client
17+
from temporalio.client import BuildIdOpAddNewDefault, Client, WorkflowHandle
1818
from temporalio.common import SearchAttributeKey
1919
from temporalio.service import RPCError, RPCStatusCode
2020
from temporalio.worker import Worker, WorkflowRunner
2121
from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner
22+
from temporalio.workflow import UpdateMethodMultiParam
2223

2324

2425
def new_worker(
@@ -128,3 +129,24 @@ async def workflow_update_exists(
128129
if err.status != RPCStatusCode.NOT_FOUND:
129130
raise
130131
return False
132+
133+
134+
# TODO: type update return value
135+
async def admitted_update_task(
136+
client: Client,
137+
handle: WorkflowHandle,
138+
update_method: UpdateMethodMultiParam,
139+
id: str,
140+
**kwargs,
141+
) -> asyncio.Task:
142+
"""
143+
Return an asyncio.Task for an update after waiting for it to be admitted.
144+
"""
145+
update_task = asyncio.create_task(
146+
handle.execute_update(update_method, id=id, **kwargs)
147+
)
148+
await assert_eq_eventually(
149+
True,
150+
lambda: workflow_update_exists(client, handle.id, id),
151+
)
152+
return update_task

tests/worker/test_workflow.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@
107107
WorkflowRunner,
108108
)
109109
from tests.helpers import (
110+
admitted_update_task,
110111
assert_eq_eventually,
111112
ensure_search_attributes_present,
112113
find_free_port,
@@ -5861,3 +5862,300 @@ async def test_timer_started_after_workflow_completion(client: Client):
58615862
)
58625863
await handle.signal(TimerStartedAfterWorkflowCompletionWorkflow.my_signal)
58635864
assert await handle.result() == "workflow-result"
5865+
5866+
5867+
# The following Lock and Semaphore tests test that asyncio concurrency primitives work as expected
5868+
# in workflow code. There is nothing Temporal-specific about the way that asyncio.Lock and
5869+
# asyncio.Semaphore are used here.
5870+
5871+
5872+
@activity.defn
5873+
async def noop_activity_for_lock_or_semaphore_tests() -> None:
5874+
return None
5875+
5876+
5877+
@dataclass
5878+
class LockOrSemaphoreWorkflowConcurrencySummary:
5879+
ever_in_critical_section: int
5880+
peak_in_critical_section: int
5881+
5882+
5883+
@dataclass
5884+
class UseLockOrSemaphoreWorkflowParameters:
5885+
n_coroutines: int = 0
5886+
semaphore_initial_value: Optional[int] = None
5887+
sleep: Optional[float] = None
5888+
timeout: Optional[float] = None
5889+
5890+
5891+
@workflow.defn
5892+
class CoroutinesUseLockWorkflow:
5893+
def __init__(self) -> None:
5894+
self.params: UseLockOrSemaphoreWorkflowParameters
5895+
self.lock_or_semaphore: Union[asyncio.Lock, asyncio.Semaphore]
5896+
self._currently_in_critical_section: set[str] = set()
5897+
self._ever_in_critical_section: set[str] = set()
5898+
self._peak_in_critical_section = 0
5899+
5900+
def init(self, params: UseLockOrSemaphoreWorkflowParameters):
5901+
self.params = params
5902+
if self.params.semaphore_initial_value is not None:
5903+
self.lock_or_semaphore = asyncio.Semaphore(
5904+
self.params.semaphore_initial_value
5905+
)
5906+
else:
5907+
self.lock_or_semaphore = asyncio.Lock()
5908+
5909+
@workflow.run
5910+
async def run(
5911+
self,
5912+
params: UseLockOrSemaphoreWorkflowParameters,
5913+
) -> LockOrSemaphoreWorkflowConcurrencySummary:
5914+
# TODO: Use workflow init method when it exists.
5915+
self.init(params)
5916+
await asyncio.gather(
5917+
*(self.coroutine(f"{i}") for i in range(self.params.n_coroutines))
5918+
)
5919+
assert not any(self._currently_in_critical_section)
5920+
return LockOrSemaphoreWorkflowConcurrencySummary(
5921+
len(self._ever_in_critical_section),
5922+
self._peak_in_critical_section,
5923+
)
5924+
5925+
async def coroutine(self, id: str):
5926+
if self.params.timeout:
5927+
try:
5928+
await asyncio.wait_for(
5929+
self.lock_or_semaphore.acquire(), self.params.timeout
5930+
)
5931+
except asyncio.TimeoutError:
5932+
return
5933+
else:
5934+
await self.lock_or_semaphore.acquire()
5935+
self._enters_critical_section(id)
5936+
try:
5937+
if self.params.sleep:
5938+
await asyncio.sleep(self.params.sleep)
5939+
else:
5940+
await workflow.execute_activity(
5941+
noop_activity_for_lock_or_semaphore_tests,
5942+
schedule_to_close_timeout=timedelta(seconds=30),
5943+
)
5944+
finally:
5945+
self.lock_or_semaphore.release()
5946+
self._exits_critical_section(id)
5947+
5948+
def _enters_critical_section(self, id: str) -> None:
5949+
self._currently_in_critical_section.add(id)
5950+
self._ever_in_critical_section.add(id)
5951+
self._peak_in_critical_section = max(
5952+
self._peak_in_critical_section,
5953+
len(self._currently_in_critical_section),
5954+
)
5955+
5956+
def _exits_critical_section(self, id: str) -> None:
5957+
self._currently_in_critical_section.remove(id)
5958+
5959+
5960+
@workflow.defn
5961+
class HandlerCoroutinesUseLockWorkflow(CoroutinesUseLockWorkflow):
5962+
def __init__(self) -> None:
5963+
super().__init__()
5964+
self.workflow_may_exit = False
5965+
5966+
@workflow.run
5967+
async def run(
5968+
self,
5969+
) -> LockOrSemaphoreWorkflowConcurrencySummary:
5970+
await workflow.wait_condition(lambda: self.workflow_may_exit)
5971+
return LockOrSemaphoreWorkflowConcurrencySummary(
5972+
len(self._ever_in_critical_section),
5973+
self._peak_in_critical_section,
5974+
)
5975+
5976+
@workflow.update
5977+
async def my_update(self, params: UseLockOrSemaphoreWorkflowParameters):
5978+
# TODO: Use workflow init method when it exists.
5979+
if not hasattr(self, "params"):
5980+
self.init(params)
5981+
assert (update_info := workflow.current_update_info())
5982+
await self.coroutine(update_info.id)
5983+
5984+
@workflow.signal
5985+
async def finish(self):
5986+
self.workflow_may_exit = True
5987+
5988+
5989+
async def _do_workflow_coroutines_lock_or_semaphore_test(
5990+
client: Client,
5991+
params: UseLockOrSemaphoreWorkflowParameters,
5992+
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
5993+
):
5994+
async with new_worker(
5995+
client,
5996+
CoroutinesUseLockWorkflow,
5997+
activities=[noop_activity_for_lock_or_semaphore_tests],
5998+
) as worker:
5999+
summary = await client.execute_workflow(
6000+
CoroutinesUseLockWorkflow.run,
6001+
arg=params,
6002+
id=str(uuid.uuid4()),
6003+
task_queue=worker.task_queue,
6004+
)
6005+
assert summary == expectation
6006+
6007+
6008+
async def _do_update_handler_lock_or_semaphore_test(
6009+
client: Client,
6010+
env: WorkflowEnvironment,
6011+
params: UseLockOrSemaphoreWorkflowParameters,
6012+
n_updates: int,
6013+
expectation: LockOrSemaphoreWorkflowConcurrencySummary,
6014+
):
6015+
if env.supports_time_skipping:
6016+
pytest.skip(
6017+
"Java test server: https://github.com/temporalio/sdk-java/issues/1903"
6018+
)
6019+
6020+
task_queue = "tq"
6021+
handle = await client.start_workflow(
6022+
HandlerCoroutinesUseLockWorkflow.run,
6023+
id=f"wf-{str(uuid.uuid4())}",
6024+
task_queue=task_queue,
6025+
)
6026+
# Create updates in Admitted state, before the worker starts polling.
6027+
admitted_updates = [
6028+
await admitted_update_task(
6029+
client,
6030+
handle,
6031+
HandlerCoroutinesUseLockWorkflow.my_update,
6032+
arg=params,
6033+
id=f"update-{i}",
6034+
)
6035+
for i in range(n_updates)
6036+
]
6037+
async with new_worker(
6038+
client,
6039+
HandlerCoroutinesUseLockWorkflow,
6040+
activities=[noop_activity_for_lock_or_semaphore_tests],
6041+
task_queue=task_queue,
6042+
):
6043+
for update_task in admitted_updates:
6044+
await update_task
6045+
await handle.signal(HandlerCoroutinesUseLockWorkflow.finish)
6046+
summary = await handle.result()
6047+
assert summary == expectation
6048+
6049+
6050+
async def test_workflow_coroutines_can_use_lock(client: Client):
6051+
await _do_workflow_coroutines_lock_or_semaphore_test(
6052+
client,
6053+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5),
6054+
# The lock limits concurrency to 1
6055+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6056+
ever_in_critical_section=5, peak_in_critical_section=1
6057+
),
6058+
)
6059+
6060+
6061+
async def test_update_handler_can_use_lock_to_serialize_handler_executions(
6062+
client: Client, env: WorkflowEnvironment
6063+
):
6064+
await _do_update_handler_lock_or_semaphore_test(
6065+
client,
6066+
env,
6067+
UseLockOrSemaphoreWorkflowParameters(),
6068+
n_updates=5,
6069+
# The lock limits concurrency to 1
6070+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6071+
ever_in_critical_section=5, peak_in_critical_section=1
6072+
),
6073+
)
6074+
6075+
6076+
async def test_workflow_coroutines_lock_acquisition_respects_timeout(client: Client):
6077+
await _do_workflow_coroutines_lock_or_semaphore_test(
6078+
client,
6079+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, sleep=0.5, timeout=0.1),
6080+
# Second and subsequent coroutines fail to acquire the lock due to the timeout.
6081+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6082+
ever_in_critical_section=1, peak_in_critical_section=1
6083+
),
6084+
)
6085+
6086+
6087+
async def test_update_handler_lock_acquisition_respects_timeout(
6088+
client: Client, env: WorkflowEnvironment
6089+
):
6090+
await _do_update_handler_lock_or_semaphore_test(
6091+
client,
6092+
env,
6093+
# Second and subsequent handler executions fail to acquire the lock due to the timeout.
6094+
UseLockOrSemaphoreWorkflowParameters(sleep=0.5, timeout=0.1),
6095+
n_updates=5,
6096+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6097+
ever_in_critical_section=1, peak_in_critical_section=1
6098+
),
6099+
)
6100+
6101+
6102+
async def test_workflow_coroutines_can_use_semaphore(client: Client):
6103+
await _do_workflow_coroutines_lock_or_semaphore_test(
6104+
client,
6105+
UseLockOrSemaphoreWorkflowParameters(n_coroutines=5, semaphore_initial_value=3),
6106+
# The semaphore limits concurrency to 3
6107+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6108+
ever_in_critical_section=5, peak_in_critical_section=3
6109+
),
6110+
)
6111+
6112+
6113+
async def test_update_handler_can_use_semaphore_to_control_handler_execution_concurrency(
6114+
client: Client, env: WorkflowEnvironment
6115+
):
6116+
await _do_update_handler_lock_or_semaphore_test(
6117+
client,
6118+
env,
6119+
# The semaphore limits concurrency to 3
6120+
UseLockOrSemaphoreWorkflowParameters(semaphore_initial_value=3),
6121+
n_updates=5,
6122+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6123+
ever_in_critical_section=5, peak_in_critical_section=3
6124+
),
6125+
)
6126+
6127+
6128+
async def test_workflow_coroutine_semaphore_acquisition_respects_timeout(
6129+
client: Client,
6130+
):
6131+
await _do_workflow_coroutines_lock_or_semaphore_test(
6132+
client,
6133+
UseLockOrSemaphoreWorkflowParameters(
6134+
n_coroutines=5, semaphore_initial_value=3, sleep=0.5, timeout=0.1
6135+
),
6136+
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
6137+
# slot fail.
6138+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6139+
ever_in_critical_section=3, peak_in_critical_section=3
6140+
),
6141+
)
6142+
6143+
6144+
async def test_update_handler_semaphore_acquisition_respects_timeout(
6145+
client: Client, env: WorkflowEnvironment
6146+
):
6147+
await _do_update_handler_lock_or_semaphore_test(
6148+
client,
6149+
env,
6150+
# Initial entry to the semaphore succeeds, but all subsequent attempts to acquire a semaphore
6151+
# slot fail.
6152+
UseLockOrSemaphoreWorkflowParameters(
6153+
semaphore_initial_value=3,
6154+
sleep=0.5,
6155+
timeout=0.1,
6156+
),
6157+
n_updates=5,
6158+
expectation=LockOrSemaphoreWorkflowConcurrencySummary(
6159+
ever_in_critical_section=3, peak_in_critical_section=3
6160+
),
6161+
)

0 commit comments

Comments
 (0)