66from collections import deque
77from typing import Deque , AsyncIterator , Union , List , Optional , Dict , Callable
88
9+ import logging
10+
911import ydb
1012from .topic_writer import (
1113 PublicWriterSettings ,
3840 GrpcWrapperAsyncIO ,
3941)
4042
43+ logger = logging .getLogger (__name__ )
44+
4145
4246class WriterAsyncIO :
4347 _loop : asyncio .AbstractEventLoop
@@ -61,7 +65,7 @@ def __del__(self):
6165 if self ._closed or self ._loop .is_closed ():
6266 return
6367
64- self ._loop .call_soon (self .close )
68+ self ._loop .call_soon (self .close , False )
6569
6670 async def close (self , * , flush : bool = True ):
6771 if self ._closed :
@@ -154,7 +158,6 @@ class WriterAsyncIOReconnector:
154158 _credentials : Union [ydb .credentials .Credentials , None ]
155159 _driver : ydb .aio .Driver
156160 _init_message : StreamWriteMessage .InitRequest
157- _init_info : asyncio .Future
158161 _stream_connected : asyncio .Event
159162 _settings : WriterSettings
160163 _codec : PublicCodec
@@ -164,25 +167,30 @@ class WriterAsyncIOReconnector:
164167 _codec_selector_last_codec : Optional [PublicCodec ]
165168 _codec_selector_check_batches_interval : int
166169
167- _last_known_seq_no : int
168170 if typing .TYPE_CHECKING :
169171 _messages_for_encode : asyncio .Queue [List [InternalMessage ]]
170172 else :
171173 _messages_for_encode : asyncio .Queue
172174 _messages : Deque [InternalMessage ]
173175 _messages_future : Deque [asyncio .Future ]
174176 _new_messages : asyncio .Queue
175- _stop_reason : asyncio .Future
176177 _background_tasks : List [asyncio .Task ]
177178
179+ _state_changed : asyncio .Event
180+ if typing .TYPE_CHECKING :
181+ _stop_reason : asyncio .Future [BaseException ]
182+ else :
183+ _stop_reason : asyncio .Future
184+ _init_info : Optional [PublicWriterInitInfo ]
185+
178186 def __init__ (self , driver : SupportedDriverType , settings : WriterSettings ):
179187 self ._closed = False
180188 self ._loop = asyncio .get_running_loop ()
181189 self ._driver = driver
182190 self ._credentials = driver ._credentials
183191 self ._init_message = settings .create_init_request ()
184192 self ._new_messages = asyncio .Queue ()
185- self ._init_info = self . _loop . create_future ()
193+ self ._init_info = None
186194 self ._stream_connected = asyncio .Event ()
187195 self ._settings = settings
188196
@@ -219,14 +227,17 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
219227 asyncio .create_task (self ._encode_loop (), name = "encode_loop" ),
220228 ]
221229
230+ self ._state_changed = asyncio .Event ()
231+
222232 async def close (self , flush : bool ):
223233 if self ._closed :
224234 return
235+ self ._closed = True
236+ logger .debug ("Close writer reconnector" )
225237
226238 if flush :
227239 await self .flush ()
228240
229- self ._closed = True
230241 self ._stop (TopicWriterStopped ())
231242
232243 for task in self ._background_tasks :
@@ -240,19 +251,20 @@ async def close(self, flush: bool):
240251 pass
241252
242253 async def wait_init (self ) -> PublicWriterInitInfo :
243- done , _ = await asyncio .wait (
244- [self ._init_info , self ._stop_reason ], return_when = asyncio .FIRST_COMPLETED
245- )
246- res = done .pop () # type: asyncio.Future
247- res_val = res .result ()
254+ while True :
255+ if self ._stop_reason .done ():
256+ raise self ._stop_reason .exception ()
248257
249- if isinstance ( res_val , BaseException ) :
250- raise res_val
258+ if self . _init_info :
259+ return self . _init_info
251260
252- return res_val
261+ await self . _state_changed . wait ()
253262
254- async def wait_stop (self ) -> Exception :
255- return await self ._stop_reason
263+ async def wait_stop (self ) -> BaseException :
264+ try :
265+ await self ._stop_reason
266+ except BaseException as stop_reason :
267+ return stop_reason
256268
257269 async def write_with_ack_future (
258270 self , messages : List [PublicMessage ]
@@ -343,13 +355,14 @@ async def _connection_loop(self):
343355 self ._settings .update_token_interval ,
344356 )
345357 try :
346- self ._last_known_seq_no = stream_writer . last_seqno
347- self ._init_info . set_result (
348- PublicWriterInitInfo (
358+ if self ._init_info is None :
359+ self ._last_known_seq_no = stream_writer . last_seqno
360+ self . _init_info = PublicWriterInitInfo (
349361 last_seqno = stream_writer .last_seqno ,
350362 supported_codecs = stream_writer .supported_codecs ,
351363 )
352- )
364+ self ._state_changed .set ()
365+
353366 except asyncio .InvalidStateError :
354367 pass
355368
@@ -369,9 +382,6 @@ async def _connection_loop(self):
369382 await stream_writer .close ()
370383 done .pop ().result ()
371384 except issues .Error as err :
372- # todo log error
373- print (err )
374-
375385 err_info = check_retriable_error (err , retry_settings , attempt )
376386 if not err_info .is_retriable :
377387 self ._stop (err )
@@ -550,8 +560,13 @@ def _stop(self, reason: Exception):
550560
551561 self ._stop_reason .set_result (reason )
552562
563+ for f in self ._messages_future :
564+ f .set_exception (reason )
565+
566+ self ._state_changed .set ()
567+ logger .info ("Stop topic writer: %s" % reason )
568+
553569 async def flush (self ):
554- self ._check_stop ()
555570 if not self ._messages_future :
556571 return
557572
0 commit comments