Skip to content

Commit 92a5c3b

Browse files
Disambiguate slips between input and init
1 parent cb85439 commit 92a5c3b

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

replit_river/codegen/client.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ def __init__(self, client: river.Client[Any]):
784784
module_names = [ModuleName(name)]
785785
init_type: Optional[TypeExpression] = None
786786
if procedure.init:
787-
init_type, module_info, input_chunks, encoder_names = encode_type(
787+
init_type, module_info, init_chunks, encoder_names = encode_type(
788788
procedure.init,
789789
TypeName(f"{name.title()}Init"),
790790
input_base_class,
@@ -794,7 +794,7 @@ def __init__(self, client: river.Client[Any]):
794794
(
795795
[extract_inner_type(init_type), *encoder_names],
796796
module_info,
797-
input_chunks,
797+
init_chunks,
798798
)
799799
)
800800
input_type, module_info, input_chunks, encoder_names = encode_type(
@@ -859,27 +859,28 @@ def __init__(self, client: river.Client[Any]):
859859

860860
# Init renderer
861861
render_init_method: Optional[str] = None
862-
if input_base_class == "TypedDict" and init_type:
863-
if is_literal(procedure.input):
864-
render_init_method = "lambda x: x"
865-
elif isinstance(
866-
procedure.input, RiverConcreteType
867-
) and procedure.input.type in ["array"]:
868-
match init_type:
869-
case ListTypeExpr(init_type_name):
870-
render_init_method = (
871-
f"lambda xs: [encode_{init_type_name}(x) for x in xs]"
872-
)
862+
if init_type and procedure.init is not None:
863+
if input_base_class == "TypedDict":
864+
if is_literal(procedure.init):
865+
render_init_method = "lambda x: x"
866+
elif isinstance(
867+
procedure.init, RiverConcreteType
868+
) and procedure.init.type in ["array"]:
869+
match init_type:
870+
case ListTypeExpr(init_type_name):
871+
render_init_method = (
872+
f"lambda xs: [encode_{init_type_name}(x) for x in xs]"
873+
)
874+
else:
875+
render_init_method = f"encode_{ensure_literal_type(init_type)}"
873876
else:
874-
render_init_method = f"encode_{ensure_literal_type(init_type)}"
875-
else:
876-
render_init_method = f"""\
877-
lambda x: TypeAdapter({render_type_expr(input_type)})
878-
.validate_python
879-
"""
877+
render_init_method = f"""\
878+
lambda x: TypeAdapter({render_type_expr(init_type)})
879+
.validate_python
880+
"""
880881

881882
assert (
882-
render_init_method
883+
init_type is None or render_init_method
883884
), f"Unable to derive the init encoder from: {input_type}"
884885

885886
# Input renderer
@@ -973,6 +974,7 @@ async def {name}(
973974
if output_type == "None":
974975
control_flow_keyword = ""
975976
if init_type:
977+
assert render_init_method, "Expected an init renderer!"
976978
current_chunks.extend(
977979
[
978980
reindent(
@@ -1023,6 +1025,7 @@ async def {name}(
10231025
)
10241026
elif procedure.type == "stream":
10251027
if init_type:
1028+
assert render_init_method, "Expected an init renderer!"
10261029
current_chunks.extend(
10271030
[
10281031
reindent(

0 commit comments

Comments
 (0)