Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/replit_river/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,11 @@ def _trace_procedure(
except RiverException as e:
span.record_exception(e, escaped=True)
_record_river_error(span_handle, RiverError(code=e.code, message=e.message))
raise e
raise
except BaseException as e:
span.record_exception(e, escaped=True)
span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}")
raise e
raise
finally:
span.end()

Expand Down
4 changes: 2 additions & 2 deletions src/replit_river/v2/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ def _trace_procedure(
except RiverException as e:
span.record_exception(e, escaped=True)
_record_river_error(span_handle, RiverError(code=e.code, message=e.message))
raise e
raise
except BaseException as e:
span.record_exception(e, escaped=True)
span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}")
raise e
raise
finally:
span.end()

Expand Down
28 changes: 28 additions & 0 deletions src/replit_river/v2/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,13 @@ async def send_rpc[R, A](
# Block for backpressure and emission errors from the ws
await backpressured_waiter()
result = await anext(output)
except asyncio.CancelledError:
await self._send_cancel_stream(
stream_id=stream_id,
message="RPC cancelled",
span=span,
)
raise
except asyncio.TimeoutError as e:
await self._send_cancel_stream(
stream_id=stream_id,
Expand Down Expand Up @@ -835,6 +842,13 @@ async def send_upload[I, R, A](
payload=payload,
span=span,
)
except asyncio.CancelledError:
await self._send_cancel_stream(
stream_id=stream_id,
message="Upload cancelled",
span=span,
)
raise
except Exception as e:
# If we get any exception other than WebsocketClosedException,
# cancel the stream.
Expand Down Expand Up @@ -916,6 +930,13 @@ async def send_subscription[I, E, A](
continue
yield response_deserializer(item["payload"])
await self._send_close_stream(stream_id, span)
except asyncio.CancelledError:
await self._send_cancel_stream(
stream_id=stream_id,
message="Subscription cancelled",
span=span,
)
raise
except Exception as e:
await self._send_cancel_stream(
stream_id=stream_id,
Expand Down Expand Up @@ -1002,6 +1023,13 @@ async def _encode_stream() -> None:
# ... block the outer function until the emitter is finished emitting,
# possibly raising a terminal exception.
await emitter_task
except asyncio.CancelledError:
await self._send_cancel_stream(
stream_id=stream_id,
message="Stream cancelled",
span=span,
)
raise
except Exception as e:
await self._send_cancel_stream(
stream_id=stream_id,
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
pytest_plugins = [
"tests.v1.river_fixtures.logging",
"tests.v1.river_fixtures.clientserver",
"tests.v2.fixtures",
"tests.v2.fixtures.bound_client",
"tests.v2.fixtures.raw_ws_server",
]

HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"]
Expand Down
File renamed without changes.
92 changes: 92 additions & 0 deletions tests/v2/fixtures/raw_ws_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import asyncio
from typing import (
AsyncIterator,
Awaitable,
Callable,
Literal,
TypeAlias,
TypedDict,
)

import pytest
from websockets import ConnectionClosed, ConnectionClosedOK, Data
from websockets.asyncio.server import ServerConnection, serve

from replit_river.transport_options import UriAndMetadata

WsServerFixture: TypeAlias = tuple[
Callable[[], Awaitable[UriAndMetadata[None]]],
asyncio.Queue[bytes],
Callable[[], ServerConnection | None],
]


class OuterPayload[A](TypedDict):
ok: Literal[True]
payload: A


class _WsServerState(TypedDict):
ipv4_laddr: tuple[str, int] | None


async def _ws_server_internal(
recv: asyncio.Queue[bytes],
set_conn: Callable[[ServerConnection], None],
state: _WsServerState,
) -> AsyncIterator[None]:
async def handle(websocket: ServerConnection) -> None:
set_conn(websocket)
datagram: Data
try:
while datagram := await websocket.recv(decode=False):
if isinstance(datagram, str):
continue
await recv.put(datagram)
except ConnectionClosedOK:
pass
except ConnectionClosed:
pass

port: int | None = None
if state["ipv4_laddr"]:
port = state["ipv4_laddr"][1]
async with serve(handle, "localhost", port=port) as server:
for sock in server.sockets:
if (pair := sock.getsockname())[0] == "127.0.0.1":
if state["ipv4_laddr"] is None:
state["ipv4_laddr"] = pair
serve_forever = asyncio.create_task(server.serve_forever())
yield None
server.close()
await server.wait_closed()
# "serve_forever" should always be done after wait_closed finishes
assert serve_forever.done()


@pytest.fixture
async def ws_server() -> AsyncIterator[WsServerFixture]:
recv: asyncio.Queue[bytes] = asyncio.Queue(maxsize=1)
connection: ServerConnection | None = None
state: _WsServerState = {"ipv4_laddr": None}

def set_conn(new_conn: ServerConnection) -> None:
nonlocal connection
connection = new_conn

server_generator = _ws_server_internal(recv, set_conn, state)
await anext(server_generator)

async def urimeta() -> UriAndMetadata[None]:
ipv4_laddr = state["ipv4_laddr"]
assert ipv4_laddr
return UriAndMetadata(uri="ws://%s:%d" % ipv4_laddr, metadata=None)

yield (urimeta, recv, lambda: connection)

connection = None

try:
await anext(server_generator)
except StopAsyncIteration:
pass
Loading
Loading