Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/replit_river/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,11 @@ def _trace_procedure(
except RiverException as e:
span.record_exception(e, escaped=True)
_record_river_error(span_handle, RiverError(code=e.code, message=e.message))
raise e
raise
except BaseException as e:
span.record_exception(e, escaped=True)
span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}")
raise e
raise
finally:
span.end()

Expand Down
2 changes: 2 additions & 0 deletions src/replit_river/error_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ class SessionClosedRiverServiceException(RiverException):
def __init__(
self,
message: str,
streamId: str,
) -> None:
super().__init__(SYNTHETIC_ERROR_CODE_SESSION_CLOSED, message)
self.streamId = streamId


def exception_from_message(code: str) -> type[RiverServiceException]:
Expand Down
21 changes: 20 additions & 1 deletion src/replit_river/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import logging
from typing import Coroutine, Set

from replit_river.error_schema import ERROR_CODE_STREAM_CLOSED, RiverException
from replit_river.error_schema import (
ERROR_CODE_STREAM_CLOSED,
RiverException,
SessionClosedRiverServiceException,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -37,6 +41,13 @@ async def cancel_task(
# If we cancel the task manager we will get called here as well,
# if we want to handle the cancellation differently we can do it here.
logger.debug("Task was cancelled %r", task_to_remove)
except SessionClosedRiverServiceException as e:
logger.warning(
"Session was closed",
extra={
"stream_id": e.streamId,
},
)
except RiverException as e:
if e.code == ERROR_CODE_STREAM_CLOSED:
# Task is cancelled
Expand Down Expand Up @@ -76,6 +87,14 @@ def _task_done_callback(
):
# Task is cancelled
pass
elif isinstance(exception, SessionClosedRiverServiceException):
# Session is closed, don't bother logging
logger.info(
"Session closed",
extra={
"stream_id": exception.streamId,
},
)
else:
logger.error(
"Exception on cancelling task",
Expand Down
4 changes: 2 additions & 2 deletions src/replit_river/v2/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ def _trace_procedure(
except RiverException as e:
span.record_exception(e, escaped=True)
_record_river_error(span_handle, RiverError(code=e.code, message=e.message))
raise e
raise
except BaseException as e:
span.record_exception(e, escaped=True)
span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}")
raise e
raise
finally:
span.end()

Expand Down
47 changes: 44 additions & 3 deletions src/replit_river/v2/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ async def _enqueue_message(
# session is closing / closed, raise
raise SessionClosedRiverServiceException(
"river session is closed, dropping message",
stream_id,
)

# Begin critical section: Avoid any await between here and _send_buffer.append
Expand Down Expand Up @@ -448,14 +449,15 @@ async def do_close() -> None:

await self._task_manager.cancel_all_tasks()

for stream_meta in self._streams.values():
for stream_id, stream_meta in self._streams.items():
stream_meta["output"].close()
# Wake up backpressured writers
try:
stream_meta["error_channel"].put_nowait(
reason
or SessionClosedRiverServiceException(
"river session is closed",
stream_id,
)
)
except ChannelFull:
Expand Down Expand Up @@ -751,6 +753,13 @@ async def send_rpc[R, A](
# Block for backpressure and emission errors from the ws
await backpressured_waiter()
result = await anext(output)
except asyncio.CancelledError:
await self._send_cancel_stream(
stream_id=stream_id,
message="RPC cancelled",
span=span,
)
raise
except asyncio.TimeoutError as e:
await self._send_cancel_stream(
stream_id=stream_id,
Expand Down Expand Up @@ -835,6 +844,13 @@ async def send_upload[I, R, A](
payload=payload,
span=span,
)
except asyncio.CancelledError:
await self._send_cancel_stream(
stream_id=stream_id,
message="Upload cancelled",
span=span,
)
raise
except Exception as e:
# If we get any exception other than WebsocketClosedException,
# cancel the stream.
Expand Down Expand Up @@ -916,6 +932,13 @@ async def send_subscription[I, E, A](
continue
yield response_deserializer(item["payload"])
await self._send_close_stream(stream_id, span)
except asyncio.CancelledError:
await self._send_cancel_stream(
stream_id=stream_id,
message="Subscription cancelled",
span=span,
)
raise
except Exception as e:
await self._send_cancel_stream(
stream_id=stream_id,
Expand Down Expand Up @@ -1002,6 +1025,15 @@ async def _encode_stream() -> None:
# ... block the outer function until the emitter is finished emitting,
# possibly raising a terminal exception.
await emitter_task
except asyncio.CancelledError as e:
await self._send_cancel_stream(
stream_id=stream_id,
message="Stream cancelled",
span=span,
)
if emitter_task.done() and (err := emitter_task.exception()):
raise e from err
raise
except Exception as e:
await self._send_cancel_stream(
stream_id=stream_id,
Expand Down Expand Up @@ -1288,6 +1320,7 @@ async def _recv_from_ws(
# the outer loop.
await transition_no_connection()
break
msg: TransportMessage | str | None = None
try:
msg = parse_transport_msg(message)
logger.debug(
Expand Down Expand Up @@ -1367,19 +1400,27 @@ async def _recv_from_ws(
stream_meta["output"].close()
except OutOfOrderMessageException:
logger.exception("Out of order message, closing connection")
stream_id = "unknown"
if isinstance(msg, TransportMessage):
stream_id = msg.streamId
close_session(
SessionClosedRiverServiceException(
"Out of order message, closing connection"
"Out of order message, closing connection",
stream_id,
)
)
continue
except InvalidMessageException:
logger.exception(
"Got invalid transport message, closing session",
)
stream_id = "unknown"
if isinstance(msg, TransportMessage):
stream_id = msg.streamId
close_session(
SessionClosedRiverServiceException(
"Out of order message, closing connection"
"Out of order message, closing connection",
stream_id,
)
)
continue
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
pytest_plugins = [
"tests.v1.river_fixtures.logging",
"tests.v1.river_fixtures.clientserver",
"tests.v2.fixtures",
"tests.v2.fixtures.bound_client",
"tests.v2.fixtures.raw_ws_server",
]

HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"]
Expand Down
File renamed without changes.
92 changes: 92 additions & 0 deletions tests/v2/fixtures/raw_ws_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import asyncio
from typing import (
AsyncIterator,
Awaitable,
Callable,
Literal,
TypeAlias,
TypedDict,
)

import pytest
from websockets import ConnectionClosed, ConnectionClosedOK, Data
from websockets.asyncio.server import ServerConnection, serve

from replit_river.transport_options import UriAndMetadata

WsServerFixture: TypeAlias = tuple[
Callable[[], Awaitable[UriAndMetadata[None]]],
asyncio.Queue[bytes],
Callable[[], ServerConnection | None],
]


class OuterPayload[A](TypedDict):
ok: Literal[True]
payload: A


class _WsServerState(TypedDict):
ipv4_laddr: tuple[str, int] | None


async def _ws_server_internal(
recv: asyncio.Queue[bytes],
set_conn: Callable[[ServerConnection], None],
state: _WsServerState,
) -> AsyncIterator[None]:
async def handle(websocket: ServerConnection) -> None:
set_conn(websocket)
datagram: Data
try:
while datagram := await websocket.recv(decode=False):
if isinstance(datagram, str):
continue
await recv.put(datagram)
except ConnectionClosedOK:
pass
except ConnectionClosed:
pass

port: int | None = None
if state["ipv4_laddr"]:
port = state["ipv4_laddr"][1]
async with serve(handle, "localhost", port=port) as server:
for sock in server.sockets:
if (pair := sock.getsockname())[0] == "127.0.0.1":
if state["ipv4_laddr"] is None:
state["ipv4_laddr"] = pair
serve_forever = asyncio.create_task(server.serve_forever())
yield None
server.close()
await server.wait_closed()
# "serve_forever" should always be done after wait_closed finishes
assert serve_forever.done()


@pytest.fixture
async def ws_server() -> AsyncIterator[WsServerFixture]:
recv: asyncio.Queue[bytes] = asyncio.Queue(maxsize=1)
connection: ServerConnection | None = None
state: _WsServerState = {"ipv4_laddr": None}

def set_conn(new_conn: ServerConnection) -> None:
nonlocal connection
connection = new_conn

server_generator = _ws_server_internal(recv, set_conn, state)
await anext(server_generator)

async def urimeta() -> UriAndMetadata[None]:
ipv4_laddr = state["ipv4_laddr"]
assert ipv4_laddr
return UriAndMetadata(uri="ws://%s:%d" % ipv4_laddr, metadata=None)

yield (urimeta, recv, lambda: connection)

connection = None

try:
await anext(server_generator)
except StopAsyncIteration:
pass
Loading
Loading