Skip to content

Commit 99c3106

Browse files
added test for websocket reconnect flow.
better reconnect flow
1 parent 01f4fbf commit 99c3106

File tree

7 files changed

+150
-72
lines changed

7 files changed

+150
-72
lines changed

rsocket/rsocket_base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(self,
7979
self._responder_lease = None
8080
self._requester_lease = None
8181
self._is_closing = False
82+
self._connecting = True
8283

8384
self._async_frame_handler_by_type: Dict[Type[Frame], Any] = {
8485
RequestResponseFrame: self.handle_request_response,
@@ -420,7 +421,7 @@ async def close(self):
420421
async def _close_transport(self):
421422
if self._current_transport().done():
422423
logger().debug('%s: Closing transport', self._log_identifier())
423-
transport = self._current_transport().result()
424+
transport = await self._current_transport()
424425

425426
if transport is not None:
426427
try:

rsocket/rsocket_client.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,17 @@ async def connect(self):
7070
return await super().connect()
7171

7272
async def _connect_new_transport(self):
73-
new_transport = await self._get_new_transport()
73+
try:
74+
new_transport = await self._get_new_transport()
7475

75-
if new_transport is None:
76-
raise RSocketNoAvailableTransport()
76+
if new_transport is None:
77+
raise RSocketNoAvailableTransport()
7778

78-
self._next_transport.set_result(new_transport)
79-
transport = await self._current_transport()
80-
await transport.connect()
79+
self._next_transport.set_result(new_transport)
80+
transport = await self._current_transport()
81+
await transport.connect()
82+
finally:
83+
self._connecting = False
8184

8285
async def _get_new_transport(self):
8386
try:
@@ -110,6 +113,13 @@ async def _reconnect_listener(self):
110113
try:
111114
while True:
112115
await self._connect_request_event.wait()
116+
117+
logger().debug('%s: Got reconnect request', self._log_identifier())
118+
119+
if self._connecting:
120+
continue
121+
122+
self._connecting = True
113123
self._connect_request_event.clear()
114124
await self._close(reconnect=True)
115125
self._next_transport = create_future()

rsocket/transports/abstract_messaging.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ def __init__(self):
1212
async def next_frame_generator(self):
1313
frame = await self._incoming_frame_queue.get()
1414

15+
if isinstance(frame, Exception):
16+
raise frame
17+
1518
async def frame_generator():
1619
yield frame
1720

rsocket/transports/aiohttp_websocket.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import aiohttp
55
from aiohttp import web
66

7+
from rsocket.exceptions import RSocketTransportError
78
from rsocket.frame import Frame
89
from rsocket.helpers import wrap_transport_exception, single_transport_provider
910
from rsocket.logger import logger
@@ -44,20 +45,29 @@ def __init__(self, url):
4445
self._ws_context = None
4546
self._ws = None
4647
self._message_handler = None
48+
self._connection_ready = asyncio.Event()
4749

4850
async def connect(self):
4951
self._session = aiohttp.ClientSession()
5052
self._ws_context = self._session.ws_connect(self._url)
5153
self._ws = await self._ws_context.__aenter__()
54+
self._connection_ready.set()
5255
self._message_handler = asyncio.create_task(self.handle_incoming_ws_messages())
5356

5457
async def handle_incoming_ws_messages(self):
55-
async for msg in self._ws:
56-
if msg.type == aiohttp.WSMsgType.BINARY:
57-
async for frame in self._frame_parser.receive_data(msg.data, 0):
58-
self._incoming_frame_queue.put_nowait(frame)
58+
await self._connection_ready.wait()
59+
try:
60+
async for msg in self._ws:
61+
if msg.type == aiohttp.WSMsgType.BINARY:
62+
async for frame in self._frame_parser.receive_data(msg.data, 0):
63+
self._incoming_frame_queue.put_nowait(frame)
64+
except asyncio.CancelledError:
65+
pass
66+
except Exception:
67+
self._incoming_frame_queue.put_nowait(RSocketTransportError())
5968

6069
async def send_frame(self, frame: Frame):
70+
await self._connection_ready.wait()
6171
with wrap_transport_exception():
6272
await self._ws.send_bytes(frame.serialize())
6373

tests/rsocket/helpers.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
from dataclasses import dataclass
23
from datetime import timedelta
34
from math import ceil
45
from typing import Type
@@ -8,6 +9,8 @@
89
from rsocket.payload import Payload
910
from rsocket.request_handler import BaseRequestHandler
1011
from rsocket.rsocket_base import RSocketBase
12+
from rsocket.rsocket_server import RSocketServer
13+
from rsocket.transports.transport import Transport
1114

1215

1316
def data_bits(data: bytes, name: str = None):
@@ -69,6 +72,12 @@ def factory(self, socket) -> BaseRequestHandler:
6972
return self._handler_factory(socket, self._server_id, self._delay)
7073

7174

72-
async def force_closing_connection(current_connection, delay=timedelta(0)):
75+
async def force_closing_connection(transport, delay=timedelta(0)):
7376
await asyncio.sleep(delay.total_seconds())
74-
current_connection[1].close()
77+
await transport.close()
78+
79+
80+
@dataclass
81+
class ServerContainer:
82+
server: RSocketServer = None
83+
transport: Transport = None

tests/rsocket/test_connection_lost.py

Lines changed: 102 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from typing import Optional, Tuple
66

77
import pytest
8+
from aiohttp.test_utils import RawTestServer
9+
from asyncstdlib import sync
810

911
from reactivestreams.publisher import Publisher
1012
from rsocket.awaitable.awaitable_rsocket import AwaitableRSocket
@@ -17,10 +19,11 @@
1719
from rsocket.rsocket_client import RSocketClient
1820
from rsocket.rsocket_server import RSocketServer
1921
from rsocket.streams.stream_from_async_generator import StreamFromAsyncGenerator
22+
from rsocket.transports.aiohttp_websocket import websocket_handler_factory, TransportAioHttpClient
2023
from rsocket.transports.tcp import TransportTCP
2124
from rsocket.transports.transport import Transport
2225
from tests.rsocket.helpers import future_from_payload, IdentifiedHandlerFactory, \
23-
IdentifiedHandler, force_closing_connection
26+
IdentifiedHandler, force_closing_connection, ServerContainer
2427

2528

2629
class ServerHandler(IdentifiedHandler):
@@ -45,7 +48,7 @@ async def test_connection_lost(unused_tcp_port):
4548
index_iterator = iter(range(1, 3))
4649

4750
wait_for_server = Event()
48-
server_connection: Optional[Tuple] = None
51+
transport: Optional[Transport] = None
4952
client_connection: Optional[Tuple] = None
5053

5154
class ClientHandler(BaseRequestHandler):
@@ -54,9 +57,9 @@ async def on_connection_lost(self, rsocket, exception: Exception):
5457
await rsocket.reconnect()
5558

5659
def session(*connection):
57-
nonlocal server, server_connection
58-
server_connection = connection
59-
server = RSocketServer(TransportTCP(*connection),
60+
nonlocal server, transport
61+
transport = TransportTCP(*connection)
62+
server = RSocketServer(transport,
6063
IdentifiedHandlerFactory(next(index_iterator), ServerHandler).factory)
6164
wait_for_server.set()
6265

@@ -89,7 +92,9 @@ async def transport_provider():
8992
await wait_for_server.wait()
9093
wait_for_server.clear()
9194
response1 = await connection.request_response(Payload(b'request 1'))
92-
await force_closing_connection(server_connection)
95+
96+
await force_closing_connection(transport)
97+
9398
await server.close() # cleanup async tasks from previous server to avoid errors (?)
9499
await wait_for_server.wait()
95100
response2 = await connection.request_response(Payload(b'request 2'))
@@ -118,11 +123,11 @@ async def close(self):
118123

119124

120125
@pytest.mark.allow_error_log(regex_filter='Connection error')
121-
async def test_connection_failure(unused_tcp_port: int):
126+
async def test_tcp_connection_failure(unused_tcp_port: int):
122127
index_iterator = iter(range(1, 3))
123128

124129
wait_for_server = Event()
125-
server_connection: Optional[Tuple] = None
130+
transport: Optional[Transport] = None
126131
client_connection: Optional[Tuple] = None
127132

128133
class ClientHandler(BaseRequestHandler):
@@ -131,9 +136,9 @@ async def on_connection_lost(self, rsocket, exception: Exception):
131136
await rsocket.reconnect()
132137

133138
def session(*connection):
134-
nonlocal server, server_connection
135-
server_connection = connection
136-
server = RSocketServer(TransportTCP(*connection),
139+
nonlocal server, transport
140+
transport = TransportTCP(*connection)
141+
server = RSocketServer(transport,
137142
IdentifiedHandlerFactory(next(index_iterator), ServerHandler).factory)
138143
wait_for_server.set()
139144

@@ -172,7 +177,7 @@ async def transport_provider():
172177

173178
response1 = await connection.request_response(Payload(b'request 1'))
174179

175-
await force_closing_connection(server_connection)
180+
await force_closing_connection(transport)
176181

177182
await server.close() # cleanup async tasks from previous server to avoid errors (?)
178183
await wait_for_server.wait()
@@ -187,75 +192,115 @@ async def transport_provider():
187192
service.close()
188193

189194

190-
@pytest.mark.allow_error_log(regex_filter='Connection error')
191-
async def test_connection_failure_during_stream(unused_tcp_port):
192-
index_iterator = iter(range(1, 3))
195+
class ClientHandler(BaseRequestHandler):
196+
async def on_connection_lost(self, rsocket, exception: Exception):
197+
logger().info('Test Reconnecting')
198+
await rsocket.reconnect()
193199

194-
wait_for_server = Event()
195-
server_connection: Optional[Tuple] = None
196-
client_connection: Optional[Tuple] = None
197200

198-
class ClientHandler(BaseRequestHandler):
199-
async def on_connection_lost(self, rsocket, exception: Exception):
200-
logger().info('Test Reconnecting')
201-
await rsocket.reconnect()
201+
async def start_tcp_service(waiter: asyncio.Event, container, port: int):
202+
index_iterator = iter(range(1, 3))
202203

203204
def session(*connection):
204-
nonlocal server, server_connection
205-
server_connection = connection
206-
server = RSocketServer(TransportTCP(*connection),
207-
IdentifiedHandlerFactory(next(index_iterator),
208-
ServerHandler,
209-
delay=timedelta(seconds=1)).factory)
210-
wait_for_server.set()
205+
container.transport = TransportTCP(*connection)
206+
container.server = RSocketServer(container.transport,
207+
IdentifiedHandlerFactory(next(index_iterator),
208+
ServerHandler,
209+
delay=timedelta(seconds=1)).factory)
210+
waiter.set()
211211

212-
async def start():
213-
nonlocal service, client
214-
service = await asyncio.start_server(session, host, port)
212+
service = await asyncio.start_server(session, 'localhost', port)
213+
return sync(service.close)
215214

216-
async def transport_provider():
217-
try:
218-
nonlocal client_connection
219-
client_connection = await asyncio.open_connection(host, port)
220-
yield TransportTCP(*client_connection)
221215

222-
yield FailingTransportTCP()
216+
async def start_tcp_client(port: int) -> RSocketClient:
217+
async def transport_provider():
218+
try:
219+
client_connection = await asyncio.open_connection('localhost', port)
220+
yield TransportTCP(*client_connection)
223221

224-
client_connection = await asyncio.open_connection(host, port)
225-
yield TransportTCP(*client_connection)
226-
except Exception:
227-
logger().error('Client connection error', exc_info=True)
228-
raise
222+
yield FailingTransportTCP()
229223

230-
client = RSocketClient(transport_provider(), handler_factory=ClientHandler)
224+
client_connection = await asyncio.open_connection('localhost', port)
225+
yield TransportTCP(*client_connection)
226+
except Exception:
227+
logger().error('Client connection error', exc_info=True)
228+
raise
231229

232-
service: Optional[Server] = None
233-
server: Optional[RSocketServer] = None
234-
client: Optional[RSocketClient] = None
235-
port = unused_tcp_port
236-
host = 'localhost'
230+
return RSocketClient(transport_provider(), handler_factory=ClientHandler)
237231

238-
await start()
232+
233+
async def start_websocket_service(waiter: asyncio.Event, container, port: int):
234+
index_iterator = iter(range(1, 3))
235+
236+
def handler_factory(*args, **kwargs):
237+
return IdentifiedHandlerFactory(
238+
next(index_iterator),
239+
ServerHandler,
240+
delay=timedelta(seconds=1)).factory(*args, **kwargs)
241+
242+
def on_server_create(server):
243+
container.server = server
244+
container.transport = server._transport
245+
waiter.set()
246+
247+
server = RawTestServer(websocket_handler_factory(on_server_create=on_server_create,
248+
handler_factory=handler_factory), port=port)
249+
await server.start_server()
250+
return server.close
251+
252+
253+
async def start_websocket_client(port: int) -> RSocketClient:
254+
url = 'http://localhost:{}'.format(port)
255+
256+
async def transport_provider():
257+
try:
258+
yield TransportAioHttpClient(url)
259+
260+
yield FailingTransportTCP()
261+
262+
yield TransportAioHttpClient(url)
263+
except Exception:
264+
logger().error('Client connection error', exc_info=True)
265+
raise
266+
267+
return RSocketClient(transport_provider(), handler_factory=ClientHandler)
268+
269+
270+
@pytest.mark.allow_error_log(regex_filter='Connection error')
271+
@pytest.mark.parametrize(
272+
'start_service, start_client',
273+
(
274+
(start_tcp_service, start_tcp_client),
275+
(start_websocket_service, start_websocket_client),
276+
)
277+
)
278+
async def test_connection_failure_during_stream(unused_tcp_port, start_service, start_client):
279+
server_container = ServerContainer()
280+
wait_for_server = Event()
281+
282+
service_closer = await start_service(wait_for_server, server_container, unused_tcp_port)
283+
client = await start_client(unused_tcp_port)
239284

240285
try:
241-
async with AwaitableRSocket(client) as connection:
286+
async with AwaitableRSocket(client) as a_client:
242287
await wait_for_server.wait()
243288
wait_for_server.clear()
244289

245290
with pytest.raises(RSocketProtocolError) as exc_info:
246291
await asyncio.gather(
247-
connection.request_stream(Payload(b'request 1')),
248-
force_closing_connection(server_connection, timedelta(seconds=2)))
292+
a_client.request_stream(Payload(b'request 1')),
293+
force_closing_connection(server_container.transport, timedelta(seconds=2)))
249294

250295
assert exc_info.value.data == 'Connection error'
251296
assert exc_info.value.error_code == ErrorCode.CONNECTION_ERROR
252297

253-
await server.close() # cleanup async tasks from previous server to avoid errors (?)
298+
await server_container.server.close() # cleanup async tasks from previous server to avoid errors (?)
254299
await wait_for_server.wait()
255-
response2 = await connection.request_response(Payload(b'request 2'))
300+
response2 = await a_client.request_response(Payload(b'request 2'))
256301

257302
assert response2.data == b'data: request 2 server 2'
258303
finally:
259-
await server.close()
304+
await server_container.server.close()
260305

261-
service.close()
306+
await service_closer()

0 commit comments

Comments
 (0)