Skip to content

Commit 672492f

Browse files
committed
make error type and translate function dynamic
1 parent c46e64a commit 672492f

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

src/replit_river/codegen/client.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@
9393
9494
from pydantic import BaseModel, Field, TypeAdapter, WrapValidator
9595
from replit_river.error_schema import RiverError
96-
from replit_river.client import RiverUnknownError, translate_unknown_error
96+
from replit_river.client import RiverUnknownError, translate_unknown_error, \
97+
RiverUnknownValue, translate_unknown_value
9798
9899
import replit_river as river
99100
@@ -168,6 +169,17 @@ def encode_type(
168169
in_module: list[ModuleName],
169170
permit_unknown_members: bool,
170171
) -> Tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]:
172+
def _make_open_union_type_expr(one_of: list[TypeExpression]) -> OpenUnionTypeExpr:
173+
return OpenUnionTypeExpr(
174+
UnionTypeExpr(one_of),
175+
fallback_type="RiverUnknownError"
176+
if base_model == "RiverError"
177+
else "RiverUnknownValue",
178+
validator_function="translate_unknown_error"
179+
if base_model == "RiverError"
180+
else "translate_unknown_value",
181+
)
182+
171183
encoder_name: TypeName | None = None # defining this up here to placate mypy
172184
chunks: List[FileContents] = []
173185
if isinstance(type, RiverNotType):
@@ -318,7 +330,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
318330
)
319331
union: TypeExpression
320332
if permit_unknown_members:
321-
union = OpenUnionTypeExpr(UnionTypeExpr(one_of))
333+
union = _make_open_union_type_expr(one_of)
322334
else:
323335
union = UnionTypeExpr(one_of)
324336
chunks.append(
@@ -392,7 +404,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
392404
)
393405
raise ValueError(f"What does it mean to have {_o2} here?")
394406
if permit_unknown_members:
395-
union = OpenUnionTypeExpr(UnionTypeExpr(any_of))
407+
union = _make_open_union_type_expr(any_of)
396408
else:
397409
union = UnionTypeExpr(any_of)
398410
if is_literal(type):

src/replit_river/codegen/typing.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def __str__(self) -> str:
5858
@dataclass(frozen=True)
5959
class OpenUnionTypeExpr:
6060
union: UnionTypeExpr
61+
fallback_type: str
62+
validator_function: str
6163

6264
def __str__(self) -> str:
6365
raise Exception("Complex type must be put through render_type_expr!")
@@ -87,8 +89,8 @@ def render_type_expr(value: TypeExpression) -> str:
8789
case OpenUnionTypeExpr(inner):
8890
return (
8991
"Annotated["
90-
f"{render_type_expr(inner)} | RiverUnknownError,"
91-
"WrapValidator(translate_unknown_error)"
92+
f"{render_type_expr(inner)} | {value.fallback_type},"
93+
f"WrapValidator({value.validator_function})"
9294
"]"
9395
)
9496
case TypeName(name):

0 commit comments

Comments
 (0)