Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
edde572
updates
airportyh Mar 13, 2025
f5190e3
refactor
airportyh Mar 13, 2025
766304e
updates to generated test output
airportyh Mar 13, 2025
9b0f693
tests pass
airportyh Mar 13, 2025
29b9353
lint
airportyh Mar 13, 2025
2f545bf
cleanup
airportyh Mar 13, 2025
6b57ec4
use Any
airportyh Mar 13, 2025
7be8c26
allow Any for error type so the existing error types can match them; …
airportyh Mar 13, 2025
c77679e
resnapshotted the tests
airportyh Mar 13, 2025
8dcd439
regened the code
airportyh Mar 13, 2025
cb3aac9
lint
airportyh Mar 14, 2025
cda628f
non-abreviated names
airportyh Mar 14, 2025
70c5ab8
test snapshot
airportyh Mar 14, 2025
1e3aa57
add UnknownRiverError and translate_unknown_error
airportyh Mar 14, 2025
693b174
lint
airportyh Mar 14, 2025
7203fef
fixes
airportyh Mar 14, 2025
26f3867
reverts
airportyh Mar 14, 2025
9b5740a
Merge remote-tracking branch 'origin/main' into th-stricter-error-types
airportyh Mar 14, 2025
86b36a1
fixed test
airportyh Mar 14, 2025
c46e64a
Update src/replit_river/client.py
airportyh Mar 15, 2025
672492f
make error type and translate function dynamic
airportyh Mar 21, 2025
3d4cbf0
Merge branch 'main' into th-stricter-error-types
airportyh Mar 21, 2025
3f959e7
updated snapshots
airportyh Mar 21, 2025
b16f181
lint
airportyh Mar 21, 2025
ada5db6
Make the code more readable
blast-hardcheese Mar 22, 2025
e6edc73
Test for unknown error values
blast-hardcheese Mar 22, 2025
ca6c0fe
Make it possible to request a client that can emit errors
blast-hardcheese Mar 22, 2025
a0db2fe
Adding a streaming method that can emit known and unknown errors
blast-hardcheese Mar 22, 2025
75f3431
Regenerating code
blast-hardcheese Mar 22, 2025
7f98b0c
Only generate the snapshot code once
blast-hardcheese Mar 22, 2025
8f5fce4
Actually write a test for stream errors
blast-hardcheese Mar 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion src/replit_river/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

