22from collections .abc import AsyncIterable , AsyncIterator , Awaitable , Callable
33from typing import Any , Generic , Optional , Union
44
5+ from replit_river .client_interceptor import (
6+ ClientInterceptor ,
7+ ClientRpcDetails ,
8+ ClientStreamDetails ,
9+ ClientSubscriptionDetails ,
10+ ClientUploadDetails ,
11+ )
512from replit_river .client_transport import ClientTransport
613from replit_river .transport_options import (
714 HandshakeMetadataType ,
@@ -28,6 +35,7 @@ def __init__(
2835 client_id : str ,
2936 server_id : str ,
3037 transport_options : TransportOptions ,
38+ interceptors : list [ClientInterceptor ] = [],
3139 ) -> None :
3240 self ._client_id = client_id
3341 self ._server_id = server_id
@@ -37,6 +45,7 @@ def __init__(
3745 server_id = server_id ,
3846 transport_options = transport_options ,
3947 )
48+ self ._interceptors = interceptors
4049
4150 async def close (self ) -> None :
4251 logger .info (f"river client { self ._client_id } start closing" )
@@ -56,13 +65,34 @@ async def send_rpc(
5665 error_deserializer : Callable [[Any ], ErrorType ],
5766 ) -> ResponseType :
5867 session = await self ._transport .get_or_create_session ()
59- return await session .send_rpc (
60- service_name ,
61- procedure_name ,
62- request ,
63- request_serializer ,
64- response_deserializer ,
65- error_deserializer ,
68+
69+ async def _run_interceptor (
70+ details : ClientRpcDetails ,
71+ interceptors : list [ClientInterceptor ],
72+ ) -> ResponseType :
73+ if interceptors :
74+ head , tail = interceptors [0 ], interceptors [1 :]
75+ return await head .intercept_rpc ( # type: ignore
76+ details ,
77+ lambda details : _run_interceptor (details , tail ),
78+ )
79+ else :
80+ return await session .send_rpc (
81+ details .service_name ,
82+ details .procedure_name ,
83+ details .request ,
84+ request_serializer ,
85+ response_deserializer ,
86+ error_deserializer ,
87+ )
88+
89+ return await _run_interceptor (
90+ ClientRpcDetails (
91+ service_name = service_name ,
92+ procedure_name = procedure_name ,
93+ request = request ,
94+ ),
95+ self ._interceptors ,
6696 )
6797
6898 async def send_upload (
@@ -77,15 +107,35 @@ async def send_upload(
77107 error_deserializer : Callable [[Any ], ErrorType ],
78108 ) -> ResponseType :
79109 session = await self ._transport .get_or_create_session ()
80- return await session .send_upload (
81- service_name ,
82- procedure_name ,
83- init ,
84- request ,
85- init_serializer ,
86- request_serializer ,
87- response_deserializer ,
88- error_deserializer ,
110+
111+ async def _run_interceptor (
112+ details : ClientUploadDetails ,
113+ interceptors : list [ClientInterceptor ],
114+ ) -> ResponseType :
115+ if interceptors :
116+ head , tail = interceptors [0 ], interceptors [1 :]
117+ return await head .intercept_upload ( # type: ignore
118+ details , lambda details : _run_interceptor (details , tail )
119+ )
120+ else :
121+ return await session .send_upload (
122+ service_name ,
123+ procedure_name ,
124+ init ,
125+ request ,
126+ init_serializer ,
127+ request_serializer ,
128+ response_deserializer ,
129+ error_deserializer ,
130+ )
131+
132+ return await _run_interceptor (
133+ ClientUploadDetails (
134+ service_name = service_name ,
135+ procedure_name = procedure_name ,
136+ init = init ,
137+ ),
138+ self ._interceptors ,
89139 )
90140
91141 async def send_subscription (
@@ -98,13 +148,33 @@ async def send_subscription(
98148 error_deserializer : Callable [[Any ], ErrorType ],
99149 ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
100150 session = await self ._transport .get_or_create_session ()
101- return session .send_subscription (
102- service_name ,
103- procedure_name ,
104- request ,
105- request_serializer ,
106- response_deserializer ,
107- error_deserializer ,
151+
152+ async def _run_interceptor (
153+ details : ClientSubscriptionDetails ,
154+ interceptors : list [ClientInterceptor ],
155+ ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
156+ if interceptors :
157+ head , tail = interceptors [0 ], interceptors [1 :]
158+ return await head .intercept_subscription ( # type: ignore
159+ details , lambda details : _run_interceptor (details , tail )
160+ )
161+ else :
162+ return session .send_subscription (
163+ service_name ,
164+ procedure_name ,
165+ request ,
166+ request_serializer ,
167+ response_deserializer ,
168+ error_deserializer ,
169+ )
170+
171+ return await _run_interceptor (
172+ ClientSubscriptionDetails (
173+ service_name = service_name ,
174+ procedure_name = procedure_name ,
175+ request = request ,
176+ ),
177+ self ._interceptors ,
108178 )
109179
110180 async def send_stream (
@@ -119,13 +189,33 @@ async def send_stream(
119189 error_deserializer : Callable [[Any ], ErrorType ],
120190 ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
121191 session = await self ._transport .get_or_create_session ()
122- return session .send_stream (
123- service_name ,
124- procedure_name ,
125- init ,
126- request ,
127- init_serializer ,
128- request_serializer ,
129- response_deserializer ,
130- error_deserializer ,
192+
193+ async def _run_interceptor (
194+ details : ClientStreamDetails ,
195+ interceptors : list [ClientInterceptor ],
196+ ) -> AsyncIterator [Union [ResponseType , ErrorType ]]:
197+ if interceptors :
198+ head , tail = interceptors [0 ], interceptors [1 :]
199+ return await head .intercept_stream ( # type: ignore
200+ details , lambda details : _run_interceptor (details , tail )
201+ )
202+ else :
203+ return session .send_stream (
204+ service_name ,
205+ procedure_name ,
206+ init ,
207+ request ,
208+ init_serializer ,
209+ request_serializer ,
210+ response_deserializer ,
211+ error_deserializer ,
212+ )
213+
214+ return await _run_interceptor (
215+ ClientStreamDetails (
216+ service_name = service_name ,
217+ procedure_name = procedure_name ,
218+ init = init ,
219+ ),
220+ self ._interceptors ,
131221 )
0 commit comments