Skip to content

Commit cd93e4a

Browse files
authored
Merge pull request #306 Fix handle stop partition request
2 parents 759ccfc + 0a9d7fa commit cd93e4a

File tree

7 files changed

+184
-34
lines changed

7 files changed

+184
-34
lines changed

tests/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,26 @@ async def topic2_path(driver, topic_consumer, database) -> str:
176176
return topic_path
177177

178178

179+
@pytest.fixture()
180+
@pytest.mark.asyncio()
181+
async def topic_with_two_partitions_path(driver, topic_consumer, database) -> str:
182+
topic_path = database + "/test-topic-two-partitions"
183+
184+
try:
185+
await driver.topic_client.drop_topic(topic_path)
186+
except issues.SchemeError:
187+
pass
188+
189+
await driver.topic_client.create_topic(
190+
path=topic_path,
191+
consumers=[topic_consumer],
192+
min_active_partitions=2,
193+
partition_count_limit=2,
194+
)
195+
196+
return topic_path
197+
198+
179199
@pytest.fixture()
180200
@pytest.mark.asyncio()
181201
async def topic_with_messages(driver, topic_consumer, database):

tests/topics/test_topic_reader.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
13
import pytest
24

35
import ydb
@@ -161,3 +163,45 @@ def decode(b: bytes):
161163
with driver_sync.topic_client.reader(topic_path, topic_consumer, decoders={codec: decode}) as reader:
162164
batch = reader.receive_batch()
163165
assert batch.messages[0].data.decode() == "123"
166+
167+
168+
@pytest.mark.asyncio
169+
class TestBugFixesAsync:
170+
async def test_issue_297_bad_handle_stop_partition(
171+
self, driver, topic_consumer, topic_with_two_partitions_path: str
172+
):
173+
async def wait(fut):
174+
return await asyncio.wait_for(fut, timeout=10)
175+
176+
topic = topic_with_two_partitions_path # type: str
177+
178+
async with driver.topic_client.writer(topic, partition_id=0) as writer:
179+
await writer.write_with_ack("00")
180+
181+
async with driver.topic_client.writer(topic, partition_id=1) as writer:
182+
await writer.write_with_ack("01")
183+
184+
# Start first reader and receive messages from both partitions
185+
reader0 = driver.topic_client.reader(topic, consumer=topic_consumer)
186+
await wait(reader0.receive_message())
187+
await wait(reader0.receive_message())
188+
189+
# Start second reader for same topic, same consumer, partition 1
190+
reader1 = driver.topic_client.reader(topic, consumer=topic_consumer)
191+
192+
# receive uncommited message
193+
await reader1.receive_message()
194+
195+
# write one message for every partition
196+
async with driver.topic_client.writer(topic, partition_id=0) as writer:
197+
await writer.write_with_ack("10")
198+
async with driver.topic_client.writer(topic, partition_id=1) as writer:
199+
await writer.write_with_ack("11")
200+
201+
msg0 = await wait(reader0.receive_message())
202+
msg1 = await wait(reader1.receive_message())
203+
204+
datas = [msg0.data.decode(), msg1.data.decode()]
205+
datas.sort()
206+
207+
assert datas == ["10", "11"]

ydb/_grpc/grpcwrapper/ydb_topic.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import datetime
24
import enum
35
import typing
@@ -8,6 +10,7 @@
810

911
from . import ydb_topic_public_types
1012
from ... import scheme
13+
from ... import issues
1114

1215
# Workaround for good IDE and universal for runtime
1316
if typing.TYPE_CHECKING:
@@ -588,16 +591,32 @@ def from_proto(
588591
)
589592

590593
@dataclass
591-
class PartitionSessionStatusRequest:
594+
class PartitionSessionStatusRequest(IToProto):
592595
partition_session_id: int
593596

597+
def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.PartitionSessionStatusRequest:
598+
return ydb_topic_pb2.StreamReadMessage.PartitionSessionStatusRequest(
599+
partition_session_id=self.partition_session_id
600+
)
601+
594602
@dataclass
595-
class PartitionSessionStatusResponse:
603+
class PartitionSessionStatusResponse(IFromProto):
596604
partition_session_id: int
597605
partition_offsets: "OffsetsRange"
598606
committed_offset: int
599607
write_time_high_watermark: float
600608

