Skip to content

Commit 89d8f7b

Browse files
committed
fix testsuite: auto-call update_server_state before streaming grpc requests
commit_hash:3f052a98a4988d1864ef796e0129ea61c5080f8f
1 parent a3ee428 commit 89d8f7b

File tree

8 files changed

+152
-85
lines changed

8 files changed

+152
-85
lines changed

.mapping.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4755,6 +4755,7 @@
47554755
"testsuite/pytest_plugins/pytest_userver/plugins/dumps.py":"taxi/uservices/userver/testsuite/pytest_plugins/pytest_userver/plugins/dumps.py",
47564756
"testsuite/pytest_plugins/pytest_userver/plugins/dynamic_config.py":"taxi/uservices/userver/testsuite/pytest_plugins/pytest_userver/plugins/dynamic_config.py",
47574757
"testsuite/pytest_plugins/pytest_userver/plugins/grpc/__init__.py":"taxi/uservices/userver/testsuite/pytest_plugins/pytest_userver/plugins/grpc/__init__.py",
4758+
"testsuite/pytest_plugins/pytest_userver/plugins/grpc/_hookspec.py":"taxi/uservices/userver/testsuite/pytest_plugins/pytest_userver/plugins/grpc/_hookspec.py",
47584759
"testsuite/pytest_plugins/pytest_userver/plugins/grpc/client.py":"taxi/uservices/userver/testsuite/pytest_plugins/pytest_userver/plugins/grpc/client.py",
47594760
"testsuite/pytest_plugins/pytest_userver/plugins/grpc/mockserver.py":"taxi/uservices/userver/testsuite/pytest_plugins/pytest_userver/plugins/grpc/mockserver.py",
47604761
"testsuite/pytest_plugins/pytest_userver/plugins/kafka.py":"taxi/uservices/userver/testsuite/pytest_plugins/pytest_userver/plugins/kafka.py",

testsuite/pytest_plugins/pytest_userver/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class PeriodicTaskFailed(BaseError):
8080

8181

8282
class PeriodicTasksState:
83-
def __init__(self):
83+
def __init__(self) -> None:
8484
self.suspended_tasks: set[str] = set()
8585
self.tasks_to_suspend: set[str] = set()
8686

testsuite/pytest_plugins/pytest_userver/plugins/caches.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def pytest_plugin_registered(self, plugin, manager):
4141

4242

4343
class InvalidationState:
44-
def __init__(self):
44+
def __init__(self) -> None:
4545
# None means that we should update all caches.
4646
# We invalidate all caches at the start of each test.
4747
self._invalidated_caches: set[str] | None = None

