Skip to content

Commit 84ac0a9

Browse files
committed
Fix review comments
1 parent c7c1f9f commit 84ac0a9

File tree

5 files changed

+169
-48
lines changed

5 files changed

+169
-48
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import asyncio
2+
import argparse
3+
import logging
4+
import ydb
5+
6+
7+
async def connect(endpoint: str, database: str) -> ydb.aio.Driver:
8+
config = ydb.DriverConfig(endpoint=endpoint, database=database)
9+
config.credentials = ydb.credentials_from_env_variables()
10+
driver = ydb.aio.Driver(config)
11+
await driver.wait(5, fail_fast=True)
12+
return driver
13+
14+
15+
async def create_topic(driver: ydb.aio.Driver, topic: str, consumer: str):
16+
try:
17+
await driver.topic_client.drop_topic(topic)
18+
except ydb.SchemeError:
19+
pass
20+
21+
await driver.topic_client.create_topic(topic, consumers=[consumer])
22+
23+
24+
async def write_with_tx_example(driver: ydb.aio.Driver, topic: str, message_count: int = 10):
25+
async with ydb.aio.QuerySessionPool(driver) as session_pool:
26+
27+
async def callee(tx: ydb.aio.QueryTxContext):
28+
tx_writer: ydb.TopicTxWriterAsyncIO = driver.topic_client.tx_writer(tx, topic)
29+
30+
for i in range(message_count):
31+
async with await tx.execute(query=f"select {i} as res;") as result_stream:
32+
async for result_set in result_stream:
33+
message = str(result_set.rows[0]["res"])
34+
await tx_writer.write(ydb.TopicWriterMessage(message))
35+
print(f"Message {result_set.rows[0]['res']} was written with tx.")
36+
37+
await session_pool.retry_tx_async(callee)
38+
39+
40+
async def read_with_tx_example(driver: ydb.aio.Driver, topic: str, consumer: str, message_count: int = 10):
41+
async with driver.topic_client.reader(topic, consumer) as reader:
42+
async with ydb.aio.QuerySessionPool(driver) as session_pool:
43+
for _ in range(message_count):
44+
45+
async def callee(tx: ydb.aio.QueryTxContext):
46+
batch = await reader.receive_batch_with_tx(tx, max_messages=1)
47+
print(f"Message {batch.messages[0].data.decode()} was read with tx.")
48+
49+
await session_pool.retry_tx_async(callee)
50+
51+
52+
async def main():
53+
parser = argparse.ArgumentParser(
54+
formatter_class=argparse.RawDescriptionHelpFormatter,
55+
description="""YDB topic basic example.\n""",
56+
)
57+
parser.add_argument("-d", "--database", default="/local", help="Name of the database to use")
58+
parser.add_argument("-e", "--endpoint", default="grpc://localhost:2136", help="Endpoint url to use")
59+
parser.add_argument("-p", "--path", default="test-topic", help="Topic name")
60+
parser.add_argument("-c", "--consumer", default="consumer", help="Consumer name")
61+
parser.add_argument("-v", "--verbose", default=False, action="store_true")
62+
parser.add_argument(
63+
"-s",
64+
"--skip-drop-and-create-topic",
65+
default=False,
66+
action="store_true",
67+
help="Use existed topic, skip remove it and re-create",
68+
)
69+
70+
args = parser.parse_args()
71+
72+
if args.verbose:
73+
logger = logging.getLogger("topicexample")
74+
logger.setLevel(logging.DEBUG)
75+
logger.addHandler(logging.StreamHandler())
76+
77+
async with await connect(args.endpoint, args.database) as driver:
78+
if not args.skip_drop_and_create_topic:
79+
await create_topic(driver, args.path, args.consumer)
80+
81+
await write_with_tx_example(driver, args.path)
82+
await read_with_tx_example(driver, args.path, args.consumer)
83+
84+
85+
if __name__ == "__main__":
86+
asyncio.run(main())
Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,54 @@
1-
import asyncio
21
import argparse
32
import logging
43
import ydb
54

