Skip to content

Commit 9010fbb

Browse files
Patching ws lifecycle
1 parent caeea3e commit 9010fbb

File tree

1 file changed

+36
-3
lines changed

1 file changed

+36
-3
lines changed

tests/v2/test_v2_session_lifecycle.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ def consume_budget(self, user: str) -> None:
4141
async def test_connect(ws_server: WsServerFixture) -> None:
4242
(urimeta, recv, conn) = ws_server
4343

44+
ws_close: asyncio.Task | None = None
45+
4446
def trigger_close(
4547
signal_closing: Callable[[], None],
4648
task_manager: BackgroundTaskManager, # .cancel_all_tasks()
@@ -49,9 +51,23 @@ def trigger_close(
4951
ws: ClientConnection | None,
5052
become_closed: Callable[[], None],
5153
) -> asyncio.Event:
52-
event = asyncio.Event()
53-
event.set()
54-
return event
54+
nonlocal ws_close
55+
56+
closing_event = asyncio.Event()
57+
58+
async def _do_close() -> None:
59+
signal_closing()
60+
await task_manager.cancel_all_tasks()
61+
terminate_remaining_output_streams()
62+
await join_output_streams_with_timeout()
63+
if ws:
64+
await ws.close()
65+
become_closed()
66+
closing_event.set()
67+
68+
ws_close = asyncio.create_task(_do_close())
69+
70+
return closing_event
5571

5672
session = Session(
5773
server_id="SERVER",
@@ -68,6 +84,8 @@ def trigger_close(
6884
assert isinstance(msg, TransportMessage)
6985
assert msg.payload["type"] == "HANDSHAKE_REQ"
7086
await session.close().wait()
87+
assert ws_close is not None
88+
await ws_close
7189
await connecting
7290

7391

@@ -77,6 +95,7 @@ async def test_close_race(ws_server: WsServerFixture) -> None:
7795
callcount = 0
7896

7997
event: asyncio.Event | None = None
98+
ws_close: asyncio.Task | None = None
8099

81100
def trigger_close(
82101
signal_closing: Callable[[], None],
@@ -87,11 +106,25 @@ def trigger_close(
87106
become_closed: Callable[[], None],
88107
) -> asyncio.Event:
89108
nonlocal event
109+
nonlocal ws_close
110+
90111
if event is None:
91112
event = asyncio.Event()
92113
event.set()
93114
nonlocal callcount
94115
callcount += 1
116+
117+
async def _do_close() -> None:
118+
signal_closing()
119+
await task_manager.cancel_all_tasks()
120+
terminate_remaining_output_streams()
121+
await join_output_streams_with_timeout()
122+
if ws:
123+
await ws.close()
124+
become_closed()
125+
126+
ws_close = asyncio.create_task(_do_close())
127+
95128
return event
96129

97130
session = Session(

0 commit comments

Comments
 (0)