Skip to content

Commit ca487fe

Browse files
Respect required "init" vs required "input"
1 parent 974c59f commit ca487fe

File tree

3 files changed

+73
-100
lines changed

3 files changed

+73
-100
lines changed

replit_river/client.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ async def send_upload(
6969
self,
7070
service_name: str,
7171
procedure_name: str,
72-
init: Optional[InitType],
73-
request: AsyncIterable[RequestType],
74-
init_serializer: Optional[Callable[[InitType], Any]],
75-
request_serializer: Callable[[RequestType], Any],
72+
init: InitType,
73+
request: Optional[AsyncIterable[RequestType]],
74+
init_serializer: Callable[[InitType], Any],
75+
request_serializer: Optional[Callable[[RequestType], Any]],
7676
response_deserializer: Callable[[Any], ResponseType],
7777
error_deserializer: Callable[[Any], ErrorType],
7878
) -> ResponseType:
@@ -111,10 +111,10 @@ async def send_stream(
111111
self,
112112
service_name: str,
113113
procedure_name: str,
114-
init: Optional[InitType],
115-
request: AsyncIterable[RequestType],
116-
init_serializer: Optional[Callable[[InitType], Any]],
117-
request_serializer: Callable[[RequestType], Any],
114+
init: InitType,
115+
request: Optional[AsyncIterable[RequestType]],
116+
init_serializer: Callable[[InitType], Any],
117+
request_serializer: Optional[Callable[[RequestType], Any]],
118118
response_deserializer: Callable[[Any], ResponseType],
119119
error_deserializer: Callable[[Any], ErrorType],
120120
) -> AsyncIterator[Union[ResponseType, ErrorType]]:

replit_river/client_session.py

