Skip to content

Commit 990df44

Browse files
authored
SNOW-966003: Fix Arrow return value for zero-row queries (#1832)
1 parent aa5e50b commit 990df44

File tree

4 files changed

+54
-5
lines changed

4 files changed

+54
-5
lines changed

DESCRIPTION.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
88

99
# Release Notes
1010

11+
- v3.7.0(TBD)
12+
13+
- Added a new boolean parameter `force_return_table` to `SnowflakeCursor.fetch_arrow_all` to force returning `pyarrow.Table` in case of zero rows.
1114

1215
- v3.6.0(December 09,2023)
1316

src/snowflake/connector/cursor.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1333,14 +1333,29 @@ def fetch_arrow_batches(self) -> Iterator[Table]:
13331333
)
13341334
return self._result_set._fetch_arrow_batches()
13351335

1336-
def fetch_arrow_all(self) -> Table | None:
1336+
@overload
1337+
def fetch_arrow_all(self, force_return_table: Literal[False]) -> Table | None:
1338+
...
1339+
1340+
@overload
1341+
def fetch_arrow_all(self, force_return_table: Literal[True]) -> Table:
1342+
...
1343+
1344+
def fetch_arrow_all(self, force_return_table: bool = False) -> Table | None:
1345+
"""
1346+
Args:
1347+
force_return_table: Set to True so that when the query returns zero rows,
1348+
an empty pyarrow table will be returned with schema using the highest bit length for each column.
1349+
Default value is False in which case None is returned in case of zero rows.
1350+
"""
13371351
self.check_can_use_arrow_resultset()
1352+
13381353
if self._prefetch_hook is not None:
13391354
self._prefetch_hook()
13401355
if self._query_result_format != "arrow":
13411356
raise NotSupportedError
13421357
self._log_telemetry_job_data(TelemetryField.ARROW_FETCH_ALL, TelemetryData.TRUE)
1343-
return self._result_set._fetch_arrow_all()
1358+
return self._result_set._fetch_arrow_all(force_return_table=force_return_table)
13441359

13451360
def fetch_pandas_batches(self, **kwargs: Any) -> Iterator[DataFrame]:
13461361
"""Fetches a single Arrow Table."""

src/snowflake/connector/result_set.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,16 @@
99
from concurrent.futures import Future
1010
from concurrent.futures.thread import ThreadPoolExecutor
1111
from logging import getLogger
12-
from typing import TYPE_CHECKING, Any, Callable, Deque, Iterable, Iterator
12+
from typing import (
13+
TYPE_CHECKING,
14+
Any,
15+
Callable,
16+
Deque,
17+
Iterable,
18+
Iterator,
19+
Literal,
20+
overload,
21+
)
1322

1423
from .constants import IterUnit
1524
from .errors import NotSupportedError
@@ -164,13 +173,21 @@ def _fetch_arrow_batches(
164173
self._can_create_arrow_iter()
165174
return self._create_iter(iter_unit=IterUnit.TABLE_UNIT, structure="arrow")
166175

167-
def _fetch_arrow_all(self) -> Table | None:
176+
@overload
177+
def _fetch_arrow_all(self, force_return_table: Literal[False]) -> Table | None:
178+
...
179+
180+
@overload
181+
def _fetch_arrow_all(self, force_return_table: Literal[True]) -> Table:
182+
...
183+
184+
def _fetch_arrow_all(self, force_return_table: bool = False) -> Table | None:
168185
"""Fetches a single Arrow Table from all of the ``ResultBatch``."""
169186
tables = list(self._fetch_arrow_batches())
170187
if tables:
171188
return pa.concat_tables(tables)
172189
else:
173-
return None
190+
return self.batches[0].to_arrow() if force_return_table else None
174191

175192
def _fetch_pandas_batches(self, **kwargs) -> Iterator[DataFrame]:
176193
"""Fetches Pandas dataframes in batches, where batch refers to Snowflake Chunk.

test/integ/pandas/test_arrow_pandas.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,6 +1233,20 @@ def test_simple_arrow_fetch(conn_cnx):
12331233
assert lo == rowcount
12341234

12351235

1236+
def test_arrow_zero_rows(conn_cnx):
1237+
with conn_cnx() as cnx:
1238+
with cnx.cursor() as cur:
1239+
cur.execute(SQL_ENABLE_ARROW)
1240+
cur.execute("select 1::NUMBER(38,0) limit 0")
1241+
table = cur.fetch_arrow_all(force_return_table=True)
1242+
# Snowflake will return an integer dtype with maximum bit-length if
1243+
# no rows are returned
1244+
assert table.schema[0].type == pyarrow.int64()
1245+
cur.execute("select 1::NUMBER(38,0) limit 0")
1246+
# test default behavior
1247+
assert cur.fetch_arrow_all(force_return_table=False) is None
1248+
1249+
12361250
@pytest.mark.parametrize("fetch_fn_name", ["to_arrow", "to_pandas", "create_iter"])
12371251
@pytest.mark.parametrize("pass_connection", [True, False])
12381252
def test_sessions_used(conn_cnx, fetch_fn_name, pass_connection):

0 commit comments

Comments
 (0)