Skip to content

Commit 9cd5695

Browse files
committed
sync
1 parent 7ac9adc commit 9cd5695

File tree

5 files changed

+178
-16
lines changed

5 files changed

+178
-16
lines changed

ydb/_topic_reader/topic_reader.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818

1919
from ydb._topic_wrapper.common import OffsetsRange, TokenGetterFuncType
20+
from ydb._topic_wrapper.reader import StreamReadMessage
2021

2122

2223
class Selector:
@@ -270,6 +271,16 @@ class PublicReaderSettings:
270271
# connection_timeout: Union[float, None] = None
271272
# retry_policy: Union["RetryPolicy", None] = None
272273

274+
def _init_message(self) -> StreamReadMessage.InitRequest:
275+
return StreamReadMessage.InitRequest(
276+
topics_read_settings=[
277+
StreamReadMessage.InitRequest.TopicReadSettings(
278+
path=self.topic,
279+
)
280+
],
281+
consumer=self.consumer,
282+
)
283+
273284

274285
class Events:
275286
class OnCommit:

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22

33
import asyncio
44
import typing
5+
from asyncio import Task
56
from collections import deque
67
from typing import Optional, Set, Dict
78

9+
from .. import _apis
10+
from ..aio import Driver
811
from ..issues import Error as YdbError
912
from .datatypes import PartitionSession, PublicMessage, PublicBatch
1013
from .topic_reader import PublicReaderSettings
11-
from .._topic_wrapper.common import TokenGetterFuncType, IGrpcWrapperAsyncIO
14+
from .._topic_wrapper.common import TokenGetterFuncType, IGrpcWrapperAsyncIO, SupportedDriverType, GrpcWrapperAsyncIO
1215
from .._topic_wrapper.reader import StreamReadMessage
1316

1417

@@ -22,11 +25,52 @@ def __init__(self):
2225

2326

2427
class PublicAsyncIOReader:
25-
pass
28+
_loop: asyncio.AbstractEventLoop
29+
_reconnector: ReaderReconnector
30+
31+
def __init__(self, driver: Driver, settings: PublicReaderSettings):
32+
self._loop = asyncio.get_running_loop()
33+
self._reconnector = ReaderReconnector(driver, settings)
2634

2735

2836
class ReaderReconnector:
29-
pass
37+
_settings: PublicReaderSettings
38+
_driver: Driver
39+
_background_tasks: Set[Task]
40+
41+
_state_changed: asyncio.Event
42+
_stream_reader: Optional["ReaderStream"]
43+
44+
def __init__(self, driver: Driver, settings: PublicReaderSettings):
45+
self._settings = settings
46+
self._driver = driver
47+
self._background_tasks = set()
48+
49+
self._state_changed = asyncio.Event()
50+
self._stream_reader = None
51+
self._background_tasks.add(asyncio.create_task(self.start()))
52+
53+
async def start(self):
54+
self._stream_reader = await ReaderStream.create(self._driver, self._settings)
55+
self._state_changed.set()
56+
57+
async def wait_message(self):
58+
while True:
59+
if self._stream_reader is not None:
60+
await self._stream_reader.wait_messages()
61+
62+
await self._state_changed.wait()
63+
self._state_changed.clear()
64+
65+
def receive_batch_nowait(self):
66+
return self._stream_reader.receive_batch_nowait()
67+
68+
async def close(self):
69+
await self._stream_reader.close()
70+
for task in self._background_tasks:
71+
task.cancel()
72+
73+
await asyncio.wait(self._background_tasks)
3074

3175

3276
class ReaderStream:
@@ -57,7 +101,22 @@ def __init__(self, settings: PublicReaderSettings):
57101
self._first_error = None
58102
self._message_batches = deque()
59103

60-
async def start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest):
104+
@staticmethod
105+
async def create(
106+
driver: SupportedDriverType,
107+
settings: PublicReaderSettings,
108+
) -> "ReaderStream":
109+
stream = GrpcWrapperAsyncIO(StreamReadMessage.FromServer.from_proto)
110+
111+
await stream.start(
112+
driver, _apis.TopicService.Stub, _apis.TopicService.StreamRead
113+
)
114+
115+
reader = ReaderStream(settings)
116+
await reader._start(stream, settings._init_message())
117+
return reader
118+
119+
async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMessage.InitRequest):
61120
if self._started:
62121
raise TopicReaderError("Double start ReaderStream")
63122

