Skip to content

Commit f36e0be

Browse files
Making input and init types optional, gating on protocol version
1 parent 86009f9 commit f36e0be

File tree

1 file changed

+60
-66
lines changed

1 file changed

+60
-66
lines changed

src/replit_river/codegen/client.py

Lines changed: 60 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -839,8 +839,9 @@ def __init__(self, client: river.Client[Any]):
839839
for name, procedure in schema.procedures.items():
840840
module_names = [ModuleName(name)]
841841
init_type: TypeExpression | None = None
842+
init_module_info: list[ModuleName] = []
842843
if procedure.init:
843-
init_type, module_info, init_chunks, encoder_names = encode_type(
844+
init_type, init_module_info, init_chunks, encoder_names = encode_type(
844845
procedure.init,
845846
TypeName(f"{name.title()}Init"),
846847
input_base_class,
@@ -850,34 +851,29 @@ def __init__(self, client: river.Client[Any]):
850851
serdes.append(
851852
(
852853
[extract_inner_type(init_type), *encoder_names],
853-
module_info,
854+
init_module_info,
854855
init_chunks,
855856
)
856857
)
857-
input_type, module_info, input_chunks, encoder_names = encode_type(
858-
procedure.input,
859-
TypeName(f"{name.title()}Input"),
860-
input_base_class,
861-
module_names,
862-
permit_unknown_members=False,
863-
)
864-
input_type_name = extract_inner_type(input_type)
865-
input_type_type_adapter_name = TypeName(
866-
f"{render_literal_type(input_type_name)}TypeAdapter"
867-
)
868-
serdes.append(
869-
(
870-
[extract_inner_type(input_type), *encoder_names],
871-
module_info,
872-
input_chunks,
858+
input_type: TypeExpression | None = None
859+
input_module_info: list[ModuleName] = []
860+
if procedure.input:
861+
input_type, input_module_info, input_chunks, encoder_names = encode_type(
862+
procedure.input,
863+
TypeName(f"{name.title()}Input"),
864+
input_base_class,
865+
module_names,
866+
permit_unknown_members=False,
873867
)
874-
)
875-
serdes.append(
876-
_type_adapter_definition(
877-
input_type_type_adapter_name, input_type, module_info
868+
serdes.append(
869+
(
870+
[extract_inner_type(input_type), *encoder_names],
871+
input_module_info,
872+
input_chunks,
873+
)
878874
)
879-
)
880-
output_type, module_info, output_chunks, encoder_names = encode_type(
875+
876+
output_type, output_module_info, output_chunks, encoder_names = encode_type(
881877
procedure.output,
882878
TypeName(f"{name.title()}Output"),
883879
"BaseModel",
@@ -888,7 +884,7 @@ def __init__(self, client: river.Client[Any]):
888884
serdes.append(
889885
(
890886
[output_type_name, *encoder_names],
891-
module_info,
887+
output_module_info,
892888
output_chunks,
893889
)
894890
)
@@ -897,12 +893,12 @@ def __init__(self, client: river.Client[Any]):
897893
)
898894
serdes.append(
899895
_type_adapter_definition(
900-
output_type_type_adapter_name, output_type, module_info
896+
output_type_type_adapter_name, output_type, output_module_info
901897
)
902898
)
903-
output_module_info = module_info
899+
904900
if procedure.errors:
905-
error_type, module_info, errors_chunks, encoder_names = encode_type(
901+
error_type, error_module_info, errors_chunks, encoder_names = encode_type(
906902
procedure.errors,
907903
TypeName(f"{name.title()}Errors"),
908904
"RiverError",
@@ -914,7 +910,7 @@ def __init__(self, client: river.Client[Any]):
914910
error_type = error_type_name
915911
else:
916912
error_type_name = extract_inner_type(error_type)
917-
serdes.append(([error_type_name], module_info, errors_chunks))
913+
serdes.append(([error_type_name], error_module_info, errors_chunks))
918914

919915
else:
920916
error_type_name = TypeName("RiverError")
@@ -924,11 +920,9 @@ def __init__(self, client: river.Client[Any]):
924920
f"{render_literal_type(error_type_name)}TypeAdapter"
925921
)
926922
if error_type_type_adapter_name.value != "RiverErrorTypeAdapter":
927-
if len(module_info) == 0:
928-
module_info = output_module_info
929923
serdes.append(
930924
_type_adapter_definition(
931-
error_type_type_adapter_name, error_type, module_info
925+
error_type_type_adapter_name, error_type, output_module_info
932926
)
933927
)
934928
output_or_error_type = UnionTypeExpr([output_type, error_type_name])
@@ -975,48 +969,50 @@ def __init__(self, client: river.Client[Any]):
975969
)
976970
serdes.append(
977971
_type_adapter_definition(
978-
init_type_type_adapter_name, init_type, module_info
972+
init_type_type_adapter_name, init_type, init_module_info
979973
)
980974
)
981975
render_init_method = f"""\
982976
lambda x: {render_type_expr(init_type_type_adapter_name)}
983977
.validate_python
984978
"""
985979

986-
assert init_type is None or render_init_method, (
987-
f"Unable to derive the init encoder from: {init_type}"
988-
)
989-
990980
# Input renderer
991981
render_input_method: str | None = None
992-
if input_base_class == "TypedDict":
993-
if is_literal(procedure.input):
994-
render_input_method = "lambda x: x"
995-
elif isinstance(
996-
procedure.input, RiverConcreteType
997-
) and procedure.input.type in ["array"]:
998-
match input_type:
999-
case ListTypeExpr(list_type):
1000-
render_input_method = f"""\
1001-
lambda xs: [
1002-
encode_{render_literal_type(list_type)}(x) for x in xs
1003-
]
1004-
"""
1005-
else:
1006-
render_input_method = f"encode_{render_literal_type(input_type)}"
1007-
else:
1008-
render_input_method = f"""\
1009-
lambda x: {render_type_expr(input_type_type_adapter_name)}
1010-
.dump_python(
1011-
x, # type: ignore[arg-type]
1012-
by_alias=True,
1013-
exclude_none=True,
1014-
)
982+
if input_type and procedure.input is not None:
983+
if input_base_class == "TypedDict":
984+
if is_literal(procedure.input):
985+
render_input_method = "lambda x: x"
986+
elif isinstance(
987+
procedure.input, RiverConcreteType
988+
) and procedure.input.type in ["array"]:
989+
match input_type:
990+
case ListTypeExpr(list_type):
991+
render_input_method = f"""\
992+
lambda xs: [
993+
encode_{render_literal_type(list_type)}(x) for x in xs
994+
]
1015995
"""
1016-
1017-
assert render_input_method, (
1018-
f"Unable to derive the input encoder from: {input_type}"
1019-
)
996+
else:
997+
render_input_method = f"encode_{render_literal_type(input_type)}"
998+
else:
999+
input_type_name = extract_inner_type(input_type)
1000+
input_type_type_adapter_name = TypeName(
1001+
f"{render_literal_type(input_type_name)}TypeAdapter"
1002+
)
1003+
serdes.append(
1004+
_type_adapter_definition(
1005+
input_type_type_adapter_name, input_type, input_module_info
1006+
)
1007+
)
1008+
render_input_method = f"""\
1009+
lambda x: {render_type_expr(input_type_type_adapter_name)}
1010+
.dump_python(
1011+
x, # type: ignore[arg-type]
1012+
by_alias=True,
1013+
exclude_none=True,
1014+
)
1015+
"""
10201016

10211017
if isinstance(output_type, NoneTypeExpr):
10221018
parse_output_method = "lambda x: None"
@@ -1075,7 +1071,6 @@ async def {name}(
10751071
)
10761072
elif procedure.type == "upload":
10771073
if init_type:
1078-
assert render_init_method, "Expected an init renderer!"
10791074
current_chunks.extend(
10801075
[
10811076
reindent(
@@ -1132,7 +1127,6 @@ async def {name}(
11321127
]
11331128
)
11341129
if init_type:
1135-
assert render_init_method, "Expected an init renderer!"
11361130
current_chunks.extend(
11371131
[
11381132
reindent(

0 commit comments

Comments
 (0)