609+
@staticmethod
610+
def from_proto(
611+
msg: ydb_topic_pb2.StreamReadMessage.PartitionSessionStatusResponse,
612+
) -> "StreamReadMessage.PartitionSessionStatusResponse":
613+
return StreamReadMessage.PartitionSessionStatusResponse(
614+
partition_session_id=msg.partition_session_id,
615+
partition_offsets=OffsetsRange.from_proto(msg.partition_offsets),
616+
committed_offset=msg.committed_offset,
617+
write_time_high_watermark=msg.write_time_high_watermark,
618+
)
619+
601620
@dataclass
602621
class StartPartitionSessionRequest(IFromProto):
603622
partition_session: "StreamReadMessage.PartitionSession"
@@ -632,15 +651,30 @@ def to_proto(
632651
return res
633652

634653
@dataclass
635-
class StopPartitionSessionRequest:
654+
class StopPartitionSessionRequest(IFromProto):
636655
partition_session_id: int
637656
graceful: bool
638657
committed_offset: int
639658

659+
@staticmethod
660+
def from_proto(
661+
msg: ydb_topic_pb2.StreamReadMessage.StopPartitionSessionRequest,
662+
) -> StreamReadMessage.StopPartitionSessionRequest:
663+
return StreamReadMessage.StopPartitionSessionRequest(
664+
partition_session_id=msg.partition_session_id,
665+
graceful=msg.graceful,
666+
committed_offset=msg.committed_offset,
667+
)
668+
640669
@dataclass
641-
class StopPartitionSessionResponse:
670+
class StopPartitionSessionResponse(IToProto):
642671
partition_session_id: int
643672

673+
def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.StopPartitionSessionResponse:
674+
return ydb_topic_pb2.StreamReadMessage.StopPartitionSessionResponse(
675+
partition_session_id=self.partition_session_id,
676+
)
677+
644678
@dataclass
645679
class FromClient(IToProto):
646680
client_message: "ReaderMessagesFromClientToServer"
@@ -660,6 +694,10 @@ def to_proto(self) -> ydb_topic_pb2.StreamReadMessage.FromClient:
660694
res.update_token_request.CopyFrom(self.client_message.to_proto())
661695
elif isinstance(self.client_message, StreamReadMessage.StartPartitionSessionResponse):
662696
res.start_partition_session_response.CopyFrom(self.client_message.to_proto())
697+
elif isinstance(self.client_message, StreamReadMessage.StopPartitionSessionResponse):
698+
res.stop_partition_session_response.CopyFrom(self.client_message.to_proto())
699+
elif isinstance(self.client_message, StreamReadMessage.PartitionSessionStatusRequest):
700+
res.start_partition_session_response.CopyFrom(self.client_message.to_proto())
663701
else:
664702
raise NotImplementedError("Unknown message type: %s" % type(self.client_message))
665703
return res
@@ -694,17 +732,32 @@ def from_proto(
694732
return StreamReadMessage.FromServer(
695733
server_status=server_status,
696734
server_message=StreamReadMessage.StartPartitionSessionRequest.from_proto(
697-
msg.start_partition_session_request
735+
msg.start_partition_session_request,
736+
),
737+
)
738+
elif mess_type == "stop_partition_session_request":
739+
return StreamReadMessage.FromServer(
740+
server_status=server_status,
741+
server_message=StreamReadMessage.StopPartitionSessionRequest.from_proto(
742+
msg.stop_partition_session_request
698743
),
699744
)
700745
elif mess_type == "update_token_response":
701746
return StreamReadMessage.FromServer(
702747
server_status=server_status,
703748
server_message=UpdateTokenResponse.from_proto(msg.update_token_response),
704749
)
705-
706-
# todo replace exception to log
707-
raise NotImplementedError()
750+
elif mess_type == "partition_session_status_response":
751+
return StreamReadMessage.FromServer(
752+
server_status=server_status,
753+
server_message=StreamReadMessage.PartitionSessionStatusResponse.from_proto(
754+
msg.partition_session_status_response
755+
),
756+
)
757+
else:
758+
raise issues.UnexpectedGrpcMessage(
759+
"Unexpected message while parse ReaderMessagesFromServerToClient: '%s'" % mess_type
760+
)
708761

709762

710763
ReaderMessagesFromClientToServer = Union[

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
Codec,
2727
)
2828
from .._errors import check_retriable_error
29+
import logging
30+
31+
logger = logging.getLogger(__name__)
2932

3033

3134
class TopicReaderError(YdbError):
@@ -146,7 +149,6 @@ class ReaderReconnector:
146149

147150
def __init__(self, driver: Driver, settings: topic_reader.PublicReaderSettings):
148151
self._id = self._static_reader_reconnector_counter.inc_and_get()
149-
150152
self._settings = settings
151153
self._driver = driver
152154
self._background_tasks = set()
@@ -395,39 +397,42 @@ async def _read_messages_loop(self):
395397
)
396398
)
397399
while True:
398-
message = await self._stream.receive() # type: StreamReadMessage.FromServer
399-
_process_response(message.server_status)
400+
try:
401+
message = await self._stream.receive() # type: StreamReadMessage.FromServer
402+
_process_response(message.server_status)
400403

