diff --git a/tests/aio/query/conftest.py b/tests/aio/query/conftest.py index 27d96343..fb7f75fa 100644 --- a/tests/aio/query/conftest.py +++ b/tests/aio/query/conftest.py @@ -1,4 +1,9 @@ +from unittest import mock + +import grpc import pytest +from grpc._cython import cygrpc + from ydb.aio.query.session import QuerySession from ydb.aio.query.pool import QuerySessionPool @@ -32,3 +37,22 @@ async def tx(session): async def pool(driver): async with QuerySessionPool(driver) as pool: yield pool + + +@pytest.fixture +async def ydb_terminates_streams_with_unavailable(): + async def _patch(self): + message = await self._read() # Read the first message + while message is not cygrpc.EOF: # While the message is not empty, continue reading the stream + yield message + message = await self._read() + + # Emulate stream termination + raise grpc.aio.AioRpcError( + code=grpc.StatusCode.UNAVAILABLE, + initial_metadata=await self.initial_metadata(), + trailing_metadata=await self.trailing_metadata(), + ) + + with mock.patch.object(grpc.aio._call._StreamResponseMixin, "_fetch_stream_responses", _patch): + yield diff --git a/tests/aio/query/test_query_session.py b/tests/aio/query/test_query_session.py index 67db045a..2602770a 100644 --- a/tests/aio/query/test_query_session.py +++ b/tests/aio/query/test_query_session.py @@ -1,4 +1,6 @@ import pytest + +import ydb from ydb.aio.query.session import QuerySession @@ -113,3 +115,13 @@ async def test_two_results(self, session: QuerySession): assert res == [[1], [2]] assert counter == 2 + + @pytest.mark.asyncio + @pytest.mark.usefixtures("ydb_terminates_streams_with_unavailable") + async def test_terminated_stream_raises_ydb_error(self, session: QuerySession): + await session.create() + + with pytest.raises(ydb.Unavailable): + async with await session.execute("select 1") as results: + async for _ in results: + pass diff --git a/tests/aio/query/test_query_transaction.py b/tests/aio/query/test_query_transaction.py index aa59abb3..b2a8ef32 100644 --- a/tests/aio/query/test_query_transaction.py +++ b/tests/aio/query/test_query_transaction.py @@ -1,5 +1,6 @@ import pytest +import ydb from ydb.aio.query.transaction import QueryTxContext from ydb.query.transaction import QueryTxStateEnum @@ -107,3 +108,13 @@ async def test_execute_two_results(self, tx: QueryTxContext): assert res == [[1], [2]] assert counter == 2 + + @pytest.mark.asyncio + @pytest.mark.usefixtures("ydb_terminates_streams_with_unavailable") + async def test_terminated_stream_raises_ydb_error(self, tx: QueryTxContext): + await tx.begin() + + with pytest.raises(ydb.Unavailable): + async with await tx.execute("select 1") as results: + async for _ in results: + pass diff --git a/ydb/_errors.py b/ydb/_errors.py index 1e2308ef..b19de749 100644 --- a/ydb/_errors.py +++ b/ydb/_errors.py @@ -1,5 +1,7 @@ from dataclasses import dataclass -from typing import Optional +from typing import Optional, Union + +import grpc from . import issues @@ -52,3 +54,26 @@ def check_retriable_error(err, retry_settings, attempt): class ErrorRetryInfo: is_retriable: bool sleep_timeout_seconds: Optional[float] + + +def stream_error_converter(exc: BaseException) -> Union[issues.Error, BaseException]: + """Converts gRPC stream errors to appropriate YDB exception types. + + This function takes a base exception and converts specific gRPC aio stream errors + to their corresponding YDB exception types for better error handling and semantic + clarity. + + Args: + exc (BaseException): The original exception to potentially convert. + + Returns: + BaseException: Either a converted YDB exception or the original exception + if no specific conversion rule applies. + """ + if isinstance(exc, (grpc.RpcError, grpc.aio.AioRpcError)): + if exc.code() == grpc.StatusCode.UNAVAILABLE: + return issues.Unavailable(exc.details() or "") + if exc.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + return issues.DeadlineExceed("Deadline exceeded on request") + return issues.Error("Stream has been terminated. Original exception: {}".format(str(exc.details()))) + return exc diff --git a/ydb/aio/_utilities.py b/ydb/aio/_utilities.py index 296cd256..53a7d412 100644 --- a/ydb/aio/_utilities.py +++ b/ydb/aio/_utilities.py @@ -2,9 +2,10 @@ class AsyncResponseIterator(object): - def __init__(self, it, wrapper): + def __init__(self, it, wrapper, error_converter=None): self.it = it.__aiter__() self.wrapper = wrapper + self.error_converter = error_converter def cancel(self): self.it.cancel() @@ -17,7 +18,13 @@ def __aiter__(self): return self async def _next(self): - res = self.wrapper(await self.it.__anext__()) + try: + res = self.wrapper(await self.it.__anext__()) + except BaseException as e: + if self.error_converter: + raise self.error_converter(e) from e + raise e + if res is not None: return res return await self._next() diff --git a/ydb/aio/query/session.py b/ydb/aio/query/session.py index 7a7ba5ba..83d527b7 100644 --- a/ydb/aio/query/session.py +++ b/ydb/aio/query/session.py @@ -19,6 +19,7 @@ ) from ..._constants import DEFAULT_INITIAL_RESPONSE_TIMEOUT +from ..._errors import stream_error_converter class QuerySession(BaseQuerySession): @@ -151,12 +152,13 @@ async def execute( ) return AsyncResponseContextIterator( - stream_it, - lambda resp: base.wrap_execute_query_response( + it=stream_it, + wrapper=lambda resp: base.wrap_execute_query_response( rpc_state=None, response_pb=resp, session_state=self._state, session=self, settings=self._settings, ), + error_converter=stream_error_converter, ) diff --git a/ydb/aio/query/transaction.py b/ydb/aio/query/transaction.py index c9a6e445..2c313a4a 100644 --- a/ydb/aio/query/transaction.py +++ b/ydb/aio/query/transaction.py @@ -11,6 +11,7 @@ BaseQueryTxContext, QueryTxStateEnum, ) +from ..._errors import stream_error_converter logger = logging.getLogger(__name__) @@ -181,8 +182,8 @@ async def execute( ) self._prev_stream = AsyncResponseContextIterator( - stream_it, - lambda resp: base.wrap_execute_query_response( + it=stream_it, + wrapper=lambda resp: base.wrap_execute_query_response( rpc_state=None, response_pb=resp, session_state=self._session_state, @@ -190,5 +191,6 @@ async def execute( commit_tx=commit_tx, settings=self.session._settings, ), + error_converter=stream_error_converter, ) return self._prev_stream