Skip to content

Commit 897d2e0

Browse files
Moving type ascriptions down to the method level
1 parent b7fe178 commit 897d2e0

File tree

2 files changed

+57
-66
lines changed

2 files changed

+57
-66
lines changed

src/replit_river/v2/client.py

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,6 @@
1414

1515
from replit_river.v2.client_transport import ClientTransport
1616
from replit_river.error_schema import ERROR_CODE_UNKNOWN, RiverError, RiverException
17-
from replit_river.rpc import (
18-
ErrorType,
19-
InitType,
20-
RequestType,
21-
ResponseType,
22-
)
2317
from replit_river.transport_options import (
2418
HandshakeMetadataType,
2519
TransportOptions,
@@ -91,16 +85,16 @@ async def close(self) -> None:
9185
async def ensure_connected(self) -> None:
9286
await self._transport.get_or_create_session()
9387

94-
async def send_rpc(
88+
async def send_rpc[R, A](
9589
self,
9690
service_name: str,
9791
procedure_name: str,
98-
request: RequestType,
99-
request_serializer: Callable[[RequestType], Any],
100-
response_deserializer: Callable[[Any], ResponseType],
101-
error_deserializer: Callable[[Any], ErrorType],
92+
request: R,
93+
request_serializer: Callable[[R], Any],
94+
response_deserializer: Callable[[Any], A],
95+
error_deserializer: Callable[[Any], RiverError],
10296
timeout: timedelta,
103-
) -> ResponseType:
97+
) -> A:
10498
with _trace_procedure("rpc", service_name, procedure_name) as span_handle:
10599
session = await self._transport.get_or_create_session()
106100
return await session.send_rpc(
@@ -114,17 +108,17 @@ async def send_rpc(
114108
timeout,
115109
)
116110

117-
async def send_upload(
111+
async def send_upload[I, R, A](
118112
self,
119113
service_name: str,
120114
procedure_name: str,
121-
init: InitType | None,
122-
request: AsyncIterable[RequestType],
123-
init_serializer: Callable[[InitType], Any] | None,
124-
request_serializer: Callable[[RequestType], Any],
125-
response_deserializer: Callable[[Any], ResponseType],
126-
error_deserializer: Callable[[Any], ErrorType],
127-
) -> ResponseType:
115+
init: I | None,
116+
request: AsyncIterable[R],
117+
init_serializer: Callable[[I], Any] | None,
118+
request_serializer: Callable[[R], Any],
119+
response_deserializer: Callable[[Any], A],
120+
error_deserializer: Callable[[Any], RiverError],
121+
) -> A:
128122
with _trace_procedure("upload", service_name, procedure_name) as span_handle:
129123
session = await self._transport.get_or_create_session()
130124
return await session.send_upload(
@@ -139,15 +133,15 @@ async def send_upload(
139133
span_handle.span,
140134
)
141135

142-
async def send_subscription(
136+
async def send_subscription[R, E, A](
143137
self,
144138
service_name: str,
145139
procedure_name: str,
146-
request: RequestType,
147-
request_serializer: Callable[[RequestType], Any],
148-
response_deserializer: Callable[[Any], ResponseType],
149-
error_deserializer: Callable[[Any], ErrorType],
150-
) -> AsyncGenerator[ResponseType | RiverError, None]:
140+
request: R,
141+
request_serializer: Callable[[R], Any],
142+
response_deserializer: Callable[[Any], A],
143+
error_deserializer: Callable[[Any], E],
144+
) -> AsyncGenerator[A | E, None]:
151145
with _trace_procedure(
152146
"subscription", service_name, procedure_name
153147
) as span_handle:
@@ -165,17 +159,17 @@ async def send_subscription(
165159
_record_river_error(span_handle, msg)
166160
yield msg # type: ignore # https://github.com/python/mypy/issues/10817
167161

168-
async def send_stream(
162+
async def send_stream[I, R, E, A](
169163
self,
170164
service_name: str,
171165
procedure_name: str,
172-
init: InitType | None,
173-
request: AsyncIterable[RequestType],
174-
init_serializer: Callable[[InitType], Any] | None,
175-
request_serializer: Callable[[RequestType], Any],
176-
response_deserializer: Callable[[Any], ResponseType],
177-
error_deserializer: Callable[[Any], ErrorType],
178-
) -> AsyncGenerator[ResponseType | RiverError, None]:
166+
init: I | None,
167+
request: AsyncIterable[R],
168+
init_serializer: Callable[[I], Any] | None,
169+
request_serializer: Callable[[R], Any],
170+
response_deserializer: Callable[[Any], A],
171+
error_deserializer: Callable[[Any], E],
172+
) -> AsyncGenerator[A | E, None]:
179173
with _trace_procedure("stream", service_name, procedure_name) as span_handle:
180174
session = await self._transport.get_or_create_session()
181175
async for msg in session.send_stream(

src/replit_river/v2/client_session.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from replit_river.error_schema import (
1616
ERROR_CODE_CANCEL,
1717
ERROR_CODE_STREAM_CLOSED,
18+
RiverError,
1819
RiverException,
1920
RiverServiceException,
2021
StreamClosedRiverServiceException,
@@ -28,10 +29,6 @@
2829
ACK_BIT,
2930
STREAM_CLOSED_BIT,
3031
STREAM_OPEN_BIT,
31-
ErrorType,
32-
InitType,
33-
RequestType,
34-
ResponseType,
3532
)
3633
from replit_river.seq_manager import (
3734
IgnoreMessageException,
@@ -168,17 +165,17 @@ async def _handle_messages_from_ws(self) -> None:
168165
except ConnectionClosed as e:
169166
raise e
170167

171-
async def send_rpc(
168+
async def send_rpc[R, A](
172169
self,
173170
service_name: str,
174171
procedure_name: str,
175-
request: RequestType,
176-
request_serializer: Callable[[RequestType], Any],
177-
response_deserializer: Callable[[Any], ResponseType],
178-
error_deserializer: Callable[[Any], ErrorType],
172+
request: R,
173+
request_serializer: Callable[[R], Any],
174+
response_deserializer: Callable[[Any], A],
175+
error_deserializer: Callable[[Any], RiverError],
179176
span: Span,
180177
timeout: timedelta,
181-
) -> ResponseType:
178+
) -> A:
182179
"""Sends a single RPC request to the server.
183180
184181
Expects the input and output be messages that will be msgpacked.
@@ -233,18 +230,18 @@ async def send_rpc(
233230
except Exception as e:
234231
raise e
235232

236-
async def send_upload(
233+
async def send_upload[I, R, A](
237234
self,
238235
service_name: str,
239236
procedure_name: str,
240-
init: InitType | None,
241-
request: AsyncIterable[RequestType],
242-
init_serializer: Callable[[InitType], Any] | None,
243-
request_serializer: Callable[[RequestType], Any],
244-
response_deserializer: Callable[[Any], ResponseType],
245-
error_deserializer: Callable[[Any], ErrorType],
237+
init: I | None,
238+
request: AsyncIterable[R],
239+
init_serializer: Callable[[I], Any] | None,
240+
request_serializer: Callable[[R], Any],
241+
response_deserializer: Callable[[Any], A],
242+
error_deserializer: Callable[[Any], RiverError],
246243
span: Span,
247-
) -> ResponseType:
244+
) -> A:
248245
"""Sends an upload request to the server.
249246
250247
Expects the input and output be messages that will be msgpacked.
@@ -320,16 +317,16 @@ async def send_upload(
320317
except Exception as e:
321318
raise e
322319

323-
async def send_subscription(
320+
async def send_subscription[R, E, A](
324321
self,
325322
service_name: str,
326323
procedure_name: str,
327-
request: RequestType,
328-
request_serializer: Callable[[RequestType], Any],
329-
response_deserializer: Callable[[Any], ResponseType],
330-
error_deserializer: Callable[[Any], ErrorType],
324+
request: R,
325+
request_serializer: Callable[[R], Any],
326+
response_deserializer: Callable[[Any], A],
327+
error_deserializer: Callable[[Any], E],
331328
span: Span,
332-
) -> AsyncGenerator[ResponseType | ErrorType, None]:
329+
) -> AsyncGenerator[A | E, None]:
333330
"""Sends a subscription request to the server.
334331
335332
Expects the input and output be messages that will be msgpacked.
@@ -372,18 +369,18 @@ async def send_subscription(
372369
finally:
373370
output.close()
374371

375-
async def send_stream(
372+
async def send_stream[I, R, E, A](
376373
self,
377374
service_name: str,
378375
procedure_name: str,
379-
init: InitType | None,
380-
request: AsyncIterable[RequestType],
381-
init_serializer: Callable[[InitType], Any] | None,
382-
request_serializer: Callable[[RequestType], Any],
383-
response_deserializer: Callable[[Any], ResponseType],
384-
error_deserializer: Callable[[Any], ErrorType],
376+
init: I | None,
377+
request: AsyncIterable[R],
378+
init_serializer: Callable[[I], Any] | None,
379+
request_serializer: Callable[[R], Any],
380+
response_deserializer: Callable[[Any], A],
381+
error_deserializer: Callable[[Any], E],
385382
span: Span,
386-
) -> AsyncGenerator[ResponseType | ErrorType, None]:
383+
) -> AsyncGenerator[A | E, None]:
387384
"""Sends a subscription request to the server.
388385
389386
Expects the input and output be messages that will be msgpacked.

0 commit comments

Comments
 (0)