Skip to content

Commit 0deb61c

Browse files
[bug] Resolve session.close() races (#154)
1 parent 177cf27 commit 0deb61c

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

src/replit_river/v2/session.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,6 @@ class Session[HandshakeMetadata]:
169169

170170
# Terminating
171171
_terminating_task: asyncio.Task[None] | None
172-
_closing_waiter: asyncio.Event | None
173172

174173
def __init__(
175174
self,
@@ -229,7 +228,6 @@ def __init__(
229228

230229
# Terminating
231230
self._terminating_task = None
232-
self._closing_waiter = None
233231

234232
self._start_recv_from_ws()
235233
self._start_buffered_message_sender()
@@ -393,11 +391,11 @@ async def close(
393391
reason: Exception | None = None,
394392
) -> None:
395393
"""Close the session and all associated streams."""
396-
if self._closing_waiter:
394+
if self._terminating_task:
397395
try:
398396
logger.debug("Session already closing, waiting...")
399397
async with asyncio.timeout(SESSION_CLOSE_TIMEOUT_SEC):
400-
await self._closing_waiter.wait()
398+
await self._terminating_task
401399
except asyncio.TimeoutError:
402400
logger.warning(
403401
f"Session took longer than {SESSION_CLOSE_TIMEOUT_SEC} "
@@ -436,7 +434,6 @@ async def do_close() -> None:
436434
f"ws: {self._ws}"
437435
)
438436
self._state = SessionState.CLOSING
439-
self._closing_waiter = asyncio.Event()
440437

441438
# We're closing, so we need to wake up...
442439
# ... tasks waiting for connection to be established
@@ -502,14 +499,11 @@ async def do_close() -> None:
502499
# This will get us GC'd, so this should be the last thing.
503500
self._close_session_callback(self)
504501

505-
# Release waiters, then release the event
506-
self._closing_waiter.set()
507-
self._closing_waiter = None
508-
509502
if self._terminating_task:
510503
return self._terminating_task
511504

512-
return asyncio.create_task(do_close())
505+
self._terminating_task = asyncio.create_task(do_close())
506+
return self._terminating_task
513507

514508
def _start_buffered_message_sender(
515509
self,

tests/v2/test_v2_session_lifecycle.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from websockets.asyncio.server import ServerConnection, serve
77
from websockets.typing import Data
88

9+
from replit_river.common_session import SessionState
910
from replit_river.messages import parse_transport_msg
1011
from replit_river.rate_limiter import RateLimiter
1112
from replit_river.rpc import TransportMessage
@@ -114,14 +115,20 @@ async def test_connect(ws_server: WsServerFixture) -> None:
114115
await connecting
115116

116117

117-
async def test_reconnect(ws_server: WsServerFixture) -> None:
118+
async def test_close_race(ws_server: WsServerFixture) -> None:
118119
(urimeta, recv, conn) = ws_server
119120

121+
callcount = 0
122+
123+
def close_session_callback(_session: Session) -> None:
124+
nonlocal callcount
125+
callcount += 1
126+
120127
session = Session(
121128
server_id="SERVER",
122129
session_id="SESSION1",
123130
transport_options=TransportOptions(),
124-
close_session_callback=lambda _: None,
131+
close_session_callback=close_session_callback,
125132
client_id="CLIENT1",
126133
rate_limiter=_PermissiveRateLimiter(),
127134
uri_and_metadata_factory=urimeta,
@@ -132,4 +139,9 @@ async def test_reconnect(ws_server: WsServerFixture) -> None:
132139
assert isinstance(msg, TransportMessage)
133140
assert msg.payload["type"] == "HANDSHAKE_REQ"
134141
await session.close()
142+
await session.close()
143+
await session.close()
144+
await session.close()
135145
await connecting
146+
assert session._state == SessionState.CLOSED
147+
assert callcount == 1

0 commit comments

Comments
 (0)