|
93 | 93 |
|
94 | 94 | from pydantic import BaseModel, Field, TypeAdapter, WrapValidator |
95 | 95 | from replit_river.error_schema import RiverError |
96 | | -from replit_river.client import RiverUnknownError, translate_unknown_error |
| 96 | +from replit_river.client import RiverUnknownError, translate_unknown_error, \ |
| 97 | + RiverUnknownValue, translate_unknown_value |
97 | 98 |
|
98 | 99 | import replit_river as river |
99 | 100 |
|
@@ -168,6 +169,17 @@ def encode_type( |
168 | 169 | in_module: list[ModuleName], |
169 | 170 | permit_unknown_members: bool, |
170 | 171 | ) -> Tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]: |
| 172 | + def _make_open_union_type_expr(one_of: list[TypeExpression]) -> OpenUnionTypeExpr: |
| 173 | + return OpenUnionTypeExpr( |
| 174 | + UnionTypeExpr(one_of), |
| 175 | + fallback_type="RiverUnknownError" |
| 176 | + if base_model == "RiverError" |
| 177 | + else "RiverUnknownValue", |
| 178 | + validator_function="translate_unknown_error" |
| 179 | + if base_model == "RiverError" |
| 180 | + else "translate_unknown_value", |
| 181 | + ) |
| 182 | + |
171 | 183 | encoder_name: TypeName | None = None # defining this up here to placate mypy |
172 | 184 | chunks: List[FileContents] = [] |
173 | 185 | if isinstance(type, RiverNotType): |
@@ -318,7 +330,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: |
318 | 330 | ) |
319 | 331 | union: TypeExpression |
320 | 332 | if permit_unknown_members: |
321 | | - union = OpenUnionTypeExpr(UnionTypeExpr(one_of)) |
| 333 | + union = _make_open_union_type_expr(one_of) |
322 | 334 | else: |
323 | 335 | union = UnionTypeExpr(one_of) |
324 | 336 | chunks.append( |
@@ -392,7 +404,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: |
392 | 404 | ) |
393 | 405 | raise ValueError(f"What does it mean to have {_o2} here?") |
394 | 406 | if permit_unknown_members: |
395 | | - union = OpenUnionTypeExpr(UnionTypeExpr(any_of)) |
| 407 | + union = _make_open_union_type_expr(any_of) |
396 | 408 | else: |
397 | 409 | union = UnionTypeExpr(any_of) |
398 | 410 | if is_literal(type): |
|
0 commit comments