55from typing import Optional , Tuple
66
77import pytest
8+ from aiohttp .test_utils import RawTestServer
9+ from asyncstdlib import sync
810
911from reactivestreams .publisher import Publisher
1012from rsocket .awaitable .awaitable_rsocket import AwaitableRSocket
1719from rsocket .rsocket_client import RSocketClient
1820from rsocket .rsocket_server import RSocketServer
1921from rsocket .streams .stream_from_async_generator import StreamFromAsyncGenerator
22+ from rsocket .transports .aiohttp_websocket import websocket_handler_factory , TransportAioHttpClient
2023from rsocket .transports .tcp import TransportTCP
2124from rsocket .transports .transport import Transport
2225from tests .rsocket .helpers import future_from_payload , IdentifiedHandlerFactory , \
23- IdentifiedHandler , force_closing_connection
26+ IdentifiedHandler , force_closing_connection , ServerContainer
2427
2528
2629class ServerHandler (IdentifiedHandler ):
@@ -45,7 +48,7 @@ async def test_connection_lost(unused_tcp_port):
4548 index_iterator = iter (range (1 , 3 ))
4649
4750 wait_for_server = Event ()
48- server_connection : Optional [Tuple ] = None
51+ transport : Optional [Transport ] = None
4952 client_connection : Optional [Tuple ] = None
5053
5154 class ClientHandler (BaseRequestHandler ):
@@ -54,9 +57,9 @@ async def on_connection_lost(self, rsocket, exception: Exception):
5457 await rsocket .reconnect ()
5558
5659 def session (* connection ):
57- nonlocal server , server_connection
58- server_connection = connection
59- server = RSocketServer (TransportTCP ( * connection ) ,
60+ nonlocal server , transport
61+ transport = TransportTCP ( * connection )
62+ server = RSocketServer (transport ,
6063 IdentifiedHandlerFactory (next (index_iterator ), ServerHandler ).factory )
6164 wait_for_server .set ()
6265
@@ -89,7 +92,9 @@ async def transport_provider():
8992 await wait_for_server .wait ()
9093 wait_for_server .clear ()
9194 response1 = await connection .request_response (Payload (b'request 1' ))
92- await force_closing_connection (server_connection )
95+
96+ await force_closing_connection (transport )
97+
9398 await server .close () # cleanup async tasks from previous server to avoid errors (?)
9499 await wait_for_server .wait ()
95100 response2 = await connection .request_response (Payload (b'request 2' ))
@@ -118,11 +123,11 @@ async def close(self):
118123
119124
120125@pytest .mark .allow_error_log (regex_filter = 'Connection error' )
121- async def test_connection_failure (unused_tcp_port : int ):
126+ async def test_tcp_connection_failure (unused_tcp_port : int ):
122127 index_iterator = iter (range (1 , 3 ))
123128
124129 wait_for_server = Event ()
125- server_connection : Optional [Tuple ] = None
130+ transport : Optional [Transport ] = None
126131 client_connection : Optional [Tuple ] = None
127132
128133 class ClientHandler (BaseRequestHandler ):
@@ -131,9 +136,9 @@ async def on_connection_lost(self, rsocket, exception: Exception):
131136 await rsocket .reconnect ()
132137
133138 def session (* connection ):
134- nonlocal server , server_connection
135- server_connection = connection
136- server = RSocketServer (TransportTCP ( * connection ) ,
139+ nonlocal server , transport
140+ transport = TransportTCP ( * connection )
141+ server = RSocketServer (transport ,
137142 IdentifiedHandlerFactory (next (index_iterator ), ServerHandler ).factory )
138143 wait_for_server .set ()
139144
@@ -172,7 +177,7 @@ async def transport_provider():
172177
173178 response1 = await connection .request_response (Payload (b'request 1' ))
174179
175- await force_closing_connection (server_connection )
180+ await force_closing_connection (transport )
176181
177182 await server .close () # cleanup async tasks from previous server to avoid errors (?)
178183 await wait_for_server .wait ()
@@ -187,75 +192,115 @@ async def transport_provider():
187192 service .close ()
188193
189194
190- @pytest .mark .allow_error_log (regex_filter = 'Connection error' )
191- async def test_connection_failure_during_stream (unused_tcp_port ):
192- index_iterator = iter (range (1 , 3 ))
195+ class ClientHandler (BaseRequestHandler ):
196+ async def on_connection_lost (self , rsocket , exception : Exception ):
197+ logger ().info ('Test Reconnecting' )
198+ await rsocket .reconnect ()
193199
194- wait_for_server = Event ()
195- server_connection : Optional [Tuple ] = None
196- client_connection : Optional [Tuple ] = None
197200
198- class ClientHandler (BaseRequestHandler ):
199- async def on_connection_lost (self , rsocket , exception : Exception ):
200- logger ().info ('Test Reconnecting' )
201- await rsocket .reconnect ()
201+ async def start_tcp_service (waiter : asyncio .Event , container , port : int ):
202+ index_iterator = iter (range (1 , 3 ))
202203
203204 def session (* connection ):
204- nonlocal server , server_connection
205- server_connection = connection
206- server = RSocketServer (TransportTCP (* connection ),
207- IdentifiedHandlerFactory (next (index_iterator ),
208- ServerHandler ,
209- delay = timedelta (seconds = 1 )).factory )
210- wait_for_server .set ()
205+ container .transport = TransportTCP (* connection )
206+ container .server = RSocketServer (container .transport ,
207+ IdentifiedHandlerFactory (next (index_iterator ),
208+ ServerHandler ,
209+ delay = timedelta (seconds = 1 )).factory )
210+ waiter .set ()
211211
212- async def start ():
213- nonlocal service , client
214- service = await asyncio .start_server (session , host , port )
212+ service = await asyncio .start_server (session , 'localhost' , port )
213+ return sync (service .close )
215214
216- async def transport_provider ():
217- try :
218- nonlocal client_connection
219- client_connection = await asyncio .open_connection (host , port )
220- yield TransportTCP (* client_connection )
221215
222- yield FailingTransportTCP ()
216+ async def start_tcp_client (port : int ) -> RSocketClient :
217+ async def transport_provider ():
218+ try :
219+ client_connection = await asyncio .open_connection ('localhost' , port )
220+ yield TransportTCP (* client_connection )
223221
224- client_connection = await asyncio .open_connection (host , port )
225- yield TransportTCP (* client_connection )
226- except Exception :
227- logger ().error ('Client connection error' , exc_info = True )
228- raise
222+ yield FailingTransportTCP ()
229223
230- client = RSocketClient (transport_provider (), handler_factory = ClientHandler )
224+ client_connection = await asyncio .open_connection ('localhost' , port )
225+ yield TransportTCP (* client_connection )
226+ except Exception :
227+ logger ().error ('Client connection error' , exc_info = True )
228+ raise
231229
232- service : Optional [Server ] = None
233- server : Optional [RSocketServer ] = None
234- client : Optional [RSocketClient ] = None
235- port = unused_tcp_port
236- host = 'localhost'
230+ return RSocketClient (transport_provider (), handler_factory = ClientHandler )
237231
238- await start ()
232+
233+ async def start_websocket_service (waiter : asyncio .Event , container , port : int ):
234+ index_iterator = iter (range (1 , 3 ))
235+
236+ def handler_factory (* args , ** kwargs ):
237+ return IdentifiedHandlerFactory (
238+ next (index_iterator ),
239+ ServerHandler ,
240+ delay = timedelta (seconds = 1 )).factory (* args , ** kwargs )
241+
242+ def on_server_create (server ):
243+ container .server = server
244+ container .transport = server ._transport
245+ waiter .set ()
246+
247+ server = RawTestServer (websocket_handler_factory (on_server_create = on_server_create ,
248+ handler_factory = handler_factory ), port = port )
249+ await server .start_server ()
250+ return server .close
251+
252+
253+ async def start_websocket_client (port : int ) -> RSocketClient :
254+ url = 'http://localhost:{}' .format (port )
255+
256+ async def transport_provider ():
257+ try :
258+ yield TransportAioHttpClient (url )
259+
260+ yield FailingTransportTCP ()
261+
262+ yield TransportAioHttpClient (url )
263+ except Exception :
264+ logger ().error ('Client connection error' , exc_info = True )
265+ raise
266+
267+ return RSocketClient (transport_provider (), handler_factory = ClientHandler )
268+
269+
270+ @pytest .mark .allow_error_log (regex_filter = 'Connection error' )
271+ @pytest .mark .parametrize (
272+ 'start_service, start_client' ,
273+ (
274+ (start_tcp_service , start_tcp_client ),
275+ (start_websocket_service , start_websocket_client ),
276+ )
277+ )
278+ async def test_connection_failure_during_stream (unused_tcp_port , start_service , start_client ):
279+ server_container = ServerContainer ()
280+ wait_for_server = Event ()
281+
282+ service_closer = await start_service (wait_for_server , server_container , unused_tcp_port )
283+ client = await start_client (unused_tcp_port )
239284
240285 try :
241- async with AwaitableRSocket (client ) as connection :
286+ async with AwaitableRSocket (client ) as a_client :
242287 await wait_for_server .wait ()
243288 wait_for_server .clear ()
244289
245290 with pytest .raises (RSocketProtocolError ) as exc_info :
246291 await asyncio .gather (
247- connection .request_stream (Payload (b'request 1' )),
248- force_closing_connection (server_connection , timedelta (seconds = 2 )))
292+ a_client .request_stream (Payload (b'request 1' )),
293+ force_closing_connection (server_container . transport , timedelta (seconds = 2 )))
249294
250295 assert exc_info .value .data == 'Connection error'
251296 assert exc_info .value .error_code == ErrorCode .CONNECTION_ERROR
252297
253- await server .close () # cleanup async tasks from previous server to avoid errors (?)
298+ await server_container . server .close () # cleanup async tasks from previous server to avoid errors (?)
254299 await wait_for_server .wait ()
255- response2 = await connection .request_response (Payload (b'request 2' ))
300+ response2 = await a_client .request_response (Payload (b'request 2' ))
256301
257302 assert response2 .data == b'data: request 2 server 2'
258303 finally :
259- await server .close ()
304+ await server_container . server .close ()
260305
261- service . close ()
306+ await service_closer ()
0 commit comments