Skip to content

Commit 0cadc40

Browse files
Swap required "input" to required "init"
1 parent 9d17da2 commit 0cadc40

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

replit_river/codegen/client.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ class RiverNotType(BaseModel):
5151

5252

5353
class RiverProcedure(BaseModel):
54-
init: Optional[RiverType] = Field(default=None)
55-
input: RiverType
54+
init: RiverType
55+
input: Optional[RiverType] = Field(default=None)
5656
output: RiverType
5757
errors: Optional[RiverType] = Field(default=None)
5858
type: (
@@ -532,20 +532,20 @@ def __init__(self, client: river.Client[{handshake_type}]):
532532
),
533533
]
534534
for name, procedure in schema.procedures.items():
535-
init_type: Optional[str] = None
536-
if procedure.init:
537-
init_type, init_chunks = encode_type(
538-
procedure.init,
539-
f"{schema_name.title()}{name.title()}Init",
540-
base_model=input_base_class,
541-
)
542-
chunks.extend(init_chunks)
543-
input_type, input_chunks = encode_type(
544-
procedure.input,
545-
f"{schema_name.title()}{name.title()}Input",
535+
init_type, init_chunks = encode_type(
536+
procedure.init,
537+
f"{schema_name.title()}{name.title()}Init",
546538
base_model=input_base_class,
547539
)
548-
chunks.extend(input_chunks)
540+
chunks.extend(init_chunks)
541+
input_type: Optional[str] = None
542+
if procedure.input:
543+
input_type, input_chunks = encode_type(
544+
procedure.input,
545+
f"{schema_name.title()}{name.title()}Input",
546+
base_model=input_base_class,
547+
)
548+
chunks.extend(input_chunks)
549549
output_type, output_chunks = encode_type(
550550
procedure.output,
551551
f"{schema_name.title()}{name.title()}Output",
@@ -584,7 +584,7 @@ def __init__(self, client: river.Client[{handshake_type}]):
584584
""".rstrip()
585585

586586
# Init renderer
587-
if typed_dict_inputs and init_type:
587+
if typed_dict_inputs:
588588
if is_literal(procedure.init):
589589
render_init_method = "lambda x: x"
590590
elif isinstance(
@@ -606,7 +606,7 @@ def __init__(self, client: river.Client[{handshake_type}]):
606606
""".rstrip()
607607

608608
# Input renderer
609-
if typed_dict_inputs:
609+
if typed_dict_inputs and input_type and procedure.input:
610610
if is_literal(procedure.input):
611611
render_input_method = "lambda x: x"
612612
elif isinstance(

0 commit comments

Comments
 (0)