Skip to content

Commit 9b44275

Browse files
v2 client
1 parent 8150fa9 commit 9b44275

File tree

4 files changed

+1124
-0
lines changed

4 files changed

+1124
-0
lines changed

src/replit_river/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
upload_method_handler,
1010
)
1111
from .server import Server
12+
import v2
1213

1314
__all__ = [
1415
"Client",
@@ -20,4 +21,5 @@
2021
"subscription_method_handler",
2122
"upload_method_handler",
2223
"stream_method_handler",
24+
"v2",
2325
]

src/replit_river/v2/client.py

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import logging
2+
from collections.abc import AsyncIterable, Awaitable, Callable
3+
from contextlib import contextmanager
4+
from dataclasses import dataclass
5+
from datetime import timedelta
6+
from typing import Any, AsyncGenerator, Generator, Generic, Literal
7+
8+
from opentelemetry import trace
9+
from opentelemetry.trace import Span, SpanKind, Status, StatusCode
10+
from pydantic import (
11+
BaseModel,
12+
ValidationInfo,
13+
)
14+
15+
from replit_river.client_transport import ClientTransport
16+
from replit_river.error_schema import ERROR_CODE_UNKNOWN, RiverError, RiverException
17+
from replit_river.rpc import (
18+
ErrorType,
19+
InitType,
20+
RequestType,
21+
ResponseType,
22+
)
23+
from replit_river.transport_options import (
24+
HandshakeMetadataType,
25+
TransportOptions,
26+
UriAndMetadata,
27+
)
28+
29+
logger = logging.getLogger(__name__)
30+
tracer = trace.get_tracer(__name__)
31+
32+
33+
@dataclass(frozen=True)
34+
class RiverUnknownValue(BaseModel):
35+
tag: Literal["RiverUnknownValue"]
36+
value: Any
37+
38+
39+
class RiverUnknownError(RiverError):
40+
pass
41+
42+
43+
def translate_unknown_value(
44+
value: Any, handler: Callable[[Any], Any], info: ValidationInfo
45+
) -> Any | RiverUnknownValue:
46+
try:
47+
return handler(value)
48+
except Exception:
49+
return RiverUnknownValue(tag="RiverUnknownValue", value=value)
50+
51+
52+
def translate_unknown_error(
53+
value: Any, handler: Callable[[Any], Any], info: ValidationInfo
54+
) -> Any | RiverUnknownError:
55+
try:
56+
return handler(value)
57+
except Exception:
58+
if isinstance(value, dict) and "code" in value and "message" in value:
59+
return RiverUnknownError(
60+
code=value["code"],
61+
message=value["message"],
62+
)
63+
else:
64+
return RiverUnknownError(code=ERROR_CODE_UNKNOWN, message="Unknown error")
65+
66+
67+
class Client(Generic[HandshakeMetadataType]):
68+
def __init__(
69+
self,
70+
uri_and_metadata_factory: Callable[
71+
[], Awaitable[UriAndMetadata[HandshakeMetadataType]]
72+
],
73+
client_id: str,
74+
server_id: str,
75+
transport_options: TransportOptions,
76+
) -> None:
77+
self._client_id = client_id
78+
self._server_id = server_id
79+
self._transport = ClientTransport[HandshakeMetadataType](
80+
uri_and_metadata_factory=uri_and_metadata_factory,
81+
client_id=client_id,
82+
server_id=server_id,
83+
transport_options=transport_options,
84+
)
85+
86+
async def close(self) -> None:
87+
logger.info(f"river client {self._client_id} start closing")
88+
await self._transport.close()
89+
logger.info(f"river client {self._client_id} closed")
90+
91+
async def ensure_connected(self) -> None:
92+
await self._transport.get_or_create_session()
93+
94+
async def send_rpc(
95+
self,
96+
service_name: str,
97+
procedure_name: str,
98+
request: RequestType,
99+
request_serializer: Callable[[RequestType], Any],
100+
response_deserializer: Callable[[Any], ResponseType],
101+
error_deserializer: Callable[[Any], ErrorType],
102+
timeout: timedelta,
103+
) -> ResponseType:
104+
with _trace_procedure("rpc", service_name, procedure_name) as span_handle:
105+
session = await self._transport.get_or_create_session()
106+
return await session.send_rpc(
107+
service_name,
108+
procedure_name,
109+
request,
110+
request_serializer,
111+
response_deserializer,
112+
error_deserializer,
113+
span_handle.span,
114+
timeout,
115+
)
116+
117+
async def send_upload(
118+
self,
119+
service_name: str,
120+
procedure_name: str,
121+
init: InitType | None,
122+
request: AsyncIterable[RequestType],
123+
init_serializer: Callable[[InitType], Any] | None,
124+
request_serializer: Callable[[RequestType], Any],
125+
response_deserializer: Callable[[Any], ResponseType],
126+
error_deserializer: Callable[[Any], ErrorType],
127+
) -> ResponseType:
128+
with _trace_procedure("upload", service_name, procedure_name) as span_handle:
129+
session = await self._transport.get_or_create_session()
130+
return await session.send_upload(
131+
service_name,
132+
procedure_name,
133+
init,
134+
request,
135+
init_serializer,
136+
request_serializer,
137+
response_deserializer,
138+
error_deserializer,
139+
span_handle.span,
140+
)
141+
142+
async def send_subscription(
143+
self,
144+
service_name: str,
145+
procedure_name: str,
146+
request: RequestType,
147+
request_serializer: Callable[[RequestType], Any],
148+
response_deserializer: Callable[[Any], ResponseType],
149+
error_deserializer: Callable[[Any], ErrorType],
150+
) -> AsyncGenerator[ResponseType | RiverError, None]:
151+
with _trace_procedure(
152+
"subscription", service_name, procedure_name
153+
) as span_handle:
154+
session = await self._transport.get_or_create_session()
155+
async for msg in session.send_subscription(
156+
service_name,
157+
procedure_name,
158+
request,
159+
request_serializer,
160+
response_deserializer,
161+
error_deserializer,
162+
span_handle.span,
163+
):
164+
if isinstance(msg, RiverError):
165+
_record_river_error(span_handle, msg)
166+
yield msg # type: ignore # https://github.com/python/mypy/issues/10817
167+
168+
async def send_stream(
169+
self,
170+
service_name: str,
171+
procedure_name: str,
172+
init: InitType | None,
173+
request: AsyncIterable[RequestType],
174+
init_serializer: Callable[[InitType], Any] | None,
175+
request_serializer: Callable[[RequestType], Any],
176+
response_deserializer: Callable[[Any], ResponseType],
177+
error_deserializer: Callable[[Any], ErrorType],
178+
) -> AsyncGenerator[ResponseType | RiverError, None]:
179+
with _trace_procedure("stream", service_name, procedure_name) as span_handle:
180+
session = await self._transport.get_or_create_session()
181+
async for msg in session.send_stream(
182+
service_name,
183+
procedure_name,
184+
init,
185+
request,
186+
init_serializer,
187+
request_serializer,
188+
response_deserializer,
189+
error_deserializer,
190+
span_handle.span,
191+
):
192+
if isinstance(msg, RiverError):
193+
_record_river_error(span_handle, msg)
194+
yield msg # type: ignore # https://github.com/python/mypy/issues/10817
195+
196+
197+
@dataclass
198+
class _SpanHandle:
199+
"""Wraps a span and keeps track of whether or not a status has been recorded yet."""
200+
201+
span: Span
202+
did_set_status: bool = False
203+
204+
def set_status(
205+
self,
206+
status: Status | StatusCode,
207+
description: str | None = None,
208+
) -> None:
209+
if self.did_set_status:
210+
return
211+
self.did_set_status = True
212+
self.span.set_status(status, description)
213+
214+
215+
@contextmanager
216+
def _trace_procedure(
217+
procedure_type: Literal["rpc", "upload", "subscription", "stream"],
218+
service_name: str,
219+
procedure_name: str,
220+
) -> Generator[_SpanHandle, None, None]:
221+
span = tracer.start_span(
222+
f"river.client.{procedure_type}.{service_name}.{procedure_name}",
223+
kind=SpanKind.CLIENT,
224+
)
225+
span_handle = _SpanHandle(span)
226+
try:
227+
yield span_handle
228+
except GeneratorExit:
229+
# This error indicates the caller is done with the async generator
230+
# but messages are still left. This is okay, we do not consider it an error.
231+
raise
232+
except RiverException as e:
233+
span.record_exception(e, escaped=True)
234+
_record_river_error(span_handle, RiverError(code=e.code, message=e.message))
235+
raise e
236+
except BaseException as e:
237+
span.record_exception(e, escaped=True)
238+
span_handle.set_status(StatusCode.ERROR, f"{type(e).__name__}: {e}")
239+
raise e
240+
finally:
241+
span_handle.set_status(StatusCode.OK)
242+
span.end()
243+
244+
245+
def _record_river_error(span_handle: _SpanHandle, error: RiverError) -> None:
246+
span_handle.set_status(StatusCode.ERROR, error.message)
247+
span_handle.span.record_exception(RiverException(error.code, error.message))
248+
span_handle.span.set_attribute("river.error_code", error.code)
249+
span_handle.span.set_attribute("river.error_message", error.message)

0 commit comments

Comments
 (0)