testsuite/pytest_plugins/pytest_userver/plugins/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def userver_config_http_client(
529529
@ingroup userver_testsuite_fixtures
530530
"""
531531

532-
def patch_config(config, config_vars):
532+
def patch_config(config, config_vars) -> None:
533533
components: dict = config['components_manager']['components']
534534
if not {'http-client-core', 'testsuite-support'}.issubset(
535535
components.keys(),
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""userver pytest hooks for testing gRPC clients and services."""
2+
3+
from collections.abc import Sequence
4+
5+
import grpc.aio
6+
import pytest
7+
8+
9+
@pytest.hookspec
10+
def pytest_grpc_client_interceptors(request: pytest.FixtureRequest) -> Sequence[grpc.aio.ClientInterceptor]:
11+
"""
12+
A pytest hook that returns gRPC client interceptors to use when making requests to the service.
13+
14+
Interceptors are accomulated over all implementations of this hook from all plugins.
15+
16+
@see @ref scripts/docs/en/userver/grpc/grpc.md
17+
@ingroup userver_testsuite_fixtures
18+
"""
19+
raise NotImplementedError()

testsuite/pytest_plugins/pytest_userver/plugins/grpc/client.py

Lines changed: 127 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,27 @@
66

77
# pylint: disable=redefined-outer-name
88
import asyncio
9-
from collections.abc import AsyncIterable
9+
from collections.abc import AsyncIterator
1010
from collections.abc import Awaitable
1111
from collections.abc import Callable
12+
from collections.abc import Generator
13+
from collections.abc import Sequence
14+
import itertools
1215
import pathlib
1316
import tempfile
17+
from typing import Any
1418
from typing import TypeAlias
1519

16-
import google.protobuf.message
1720
import grpc
21+
import grpc.aio
1822
import pytest
1923

24+
import pytest_userver.client
25+
from . import _hookspec
26+
2027
DEFAULT_TIMEOUT = 15.0
2128
USERVER_CONFIG_HOOKS = ['userver_config_grpc_endpoint']
2229

23-
MessageOrStream: TypeAlias = google.protobuf.message.Message | AsyncIterable[google.protobuf.message.Message]
2430
_AsyncExcCheck: TypeAlias = Callable[[], None]
2531

2632

@@ -72,40 +78,9 @@ def grpc_service_timeout(pytestconfig) -> float:
7278
return float(pytestconfig.option.service_timeout) or DEFAULT_TIMEOUT
7379

7480

75-
@pytest.fixture
76-
def grpc_client_prepare(
77-
service_client,
78-
asyncexc_check,
79-
) -> Callable[[grpc.aio.ClientCallDetails, MessageOrStream], Awaitable[None]]:
80-
"""
81-
Returns the function that will be called in before each gRPC request,
82-
client-side.
83-
84-
@ingroup userver_testsuite_fixtures
85-
"""
86-
87-
async def prepare(
88-
_client_call_details: grpc.aio.ClientCallDetails,
89-
_request_or_stream: MessageOrStream,
90-
/,
91-
) -> None:
92-
asyncexc_check()
93-
if hasattr(service_client, 'update_server_state'):
94-
await service_client.update_server_state()
95-
96-
return prepare
97-
98-
9981
@pytest.fixture(scope='session')
100-
async def grpc_session_channel(
101-
grpc_service_endpoint,
102-
_grpc_channel_interceptor,
103-
_grpc_channel_interceptor_asyncexc,
104-
):
105-
async with grpc.aio.insecure_channel(
106-
grpc_service_endpoint,
107-
interceptors=[_grpc_channel_interceptor, _grpc_channel_interceptor_asyncexc],
108-
) as channel:
82+
async def grpc_session_channel(grpc_service_endpoint):
83+
async with grpc.aio.insecure_channel(grpc_service_endpoint) as channel:
10984
yield channel
11085

11186

@@ -115,19 +90,23 @@ async def grpc_channel(
11590
grpc_service_endpoint,
11691
grpc_service_timeout,
11792
grpc_session_channel,
118-
_grpc_channel_interceptor,
119-
grpc_client_prepare,
120-
_grpc_channel_interceptor_asyncexc,
121-
asyncexc_check,
93+
request,
12294
):
12395
"""
12496
Returns the gRPC channel configured by the parameters from the
12597
@ref pytest_userver.plugins.grpc.client.grpc_service_endpoint "grpc_service_endpoint" fixture.
12698
99+
You can add interceptors to the channel by implementing the
100+
@ref pytest_userver.plugins.grpc._hookspec.pytest_grpc_client_interceptors "pytest_grpc_client_interceptors"
101+
hook in your pytest plugin or initial (root) conftest.
102+
127103
@ingroup userver_testsuite_fixtures
128104
"""
129-
_grpc_channel_interceptor.prepare_func = grpc_client_prepare
130-
_grpc_channel_interceptor_asyncexc.asyncexc_check = asyncexc_check
105+
interceptors = request.config.hook.pytest_grpc_client_interceptors(request=request)
106+
interceptors_list = list(itertools.chain.from_iterable(interceptors))
107+
# Sanity check: we have at least one "builtin" interceptor.
108+
assert len(interceptors_list) != 0
109+
131110
try:
132111
await asyncio.wait_for(
133112
grpc_session_channel.channel_ready(),
@@ -137,11 +116,16 @@ async def grpc_channel(
137116
raise RuntimeError(
138117
f'Failed to connect to remote gRPC server by address {grpc_service_endpoint}',
139118
)
140-
return grpc_session_channel
119+
120+
_set_client_interceptors(grpc_session_channel, interceptors_list)
121+
try:
122+
yield grpc_session_channel
123+
finally:
124+
_set_client_interceptors(grpc_session_channel, [])
141125

142126

143127
@pytest.fixture(scope='session')
144-
def grpc_socket_path() -> pathlib.Path | None:
128+
def grpc_socket_path() -> Generator[pathlib.Path, None, None]:
145129
"""
146130
Path for the UNIX socket over which testsuite will talk to the gRPC service, if it chooses to use a UNIX socket.
147131
@@ -203,56 +187,119 @@ def patch_config(config_yaml, config_vars):
203187
return patch_config
204188

205189

206-
# Taken from
207-
# https://github.com/grpc/grpc/blob/master/examples/python/interceptors/headers/generic_client_interceptor.py
208-
class _GenericClientInterceptor(
190+
# @cond
191+
192+
193+
def pytest_addhooks(pluginmanager: pytest.PytestPluginManager):
194+
pluginmanager.add_hookspecs(_hookspec)
195+
196+
197+
class _UpdateServerStateInterceptor(
209198
grpc.aio.UnaryUnaryClientInterceptor,
210199
grpc.aio.UnaryStreamClientInterceptor,
211200
grpc.aio.StreamUnaryClientInterceptor,
212201
grpc.aio.StreamStreamClientInterceptor,
213202
):
214-
def __init__(self):
215-
self.prepare_func: Callable[[grpc.aio.ClientCallDetails, MessageOrStream], Awaitable[None]] | None = None
203+
def __init__(self, service_client: pytest_userver.client.Client, asyncexc_check: _AsyncExcCheck):
204+
self._service_client = service_client
205+
self._asyncexc_check = asyncexc_check
206+
207+
async def _before_call_hook(self) -> None:
208+
self._asyncexc_check()
209+
if hasattr(self._service_client, 'update_server_state'):
210+
await self._service_client.update_server_state()
211+
212+
async def intercept_unary_unary(
213+
self,
214+
continuation: Callable[[grpc.aio.ClientCallDetails, Any], Awaitable[grpc.aio.UnaryUnaryCall]],
215+
client_call_details: grpc.aio.ClientCallDetails,
216+
request: Any,
217+
) -> grpc.aio.UnaryUnaryCall:
218+
await self._before_call_hook()
219+
try:
220+
return await continuation(client_call_details, request)
221+
finally:
222+
self._asyncexc_check()
223+
224+
# Note: full type of this function is Callable[[...], Awaitable[AsyncIterator[Any]]]
225+
async def intercept_unary_stream(
226+
self,
227+
continuation: Callable[[grpc.aio.ClientCallDetails, Any], grpc.aio.UnaryStreamCall],
228+
client_call_details: grpc.aio.ClientCallDetails,
229+
request: Any,
230+
) -> AsyncIterator[Any]:
231+
await self._before_call_hook()
232+
call = await continuation(client_call_details, request)
233+
234+
async def response_stream() -> AsyncIterator[Any]:
235+
try:
236+
async for response in call:
237+
yield response
238+
finally:
239+
self._asyncexc_check()
240+
241+
return response_stream()
242+
243+
async def intercept_stream_unary(
244+
self,
245+
continuation: Callable[[grpc.aio.ClientCallDetails, AsyncIterator[Any]], Awaitable[grpc.aio.StreamUnaryCall]],
246+
client_call_details: grpc.aio.ClientCallDetails,
247+
request_iterator: AsyncIterator[Any],
248+
) -> grpc.aio.StreamUnaryCall:
249+
await self._before_call_hook()
250+
try:
251+
return await continuation(client_call_details, request_iterator)
252+
finally:
253+
self._asyncexc_check()
216254

217-
async def intercept_unary_unary(self, continuation, client_call_details, request):
218-
await self.prepare_func(client_call_details, request)
219-
return await continuation(client_call_details, request)
255+
# Note: full type of this function is Callable[[...], Awaitable[AsyncIterator[Any]]]
256+
async def intercept_stream_stream(
257+
self,
258+
continuation: Callable[[grpc.aio.ClientCallDetails, AsyncIterator[Any]], grpc.aio.StreamStreamCall],
259+
client_call_details: grpc.aio.ClientCallDetails,
260+
request_iterator: AsyncIterator[Any],
261+
) -> AsyncIterator[Any]:
262+
self._asyncexc_check()
263+
await self._before_call_hook()
264+
call = await continuation(client_call_details, request_iterator)
220265

221-
async def intercept_unary_stream(self, continuation, client_call_details, request):
222-
await self.prepare_func(client_call_details, request)
223-
return continuation(client_call_details, next(request))
266+
async def response_stream() -> AsyncIterator[Any]:
267+
try:
268+
async for response in call:
269+
yield response
270+
finally:
271+
self._asyncexc_check()
224272

225-
async def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
226-
await self.prepare_func(client_call_details, request_iterator)
227-
return await continuation(client_call_details, request_iterator)
273+
return response_stream()
228274

229-
async def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
230-
await self.prepare_func(client_call_details, request_iterator)
231-
return continuation(client_call_details, request_iterator)
232275

276+
def pytest_grpc_client_interceptors(request: pytest.FixtureRequest) -> Sequence[grpc.aio.ClientInterceptor]:
277+
return [
278+
_UpdateServerStateInterceptor(
279+
request.getfixturevalue('service_client'),
280+
request.getfixturevalue('asyncexc_check'),
281+
),
282+
]
233283

234-
@pytest.fixture(scope='session')
235-
def _grpc_channel_interceptor(daemon_scoped_mark) -> _GenericClientInterceptor:
236-
return _GenericClientInterceptor()
237284

285+
def _filter_interceptors(
286+
interceptors: Sequence[grpc.aio.ClientInterceptor], desired_type: type[grpc.aio.ClientInterceptor]
287+
) -> list[grpc.aio.ClientInterceptor]:
288+
return [interceptor for interceptor in interceptors if isinstance(interceptor, desired_type)]
238289

239-
class _AsyncExcClientInterceptor(grpc.aio.UnaryUnaryClientInterceptor, grpc.aio.StreamUnaryClientInterceptor):
240-
def __init__(self):
241-
self.asyncexc_check: _AsyncExcCheck | None = None
242290

243-
async def intercept_unary_unary(self, continuation, client_call_details, request):
244-
try:
245-
return await continuation(client_call_details, request)
246-
finally:
247-
self.asyncexc_check()
291+
def _set_client_interceptors(channel: grpc.aio.Channel, interceptors: Sequence[grpc.aio.ClientInterceptor]) -> None:
292+
"""
293+
Allows to set interceptors dynamically while reusing the same underlying channel,
294+
which is something grpc-io currently doesn't support.
248295
249-
async def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
250-
try:
251-
return await continuation(client_call_details, request_iterator)
252-
finally:
253-
self.asyncexc_check()
296+
Also fixes the bug: multi-inheritance interceptors are only registered for first matching type
297+
https://github.com/grpc/grpc/issues/31442
298+
"""
299+
channel._unary_unary_interceptors = _filter_interceptors(interceptors, grpc.aio.UnaryUnaryClientInterceptor)
300+
channel._unary_stream_interceptors = _filter_interceptors(interceptors, grpc.aio.UnaryStreamClientInterceptor)
301+
channel._stream_unary_interceptors = _filter_interceptors(interceptors, grpc.aio.StreamUnaryClientInterceptor)
302+
channel._stream_stream_interceptors = _filter_interceptors(interceptors, grpc.aio.StreamStreamClientInterceptor)
254303

255304

256-
@pytest.fixture(scope='session')
257-
def _grpc_channel_interceptor_asyncexc() -> _AsyncExcClientInterceptor:
258-
return _AsyncExcClientInterceptor()
305+
# @endcond

testsuite/pytest_plugins/pytest_userver/plugins/testpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class UnregisteredTestpointError(BaseError):
1717

1818

1919
class TestpointControl:
20-
def __init__(self):
20+
def __init__(self) -> None:
2121
self.enabled_testpoints: frozenset[str] = frozenset()
2222

2323

testsuite/pytest_plugins/pytest_userver/s3api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class S3Object:
1515

1616

1717
class S3MockBucketStorage:
18-
def __init__(self):
18+
def __init__(self) -> None:
1919
# use Path to normalize keys (e.g. /a//file.json)
2020
self._storage: dict[pathlib.Path, S3Object] = {}
2121

0 commit comments

Comments
 (0)