Skip to content

Commit 3579d4a

Browse files
committed
async topic tx with listener pattern
1 parent e1629aa commit 3579d4a

File tree

10 files changed

+175
-17
lines changed

10 files changed

+175
-17
lines changed

tests/topics/test_topic_transactions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ async def callee(tx: ydb.aio.QueryTxContext):
3838
assert msg.data.decode() == "123"
3939

4040

41-
@pytest.mark.skip("Not implemented yet.")
41+
# @pytest.mark.skip("Not implemented yet.")
4242
class TestTopicTransactionalWriter:
4343
async def test_commit(self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO):
4444
async with ydb.aio.QuerySessionPool(driver) as pool:
4545

4646
async def callee(tx: ydb.aio.QueryTxContext):
4747
tx_writer = driver.topic_client.tx_writer(tx, topic_path)
48-
tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
48+
await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
4949

5050
await pool.retry_tx_async(callee)
5151

@@ -57,7 +57,7 @@ async def test_rollback(self, driver: ydb.aio.Driver, topic_path, topic_reader:
5757

5858
async def callee(tx: ydb.aio.QueryTxContext):
5959
tx_writer = driver.topic_client.tx_writer(tx, topic_path)
60-
tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
60+
await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
6161

6262
await tx.rollback()
6363

ydb/_grpc/grpcwrapper/ydb_topic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def to_proto(self) -> ydb_topic_pb2.StreamWriteMessage.WriteRequest:
252252
proto.codec = self.codec
253253

254254
if self.tx_identity is not None:
255-
proto.tx = self.tx_identity.to_proto()
255+
proto.tx.CopyFrom(self.tx_identity.to_proto())
256256

257257
for message in self.messages:
258258
proto_mess = proto.messages.add()
@@ -314,6 +314,8 @@ def from_proto(cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.Wr
314314
)
315315
except ValueError:
316316
message_write_status = reason
317+
elif proto_ack.HasField("written_in_tx"):
318+
message_write_status = StreamWriteMessage.WriteResponse.WriteAck.StatusWrittenInTx()
317319
else:
318320
raise NotImplementedError("unexpected ack status")
319321

@@ -326,6 +328,9 @@ def from_proto(cls, proto_ack: ydb_topic_pb2.StreamWriteMessage.WriteResponse.Wr
326328
class StatusWritten:
327329
offset: int
328330

331+
class StatusWrittenInTx:
332+
pass
333+
329334
@dataclass
330335
class StatusSkipped:
331336
reason: "StreamWriteMessage.WriteResponse.WriteAck.StatusSkipped.Reason"

ydb/_topic_writer/topic_writer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,12 @@ class Written:
5454
class Skipped:
5555
pass
5656

57+
@dataclass(eq=True)
58+
class WrittenInTx:
59+
pass
60+
5761

58-
PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped]
62+
PublicWriteResultTypes = Union[PublicWriteResult.Written, PublicWriteResult.Skipped, PublicWriteResult.WrittenInTx]
5963

6064

6165
class WriterSettings(PublicWriterSettings):

ydb/_topic_writer/topic_writer_asyncio.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
GrpcWrapperAsyncIO,
4545
)
4646

47+
from ..query.base import TxListenerAsyncIO
48+
4749
if typing.TYPE_CHECKING:
4850
from ..query.transaction import BaseQueryTxContext
4951

@@ -168,12 +170,12 @@ async def wait_init(self) -> PublicWriterInitInfo:
168170
return await self._reconnector.wait_init()
169171

170172

171-
class TxWriterAsyncIO(WriterAsyncIO):
172-
_tx: object
173+
class TxWriterAsyncIO(WriterAsyncIO, TxListenerAsyncIO):
174+
_tx: "BaseQueryTxContext"
173175

