22from collections .abc import AsyncIterable , AsyncIterator , Awaitable , Callable
33from typing import Any , Generic , Optional , Union
44
5+ from replit_river .client_interceptor import (
6+ ClientInterceptor ,
7+ ClientRpcDetails ,
8+ ClientStreamDetails ,
9+ ClientSubscriptionDetails ,
10+ ClientUploadDetails ,
11+ )
512from replit_river .client_transport import ClientTransport
613from replit_river .transport_options import (
714 HandshakeMetadataType ,
@@ -28,6 +35,7 @@ def __init__(
2835 client_id : str ,
2936 server_id : str ,
3037 transport_options : TransportOptions ,
38+ interceptors : list [ClientInterceptor ] = [],
3139 ) -> None :
3240 self ._client_id = client_id
3341 self ._server_id = server_id
@@ -37,6 +45,7 @@ def __init__(
3745 server_id = server_id ,
3846 transport_options = transport_options ,
3947 )
48+ self ._interceptors = interceptors
4049
4150 async def close (self ) -> None :
4251 logger .info (f"river client { self ._client_id } start closing" )
@@ -56,13 +65,32 @@ async def send_rpc(
5665 error_deserializer : Callable [[Any ], ErrorType ],
5766 ) -> ResponseType :
5867 session = await self ._transport .get_or_create_session ()
59- return await session .send_rpc (
60- service_name ,
61- procedure_name ,
62- request ,
63- request_serializer ,
64- response_deserializer ,
65- error_deserializer ,
68+
69+ async def _run_interceptor (
70+ details : ClientRpcDetails ,
71+ interceptors : list [ClientInterceptor ],
72+ ) -> ResponseType :
73+ if interceptors :
74+ return await interceptors [0 ].intercept_rpc (
75+ details , lambda details : _run_interceptor (details , interceptors [1 :])
76+ )
77+ else :
78+ return await session .send_rpc (
79+ details .service_name ,
80+ details .procedure_name ,
81+ details .request ,
82+ request_serializer ,
83+ response_deserializer ,
84+ error_deserializer ,
85+ )
86+
87+ return await _run_interceptor (
88+ ClientRpcDetails (
89+ service_name = service_name ,
90+ procedure_name = procedure_name ,
91+ request = request ,
92+ ),
93+ self ._interceptors ,
6694 )
6795
6896 async def send_upload (
@@ -77,15 +105,34 @@ async def send_upload(
77105 error_deserializer : Callable [[Any ], ErrorType ],
78106 ) -> ResponseType :
79107 session = await self ._transport .get_or_create_session ()
80- return await session .send_upload (
81- service_name ,
82- procedure_name ,
83- init ,
84- request ,
85- init_serializer ,
86- request_serializer ,
87- response_deserializer ,
88- error_deserializer ,
108+
109+ async def _run_interceptor (
110+ details : ClientUploadDetails ,
111+ interceptors : list [ClientInterceptor ],
112+ ) -> ResponseType :
113+ if interceptors :
114+ return await interceptors [0 ].intercept_upload (
115+ details , lambda details : _run_interceptor (details , interceptors [1 :])
116+ )
117+ else :
118+ return await session .send_upload (
119+ service_name ,
120+ procedure_name ,
121+ init ,
122+ request ,
123+ init_serializer ,
124+ request_serializer ,
125+ response_deserializer ,
126+ error_deserializer ,
127+ )
128+
129+ return await _run_interceptor (
130+ ClientUploadDetails (
131+ service_name = service_name ,
132+ procedure_name = procedure_name ,
133+ init = init ,
134+ ),
135+ self ._interceptors ,
89136 )
90137
91138 async def send_subscription (
@@ -98,13 +145,32 @@ async def send_subscription(
98145 error_deserializer : Callable [[Any ], ErrorType ],
99146 ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
100147 session = await self ._transport .get_or_create_session ()
101- return session .send_subscription (
102- service_name ,
103- procedure_name ,
104- request ,
105- request_serializer ,
106- response_deserializer ,
107- error_deserializer ,
148+
149+ async def _run_interceptor (
150+ details : ClientSubscriptionDetails ,
151+ interceptors : list [ClientInterceptor ],
152+ ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
153+ if interceptors :
154+ return await interceptors [0 ].intercept_subscription (
155+ details , lambda details : _run_interceptor (details , interceptors [1 :])
156+ )
157+ else :
158+ return session .send_subscription (
159+ service_name ,
160+ procedure_name ,
161+ request ,
162+ request_serializer ,
163+ response_deserializer ,
164+ error_deserializer ,
165+ )
166+
167+ return await _run_interceptor (
168+ ClientSubscriptionDetails (
169+ service_name = service_name ,
170+ procedure_name = procedure_name ,
171+ request = request ,
172+ ),
173+ self ._interceptors ,
108174 )
109175
110176 async def send_stream (
@@ -119,13 +185,32 @@ async def send_stream(
119185 error_deserializer : Callable [[Any ], ErrorType ],
120186 ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
121187 session = await self ._transport .get_or_create_session ()
122- return session .send_stream (
123- service_name ,
124- procedure_name ,
125- init ,
126- request ,
127- init_serializer ,
128- request_serializer ,
129- response_deserializer ,
130- error_deserializer ,
188+
189+ async def _run_interceptor (
190+ details : ClientStreamDetails ,
191+ interceptors : list [ClientInterceptor ],
192+ ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
193+ if interceptors :
194+ return await interceptors [0 ].intercept_stream (
195+ details , lambda details : _run_interceptor (details , interceptors [1 :])
196+ )
197+ else :
198+ return session .send_stream (
199+ service_name ,
200+ procedure_name ,
201+ init ,
202+ request ,
203+ init_serializer ,
204+ request_serializer ,
205+ response_deserializer ,
206+ error_deserializer ,
207+ )
208+
209+ return await _run_interceptor (
210+ ClientStreamDetails (
211+ service_name = service_name ,
212+ procedure_name = procedure_name ,
213+ init = init ,
214+ ),
215+ self ._interceptors ,
131216 )
0 commit comments