@@ -51,13 +51,23 @@ async def _assert_connection_dead(sock: AsyncioSocket) -> None:
5151class Server :
5252 def __init__ (self , sock : AsyncioSocket ):
5353 self ._sock = sock
54+ self ._sockets = []
5455
5556 async def accept (self ) -> AsyncioSocket :
5657 logger .debug (f'Accepting connections on { self ._sock .getsockname ()} ' )
5758 server_connection , _ = await self ._sock .accept ()
59+ self ._sockets .append (server_connection )
5860 logger .debug (f'Accepted a connection on { server_connection .getsockname ()} ' )
5961 return server_connection
6062
63+ async def __aenter__ (self ):
64+ return self
65+
66+ async def __aexit__ (self , exc_type , exc_value , traceback ):
67+ for sock in self ._sockets :
68+ sock .close ()
69+ self ._sock .close ()
70+
6171 def get_port (self ) -> int :
6272 return self ._sock .getsockname ()[1 ]
6373
@@ -67,13 +77,20 @@ async def _make_client(
6777 gate : chaos .TcpGate ,
6878 asyncio_socket : AsyncioSocketsFactory ,
6979):
80+ # collect all produced sockets to properly close them during teardown
81+ sockets = []
82+
7083 async def make_client ():
7184 sock = asyncio_socket .tcp ()
7285 sock .setsockopt (socket .IPPROTO_TCP , socket .TCP_NODELAY , 1 )
86+ sockets .append (sock )
7387 await sock .connect (gate .get_sockname_for_clients ())
7488 return sock
7589
76- return make_client
90+ yield make_client
91+
92+ for sock in sockets :
93+ sock .close ()
7794
7895
7996@pytest .fixture (name = 'tcp_server' )
@@ -82,8 +99,8 @@ async def _server(asyncio_socket: AsyncioSocketsFactory):
8299 sock .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
83100 sock .bind (('localhost' , 0 ))
84101 sock .listen ()
85- yield Server (sock )
86- sock . close ()
102+ async with Server (sock ) as server :
103+ yield server
87104
88105
89106@pytest .fixture (name = 'gate' )
@@ -99,18 +116,15 @@ async def _gate(tcp_server):
99116
100117@pytest .fixture (name = 'tcp_client' )
101118async def _client (make_client ):
102- sock = await make_client ()
103- yield sock
104- sock .close ()
119+ yield await make_client ()
105120
106121
107122@pytest .fixture (name = 'server_connection' )
108- async def _server_connection (tcp_server , gate ):
123+ async def _server_connection (tcp_server , gate , tcp_client ):
109124 sock = await tcp_server .accept ()
110125 await gate .wait_for_connections (count = 1 )
111126 assert gate .connections_count () >= 1
112- yield sock
113- sock .close ()
127+ return sock
114128
115129
116130async def test_basic (tcp_client , gate , server_connection ):
@@ -122,6 +136,14 @@ async def test_basic(tcp_client, gate, server_connection):
122136 assert gate .connections_count () == 0
123137
124138
139+ async def test_gate (gate ):
140+ pass
141+
142+
143+ async def test_server_connection (server_connection ):
144+ pass
145+
146+
125147async def test_to_client_noop (tcp_client , gate , server_connection ):
126148 await gate .to_client_noop ()
127149
0 commit comments