Lines changed: 50 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ async def send_upload(
8383
self,
8484
service_name: str,
8585
procedure_name: str,
86-
init: Optional[InitType],
87-
request: AsyncIterable[RequestType],
88-
init_serializer: Optional[Callable[[InitType], Any]],
89-
request_serializer: Callable[[RequestType], Any],
86+
init: InitType,
87+
request: Optional[AsyncIterable[RequestType]],
88+
init_serializer: Callable[[InitType], Any],
89+
request_serializer: Optional[Callable[[RequestType], Any]],
9090
response_deserializer: Callable[[Any], ResponseType],
9191
error_deserializer: Callable[[Any], ErrorType],
9292
) -> ResponseType:
@@ -100,29 +100,26 @@ async def send_upload(
100100
self._streams[stream_id] = output
101101
first_message = True
102102
try:
103-
if init and init_serializer:
104-
await self.send_message(
105-
stream_id=stream_id,
106-
control_flags=STREAM_OPEN_BIT,
107-
service_name=service_name,
108-
procedure_name=procedure_name,
109-
payload=init_serializer(init),
110-
)
111-
first_message = False
112-
# If this request is not closed and the session is killed, we should
113-
# throw exception here
114-
async for item in request:
115-
control_flags = 0
116-
if first_message:
117-
control_flags = STREAM_OPEN_BIT
118-
first_message = False
119-
await self.send_message(
120-
stream_id=stream_id,
121-
service_name=service_name,
122-
procedure_name=procedure_name,
123-
control_flags=control_flags,
124-
payload=request_serializer(item),
125-
)
103+
await self.send_message(
104+
stream_id=stream_id,
105+
control_flags=STREAM_OPEN_BIT,
106+
service_name=service_name,
107+
procedure_name=procedure_name,
108+
payload=init_serializer(init),
109+
)
110+
first_message = False
111+
if request is not None and request_serializer is not None:
112+
# If this request is not closed and the session is killed, we should
113+
# throw exception here
114+
async for item in request:
115+
control_flags = 0
116+
await self.send_message(
117+
stream_id=stream_id,
118+
service_name=service_name,
119+
procedure_name=procedure_name,
120+
control_flags=control_flags,
121+
payload=request_serializer(item),
122+
)
126123
except Exception as e:
127124
raise RiverServiceException(
128125
ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name
@@ -215,10 +212,10 @@ async def send_stream(
215212
self,
216213
service_name: str,
217214
procedure_name: str,
218-
init: Optional[InitType],
219-
request: AsyncIterable[RequestType],
220-
init_serializer: Optional[Callable[[InitType], Any]],
221-
request_serializer: Callable[[RequestType], Any],
215+
init: InitType,
216+
request: Optional[AsyncIterable[RequestType]],
217+
init_serializer: Callable[[InitType], Any],
218+
request_serializer: Optional[Callable[[RequestType], Any]],
222219
response_deserializer: Callable[[Any], ResponseType],
223220
error_deserializer: Callable[[Any], ErrorType],
224221
) -> AsyncIterator[Union[ResponseType, ErrorType]]:
@@ -230,60 +227,36 @@ async def send_stream(
230227
stream_id = nanoid.generate()
231228
output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE)
232229
self._streams[stream_id] = output
233-
empty_stream = False
234230
try:
235-
if init and init_serializer:
236-
await self.send_message(
237-
service_name=service_name,
238-
procedure_name=procedure_name,
239-
stream_id=stream_id,
240-
control_flags=STREAM_OPEN_BIT,
241-
payload=init_serializer(init),
242-
)
243-
else:
244-
# Get the very first message to open the stream
245-
request_iter = aiter(request)
246-
first = await anext(request_iter)
247-
await self.send_message(
248-
service_name=service_name,
249-
procedure_name=procedure_name,
250-
stream_id=stream_id,
251-
control_flags=STREAM_OPEN_BIT,
252-
payload=request_serializer(first),
253-
)
254-
255-
except StopAsyncIteration:
256-
empty_stream = True
231+
await self.send_message(
232+
service_name=service_name,
233+
procedure_name=procedure_name,
234+
stream_id=stream_id,
235+
control_flags=STREAM_OPEN_BIT,
236+
payload=init_serializer(init),
237+
)
257238

258239
except Exception as e:
259240
raise StreamClosedRiverServiceException(
260241
ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name
261242
) from e
262243

263-
# Create the encoder task
264-
async def _encode_stream() -> None:
265-
if empty_stream:
266-
await self.send_close_stream(
267-
service_name,
268-
procedure_name,
269-
stream_id,
270-
extra_control_flags=STREAM_OPEN_BIT,
271-
)
272-
return
273-
274-
async for item in request:
275-
if item is None:
276-
continue
277-
await self.send_message(
278-
service_name=service_name,
279-
procedure_name=procedure_name,
280-
stream_id=stream_id,
281-
control_flags=0,
282-
payload=request_serializer(item),
283-
)
284-
await self.send_close_stream(service_name, procedure_name, stream_id)
244+
if request is not None and request_serializer is not None:
245+
# Create the encoder task
246+
async def _encode_stream() -> None:
247+
async for item in request:
248+
if item is None:
249+
continue
250+
await self.send_message(
251+
service_name=service_name,
252+
procedure_name=procedure_name,
253+
stream_id=stream_id,
254+
control_flags=0,
255+
payload=request_serializer(item),
256+
)
257+
await self.send_close_stream(service_name, procedure_name, stream_id)
285258

286-
self._task_manager.create_task(_encode_stream())
259+
self._task_manager.create_task(_encode_stream())
287260

288261
# Handle potential errors during communication
289262
try:

replit_river/codegen/client.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -646,13 +646,13 @@ def __init__(self, client: river.Client[{handshake_type}]):
646646
f"""\
647647
async def {name}(
648648
self,
649-
input: {input_type},
649+
init: {init_type},
650650
) -> {output_type}:
651651
return await self.client.send_rpc(
652652
{repr(schema_name)},
653653
{repr(name)},
654-
input,
655-
{render_input_method},
654+
init,
655+
{render_init_method},
656656
{parse_output_method},
657657
{parse_error_method},
658658
)
@@ -668,13 +668,13 @@ async def {name}(
668668
f"""\
669669
async def {name}(
670670
self,
671-
input: {input_type},
671+
init: {init_type},
672672
) -> AsyncIterator[{output_or_error_type}]:
673673
return await self.client.send_subscription(
674674
{repr(schema_name)},
675675
{repr(name)},
676-
input,
677-
{render_input_method},
676+
init,
677+
{render_init_method},
678678
{parse_output_method},
679679
{parse_error_method},
680680
)
@@ -683,7 +683,7 @@ async def {name}(
683683
]
684684
)
685685
elif procedure.type == "upload":
686-
if init_type:
686+
if input_type is not None:
687687
current_chunks.extend(
688688
[
689689
reindent(
@@ -716,15 +716,15 @@ async def {name}(
716716
f"""\
717717
async def {name}(
718718
self,
719-
inputStream: AsyncIterable[{input_type}],
720-
) -> {output_or_error_type}:
719+
init: {init_type},
720+
) -> {output_type}:
721721
return await self.client.send_upload(
722722
{repr(schema_name)},
723723
{repr(name)},
724+
init,
724725
None,
725-
inputStream,
726+
{render_init_method},
726727
None,
727-
{render_input_method},
728728
{parse_output_method},
729729
{parse_error_method},
730730
)
@@ -733,7 +733,7 @@ async def {name}(
733733
]
734734
)
735735
elif procedure.type == "stream":
736-
if init_type:
736+
if input_type is not None:
737737
current_chunks.extend(
738738
[
739739
reindent(
@@ -766,15 +766,15 @@ async def {name}(
766766
f"""\
767767
async def {name}(
768768
self,
769-
inputStream: AsyncIterable[{input_type}],
769+
init: {init_type},
770770
) -> AsyncIterator[{output_or_error_type}]:
771771
return await self.client.send_stream(
772772
{repr(schema_name)},
773773
{repr(name)},
774+
init,
774775
None,
775-
inputStream,
776+
{render_init_method},
776777
None,
777-
{render_input_method},
778778
{parse_output_method},
779779
{parse_error_method},
780780
)

0 commit comments

Comments
 (0)