Skip to content

Commit b49a73f

Browse files
Patching ws lifecycle
1 parent 06f6941 commit b49a73f

File tree

1 file changed

+34
-3
lines changed

1 file changed

+34
-3
lines changed

tests/v2/test_v2_session_lifecycle.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def consume_budget(self, user: str) -> None:
4242
async def test_connect(ws_server: WsServerFixture) -> None:
4343
(urimeta, recv, conn) = ws_server
4444

45+
ws_close: asyncio.Task | None = None
46+
4547
def trigger_close(
4648
signal_closing: Callable[[], None],
4749
task_manager: BackgroundTaskManager, # .cancel_all_tasks()
@@ -50,9 +52,21 @@ def trigger_close(
5052
ws: ClientConnection | None,
5153
become_closed: Callable[[], None],
5254
) -> asyncio.Event:
53-
event = asyncio.Event()
54-
event.set()
55-
return event
55+
nonlocal ws_close
56+
57+
closing_event = asyncio.Event()
58+
async def _do_close() -> None:
59+
signal_closing()
60+
await task_manager.cancel_all_tasks()
61+
terminate_remaining_output_streams()
62+
await join_output_streams_with_timeout()
63+
if ws:
64+
await ws.close()
65+
become_closed()
66+
closing_event.set()
67+
ws_close = asyncio.create_task(_do_close())
68+
69+
return closing_event
5670

5771
session = Session(
5872
server_id="SERVER",
@@ -69,6 +83,8 @@ def trigger_close(
6983
assert isinstance(msg, TransportMessage)
7084
assert msg.payload["type"] == "HANDSHAKE_REQ"
7185
await session.close().wait()
86+
assert ws_close is not None
87+
await ws_close
7288
await connecting
7389

7490

@@ -78,6 +94,7 @@ async def test_close_race(ws_server: WsServerFixture) -> None:
7894
callcount = 0
7995

8096
event: asyncio.Event | None = None
97+
ws_close: asyncio.Task | None = None
8198

8299
def trigger_close(
83100
signal_closing: Callable[[], None],
@@ -88,11 +105,25 @@ def trigger_close(
88105
become_closed: Callable[[], None],
89106
) -> asyncio.Event:
90107
nonlocal event
108+
nonlocal ws_close
109+
91110
if event is None:
92111
event = asyncio.Event()
93112
event.set()
94113
nonlocal callcount
95114
callcount += 1
115+
116+
async def _do_close() -> None:
117+
signal_closing()
118+
await task_manager.cancel_all_tasks()
119+
terminate_remaining_output_streams()
120+
await join_output_streams_with_timeout()
121+
if ws:
122+
await ws.close()
123+
become_closed()
124+
125+
ws_close = asyncio.create_task(_do_close())
126+
96127
return event
97128

98129
session = Session(

0 commit comments

Comments
 (0)