Skip to content

Commit 15d333e

Browse files
Add v2 rpc tests
1 parent b323a06 commit 15d333e

File tree

5 files changed

+232
-0
lines changed

5 files changed

+232
-0
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from pydantic import BaseModel
3+
from typing import Literal
4+
5+
import replit_river as river
6+
7+
8+
from .test_service import Test_ServiceService
9+
10+
11+
class StreamClient:
12+
def __init__(self, client: river.v2.Client[Literal[None]]):
13+
self.test_service = Test_ServiceService(client)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
from typing import Any
4+
import datetime
5+
6+
from pydantic import TypeAdapter
7+
8+
from replit_river.error_schema import RiverError, RiverErrorTypeAdapter
9+
import replit_river as river
10+
11+
12+
from .rpc_method import (
13+
Rpc_MethodInit,
14+
Rpc_MethodOutput,
15+
Rpc_MethodOutputTypeAdapter,
16+
encode_Rpc_MethodInit,
17+
)
18+
19+
20+
class Test_ServiceService:
21+
def __init__(self, client: river.v2.Client[Any]):
22+
self.client = client
23+
24+
async def rpc_method(
25+
self,
26+
init: Rpc_MethodInit,
27+
timeout: datetime.timedelta,
28+
) -> Rpc_MethodOutput:
29+
return await self.client.send_rpc(
30+
"test_service",
31+
"rpc_method",
32+
init,
33+
encode_Rpc_MethodInit,
34+
lambda x: Rpc_MethodOutputTypeAdapter.validate_python(
35+
x # type: ignore[arg-type]
36+
),
37+
lambda x: RiverErrorTypeAdapter.validate_python(
38+
x # type: ignore[arg-type]
39+
),
40+
timeout,
41+
)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
import datetime
4+
from typing import (
5+
Any,
6+
Literal,
7+
Mapping,
8+
NotRequired,
9+
TypedDict,
10+
)
11+
from typing_extensions import Annotated
12+
13+
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
14+
from replit_river.error_schema import RiverError
15+
from replit_river.client import (
16+
RiverUnknownError,
17+
translate_unknown_error,
18+
RiverUnknownValue,
19+
translate_unknown_value,
20+
)
21+
22+
import replit_river as river
23+
24+
25+
def encode_Rpc_MethodInit(
26+
x: "Rpc_MethodInit",
27+
) -> Any:
28+
return {
29+
k: v
30+
for (k, v) in (
31+
{
32+
"data": x.get("data"),
33+
}
34+
).items()
35+
if v is not None
36+
}
37+
38+
39+
class Rpc_MethodInit(TypedDict):
40+
data: str
41+
42+
43+
class Rpc_MethodOutput(BaseModel):
44+
data: str
45+
46+
47+
Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter(
48+
Rpc_MethodOutput
49+
)

tests/v2/test_rpc.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import importlib
2+
import logging
3+
from collections import deque
4+
from datetime import timedelta
5+
from typing import (
6+
Literal,
7+
)
8+
9+
import pytest
10+
from pytest_snapshot.plugin import Snapshot
11+
12+
from replit_river.v2.client import Client
13+
from tests.fixtures.codegen_snapshot_fixtures import validate_codegen
14+
from tests.v2.datagrams import (
15+
ClientId,
16+
FromClient,
17+
ServerId,
18+
StreamAlias,
19+
TestTransport,
20+
ToClient,
21+
ValueSet,
22+
WaitForClosed,
23+
)
24+
25+
logger = logging.getLogger(__name__)
26+
27+
_AlreadyGenerated = False
28+
29+
30+
@pytest.fixture(scope="function", autouse=True)
31+
def stream_client_codegen(snapshot: Snapshot) -> Literal[True]:
32+
global _AlreadyGenerated
33+
if not _AlreadyGenerated:
34+
validate_codegen(
35+
snapshot=snapshot,
36+
snapshot_dir="tests/v2/codegen/snapshot/snapshots",
37+
read_schema=lambda: open("tests/v2/test_rpc.schema.json"),
38+
target_path="test_basic_rpc",
39+
client_name="StreamClient",
40+
protocol_version="v2.0",
41+
)
42+
_AlreadyGenerated = True
43+
44+
import tests.v2.codegen.snapshot.snapshots.test_basic_stream
45+
46+
importlib.reload(tests.v2.codegen.snapshot.snapshots.test_basic_stream)
47+
return True
48+
49+
50+
rpc_expected: deque[TestTransport] = deque(
51+
[
52+
FromClient(
53+
handshake_request=ValueSet(
54+
seq=0, # These don't count due to being during a handshake
55+
ack=0,
56+
from_=ServerId("server-001"),
57+
to=ClientId("client-001"),
58+
)
59+
),
60+
FromClient(
61+
stream_open=ValueSet(
62+
seq=0,
63+
ack=0,
64+
from_=ServerId("server-001"),
65+
to=ClientId("client-001"),
66+
serviceName="test_service",
67+
procedureName="rpc_method",
68+
create_alias=StreamAlias(1),
69+
payload={"data": "foo"},
70+
stream_closed=True,
71+
)
72+
),
73+
ToClient(
74+
seq=0,
75+
ack=1,
76+
stream_frame=(
77+
StreamAlias(1),
78+
{"ok": True, "payload": {"data": "Hello, foo!"}},
79+
),
80+
),
81+
WaitForClosed(),
82+
]
83+
)
84+
85+
86+
@pytest.mark.parametrize("expected", [rpc_expected])
87+
async def test_rpc(bound_client: Client) -> None:
88+
from tests.v2.codegen.snapshot.snapshots.test_basic_rpc import (
89+
StreamClient,
90+
)
91+
92+
res = await StreamClient(bound_client).test_service.rpc_method(
93+
init={"data": "foo"},
94+
timeout=timedelta(seconds=5),
95+
)
96+
97+
assert res.data == "Hello, foo!"

tests/v2/test_rpc.schema.json

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
{
2+
"services": {
3+
"test_service": {
4+
"procedures": {
5+
"rpc_method": {
6+
"init": {
7+
"type": "object",
8+
"properties": {
9+
"data": {
10+
"type": "string"
11+
}
12+
},
13+
"required": ["data"]
14+
},
15+
"output": {
16+
"type": "object",
17+
"properties": {
18+
"data": {
19+
"type": "string"
20+
}
21+
},
22+
"required": ["data"]
23+
},
24+
"errors": {
25+
"not": {}
26+
},
27+
"type": "rpc"
28+
}
29+
}
30+
}
31+
}
32+
}

0 commit comments

Comments
 (0)