Skip to content

Commit 01dd5bf

Browse files
committed
sync
1 parent 0cff446 commit 01dd5bf

File tree

2 files changed

+59
-21
lines changed

2 files changed

+59
-21
lines changed

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from collections import deque
77
from typing import Optional, Set, Dict
88

9-
import grpc
109

11-
import ydb
1210
from .. import _apis, issues, RetrySettings
1311
from ..aio import Driver
1412
from ..issues import (
@@ -53,7 +51,7 @@ class ReaderReconnector:
5351

5452
_state_changed: asyncio.Event
5553
_stream_reader: Optional["ReaderStream"]
56-
_first_error: asyncio.Future[ydb.Error]
54+
_first_error: asyncio.Future[YdbError]
5755

5856
def __init__(self, driver: Driver, settings: PublicReaderSettings):
5957
self._settings = settings
@@ -71,11 +69,10 @@ async def _connection_loop(self):
7169
while True:
7270
try:
7371
self._stream_reader = await ReaderStream.create(self._driver, self._settings)
72+
attempt = 0
7473
self._state_changed.set()
75-
self._stream_reader._state_changed.wait()
76-
except Exception as err:
77-
# todo reset attempts when connection established
78-
74+
await self._stream_reader.wait_error()
75+
except issues.Error as err:
7976
retry_info = check_retriable_error(err, self._settings._retry_settings(), attempt)
8077
if not retry_info.is_retriable:
8178
self._set_first_error(err)
@@ -90,8 +87,11 @@ async def wait_message(self):
9087
raise self._first_error.result()
9188

9289
if self._stream_reader is not None:
93-
await self._stream_reader.wait_messages()
94-
return
90+
try:
91+
await self._stream_reader.wait_messages()
92+
return
93+
except YdbError:
94+
pass # handle errors in reconnection loop
9595

9696
await self._state_changed.wait()
9797
self._state_changed.clear()
@@ -114,6 +114,7 @@ def _set_first_error(self, err: issues.Error):
114114
# skip if already has result
115115
pass
116116

117+
117118
class ReaderStream:
118119
_token_getter: Optional[TokenGetterFuncType]
119120
_session_id: str
@@ -126,7 +127,7 @@ class ReaderStream:
126127
_state_changed: asyncio.Event
127128
_closed: bool
128129
_message_batches: typing.Deque[PublicBatch]
129-
first_error: asyncio.Future[YdbError]
130+
_first_error: asyncio.Future[YdbError]
130131

131132
def __init__(self, settings: PublicReaderSettings):
132133
self._token_getter = settings._token_getter
@@ -139,7 +140,7 @@ def __init__(self, settings: PublicReaderSettings):
139140

140141
self._state_changed = asyncio.Event()
141142
self._closed = False
142-
self.first_error = asyncio.get_running_loop().create_future()
143+
self._first_error = asyncio.get_running_loop().create_future()
143144
self._message_batches = deque()
144145

145146
@staticmethod
@@ -174,6 +175,9 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess
174175
read_messages_task = asyncio.create_task(self._read_messages_loop(stream))
175176
self._background_tasks.add(read_messages_task)
176177

178+
async def wait_error(self):
179+
raise await self._first_error
180+
177181
async def wait_messages(self):
178182
while True:
179183
if self._get_first_error() is not None:
@@ -317,17 +321,17 @@ def _read_response_to_batches(self, message: StreamReadMessage.ReadResponse) ->
317321
batches[-1]._bytes_size += additional_bytes_to_last_batch
318322
return batches
319323

320-
def _set_first_error(self, err):
324+
def _set_first_error(self, err: ydb.Error):
321325
try:
322-
self.first_error.set_result(err)
326+
self._first_error.set_result(err)
323327
self._state_changed.set()
324328
except asyncio.InvalidStateError:
325329
# skip later set errors
326330
pass
327331

328-
def _get_first_error(self):
329-
if self.first_error.done():
330-
return self.first_error.result()
332+
def _get_first_error(self) -> Optional[ydb.Error]:
333+
if self._first_error.done():
334+
return self._first_error.result()
331335
else:
332336
return None
333337

ydb/_topic_reader/topic_reader_asyncio_test.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .datatypes import PublicBatch, PublicMessage
1111
from .topic_reader import PublicReaderSettings
1212
from .topic_reader_asyncio import ReaderStream, PartitionSession, ReaderReconnector
13-
from .._topic_wrapper.common import OffsetsRange, Codec, ServerStatus, UpdateTokenResponse
13+
from .._topic_wrapper.common import OffsetsRange, Codec, ServerStatus, UpdateTokenResponse, SupportedDriverType
1414
from .._topic_wrapper.reader import StreamReadMessage
1515
from .._topic_wrapper.test_helpers import StreamMock, wait_condition, wait_for_fast
1616
from ..issues import Unavailable
@@ -23,6 +23,13 @@
2323
from .._grpc.common.protos import ydb_status_codes_pb2
2424

2525

26+
@pytest.fixture(autouse=True)
27+
def handle_exceptions(event_loop):
28+
def handler(loop, context):
29+
print(context)
30+
event_loop.set_exception_handler(handler)
31+
32+
2633
@pytest.fixture()
2734
def default_reader_settings():
2835
return PublicReaderSettings(
@@ -634,11 +641,38 @@ async def test_receive_batch_nowait(self, stream, stream_reader, partition_sessi
634641
@pytest.mark.asyncio
635642
class TestReaderReconnector:
636643
async def test_reconnect_on_repeatable_error(self, monkeypatch):
637-
def stream_create():
638-
pass
644+
test_error = issues.Overloaded("test error")
645+
646+
async def wait_error():
647+
raise test_error
648+
649+
reader_stream_mock_with_error = mock.Mock(ReaderStream)
650+
reader_stream_mock_with_error.wait_error = mock.AsyncMock(side_effect=wait_error)
651+
652+
async def wait_messages():
653+
raise test_error
654+
655+
reader_stream_mock_with_error.wait_messages = mock.AsyncMock(side_effect=wait_messages)
656+
657+
reader_stream_with_messages = mock.Mock(ReaderStream)
658+
reader_stream_with_messages.wait_error.return_value = asyncio.Future()
659+
reader_stream_with_messages.wait_messages.return_value = None
660+
661+
stream_index = 0
662+
663+
async def stream_create(driver: SupportedDriverType, settings: PublicReaderSettings,):
664+
nonlocal stream_index
665+
stream_index += 1
666+
if stream_index == 1:
667+
return reader_stream_mock_with_error
668+
elif stream_index == 2:
669+
return reader_stream_with_messages
670+
else:
671+
raise Exception("unexpected create stream")
639672

640673
with mock.patch.object(ReaderStream, "create", stream_create):
641-
reconnector = ReaderReconnector(None, PublicReaderSettings("", ""))
674+
reconnector = ReaderReconnector(mock.Mock(), PublicReaderSettings("", ""))
642675
await reconnector.wait_message()
643676

644-
raise NotImplementedError()
677+
reader_stream_mock_with_error.wait_error.assert_any_await()
678+
reader_stream_mock_with_error.wait_messages.assert_any_await()

0 commit comments

Comments
 (0)