65

7-
async def connect(endpoint: str, database: str) -> ydb.aio.Driver:
6+
def connect(endpoint: str, database: str) -> ydb.Driver:
87
config = ydb.DriverConfig(endpoint=endpoint, database=database)
98
config.credentials = ydb.credentials_from_env_variables()
10-
driver = ydb.aio.Driver(config)
11-
await driver.wait(5, fail_fast=True)
9+
driver = ydb.Driver(config)
10+
driver.wait(5, fail_fast=True)
1211
return driver
1312

1413

15-
async def create_topic(driver: ydb.aio.Driver, topic: str, consumer: str):
14+
def create_topic(driver: ydb.Driver, topic: str, consumer: str):
1615
try:
17-
await driver.topic_client.drop_topic(topic)
16+
driver.topic_client.drop_topic(topic)
1817
except ydb.SchemeError:
1918
pass
2019

21-
await driver.topic_client.create_topic(topic, consumers=[consumer])
20+
driver.topic_client.create_topic(topic, consumers=[consumer])
2221

2322

24-
async def write_with_tx_example(driver: ydb.aio.Driver, topic: str, message_count: int = 10):
25-
async with ydb.aio.QuerySessionPool(driver) as session_pool:
23+
def write_with_tx_example(driver: ydb.Driver, topic: str, message_count: int = 10):
24+
with ydb.QuerySessionPool(driver) as session_pool:
2625

27-
async def callee(tx: ydb.aio.QueryTxContext):
28-
print(f"TX ID: {tx.tx_id}")
29-
print(f"TX STATE: {tx._tx_state._state.value}")
30-
tx_writer: ydb.TopicTxWriterAsyncIO = driver.topic_client.tx_writer(tx, topic)
31-
print(f"TX ID: {tx.tx_id}")
32-
print(f"TX STATE: {tx._tx_state._state.value}")
33-
for i in range(message_count):
34-
result_stream = await tx.execute(query=f"select {i} as res")
35-
messages = [result_set.rows[0]["res"] async for result_set in result_stream]
36-
37-
await tx_writer.write([ydb.TopicWriterMessage(data=str(message)) for message in messages])
26+
def callee(tx: ydb.QueryTxContext):
27+
tx_writer: ydb.TopicTxWriter = driver.topic_client.tx_writer(tx, topic)
3828

39-
print(f"Messages {messages} were written with tx.")
29+
for i in range(message_count):
30+
result_stream = tx.execute(query=f"select {i} as res;")
31+
for result_set in result_stream:
32+
message = str(result_set.rows[0]["res"])
33+
tx_writer.write(ydb.TopicWriterMessage(message))
34+
print(f"Message {message} was written with tx.")
4035

41-
await session_pool.retry_tx_async(callee)
36+
session_pool.retry_tx_sync(callee)
4237

4338

44-
async def read_with_tx_example(driver: ydb.aio.Driver, topic: str, consumer: str, message_count: int = 10):
45-
async with driver.topic_client.reader(topic, consumer) as reader:
46-
async with ydb.aio.QuerySessionPool(driver) as session_pool:
39+
def read_with_tx_example(driver: ydb.Driver, topic: str, consumer: str, message_count: int = 10):
40+
with driver.topic_client.reader(topic, consumer) as reader:
41+
with ydb.QuerySessionPool(driver) as session_pool:
4742
for _ in range(message_count):
4843

49-
async def callee(tx: ydb.aio.QueryTxContext):
50-
batch = await reader.receive_batch_with_tx(tx, max_messages=1)
51-
print(f"Messages {batch.messages[0].data} were read with tx.")
44+
def callee(tx: ydb.QueryTxContext):
45+
batch = reader.receive_batch_with_tx(tx, max_messages=1)
46+
print(f"Message {batch.messages[0].data.decode()} was read with tx.")
5247

