|
1 | 1 | import asyncio |
2 | 2 | import logging |
3 | | -import sys |
| 3 | +import ssl |
4 | 4 | from dataclasses import dataclass |
5 | 5 | from datetime import timedelta |
6 | 6 | from typing import Optional |
7 | 7 |
|
| 8 | +import asyncclick as click |
| 9 | +from aiohttp import web |
| 10 | + |
8 | 11 | from examples.example_fixtures import large_data1 |
| 12 | +from examples.fixtures import cert_gen |
9 | 13 | from examples.response_channel import response_stream_1, LoggingSubscriber |
10 | 14 | from response_stream import response_stream_2 |
11 | 15 | from rsocket.extensions.authentication import Authentication, AuthenticationSimple |
|
15 | 19 | from rsocket.routing.request_router import RequestRouter |
16 | 20 | from rsocket.routing.routing_request_handler import RoutingRequestHandler |
17 | 21 | from rsocket.rsocket_server import RSocketServer |
| 22 | +from rsocket.transports.aiohttp_websocket import TransportAioHttpWebsocket |
18 | 23 | from rsocket.transports.tcp import TransportTCP |
19 | 24 |
|
20 | 25 | router = RequestRouter() |
@@ -106,16 +111,49 @@ def handle_client(reader, writer): |
106 | 111 | RSocketServer(TransportTCP(reader, writer), handler_factory=handler_factory) |
107 | 112 |
|
108 | 113 |
|
109 | | -async def run_server(server_port): |
110 | | - logging.info('Starting server at localhost:%s', server_port) |
| 114 | +def websocket_handler_factory(**kwargs): |
| 115 | + async def websocket_handler(request): |
| 116 | + ws = web.WebSocketResponse() |
| 117 | + await ws.prepare(request) |
| 118 | + transport = TransportAioHttpWebsocket(ws) |
| 119 | + RSocketServer(transport, **kwargs) |
| 120 | + await transport.handle_incoming_ws_messages() |
| 121 | + return ws |
| 122 | + |
| 123 | + return websocket_handler |
| 124 | + |
| 125 | + |
| 126 | +@click.command() |
| 127 | +@click.option('--port', help='Port to listen on', default=6565, type=int) |
| 128 | +@click.option('--with-ssl', is_flag=True, help='Enable SSL mode') |
| 129 | +@click.option('--transport', is_flag=False, default='tcp') |
| 130 | +async def start_server(with_ssl: bool, port: int, transport: str): |
| 131 | + logging.basicConfig(level=logging.DEBUG) |
| 132 | + |
| 133 | + logging.info(f'Starting {transport} server at localhost:{port}') |
| 134 | + |
| 135 | + if transport in ['ws', 'wss']: |
| 136 | + app = web.Application() |
| 137 | + app.add_routes([web.get('/', websocket_handler_factory(handler_factory=handler_factory))]) |
| 138 | + |
| 139 | + if with_ssl: |
| 140 | + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) |
| 141 | + |
| 142 | + with cert_gen() as (certificate, key): |
| 143 | + ssl_context.load_cert_chain(certificate, key) |
| 144 | + else: |
| 145 | + ssl_context = None |
111 | 146 |
|
112 | | - server = await asyncio.start_server(handle_client, 'localhost', server_port) |
| 147 | + await web._run_app(app, port=port, ssl_context=ssl_context) |
| 148 | + elif transport == 'tcp': |
113 | 149 |
|
114 | | - async with server: |
115 | | - await server.serve_forever() |
| 150 | + server = await asyncio.start_server(handle_client, 'localhost', port) |
| 151 | + |
| 152 | + async with server: |
| 153 | + await server.serve_forever() |
| 154 | + else: |
| 155 | + raise Exception(f'Unsupported transport {transport}') |
116 | 156 |
|
117 | 157 |
|
118 | 158 | if __name__ == '__main__': |
119 | | - port = sys.argv[1] if len(sys.argv) > 1 else 6565 |
120 | | - logging.basicConfig(level=logging.DEBUG) |
121 | | - asyncio.run(run_server(port)) |
| 159 | + start_server() |
0 commit comments