2020import grpc
2121import grpc .aio
2222import pytest
23+ from typing_extensions import override
2324
2425import pytest_userver .client
26+ import pytest_userver .grpc
2527from . import _hookspec
2628
2729DEFAULT_TIMEOUT = 15.0
@@ -194,41 +196,37 @@ def pytest_addhooks(pluginmanager: pytest.PytestPluginManager):
194196 pluginmanager .add_hookspecs (_hookspec )
195197
196198
197- class _UpdateServerStateInterceptor (
199+ class _AsyncExcCheckInterceptor (
198200 grpc .aio .UnaryUnaryClientInterceptor ,
199201 grpc .aio .UnaryStreamClientInterceptor ,
200202 grpc .aio .StreamUnaryClientInterceptor ,
201203 grpc .aio .StreamStreamClientInterceptor ,
202204):
203- def __init__ (self , service_client : pytest_userver .client .Client , asyncexc_check : _AsyncExcCheck ):
204- self ._service_client = service_client
205+ def __init__ (self , asyncexc_check : _AsyncExcCheck ):
205206 self ._asyncexc_check = asyncexc_check
206207
207- async def _before_call_hook (self ) -> None :
208- self ._asyncexc_check ()
209- if hasattr (self ._service_client , 'update_server_state' ):
210- await self ._service_client .update_server_state ()
211-
208+ @override
212209 async def intercept_unary_unary (
213210 self ,
214211 continuation : Callable [[grpc .aio .ClientCallDetails , Any ], Awaitable [grpc .aio .UnaryUnaryCall ]],
215212 client_call_details : grpc .aio .ClientCallDetails ,
216213 request : Any ,
217214 ) -> grpc .aio .UnaryUnaryCall :
218- await self ._before_call_hook ()
215+ self ._asyncexc_check ()
219216 try :
220217 return await continuation (client_call_details , request )
221218 finally :
222219 self ._asyncexc_check ()
223220
224221 # Note: full type of this function is Callable[[...], Awaitable[AsyncIterator[Any]]]
222+ @override
225223 async def intercept_unary_stream (
226224 self ,
227225 continuation : Callable [[grpc .aio .ClientCallDetails , Any ], grpc .aio .UnaryStreamCall ],
228226 client_call_details : grpc .aio .ClientCallDetails ,
229227 request : Any ,
230228 ) -> AsyncIterator [Any ]:
231- await self ._before_call_hook ()
229+ self ._asyncexc_check ()
232230 call = await continuation (client_call_details , request )
233231
234232 async def response_stream () -> AsyncIterator [Any ]:
@@ -240,27 +238,28 @@ async def response_stream() -> AsyncIterator[Any]:
240238
241239 return response_stream ()
242240
241+ @override
243242 async def intercept_stream_unary (
244243 self ,
245244 continuation : Callable [[grpc .aio .ClientCallDetails , AsyncIterator [Any ]], Awaitable [grpc .aio .StreamUnaryCall ]],
246245 client_call_details : grpc .aio .ClientCallDetails ,
247246 request_iterator : AsyncIterator [Any ],
248247 ) -> grpc .aio .StreamUnaryCall :
249- await self ._before_call_hook ()
248+ self ._asyncexc_check ()
250249 try :
251250 return await continuation (client_call_details , request_iterator )
252251 finally :
253252 self ._asyncexc_check ()
254253
255254 # Note: full type of this function is Callable[[...], Awaitable[AsyncIterator[Any]]]
255+ @override
256256 async def intercept_stream_stream (
257257 self ,
258258 continuation : Callable [[grpc .aio .ClientCallDetails , AsyncIterator [Any ]], grpc .aio .StreamStreamCall ],
259259 client_call_details : grpc .aio .ClientCallDetails ,
260260 request_iterator : AsyncIterator [Any ],
261261 ) -> AsyncIterator [Any ]:
262262 self ._asyncexc_check ()
263- await self ._before_call_hook ()
264263 call = await continuation (client_call_details , request_iterator )
265264
266265 async def response_stream () -> AsyncIterator [Any ]:
@@ -273,12 +272,20 @@ async def response_stream() -> AsyncIterator[Any]:
273272 return response_stream ()
274273
275274
275+ class _UpdateServerStateInterceptor (pytest_userver .grpc .PreCallClientInterceptor ):
276+ def __init__ (self , service_client : pytest_userver .client .Client ):
277+ self ._service_client = service_client
278+
279+ @override
280+ async def pre_call_hook (self , client_call_details : grpc .aio .ClientCallDetails ) -> None :
281+ if hasattr (self ._service_client , 'update_server_state' ):
282+ await self ._service_client .update_server_state ()
283+
284+
276285def pytest_grpc_client_interceptors (request : pytest .FixtureRequest ) -> Sequence [grpc .aio .ClientInterceptor ]:
277286 return [
278- _UpdateServerStateInterceptor (
279- request .getfixturevalue ('service_client' ),
280- request .getfixturevalue ('asyncexc_check' ),
281- ),
287+ _AsyncExcCheckInterceptor (request .getfixturevalue ('asyncexc_check' )),
288+ _UpdateServerStateInterceptor (request .getfixturevalue ('service_client' )),
282289 ]
283290
284291
0 commit comments