66
77# pylint: disable=redefined-outer-name
88import asyncio
9- from collections .abc import AsyncIterable
9+ from collections .abc import AsyncIterator
1010from collections .abc import Awaitable
1111from collections .abc import Callable
12+ from collections .abc import Generator
13+ from collections .abc import Sequence
14+ import itertools
1215import pathlib
1316import tempfile
17+ from typing import Any
1418from typing import TypeAlias
1519
16- import google .protobuf .message
1720import grpc
21+ import grpc .aio
1822import pytest
1923
24+ import pytest_userver .client
25+ from . import _hookspec
26+
2027DEFAULT_TIMEOUT = 15.0
2128USERVER_CONFIG_HOOKS = ['userver_config_grpc_endpoint' ]
2229
23- MessageOrStream : TypeAlias = google .protobuf .message .Message | AsyncIterable [google .protobuf .message .Message ]
2430_AsyncExcCheck : TypeAlias = Callable [[], None ]
2531
2632
@@ -72,40 +78,9 @@ def grpc_service_timeout(pytestconfig) -> float:
7278 return float (pytestconfig .option .service_timeout ) or DEFAULT_TIMEOUT
7379
7480
75- @pytest .fixture
76- def grpc_client_prepare (
77- service_client ,
78- asyncexc_check ,
79- ) -> Callable [[grpc .aio .ClientCallDetails , MessageOrStream ], Awaitable [None ]]:
80- """
81- Returns the function that will be called in before each gRPC request,
82- client-side.
83-
84- @ingroup userver_testsuite_fixtures
85- """
86-
87- async def prepare (
88- _client_call_details : grpc .aio .ClientCallDetails ,
89- _request_or_stream : MessageOrStream ,
90- / ,
91- ) -> None :
92- asyncexc_check ()
93- if hasattr (service_client , 'update_server_state' ):
94- await service_client .update_server_state ()
95-
96- return prepare
97-
98-
9981@pytest .fixture (scope = 'session' )
100- async def grpc_session_channel (
101- grpc_service_endpoint ,
102- _grpc_channel_interceptor ,
103- _grpc_channel_interceptor_asyncexc ,
104- ):
105- async with grpc .aio .insecure_channel (
106- grpc_service_endpoint ,
107- interceptors = [_grpc_channel_interceptor , _grpc_channel_interceptor_asyncexc ],
108- ) as channel :
82+ async def grpc_session_channel (grpc_service_endpoint ):
83+ async with grpc .aio .insecure_channel (grpc_service_endpoint ) as channel :
10984 yield channel
11085
11186
@@ -115,19 +90,23 @@ async def grpc_channel(
11590 grpc_service_endpoint ,
11691 grpc_service_timeout ,
11792 grpc_session_channel ,
118- _grpc_channel_interceptor ,
119- grpc_client_prepare ,
120- _grpc_channel_interceptor_asyncexc ,
121- asyncexc_check ,
93+ request ,
12294):
12395 """
12496 Returns the gRPC channel configured by the parameters from the
12597 @ref pytest_userver.plugins.grpc.client.grpc_service_endpoint "grpc_service_endpoint" fixture.
12698
99+ You can add interceptors to the channel by implementing the
100+ @ref pytest_userver.plugins.grpc._hookspec.pytest_grpc_client_interceptors "pytest_grpc_client_interceptors"
101+ hook in your pytest plugin or initial (root) conftest.
102+
127103 @ingroup userver_testsuite_fixtures
128104 """
129- _grpc_channel_interceptor .prepare_func = grpc_client_prepare
130- _grpc_channel_interceptor_asyncexc .asyncexc_check = asyncexc_check
105+ interceptors = request .config .hook .pytest_grpc_client_interceptors (request = request )
106+ interceptors_list = list (itertools .chain .from_iterable (interceptors ))
107+ # Sanity check: we have at least one "builtin" interceptor.
108+ assert len (interceptors_list ) != 0
109+
131110 try :
132111 await asyncio .wait_for (
133112 grpc_session_channel .channel_ready (),
@@ -137,11 +116,16 @@ async def grpc_channel(
137116 raise RuntimeError (
138117 f'Failed to connect to remote gRPC server by address { grpc_service_endpoint } ' ,
139118 )
140- return grpc_session_channel
119+
120+ _set_client_interceptors (grpc_session_channel , interceptors_list )
121+ try :
122+ yield grpc_session_channel
123+ finally :
124+ _set_client_interceptors (grpc_session_channel , [])
141125
142126
143127@pytest .fixture (scope = 'session' )
144- def grpc_socket_path () -> pathlib .Path | None :
128+ def grpc_socket_path () -> Generator [ pathlib .Path , None , None ] :
145129 """
146130 Path for the UNIX socket over which testsuite will talk to the gRPC service, if it chooses to use a UNIX socket.
147131
@@ -203,56 +187,119 @@ def patch_config(config_yaml, config_vars):
203187 return patch_config
204188
205189
206- # Taken from
207- # https://github.com/grpc/grpc/blob/master/examples/python/interceptors/headers/generic_client_interceptor.py
208- class _GenericClientInterceptor (
190+ # @cond
191+
192+
193+ def pytest_addhooks (pluginmanager : pytest .PytestPluginManager ):
194+ pluginmanager .add_hookspecs (_hookspec )
195+
196+
197+ class _UpdateServerStateInterceptor (
209198 grpc .aio .UnaryUnaryClientInterceptor ,
210199 grpc .aio .UnaryStreamClientInterceptor ,
211200 grpc .aio .StreamUnaryClientInterceptor ,
212201 grpc .aio .StreamStreamClientInterceptor ,
213202):
214- def __init__ (self ):
215- self .prepare_func : Callable [[grpc .aio .ClientCallDetails , MessageOrStream ], Awaitable [None ]] | None = None
203+ def __init__ (self , service_client : pytest_userver .client .Client , asyncexc_check : _AsyncExcCheck ):
204+ self ._service_client = service_client
205+ self ._asyncexc_check = asyncexc_check
206+
207+ async def _before_call_hook (self ) -> None :
208+ self ._asyncexc_check ()
209+ if hasattr (self ._service_client , 'update_server_state' ):
210+ await self ._service_client .update_server_state ()
211+
212+ async def intercept_unary_unary (
213+ self ,
214+ continuation : Callable [[grpc .aio .ClientCallDetails , Any ], Awaitable [grpc .aio .UnaryUnaryCall ]],
215+ client_call_details : grpc .aio .ClientCallDetails ,
216+ request : Any ,
217+ ) -> grpc .aio .UnaryUnaryCall :
218+ await self ._before_call_hook ()
219+ try :
220+ return await continuation (client_call_details , request )
221+ finally :
222+ self ._asyncexc_check ()
223+
224+ # Note: full type of this function is Callable[[...], Awaitable[AsyncIterator[Any]]]
225+ async def intercept_unary_stream (
226+ self ,
227+ continuation : Callable [[grpc .aio .ClientCallDetails , Any ], grpc .aio .UnaryStreamCall ],
228+ client_call_details : grpc .aio .ClientCallDetails ,
229+ request : Any ,
230+ ) -> AsyncIterator [Any ]:
231+ await self ._before_call_hook ()
232+ call = await continuation (client_call_details , request )
233+
234+ async def response_stream () -> AsyncIterator [Any ]:
235+ try :
236+ async for response in call :
237+ yield response
238+ finally :
239+ self ._asyncexc_check ()
240+
241+ return response_stream ()
242+
243+ async def intercept_stream_unary (
244+ self ,
245+ continuation : Callable [[grpc .aio .ClientCallDetails , AsyncIterator [Any ]], Awaitable [grpc .aio .StreamUnaryCall ]],
246+ client_call_details : grpc .aio .ClientCallDetails ,
247+ request_iterator : AsyncIterator [Any ],
248+ ) -> grpc .aio .StreamUnaryCall :
249+ await self ._before_call_hook ()
250+ try :
251+ return await continuation (client_call_details , request_iterator )
252+ finally :
253+ self ._asyncexc_check ()
216254
217- async def intercept_unary_unary (self , continuation , client_call_details , request ):
218- await self .prepare_func (client_call_details , request )
219- return await continuation (client_call_details , request )
255+ # Note: full type of this function is Callable[[...], Awaitable[AsyncIterator[Any]]]
256+ async def intercept_stream_stream (
257+ self ,
258+ continuation : Callable [[grpc .aio .ClientCallDetails , AsyncIterator [Any ]], grpc .aio .StreamStreamCall ],
259+ client_call_details : grpc .aio .ClientCallDetails ,
260+ request_iterator : AsyncIterator [Any ],
261+ ) -> AsyncIterator [Any ]:
262+ self ._asyncexc_check ()
263+ await self ._before_call_hook ()
264+ call = await continuation (client_call_details , request_iterator )
220265
221- async def intercept_unary_stream (self , continuation , client_call_details , request ):
222- await self .prepare_func (client_call_details , request )
223- return continuation (client_call_details , next (request ))
266+ async def response_stream () -> AsyncIterator [Any ]:
267+ try :
268+ async for response in call :
269+ yield response
270+ finally :
271+ self ._asyncexc_check ()
224272
225- async def intercept_stream_unary (self , continuation , client_call_details , request_iterator ):
226- await self .prepare_func (client_call_details , request_iterator )
227- return await continuation (client_call_details , request_iterator )
273+ return response_stream ()
228274
229- async def intercept_stream_stream (self , continuation , client_call_details , request_iterator ):
230- await self .prepare_func (client_call_details , request_iterator )
231- return continuation (client_call_details , request_iterator )
232275
276+ def pytest_grpc_client_interceptors (request : pytest .FixtureRequest ) -> Sequence [grpc .aio .ClientInterceptor ]:
277+ return [
278+ _UpdateServerStateInterceptor (
279+ request .getfixturevalue ('service_client' ),
280+ request .getfixturevalue ('asyncexc_check' ),
281+ ),
282+ ]
233283
234- @pytest .fixture (scope = 'session' )
235- def _grpc_channel_interceptor (daemon_scoped_mark ) -> _GenericClientInterceptor :
236- return _GenericClientInterceptor ()
237284
285+ def _filter_interceptors (
286+ interceptors : Sequence [grpc .aio .ClientInterceptor ], desired_type : type [grpc .aio .ClientInterceptor ]
287+ ) -> list [grpc .aio .ClientInterceptor ]:
288+ return [interceptor for interceptor in interceptors if isinstance (interceptor , desired_type )]
238289
239- class _AsyncExcClientInterceptor (grpc .aio .UnaryUnaryClientInterceptor , grpc .aio .StreamUnaryClientInterceptor ):
240- def __init__ (self ):
241- self .asyncexc_check : _AsyncExcCheck | None = None
242290
243- async def intercept_unary_unary (self , continuation , client_call_details , request ):
244- try :
245- return await continuation (client_call_details , request )
246- finally :
247- self .asyncexc_check ()
291+ def _set_client_interceptors (channel : grpc .aio .Channel , interceptors : Sequence [grpc .aio .ClientInterceptor ]) -> None :
292+ """
293+ Allows to set interceptors dynamically while reusing the same underlying channel,
294+ which is something grpc-io currently doesn't support.
248295
249- async def intercept_stream_unary (self , continuation , client_call_details , request_iterator ):
250- try :
251- return await continuation (client_call_details , request_iterator )
252- finally :
253- self .asyncexc_check ()
296+ Also fixes the bug: multi-inheritance interceptors are only registered for first matching type
297+ https://github.com/grpc/grpc/issues/31442
298+ """
299+ channel ._unary_unary_interceptors = _filter_interceptors (interceptors , grpc .aio .UnaryUnaryClientInterceptor )
300+ channel ._unary_stream_interceptors = _filter_interceptors (interceptors , grpc .aio .UnaryStreamClientInterceptor )
301+ channel ._stream_unary_interceptors = _filter_interceptors (interceptors , grpc .aio .StreamUnaryClientInterceptor )
302+ channel ._stream_stream_interceptors = _filter_interceptors (interceptors , grpc .aio .StreamStreamClientInterceptor )
254303
255304
256- @pytest .fixture (scope = 'session' )
257- def _grpc_channel_interceptor_asyncexc () -> _AsyncExcClientInterceptor :
258- return _AsyncExcClientInterceptor ()
305+ # @endcond
0 commit comments