53-
await session_pool.retry_tx_async(callee)
48+
session_pool.retry_tx_sync(callee)
5449

5550

56-
async def main():
51+
def main():
5752
parser = argparse.ArgumentParser(
5853
formatter_class=argparse.RawDescriptionHelpFormatter,
5954
description="""YDB topic basic example.\n""",
@@ -78,13 +73,13 @@ async def main():
7873
logger.setLevel(logging.DEBUG)
7974
logger.addHandler(logging.StreamHandler())
8075

81-
driver = await connect(args.endpoint, args.database)
82-
if not args.skip_drop_and_create_topic:
83-
await create_topic(driver, args.path, args.consumer)
76+
with connect(args.endpoint, args.database) as driver:
77+
if not args.skip_drop_and_create_topic:
78+
create_topic(driver, args.path, args.consumer)
8479

85-
await write_with_tx_example(driver, args.path)
86-
await read_with_tx_example(driver, args.path, args.consumer)
80+
write_with_tx_example(driver, args.path)
81+
read_with_tx_example(driver, args.path, args.consumer)
8782

8883

8984
if __name__ == "__main__":
90-
asyncio.run(main())
85+
main()

tests/topics/test_topic_transactions.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,19 @@ async def callee(tx: ydb.aio.QueryTxContext):
357357
with pytest.raises(asyncio.TimeoutError):
358358
await wait_for(topic_reader.receive_message(), 0.1)
359359

360+
async def test_writes_do_not_conflict_with_executes(
361+
self, driver: ydb.aio.Driver, topic_path
362+
):
363+
async with ydb.aio.QuerySessionPool(driver) as pool:
364+
365+
async def callee(tx: ydb.aio.QueryTxContext):
366+
tx_writer = driver.topic_client.tx_writer(tx, topic_path)
367+
for _ in range(3):
368+
async with await tx.execute("select 1"):
369+
await tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
370+
371+
await pool.retry_tx_async(callee, retry_settings=DEFAULT_RETRY_SETTINGS)
372+
360373

361374
class TestTopicTransactionalWriterSync:
362375
def test_commit(self, driver_sync: ydb.Driver, topic_path, topic_reader_sync: ydb.TopicReader):
@@ -445,3 +458,16 @@ def callee(tx: ydb.QueryTxContext):
445458

446459
with pytest.raises(TimeoutError):
447460
topic_reader_sync.receive_message(timeout=DEFAULT_TIMEOUT)
461+
462+
def test_writes_do_not_conflict_with_executes(
463+
self, driver_sync: ydb.Driver, topic_path
464+
):
465+
with ydb.QuerySessionPool(driver_sync) as pool:
466+
467+
def callee(tx: ydb.QueryTxContext):
468+
tx_writer = driver_sync.topic_client.tx_writer(tx, topic_path)
469+
for _ in range(3):
470+
with tx.execute("select 1"):
471+
tx_writer.write(ydb.TopicWriterMessage(data="123".encode()))
472+
473+
pool.retry_tx_sync(callee, retry_settings=DEFAULT_RETRY_SETTINGS)

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
):
8585
self._loop = asyncio.get_running_loop()
8686
self._closed = False
87-
self._reconnector = ReaderReconnector(driver, settings)
87+
self._reconnector = ReaderReconnector(driver, settings, self._loop)
8888
self._parent = _parent
8989

9090
async def __aenter__(self):
@@ -190,18 +190,24 @@ class ReaderReconnector:
190190
_first_error: asyncio.Future[YdbError]
191191
_tx_to_batches_map: Dict[str, typing.List[datatypes.PublicBatch]]
192192

193-
def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
193+
def __init__(
194+
self,
195+
driver: Driver,
196+
settings: topic_reader.PublicReaderSettings,
197+
loop: Optional[asyncio.AbstractEventLoop] = None,
198+
):
194199
self._id = self._static_reader_reconnector_counter.inc_and_get()
195200
self._settings = settings
196201
self._driver = driver
202+
self._loop = loop if loop is not None else asyncio.get_running_loop()
197203
self._background_tasks = set()
198204

