66from collections import deque
77from typing import Deque , AsyncIterator , Union , List , Optional , Dict , Callable
88
9+ import logging
10+
911import ydb
1012from .topic_writer import (
1113 PublicWriterSettings ,
3840 GrpcWrapperAsyncIO ,
3941)
4042
43+ logger = logging .getLogger (__name__ )
44+
4145
4246class WriterAsyncIO :
4347 _loop : asyncio .AbstractEventLoop
@@ -158,7 +162,6 @@ class WriterAsyncIOReconnector:
158162 _credentials : Union [ydb .credentials .Credentials , None ]
159163 _driver : ydb .aio .Driver
160164 _init_message : StreamWriteMessage .InitRequest
161- _init_info : asyncio .Future
162165 _stream_connected : asyncio .Event
163166 _settings : WriterSettings
164167 _codec : PublicCodec
@@ -168,25 +171,30 @@ class WriterAsyncIOReconnector:
168171 _codec_selector_last_codec : Optional [PublicCodec ]
169172 _codec_selector_check_batches_interval : int
170173
171- _last_known_seq_no : int
172174 if typing .TYPE_CHECKING :
173175 _messages_for_encode : asyncio .Queue [List [InternalMessage ]]
174176 else :
175177 _messages_for_encode : asyncio .Queue
176178 _messages : Deque [InternalMessage ]
177179 _messages_future : Deque [asyncio .Future ]
178180 _new_messages : asyncio .Queue
179- _stop_reason : asyncio .Future
180181 _background_tasks : List [asyncio .Task ]
181182
183+ _state_changed : asyncio .Event
184+ if typing .TYPE_CHECKING :
185+ _stop_reason : asyncio .Future [BaseException ]
186+ else :
187+ _stop_reason : asyncio .Future
188+ _init_info : Optional [PublicWriterInitInfo ]
189+
182190 def __init__ (self , driver : SupportedDriverType , settings : WriterSettings ):
183191 self ._closed = False
184192 self ._loop = asyncio .get_running_loop ()
185193 self ._driver = driver
186194 self ._credentials = driver ._credentials
187195 self ._init_message = settings .create_init_request ()
188196 self ._new_messages = asyncio .Queue ()
189- self ._init_info = self . _loop . create_future ()
197+ self ._init_info = None
190198 self ._stream_connected = asyncio .Event ()
191199 self ._settings = settings
192200
@@ -223,14 +231,17 @@ def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
223231 asyncio .create_task (self ._encode_loop (), name = "encode_loop" ),
224232 ]
225233
234+ self ._state_changed = asyncio .Event ()
235+
226236 async def close (self , flush : bool ):
227237 if self ._closed :
228238 return
239+ self ._closed = True
240+ logger .debug ("Close writer reconnector" )
229241
230242 if flush :
231243 await self .flush ()
232244
233- self ._closed = True
234245 self ._stop (TopicWriterStopped ())
235246
236247 for task in self ._background_tasks :
@@ -244,19 +255,20 @@ async def close(self, flush: bool):
244255 pass
245256
246257 async def wait_init (self ) -> PublicWriterInitInfo :
247- done , _ = await asyncio .wait (
248- [self ._init_info , self ._stop_reason ], return_when = asyncio .FIRST_COMPLETED
249- )
250- res = done .pop () # type: asyncio.Future
251- res_val = res .result ()
258+ while True :
259+ if self ._stop_reason .done ():
260+ raise self ._stop_reason .exception ()
252261
253- if isinstance ( res_val , BaseException ) :
254- raise res_val
262+ if self . _init_info :
263+ return self . _init_info
255264
256- return res_val
265+ await self . _state_changed . wait ()
257266
258- async def wait_stop (self ) -> Exception :
259- return await self ._stop_reason
267+ async def wait_stop (self ) -> BaseException :
268+ try :
269+ await self ._stop_reason
270+ except BaseException as stop_reason :
271+ return stop_reason
260272
261273 async def write_with_ack_future (
262274 self , messages : List [PublicMessage ]
@@ -347,13 +359,14 @@ async def _connection_loop(self):
347359 self ._settings .update_token_interval ,
348360 )
349361 try :
350- self ._last_known_seq_no = stream_writer . last_seqno
351- self ._init_info . set_result (
352- PublicWriterInitInfo (
362+ if self ._init_info is None :
363+ self ._last_known_seq_no = stream_writer . last_seqno
364+ self . _init_info = PublicWriterInitInfo (
353365 last_seqno = stream_writer .last_seqno ,
354366 supported_codecs = stream_writer .supported_codecs ,
355367 )
356- )
368+ self ._state_changed .set ()
369+
357370 except asyncio .InvalidStateError :
358371 pass
359372
@@ -373,9 +386,6 @@ async def _connection_loop(self):
373386 await stream_writer .close ()
374387 done .pop ().result ()
375388 except issues .Error as err :
376- # todo log error
377- print (err )
378-
379389 err_info = check_retriable_error (err , retry_settings , attempt )
380390 if not err_info .is_retriable :
381391 self ._stop (err )
@@ -554,8 +564,13 @@ def _stop(self, reason: Exception):
554564
555565 self ._stop_reason .set_result (reason )
556566
567+ for f in self ._messages_future :
568+ f .set_exception (reason )
569+
570+ self ._state_changed .set ()
571+ logger .info ("Stop topic writer: %s" % reason )
572+
557573 async def flush (self ):
558- self ._check_stop ()
559574 if not self ._messages_future :
560575 return
561576
0 commit comments