Skip to content

Commit b9283cd

Browse files
Disambiguate slips between input and init
1 parent 2a0b18e commit b9283cd

File tree

1 file changed

+23
-20
lines changed

1 file changed

+23
-20
lines changed

replit_river/codegen/client.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ def __init__(self, client: river.Client[Any]):
781781
module_names = [ModuleName(name)]
782782
init_type: Optional[TypeExpression] = None
783783
if procedure.init:
784-
init_type, module_info, input_chunks, encoder_names = encode_type(
784+
init_type, module_info, init_chunks, encoder_names = encode_type(
785785
procedure.init,
786786
TypeName(f"{name.title()}Init"),
787787
input_base_class,
@@ -791,7 +791,7 @@ def __init__(self, client: river.Client[Any]):
791791
(
792792
[extract_inner_type(init_type), *encoder_names],
793793
module_info,
794-
input_chunks,
794+
init_chunks,
795795
)
796796
)
797797
input_type, module_info, input_chunks, encoder_names = encode_type(
@@ -856,27 +856,28 @@ def __init__(self, client: river.Client[Any]):
856856

857857
# Init renderer
858858
render_init_method: Optional[str] = None
859-
if input_base_class == "TypedDict" and init_type:
860-
if is_literal(procedure.input):
861-
render_init_method = "lambda x: x"
862-
elif isinstance(
863-
procedure.input, RiverConcreteType
864-
) and procedure.input.type in ["array"]:
865-
match init_type:
866-
case ListTypeExpr(init_type_name):
867-
render_init_method = (
868-
f"lambda xs: [encode_{init_type_name}(x) for x in xs]"
869-
)
859+
if init_type and procedure.init is not None:
860+
if input_base_class == "TypedDict":
861+
if is_literal(procedure.init):
862+
render_init_method = "lambda x: x"
863+
elif isinstance(
864+
procedure.init, RiverConcreteType
865+
) and procedure.init.type in ["array"]:
866+
match init_type:
867+
case ListTypeExpr(init_type_name):
868+
render_init_method = (
869+
f"lambda xs: [encode_{init_type_name}(x) for x in xs]"
870+
)
871+
else:
872+
render_init_method = f"encode_{ensure_literal_type(init_type)}"
870873
else:
871-
render_init_method = f"encode_{ensure_literal_type(init_type)}"
872-
else:
873-
render_init_method = f"""\
874-
lambda x: TypeAdapter({render_type_expr(input_type)})
875-
.validate_python
876-
"""
874+
render_init_method = f"""\
875+
lambda x: TypeAdapter({render_type_expr(init_type)})
876+
.validate_python
877+
"""
877878

878879
assert (
879-
render_init_method
880+
init_type is None or render_init_method
880881
), f"Unable to derive the init encoder from: {input_type}"
881882

882883
# Input renderer
@@ -970,6 +971,7 @@ async def {name}(
970971
if output_type == "None":
971972
control_flow_keyword = ""
972973
if init_type:
974+
assert render_init_method, "Expected an init renderer!"
973975
current_chunks.extend(
974976
[
975977
reindent(
@@ -1020,6 +1022,7 @@ async def {name}(
10201022
)
10211023
elif procedure.type == "stream":
10221024
if init_type:
1025+
assert render_init_method, "Expected an init renderer!"
10231026
current_chunks.extend(
10241027
[
10251028
reindent(

0 commit comments

Comments
 (0)