Skip to content

Commit f41641b

Browse files
committed
backport barrier
1 parent e1477ec commit f41641b

File tree

2 files changed

+166
-5
lines changed

2 files changed

+166
-5
lines changed

temporalio/_asyncio_compat.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
"""Backports for asyncio functionality for older Python versions."""
2+
3+
from __future__ import annotations
4+
5+
import asyncio
6+
import sys
7+
from typing import Optional
8+
9+
# Only define these if we're on Python < 3.11
10+
if sys.version_info < (3, 11):
11+
12+
class BrokenBarrierError(Exception):
13+
"""Exception raised when a Barrier is broken."""
14+
pass
15+
16+
class Barrier:
17+
"""Backport of asyncio.Barrier for Python < 3.11.
18+
19+
A barrier is a synchronization primitive that allows a set number of tasks
20+
to wait until they have all reached the barrier before proceeding.
21+
"""
22+
23+
def __init__(self, parties: int) -> None:
24+
"""Initialize a barrier.
25+
26+
Args:
27+
parties: The number of tasks that must call wait() before any
28+
of them can proceed.
29+
30+
Raises:
31+
ValueError: If parties is less than 1.
32+
"""
33+
if parties < 1:
34+
raise ValueError("parties must be greater than 0")
35+
36+
self._parties = parties
37+
self._count = 0
38+
self._broken = False
39+
self._waiters: list[asyncio.Future[int]] = []
40+
self._state_lock = asyncio.Lock()
41+
42+
async def wait(self) -> int:
43+
"""Wait for all parties to reach the barrier.
44+
45+
Returns:
46+
The index of the current task (0 to parties-1).
47+
48+
Raises:
49+
BrokenBarrierError: If the barrier is broken or gets broken
50+
while waiting.
51+
"""
52+
async with self._state_lock:
53+
if self._broken:
54+
raise BrokenBarrierError("Barrier is broken")
55+
56+
index = self._count
57+
self._count += 1
58+
59+
if self._count == self._parties:
60+
# We're the last one, release everyone
61+
self._count = 0
62+
for i, waiter in enumerate(self._waiters):
63+
if not waiter.done():
64+
waiter.set_result(i)
65+
self._waiters.clear()
66+
return index
67+
else:
68+
# We need to wait
69+
fut = asyncio.get_running_loop().create_future()
70+
self._waiters.append(fut)
71+
72+
try:
73+
return await fut
74+
except asyncio.CancelledError:
75+
# If we're cancelled, we need to handle cleanup
76+
async with self._state_lock:
77+
# Remove our future from waiters if it's still there
78+
try:
79+
self._waiters.remove(fut)
80+
self._count -= 1
81+
except ValueError:
82+
# Future was already processed
83+
pass
84+
85+
# If we were the last waiter and got cancelled, break the barrier
86+
if self._count == 0 and not self._broken:
87+
self._broken = True
88+
self._break_barrier()
89+
raise
90+
91+
def reset(self) -> None:
92+
"""Reset the barrier to its initial state.
93+
94+
Any tasks currently waiting will receive BrokenBarrierError.
95+
"""
96+
# Note: This is synchronous in Python 3.11+, so we keep it synchronous
97+
# but need to handle the async context properly
98+
loop = asyncio.get_event_loop()
99+
if loop.is_running():
100+
# Schedule the reset to run in the event loop
101+
asyncio.create_task(self._reset_async())
102+
else:
103+
# If no loop is running, we can't really do much
104+
# This matches the behavior where reset() expects to be called
105+
# from within an async context
106+
raise RuntimeError("Cannot reset barrier outside of async context")
107+
108+
async def _reset_async(self) -> None:
109+
"""Async implementation of reset."""
110+
async with self._state_lock:
111+
self._count = 0
112+
self._broken = False
113+
self._break_barrier()
114+
self._waiters.clear()
115+
116+
def abort(self) -> None:
117+
"""Place the barrier into a broken state.
118+
119+
This causes any current or future calls to wait() to fail with
120+
BrokenBarrierError.
121+
"""
122+
# Note: This is synchronous in Python 3.11+, so we keep it synchronous
123+
loop = asyncio.get_event_loop()
124+
if loop.is_running():
125+
# Schedule the abort to run in the event loop
126+
asyncio.create_task(self._abort_async())
127+
else:
128+
raise RuntimeError("Cannot abort barrier outside of async context")
129+
130+
async def _abort_async(self) -> None:
131+
"""Async implementation of abort."""
132+
async with self._state_lock:
133+
self._broken = True
134+
self._break_barrier()
135+
136+
def _break_barrier(self) -> None:
137+
"""Break the barrier, causing all waiters to get BrokenBarrierError.
138+
139+
Must be called while holding the state lock.
140+
"""
141+
for waiter in self._waiters:
142+
if not waiter.done():
143+
waiter.set_exception(BrokenBarrierError("Barrier is broken"))
144+
145+
@property
146+
def parties(self) -> int:
147+
"""Return the number of parties required to pass the barrier."""
148+
return self._parties
149+
150+
@property
151+
def n_waiting(self) -> int:
152+
"""Return the number of tasks currently waiting at the barrier."""
153+
return self._count
154+
155+
@property
156+
def broken(self) -> bool:
157+
"""Return True if the barrier is in a broken state."""
158+
return self._broken
159+
160+
else:
161+
# Python 3.11+, use the built-in
162+
from asyncio import Barrier, BrokenBarrierError # type: ignore[attr-defined]
163+
164+
__all__ = ["Barrier", "BrokenBarrierError"]

tests/nexus/test_worker.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import sys
54
import uuid
65
from datetime import timedelta
76
from typing import Any
@@ -10,6 +9,7 @@
109
import pytest
1110

1211
from temporalio import workflow
12+
from temporalio._asyncio_compat import Barrier
1313
from temporalio.testing import WorkflowEnvironment
1414
from tests.helpers import new_worker
1515
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
@@ -53,10 +53,7 @@ async def test_max_concurrent_nexus_tasks(
5353
if env.supports_time_skipping:
5454
pytest.skip("Nexus tests don't work with Javas test server")
5555

56-
if sys.version_info < (3, 11):
57-
pytest.skip("Test requires Python 3.11+")
58-
59-
barrier = asyncio.Barrier(num_nexus_operations) # type: ignore
56+
barrier = Barrier(num_nexus_operations)
6057

6158
@nexusrpc.handler.service_handler
6259
class MaxConcurrentTestService:

0 commit comments

Comments
 (0)