66
77import pytest
88from aiohttp .test_utils import RawTestServer
9+ from aioquic .quic .configuration import QuicConfiguration
910from asyncstdlib import sync
11+ from cryptography .hazmat .primitives import serialization
1012
1113from reactivestreams .publisher import Publisher
1214from rsocket .awaitable .awaitable_rsocket import AwaitableRSocket
2022from rsocket .rsocket_server import RSocketServer
2123from rsocket .streams .stream_from_async_generator import StreamFromAsyncGenerator
2224from rsocket .transports .aiohttp_websocket import websocket_handler_factory , TransportAioHttpClient
25+ from rsocket .transports .aioquic_transport import rsocket_connect , rsocket_serve
2326from rsocket .transports .tcp import TransportTCP
2427from rsocket .transports .transport import Transport
2528from tests .rsocket .helpers import future_from_payload , IdentifiedHandlerFactory , \
@@ -107,7 +110,7 @@ async def transport_provider():
107110 service .close ()
108111
109112
110- class FailingTransportTCP (Transport ):
113+ class FailingTransport (Transport ):
111114
112115 async def connect (self ):
113116 raise Exception
@@ -152,7 +155,7 @@ async def transport_provider():
152155 client_connection = await asyncio .open_connection (host , port )
153156 yield TransportTCP (* client_connection )
154157
155- yield FailingTransportTCP ()
158+ yield FailingTransport ()
156159
157160 client_connection = await asyncio .open_connection (host , port )
158161 yield TransportTCP (* client_connection )
@@ -198,7 +201,7 @@ async def on_connection_lost(self, rsocket, exception: Exception):
198201 await rsocket .reconnect ()
199202
200203
201- async def start_tcp_service (waiter : asyncio .Event , container , port : int ):
204+ async def start_tcp_service (waiter : asyncio .Event , container , port : int , generate_test_certificates ):
202205 index_iterator = iter (range (1 , 3 ))
203206
204207 def session (* connection ):
@@ -213,13 +216,13 @@ def session(*connection):
213216 return sync (service .close )
214217
215218
216- async def start_tcp_client (port : int ) -> RSocketClient :
219+ async def start_tcp_client (port : int , generate_test_certificates ) -> RSocketClient :
217220 async def transport_provider ():
218221 try :
219222 client_connection = await asyncio .open_connection ('localhost' , port )
220223 yield TransportTCP (* client_connection )
221224
222- yield FailingTransportTCP ()
225+ yield FailingTransport ()
223226
224227 client_connection = await asyncio .open_connection ('localhost' , port )
225228 yield TransportTCP (* client_connection )
@@ -230,7 +233,7 @@ async def transport_provider():
230233 return RSocketClient (transport_provider (), handler_factory = ClientHandler )
231234
232235
233- async def start_websocket_service (waiter : asyncio .Event , container , port : int ):
236+ async def start_websocket_service (waiter : asyncio .Event , container , port : int , generate_test_certificates ):
234237 index_iterator = iter (range (1 , 3 ))
235238
236239 def handler_factory (* args , ** kwargs ):
@@ -250,14 +253,14 @@ def on_server_create(server):
250253 return server .close
251254
252255
253- async def start_websocket_client (port : int ) -> RSocketClient :
256+ async def start_websocket_client (port : int , generate_test_certificates ) -> RSocketClient :
254257 url = 'http://localhost:{}' .format (port )
255258
256259 async def transport_provider ():
257260 try :
258261 yield TransportAioHttpClient (url )
259262
260- yield FailingTransportTCP ()
263+ yield FailingTransport ()
261264
262265 yield TransportAioHttpClient (url )
263266 except Exception :
@@ -267,20 +270,76 @@ async def transport_provider():
267270 return RSocketClient (transport_provider (), handler_factory = ClientHandler )
268271
269272
273+ async def start_quic_service (waiter : asyncio .Event , container , port : int , generate_test_certificates ):
274+ index_iterator = iter (range (1 , 3 ))
275+ certificate , private_key = generate_test_certificates
276+ server_configuration = QuicConfiguration (
277+ certificate = certificate ,
278+ private_key = private_key ,
279+ is_client = False
280+ )
281+
282+ def handler_factory (* args , ** kwargs ):
283+ return IdentifiedHandlerFactory (
284+ next (index_iterator ),
285+ ServerHandler ,
286+ delay = timedelta (seconds = 1 )).factory (* args , ** kwargs )
287+
288+ def on_server_create (server ):
289+ container .server = server
290+ container .transport = server ._transport
291+ waiter .set ()
292+
293+ quic_server = await rsocket_serve (host = 'localhost' ,
294+ port = port ,
295+ configuration = server_configuration ,
296+ on_server_create = on_server_create ,
297+ handler_factory = handler_factory )
298+ return sync (quic_server .close )
299+
300+
301+ async def start_quic_client (port : int , generate_test_certificates ) -> RSocketClient :
302+ certificate , private_key = generate_test_certificates
303+ client_configuration = QuicConfiguration (
304+ is_client = True
305+ )
306+ ca_data = certificate .public_bytes (serialization .Encoding .PEM )
307+ client_configuration .load_verify_locations (cadata = ca_data , cafile = None )
308+
309+ async def transport_provider ():
310+ try :
311+ async with rsocket_connect ('localhost' , port ,
312+ configuration = client_configuration ) as transport :
313+ yield transport
314+
315+ yield FailingTransport ()
316+
317+ async with rsocket_connect ('localhost' , port ,
318+ configuration = client_configuration ) as transport :
319+ yield transport
320+ except Exception :
321+ logger ().error ('Client connection error' , exc_info = True )
322+ raise
323+
324+ return RSocketClient (transport_provider (), handler_factory = ClientHandler )
325+
326+
270327@pytest .mark .allow_error_log (regex_filter = 'Connection error' )
271328@pytest .mark .parametrize (
272329 'start_service, start_client' ,
273330 (
274331 (start_tcp_service , start_tcp_client ),
275332 (start_websocket_service , start_websocket_client ),
333+ (start_quic_service , start_quic_client ),
276334 )
277335)
278- async def test_connection_failure_during_stream (unused_tcp_port , start_service , start_client ):
336+ async def test_connection_failure_during_stream (unused_tcp_port , generate_test_certificates ,
337+ start_service , start_client ):
279338 server_container = ServerContainer ()
280339 wait_for_server = Event ()
281340
282- service_closer = await start_service (wait_for_server , server_container , unused_tcp_port )
283- client = await start_client (unused_tcp_port )
341+ service_closer = await start_service (wait_for_server , server_container , unused_tcp_port , generate_test_certificates )
342+ client = await start_client (unused_tcp_port , generate_test_certificates )
284343
285344 try :
286345 async with AwaitableRSocket (client ) as a_client :
0 commit comments