Skip to content

Commit a8c48ab

Browse files
committed
allow save link to parent client - for prevent stop it (and underlay executors) by GC.
1 parent b62d157 commit a8c48ab

File tree

7 files changed

+87
-45
lines changed

7 files changed

+87
-45
lines changed

tests/topics/test_topic_reader.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ async def test_read_batch(
1616

1717
await reader.close()
1818

19+
async def test_link_to_client(self, driver, topic_path, topic_consumer):
20+
reader = driver.topic_client.reader(topic_path, topic_consumer)
21+
assert reader._parent is driver.topic_client
22+
1923
async def test_read_message(
2024
self, driver, topic_path, topic_with_messages, topic_consumer
2125
):
@@ -84,6 +88,10 @@ def test_read_batch(
8488

8589
reader.close()
8690

91+
def test_link_to_client(self, driver_sync, topic_path, topic_consumer):
92+
reader = driver_sync.topic_client.reader(topic_path, topic_consumer)
93+
assert reader._parent is driver_sync.topic_client
94+
8795
def test_read_message(
8896
self, driver_sync, topic_path, topic_with_messages, topic_consumer
8997
):

tests/topics/test_topic_writer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ async def test_wait_last_seqno(self, driver: ydb.aio.Driver, topic_path):
2727
init_info = await writer2.wait_init()
2828
assert init_info.last_seqno == 5
2929

30+
async def test_link_to_client(self, driver, topic_path, topic_consumer):
31+
writer = driver.topic_client.writer(topic_path)
32+
assert writer._parent is driver.topic_client
33+
3034
async def test_random_producer_id(
3135
self, driver: ydb.aio.Driver, topic_path, topic_reader: ydb.TopicReaderAsyncIO
3236
):
@@ -138,6 +142,10 @@ def test_auto_flush_on_close(self, driver_sync: ydb.Driver, topic_path):
138142
init_info = writer.wait_init()
139143
assert init_info.last_seqno == last_seqno
140144

145+
def test_link_to_client(self, driver_sync, topic_path, topic_consumer):
146+
writer = driver_sync.topic_client.writer(topic_path)
147+
assert writer._parent is driver_sync.topic_client
148+
141149
def test_random_producer_id(
142150
self,
143151
driver_sync: ydb.aio.Driver,

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Codec,
2727
)
2828
from .._errors import check_retriable_error
29+
from .. import topic
2930

3031

3132
class TopicReaderError(YdbError):
@@ -61,11 +62,19 @@ class PublicAsyncIOReader:
6162
_loop: asyncio.AbstractEventLoop
6263
_closed: bool
6364
_reconnector: ReaderReconnector
65+
_parent: typing.Any # need for prevent close parent client by GC
6466

65-
def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
67+
def __init__(
68+
self,
69+
driver: Driver,
70+
settings: topic_reader.PublicReaderSettings,
71+
*,
72+
_parent=None,
73+
):
6674
self._loop = asyncio.get_running_loop()
6775
self._closed = False
6876
self._reconnector = ReaderReconnector(driver, settings)
77+
self._parent = _parent
6978

7079
async def __aenter__(self):
7180
return self
@@ -78,7 +87,7 @@ def __del__(self):
7887
self._loop.create_task(self.close(flush=False), name="close reader")
7988

8089
async def receive_batch(
81-
self,
90+
self,
8291
) -> typing.Union[datatypes.PublicBatch, None]:
8392
"""
8493
Get one messages batch from reader.
@@ -99,7 +108,7 @@ async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]:
99108
return self._reconnector.receive_message_nowait()
100109

101110
def commit(
102-
self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]
111+
self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]
103112
):
104113
"""
105114
Write commit message to a buffer.
@@ -110,7 +119,7 @@ def commit(
110119
self._reconnector.commit(batch)
111120

112121
async def commit_with_ack(
113-
self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]
122+
self, batch: typing.Union[datatypes.PublicMessage, datatypes.PublicBatch]
114123
):
115124
"""
116125
write commit message to a buffer and wait ack from the server.
@@ -195,7 +204,7 @@ def receive_message_nowait(self):
195204
return self._stream_reader.receive_message_nowait()
196205

197206
def commit(
198-
self, batch: datatypes.ICommittable
207+
self, batch: datatypes.ICommittable
199208
) -> datatypes.PartitionSession.CommitAckWaiter:
200209
return self._stream_reader.commit(batch)
201210

@@ -254,10 +263,10 @@ class ReaderStream:
254263
_get_token_function: Callable[[], str]
255264

256265
def __init__(
257-
self,
258-
reader_reconnector_id: int,
259-
settings: topic_reader.PublicReaderSettings,
260-
get_token_function: Optional[Callable[[], str]] = None,
266+
self,
267+
reader_reconnector_id: int,
268+
settings: topic_reader.PublicReaderSettings,
269+
get_token_function: Optional[Callable[[], str]] = None,
261270
):
262271
self._loop = asyncio.get_running_loop()
263272
self._id = ReaderStream._static_id_counter.inc_and_get()
@@ -286,9 +295,9 @@ def __init__(
286295

287296
@staticmethod
288297
async def create(
289-
reader_reconnector_id: int,
290-
driver: SupportedDriverType,
291-
settings: topic_reader.PublicReaderSettings,
298+
reader_reconnector_id: int,
299+
driver: SupportedDriverType,
300+
settings: topic_reader.PublicReaderSettings,
292301
) -> "ReaderStream":
293302
stream = GrpcWrapperAsyncIO(StreamReadMessage.FromServer.from_proto)
294303

@@ -306,7 +315,7 @@ async def create(
306315
return reader
307316

308317
async def _start(
309-
self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest
318+
self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest
310319
):
311320
if self._started:
312321
raise TopicReaderError("Double start ReaderStream")
@@ -372,7 +381,7 @@ def receive_message_nowait(self):
372381
return message
373382

374383
def commit(
375-
self, batch: datatypes.ICommittable
384+
self, batch: datatypes.ICommittable
376385
) -> datatypes.PartitionSession.CommitAckWaiter:
377386
partition_session = batch._commit_get_partition_session()
378387

@@ -426,19 +435,19 @@ async def _read_messages_loop(self):
426435
self._on_read_response(message.server_message)
427436

428437
elif isinstance(
429-
message.server_message, StreamReadMessage.CommitOffsetResponse
438+
message.server_message, StreamReadMessage.CommitOffsetResponse
430439
):
431440
self._on_commit_response(message.server_message)
432441

433442
elif isinstance(
434-
message.server_message,
435-
StreamReadMessage.StartPartitionSessionRequest,
443+
message.server_message,
444+
StreamReadMessage.StartPartitionSessionRequest,
436445
):
437446
self._on_start_partition_session(message.server_message)
438447

439448
elif isinstance(
440-
message.server_message,
441-
StreamReadMessage.StopPartitionSessionRequest,
449+
message.server_message,
450+
StreamReadMessage.StopPartitionSessionRequest,
442451
):
443452
self._on_partition_session_stop(message.server_message)
444453

@@ -470,12 +479,12 @@ async def _update_token(self, token: str):
470479
self._update_token_event.clear()
471480

472481
def _on_start_partition_session(
473-
self, message: StreamReadMessage.StartPartitionSessionRequest
482+
self, message: StreamReadMessage.StartPartitionSessionRequest
474483
):
475484
try:
476485
if (
477-
message.partition_session.partition_session_id
478-
in self._partition_sessions
486+
message.partition_session.partition_session_id
487+
in self._partition_sessions
479488
):
480489
raise TopicReaderError(
481490
"Double start partition session: %s"
@@ -506,7 +515,7 @@ def _on_start_partition_session(
506515
self._set_first_error(err)
507516

508517
def _on_partition_session_stop(
509-
self, message: StreamReadMessage.StopPartitionSessionRequest
518+
self, message: StreamReadMessage.StopPartitionSessionRequest
510519
):
511520
if message.partition_session_id not in self._partition_sessions:
512521
# may if receive stop partition with graceful=false after response on stop partition
@@ -554,7 +563,7 @@ def _buffer_release_bytes(self, bytes_size):
554563
)
555564

556565
def _read_response_to_batches(
557-
self, message: StreamReadMessage.ReadResponse
566+
self, message: StreamReadMessage.ReadResponse
558567
) -> typing.List[datatypes.PublicBatch]:
559568
batches = []
560569

@@ -564,7 +573,7 @@ def _read_response_to_batches(
564573

565574
bytes_per_batch = message.bytes_size // batch_count
566575
additional_bytes_to_last_batch = (
567-
message.bytes_size - bytes_per_batch * batch_count
576+
message.bytes_size - bytes_per_batch * batch_count
568577
)
569578

570579
for partition_data in message.partition_data:

ydb/_topic_reader/topic_reader_sync.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@ class TopicReaderSync:
2525
_caller: CallFromSyncToAsync
2626
_async_reader: PublicAsyncIOReader
2727
_closed: bool
28+
_parent: typing.Any # need for prevent stop the client by GC
2829

2930
def __init__(
3031
self,
3132
driver: SupportedDriverType,
3233
settings: PublicReaderSettings,
3334
*,
3435
eventloop: Optional[asyncio.AbstractEventLoop] = None,
36+
_parent=None, # need for prevent stop the client by GC
3537
):
3638
self._closed = False
3739

@@ -49,6 +51,8 @@ async def create_reader():
4951
create_reader(), loop
5052
).result()
5153

54+
self._parent = _parent
55+
5256
def __del__(self):
5357
self.close(flush=False)
5458

ydb/_topic_writer/topic_writer_asyncio.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,29 @@
4040
GrpcWrapperAsyncIO,
4141
)
4242

43+
from .. import topic
44+
4345
logger = logging.getLogger(__name__)
4446

4547

4648
class WriterAsyncIO:
4749
_loop: asyncio.AbstractEventLoop
4850
_reconnector: "WriterAsyncIOReconnector"
4951
_closed: bool
52+
_parent: typing.Any # need for prevent close parent client by GC
5053

51-
def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings):
54+
def __init__(
55+
self,
56+
driver: SupportedDriverType,
57+
settings: PublicWriterSettings,
58+
_client=None,
59+
):
5260
self._loop = asyncio.get_running_loop()
5361
self._closed = False
5462
self._reconnector = WriterAsyncIOReconnector(
5563
driver=driver, settings=WriterSettings(settings)
5664
)
65+
self._parent = _client
5766

5867
async def __aenter__(self) -> "WriterAsyncIO":
5968
return self

ydb/_topic_writer/topic_writer_sync.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import typing
45
from concurrent.futures import Future
56
from typing import Union, List, Optional
67

8+
from .. import topic
79
from .._grpc.grpcwrapper.common_utils import SupportedDriverType
810
from .topic_writer import (
911
PublicWriterSettings,
@@ -25,13 +27,15 @@ class WriterSync:
2527
_caller: CallFromSyncToAsync
2628
_async_writer: WriterAsyncIO
2729
_closed: bool
30+
_parent: typing.Any # need for prevent close parent client by GC
2831

2932
def __init__(
30-
self,
31-
driver: SupportedDriverType,
32-
settings: PublicWriterSettings,
33-
*,
34-
eventloop: Optional[asyncio.AbstractEventLoop] = None,
33+
self,
34+
driver: SupportedDriverType,
35+
settings: PublicWriterSettings,
36+
*,
37+
eventloop: Optional[asyncio.AbstractEventLoop] = None,
38+
_parent=None
3539
):
3640

3741
self._closed = False
@@ -49,6 +53,7 @@ async def create_async_writer():
4953
self._async_writer = self._caller.safe_call_with_result(
5054
create_async_writer(), None
5155
)
56+
self._parent = _parent
5257

5358
def __enter__(self):
5459
return self
@@ -96,17 +101,17 @@ def wait_init(self, *, timeout: TimeoutType = None) -> PublicWriterInitInfo:
96101
)
97102

98103
def write(
99-
self,
100-
messages: Union[Message, List[Message]],
101-
timeout: TimeoutType = None,
104+
self,
105+
messages: Union[Message, List[Message]],
106+
timeout: TimeoutType = None,
102107
):
103108
self._check_closed()
104109

105110
self._caller.safe_call_with_result(self._async_writer.write(messages), timeout)
106111

107112
def async_write_with_ack(
108-
self,
109-
messages: Union[Message, List[Message]],
113+
self,
114+
messages: Union[Message, List[Message]],
110115
) -> Future[Union[PublicWriteResult, List[PublicWriteResult]]]:
111116
self._check_closed()
112117

@@ -115,9 +120,9 @@ def async_write_with_ack(
115120
)
116121

117122
def write_with_ack(
118-
self,
119-
messages: Union[Message, List[Message]],
120-
timeout: Union[float, None] = None,
123+
self,
124+
messages: Union[Message, List[Message]],
125+
timeout: Union[float, None] = None,
121126
) -> Union[PublicWriteResult, List[PublicWriteResult]]:
122127
self._check_closed()
123128

0 commit comments

Comments
 (0)