Skip to content

Commit 2d82ba9

Browse files
committed
Use context & add test for concurrent timer creation w/ summaries
1 parent 19c48f9 commit 2d82ba9

File tree

2 files changed

+84
-12
lines changed

2 files changed

+84
-12
lines changed

temporalio/worker/_workflow_instance.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -325,10 +325,6 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
325325
# For tracking the thread this workflow is running on (primarily for deadlock situations)
326326
self._current_thread_id: Optional[int] = None
327327

328-
# Since timer creation often happens indirectly through asyncio, we need some place to
329-
# temporarily store options for timers created by, ex `wait_condition`.
330-
self._next_timer_options: Optional[_TimerOptions] = None
331-
332328
# The current details (as opposed to static details on workflow start), returned in the
333329
# metadata query
334330
self._current_details = ""
@@ -1474,8 +1470,13 @@ async def workflow_wait_condition(
14741470
if timeout_summary
14751471
else None
14761472
)
1477-
self._next_timer_options = _TimerOptions(user_metadata=user_metadata)
1478-
await asyncio.wait_for(fut, timeout)
1473+
ctxvars = contextvars.copy_context()
1474+
1475+
async def in_context():
1476+
_TimerOptionsCtxVar.set(_TimerOptions(user_metadata=user_metadata))
1477+
await asyncio.wait_for(fut, timeout)
1478+
1479+
await ctxvars.run(in_context)
14791480

14801481
def workflow_get_current_details(self) -> str:
14811482
return self._current_details
@@ -2105,11 +2106,7 @@ def call_later(
21052106
*args: Any,
21062107
context: Optional[contextvars.Context] = None,
21072108
) -> asyncio.TimerHandle:
2108-
# Fetch options from the class field, erasing them afterward.
2109-
options = (
2110-
self._next_timer_options if self._next_timer_options else _TimerOptions()
2111-
)
2112-
self._next_timer_options = None
2109+
options = _TimerOptionsCtxVar.get()
21132110
return self._timer_impl(delay, options, callback, *args, context=context)
21142111

21152112
def call_at(
@@ -2332,6 +2329,11 @@ def start_local_activity(
23322329
return self._instance._outbound_schedule_activity(input)
23332330

23342331

2332+
_TimerOptionsCtxVar: contextvars.ContextVar[_TimerOptions] = contextvars.ContextVar(
2333+
"__temporal_timer_options"
2334+
)
2335+
2336+
23352337
@dataclass(frozen=True)
23362338
class _TimerOptions:
23372339
user_metadata: Optional[temporalio.api.sdk.v1.UserMetadata] = None

tests/worker/test_workflow.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from temporalio.api.failure.v1 import Failure
4545
from temporalio.api.sdk.v1 import EnhancedStackTrace
4646
from temporalio.api.workflowservice.v1 import (
47-
DescribeWorkflowExecutionRequest,
4847
GetWorkflowExecutionHistoryRequest,
4948
ResetStickyTaskQueueRequest,
5049
)
@@ -6306,3 +6305,74 @@ async def test_workflow_sleep(client: Client):
63066305
task_queue=worker.task_queue,
63076306
)
63086307
assert (datetime.now() - start_time) >= timedelta(seconds=1)
6308+
6309+
6310+
@workflow.defn
6311+
class ConcurrentSleepsWorkflow:
6312+
@workflow.run
6313+
async def run(self) -> None:
6314+
sleeps_a = [workflow.sleep(0.1, summary=f"t{i}") for i in range(5)]
6315+
zero_a = workflow.sleep(0, summary="zero_timer")
6316+
wait_some = workflow.wait_condition(
6317+
lambda: False, timeout=0.1, timeout_summary="wait_some"
6318+
)
6319+
zero_b = workflow.wait_condition(
6320+
lambda: False, timeout=0, timeout_summary="zero_wait"
6321+
)
6322+
no_summ = workflow.sleep(0.1)
6323+
sleeps_b = [workflow.sleep(0.1, summary=f"t{i}") for i in range(5, 10)]
6324+
try:
6325+
await asyncio.gather(
6326+
*sleeps_a,
6327+
zero_a,
6328+
wait_some,
6329+
zero_b,
6330+
no_summ,
6331+
*sleeps_b,
6332+
return_exceptions=True,
6333+
)
6334+
except asyncio.TimeoutError:
6335+
pass
6336+
6337+
task_1 = asyncio.create_task(self.make_timers(100, 105))
6338+
task_2 = asyncio.create_task(self.make_timers(105, 110))
6339+
await asyncio.gather(task_1, task_2)
6340+
6341+
async def make_timers(self, start: int, end: int):
6342+
await asyncio.gather(
6343+
*[workflow.sleep(0.1, summary=f"m_t{i}") for i in range(start, end)]
6344+
)
6345+
6346+
6347+
async def test_concurrent_sleeps_use_proper_options(client: Client):
6348+
async with new_worker(client, ConcurrentSleepsWorkflow) as worker:
6349+
handle = await client.start_workflow(
6350+
ConcurrentSleepsWorkflow.run,
6351+
id=f"workflow-{uuid.uuid4()}",
6352+
task_queue=worker.task_queue,
6353+
)
6354+
await handle.result()
6355+
resp = await client.workflow_service.get_workflow_execution_history(
6356+
GetWorkflowExecutionHistoryRequest(
6357+
namespace=client.namespace,
6358+
execution=WorkflowExecution(workflow_id=handle.id),
6359+
)
6360+
)
6361+
timer_summaries = [
6362+
PayloadConverter.default.from_payload(e.user_metadata.summary)
6363+
if e.user_metadata.HasField("summary")
6364+
else "<no summ>"
6365+
for e in resp.history.events
6366+
if e.event_type == EventType.EVENT_TYPE_TIMER_STARTED
6367+
]
6368+
assert timer_summaries == [
6369+
*[f"t{i}" for i in range(5)],
6370+
"zero_timer",
6371+
"wait_some",
6372+
"<no summ>",
6373+
*[f"t{i}" for i in range(5, 10)],
6374+
*[f"m_t{i}" for i in range(100, 110)],
6375+
]
6376+
6377+
# Force replay with a query to ensure determinism
6378+
await handle.query("__temporal_workflow_metadata")

0 commit comments

Comments
 (0)