@@ -84,7 +84,7 @@ def __init__(
8484 ):
8585 self ._loop = asyncio .get_running_loop ()
8686 self ._closed = False
87- self ._reconnector = ReaderReconnector (driver , settings )
87+ self ._reconnector = ReaderReconnector (driver , settings , self . _loop )
8888 self ._parent = _parent
8989
9090 async def __aenter__ (self ):
@@ -190,18 +190,24 @@ class ReaderReconnector:
190190 _first_error : asyncio .Future [YdbError ]
191191 _tx_to_batches_map : Dict [str , typing .List [datatypes .PublicBatch ]]
192192
193- def __init__ (self , driver : Driver , settings : topic_reader .PublicReaderSettings ):
193+ def __init__ (
194+ self ,
195+ driver : Driver ,
196+ settings : topic_reader .PublicReaderSettings ,
197+ loop : Optional [asyncio .AbstractEventLoop ] = None ,
198+ ):
194199 self ._id = self ._static_reader_reconnector_counter .inc_and_get ()
195200 self ._settings = settings
196201 self ._driver = driver
202+ self ._loop = loop if loop is not None else asyncio .get_running_loop ()
197203 self ._background_tasks = set ()
198204
199205 self ._state_changed = asyncio .Event ()
200206 self ._stream_reader = None
201207 self ._background_tasks .add (asyncio .create_task (self ._connection_loop ()))
202208 self ._first_error = asyncio .get_running_loop ().create_future ()
203209
204- self ._tx_to_batches_map = defaultdict ( list )
210+ self ._tx_to_batches_map = dict ( )
205211
206212 async def _connection_loop (self ):
207213 attempt = 0
@@ -254,22 +260,23 @@ def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: O
254260 max_messages = max_messages ,
255261 )
256262
257- self ._init_tx_if_needed (tx )
263+ self ._init_tx (tx )
258264
259265 self ._tx_to_batches_map [tx .tx_id ].append (batch )
260266
261- tx ._add_callback (TxEvent .AFTER_COMMIT , batch ._update_partition_offsets , None ) # probably should be current loop
267+ tx ._add_callback (TxEvent .AFTER_COMMIT , batch ._update_partition_offsets , self . _loop )
262268
263269 return batch
264270
265271 def receive_message_nowait (self ):
266272 return self ._stream_reader .receive_message_nowait ()
267273
268- def _init_tx_if_needed (self , tx : "BaseQueryTxContext" ):
274+ def _init_tx (self , tx : "BaseQueryTxContext" ):
269275 if tx .tx_id not in self ._tx_to_batches_map : # Init tx callbacks
270- tx ._add_callback (TxEvent .BEFORE_COMMIT , self ._commit_batches_with_tx , None )
271- tx ._add_callback (TxEvent .AFTER_COMMIT , self ._handle_after_tx_commit , None )
272- tx ._add_callback (TxEvent .AFTER_ROLLBACK , self ._handle_after_tx_rollback , None )
276+ self ._tx_to_batches_map [tx .tx_id ] = []
277+ tx ._add_callback (TxEvent .BEFORE_COMMIT , self ._commit_batches_with_tx , self ._loop )
278+ tx ._add_callback (TxEvent .AFTER_COMMIT , self ._handle_after_tx_commit , self ._loop )
279+ tx ._add_callback (TxEvent .AFTER_ROLLBACK , self ._handle_after_tx_rollback , self ._loop )
273280
274281 async def _commit_batches_with_tx (self , tx : "BaseQueryTxContext" ):
275282 grouped_batches = defaultdict (lambda : defaultdict (list ))
0 commit comments