66from websockets .asyncio .server import ServerConnection , serve
77from websockets .typing import Data
88
9+ from replit_river .common_session import SessionState
910from replit_river .messages import parse_transport_msg
1011from replit_river .rate_limiter import RateLimiter
1112from replit_river .rpc import TransportMessage
@@ -114,14 +115,20 @@ async def test_connect(ws_server: WsServerFixture) -> None:
114115 await connecting
115116
116117
117- async def test_reconnect (ws_server : WsServerFixture ) -> None :
118+ async def test_close_race (ws_server : WsServerFixture ) -> None :
118119 (urimeta , recv , conn ) = ws_server
119120
121+ callcount = 0
122+
123+ def close_session_callback (_session : Session ) -> None :
124+ nonlocal callcount
125+ callcount += 1
126+
120127 session = Session (
121128 server_id = "SERVER" ,
122129 session_id = "SESSION1" ,
123130 transport_options = TransportOptions (),
124- close_session_callback = lambda _ : None ,
131+ close_session_callback = close_session_callback ,
125132 client_id = "CLIENT1" ,
126133 rate_limiter = _PermissiveRateLimiter (),
127134 uri_and_metadata_factory = urimeta ,
@@ -132,4 +139,9 @@ async def test_reconnect(ws_server: WsServerFixture) -> None:
132139 assert isinstance (msg , TransportMessage )
133140 assert msg .payload ["type" ] == "HANDSHAKE_REQ"
134141 await session .close ()
142+ await session .close ()
143+ await session .close ()
144+ await session .close ()
135145 await connecting
146+ assert session ._state == SessionState .CLOSED
147+ assert callcount == 1
0 commit comments