Skip to content

Commit e5bc027

Browse files
committed
Implement Barrier
1 parent 529921e commit e5bc027

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

tests/nexus/test_worker.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import sys
54
import uuid
65
from datetime import timedelta
76
from typing import Any
@@ -53,10 +52,7 @@ async def test_max_concurrent_nexus_tasks(
5352
if env.supports_time_skipping:
5453
pytest.skip("Nexus tests don't work with Javas test server")
5554

56-
if sys.version_info < (3, 11):
57-
pytest.skip("Test requires Python 3.11+")
58-
59-
barrier = asyncio.Barrier(num_nexus_operations) # type: ignore
55+
barrier = Barrier(num_nexus_operations)
6056

6157
@nexusrpc.handler.service_handler
6258
class MaxConcurrentTestService:
@@ -93,3 +89,35 @@ async def op(
9389
f"max_concurrent_nexus_tasks={max_concurrent_nexus_tasks}, "
9490
f"num_nexus_operations={num_nexus_operations}"
9591
)
92+
93+
94+
# Minimal implementation of asyncio.Barrier for Python 3.9+ compatibility
95+
96+
97+
class Barrier:
98+
"""Minimal implementation of asyncio.Barrier for Python 3.9+ compatibility.
99+
100+
This is a simplified version that only implements the wait() method needed
101+
for this test. All tasks block until exactly 'parties' tasks have called
102+
wait(), then all are released simultaneously.
103+
"""
104+
105+
def __init__(self, parties: int):
106+
"""Create a barrier for 'parties' tasks."""
107+
if parties < 1:
108+
raise ValueError("parties must be > 0")
109+
self._parties = parties
110+
self._count = 0
111+
self._lock = asyncio.Lock()
112+
self._event = asyncio.Event()
113+
114+
async def wait(self) -> None:
115+
"""Wait for all parties to reach the barrier."""
116+
async with self._lock:
117+
self._count += 1
118+
if self._count == self._parties:
119+
# Last one in, release everyone
120+
self._event.set()
121+
122+
# Wait for all parties to arrive
123+
await self._event.wait()

0 commit comments

Comments
 (0)