Skip to content

Commit db110ff

Browse files
committed
Fix empty result sets from stream
1 parent 129cea2 commit db110ff

File tree

10 files changed

+52
-4
lines changed

10 files changed

+52
-4
lines changed

tests/aio/query/test_query_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,13 @@ async def test_basic_execute(self, session: QuerySession):
103103
async def test_two_results(self, session: QuerySession):
104104
await session.create()
105105
res = []
106+
counter = 0
106107

107108
async with await session.execute("select 1; select 2") as results:
108109
async for result_set in results:
110+
counter += 1
109111
if len(result_set.rows) > 0:
110112
res.append(list(result_set.rows[0].values()))
111113

112114
assert res == [[1], [2]]
115+
assert counter == 2

tests/aio/query/test_query_transaction.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,18 @@ async def test_execute_as_context_manager(self, tx: QueryTxContext):
9292
res = [result_set async for result_set in results]
9393

9494
assert len(res) == 1
95+
96+
@pytest.mark.asyncio
97+
async def test_execute_two_results(self, tx: QueryTxContext):
98+
await tx.begin()
99+
counter = 0
100+
res = []
101+
102+
async with await tx.execute("select 1; select 2") as results:
103+
async for result_set in results:
104+
counter += 1
105+
if len(result_set.rows) > 0:
106+
res.append(list(result_set.rows[0].values()))
107+
108+
assert res == [[1], [2]]
109+
assert counter == 2

tests/query/test_query_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,16 @@ def test_basic_execute(self, session: QuerySession):
9898
def test_two_results(self, session: QuerySession):
9999
session.create()
100100
res = []
101+
counter = 0
101102

102103
with session.execute("select 1; select 2") as results:
103104
for result_set in results:
105+
counter += 1
104106
if len(result_set.rows) > 0:
105107
res.append(list(result_set.rows[0].values()))
106108

107109
assert res == [[1], [2]]
110+
assert counter == 2
108111

109112
def test_thread_leaks(self, session: QuerySession):
110113
session.create()

tests/query/test_query_transaction.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,16 @@ def test_execute_as_context_manager(self, tx: QueryTxContext):
7979
res = [result_set for result_set in results]
8080

8181
assert len(res) == 1
82+
83+
def test_execute_two_results(self, tx: QueryTxContext):
84+
tx.begin()
85+
counter = 0
86+
res = []
87+
88+
with tx.execute("select 1; select 2") as results:
89+
for result_set in results:
90+
counter += 1
91+
res.append(list(result_set.rows[0].values()))
92+
93+
assert res == [[1], [2]]
94+
assert counter == 2

ydb/_utilities.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,10 @@ def __next__(self):
149149

150150

151151
class SyncResponseIterator(object):
152-
def __init__(self, it, wrapper):
152+
def __init__(self, it, wrapper, _filter=None):
153153
self.it = it
154154
self.wrapper = wrapper
155+
self.filter = _filter
155156

156157
def cancel(self):
157158
self.it.cancel()
@@ -161,7 +162,11 @@ def __iter__(self):
161162
return self
162163

163164
def _next(self):
164-
return self.wrapper(next(self.it))
165+
res = next(self.it)
166+
wrapped = self.wrapper(res)
167+
if self.filter is None or self.filter(res):
168+
return wrapped
169+
return self._next()
165170

166171
def next(self):
167172
return self._next()

ydb/aio/_utilities.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33

44
class AsyncResponseIterator(object):
5-
def __init__(self, it, wrapper):
5+
def __init__(self, it, wrapper, _filter=None):
66
self.it = it.__aiter__()
77
self.wrapper = wrapper
8+
self.filter = _filter
89

910
def cancel(self):
1011
self.it.cancel()
@@ -17,7 +18,11 @@ def __aiter__(self):
1718
return self
1819

1920
async def _next(self):
20-
return self.wrapper(await self.it.__anext__())
21+
res = await self.it.__anext__()
22+
wrapped = self.wrapper(res)
23+
if self.filter is None or self.filter(res):
24+
return wrapped
25+
return await self._next()
2126

2227
async def next(self):
2328
return await self._next()

ydb/aio/query/session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,5 @@ async def execute(
149149
session_state=self._state,
150150
settings=self._settings,
151151
),
152+
lambda resp: resp.HasField("result_set"),
152153
)

ydb/aio/query/transaction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,5 +144,6 @@ async def execute(
144144
commit_tx=commit_tx,
145145
settings=self.session._settings,
146146
),
147+
lambda resp: resp.HasField("result_set"),
147148
)
148149
return self._prev_stream

ydb/query/session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,4 +353,5 @@ def execute(
353353
session_state=self._state,
354354
settings=self._settings,
355355
),
356+
lambda resp: resp.HasField("result_set"),
356357
)

ydb/query/transaction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,5 +425,6 @@ def execute(
425425
commit_tx=commit_tx,
426426
settings=self.session._settings,
427427
),
428+
lambda resp: resp.HasField("result_set"),
428429
)
429430
return self._prev_stream

0 commit comments

Comments
 (0)