401-
if isinstance(message.server_message, StreamReadMessage.ReadResponse):
402-
self._on_read_response(message.server_message)
404+
if isinstance(message.server_message, StreamReadMessage.ReadResponse):
405+
self._on_read_response(message.server_message)
403406

404-
elif isinstance(message.server_message, StreamReadMessage.CommitOffsetResponse):
405-
self._on_commit_response(message.server_message)
407+
elif isinstance(message.server_message, StreamReadMessage.CommitOffsetResponse):
408+
self._on_commit_response(message.server_message)
406409

407-
elif isinstance(
408-
message.server_message,
409-
StreamReadMessage.StartPartitionSessionRequest,
410-
):
411-
self._on_start_partition_session(message.server_message)
410+
elif isinstance(
411+
message.server_message,
412+
StreamReadMessage.StartPartitionSessionRequest,
413+
):
414+
self._on_start_partition_session(message.server_message)
412415

413-
elif isinstance(
414-
message.server_message,
415-
StreamReadMessage.StopPartitionSessionRequest,
416-
):
417-
self._on_partition_session_stop(message.server_message)
416+
elif isinstance(
417+
message.server_message,
418+
StreamReadMessage.StopPartitionSessionRequest,
419+
):
420+
self._on_partition_session_stop(message.server_message)
418421

419-
elif isinstance(message.server_message, UpdateTokenResponse):
420-
self._update_token_event.set()
422+
elif isinstance(message.server_message, UpdateTokenResponse):
423+
self._update_token_event.set()
421424

422-
else:
423-
raise NotImplementedError(
424-
"Unexpected type of StreamReadMessage.FromServer message: %s" % message.server_message
425-
)
425+
else:
426+
raise issues.UnexpectedGrpcMessage(
427+
"Unexpected message in _read_messages_loop: %s" % type(message.server_message)
428+
)
429+
except issues.UnexpectedGrpcMessage as e:
430+
logger.exception("unexpected message in stream reader: %s" % e)
426431

427432
self._state_changed.set()
428433
except Exception as e:
429434
self._set_first_error(e)
430-
raise
435+
return
431436

432437
async def _update_token_loop(self):
433438
while True:

ydb/_topic_reader/topic_reader_asyncio_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,29 @@ async def test_update_token(self, stream):
11271127

11281128
await reader.close()
11291129

1130+
async def test_read_unknown_message(self, stream, stream_reader, caplog):
1131+
class TestMessage:
1132+
pass
1133+
1134+
# noinspection PyTypeChecker
1135+
stream.from_server.put_nowait(
1136+
StreamReadMessage.FromServer(
1137+
server_status=ServerStatus(
1138+
status=issues.StatusCode.SUCCESS,
1139+
issues=[],
1140+
),
1141+
server_message=TestMessage(),
1142+
)
1143+
)
1144+
1145+
def logged():
1146+
for rec in caplog.records:
1147+
if TestMessage.__name__ in rec.message:
1148+
return True
1149+
return False
1150+
1151+
await wait_condition(logged)
1152+
11301153

11311154
@pytest.mark.asyncio
11321155
class TestReaderReconnector:

ydb/issues.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,11 @@ class SessionPoolEmpty(Error, queue.Empty):
156156
status = StatusCode.SESSION_POOL_EMPTY
157157

158158

159+
class UnexpectedGrpcMessage(Error):
160+
def __init__(self, message: str):
161+
super().__init__(message)
162+
163+
159164
def _format_issues(issues):
160165
if not issues:
161166
return ""

ydb/topic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def reader(
168168
if not decoder_executor:
169169
decoder_executor = self._executor
170170

171-
args = locals()
171+
args = locals().copy()
172172
del args["self"]
173173

174174
settings = TopicReaderSettings(**args)
@@ -188,7 +188,7 @@ def writer(
188188
encoders: Optional[Mapping[_ydb_topic_public_types.PublicCodec, Callable[[bytes], bytes]]] = None,
189189
encoder_executor: Optional[concurrent.futures.Executor] = None, # default shared client executor pool
190190
) -> TopicWriterAsyncIO:
191-
args = locals()
191+
args = locals().copy()
192192
del args["self"]
193193

194194
settings = TopicWriterSettings(**args)

0 commit comments

Comments
 (0)