88from collections import deque
99from typing import Optional , Set , Dict , Union , Callable
1010
11+ import ydb
1112from .. import _apis , issues
1213from .._utilities import AtomicCounter
1314from ..aio import Driver
@@ -35,7 +36,7 @@ class TopicReaderError(YdbError):
3536 pass
3637
3738
38- class TopicReaderUnexpectedCodec (YdbError ):
39+ class PublicTopicReaderUnexpectedCodecError (YdbError ):
3940 pass
4041
4142
@@ -222,9 +223,7 @@ def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.Co
222223
223224 async def close (self , flush : bool ):
224225 if self ._stream_reader :
225- if flush :
226- await self .flush ()
227- await self ._stream_reader .close ()
226+ await self ._stream_reader .close (flush )
228227 for task in self ._background_tasks :
229228 task .cancel ()
230229
@@ -339,9 +338,12 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess
339338 self ._update_token_event .set ()
340339
341340 self ._background_tasks .add (asyncio .create_task (self ._read_messages_loop (), name = "read_messages_loop" ))
342- self ._background_tasks .add (asyncio .create_task (self ._decode_batches_loop ()))
341+ self ._background_tasks .add (asyncio .create_task (self ._decode_batches_loop (), name = "decode_batches" ))
343342 if self ._get_token_function :
344343 self ._background_tasks .add (asyncio .create_task (self ._update_token_loop (), name = "update_token_loop" ))
344+ self ._background_tasks .add (
345+ asyncio .create_task (self ._handle_background_errors (), name = "handle_background_errors" )
346+ )
345347
346348 async def wait_error (self ):
347349 raise await self ._first_error
@@ -411,6 +413,17 @@ def commit(self, batch: datatypes.ICommittable) -> datatypes.PartitionSession.Co
411413
412414 return waiter
413415
416+ async def _handle_background_errors (self ):
417+ done , _ = await asyncio .wait (self ._background_tasks , return_when = asyncio .FIRST_EXCEPTION )
418+ for f in done :
419+ f = f # type: asyncio.Future
420+ err = f .exception ()
421+ if not isinstance (err , ydb .Error ):
422+ old_err = err
423+ err = ydb .Error ("Background process failed unexpected" )
424+ err .__cause__ = old_err
425+ self ._set_first_error (err )
426+
414427 async def _read_messages_loop (self ):
415428 try :
416429 self ._stream .write (
@@ -602,7 +615,7 @@ async def _decode_batch_inplace(self, batch):
602615 try :
603616 decode_func = self ._decoders [batch ._codec ]
604617 except KeyError :
605- raise TopicReaderUnexpectedCodec ("Receive message with unexpected codec: %s" % batch ._codec )
618+ raise PublicTopicReaderUnexpectedCodecError ("Receive message with unexpected codec: %s" % batch ._codec )
606619
607620 decode_data_futures = []
608621 for message in batch .messages :
@@ -628,22 +641,22 @@ def _get_first_error(self) -> Optional[YdbError]:
628641 return self ._first_error .result ()
629642
630643 async def flush (self ):
631- if self ._closed :
632- raise RuntimeError ("Flush on closed Stream" )
633-
634644 futures = []
635645 for session in self ._partition_sessions .values ():
636646 futures .extend (w .future for w in session ._ack_waiters )
637647
638648 if futures :
639649 await asyncio .wait (futures )
640650
641- async def close (self ):
651+ async def close (self , flush : bool ):
642652 if self ._closed :
643653 return
644654
645655 self ._closed = True
646656
657+ if flush :
658+ await self .flush ()
659+
647660 self ._set_first_error (TopicReaderStreamClosedError ())
648661 self ._state_changed .set ()
649662 self ._stream .close ()
0 commit comments