Skip to content

Commit 184d5c7

Browse files
SNOW-806291: Fix Snowpark much slower than Spark Connector on Databricks (for collect() and toPandas()) (#1946)
1 parent 0d00519 commit 184d5c7

File tree

2 files changed

+59
-32
lines changed

2 files changed

+59
-32
lines changed

DESCRIPTION.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne
88

99
# Release Notes
1010

11+
- v3.12.0(TBD)
12+
- Optimized `to_pandas()` performance by fully parallel downloading logic.
13+
14+
1115
- v3.11.0(June 17,2024)
1216
- Added support for `token_file_path` connection parameter to read an OAuth token from a file when connecting to Snowflake.
1317
- Added support for `debug_arrow_chunk` connection parameter to allow debugging raw arrow data in case of arrow data parsing failure.

src/snowflake/connector/result_set.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import inspect
88
from collections import deque
9-
from concurrent.futures import Future
9+
from concurrent.futures import ALL_COMPLETED, Future, wait
1010
from concurrent.futures.thread import ThreadPoolExecutor
1111
from logging import getLogger
1212
from typing import (
@@ -61,45 +61,64 @@ def result_set_iterator(
6161
Just like ``ResultBatch`` iterator, this might yield an ``Exception`` to allow users
6262
to continue iterating through the rest of the ``ResultBatch``.
6363
"""
64-
65-
with ThreadPoolExecutor(prefetch_thread_num) as pool:
66-
# Fill up window
67-
68-
logger.debug("beginning to schedule result batch downloads")
69-
70-
for _ in range(min(prefetch_thread_num, len(unfetched_batches))):
71-
logger.debug(
72-
f"queuing download of result batch id: {unfetched_batches[0].id}"
73-
)
74-
unconsumed_batches.append(
75-
pool.submit(unfetched_batches.popleft().create_iter, **kw)
76-
)
77-
78-
yield from first_batch_iter
79-
80-
i = 1
81-
while unconsumed_batches:
82-
logger.debug(f"user requesting to consume result batch {i}")
83-
84-
# Submit the next un-fetched batch to the pool
85-
if unfetched_batches:
64+
is_fetch_all = kw.pop("is_fetch_all", False)
65+
if is_fetch_all:
66+
with ThreadPoolExecutor(prefetch_thread_num) as pool:
67+
logger.debug("beginning to schedule result batch downloads")
68+
yield from first_batch_iter
69+
while unfetched_batches:
8670
logger.debug(
8771
f"queuing download of result batch id: {unfetched_batches[0].id}"
8872
)
8973
future = pool.submit(unfetched_batches.popleft().create_iter, **kw)
9074
unconsumed_batches.append(future)
75+
_, _ = wait(unconsumed_batches, return_when=ALL_COMPLETED)
76+
i = 1
77+
while unconsumed_batches:
78+
logger.debug(f"user began consuming result batch {i}")
79+
yield from unconsumed_batches.popleft().result()
80+
logger.debug(f"user began consuming result batch {i}")
81+
i += 1
82+
final()
83+
else:
84+
with ThreadPoolExecutor(prefetch_thread_num) as pool:
85+
# Fill up window
86+
87+
logger.debug("beginning to schedule result batch downloads")
88+
89+
for _ in range(min(prefetch_thread_num, len(unfetched_batches))):
90+
logger.debug(
91+
f"queuing download of result batch id: {unfetched_batches[0].id}"
92+
)
93+
unconsumed_batches.append(
94+
pool.submit(unfetched_batches.popleft().create_iter, **kw)
95+
)
9196

92-
future = unconsumed_batches.popleft()
97+
yield from first_batch_iter
9398

94-
# this will raise an exception if one has occurred
95-
batch_iterator = future.result()
99+
i = 1
100+
while unconsumed_batches:
101+
logger.debug(f"user requesting to consume result batch {i}")
96102

97-
logger.debug(f"user began consuming result batch {i}")
98-
yield from batch_iterator
99-
logger.debug(f"user finished consuming result batch {i}")
103+
# Submit the next un-fetched batch to the pool
104+
if unfetched_batches:
105+
logger.debug(
106+
f"queuing download of result batch id: {unfetched_batches[0].id}"
107+
)
108+
future = pool.submit(unfetched_batches.popleft().create_iter, **kw)
109+
unconsumed_batches.append(future)
100110

101-
i += 1
102-
final()
111+
future = unconsumed_batches.popleft()
112+
113+
# this will raise an exception if one has occurred
114+
batch_iterator = future.result()
115+
116+
logger.debug(f"user began consuming result batch {i}")
117+
yield from batch_iterator
118+
logger.debug(f"user finished consuming result batch {i}")
119+
120+
i += 1
121+
final()
103122

104123

105124
class ResultSet(Iterable[list]):
@@ -202,7 +221,7 @@ def _fetch_pandas_all(self, **kwargs) -> DataFrame:
202221
"""Fetches a single Pandas dataframe."""
203222
concat_args = list(inspect.signature(pandas.concat).parameters)
204223
concat_kwargs = {k: kwargs.pop(k) for k in dict(kwargs) if k in concat_args}
205-
dataframes = list(self._fetch_pandas_batches(**kwargs))
224+
dataframes = list(self._fetch_pandas_batches(is_fetch_all=True, **kwargs))
206225
if dataframes:
207226
return pandas.concat(
208227
dataframes,
@@ -238,6 +257,9 @@ def _create_iter(
238257
This function is a helper function to ``__iter__`` and it was introduced for the
239258
cases where we need to propagate some values to later ``_download`` calls.
240259
"""
260+
# pop is_fetch_all and pass it to result_set_iterator
261+
is_fetch_all = kwargs.pop("is_fetch_all", False)
262+
241263
# add connection so that result batches can use sessions
242264
kwargs["connection"] = self._cursor.connection
243265

@@ -257,6 +279,7 @@ def _create_iter(
257279
unfetched_batches,
258280
self._finish_iterating,
259281
self.prefetch_thread_num,
282+
is_fetch_all=is_fetch_all,
260283
**kwargs,
261284
)
262285

0 commit comments

Comments
 (0)