@@ -42,6 +42,8 @@ def consume_budget(self, user: str) -> None:
4242async def test_connect (ws_server : WsServerFixture ) -> None :
4343 (urimeta , recv , conn ) = ws_server
4444
45+ ws_close : asyncio .Task | None = None
46+
4547 def trigger_close (
4648 signal_closing : Callable [[], None ],
4749 task_manager : BackgroundTaskManager , # .cancel_all_tasks()
@@ -50,9 +52,21 @@ def trigger_close(
5052 ws : ClientConnection | None ,
5153 become_closed : Callable [[], None ],
5254 ) -> asyncio .Event :
53- event = asyncio .Event ()
54- event .set ()
55- return event
55+ nonlocal ws_close
56+
57+ closing_event = asyncio .Event ()
58+ async def _do_close () -> None :
59+ signal_closing ()
60+ await task_manager .cancel_all_tasks ()
61+ terminate_remaining_output_streams ()
62+ await join_output_streams_with_timeout ()
63+ if ws :
64+ await ws .close ()
65+ become_closed ()
66+ closing_event .set ()
67+ ws_close = asyncio .create_task (_do_close ())
68+
69+ return closing_event
5670
5771 session = Session (
5872 server_id = "SERVER" ,
@@ -69,6 +83,8 @@ def trigger_close(
6983 assert isinstance (msg , TransportMessage )
7084 assert msg .payload ["type" ] == "HANDSHAKE_REQ"
7185 await session .close ().wait ()
86+ assert ws_close is not None
87+ await ws_close
7288 await connecting
7389
7490
@@ -78,6 +94,7 @@ async def test_close_race(ws_server: WsServerFixture) -> None:
7894 callcount = 0
7995
8096 event : asyncio .Event | None = None
97+ ws_close : asyncio .Task | None = None
8198
8299 def trigger_close (
83100 signal_closing : Callable [[], None ],
@@ -88,11 +105,25 @@ def trigger_close(
88105 become_closed : Callable [[], None ],
89106 ) -> asyncio .Event :
90107 nonlocal event
108+ nonlocal ws_close
109+
91110 if event is None :
92111 event = asyncio .Event ()
93112 event .set ()
94113 nonlocal callcount
95114 callcount += 1
115+
116+ async def _do_close () -> None :
117+ signal_closing ()
118+ await task_manager .cancel_all_tasks ()
119+ terminate_remaining_output_streams ()
120+ await join_output_streams_with_timeout ()
121+ if ws :
122+ await ws .close ()
123+ become_closed ()
124+
125+ ws_close = asyncio .create_task (_do_close ())
126+
96127 return event
97128
98129 session = Session (
0 commit comments