Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ classifiers = [
]
dependencies = [
"pydantic==2.9.2",
"aiochannel>=1.2.1",
"black>=23.11,<25.0",
"grpcio-tools>=1.59.3",
"grpcio>=1.59.3",
"msgpack-types>=0.3.0",
"msgpack>=1.0.7",
"nanoid>=2.0.0",
"protobuf>=4.24.4",
"pydantic-core>=2.20.1",
"websockets>=12.0",
"aiochannel>=1.2.1",
"black>=23.11,<25.0",
"grpcio-tools>=1.59.3",
"grpcio>=1.59.3",
"msgpack-types>=0.3.0",
"msgpack>=1.0.7",
"nanoid>=2.0.0",
"protobuf>=4.24.4",
"pydantic-core>=2.20.1",
"websockets>=12.0",
"opentelemetry-sdk>=1.27.0,<1.28.0",
"opentelemetry-api>=1.27.0,<1.28.0",
]

[tool.uv]
Expand Down
154 changes: 122 additions & 32 deletions replit_river/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
from typing import Any, Generic, Optional, Union

from replit_river.client_interceptor import (
ClientInterceptor,
ClientRpcDetails,
ClientStreamDetails,
ClientSubscriptionDetails,
ClientUploadDetails,
)
from replit_river.client_transport import ClientTransport
from replit_river.transport_options import (
HandshakeMetadataType,
Expand All @@ -28,6 +35,7 @@ def __init__(
client_id: str,
server_id: str,
transport_options: TransportOptions,
interceptors: list[ClientInterceptor] = [],
) -> None:
self._client_id = client_id
self._server_id = server_id
Expand All @@ -37,6 +45,7 @@ def __init__(
server_id=server_id,
transport_options=transport_options,
)
self._interceptors = interceptors

async def close(self) -> None:
logger.info(f"river client {self._client_id} start closing")
Expand All @@ -56,13 +65,34 @@ async def send_rpc(
error_deserializer: Callable[[Any], ErrorType],
) -> ResponseType:
session = await self._transport.get_or_create_session()
return await session.send_rpc(
service_name,
procedure_name,
request,
request_serializer,
response_deserializer,
error_deserializer,

async def _run_interceptor(
details: ClientRpcDetails,
interceptors: list[ClientInterceptor],
) -> ResponseType:
if interceptors:
head, tail = interceptors[0], interceptors[1:]
return await head.intercept_rpc( # type: ignore
details,
lambda details: _run_interceptor(details, tail),
)
else:
return await session.send_rpc(
details.service_name,
details.procedure_name,
details.request,
request_serializer,
response_deserializer,
error_deserializer,
)

return await _run_interceptor(
ClientRpcDetails(
service_name=service_name,
procedure_name=procedure_name,
request=request,
),
self._interceptors,
)

async def send_upload(
Expand All @@ -77,15 +107,35 @@ async def send_upload(
error_deserializer: Callable[[Any], ErrorType],
) -> ResponseType:
session = await self._transport.get_or_create_session()
return await session.send_upload(
service_name,
procedure_name,
init,
request,
init_serializer,
request_serializer,
response_deserializer,
error_deserializer,

async def _run_interceptor(
details: ClientUploadDetails,
interceptors: list[ClientInterceptor],
) -> ResponseType:
if interceptors:
head, tail = interceptors[0], interceptors[1:]
return await head.intercept_upload( # type: ignore
details, lambda details: _run_interceptor(details, tail)
)
else:
return await session.send_upload(
service_name,
procedure_name,
init,
request,
init_serializer,
request_serializer,
response_deserializer,
error_deserializer,
)

return await _run_interceptor(
ClientUploadDetails(
service_name=service_name,
procedure_name=procedure_name,
init=init,
),
self._interceptors,
)

async def send_subscription(
Expand All @@ -98,13 +148,33 @@ async def send_subscription(
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
session = await self._transport.get_or_create_session()
return session.send_subscription(
service_name,
procedure_name,
request,
request_serializer,
response_deserializer,
error_deserializer,

async def _run_interceptor(
details: ClientSubscriptionDetails,
interceptors: list[ClientInterceptor],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
if interceptors:
head, tail = interceptors[0], interceptors[1:]
return await head.intercept_subscription( # type: ignore
details, lambda details: _run_interceptor(details, tail)
)
else:
return session.send_subscription(
service_name,
procedure_name,
request,
request_serializer,
response_deserializer,
error_deserializer,
)

return await _run_interceptor(
ClientSubscriptionDetails(
service_name=service_name,
procedure_name=procedure_name,
request=request,
),
self._interceptors,
)

async def send_stream(
Expand All @@ -119,13 +189,33 @@ async def send_stream(
error_deserializer: Callable[[Any], ErrorType],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
session = await self._transport.get_or_create_session()
return session.send_stream(
service_name,
procedure_name,
init,
request,
init_serializer,
request_serializer,
response_deserializer,
error_deserializer,

async def _run_interceptor(
details: ClientStreamDetails,
interceptors: list[ClientInterceptor],
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
if interceptors:
head, tail = interceptors[0], interceptors[1:]
return await head.intercept_stream( # type: ignore
details, lambda details: _run_interceptor(details, tail)
)
else:
return session.send_stream(
service_name,
procedure_name,
init,
request,
init_serializer,
request_serializer,
response_deserializer,
error_deserializer,
)

return await _run_interceptor(
ClientStreamDetails(
service_name=service_name,
procedure_name=procedure_name,
init=init,
),
self._interceptors,
)
74 changes: 74 additions & 0 deletions replit_river/client_interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Awaitable, Callable, NamedTuple, Optional


class ClientRpcDetails(NamedTuple):
service_name: str
procedure_name: str
request: Any


class ClientUploadDetails(NamedTuple):
service_name: str
procedure_name: str
init: Optional[Any]


class ClientSubscriptionDetails(NamedTuple):
service_name: str
procedure_name: str
request: Any


class ClientStreamDetails(NamedTuple):
service_name: str
procedure_name: str
init: Optional[Any]


class ClientInterceptor(ABC):
@abstractmethod
async def intercept_rpc(
self,
details: ClientRpcDetails,
continuation: Callable[[ClientRpcDetails], Awaitable[Any]],
) -> Any:
"""
TODO: docs
"""

@abstractmethod
async def intercept_upload(
self,
details: ClientUploadDetails,
continuation: Callable[[ClientUploadDetails], Awaitable[Any]],
) -> Any:
"""
TODO: docs
"""

@abstractmethod
async def intercept_subscription(
self,
details: ClientSubscriptionDetails,
continuation: Callable[
[ClientSubscriptionDetails],
Awaitable[AsyncIterator[Any]],
],
) -> AsyncIterator[Any]:
"""
TODO: docs
"""

@abstractmethod
async def intercept_stream(
self,
details: ClientStreamDetails,
continuation: Callable[
[ClientStreamDetails],
Awaitable[AsyncIterator[Any]],
],
) -> AsyncIterator[Any]:
"""
TODO: docs
"""
86 changes: 86 additions & 0 deletions replit_river/contrib/opentelemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Any, AsyncIterator, Awaitable, Callable

from opentelemetry import trace

from replit_river.client_interceptor import (
ClientInterceptor,
ClientRpcDetails,
ClientStreamDetails,
ClientSubscriptionDetails,
ClientUploadDetails,
)
from replit_river.error_schema import RiverException

tracer = trace.get_tracer(__name__)


class OpenTelemetryClientInterceptor(ClientInterceptor):
async def intercept_rpc(
self,
details: ClientRpcDetails,
continuation: Callable[[ClientRpcDetails], Awaitable[Any]],
) -> Any:
with tracer.start_as_current_span(
f"river.rpc.{details.service_name}.{details.procedure_name}",
kind=trace.SpanKind.CLIENT,
) as span:
try:
return await continuation(details)
except RiverException as e:
span.set_attribute("river.error_code", e.code)
span.set_attribute("river.error_message", e.message)
return e

async def intercept_upload(
self,
details: ClientUploadDetails,
continuation: Callable[[ClientUploadDetails], Awaitable[Any]],
) -> Any:
with tracer.start_as_current_span(
f"river.upload.{details.service_name}.{details.procedure_name}",
kind=trace.SpanKind.CLIENT,
) as span:
try:
return await continuation(details)
except RiverException as e:
span.set_attribute("river.error_code", e.code)
span.set_attribute("river.error_message", e.message)
return e

async def intercept_subscription(
self,
details: ClientSubscriptionDetails,
continuation: Callable[
[ClientSubscriptionDetails],
Awaitable[AsyncIterator[Any]],
],
) -> Any:
with tracer.start_as_current_span(
f"river.subscription.{details.service_name}.{details.procedure_name}",
kind=trace.SpanKind.CLIENT,
) as span:
try:
return await continuation(details)
except RiverException as e:
span.set_attribute("river.error_code", e.code)
span.set_attribute("river.error_message", e.message)
return e

async def intercept_stream(
self,
details: ClientStreamDetails,
continuation: Callable[
[ClientStreamDetails],
Awaitable[AsyncIterator[Any]],
],
) -> Any:
with tracer.start_as_current_span(
f"river.stream.{details.service_name}.{details.procedure_name}",
kind=trace.SpanKind.CLIENT,
) as span:
try:
return await continuation(details)
except RiverException as e:
span.set_attribute("river.error_code", e.code)
span.set_attribute("river.error_message", e.message)
return e
Loading