Skip to content

Commit 6c928f1

Browse files
Add v2 stream tests
1 parent 986c0ce commit 6c928f1

File tree

11 files changed

+892
-6
lines changed

11 files changed

+892
-6
lines changed

src/replit_river/v2/session.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@ def commit(msg: TransportMessage) -> None:
464464
# Wake up backpressured writer
465465
backpressure_waiter, _ = self._streams[pending.streamId]
466466
backpressure_waiter.set()
467+
467468
def get_next_pending() -> TransportMessage | None:
468469
if self._send_buffer:
469470
return self._send_buffer[0]
@@ -923,8 +924,7 @@ async def _encode_stream() -> None:
923924
yield error_deserializer(result["payload"])
924925
except Exception:
925926
logger.exception(
926-
"Error during stream "
927-
f"error deserialization: {result}"
927+
f"Error during stream error deserialization: {result}"
928928
)
929929
continue
930930
yield response_deserializer(result["payload"])
@@ -1228,10 +1228,14 @@ async def _serve(
12281228
logger.debug(f"_serve loop count={idx}")
12291229
idx += 1
12301230
ws = None
1231-
while (
1232-
state := get_state()
1233-
) in ConnectingStates or (ws := get_ws()) is None:
1234-
logger.debug("_handle_messages_from_ws spinning while connecting, %r %r", ws, state)
1231+
while (state := get_state()) in ConnectingStates or (
1232+
ws := get_ws()
1233+
) is None:
1234+
logger.debug(
1235+
"_handle_messages_from_ws spinning while connecting, %r %r",
1236+
ws,
1237+
state,
1238+
)
12351239
await block_until_connected()
12361240
if state in TerminalStates:
12371241
break

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
pytest_plugins = [
1818
"tests.v1.river_fixtures.logging",
1919
"tests.v1.river_fixtures.clientserver",
20+
"tests.v2.fixtures",
2021
]
2122

2223
HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from pydantic import BaseModel
3+
from typing import Literal
4+
5+
import replit_river as river
6+
7+
8+
from .test_service import Test_ServiceService
9+
10+
11+
class StreamClient:
12+
def __init__(self, client: river.v2.Client[Literal[None]]):
13+
self.test_service = Test_ServiceService(client)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
from typing import Any
4+
import datetime
5+
6+
from pydantic import TypeAdapter
7+
8+
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
9+
import replit_river as river
10+
11+
12+
from .stream_method import (
13+
Stream_MethodInit,
14+
Stream_MethodInput,
15+
Stream_MethodOutput,
16+
Stream_MethodOutputTypeAdapter,
17+
encode_Stream_MethodInit,
18+
encode_Stream_MethodInput,
19+
)
20+
21+
22+
class Test_ServiceService:
23+
def __init__(self, client: river.v2.Client[Any]):
24+
self.client = client
25+
26+
async def stream_method(
27+
self,
28+
init: Stream_MethodInit,
29+
inputStream: AsyncIterable[Stream_MethodInput],
30+
) -> AsyncIterator[Stream_MethodOutput | RiverError | RiverError]:
31+
return self.client.send_stream(
32+
"test_service",
33+
"stream_method",
34+
init,
35+
inputStream,
36+
encode_Stream_MethodInit,
37+
encode_Stream_MethodInput,
38+
lambda x: Stream_MethodOutputTypeAdapter.validate_python(
39+
x # type: ignore[arg-type]
40+
),
41+
lambda x: RiverErrorTypeAdapter.validate_python(
42+
x # type: ignore[arg-type]
43+
),
44+
)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
import datetime
4+
from typing import (
5+
Any,
6+
Literal,
7+
Mapping,
8+
NotRequired,
9+
TypedDict,
10+
)
11+
from typing_extensions import Annotated
12+
13+
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
14+
from replit_river.error_schema import RiverError
15+
from replit_river.client import (
16+
RiverUnknownError,
17+
translate_unknown_error,
18+
RiverUnknownValue,
19+
translate_unknown_value,
20+
)
21+
22+
import replit_river as river
23+
24+
25+
class Emit_ErrorErrorsOneOf_DATA_LOSS(RiverError):
26+
code: Literal["DATA_LOSS"]
27+
message: str
28+
29+
30+
class Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT(RiverError):
31+
code: Literal["UNEXPECTED_DISCONNECT"]
32+
message: str
33+
34+
35+
Emit_ErrorErrors = Annotated[
36+
Emit_ErrorErrorsOneOf_DATA_LOSS
37+
| Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT
38+
| RiverUnknownError,
39+
WrapValidator(translate_unknown_error),
40+
]
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
import datetime
4+
from typing import (
5+
Any,
6+
Literal,
7+
Mapping,
8+
NotRequired,
9+
TypedDict,
10+
)
11+
from typing_extensions import Annotated
12+
13+
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
14+
from replit_river.error_schema import RiverError
15+
from replit_river.client import (
16+
RiverUnknownError,
17+
translate_unknown_error,
18+
RiverUnknownValue,
19+
translate_unknown_value,
20+
)
21+
22+
import replit_river as river
23+
24+
25+
def encode_Stream_MethodInit(
26+
_: "Stream_MethodInit",
27+
) -> Any:
28+
return {}
29+
30+
31+
class Stream_MethodInit(TypedDict):
32+
pass
33+
34+
35+
def encode_Stream_MethodInput(
36+
x: "Stream_MethodInput",
37+
) -> Any:
38+
return {
39+
k: v
40+
for (k, v) in (
41+
{
42+
"data": x.get("data"),
43+
}
44+
).items()
45+
if v is not None
46+
}
47+
48+
49+
class Stream_MethodInput(TypedDict):
50+
data: str
51+
52+
53+
class Stream_MethodOutput(BaseModel):
54+
data: str
55+
56+
57+
Stream_MethodOutputTypeAdapter: TypeAdapter[Stream_MethodOutput] = TypeAdapter(
58+
Stream_MethodOutput
59+
)

tests/v2/datagrams.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
from dataclasses import dataclass
2+
from typing import (
3+
Any,
4+
NewType,
5+
TypeAlias,
6+
)
7+
8+
Datagram = dict[str, Any]
9+
TestTransport: TypeAlias = "FromClient | ToClient | WaitForClosed"
10+
11+
StreamId = NewType("StreamId", str)
12+
ClientId = NewType("ClientId", str)
13+
ServerId = NewType("ServerId", str)
14+
SessionId = NewType("SessionId", str)
15+
16+
17+
@dataclass(frozen=True)
18+
class StreamAlias:
19+
alias_id: int
20+
21+
22+
@dataclass(frozen=True)
23+
class ValueSet:
24+
seq: int
25+
ack: int
26+
from_: ServerId | None = None
27+
to: ClientId | None = None
28+
procedureName: str | None = None
29+
serviceName: str | None = None
30+
create_alias: StreamAlias | None = None
31+
stream_alias: StreamAlias | None = None
32+
payload: Datagram | None = None
33+
34+
35+
@dataclass(frozen=True)
36+
class FromClient:
37+
handshake_request: tuple[ClientId, ServerId, SessionId] | ValueSet | None = None
38+
stream_open: tuple[ClientId, ServerId, str, str, StreamId] | ValueSet | None = None
39+
stream_frame: tuple[ClientId, ServerId, int, int, Datagram] | ValueSet | None = None
40+
41+
42+
@dataclass(frozen=True)
43+
class ToClient:
44+
seq: int
45+
ack: int
46+
control_flags: int = 0
47+
handshake_response: bool | None = None
48+
stream_frame: tuple[StreamAlias, Datagram] | None = None
49+
stream_close: StreamAlias | None = None
50+
51+
52+
@dataclass(frozen=True)
53+
class WaitForClosed:
54+
pass
55+
56+
57+
def decode_FromClient(datagram: dict[str, Any]) -> FromClient:
58+
assert "from" in datagram
59+
assert "to" in datagram
60+
if datagram.get("payload", {}).get("type") == "HANDSHAKE_REQ":
61+
assert "payload" in datagram
62+
assert "sessionId" in datagram["payload"]
63+
return FromClient(
64+
handshake_request=(
65+
ClientId(datagram["from"]),
66+
ServerId(datagram["to"]),
67+
SessionId(datagram["payload"]["sessionId"]),
68+
)
69+
)
70+
elif datagram.get("controlFlags", 0) & 0b00010 > 0: # STREAM_OPEN_BIT
71+
return FromClient(
72+
stream_open=(
73+
ClientId(datagram["from"]),
74+
ServerId(datagram["to"]),
75+
datagram["serviceName"],
76+
datagram["procedureName"],
77+
StreamId(datagram["streamId"]),
78+
)
79+
)
80+
elif datagram:
81+
return FromClient(
82+
stream_frame=(
83+
ClientId(datagram["from"]),
84+
ServerId(datagram["to"]),
85+
datagram["seq"],
86+
datagram["ack"],
87+
datagram["payload"],
88+
)
89+
)
90+
raise ValueError("Unexpected datagram: %r", datagram)
91+
92+
93+
def parser(datagram: dict[str, Any]) -> FromClient:
94+
return decode_FromClient(datagram)

0 commit comments

Comments
 (0)