Skip to content

Commit da04bbb

Browse files
v2 send_stream codegen
1 parent 317db2a commit da04bbb

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

src/replit_river/codegen/client.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -978,10 +978,8 @@ async def {name}(
978978
]
979979
)
980980
elif procedure.type == "stream":
981-
assert input_meta
982981
assert output_meta
983982
assert error_meta
984-
_, input_type, render_input_method = input_meta
985983
_, output_type, parse_output_method = output_meta
986984
_, error_type, parse_error_method = error_meta
987985
error_type_name = extract_inner_type(error_type)
@@ -994,8 +992,9 @@ async def {name}(
994992
TypeName("RiverError"),
995993
]
996994
)
997-
if init_meta:
995+
if init_meta and input_meta:
998996
_, init_type, render_init_method = init_meta
997+
_, input_type, render_input_method = input_meta
999998
current_chunks.extend(
1000999
[
10011000
reindent(
@@ -1020,8 +1019,9 @@ async def {name}(
10201019
)
10211020
]
10221021
)
1023-
else:
1024-
assert protocol_version == "v1.1", "Protocol v2 requires init to be defined"
1022+
elif protocol_version == "v1.1":
1023+
assert input_meta, "Protocol v1 requires input to be defined"
1024+
_, input_type, render_input_method = input_meta
10251025
current_chunks.extend(
10261026
[
10271027
reindent(
@@ -1045,6 +1045,34 @@ async def {name}(
10451045
)
10461046
]
10471047
)
1048+
elif protocol_version == "v2.0":
1049+
assert init_meta, "Protocol v2 requires init to be defined"
1050+
_, init_type, render_init_method = init_meta
1051+
current_chunks.extend(
1052+
[
1053+
reindent(
1054+
" ",
1055+
f"""\
1056+
async def {name}(
1057+
self,
1058+
init: {render_type_expr(init_type)},
1059+
) -> AsyncIterator[{render_type_expr(output_or_error_type)}]:
1060+
return self.client.send_stream(
1061+
{repr(schema_name)},
1062+
{repr(name)},
1063+
init,
1064+
None,
1065+
{reindent(" ", render_init_method)},
1066+
None,
1067+
{reindent(" ", parse_output_method)},
1068+
{reindent(" ", parse_error_method)},
1069+
)
1070+
""",
1071+
)
1072+
]
1073+
)
1074+
else:
1075+
raise ValueError("Precondition failed")
10481076

10491077
current_chunks.append("")
10501078
return current_chunks

0 commit comments

Comments
 (0)