ydb/_topic_reader/topic_reader_asyncio_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def stream_reader(self, stream, default_reader_settings, partition_session
6262
init_message = object()
6363

6464
# noinspection PyTypeChecker
65-
start = asyncio.create_task(reader.start(stream, init_message))
65+
start = asyncio.create_task(reader._start(stream, init_message))
6666

6767
stream.from_server.put_nowait(StreamReadMessage.FromServer(
6868
StreamReadMessage.InitResponse(session_id="test-session")
@@ -178,7 +178,7 @@ async def test_init_reader(self, stream, default_reader_settings):
178178
read_from=None,
179179
)]
180180
)
181-
start_task = asyncio.create_task(reader.start(stream, init_message))
181+
start_task = asyncio.create_task(reader._start(stream, init_message))
182182

183183
sent_message = await wait_for_fast(stream.from_client.get())
184184
expected_sent_init_message = StreamReadMessage.FromClient(client_message=init_message)

ydb/_topic_wrapper/common.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,17 @@ class Codec(IntEnum):
3333

3434

3535
@dataclass
36-
class OffsetsRange:
36+
class OffsetsRange(IFromProto):
3737
start: int
3838
end: int
3939

40+
@staticmethod
41+
def from_proto(msg: ydb_topic_pb2.OffsetsRange) -> "OffsetsRange":
42+
return OffsetsRange(
43+
start=msg.start,
44+
end=msg.end,
45+
)
46+
4047

4148
class IToProto(abc.ABC):
4249
@abc.abstractmethod

ydb/_topic_wrapper/reader.py

Lines changed: 94 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import datetime
2+
import typing
23
from dataclasses import dataclass, field
34
from typing import List, Union, Dict
45

6+
from google.protobuf.message import Message
7+
58
from ydb._topic_wrapper.common import OffsetsRange, IToProto, UpdateTokenRequest, UpdateTokenResponse, IFromProto
69
from google.protobuf.duration_pb2 import Duration as ProtoDuration
710

@@ -14,11 +17,19 @@
1417

1518
class StreamReadMessage:
1619
@dataclass
17-
class PartitionSession:
20+
class PartitionSession(IFromProto):
1821
partition_session_id: int
1922
path: str
2023
partition_id: int
2124

25+
@staticmethod
26+
def from_proto(msg: ydb_topic_pb2.StreamReadMessage.PartitionSession) -> "StreamReadMessage.PartitionSession":
27+
return StreamReadMessage.PartitionSession(
28+
partition_session_id=msg.partition_session_id,
29+
path=msg.path,
30+
partition_id=msg.partition_id,
31+
)
32+
2233
@dataclass
2334
class InitRequest(IToProto):
2435
topics_read_settings: List["StreamReadMessage.InitRequest.TopicReadSettings"]
@@ -56,36 +67,90 @@ def from_proto(msg: ydb_topic_pb2.StreamReadMessage.InitResponse) -> "StreamRead
5667
return StreamReadMessage.InitResponse(session_id=msg.session_id)
5768

5869
@dataclass
59-
class ReadRequest:
70+
class ReadRequest(IToProto):
6071
bytes_size: int
6172

73+
def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.ReadRequest:
74+
res = ydb_topic_pb2.StreamReadMessage.ReadRequest()
75+
res.bytes_size = self.bytes_size
76+
return res
77+
6278
@dataclass
63-
class ReadResponse:
79+
class ReadResponse(IFromProto):
6480
partition_data: List["StreamReadMessage.ReadResponse.PartitionData"]
6581
bytes_size: int
6682

83+
@staticmethod
84+
def from_proto(msg: ydb_topic_pb2.StreamReadMessage.ReadResponse) -> "StreamReadMessage.ReadResponse":
85+
partition_data = []
86+
for proto_partition_data in msg.partition_data:
87+
partition_data.append(StreamReadMessage.ReadResponse.PartitionData.from_proto(proto_partition_data))
88+
return StreamReadMessage.ReadResponse(
89+
partition_data=partition_data,
90+
bytes_size=msg.bytes_size,
91+
)
92+
6793
@dataclass
68-
class MessageData:
94+
class MessageData(IFromProto):
6995
offset: int
7096
seq_no: int
7197
created_at: datetime.datetime
7298
data: bytes
7399
uncompresed_size: int
74100
message_group_id: str
75101

