Skip to content

Commit 8a0bda1

Browse files
committed
handle grpc errors and errors in server status
1 parent cae886a commit 8a0bda1

File tree

5 files changed

+161
-43
lines changed

5 files changed

+161
-43
lines changed

ydb/_topic_reader/topic_reader_asyncio.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import grpc
1010

11-
from .. import _apis
11+
from .. import _apis, issues
1212
from ..aio import Driver
1313
from ..issues import Error as YdbError
1414
from .datatypes import PartitionSession, PublicMessage, PublicBatch
@@ -143,19 +143,19 @@ async def _start(self, stream: IGrpcWrapperAsyncIO, init_message: StreamReadMess
143143
self._background_tasks.add(read_messages_task)
144144

145145
async def wait_messages(self):
146-
if self._closed:
147-
raise TopicReaderStreamClosedError()
148-
149-
while len(self._message_batches) == 0:
146+
while True:
150147
if self._first_error is not None:
151148
raise self._first_error
152149

150+
if len(self._message_batches) > 0:
151+
return
152+
153153
await self._state_changed.wait()
154154
self._state_changed.clear()
155155

156156
def receive_batch_nowait(self):
157-
if self._closed:
158-
raise TopicReaderStreamClosedError()
157+
if self._first_error is not None:
158+
raise self._first_error
159159

160160
try:
161161
batch = self._message_batches.popleft()
@@ -185,8 +185,6 @@ async def _read_messages_loop(self, stream: IGrpcWrapperAsyncIO):
185185
)
186186

187187
self._state_changed.set()
188-
except grpc.RpcError as e:
189-
190188
except Exception as e:
191189
self._set_first_error(e)
192190
raise e
@@ -296,7 +294,7 @@ async def close(self):
296294
raise TopicReaderError(message="Double closed ReaderStream")
297295

298296
self._closed = True
299-
self._set_first_error(TopicReaderError("Reader closed"))
297+
self._set_first_error(TopicReaderStreamClosedError())
300298
self._state_changed.set()
301299

302300
for task in self._background_tasks:

ydb/_topic_reader/topic_reader_asyncio_test.py

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,23 @@
55
import grpc
66
import pytest
77

8+
import ydb
89
from ydb import aio
910
from .datatypes import PublicBatch, PublicMessage
1011
from .topic_reader import PublicReaderSettings
1112
from .topic_reader_asyncio import ReaderStream, PartitionSession
12-
from .._topic_wrapper.common import OffsetsRange, Codec
13+
from .._topic_wrapper.common import OffsetsRange, Codec, ServerStatus
1314
from .._topic_wrapper.reader import StreamReadMessage
1415
from .._topic_wrapper.test_helpers import StreamMock, wait_condition, wait_for_fast
1516
from ..issues import Unavailable
1617

18+
# Workaround for good autocomplete in IDE and universal import at runtime
19+
# noinspection PyUnreachableCode
20+
if False:
21+
from .._grpc.v4.protos import ydb_status_codes_pb2
22+
else:
23+
from .._grpc.common.protos import ydb_status_codes_pb2
24+
1725

1826
@pytest.fixture()
1927
def default_reader_settings():
@@ -58,7 +66,7 @@ def second_partition_session(self, default_reader_settings):
5866
)
5967

