55import gzip
66import typing
77from asyncio import Task
8- from collections import OrderedDict
8+ from collections import OrderedDict , defaultdict
99from typing import Optional , Set , Dict , Union , Callable
1010
1111import ydb
@@ -140,7 +140,8 @@ async def receive_batch_with_tx(
140140 use asyncio.wait_for for wait with timeout.
141141 """
142142 await self ._reconnector .wait_message ()
143- return await self ._reconnector .receive_batch_with_tx_nowait (
143+ tx ._add_listener (self )
144+ return self ._reconnector .receive_batch_with_tx_nowait (
144145 tx ,
145146 max_messages = max_messages ,
146147 )
@@ -177,11 +178,14 @@ async def close(self, flush: bool = True):
177178 self ._closed = True
178179 await self ._reconnector .close (flush )
179180
180- def _on_after_commit (self , exc ):
181- return super (). _on_after_commit ( exc )
181+ async def _on_before_commit (self , tx ):
182+ await self . _reconnector . _on_before_commit ( tx )
182183
183- def _on_after_rollback (self , exc ):
184- return super ()._on_after_rollback (exc )
184+ async def _on_after_commit (self , tx , exc ):
185+ await self ._reconnector ._on_after_commit (tx , exc )
186+
187+ async def _on_after_rollback (self , tx , exc ):
188+ await self ._reconnector ._on_after_rollback (tx , exc )
185189
186190
187191class ReaderReconnector :
@@ -195,8 +199,10 @@ class ReaderReconnector:
195199 _state_changed : asyncio .Event
196200 _stream_reader : Optional ["ReaderStream" ]
197201 _first_error : asyncio .Future [YdbError ]
198- _batches_to_commit : asyncio .Queue
199- _wait_executor : Optional [concurrent .futures .ThreadPoolExecutor ]
202+
203+ _batches_to_commit_with_tx : asyncio .Queue
204+ _tx_to_batches : typing .Dict [str , typing .List [datatypes .PublicBatch ]]
205+ _wait_executor : concurrent .futures .ThreadPoolExecutor
200206
201207 def __init__ (self , driver : Driver , settings : topic_reader .PublicReaderSettings ):
202208 self ._id = self ._static_reader_reconnector_counter .inc_and_get ()
@@ -210,7 +216,9 @@ def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
210216 self ._first_error = asyncio .get_running_loop ().create_future ()
211217 self ._wait_executor = concurrent .futures .ThreadPoolExecutor (max_workers = 1 )
212218
213- self ._batches_to_commit = asyncio .Queue ()
219+ self ._batches_to_commit_with_tx = asyncio .Queue ()
220+ self ._tx_to_batches = defaultdict (list )
221+ self ._background_tasks .add (asyncio .create_task (self ._update_offsets_in_tx_loop ()))
214222
215223 async def _connection_loop (self ):
216224 attempt = 0
@@ -263,19 +271,21 @@ def receive_message_nowait(self):
263271 def commit (self , batch : datatypes .ICommittable ) -> datatypes .PartitionSession .CommitAckWaiter :
264272 return self ._stream_reader .commit (batch )
265273
266- async def _commit_with_tx (self , tx : "BaseQueryTxContext" , batch : datatypes .ICommittable ) -> None :
267- pass
268-
269- async def receive_batch_with_tx_nowait (self , tx : "BaseQueryTxContext" , max_messages : Optional [int ] = None ):
274+ def receive_batch_with_tx_nowait (self , tx : "BaseQueryTxContext" , max_messages : Optional [int ] = None ):
270275 batch = self .receive_batch_nowait (max_messages = max_messages )
271- tx ._add_listener (batch )
272- await self ._update_offsets_in_tx_call (self ._driver , tx , batch )
276+ self ._tx_to_batches [tx .tx_id ].append (batch )
277+
278+ self ._batches_to_commit_with_tx .put_nowait ((tx , batch ))
279+
280+ print ("batch recieved" )
281+
273282 return batch
274- # self._batches_to_commit.put_nowait((tx, batch))
275283
276284 async def _update_offsets_in_tx_loop (self ):
277285 while True :
278- tx , batch = self ._batches_to_commit .get ()
286+ print ("_update_offsets_in_tx_loop" )
287+
288+ tx , batch = await self ._batches_to_commit_with_tx .get ()
279289 await self ._update_offsets_in_tx_call (self ._driver , tx , batch )
280290
281291 async def _update_offsets_in_tx_call (self , driver : SupportedDriverType , tx : "BaseQueryTxContext" , batch : datatypes .ICommittable ) -> None :
@@ -309,9 +319,25 @@ async def _update_offsets_in_tx_call(self, driver: SupportedDriverType, tx: "Bas
309319 else :
310320 res = await to_thread (driver , * args , executor = self ._wait_executor )
311321
322+ batch ._commited_with_tx = True
323+
312324 return res
313325 except BaseException as e :
314- self ._set_first_error (e )
326+ self ._stream_reader ._set_first_error (e )
327+
328+ async def _ensure_all_batches_commited_with_tx (self , tx : "BaseQueryTxContext" ):
329+ while True :
330+ print ("_ensure_all_batches_commited_with_tx" )
331+ if tx .tx_id not in self ._tx_to_batches :
332+ # we should not be here
333+ return True
334+ batches = self ._tx_to_batches .get (tx .tx_id )
335+ everything_commited = True
336+ for batch in batches :
337+ everything_commited = everything_commited and batch ._commited_with_tx
338+ if everything_commited :
339+ return True
340+ await asyncio .sleep (0.001 )
315341
316342 async def close (self , flush : bool ):
317343 if self ._stream_reader :
@@ -336,6 +362,26 @@ def _set_first_error(self, err: issues.Error):
336362 # skip if already has result
337363 pass
338364
365+ async def _on_before_commit (self , tx ):
366+ print ("on before commit" )
367+ await asyncio .wait_for (self ._ensure_all_batches_commited_with_tx (tx ), 1 )
368+ pass
369+
370+ async def _on_after_commit (self , tx , exc ):
371+ print (f"on after commit, exc = { exc is not None } " )
372+
373+ if exc :
374+ self ._stream_reader ._set_first_error (exc )
375+ for batch in self ._tx_to_batches [tx .tx_id ]:
376+ batch ._partition_session .committed_offset = max (batch ._partition_session .committed_offset , batch ._commit_get_offsets_range ().end )
377+ del self ._tx_to_batches [tx .tx_id ]
378+
379+ async def _on_after_rollback (self , tx , exc ):
380+ print (f"on after rollback, exc = { exc is not None } " )
381+ print (exc )
382+ exc = exc if exc is not None else issues .InternalError ("tx rollback failed" )
383+ self ._stream_reader ._set_first_error (exc )
384+
339385
340386class ReaderStream :
341387 _static_id_counter = AtomicCounter ()
0 commit comments