102+
@staticmethod
103+
def from_proto(msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.MessageData) ->\
104+
"StreamReadMessage.ReadResponse.MessageData":
105+
return StreamReadMessage.ReadResponse.MessageData(
106+
offset=msg.offset,
107+
seq_no=msg.seq_no,
108+
created_at=msg.created_at.ToDatetime(),
109+
data=msg.data,
110+
uncompresed_size=msg.uncompressed_size,
111+
message_group_id=msg.message_group_id
112+
)
113+
76114
@dataclass
77-
class Batch:
115+
class Batch(IFromProto):
78116
message_data: List["StreamReadMessage.ReadResponse.MessageData"]
79117
producer_id: str
80118
write_session_meta: Dict[str, str]
81119
codec: int
82120
written_at: datetime.datetime
83121

122+
@staticmethod
123+
def from_proto(msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.Batch) -> \
124+
"StreamReadMessage.ReadResponse.Batch":
125+
message_data = []
126+
for message in msg.message_data:
127+
message_data.append(StreamReadMessage.ReadResponse.MessageData.from_proto(message))
128+
return StreamReadMessage.ReadResponse.Batch(
129+
message_data=message_data,
130+
producer_id=msg.producer_id,
131+
write_session_meta=dict(msg.write_session_meta),
132+
codec=msg.codec,
133+
written_at=msg.written_at.ToDatetime(),
134+
)
135+
136+
84137
@dataclass
85-
class PartitionData:
138+
class PartitionData(IFromProto):
86139
partition_session_id: int
87140
batches: List["StreamReadMessage.ReadResponse.Batch"]
88141

142+
@staticmethod
143+
def from_proto(msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.PartitionData) ->\
144+
"StreamReadMessage.ReadResponse.PartitionData":
145+
batches = []
146+
for proto_batch in msg.batches:
147+
batches.append(StreamReadMessage.ReadResponse.Batch.from_proto(proto_batch))
148+
return StreamReadMessage.ReadResponse.PartitionData(
149+
partition_session_id=msg.partition_session_id,
150+
batches=batches,
151+
)
152+
153+
89154
@dataclass
90155
class CommitOffsetRequest:
91156
commit_offsets: List["PartitionCommitOffset"]
@@ -116,17 +181,33 @@ class PartitionSessionStatusResponse:
116181
write_time_high_watermark: float
117182

118183
@dataclass
119-
class StartPartitionSessionRequest:
184+
class StartPartitionSessionRequest(IFromProto):
120185
partition_session: "StreamReadMessage.PartitionSession"
121186
committed_offset: int
122187
partition_offsets: OffsetsRange
123188

189+
@staticmethod
190+
def from_proto(msg: ydb_topic_pb2.StreamReadMessage.StartPartitionSessionRequest) -> \
191+
"StreamReadMessage.StartPartitionSessionRequest":
192+
return StreamReadMessage.StartPartitionSessionRequest(
193+
partition_session=StreamReadMessage.PartitionSession.from_proto(msg.partition_session),
194+
committed_offset=msg.committed_offset,
195+
partition_offsets=OffsetsRange.from_proto(msg.partition_offsets)
196+
)
197+
124198
@dataclass
125-
class StartPartitionSessionResponse:
199+
class StartPartitionSessionResponse(IToProto):
126200
partition_session_id: int
127201
read_offset: int
128202
commit_offset: int
129203

204+
def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse:
205+
res = ydb_topic_pb2.StreamReadMessage.StartPartitionSessionResponse()
206+
res.partition_session_id = self.partition_session_id
207+
res.read_offset = self.read_offset
208+
res.commit_offset = self.commit_offset
209+
return res
210+
130211
@dataclass
131212
class StopPartitionSessionRequest:
132213
partition_session_id: int
@@ -159,7 +240,11 @@ class FromServer(IFromProto):
159240
@staticmethod
160241
def from_proto(msg: ydb_topic_pb2.StreamReadMessage.FromServer) -> "StreamReadMessage.FromServer":
161242
mess_type = msg.WhichOneof("server_message")
162-
if mess_type == "init_response":
243+
if mess_type == "read_response":
244+
return StreamReadMessage.FromServer(
245+
server_message=StreamReadMessage.ReadResponse.from_proto(msg.init_response)
246+
)
247+
elif mess_type == "init_response":
163248
return StreamReadMessage.FromServer(
164249
server_message=StreamReadMessage.InitResponse.from_proto(msg.init_response),
165250
)

0 commit comments

Comments
 (0)