1+ import functools
12import logging
23from collections .abc import AsyncIterable , AsyncIterator , Awaitable , Callable
34from typing import Any , Generic , Optional , Union
45
6+ from replit_river .client_interceptor import (
7+ ClientRpcDetails ,
8+ ClientInterceptor ,
9+ ClientStreamDetails ,
10+ ClientSubscriptionDetails ,
11+ ClientUploadDetails ,
12+ )
513from replit_river .client_transport import ClientTransport
614from replit_river .transport_options import (
715 HandshakeMetadataType ,
@@ -28,6 +36,7 @@ def __init__(
2836 client_id : str ,
2937 server_id : str ,
3038 transport_options : TransportOptions ,
39+ interceptors : list [ClientInterceptor ] = [],
3140 ) -> None :
3241 self ._client_id = client_id
3342 self ._server_id = server_id
@@ -37,6 +46,7 @@ def __init__(
3746 server_id = server_id ,
3847 transport_options = transport_options ,
3948 )
49+ self ._interceptors = interceptors
4050
4151 async def close (self ) -> None :
4252 logger .info (f"river client { self ._client_id } start closing" )
@@ -56,13 +66,32 @@ async def send_rpc(
5666 error_deserializer : Callable [[Any ], ErrorType ],
5767 ) -> ResponseType :
5868 session = await self ._transport .get_or_create_session ()
59- return await session .send_rpc (
60- service_name ,
61- procedure_name ,
62- request ,
63- request_serializer ,
64- response_deserializer ,
65- error_deserializer ,
69+
70+ async def _run_interceptor (
71+ details : ClientRpcDetails ,
72+ interceptors : list [ClientInterceptor ],
73+ ) -> ResponseType :
74+ if interceptors :
75+ return await interceptors [0 ].intercept_rpc (
76+ details , lambda details : _run_interceptor (details , interceptors [1 :])
77+ )
78+ else :
79+ return await session .send_rpc (
80+ details .service_name ,
81+ details .procedure_name ,
82+ details .request ,
83+ request_serializer ,
84+ response_deserializer ,
85+ error_deserializer ,
86+ )
87+
88+ return await _run_interceptor (
89+ ClientRpcDetails (
90+ service_name = service_name ,
91+ procedure_name = procedure_name ,
92+ request = request ,
93+ ),
94+ self ._interceptors ,
6695 )
6796
6897 async def send_upload (
@@ -77,15 +106,34 @@ async def send_upload(
77106 error_deserializer : Callable [[Any ], ErrorType ],
78107 ) -> ResponseType :
79108 session = await self ._transport .get_or_create_session ()
80- return await session .send_upload (
81- service_name ,
82- procedure_name ,
83- init ,
84- request ,
85- init_serializer ,
86- request_serializer ,
87- response_deserializer ,
88- error_deserializer ,
109+
110+ async def _run_interceptor (
111+ details : ClientUploadDetails ,
112+ interceptors : list [ClientInterceptor ],
113+ ) -> ResponseType :
114+ if interceptors :
115+ return await interceptors [0 ].intercept_upload (
116+ details , lambda details : _run_interceptor (details , interceptors [1 :])
117+ )
118+ else :
119+ return await session .send_upload (
120+ service_name ,
121+ procedure_name ,
122+ init ,
123+ request ,
124+ init_serializer ,
125+ request_serializer ,
126+ response_deserializer ,
127+ error_deserializer ,
128+ )
129+
130+ return await _run_interceptor (
131+ ClientUploadDetails (
132+ service_name = service_name ,
133+ procedure_name = procedure_name ,
134+ init = init ,
135+ ),
136+ self ._interceptors ,
89137 )
90138
91139 async def send_subscription (
@@ -98,13 +146,32 @@ async def send_subscription(
98146 error_deserializer : Callable [[Any ], ErrorType ],
99147 ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
100148 session = await self ._transport .get_or_create_session ()
101- return session .send_subscription (
102- service_name ,
103- procedure_name ,
104- request ,
105- request_serializer ,
106- response_deserializer ,
107- error_deserializer ,
149+
150+ async def _run_interceptor (
151+ details : ClientSubscriptionDetails ,
152+ interceptors : list [ClientInterceptor ],
153+ ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
154+ if interceptors :
155+ return await interceptors [0 ].intercept_subscription (
156+ details , lambda details : _run_interceptor (details , interceptors [1 :])
157+ )
158+ else :
159+ return session .send_subscription (
160+ service_name ,
161+ procedure_name ,
162+ request ,
163+ request_serializer ,
164+ response_deserializer ,
165+ error_deserializer ,
166+ )
167+
168+ return await _run_interceptor (
169+ ClientSubscriptionDetails (
170+ service_name = service_name ,
171+ procedure_name = procedure_name ,
172+ request = request ,
173+ ),
174+ self ._interceptors ,
108175 )
109176
110177 async def send_stream (
@@ -119,13 +186,32 @@ async def send_stream(
119186 error_deserializer : Callable [[Any ], ErrorType ],
120187 ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
121188 session = await self ._transport .get_or_create_session ()
122- return session .send_stream (
123- service_name ,
124- procedure_name ,
125- init ,
126- request ,
127- init_serializer ,
128- request_serializer ,
129- response_deserializer ,
130- error_deserializer ,
189+
190+ async def _run_interceptor (
191+ details : ClientStreamDetails ,
192+ interceptors : list [ClientInterceptor ],
193+ ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
194+ if interceptors :
195+ return await interceptors [0 ].intercept_stream (
196+ details , lambda details : _run_interceptor (details , interceptors [1 :])
197+ )
198+ else :
199+ return session .send_stream (
200+ service_name ,
201+ procedure_name ,
202+ init ,
203+ request ,
204+ init_serializer ,
205+ request_serializer ,
206+ response_deserializer ,
207+ error_deserializer ,
208+ )
209+
210+ return await _run_interceptor (
211+ ClientStreamDetails (
212+ service_name = service_name ,
213+ procedure_name = procedure_name ,
214+ init = init ,
215+ ),
216+ self ._interceptors ,
131217 )
0 commit comments