Skip to content

Commit 9ddfa23

Browse files
SNOW-2062305 process pool batch fetcher (#2365)
1 parent 2583452 commit 9ddfa23

File tree

7 files changed

+101
-22
lines changed

7 files changed

+101
-22
lines changed

DESCRIPTION.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
1111
- Bumped numpy dependency from <2.1.0 to <=2.2.4
1212
- Added Windows support for Python 3.13.
1313
- Add `bulk_upload_chunks` parameter to `write_pandas` function. Setting this parameter to True changes the behaviour of write_pandas function to first write all the data chunks to the local disk and then perform the wildcard upload of the chunks folder to the stage. In default behaviour the chunks are being saved, uploaded and deleted one by one.
14-
- Added support for new authentication mechanism PAT with external session ID
14+
- Added support for new authentication mechanism PAT with external session ID.
15+
- Added `client_fetch_use_mp` parameter that enables multiprocessed fetching of result batches.
1516

1617
- v3.15.1(May 20, 2025)
1718
- Added basic arrow support for Interval types.

src/snowflake/connector/connection.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def _get_private_bytes_from_file(
229229
), # snowflake
230230
"client_prefetch_threads": (4, int), # snowflake
231231
"client_fetch_threads": (None, (type(None), int)),
232+
"client_fetch_use_mp": (False, bool),
232233
"numpy": (False, bool), # snowflake
233234
"ocsp_response_cache_filename": (None, (type(None), str)), # snowflake internal
234235
"converter_class": (DefaultConverterClass(), SnowflakeConverter),
@@ -428,7 +429,9 @@ class SnowflakeConnection:
428429
See the backoff_policies module for details and implementation examples.
429430
client_session_keep_alive_heartbeat_frequency: Heartbeat frequency to keep connection alive in seconds.
430431
client_prefetch_threads: Number of threads to download the result set.
431-
client_fetch_threads: Number of threads to fetch staged query results.
432+
client_fetch_threads: Number of threads (or processes) to fetch staged query results.
433+
If not specified, reuses client_prefetch_threads value.
434+
client_fetch_use_mp: Enables multiprocessing for fetching query results in parallel.
432435
rest: Snowflake REST API object. Internal use only. Maybe removed in a later release.
433436
application: Application name to communicate with Snowflake as. By default, this is "PythonConnector".
434437
errorhandler: Handler used with errors. By default, an exception will be raised on error.
@@ -701,6 +704,10 @@ def client_fetch_threads(self, value: None | int) -> None:
701704
value = min(max(1, value), MAX_CLIENT_FETCH_THREADS)
702705
self._client_fetch_threads = value
703706

707+
@property
708+
def client_fetch_use_mp(self) -> bool:
709+
return self._client_fetch_use_mp
710+
704711
@property
705712
def rest(self) -> SnowflakeRestful | None:
706713
return self._rest

src/snowflake/connector/cursor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,7 @@ def _init_result_and_meta(self, data: dict[Any, Any]) -> None:
12121212
result_chunks,
12131213
self._connection.client_fetch_threads
12141214
or self._connection.client_prefetch_threads,
1215+
self._connection.client_fetch_use_mp,
12151216
)
12161217
self._rownumber = -1
12171218
self._result_state = ResultState.VALID

src/snowflake/connector/result_batch.py

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from logging import getLogger
99
from typing import TYPE_CHECKING, Any, Callable, Iterator, NamedTuple, Sequence
1010

11+
from typing_extensions import Self
12+
1113
from .arrow_context import ArrowConverterContext
1214
from .backoff_policies import exponential_backoff
1315
from .compat import OK, UNAUTHORIZED, urlparse
@@ -413,6 +415,14 @@ def to_pandas(self) -> DataFrame:
413415
def to_arrow(self) -> Table:
414416
raise NotImplementedError()
415417

418+
@abc.abstractmethod
419+
def populate_data(
420+
self, connection: SnowflakeConnection | None = None, **kwargs
421+
) -> Self:
422+
"""Downloads the data that the ``ResultBatch`` is pointing at and populates it into self._data.
423+
Returns the instance itself."""
424+
raise NotImplementedError()
425+
416426

