Skip to content

Commit 317db2a

Browse files
Reflowing v2 send_stream to have modern semantics
1 parent 7b4dd0e commit 317db2a

File tree

2 files changed

+19
-35
lines changed

2 files changed

+19
-35
lines changed

src/replit_river/v2/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,10 @@ async def send_stream[I, R, E, A](
163163
self,
164164
service_name: str,
165165
procedure_name: str,
166-
init: I | None,
167-
request: AsyncIterable[R],
168-
init_serializer: Callable[[I], Any] | None,
169-
request_serializer: Callable[[R], Any],
166+
init: I,
167+
request: AsyncIterable[R] | None,
168+
init_serializer: Callable[[I], Any],
169+
request_serializer: Callable[[R], Any] | None,
170170
response_deserializer: Callable[[Any], A],
171171
error_deserializer: Callable[[Any], E],
172172
) -> AsyncGenerator[A | E, None]:

src/replit_river/v2/client_session.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -383,10 +383,10 @@ async def send_stream[I, R, E, A](
383383
self,
384384
service_name: str,
385385
procedure_name: str,
386-
init: I | None,
387-
request: AsyncIterable[R],
388-
init_serializer: Callable[[I], Any] | None,
389-
request_serializer: Callable[[R], Any],
386+
init: I,
387+
request: AsyncIterable[R] | None,
388+
init_serializer: Callable[[I], Any],
389+
request_serializer: Callable[[R], Any] | None,
390390
response_deserializer: Callable[[Any], A],
391391
error_deserializer: Callable[[Any], E],
392392
span: Span,
@@ -399,41 +399,23 @@ async def send_stream[I, R, E, A](
399399
stream_id = nanoid.generate()
400400
output: Channel[Any] = Channel(MAX_MESSAGE_BUFFER_SIZE)
401401
self._streams[stream_id] = output
402-
empty_stream = False
403402
try:
404-
if init and init_serializer:
405-
await self.send_message(
406-
service_name=service_name,
407-
procedure_name=procedure_name,
408-
stream_id=stream_id,
409-
control_flags=STREAM_OPEN_BIT,
410-
payload=init_serializer(init),
411-
span=span,
412-
)
413-
else:
414-
# Get the very first message to open the stream
415-
request_iter = aiter(request)
416-
first = await anext(request_iter)
417-
await self.send_message(
418-
service_name=service_name,
419-
procedure_name=procedure_name,
420-
stream_id=stream_id,
421-
control_flags=STREAM_OPEN_BIT,
422-
payload=request_serializer(first),
423-
span=span,
424-
)
425-
426-
except StopAsyncIteration:
427-
empty_stream = True
428-
403+
await self.send_message(
404+
service_name=service_name,
405+
procedure_name=procedure_name,
406+
stream_id=stream_id,
407+
control_flags=STREAM_OPEN_BIT,
408+
payload=init_serializer(init),
409+
span=span,
410+
)
429411
except Exception as e:
430412
raise StreamClosedRiverServiceException(
431413
ERROR_CODE_STREAM_CLOSED, str(e), service_name, procedure_name
432414
) from e
433415

434416
# Create the encoder task
435417
async def _encode_stream() -> None:
436-
if empty_stream:
418+
if not request:
437419
await self.send_close_stream(
438420
service_name,
439421
procedure_name,
@@ -442,6 +424,8 @@ async def _encode_stream() -> None:
442424
)
443425
return
444426

427+
assert request_serializer, "send_stream missing request_serializer"
428+
445429
async for item in request:
446430
if item is None:
447431
continue

0 commit comments

Comments
 (0)