Skip to content

Commit 38554fc

Browse files
Stricter error types (#141)
Why === Have the generated code return error types that actually satisfy the `ErrorType`. Previously the union type for error types generated can be values of `RiverUnknownValue` which is not an error type. This introduces a `RiverUnknownError` and uses that in the error unions instead. Also type adapter types are now more strictly typed. # Testing When switching the client code over, we should get errors about `RiverUnknownError` not being handled and force you to fix them. --------- Co-authored-by: Devon Stewart <[email protected]>
1 parent 2983732 commit 38554fc

File tree

17 files changed

+397
-51
lines changed

17 files changed

+397
-51
lines changed

src/replit_river/client.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414

1515
from replit_river.client_transport import ClientTransport
16-
from replit_river.error_schema import RiverError, RiverException
16+
from replit_river.error_schema import ERROR_CODE_UNKNOWN, RiverError, RiverException
1717
from replit_river.transport_options import (
1818
HandshakeMetadataType,
1919
TransportOptions,
@@ -37,6 +37,10 @@ class RiverUnknownValue(BaseModel):
3737
value: Any
3838

3939

40+
class RiverUnknownError(RiverError):
41+
pass
42+
43+
4044
def translate_unknown_value(
4145
value: Any, handler: Callable[[Any], Any], info: ValidationInfo
4246
) -> Any | RiverUnknownValue:
@@ -46,6 +50,21 @@ def translate_unknown_value(
4650
return RiverUnknownValue(tag="RiverUnknownValue", value=value)
4751

4852

53+
def translate_unknown_error(
54+
value: Any, handler: Callable[[Any], Any], info: ValidationInfo
55+
) -> Any | RiverUnknownError:
56+
try:
57+
return handler(value)
58+
except Exception:
59+
if isinstance(value, dict) and "code" in value and "message" in value:
60+
return RiverUnknownError(
61+
code=value["code"],
62+
message=value["message"],
63+
)
64+
else:
65+
return RiverUnknownError(code=ERROR_CODE_UNKNOWN, message="Unknown error")
66+
67+
4968
class Client(Generic[HandshakeMetadataType]):
5069
def __init__(
5170
self,

src/replit_river/codegen/client.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@
8181
8282
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
8383
from replit_river.error_schema import RiverError
84-
from replit_river.client import RiverUnknownValue, translate_unknown_value
84+
from replit_river.client import RiverUnknownError, translate_unknown_error, \
85+
RiverUnknownValue, translate_unknown_value
8586
8687
import replit_river as river
8788
@@ -154,6 +155,20 @@ def encode_type(
154155
in_module: list[ModuleName],
155156
permit_unknown_members: bool,
156157
) -> tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]:
158+
def _make_open_union_type_expr(one_of: list[TypeExpression]) -> OpenUnionTypeExpr:
159+
if base_model == "RiverError":
160+
return OpenUnionTypeExpr(
161+
UnionTypeExpr(one_of),
162+
fallback_type="RiverUnknownError",
163+
validator_function="translate_unknown_error",
164+
)
165+
else:
166+
return OpenUnionTypeExpr(
167+
UnionTypeExpr(one_of),
168+
fallback_type="RiverUnknownValue",
169+
validator_function="translate_unknown_value",
170+
)
171+
157172
encoder_name: TypeName | None = None # defining this up here to placate mypy
158173
chunks: list[FileContents] = []
159174
if isinstance(type, RiverNotType):
@@ -304,7 +319,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
304319
)
305320
union: TypeExpression
306321
if permit_unknown_members:
307-
union = OpenUnionTypeExpr(UnionTypeExpr(one_of))
322+
union = _make_open_union_type_expr(one_of)
308323
else:
309324
union = UnionTypeExpr(one_of)
310325
chunks.append(
@@ -383,7 +398,7 @@ def {_field_name}(
383398
)
384399
raise ValueError(f"What does it mean to have {_o2} here?")
385400
if permit_unknown_members:
386-
union = OpenUnionTypeExpr(UnionTypeExpr(any_of))
401+
union = _make_open_union_type_expr(any_of)
387402
else:
388403
union = UnionTypeExpr(any_of)
389404
if is_literal(type):
@@ -795,17 +810,18 @@ def _type_adapter_definition(
795810
_type: TypeExpression,
796811
module_info: list[ModuleName],
797812
) -> tuple[list[TypeName], list[ModuleName], list[FileContents]]:
813+
varname = render_type_expr(type_adapter_name)
798814
rendered_type_expr = render_type_expr(_type)
799815
return (
800816
[type_adapter_name],
801817
module_info,
802818
[
803819
FileContents(
804820
dedent(f"""
805-
{render_type_expr(type_adapter_name)}: TypeAdapter[Any] = (
806-
TypeAdapter({rendered_type_expr})
807-
)
808-
""")
821+
{varname}: TypeAdapter[{rendered_type_expr}] = (
822+
TypeAdapter({rendered_type_expr})
823+
)
824+
""")
809825
)
810826
],
811827
)

src/replit_river/codegen/typing.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass
2-
from typing import NewType, assert_never
2+
from typing import NewType, assert_never, cast
33

44
ModuleName = NewType("ModuleName", str)
55
ClassName = NewType("ClassName", str)
@@ -96,6 +96,8 @@ def __lt__(self, other: object) -> bool:
9696
@dataclass(frozen=True)
9797
class OpenUnionTypeExpr:
9898
union: UnionTypeExpr
99+
fallback_type: str
100+
validator_function: str
99101

100102
def __str__(self) -> str:
101103
raise Exception("Complex type must be put through render_type_expr!")
@@ -182,10 +184,11 @@ def render_type_expr(value: TypeExpression) -> str:
182184
retval = "None"
183185
return retval
184186
case OpenUnionTypeExpr(inner):
187+
open_union = cast(OpenUnionTypeExpr, value)
185188
return (
186189
"Annotated["
187-
f"{render_type_expr(inner)} | RiverUnknownValue,"
188-
"WrapValidator(translate_unknown_value)"
190+
f"{render_type_expr(inner)} | {open_union.fallback_type},"
191+
f"WrapValidator({open_union.validator_function})"
189192
"]"
190193
)
191194
case TypeName(name):

src/replit_river/error_schema.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
# ERROR_CODE_CANCEL is the code used when either server or client cancels the stream.
1818
ERROR_CODE_CANCEL = "CANCEL"
1919

20+
# ERROR_CODE_UNKNOWN is the code for the RiverUnknownError
21+
ERROR_CODE_UNKNOWN = "UNKNOWN"
22+
2023

2124
class RiverError(BaseModel):
2225
"""Error message from the server."""

tests/codegen/rpc/generated/test_service/rpc_method.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212

1313
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
1414
from replit_river.error_schema import RiverError
15-
from replit_river.client import RiverUnknownValue, translate_unknown_value
15+
from replit_river.client import (
16+
RiverUnknownError,
17+
translate_unknown_error,
18+
RiverUnknownValue,
19+
translate_unknown_value,
20+
)
1621

1722
import replit_river as river
1823

@@ -35,11 +40,13 @@ class Rpc_MethodInput(TypedDict):
3540
data: str
3641

3742

38-
Rpc_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodInput)
43+
Rpc_MethodInputTypeAdapter: TypeAdapter[Rpc_MethodInput] = TypeAdapter(Rpc_MethodInput)
3944

4045

4146
class Rpc_MethodOutput(BaseModel):
4247
data: str
4348

4449

45-
Rpc_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Rpc_MethodOutput)
50+
Rpc_MethodOutputTypeAdapter: TypeAdapter[Rpc_MethodOutput] = TypeAdapter(
51+
Rpc_MethodOutput
52+
)

tests/codegen/snapshot/snapshots/test_basic_stream/test_service/__init__.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@
1616
Stream_MethodOutputTypeAdapter,
1717
encode_Stream_MethodInput,
1818
)
19+
from .emit_error import Emit_ErrorErrors, Emit_ErrorErrorsTypeAdapter
20+
21+
intTypeAdapter: TypeAdapter[int] = TypeAdapter(int)
22+
23+
24+
boolTypeAdapter: TypeAdapter[bool] = TypeAdapter(bool)
1925

2026

2127
class Test_ServiceService:
@@ -40,3 +46,22 @@ async def stream_method(
4046
x # type: ignore[arg-type]
4147
),
4248
)
49+
50+
async def emit_error(
51+
self,
52+
inputStream: AsyncIterable[int],
53+
) -> AsyncIterator[bool | Emit_ErrorErrors | RiverError]:
54+
return self.client.send_stream(
55+
"test_service",
56+
"emit_error",
57+
None,
58+
inputStream,
59+
None,
60+
lambda x: x,
61+
lambda x: boolTypeAdapter.validate_python(
62+
x # type: ignore[arg-type]
63+
),
64+
lambda x: Emit_ErrorErrorsTypeAdapter.validate_python(
65+
x # type: ignore[arg-type]
66+
),
67+
)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Code generated by river.codegen. DO NOT EDIT.
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
import datetime
4+
from typing import (
5+
Any,
6+
Literal,
7+
Mapping,
8+
NotRequired,
9+
TypedDict,
10+
)
11+
from typing_extensions import Annotated
12+
13+
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
14+
from replit_river.error_schema import RiverError
15+
from replit_river.client import (
16+
RiverUnknownError,
17+
translate_unknown_error,
18+
RiverUnknownValue,
19+
translate_unknown_value,
20+
)
21+
22+
import replit_river as river
23+
24+
25+
class Emit_ErrorErrorsOneOf_DATA_LOSS(RiverError):
26+
code: Literal["DATA_LOSS"]
27+
message: str
28+
29+
30+
class Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT(RiverError):
31+
code: Literal["UNEXPECTED_DISCONNECT"]
32+
message: str
33+
34+
35+
Emit_ErrorErrors = Annotated[
36+
Emit_ErrorErrorsOneOf_DATA_LOSS
37+
| Emit_ErrorErrorsOneOf_UNEXPECTED_DISCONNECT
38+
| RiverUnknownError,
39+
WrapValidator(translate_unknown_error),
40+
]
41+
42+
43+
Emit_ErrorErrorsTypeAdapter: TypeAdapter[Emit_ErrorErrors] = TypeAdapter(
44+
Emit_ErrorErrors
45+
)

tests/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212

1313
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
1414
from replit_river.error_schema import RiverError
15-
from replit_river.client import RiverUnknownValue, translate_unknown_value
15+
from replit_river.client import (
16+
RiverUnknownError,
17+
translate_unknown_error,
18+
RiverUnknownValue,
19+
translate_unknown_value,
20+
)
1621

1722
import replit_river as river
1823

@@ -35,11 +40,15 @@ class Stream_MethodInput(TypedDict):
3540
data: str
3641

3742

38-
Stream_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodInput)
43+
Stream_MethodInputTypeAdapter: TypeAdapter[Stream_MethodInput] = TypeAdapter(
44+
Stream_MethodInput
45+
)
3946

4047

4148
class Stream_MethodOutput(BaseModel):
4249
data: str
4350

4451

45-
Stream_MethodOutputTypeAdapter: TypeAdapter[Any] = TypeAdapter(Stream_MethodOutput)
52+
Stream_MethodOutputTypeAdapter: TypeAdapter[Stream_MethodOutput] = TypeAdapter(
53+
Stream_MethodOutput
54+
)

tests/codegen/snapshot/snapshots/test_pathological_types/test_service/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
encode_Pathological_MethodInputReq_Obj_Undefined,
3232
)
3333

34-
boolTypeAdapter: TypeAdapter[Any] = TypeAdapter(bool)
34+
boolTypeAdapter: TypeAdapter[bool] = TypeAdapter(bool)
3535

3636

3737
class Test_ServiceService:

tests/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212

1313
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
1414
from replit_river.error_schema import RiverError
15-
from replit_river.client import RiverUnknownValue, translate_unknown_value
15+
from replit_river.client import (
16+
RiverUnknownError,
17+
translate_unknown_error,
18+
RiverUnknownValue,
19+
translate_unknown_value,
20+
)
1621

1722
import replit_river as river
1823

@@ -473,6 +478,6 @@ class Pathological_MethodInput(TypedDict):
473478
undefined: NotRequired[None]
474479

475480

476-
Pathological_MethodInputTypeAdapter: TypeAdapter[Any] = TypeAdapter(
477-
Pathological_MethodInput
481+
Pathological_MethodInputTypeAdapter: TypeAdapter[Pathological_MethodInput] = (
482+
TypeAdapter(Pathological_MethodInput)
478483
)

0 commit comments

Comments
 (0)