|
| 1 | +import asyncio |
| 2 | +from time import sleep |
| 3 | +from typing import Tuple |
| 4 | + |
| 5 | +import pytest |
| 6 | +import reactivex |
| 7 | +from reactivex import operators, Observable |
| 8 | +from reactivex.operators._tofuture import to_future_ |
| 9 | +from reactivex.scheduler import ThreadPoolScheduler |
| 10 | +from reactivex.scheduler.eventloop import AsyncIOScheduler |
| 11 | + |
| 12 | +from rsocket.exceptions import RSocketProtocolError |
| 13 | +from rsocket.frame_helpers import ensure_bytes |
| 14 | +from rsocket.payload import Payload |
| 15 | +from rsocket.reactivex.reactivex_client import ReactiveXClient |
| 16 | +from rsocket.reactivex.reactivex_handler import BaseReactivexHandler |
| 17 | +from rsocket.reactivex.reactivex_handler_adapter import reactivex_handler_factory |
| 18 | +from rsocket.rsocket_client import RSocketClient |
| 19 | +from rsocket.rsocket_server import RSocketServer |
| 20 | + |
| 21 | + |
| 22 | +async def test_serve_reactivex_stream_disconnected(pipe: Tuple[RSocketServer, RSocketClient]): |
| 23 | + sent_counter = 0 |
| 24 | + |
| 25 | + def increment_sent_counter(): |
| 26 | + nonlocal sent_counter |
| 27 | + sent_counter += 1 |
| 28 | + |
| 29 | + def delayed(message): |
| 30 | + sleep(0.3) |
| 31 | + return message |
| 32 | + |
| 33 | + class Handler(BaseReactivexHandler): |
| 34 | + |
| 35 | + async def request_stream(self, payload: Payload) -> Observable: |
| 36 | + return reactivex.from_((delayed('Feed Item: {}'.format(index)) for index in range(10)), |
| 37 | + ThreadPoolScheduler()).pipe( |
| 38 | + operators.do_action(on_next=lambda _: increment_sent_counter()), |
| 39 | + operators.map(lambda _: Payload(ensure_bytes(_))) |
| 40 | + ) |
| 41 | + |
| 42 | + server, client = pipe |
| 43 | + |
| 44 | + server.set_handler_using_factory(reactivex_handler_factory(Handler)) |
| 45 | + |
| 46 | + async def request(): |
| 47 | + await ReactiveXClient(client).request_stream(Payload(b'request text'), |
| 48 | + request_limit=2).pipe( |
| 49 | + operators.map(lambda payload: payload.data), |
| 50 | + operators.to_list(), |
| 51 | + to_future_(scheduler=AsyncIOScheduler(loop=asyncio.get_event_loop())) |
| 52 | + ) |
| 53 | + |
| 54 | + task = asyncio.create_task(request()) |
| 55 | + |
| 56 | + await asyncio.sleep(1) |
| 57 | + |
| 58 | + await client.close() |
| 59 | + |
| 60 | + assert 0 < sent_counter < 5 |
| 61 | + |
| 62 | + with pytest.raises(RSocketProtocolError): |
| 63 | + await task |
0 commit comments