from replit_river.client_transport import ClientTransport
from replit_river.error_schema import RiverError, RiverException
from replit_river.error_schema import ERROR_CODE_UNKNOWN, RiverError, RiverException
from replit_river.transport_options import (
HandshakeMetadataType,
TransportOptions,
Expand All @@ -37,6 +37,10 @@ class RiverUnknownValue(BaseModel):
value: Any


class RiverUnknownError(RiverError):
pass


def translate_unknown_value(
value: Any, handler: Callable[[Any], Any], info: ValidationInfo
) -> Any | RiverUnknownValue:
Expand All @@ -46,6 +50,21 @@ def translate_unknown_value(
return RiverUnknownValue(tag="RiverUnknownValue", value=value)


def translate_unknown_error(
value: Any, handler: Callable[[Any], Any], info: ValidationInfo
) -> Any | RiverUnknownError:
try:
return handler(value)
except Exception:
if isinstance(value, dict) and "code" in value and "message" in value:
return RiverUnknownError(
code=value["code"],
message=value["message"],
)
else:
return RiverUnknownError(code=ERROR_CODE_UNKNOWN, message="Unknown error")


class Client(Generic[HandshakeMetadataType]):
def __init__(
self,
Expand Down
30 changes: 23 additions & 7 deletions src/replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@

from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
from replit_river.error_schema import RiverError
from replit_river.client import RiverUnknownValue, translate_unknown_value
from replit_river.client import RiverUnknownError, translate_unknown_error, \
RiverUnknownValue, translate_unknown_value

import replit_river as river

Expand Down Expand Up @@ -154,6 +155,20 @@ def encode_type(
in_module: list[ModuleName],
permit_unknown_members: bool,
) -> tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]:
def _make_open_union_type_expr(one_of: list[TypeExpression]) -> OpenUnionTypeExpr:
if base_model == "RiverError":
return OpenUnionTypeExpr(
UnionTypeExpr(one_of),
fallback_type="RiverUnknownError",
validator_function="translate_unknown_error",
)
else:
return OpenUnionTypeExpr(
UnionTypeExpr(one_of),
fallback_type="RiverUnknownValue",
validator_function="translate_unknown_value",
)

encoder_name: TypeName | None = None # defining this up here to placate mypy
chunks: list[FileContents] = []
if isinstance(type, RiverNotType):
Expand Down Expand Up @@ -304,7 +319,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
)
union: TypeExpression
if permit_unknown_members:
union = OpenUnionTypeExpr(UnionTypeExpr(one_of))
union = _make_open_union_type_expr(one_of)
else:
union = UnionTypeExpr(one_of)
chunks.append(
Expand Down Expand Up @@ -383,7 +398,7 @@ def {_field_name}(
)
raise ValueError(f"What does it mean to have {_o2} here?")
if permit_unknown_members:
union = OpenUnionTypeExpr(UnionTypeExpr(any_of))
union = _make_open_union_type_expr(any_of)
else:
union = UnionTypeExpr(any_of)
if is_literal(type):
Expand Down Expand Up @@ -795,17 +810,18 @@ def _type_adapter_definition(
_type: TypeExpression,
module_info: list[ModuleName],
) -> tuple[list[TypeName], list[ModuleName], list[FileContents]]:
varname = render_type_expr(type_adapter_name)
rendered_type_expr = render_type_expr(_type)
return (
[type_adapter_name],
module_info,
[
FileContents(
dedent(f"""
{render_type_expr(type_adapter_name)}: TypeAdapter[Any] = (
TypeAdapter({rendered_type_expr})
)
""")
{varname}: TypeAdapter[{rendered_type_expr}] = (
TypeAdapter({rendered_type_expr})
)
""")
)
],
)
Expand Down
9 changes: 6 additions & 3 deletions src/replit_river/codegen/typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import NewType, assert_never
from typing import NewType, assert_never, cast

ModuleName = NewType("ModuleName", str)
ClassName = NewType("ClassName", str)
Expand Down Expand Up @@ -96,6 +96,8 @@ def __lt__(self, other: object) -> bool:
@dataclass(frozen=True)
class OpenUnionTypeExpr:
union: UnionTypeExpr
fallback_type: str
validator_function: str

def __str__(self) -> str:
raise Exception("Complex type must be put through render_type_expr!")
Expand Down Expand Up @@ -182,10 +184,11 @@ def render_type_expr(value: TypeExpression) -> str:
retval = "None"
return retval
case OpenUnionTypeExpr(inner):
open_union = cast(OpenUnionTypeExpr, value)
return (
"Annotated["
f"{render_type_expr(inner)} | RiverUnknownValue,"
"WrapValidator(translate_unknown_value)"
f"{render_type_expr(inner)} | {open_union.fallback_type},"
f"WrapValidator({open_union.validator_function})"
"]"
)
case TypeName(name):
Expand Down
3 changes: 3 additions & 0 deletions src/replit_river/error_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# ERROR_CODE_CANCEL is the code used when either server or client cancels the stream.
ERROR_CODE_CANCEL = "CANCEL"

# ERROR_CODE_UNKNOWN is the code for the RiverUnknownError
ERROR_CODE_UNKNOWN = "UNKNOWN"


class RiverError(BaseModel):
"""Error message from the server."""
Expand Down
13 changes: 10 additions & 3 deletions tests/codegen/rpc/generated/test_service/rpc_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
from replit_river.error_schema import RiverError
from replit_river.client import RiverUnknownValue, translate_unknown_value
from replit_river.client import (
RiverUnknownError,
translate_unknown_error,
RiverUnknownValue,
translate_unknown_value,
)

import replit_river as river

Expand All @@ -35,11 +40,13 @@ class Rpc_MethodInput(TypedDict):
data: str


Rpc_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodInput)
Rpc_MethodInputTypeAdapter: TypeAdapter[Rpc_MethodInput] = TypeAdapter(Rpc_MethodInput)


class Rpc_MethodOutput(BaseModel):
data: str


Rpc_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodOutput)
Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter(
Rpc_MethodOutput
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
Stream_MethodOutputTypeAdapter,
encode_Stream_MethodInput,
)
from .emit_error import Emit_ErrorErrors, Emit_ErrorErrorsTypeAdapter

intTypeAdapter: TypeAdapter[int] = TypeAdapter(int)


boolTypeAdapter: TypeAdapter[bool] = TypeAdapter(bool)


class Test_ServiceService:
Expand All @@ -40,3 +46,22 @@ async def stream_method(
x # type: ignore[arg-type]
),
)

async def emit_error(
self,
inputStream: AsyncIterable[int],
) -> AsyncIterator[bool | Emit_ErrorErrors | RiverError]:
return self.client.send_stream(
"test_service",
"emit_error",
None,
inputStream,
None,
lambda x: x,
lambda x: boolTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
lambda x: Emit_ErrorErrorsTypeAdapter.validate_python(
x # type: ignore[arg-type]
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Code generated by river.codegen. DO NOT EDIT.
from collections.abc import AsyncIterable, AsyncIterator
import datetime
from typing import (
Any,
Literal,
Mapping,
NotRequired,
TypedDict,
)
from typing_extensions import Annotated

from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
from replit_river.error_schema import RiverError
from replit_river.client import (
RiverUnknownError,
translate_unknown_error,
RiverUnknownValue,
translate_unknown_value,
)

import replit_river as river


class Emit_ErrorErrorsOneOf_DATA_LOSS(RiverError):
code: Literal["DATA_LOSS"]
message: str


class Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT(RiverError):
code: Literal["UNEXPECTED_DISCONNECT"]
message: str


Emit_ErrorErrors = Annotated[
Emit_ErrorErrorsOneOf_DATA_LOSS
| Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT
| RiverUnknownError,
WrapValidator(translate_unknown_error),
]


Emit_ErrorErrorsTypeAdapter: TypeAdapter[Emit_ErrorErrors] = TypeAdapter(
Emit_ErrorErrors
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
from replit_river.error_schema import RiverError
from replit_river.client import RiverUnknownValue, translate_unknown_value
from replit_river.client import (
RiverUnknownError,
translate_unknown_error,
RiverUnknownValue,
translate_unknown_value,
)

import replit_river as river

Expand All @@ -35,11 +40,15 @@ class Stream_MethodInput(TypedDict):
data: str


Stream_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodInput)
Stream_MethodInputTypeAdapter: TypeAdapter[Stream_MethodInput] = TypeAdapter(
Stream_MethodInput
)


class Stream_MethodOutput(BaseModel):
data: str


Stream_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodOutput)
Stream_MethodOutputTypeAdapter: TypeAdapter[Stream_MethodOutput] = TypeAdapter(
Stream_MethodOutput
)
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
encode_Pathological_MethodInputReq_Obj_Undefined,
)

boolTypeAdapter: TypeAdapter[Any] = TypeAdapter(bool)
boolTypeAdapter: TypeAdapter[bool] = TypeAdapter(bool)


class Test_ServiceService:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
from replit_river.error_schema import RiverError
from replit_river.client import RiverUnknownValue, translate_unknown_value
from replit_river.client import (
RiverUnknownError,
translate_unknown_error,
RiverUnknownValue,
translate_unknown_value,
)

import replit_river as river

Expand Down Expand Up @@ -473,6 +478,6 @@ class Pathological_MethodInput(TypedDict):
undefined: NotRequired[None]


Pathological_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(
Pathological_MethodInput
Pathological_MethodInputTypeAdapter: TypeAdapter[Pathological_MethodInput] = (
TypeAdapter(Pathological_MethodInput)
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
from replit_river.error_schema import RiverError
from replit_river.client import RiverUnknownValue, translate_unknown_value
from replit_river.client import (
RiverUnknownError,
translate_unknown_error,
RiverUnknownValue,
translate_unknown_value,
)

import replit_river as river

Expand All @@ -24,18 +29,32 @@ def encode_NeedsenumInput(x: "NeedsenumInput") -> Any:
return x


NeedsenumInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumInput)
NeedsenumInputTypeAdapter: TypeAdapter[NeedsenumInput] = TypeAdapter(NeedsenumInput)

NeedsenumOutput = Annotated[
Literal["out_first", "out_second"] | RiverUnknownValue,
WrapValidator(translate_unknown_value),
]

NeedsenumOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumOutput)
NeedsenumOutputTypeAdapter: TypeAdapter[NeedsenumOutput] = TypeAdapter(NeedsenumOutput)


class NeedsenumErrorsOneOf_err_first(RiverError):
code: Literal["err_first"]
message: str


class NeedsenumErrorsOneOf_err_second(RiverError):
code: Literal["err_second"]
message: str


NeedsenumErrors = Annotated[
Literal["err_first", "err_second"] | RiverUnknownValue,
WrapValidator(translate_unknown_value),
NeedsenumErrorsOneOf_err_first
| NeedsenumErrorsOneOf_err_second
| RiverUnknownError,
WrapValidator(translate_unknown_error),
]

NeedsenumErrorsTypeAdapter: TypeAdapter[Any] = TypeAdapter(NeedsenumErrors)

NeedsenumErrorsTypeAdapter: TypeAdapter[NeedsenumErrors] = TypeAdapter(NeedsenumErrors)
Loading
Loading