66import typing
77from asyncio import Task
88from collections import deque
9- from typing import Optional , Set , Dict
9+ from typing import Optional , Set , Dict , Union , Callable
1010
11- from .. import _apis , issues , RetrySettings
11+ from .. import _apis , issues
1212from .._utilities import AtomicCounter
1313from ..aio import Driver
1414from ..issues import Error as YdbError , _process_response
1919 SupportedDriverType ,
2020 GrpcWrapperAsyncIO ,
2121)
22- from .._grpc .grpcwrapper .ydb_topic import StreamReadMessage , Codec
22+ from .._grpc .grpcwrapper .ydb_topic import (
23+ StreamReadMessage ,
24+ UpdateTokenRequest ,
25+ UpdateTokenResponse ,
26+ Codec ,
27+ )
2328from .._errors import check_retriable_error
2429
2530
@@ -194,7 +199,6 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
194199 self ._settings = settings
195200 self ._driver = driver
196201 self ._background_tasks = set ()
197- self ._retry_settins = RetrySettings (idempotent = True ) # get from settings
198202
199203 self ._state_changed = asyncio .Event ()
200204 self ._stream_reader = None
@@ -227,7 +231,7 @@ async def wait_message(self):
227231 if self ._first_error .done ():
228232 raise self ._first_error .result ()
229233
230- if self ._stream_reader is not None :
234+ if self ._stream_reader :
231235 try :
232236 await self ._stream_reader .wait_messages ()
233237 return
@@ -289,8 +293,15 @@ class ReaderStream:
289293 _message_batches : typing .Deque [datatypes .PublicBatch ]
290294 _first_error : asyncio .Future [YdbError ]
291295
296+ _update_token_interval : Union [int , float ]
297+ _update_token_event : asyncio .Event
298+ _get_token_function : Callable [[], str ]
299+
292300 def __init__ (
293- self , reader_reconnector_id : int , settings : topic_reader .PublicReaderSettings
301+ self ,
302+ reader_reconnector_id : int ,
303+ settings : topic_reader .PublicReaderSettings ,
304+ get_token_function : Optional [Callable [[], str ]] = None ,
294305 ):
295306 self ._loop = asyncio .get_running_loop ()
296307 self ._id = ReaderStream ._static_id_counter .inc_and_get ()
@@ -313,6 +324,10 @@ def __init__(
313324 self ._batches_to_decode = asyncio .Queue ()
314325 self ._message_batches = deque ()
315326
327+ self ._update_token_interval = settings .update_token_interval
328+ self ._get_token_function = get_token_function
329+ self ._update_token_event = asyncio .Event ()
330+
316331 @staticmethod
317332 async def create (
318333 reader_reconnector_id : int ,
@@ -325,7 +340,12 @@ async def create(
325340 driver , _apis .TopicService .Stub , _apis .TopicService .StreamRead
326341 )
327342
328- reader = ReaderStream (reader_reconnector_id , settings )
343+ creds = driver ._credentials
344+ reader = ReaderStream (
345+ reader_reconnector_id ,
346+ settings ,
347+ get_token_function = creds .get_auth_token if creds else None ,
348+ )
329349 await reader ._start (stream , settings ._init_message ())
330350 return reader
331351
@@ -347,35 +367,41 @@ async def _start(
347367 "Unexpected message after InitRequest: %s" , init_response
348368 )
349369
370+ self ._update_token_event .set ()
371+
350372 self ._background_tasks .add (
351- asyncio .create_task (self ._read_messages_loop (stream ) )
373+ asyncio .create_task (self ._read_messages_loop (), name = "read_messages_loop" )
352374 )
353375 self ._background_tasks .add (asyncio .create_task (self ._decode_batches_loop ()))
376+ if self ._get_token_function :
377+ self ._background_tasks .add (
378+ asyncio .create_task (self ._update_token_loop (), name = "update_token_loop" )
379+ )
354380
355381 async def wait_error (self ):
356382 raise await self ._first_error
357383
358384 async def wait_messages (self ):
359385 while True :
360- if self ._get_first_error () is not None :
386+ if self ._get_first_error ():
361387 raise self ._get_first_error ()
362388
363- if len ( self ._message_batches ) > 0 :
389+ if self ._message_batches :
364390 return
365391
366392 await self ._state_changed .wait ()
367393 self ._state_changed .clear ()
368394
369395 def receive_batch_nowait (self ):
370- if self ._get_first_error () is not None :
396+ if self ._get_first_error ():
371397 raise self ._get_first_error ()
372398
373- try :
374- batch = self . _message_batches . popleft ()
375- self . _buffer_release_bytes ( batch . _bytes_size )
376- return batch
377- except IndexError :
378- return None
399+ if not self . _message_batches :
400+ return
401+
402+ batch = self . _message_batches . popleft ()
403+ self . _buffer_release_bytes ( batch . _bytes_size )
404+ return batch
379405
380406 def commit (
381407 self , batch : datatypes .ICommittable
@@ -413,7 +439,7 @@ def commit(
413439
414440 return waiter
415441
416- async def _read_messages_loop (self , stream : IGrpcWrapperAsyncIO ):
442+ async def _read_messages_loop (self ):
417443 try :
418444 self ._stream .write (
419445 StreamReadMessage .FromClient (
@@ -423,24 +449,34 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO):
423449 )
424450 )
425451 while True :
426- message = await stream .receive () # type: StreamReadMessage.FromServer
452+ message = (
453+ await self ._stream .receive ()
454+ ) # type: StreamReadMessage.FromServer
427455 _process_response (message .server_status )
456+
428457 if isinstance (message .server_message , StreamReadMessage .ReadResponse ):
429458 self ._on_read_response (message .server_message )
459+
430460 elif isinstance (
431461 message .server_message , StreamReadMessage .CommitOffsetResponse
432462 ):
433463 self ._on_commit_response (message .server_message )
464+
434465 elif isinstance (
435466 message .server_message ,
436467 StreamReadMessage .StartPartitionSessionRequest ,
437468 ):
438469 self ._on_start_partition_session (message .server_message )
470+
439471 elif isinstance (
440472 message .server_message ,
441473 StreamReadMessage .StopPartitionSessionRequest ,
442474 ):
443475 self ._on_partition_session_stop (message .server_message )
476+
477+ elif isinstance (message .server_message , UpdateTokenResponse ):
478+ self ._update_token_event .set ()
479+
444480 else :
445481 raise NotImplementedError (
446482 "Unexpected type of StreamReadMessage.FromServer message: %s"
@@ -450,7 +486,20 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO):
450486 self ._state_changed .set ()
451487 except Exception as e :
452488 self ._set_first_error (e )
453- raise e
489+ raise
490+
491+ async def _update_token_loop (self ):
492+ while True :
493+ await asyncio .sleep (self ._update_token_interval )
494+ await self ._update_token (token = self ._get_token_function ())
495+
496+ async def _update_token (self , token : str ):
497+ await self ._update_token_event .wait ()
498+ try :
499+ msg = StreamReadMessage .FromClient (UpdateTokenRequest (token ))
500+ self ._stream .write (msg )
501+ finally :
502+ self ._update_token_event .clear ()
454503
455504 def _on_start_partition_session (
456505 self , message : StreamReadMessage .StartPartitionSessionRequest
@@ -491,14 +540,12 @@ def _on_start_partition_session(
491540 def _on_partition_session_stop (
492541 self , message : StreamReadMessage .StopPartitionSessionRequest
493542 ):
494- try :
495- partition = self ._partition_sessions [message .partition_session_id ]
496- except KeyError :
543+ if message .partition_session_id not in self ._partition_sessions :
497544 # may if receive stop partition with graceful=false after response on stop partition
498545 # with graceful=true and remove partition from internal dictionary
499546 return
500547
501- del self ._partition_sessions [ message .partition_session_id ]
548+ partition = self ._partition_sessions . pop ( message .partition_session_id )
502549 partition .close ()
503550
504551 if message .graceful :
@@ -519,11 +566,10 @@ def _on_read_response(self, message: StreamReadMessage.ReadResponse):
519566
520567 def _on_commit_response (self , message : StreamReadMessage .CommitOffsetResponse ):
521568 for partition_offset in message .partitions_committed_offsets :
522- session = self ._partition_sessions .get (
523- partition_offset .partition_session_id
524- )
525- if session is None :
569+ if partition_offset .partition_session_id not in self ._partition_sessions :
526570 continue
571+
572+ session = self ._partition_sessions [partition_offset .partition_session_id ]
527573 session .ack_notify (partition_offset .committed_offset )
528574
529575 def _buffer_consume_bytes (self , bytes_size ):
@@ -544,12 +590,9 @@ def _read_response_to_batches(
544590 ) -> typing .List [datatypes .PublicBatch ]:
545591 batches = []
546592
547- batch_count = 0
548- for partition_data in message .partition_data :
549- batch_count += len (partition_data .batches )
550-
593+ batch_count = sum (len (p .batches ) for p in message .partition_data )
551594 if batch_count == 0 :
552- return []
595+ return batches
553596
554597 bytes_per_batch = message .bytes_size // batch_count
555598 additional_bytes_to_last_batch = (
@@ -577,12 +620,11 @@ def _read_response_to_batches(
577620 _commit_end_offset = message_data .offset + 1 ,
578621 )
579622 messages .append (mess )
580-
581623 partition_session ._next_message_start_commit_offset = (
582624 mess ._commit_end_offset
583625 )
584626
585- if len ( messages ) > 0 :
627+ if messages :
586628 batch = datatypes .PublicBatch (
587629 session_metadata = server_batch .write_session_meta ,
588630 messages = messages ,
@@ -637,14 +679,12 @@ def _set_first_error(self, err: YdbError):
637679 def _get_first_error (self ) -> Optional [YdbError ]:
638680 if self ._first_error .done ():
639681 return self ._first_error .result ()
640- else :
641- return None
642682
643683 async def close (self ):
644684 if self ._closed :
645- raise TopicReaderError (message = "Double closed ReaderStream" )
646-
685+ return
647686 self ._closed = True
687+
648688 self ._set_first_error (TopicReaderStreamClosedError ())
649689 self ._state_changed .set ()
650690 self ._stream .close ()
@@ -654,5 +694,4 @@ async def close(self):
654694
655695 for task in self ._background_tasks :
656696 task .cancel ()
657-
658697 await asyncio .wait (self ._background_tasks )
0 commit comments