199205
self._state_changed = asyncio.Event()
200206
self._stream_reader = None
201207
self._background_tasks.add(asyncio.create_task(self._connection_loop()))
202208
self._first_error = asyncio.get_running_loop().create_future()
203209

204-
self._tx_to_batches_map = defaultdict(list)
210+
self._tx_to_batches_map = dict()
205211

206212
async def _connection_loop(self):
207213
attempt = 0
@@ -254,22 +260,23 @@ def receive_batch_with_tx_nowait(self, tx: "BaseQueryTxContext", max_messages: O
254260
max_messages=max_messages,
255261
)
256262

257-
self._init_tx_if_needed(tx)
263+
self._init_tx(tx)
258264

259265
self._tx_to_batches_map[tx.tx_id].append(batch)
260266

261-
tx._add_callback(TxEvent.AFTER_COMMIT, batch._update_partition_offsets, None) # probably should be current loop
267+
tx._add_callback(TxEvent.AFTER_COMMIT, batch._update_partition_offsets, self._loop)
262268

263269
return batch
264270

265271
def receive_message_nowait(self):
266272
return self._stream_reader.receive_message_nowait()
267273

268-
def _init_tx_if_needed(self, tx: "BaseQueryTxContext"):
274+
def _init_tx(self, tx: "BaseQueryTxContext"):
269275
if tx.tx_id not in self._tx_to_batches_map: # Init tx callbacks
270-
tx._add_callback(TxEvent.BEFORE_COMMIT, self._commit_batches_with_tx, None)
271-
tx._add_callback(TxEvent.AFTER_COMMIT, self._handle_after_tx_commit, None)
272-
tx._add_callback(TxEvent.AFTER_ROLLBACK, self._handle_after_tx_rollback, None)
276+
self._tx_to_batches_map[tx.tx_id] = []
277+
tx._add_callback(TxEvent.BEFORE_COMMIT, self._commit_batches_with_tx, self._loop)
278+
tx._add_callback(TxEvent.AFTER_COMMIT, self._handle_after_tx_commit, self._loop)
279+
tx._add_callback(TxEvent.AFTER_ROLLBACK, self._handle_after_tx_rollback, self._loop)
273280

274281
async def _commit_batches_with_tx(self, tx: "BaseQueryTxContext"):
275282
grouped_batches = defaultdict(lambda: defaultdict(list))

ydb/_topic_writer/topic_writer_asyncio.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import concurrent.futures
33
import datetime
4-
import functools
54
import gzip
65
import typing
76
from collections import deque
@@ -186,6 +185,10 @@ def __init__(
186185
self._parent = _client
187186
self._is_implicit = _is_implicit
188187

188+
# For some reason, creating partition could conflict with other session operations.
189+
# Could be removed later.
190+
self._first_write = True
191+
189192
tx._add_callback(TxEvent.BEFORE_COMMIT, self._on_before_commit, self._loop)
190193
tx._add_callback(TxEvent.BEFORE_ROLLBACK, self._on_before_rollback, self._loop)
191194

@@ -199,14 +202,18 @@ async def write(
199202
200203
For wait with timeout use asyncio.wait_for.
201204
"""
202-
await self.write_with_ack(messages)
205+
if self._first_write:
206+
self._first_write = False
207+
return await super().write_with_ack(messages)
208+
return await super().write(messages)
209+
203210

204211
async def _on_before_commit(self, tx: "BaseQueryTxContext"):
205212
if self._is_implicit:
206213
return
207-
await self.flush()
208214
await self.close()
209215

216+
210217
async def _on_before_rollback(self, tx: "BaseQueryTxContext"):
211218
if self._is_implicit:
212219
return

0 commit comments

Comments
 (0)