@@ -41,6 +41,8 @@ def consume_budget(self, user: str) -> None:
4141async def test_connect (ws_server : WsServerFixture ) -> None :
4242 (urimeta , recv , conn ) = ws_server
4343
44+ ws_close : asyncio .Task | None = None
45+
4446 def trigger_close (
4547 signal_closing : Callable [[], None ],
4648 task_manager : BackgroundTaskManager , # .cancel_all_tasks()
@@ -49,9 +51,23 @@ def trigger_close(
4951 ws : ClientConnection | None ,
5052 become_closed : Callable [[], None ],
5153 ) -> asyncio .Event :
52- event = asyncio .Event ()
53- event .set ()
54- return event
54+ nonlocal ws_close
55+
56+ closing_event = asyncio .Event ()
57+
58+ async def _do_close () -> None :
59+ signal_closing ()
60+ await task_manager .cancel_all_tasks ()
61+ terminate_remaining_output_streams ()
62+ await join_output_streams_with_timeout ()
63+ if ws :
64+ await ws .close ()
65+ become_closed ()
66+ closing_event .set ()
67+
68+ ws_close = asyncio .create_task (_do_close ())
69+
70+ return closing_event
5571
5672 session = Session (
5773 server_id = "SERVER" ,
@@ -68,6 +84,8 @@ def trigger_close(
6884 assert isinstance (msg , TransportMessage )
6985 assert msg .payload ["type" ] == "HANDSHAKE_REQ"
7086 await session .close ().wait ()
87+ assert ws_close is not None
88+ await ws_close
7189 await connecting
7290
7391
@@ -77,6 +95,7 @@ async def test_close_race(ws_server: WsServerFixture) -> None:
7795 callcount = 0
7896
7997 event : asyncio .Event | None = None
98+ ws_close : asyncio .Task | None = None
8099
81100 def trigger_close (
82101 signal_closing : Callable [[], None ],
@@ -87,11 +106,25 @@ def trigger_close(
87106 become_closed : Callable [[], None ],
88107 ) -> asyncio .Event :
89108 nonlocal event
109+ nonlocal ws_close
110+
90111 if event is None :
91112 event = asyncio .Event ()
92113 event .set ()
93114 nonlocal callcount
94115 callcount += 1
116+
117+ async def _do_close () -> None :
118+ signal_closing ()
119+ await task_manager .cancel_all_tasks ()
120+ terminate_remaining_output_streams ()
121+ await join_output_streams_with_timeout ()
122+ if ws :
123+ await ws .close ()
124+ become_closed ()
125+
126+ ws_close = asyncio .create_task (_do_close ())
127+
95128 return event
96129
97130 session = Session (
0 commit comments