417427
class JSONResultBatch(ResultBatch):
418428
def __init__(
@@ -538,11 +548,9 @@ def _parse(
538548
def __repr__(self) -> str:
539549
return f"JSONResultChunk({self.id})"
540550

541-
def create_iter(
551+
def _fetch_data(
542552
self, connection: SnowflakeConnection | None = None, **kwargs
543-
) -> Iterator[dict | Exception] | Iterator[tuple | Exception]:
544-
if self._local:
545-
return iter(self._data)
553+
) -> list[dict | Exception] | list[tuple | Exception]:
546554
response = self._download(connection=connection)
547555
# Load data to a intermediate form
548556
logger.debug(f"started loading result batch id: {self.id}")
@@ -554,7 +562,20 @@ def create_iter(
554562
with TimerContextManager() as parse_metric:
555563
parsed_data = self._parse(downloaded_data)
556564
self._metrics[DownloadMetrics.parse.value] = parse_metric.get_timing_millis()
557-
return iter(parsed_data)
565+
return parsed_data
566+
567+
def populate_data(
568+
self, connection: SnowflakeConnection | None = None, **kwargs
569+
) -> Self:
570+
self._data = self._fetch_data(connection=connection, **kwargs)
571+
return self
572+
573+
def create_iter(
574+
self, connection: SnowflakeConnection | None = None, **kwargs
575+
) -> Iterator[dict | Exception] | Iterator[tuple | Exception]:
576+
if self._local:
577+
return iter(self._data)
578+
return iter(self._fetch_data(connection=connection, **kwargs))
558579

559580
def _arrow_fetching_error(self):
560581
return NotSupportedError(
@@ -613,7 +634,10 @@ def _load(
613634
)
614635

615636
def _from_data(
616-
self, data: str, iter_unit: IterUnit, check_error_on_every_column: bool = True
637+
self,
638+
data: str | bytes,
639+
iter_unit: IterUnit,
640+
check_error_on_every_column: bool = True,
617641
) -> Iterator[dict | Exception] | Iterator[tuple | Exception]:
618642
"""Creates a ``PyArrowIterator`` files from a str.
619643
@@ -623,8 +647,11 @@ def _from_data(
623647
if len(data) == 0:
624648
return iter([])
625649

650+
if isinstance(data, str):
651+
data = b64decode(data)
652+
626653
return _create_nanoarrow_iterator(
627-
b64decode(data),
654+
data,
628655
self._context,
629656
self._use_dict_result,
630657
self._numpy,
@@ -751,3 +778,9 @@ def create_iter(
751778
return self._get_arrow_iter(connection=connection)
752779
else:
753780
return self._create_iter(iter_unit=iter_unit, connection=connection)
781+
782+
def populate_data(
783+
self, connection: SnowflakeConnection | None = None, **kwargs
784+
) -> Self:
785+
self._data = self._download(connection=connection).content
786+
return self

src/snowflake/connector/result_set.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import inspect
44
from collections import deque
5-
from concurrent.futures import ALL_COMPLETED, Future, wait
5+
from concurrent.futures import ALL_COMPLETED, Future, ProcessPoolExecutor, wait
66
from concurrent.futures.thread import ThreadPoolExecutor
77
from logging import getLogger
88
from typing import (
@@ -44,6 +44,7 @@ def result_set_iterator(
4444
unfetched_batches: Deque[ResultBatch],
4545
final: Callable[[], None],
4646
prefetch_thread_num: int,
47+
use_mp: bool,
4748
**kw: Any,
4849
) -> Iterator[dict | Exception] | Iterator[tuple | Exception] | Iterator[Table]:
4950
"""Creates an iterator over some other iterators.
@@ -58,26 +59,52 @@ def result_set_iterator(
5859
to continue iterating through the rest of the ``ResultBatch``.
5960
"""
6061
is_fetch_all = kw.pop("is_fetch_all", False)
62+
63+
if use_mp:
64+
65+
def create_pool_executor() -> ProcessPoolExecutor:
66+
return ProcessPoolExecutor(prefetch_thread_num)
67+
68+
def create_fetch_task(batch: ResultBatch):
69+
return batch.populate_data
70+
71+
def get_fetch_result(future_result: ResultBatch):
72+
return future_result.create_iter(**kw)
73+
74+
kw["connection"] = None
75+
else:
76+
77+
def create_pool_executor() -> ThreadPoolExecutor:
78+
return ThreadPoolExecutor(prefetch_thread_num)
79+
80+
def create_fetch_task(batch: ResultBatch):
81+
return batch.create_iter
82+
83+
def get_fetch_result(future_result: Iterator):
84+
return future_result
85+
6186
if is_fetch_all:
62-
with ThreadPoolExecutor(prefetch_thread_num) as pool:
87+
with create_pool_executor() as pool:
6388
logger.debug("beginning to schedule result batch downloads")
6489
yield from first_batch_iter
6590
while unfetched_batches:
6691
logger.debug(
6792
f"queuing download of result batch id: {unfetched_batches[0].id}"
6893
)
69-
future = pool.submit(unfetched_batches.popleft().create_iter, **kw)
94+
future = pool.submit(
95+
create_fetch_task(unfetched_batches.popleft()), **kw
96+
)
7097
unconsumed_batches.append(future)
7198
_, _ = wait(unconsumed_batches, return_when=ALL_COMPLETED)
7299
i = 1
73100
while unconsumed_batches:
74101
logger.debug(f"user began consuming result batch {i}")
75-
yield from unconsumed_batches.popleft().result()
102+
yield from get_fetch_result(unconsumed_batches.popleft().result())
76103
logger.debug(f"user began consuming result batch {i}")
77104
i += 1
78105
final()
79106
else:
80-
with ThreadPoolExecutor(prefetch_thread_num) as pool:
107+
with create_pool_executor() as pool:
81108
# Fill up window
82109

83110
logger.debug("beginning to schedule result batch downloads")
@@ -87,7 +114,7 @@ def result_set_iterator(
87114
f"queuing download of result batch id: {unfetched_batches[0].id}"
88115
)
89116
unconsumed_batches.append(
90-
pool.submit(unfetched_batches.popleft().create_iter, **kw)
117+
pool.submit(create_fetch_task(unfetched_batches.popleft()), **kw)
91118
)
92119

93120
yield from first_batch_iter
@@ -101,13 +128,15 @@ def result_set_iterator(
101128
logger.debug(
102129
f"queuing download of result batch id: {unfetched_batches[0].id}"
103130
)
104-
future = pool.submit(unfetched_batches.popleft().create_iter, **kw)
131+
future = pool.submit(
132+
create_fetch_task(unfetched_batches.popleft()), **kw
133+
)
105134
unconsumed_batches.append(future)
106135

107136
future = unconsumed_batches.popleft()
108137

109138
# this will raise an exception if one has occurred
110-
batch_iterator = future.result()
139+
batch_iterator = get_fetch_result(future.result())
111140

112141
logger.debug(f"user began consuming result batch {i}")
113142
yield from batch_iterator
@@ -136,10 +165,12 @@ def __init__(
136165
cursor: SnowflakeCursor,
137166
result_chunks: list[JSONResultBatch] | list[ArrowResultBatch],
138167
prefetch_thread_num: int,
168+
use_mp: bool,
139169
) -> None:
140170
self.batches = result_chunks
141171
self._cursor = cursor
142172
self.prefetch_thread_num = prefetch_thread_num
173+
self._use_mp = use_mp
143174

144175
def _report_metrics(self) -> None:
145176
"""Report all metrics totalled up.
@@ -276,6 +307,7 @@ def _create_iter(
276307
self._finish_iterating,
277308
self.prefetch_thread_num,
278309
is_fetch_all=is_fetch_all,
310+
use_mp=self._use_mp,
279311
**kwargs,
280312
)
281313

test/integ/pandas_it/test_arrow_pandas.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,9 +1285,10 @@ def test_to_arrow_datatypes(enable_structured_types, conn_cnx):
12851285
cur.execute(f"alter session unset {param}")
12861286

12871287

1288-
def test_simple_arrow_fetch(conn_cnx):
1288+
@pytest.mark.parametrize("client_fetch_use_mp", [False, True])
1289+
def test_simple_arrow_fetch(conn_cnx, client_fetch_use_mp):
12891290
rowcount = 250_000
1290-
with conn_cnx() as cnx:
1291+
with conn_cnx(client_fetch_use_mp=client_fetch_use_mp) as cnx:
12911292
with cnx.cursor() as cur:
12921293
cur.execute(SQL_ENABLE_ARROW)
12931294
cur.execute(
@@ -1316,8 +1317,9 @@ def test_simple_arrow_fetch(conn_cnx):
13161317
assert lo == rowcount
13171318

13181319

1319-
def test_arrow_zero_rows(conn_cnx):
1320-
with conn_cnx() as cnx:
1320+
@pytest.mark.parametrize("client_fetch_use_mp", [False, True])
1321+
def test_arrow_zero_rows(conn_cnx, client_fetch_use_mp):
1322+
with conn_cnx(client_fetch_use_mp=client_fetch_use_mp) as cnx:
13211323
with cnx.cursor() as cur:
13221324
cur.execute(SQL_ENABLE_ARROW)
13231325
cur.execute("select 1::NUMBER(38,0) limit 0")

test/integ/test_cursor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1581,11 +1581,13 @@ def test__log_telemetry_job_data(conn_cnx, caplog):
15811581
("arrow", ArrowResultBatch),
15821582
),
15831583
)
1584+
@pytest.mark.parametrize("client_fetch_use_mp", [False, True])
15841585
def test_resultbatch(
15851586
conn_cnx,
15861587
result_format,
15871588
expected_chunk_type,
15881589
capture_sf_telemetry,
1590+
client_fetch_use_mp,
15891591
):
15901592
"""This test checks the following things:
15911593
1. After executing a query can we pickle the result batches
@@ -1598,7 +1600,8 @@ def test_resultbatch(
15981600
with conn_cnx(
15991601
session_parameters={
16001602
"python_connector_query_result_format": result_format,
1601-
}
1603+
},
1604+
client_fetch_use_mp=client_fetch_use_mp,
16021605
) as con:
16031606
with capture_sf_telemetry.patch_connection(con) as telemetry_data:
16041607
with con.cursor() as cur:

0 commit comments

Comments
 (0)