diff --git a/xconn/async_session.py b/xconn/async_session.py index 7265629..d855cf9 100644 --- a/xconn/async_session.py +++ b/xconn/async_session.py @@ -76,6 +76,49 @@ def __init__(self, base_session: types.IAsyncBaseSession): self._loop = get_event_loop() self.wait_task = self._loop.create_task(self._wait()) + async def _handle_invocation( + self, + msg: messages.Invocation, + endpoint: Union[ + Callable[[types.Invocation], types.Result], Callable[[types.Invocation], Awaitable[types.Result]] + ], + ): + try: + result = await endpoint(types.Invocation(msg.args, msg.kwargs, msg.details)) + + if result is None: + data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id))) + elif isinstance(result, types.Result): + data = self._session.send_message( + messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details)) + ) + else: + message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str( + type(result) + ) + msg_to_send = messages.Error( + messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message]) + ) + data = self._session.send_message(msg_to_send) + + await self._base_session.send(data) + except ApplicationError as e: + msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args)) + data = self._session.send_message(msg_to_send) + await self._base_session.send(data) + except Exception as e: + msg_to_send = messages.Error( + messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_RUNTIME_ERROR, [e.__str__()]) + ) + data = self._session.send_message(msg_to_send) + await self._base_session.send(data) + + async def _handle_event(self, msg: messages.Event, endpoint: Callable[[types.Event], Awaitable[None]]): + try: + await endpoint(types.Event(msg.args, msg.kwargs, msg.details)) + except Exception as e: + print(e) + async def _wait(self): while await self._base_session.transport.is_connected(): try: @@ -84,12 +127,11 @@ async def _wait(self): print(e) break - task = self._loop.create_task(self._process_incoming_message(self._session.receive(data))) - self._tasks.add(task) - task.add_done_callback(self._tasks.discard) + await self._process_incoming_message(self._session.receive(data)) - for callback in self._disconnect_callback: - await callback() + if self._disconnect_callback: + callbacks = [callback() for callback in self._disconnect_callback] + await asyncio.gather(*callbacks) async def _process_incoming_message(self, msg: messages.Message): if isinstance(msg, messages.Registered): @@ -104,36 +146,10 @@ async def _process_incoming_message(self, msg: messages.Message): request = self._call_requests.pop(msg.request_id) request.set_result(types.Result(msg.args, msg.kwargs, msg.details)) elif isinstance(msg, messages.Invocation): - try: - endpoint = self._registrations[msg.registration_id] - result = await endpoint(types.Invocation(msg.args, msg.kwargs, msg.details)) - - if result is None: - data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id))) - elif isinstance(result, types.Result): - data = self._session.send_message( - messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details)) - ) - else: - message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str( - type(result) - ) - msg_to_send = messages.Error( - messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message]) - ) - data = self._session.send_message(msg_to_send) - - await self._base_session.send(data) - except ApplicationError as e: - msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args)) - data = self._session.send_message(msg_to_send) - await self._base_session.send(data) - except Exception as e: - msg_to_send = messages.Error( - messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_RUNTIME_ERROR, [e.__str__()]) - ) - data = self._session.send_message(msg_to_send) - await self._base_session.send(data) + endpoint = self._registrations[msg.registration_id] + task = self._loop.create_task(self._handle_invocation(msg, endpoint)) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) elif isinstance(msg, messages.Subscribed): request = self._subscribe_requests.pop(msg.request_id) self._subscriptions[msg.subscription_id] = request.endpoint @@ -147,10 +163,9 @@ async def _process_incoming_message(self, msg: messages.Message): request.set_result(None) elif isinstance(msg, messages.Event): endpoint = self._subscriptions[msg.subscription_id] - try: - await endpoint(types.Event(msg.args, msg.kwargs, msg.details)) - except Exception as e: - print(e) + task = self._loop.create_task(self._handle_event(msg, endpoint)) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) elif isinstance(msg, messages.Error): match msg.message_type: case messages.Call.TYPE: diff --git a/xconn/session.py b/xconn/session.py index b13aed7..4a4ea64 100644 --- a/xconn/session.py +++ b/xconn/session.py @@ -1,7 +1,8 @@ from __future__ import annotations -from concurrent.futures import Future -from threading import Thread +from concurrent.futures import Future, ThreadPoolExecutor, wait +import threading +from os import cpu_count from typing import Callable, Any from dataclasses import dataclass @@ -67,8 +68,12 @@ def __init__(self, base_session: types.BaseSession): self._session = session.WAMPSession(base_session.serializer) self._disconnect_callback: list[Callable[[], None] | None] = [] + self._stopped = threading.Event() - thread = Thread(target=self._wait, daemon=False) + # callback executor thread-pool + self._executor = ThreadPoolExecutor(max_workers=(cpu_count() or 1) * 4) + + thread = threading.Thread(target=self._wait, daemon=True) thread.start() def _wait(self): @@ -80,8 +85,54 @@ def _wait(self): self._process_incoming_message(self._session.receive(data)) - for callback in self._disconnect_callback: - callback() + # Shut down executor, cancelling anything still running + self._executor.shutdown(cancel_futures=True, wait=False) + + if self._disconnect_callback: + with ThreadPoolExecutor(max_workers=len(self._disconnect_callback)) as executor: + # Trigger disconnect callbacks concurrently + futures = [executor.submit(cb) for cb in self._disconnect_callback] + # Wait up to 1 second for them to finish + wait(futures, timeout=1) + + self._stopped.set() + + def _handle_invocation(self, msg: messages.Invocation, endpoint: Callable[[types.Invocation], types.Result]): + try: + result = endpoint(types.Invocation(msg.args, msg.kwargs, msg.details)) + + if result is None: + data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id))) + elif isinstance(result, types.Result): + data = self._session.send_message( + messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details)) + ) + else: + message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str( + type(result) + ) + msg_to_send = messages.Error( + messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message]) + ) + data = self._session.send_message(msg_to_send) + + self._base_session.send(data) + except ApplicationError as e: + msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args)) + data = self._session.send_message(msg_to_send) + self._base_session.send(data) + except Exception as e: + msg_to_send = messages.Error( + messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_RUNTIME_ERROR, [e.__str__()]) + ) + data = self._session.send_message(msg_to_send) + self._base_session.send(data) + + def _handle_event(self, msg: messages.Event, endpoint: Callable[[types.Event], None]): + try: + endpoint(types.Event(msg.args, msg.kwargs, msg.details)) + except Exception as e: + print(e) def _process_incoming_message(self, msg: messages.Message): if isinstance(msg, messages.Registered): @@ -98,28 +149,7 @@ def _process_incoming_message(self, msg: messages.Message): elif isinstance(msg, messages.Invocation): try: endpoint = self._registrations[msg.registration_id] - result = endpoint(types.Invocation(msg.args, msg.kwargs, msg.details)) - - if result is None: - data = self._session.send_message(messages.Yield(messages.YieldFields(msg.request_id))) - elif isinstance(result, types.Result): - data = self._session.send_message( - messages.Yield(messages.YieldFields(msg.request_id, result.args, result.kwargs, result.details)) - ) - else: - message = "Endpoint returned invalid result type. Expected types.Result or None, got: " + str( - type(result) - ) - msg_to_send = messages.Error( - messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_INTERNAL_ERROR, [message]) - ) - data = self._session.send_message(msg_to_send) - - self._base_session.send(data) - except ApplicationError as e: - msg_to_send = messages.Error(messages.ErrorFields(msg.TYPE, msg.request_id, e.message, e.args)) - data = self._session.send_message(msg_to_send) - self._base_session.send(data) + self._executor.submit(self._handle_invocation, msg, endpoint) except Exception as e: msg_to_send = messages.Error( messages.ErrorFields(msg.TYPE, msg.request_id, xconn_uris.ERROR_RUNTIME_ERROR, [e.__str__()]) @@ -140,7 +170,7 @@ def _process_incoming_message(self, msg: messages.Message): elif isinstance(msg, messages.Event): try: endpoint = self._subscriptions[msg.subscription_id] - endpoint(types.Event(msg.args, msg.kwargs, msg.details)) + self._executor.submit(self._handle_event, msg, endpoint) except Exception as e: print(e) elif isinstance(msg, messages.Error): @@ -295,3 +325,12 @@ def ping(self, timeout: int = 10) -> float: def _on_disconnect(self, callback: Callable[[], None]) -> None: if callback is not None: self._disconnect_callback.append(callback) + + def run_forever(self): + """Block until the session is closed/disconnected.""" + print("[Session] Running forever — press Ctrl+C to exit.") + try: + self._stopped.wait() + except KeyboardInterrupt: + print("[Session] Interrupted — shutting down...") + self.leave() diff --git a/xconn/transports.py b/xconn/transports.py index c0812ed..d781863 100644 --- a/xconn/transports.py +++ b/xconn/transports.py @@ -29,6 +29,19 @@ # Applies to handshake and message itself. RAW_SOCKET_HEADER_LENGTH = 4 +_ASYNC_CONNECTION_ERRORS = ( + asyncio.IncompleteReadError, + BrokenPipeError, + ConnectionResetError, + OSError, +) + +_CONNECTION_ERRORS = ( + BrokenPipeError, + ConnectionResetError, + OSError, +) + @dataclass class PendingPing: @@ -63,6 +76,7 @@ class RawSocketTransport(ITransport): def __init__(self, sock: socket.socket): super().__init__() self._sock = sock + self._connected = True self._pending_pings: dict[bytes, PendingPing] = {} self._write_lock = threading.Lock() @@ -97,6 +111,10 @@ def connect( return RawSocketTransport(sock) + def _mark_disconnected(self, _: Exception | None): + if self._connected: + self._connected = False + def read(self) -> str | bytes: msg_header_bytes = _recv_exactly(self._sock, RAW_SOCKET_HEADER_LENGTH) msg_header = MessageHeader.from_bytes(msg_header_bytes) @@ -105,9 +123,14 @@ def read(self) -> str | bytes: return _recv_exactly(self._sock, msg_header.length) elif msg_header.kind == MSG_TYPE_PING: ping_payload = _recv_exactly(self._sock, msg_header.length) - pong = MessageHeader(MSG_TYPE_PONG, msg_header.length) - self._sock.sendall(pong.to_bytes()) - self._sock.sendall(ping_payload) + pong_header = MessageHeader(MSG_TYPE_PONG, msg_header.length) + + try: + with self._write_lock: + self._sock.sendall(pong_header.to_bytes() + ping_payload) + except _CONNECTION_ERRORS as e: + self._mark_disconnected(e) + raise return self.read() elif msg_header.kind == MSG_TYPE_PONG: @@ -122,30 +145,36 @@ def read(self) -> str | bytes: raise ValueError(f"Unsupported message type {msg_header.kind}") def write(self, data: str | bytes): - msg_header = MessageHeader(MSG_TYPE_WAMP, len(data)) payload = data.encode() if isinstance(data, str) else data + msg_header = MessageHeader(MSG_TYPE_WAMP, len(payload)) - with self._write_lock: # ensure exclusive access - self._sock.sendall(msg_header.to_bytes()) - self._sock.sendall(payload) + try: + with self._write_lock: + self._sock.sendall(msg_header.to_bytes() + payload) + except _CONNECTION_ERRORS as e: + self._mark_disconnected(e) + raise def close(self): - self._sock.close() + try: + self._sock.close() + finally: + self._mark_disconnected(None) def is_connected(self) -> bool: - try: - self._sock.send(b"") # Send zero bytes - return True - except (BrokenPipeError, ConnectionResetError, OSError): - return False + return self._connected def ping(self, timeout: int = 10) -> float: f: ConcurrentFuture[int] = ConcurrentFuture() payload, ping_header, created_at = create_ping() self._pending_pings[payload] = PendingPing(f, created_at) - self._sock.sendall(ping_header.to_bytes()) - self._sock.sendall(payload) + try: + with self._write_lock: + self._sock.sendall(ping_header.to_bytes() + payload) + except _CONNECTION_ERRORS as e: + self._mark_disconnected(e) + raise return f.result(timeout) @@ -156,6 +185,7 @@ def __init__(self, reader: StreamReader, writer: StreamWriter): self._reader = reader self._writer = writer + self._connected = True self._pending_pings: dict[bytes, PendingPing] = {} @staticmethod @@ -189,23 +219,49 @@ async def connect( return AsyncRawSocketTransport(reader, writer) + def _mark_disconnected(self, _: Exception | None): + if self._connected: + self._connected = False + async def read(self) -> str | bytes: - msg_header_bytes = await self._reader.readexactly(RAW_SOCKET_HEADER_LENGTH) + try: + msg_header_bytes = await self._reader.readexactly(RAW_SOCKET_HEADER_LENGTH) + except _ASYNC_CONNECTION_ERRORS as e: + self._mark_disconnected(e) + raise + msg_header = MessageHeader.from_bytes(msg_header_bytes) if msg_header.kind == MSG_TYPE_WAMP: - return await self._reader.readexactly(msg_header.length) + try: + return await self._reader.readexactly(msg_header.length) + except _ASYNC_CONNECTION_ERRORS as e: + self._mark_disconnected(e) + raise elif msg_header.kind == MSG_TYPE_PING: - ping_payload = await self._reader.readexactly(msg_header.length) - pong = MessageHeader(MSG_TYPE_PONG, msg_header.length) - self._writer.write(pong.to_bytes()) - await self._writer.drain() - self._writer.write(ping_payload) - await self._writer.drain() + try: + ping_payload = await self._reader.readexactly(msg_header.length) + except _ASYNC_CONNECTION_ERRORS as e: + self._mark_disconnected(e) + raise + + pong_header = MessageHeader(MSG_TYPE_PONG, msg_header.length) + + try: + self._writer.write(pong_header.to_bytes() + ping_payload) + await self._writer.drain() + except _CONNECTION_ERRORS as e: + self._mark_disconnected(e) + raise return await self.read() elif msg_header.kind == MSG_TYPE_PONG: - pong_payload = await self._reader.readexactly(msg_header.length) + try: + pong_payload = await self._reader.readexactly(msg_header.length) + except _ASYNC_CONNECTION_ERRORS as e: + self._mark_disconnected(e) + raise + pending_ping = self._pending_pings.pop(pong_payload, None) if pending_ping is not None: received_at = time.time() * 1000 @@ -216,39 +272,36 @@ async def read(self) -> str | bytes: raise ValueError(f"Unsupported message type {msg_header.kind}") async def write(self, data: str | bytes): - msg_header = MessageHeader(MSG_TYPE_WAMP, len(data)) - - self._writer.write(msg_header.to_bytes()) - await self._writer.drain() - - if isinstance(data, str): - self._writer.write(data.encode()) - else: - self._writer.write(data) + payload = data.encode() if isinstance(data, str) else data + msg_header = MessageHeader(MSG_TYPE_WAMP, len(payload)) - await self._writer.drain() + try: + self._writer.write(msg_header.to_bytes() + payload) + await self._writer.drain() + except _CONNECTION_ERRORS as e: + self._mark_disconnected(e) + raise async def close(self): - self._writer.close() + try: + self._writer.close() + finally: + self._mark_disconnected(None) async def is_connected(self) -> bool: - try: - self._writer.write(b"") # Send zero bytes - await self._writer.drain() - return True - except (BrokenPipeError, ConnectionResetError, OSError): - return False + return self._connected async def ping(self, timeout: int = 10) -> float: f: Future[int] = Future() payload, ping_header, created_at = create_ping() self._pending_pings[payload] = PendingPing(f, created_at) - self._writer.write(ping_header.to_bytes()) - await self._writer.drain() - - self._writer.write(payload) - await self._writer.drain() + try: + self._writer.write(ping_header.to_bytes() + payload) + await self._writer.drain() + except _CONNECTION_ERRORS as e: + self._mark_disconnected(e) + raise return await asyncio.wait_for(f, timeout)