6068
@pytest.fixture()
61-
async def stream_reader(self, stream, default_reader_settings, partition_session,
69+
async def stream_reader_started(self, stream, default_reader_settings, partition_session,
6270
second_partition_session) -> ReaderStream:
6371
reader = ReaderStream(default_reader_settings)
6472
init_message = object()
@@ -67,7 +75,8 @@ async def stream_reader(self, stream, default_reader_settings, partition_session
6775
start = asyncio.create_task(reader._start(stream, init_message))
6876

6977
stream.from_server.put_nowait(StreamReadMessage.FromServer(
70-
StreamReadMessage.InitResponse(session_id="test-session")
78+
server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []),
79+
server_message=StreamReadMessage.InitResponse(session_id="test-session"),
7180
))
7281

7382
init_request = await wait_for_fast(stream.from_client.get())
@@ -77,7 +86,9 @@ async def stream_reader(self, stream, default_reader_settings, partition_session
7786
assert isinstance(read_request.client_message, StreamReadMessage.ReadRequest)
7887

7988
stream.from_server.put_nowait(
80-
StreamReadMessage.FromServer(server_message=StreamReadMessage.StartPartitionSessionRequest(
89+
StreamReadMessage.FromServer(
90+
server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []),
91+
server_message=StreamReadMessage.StartPartitionSessionRequest(
8192
partition_session=StreamReadMessage.PartitionSession(
8293
partition_session_id=partition_session.id,
8394
path=partition_session.topic_path,
@@ -96,7 +107,9 @@ async def stream_reader(self, stream, default_reader_settings, partition_session
96107
assert isinstance(start_partition_resp.client_message, StreamReadMessage.StartPartitionSessionResponse)
97108

98109
stream.from_server.put_nowait(
99-
StreamReadMessage.FromServer(server_message=StreamReadMessage.StartPartitionSessionRequest(
110+
StreamReadMessage.FromServer(
111+
server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []),
112+
server_message=StreamReadMessage.StartPartitionSessionRequest(
100113
partition_session=StreamReadMessage.PartitionSession(
101114
partition_session_id=second_partition_session.id,
102115
path=second_partition_session.topic_path,
@@ -116,11 +129,22 @@ async def stream_reader(self, stream, default_reader_settings, partition_session
116129
with pytest.raises(asyncio.QueueEmpty):
117130
stream.from_client.get_nowait()
118131

119-
yield reader
132+
return reader
120133

121-
assert reader._first_error is None
134+
@pytest.fixture()
135+
async def stream_reader(self, stream_reader_started: ReaderStream):
136+
yield stream_reader_started
137+
138+
assert stream_reader_started._first_error is None
139+
await stream_reader_started.close()
140+
141+
@pytest.fixture()
142+
async def stream_reader_finish_with_error(self, stream_reader_started: ReaderStream):
143+
yield stream_reader_started
144+
145+
assert stream_reader_started._first_error is not None
146+
await stream_reader_started.close()
122147

123-
await reader.close()
124148

125149
@staticmethod
126150
def create_message(partition_session: PartitionSession, seqno: int):
@@ -143,7 +167,9 @@ def batch_count():
143167
initial_batches = batch_count()
144168

145169
stream = stream_reader._stream # type: StreamMock
146-
stream.from_server.put_nowait(StreamReadMessage.FromServer(server_message=StreamReadMessage.ReadResponse(
170+
stream.from_server.put_nowait(StreamReadMessage.FromServer(
171+
server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []),
172+
server_message=StreamReadMessage.ReadResponse(
147173
partition_data=[StreamReadMessage.ReadResponse.PartitionData(
148174
partition_session_id=message._partition_session.id,
149175
batches=[
@@ -169,20 +195,25 @@ def batch_count():
169195
)))
170196
await wait_condition(lambda: batch_count() > initial_batches)
171197

172-
async def test_convert_errors_to_ydb(self, stream, stream_reader):
173-
class TestError(grpc.RpcError):
174-
_code: grpc.StatusCode
175-
176-
def __init__(self, code: grpc.StatusCode):
177-
self._code = code
198+
async def test_first_error(self, stream, stream_reader_finish_with_error):
199+
class TestError(grpc.RpcError, grpc.Call):
200+
def __init__(self):
201+
pass
178202

179203
def code(self):
180-
return self._code
204+
return grpc.StatusCode.UNAUTHENTICATED
181205

182-
stream.from_server.put_nowait(TestError(grpc.StatusCode.UNAVAILABLE))
206+
def details(self):
207+
return "test error"
183208

184-
with pytest.raises(Unavailable):
185-
await wait_for_fast(stream_reader.wait_messages())
209+
test_err = TestError()
210+
stream.from_server.put_nowait(test_err)
211+
212+
with pytest.raises(TestError):
213+
await wait_for_fast(stream_reader_finish_with_error.wait_messages())
214+
215+
with pytest.raises(TestError):
216+
stream_reader_finish_with_error.receive_batch_nowait()
186217

187218
async def test_init_reader(self, stream, default_reader_settings):
188219
reader = ReaderStream(default_reader_settings)
@@ -202,6 +233,7 @@ async def test_init_reader(self, stream, default_reader_settings):
202233
assert sent_message == expected_sent_init_message
203234

204235
stream.from_server.put_nowait(StreamReadMessage.FromServer(
236+
server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []),
205237
server_message=StreamReadMessage.InitResponse(session_id="test"))
206238
)
207239

@@ -231,6 +263,7 @@ def session_count():
231263
test_topic_path = default_reader_settings.topic + "-asd"
232264

233265
stream.from_server.put_nowait(StreamReadMessage.FromServer(
266+
server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []),
234267
server_message=StreamReadMessage.StartPartitionSessionRequest(
235268
partition_session=StreamReadMessage.PartitionSession(
236269
partition_session_id=test_partition_session_id,
@@ -266,6 +299,7 @@ def session_count():
266299
initial_session_count = session_count()
267300

268301
stream.from_server.put_nowait(StreamReadMessage.FromServer(
302+
server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []),
269303
server_message=StreamReadMessage.StopPartitionSessionRequest(
270304
partition_session_id=partition_session.id,
271305
graceful=False,
@@ -287,6 +321,7 @@ def session_count():
287321
initial_session_count = session_count()
288322

289323
stream.from_server.put_nowait(StreamReadMessage.FromServer(
324+
server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []),
290325
server_message=StreamReadMessage.StopPartitionSessionRequest(
291326
partition_session_id=partition_session.id,
292327
graceful=True,
@@ -303,6 +338,7 @@ def session_count():
303338
)
304339

305340
stream.from_server.put_nowait(StreamReadMessage.FromServer(
341+
server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []),
306342
server_message=StreamReadMessage.StopPartitionSessionRequest(
307343
partition_session_id=partition_session.id,
308344
graceful=False,
@@ -330,7 +366,9 @@ def reader_batch_count():
330366
session_meta = {"a": "b"}
331367
message_group_id = "test-message-group-id"
332368

333-
stream.from_server.put_nowait(StreamReadMessage.FromServer(server_message=StreamReadMessage.ReadResponse(
369+
stream.from_server.put_nowait(StreamReadMessage.FromServer(
370+
server_status=ServerStatus(ydb_status_codes_pb2.StatusIds.SUCCESS, []),
371+
server_message=StreamReadMessage.ReadResponse(
334372
bytes_size=bytes_size,
335373
partition_data=[
336374
StreamReadMessage.ReadResponse.PartitionData(

ydb/_topic_wrapper/common.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@
44
from dataclasses import dataclass
55
from enum import IntEnum
66

7+
import grpc
78
from google.protobuf.message import Message
89

910
import ydb.aio
1011

12+
from .. import issues, connection
13+
1114
# Workaround for good autocomplete in IDE and universal import at runtime
15+
# noinspection PyUnreachableCode
1216
if False:
1317
from ydb._grpc.v4.protos import (
1418
ydb_status_codes_pb2,
@@ -147,16 +151,19 @@ class GrpcWrapperAsyncIO(IGrpcWrapperAsyncIO):
147151
from_client_grpc: asyncio.Queue
148152
from_server_grpc: typing.AsyncIterator
149153
convert_server_grpc_to_wrapper: typing.Callable[[typing.Any], typing.Any]
154+
_connection_state: str
150155

151156
def __init__(self, convert_server_grpc_to_wrapper):
152157
self.from_client_grpc = asyncio.Queue()
153158
self.convert_server_grpc_to_wrapper = convert_server_grpc_to_wrapper
159+
self._connection_state = "new"
154160

155161
async def start(self, driver: SupportedDriverType, stub, method):
156162
if asyncio.iscoroutinefunction(driver.__call__):
157163
await self._start_asyncio_driver(driver, stub, method)
158164
else:
159165
await self._start_sync_driver(driver, stub, method)
166+
self._connection_state = "started"
160167

161168
async def _start_asyncio_driver(self, driver: ydb.aio.Driver, stub, method):
162169
requests_iterator = QueueToIteratorAsyncIO(self.from_client_grpc)
@@ -179,37 +186,49 @@ async def _start_sync_driver(self, driver: ydb.Driver, stub, method):
179186

180187
async def receive(self) -> typing.Any:
181188
# todo handle grpc exceptions and convert it to internal exceptions
182-
grpc_message = await self.from_server_grpc.__anext__()
189+
try:
190+
grpc_message = await self.from_server_grpc.__anext__()
191+
except grpc.RpcError as e:
192+
raise connection._rpc_error_handler(self._connection_state, e)
193+
194+
issues._process_response(grpc_message)
195+
196+
if self._connection_state != "has_received_messages":
197+
self._connection_state = "has_received_messages"
198+
183199
# print("rekby, grpc, received", grpc_message)
184200
return self.convert_server_grpc_to_wrapper(grpc_message)
185201

186202
def write(self, wrap_message: IToProto):
187-
grpc_message=wrap_message.to_proto()
203+
grpc_message = wrap_message.to_proto()
188204
# print("rekby, grpc, send", grpc_message)
189205
self.from_client_grpc.put_nowait(grpc_message)
190206

191207

192208
@dataclass(init=False)
193209
class ServerStatus(IFromProto):
194-
__slots__ = ("status", "_issues")
210+
__slots__ = ("_grpc_status_code", "_issues")
195211

196212
def __init__(
197-
self,
198-
status: ydb_status_codes_pb2.StatusIds.StatusCode,
199-
issues: typing.Iterable[ydb_issue_message_pb2.IssueMessage],
213+
self,
214+
status_code: ydb_status_codes_pb2.StatusIds.StatusCode,
215+
grpc_issues: typing.Iterable[ydb_issue_message_pb2.IssueMessage],
200216
):
201-
self.status = status
202-
self._issues = issues
217+
self._grpc_status_code = status_code
218+
self._issues = grpc_issues
203219

204220
def __str__(self):
205221
return self.__repr__()
206222

207223
@staticmethod
208-
def from_proto(msg: Message) -> "ServerStatus":
209-
return ServerStatus(msg.status)
224+
def from_proto(msg: typing.Union[
225+
ydb_topic_pb2.StreamReadMessage.FromServer,
226+
ydb_topic_pb2.StreamWriteMessage.FromServer,
227+
]) -> "ServerStatus":
228+
return ServerStatus(msg.status, msg.issues)
210229

211230
def is_success(self) -> bool:
212-
return self.status == ydb_status_codes_pb2.StatusIds.SUCCESS
231+
return self._grpc_status_code == ydb_status_codes_pb2.StatusIds.SUCCESS
213232

214233
@classmethod
215234
def issue_to_str(cls, issue: ydb_issue_message_pb2.IssueMessage):
@@ -248,3 +267,6 @@ def callback_from_asyncio(callback: typing.Union[typing.Callable, typing.Corouti
248267
else:
249268
return loop.run_in_executor(None, callback)
250269

270+
271+
def ensure_success_or_raise_error(server_status: ServerStatus):
272+
error = issues._process_response(server_status._grpc_status_code, server_status._issues)

0 commit comments

Comments
 (0)