174176
def __init__(
175177
self,
176-
tx,
178+
tx: "BaseQueryTxContext",
177179
driver: SupportedDriverType,
178180
settings: PublicWriterSettings,
179181
_client=None,
@@ -183,6 +185,13 @@ def __init__(
183185
self._closed = False
184186
self._reconnector = WriterAsyncIOReconnector(driver=driver, settings=WriterSettings(settings), tx=self._tx)
185187
self._parent = _client
188+
self._tx._add_listener(self)
189+
190+
async def _on_before_commit(self):
191+
await self.close()
192+
193+
async def _on_before_rollback(self):
194+
await self.close()
186195

187196

188197
class WriterAsyncIOReconnector:
@@ -560,6 +569,8 @@ def _handle_receive_ack(self, ack):
560569
result = PublicWriteResult.Skipped()
561570
elif isinstance(status, write_ack_msg.StatusWritten):
562571
result = PublicWriteResult.Written(offset=status.offset)
572+
elif isinstance(status, write_ack_msg.StatusWrittenInTx):
573+
result = PublicWriteResult.WrittenInTx()
563574
else:
564575
raise TopicWriterError("internal error - receive unexpected ack message.")
565576
message_future.set_result(result)
@@ -575,6 +586,7 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"):
575586

576587
while True:
577588
m = await self._new_messages.get() # type: InternalMessage
589+
print("NEW MESSAGE")
578590
if m.seq_no > last_seq_no:
579591
writer.write([m])
580592
except asyncio.CancelledError:
@@ -606,6 +618,7 @@ async def flush(self):
606618

607619
# wait last message
608620
await asyncio.wait(self._messages_future)
621+
print("ALL MESSAGES WERE SENT TO SERVER")
609622

610623

611624
class WriterAsyncIOStream:

ydb/aio/query/pool.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ async def retry_tx_async(
158158
async def wrapped_callee():
159159
async with self.checkout() as session:
160160
async with session.transaction(tx_mode=tx_mode) as tx:
161+
if tx_mode.name in ["serializable_read_write", "snapshot_read_only"]:
162+
await tx.begin()
161163
result = await callee(tx, *args, **kwargs)
162164
await tx.commit()
163165
return result

ydb/aio/query/transaction.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ async def begin(self, settings: Optional[BaseRequestSettings] = None) -> "QueryT
5757
await self._begin_call(settings)
5858
return self
5959

60+
@base.with_async_transaction_events
6061
async def commit(self, settings: Optional[BaseRequestSettings] = None) -> None:
6162
"""Calls commit on a transaction if it is open otherwise is no-op. If transaction execution
6263
failed then this method raises PreconditionFailed.
@@ -65,7 +66,7 @@ async def commit(self, settings: Optional[BaseRequestSettings] = None) -> None:
6566
6667
:return: A committed transaction or exception if commit is failed
6768
"""
68-
if self._tx_state._already_in(QueryTxStateEnum.COMMITTED):
69+
if self._tx_state._should_skip(QueryTxStateEnum.COMMITTED):
6970
return
7071

7172
if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED:
@@ -76,6 +77,7 @@ async def commit(self, settings: Optional[BaseRequestSettings] = None) -> None:
7677

7778
await self._commit_call(settings)
7879

80+
@base.with_async_transaction_events
7981
async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None:
8082
"""Calls rollback on a transaction if it is open otherwise is no-op. If transaction execution
8183
failed then this method raises PreconditionFailed.
@@ -84,7 +86,7 @@ async def rollback(self, settings: Optional[BaseRequestSettings] = None) -> None
8486
8587
:return: A committed transaction or exception if commit is failed
8688
"""
87-
if self._tx_state._already_in(QueryTxStateEnum.ROLLBACKED):
89+
if self._tx_state._should_skip(QueryTxStateEnum.ROLLBACKED):
8890
return
8991

9092
if self._tx_state._state == QueryTxStateEnum.NOT_INITIALIZED:

ydb/query/base.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import abc
2+
import asyncio
23
import enum
34
import functools
45

@@ -196,3 +197,116 @@ def wrap_execute_query_response(
196197
return convert.ResultSet.from_message(response_pb.result_set, settings)
197198

198199
return None
200+
201+
202+
class TxListener:
203+
def _on_before_commit(self):
204+
pass
205+
206+
def _on_after_commit(self, exc: typing.Optional[BaseException]):
207+
pass
208+
209+
def _on_before_rollback(self):
210+
pass
211+
212+
def _on_after_rollback(self, exc: typing.Optional[BaseException]):
213+
pass
214+
215+
216+
class TxListenerAsyncIO:
217+
async def _on_before_commit(self):
218+
pass
219+
220+
async def _on_after_commit(self, exc: typing.Optional[BaseException]):
221+
pass
222+
223+
async def _on_before_rollback(self):
224+
pass
225+
226+
async def _on_after_rollback(self, exc: typing.Optional[BaseException]):
227+
pass
228+
229+
230+
def with_transaction_events(method):
231+
@functools.wraps(method)
232+
def wrapper(self, *args, **kwargs):
233+
method_name = method.__name__
234+
before_event = f"_on_before_{method_name}"
235+
after_event = f"_on_after_{method_name}"
236+
237+
self._notify_listeners_sync(before_event)
238+
239+
try:
240+
result = method(self, *args, **kwargs)
241+
242+
self._notify_listeners_sync(after_event, exc=None)
243+
244+
return result
245+
except BaseException as e:
246+
self._notify_listeners_sync(after_event, exc=e)
247+
raise
248+
249+
return wrapper
250+
251+
252+
def with_async_transaction_events(method):
253+
@functools.wraps(method)
254+
async def wrapper(self, *args, **kwargs):
255+
method_name = method.__name__
256+
before_event = f"_on_before_{method_name}"
257+
after_event = f"_on_after_{method_name}"
258+
259+
await self._notify_listeners_async(before_event)
260+
261+
try:
262+
result = await method(self, *args, **kwargs)
263+
264+
await self._notify_listeners_async(after_event, exc=None)
265+
266+
return result
267+
except BaseException as e:
268+
await self._notify_listeners_async(after_event, exc=e)
269+
raise
270+
271+
return wrapper
272+
273+
274+
class ListenerHandlerMixin:
275+
def _init_listener_handler(self):
276+
self.listeners = []
277+
278+
def _add_listener(self, listener):
279+
if listener not in self.listeners:
280+
self.listeners.append(listener)
281+
return self
282+
283+
def _remove_listener(self, listener):
284+
if listener in self.listeners:
285+
self.listeners.remove(listener)
286+
return self
287+
288+
def _clear_listeners(self):
289+
self.listeners.clear()
290+
return self
291+
292+
def _notify_sync_listeners(self, event_name: str, **kwargs) -> None:
293+
for listener in self.listeners:
294+
if isinstance(listener, TxListener) and hasattr(listener, event_name):
295+
getattr(listener, event_name)(**kwargs)
296+
297+
async def _notify_async_listeners(self, event_name: str, **kwargs) -> None:
298+
coros = []
299+
for listener in self.listeners:
300+
if isinstance(listener, TxListenerAsyncIO) and hasattr(listener, event_name):
301+
coros.append(getattr(listener, event_name)(**kwargs))
302+
303+
if coros:
304+
await asyncio.gather(*coros)
305+
306+
def _notify_listeners_sync(self, event_name: str, **kwargs) -> None:
307+
self._notify_sync_listeners(event_name, **kwargs)
308+
309+
async def _notify_listeners_async(self, event_name: str, **kwargs) -> None:
310+
# self._notify_sync_listeners(event_name, **kwargs)
311+
312+
await self._notify_async_listeners(event_name, **kwargs)

ydb/query/pool.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def retry_tx_sync(
167167
def wrapped_callee():
168168
with self.checkout(timeout=retry_settings.max_session_acquire_timeout) as session:
169169
with session.transaction(tx_mode=tx_mode) as tx:
170+
if tx_mode.name in ["serializable_read_write", "snapshot_read_only"]:
171+
tx.begin()
170172
result = callee(tx, *args, **kwargs)
171173
tx.commit()
172174
return result

0 commit comments

Comments
 (0)