Skip to content

Commit 92e7162

Browse files
authored
SNOW-2097818: use multithread as the default implementation for dbapi (#3491)
1 parent 8839590 commit 92e7162

File tree

4 files changed

+567
-107
lines changed

4 files changed

+567
-107
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
- Added debuggability improvements to eagerly validate dataframe schema metadata. Enable it using `snowflake.snowpark.context.configure_development_features()`.
1212
- Added a new function `snowflake.snowpark.dataframe.map_in_pandas` that allows users map a function across a dataframe. The mapping function takes an iterator of pandas dataframes as input and provides one as output.
1313
- Added a ttl cache to describe queries. Repeated queries in a 15 second interval will use the cached value rather than requery Snowflake.
14+
- Added a parameter `fetch_with_process` to `DataFrameReader.dbapi` (PrPr) to enable multiprocessing for parallel data fetching in
15+
local ingestion. By default, local ingestion uses multithreading. Multiprocessing may improve performance for CPU-bound tasks like Parquet file generation.
1416

1517
#### Improvements
1618

src/snowflake/snowpark/_internal/data_source/utils.py

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
import queue
66
import time
77
import traceback
8+
import threading
89
import multiprocessing as mp
910
from concurrent.futures import ThreadPoolExecutor
1011
from threading import BoundedSemaphore
1112
from io import BytesIO
1213
from enum import Enum
13-
from typing import Any, Tuple, Optional, Callable, Dict
14+
from typing import Any, Tuple, Optional, Callable, Dict, Union
1415
import logging
1516
from snowflake.snowpark._internal.data_source.dbms_dialects import (
1617
Sqlite3Dialect,
@@ -151,7 +152,8 @@ def _task_fetch_data_from_source(
151152
worker: DataSourceReader,
152153
partition: str,
153154
partition_idx: int,
154-
parquet_queue: mp.Queue,
155+
parquet_queue: Union[mp.Queue, queue.Queue],
156+
stop_event: threading.Event = None,
155157
):
156158
"""
157159
Fetch data from source and convert to parquet BytesIO objects.
@@ -179,6 +181,16 @@ def convert_to_parquet_bytesio(fetched_data, fetch_idx):
179181
logger.debug(f"Added parquet BytesIO to queue: {parquet_id}")
180182

181183
for i, result in enumerate(worker.read(partition)):
184+
if stop_event and stop_event.is_set():
185+
parquet_queue.put(
186+
(
187+
PARTITION_TASK_ERROR_SIGNAL,
188+
SnowparkDataframeReaderException(
189+
"Data fetching stopped by thread failure"
190+
),
191+
)
192+
)
193+
break
182194
convert_to_parquet_bytesio(result, i)
183195

184196
parquet_queue.put((f"{PARTITION_TASK_COMPLETE_SIGNAL_PREFIX}{partition_idx}", None))
@@ -188,14 +200,16 @@ def _task_fetch_data_from_source_with_retry(
188200
worker: DataSourceReader,
189201
partition: str,
190202
partition_idx: int,
191-
parquet_queue: mp.Queue,
203+
parquet_queue: Union[mp.Queue, queue.Queue],
204+
stop_event: threading.Event = None,
192205
):
193206
_retry_run(
194207
_task_fetch_data_from_source,
195208
worker,
196209
partition,
197210
partition_idx,
198211
parquet_queue,
212+
stop_event,
199213
)
200214

201215

@@ -292,7 +306,12 @@ def _retry_run(func: Callable, *args, **kwargs) -> Any:
292306

293307

294308
# DBAPI worker function that processes multiple partitions
295-
def worker_process(partition_queue: mp.Queue, parquet_queue: mp.Queue, reader):
309+
def worker_process(
310+
partition_queue: Union[mp.Queue, queue.Queue],
311+
parquet_queue: Union[mp.Queue, queue.Queue],
312+
reader,
313+
stop_event: threading.Event = None,
314+
):
296315
"""Worker process that fetches data from multiple partitions"""
297316
while True:
298317
try:
@@ -304,6 +323,7 @@ def worker_process(partition_queue: mp.Queue, parquet_queue: mp.Queue, reader):
304323
query,
305324
partition_idx,
306325
parquet_queue,
326+
stop_event,
307327
)
308328
except queue.Empty:
309329
# No more work available, exit gracefully
@@ -340,14 +360,15 @@ def process_completed_futures(thread_futures):
340360

341361
def process_parquet_queue_with_threads(
342362
session: "snowflake.snowpark.Session",
343-
parquet_queue: mp.Queue,
344-
processes: list,
363+
parquet_queue: Union[mp.Queue, queue.Queue],
364+
workers: list,
345365
total_partitions: int,
346366
snowflake_stage_name: str,
347367
snowflake_table_name: str,
348368
max_workers: int,
349369
statements_params: Optional[Dict[str, str]] = None,
350370
on_error: str = "abort_statement",
371+
fetch_with_process: bool = False,
351372
) -> None:
352373
"""
353374
Process parquet data from a multiprocessing queue using a thread pool.
@@ -361,7 +382,7 @@ def process_parquet_queue_with_threads(
361382
Args:
362383
session: Snowflake session for database operations
363384
parquet_queue: Multiprocessing queue containing parquet data
364-
processes: List of worker processes to monitor
385+
workers: List of worker processes or thread futures to monitor
365386
total_partitions: Total number of partitions expected
366387
snowflake_stage_name: Name of the Snowflake stage for uploads
367388
snowflake_table_name: Name of the target Snowflake table
@@ -424,19 +445,44 @@ def process_parquet_queue_with_threads(
424445

425446
except queue.Empty:
426447
backpressure_semaphore.release() # Release semaphore if no data was fetched
427-
# Check if any processes have failed
428-
for i, process in enumerate(processes):
429-
if not process.is_alive() and process.exitcode != 0:
430-
raise SnowparkDataframeReaderException(
431-
f"Partition {i} data fetching process failed with exit code {process.exitcode}"
432-
)
448+
if fetch_with_process:
449+
# Check if any processes have failed
450+
for i, process in enumerate(workers):
451+
if not process.is_alive() and process.exitcode != 0:
452+
raise SnowparkDataframeReaderException(
453+
f"Partition {i} data fetching process failed with exit code {process.exitcode}"
454+
)
455+
else:
456+
# Check if any threads have failed
457+
for i, future in enumerate(workers):
458+
if future.done():
459+
try:
460+
future.result()
461+
except BaseException as e:
462+
if isinstance(e, SnowparkDataframeReaderException):
463+
raise e
464+
raise SnowparkDataframeReaderException(
465+
f"Partition {i} data fetching thread failed with error: {e}"
466+
)
433467
time.sleep(0.1)
434468
continue
435469

436-
# Wait for all processes to complete
437-
for idx, process in enumerate(processes):
438-
process.join()
439-
if process.exitcode != 0:
440-
raise SnowparkDataframeReaderException(
441-
f"Partition {idx} data fetching process failed with exit code {process.exitcode}"
442-
)
470+
if fetch_with_process:
471+
# Wait for all processes to complete
472+
for idx, process in enumerate(workers):
473+
process.join()
474+
if process.exitcode != 0:
475+
raise SnowparkDataframeReaderException(
476+
f"Partition {idx} data fetching process failed with exit code {process.exitcode}"
477+
)
478+
else:
479+
# Wait for all threads to complete
480+
for idx, future in enumerate(workers):
481+
try:
482+
future.result()
483+
except BaseException as e:
484+
if isinstance(e, SnowparkDataframeReaderException):
485+
raise e
486+
raise SnowparkDataframeReaderException(
487+
f"Partition {idx} data fetching thread failed with error: {e}"
488+
)

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import os
66
import sys
77
import time
8-
8+
import queue
9+
from concurrent.futures import ThreadPoolExecutor
910
from logging import getLogger
1011
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Callable
12+
import threading
1113

1214
import snowflake.snowpark
1315
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
@@ -1283,6 +1285,7 @@ def dbapi(
12831285
session_init_statement: Optional[Union[str, List[str]]] = None,
12841286
udtf_configs: Optional[dict] = None,
12851287
fetch_merge_count: int = 1,
1288+
fetch_with_process: bool = False,
12861289
_emit_ast: bool = True,
12871290
) -> DataFrame:
12881291
"""
@@ -1356,6 +1359,12 @@ def dbapi(
13561359
before uploading it. This improves performance by reducing the number of
13571360
small Parquet files. Defaults to 1, meaning each `fetch_size` batch is written to its own
13581361
Parquet file and uploaded separately.
1362+
fetch_with_process: Whether to use multiprocessing for data fetching and Parquet file generation in local ingestion.
1363+
Default to `False`, which means multithreading is used to fetch data in parallel.
1364+
Setting this to `True` enables multiprocessing, which may improve performance for CPU-bound tasks
1365+
like Parquet file generation. When using multiprocessing, guard your script with
1366+
`if __name__ == "__main__":` and call `multiprocessing.freeze_support()` on Windows if needed.
1367+
This parameter has no effect in UDFT ingestion.
13591368
13601369
Example::
13611370
.. code-block:: python
@@ -1366,6 +1375,17 @@ def create_oracledb_connection():
13661375
return connection
13671376
13681377
df = session.read.dbapi(create_oracledb_connection, table=...)
1378+
1379+
Example::
1380+
.. code-block:: python
1381+
1382+
import oracledb
1383+
def create_oracledb_connection():
1384+
connection = oracledb.connect(...)
1385+
return connection
1386+
1387+
if __name__ == "__main__":
1388+
df = session.read.dbapi(create_oracledb_connection, table=..., fetch_with_process=True)
13691389
"""
13701390
if (not table and not query) or (table and query):
13711391
raise SnowparkDataframeReaderException(
@@ -1444,59 +1464,91 @@ def create_oracledb_connection():
14441464
statement_params=statements_params_for_telemetry, _emit_ast=False
14451465
)
14461466

1467+
data_fetching_thread_pool_executor = None
1468+
data_fetching_thread_stop_event = None
1469+
workers = []
14471470
try:
1448-
processes = []
1449-
1450-
# Determine the number of processes to use
1451-
max_workers = max_workers or mp.cpu_count()
1452-
1471+
# Determine the number of processes or threads to use
1472+
max_workers = max_workers or os.cpu_count()
1473+
queue_class = mp.Queue if fetch_with_process else queue.Queue
14531474
# a queue of partitions to be processed, this is filled by the partitioner before starting the workers
1454-
partition_queue = mp.Queue()
1475+
partition_queue = queue_class()
14551476
# a queue of parquet BytesIO objects to be uploaded
14561477
# Set max size for parquet_queue to prevent overfilling when thread consumers are slower than process producers
14571478
# process workers will block on this queue if it's full until the upload threads consume the BytesIO objects
1458-
parquet_queue = mp.Queue(_MAX_WORKER_SCALE * max_workers)
1479+
parquet_queue = queue_class(_MAX_WORKER_SCALE * max_workers)
14591480
for partition_idx, query in enumerate(partitioned_queries):
14601481
partition_queue.put((partition_idx, query))
14611482

14621483
# Start worker processes
14631484
logger.debug(
14641485
f"Starting {max_workers} worker processes to fetch data from the data source."
14651486
)
1466-
for _worker_id in range(max_workers):
1467-
process = mp.Process(
1468-
target=worker_process,
1469-
args=(partition_queue, parquet_queue, partitioner.reader()),
1487+
1488+
if fetch_with_process:
1489+
for _worker_id in range(max_workers):
1490+
process = mp.Process(
1491+
target=worker_process,
1492+
args=(partition_queue, parquet_queue, partitioner.reader()),
1493+
)
1494+
process.start()
1495+
workers.append(process)
1496+
else:
1497+
data_fetching_thread_pool_executor = ThreadPoolExecutor(
1498+
max_workers=max_workers
14701499
)
1471-
process.start()
1472-
processes.append(process)
1500+
data_fetching_thread_stop_event = threading.Event()
1501+
workers = [
1502+
data_fetching_thread_pool_executor.submit(
1503+
worker_process,
1504+
partition_queue,
1505+
parquet_queue,
1506+
partitioner.reader(),
1507+
data_fetching_thread_stop_event,
1508+
)
1509+
for _worker_id in range(max_workers)
1510+
]
14731511

14741512
# Process BytesIO objects from queue and upload them using utility method
14751513
process_parquet_queue_with_threads(
14761514
session=self._session,
14771515
parquet_queue=parquet_queue,
1478-
processes=processes,
1516+
workers=workers,
14791517
total_partitions=len(partitioned_queries),
14801518
snowflake_stage_name=snowflake_stage_name,
14811519
snowflake_table_name=snowflake_table_name,
14821520
max_workers=max_workers,
14831521
statements_params=statements_params_for_telemetry,
14841522
on_error="abort_statement",
1523+
fetch_with_process=fetch_with_process,
14851524
)
14861525

14871526
except BaseException as exc:
1488-
# Graceful shutdown - terminate all processes
1489-
for process in processes:
1490-
if process.is_alive():
1491-
process.terminate()
1492-
process.join(timeout=5)
1527+
if fetch_with_process:
1528+
# Graceful shutdown - terminate all processes
1529+
for process in workers:
1530+
if process.is_alive():
1531+
process.terminate()
1532+
process.join(timeout=5)
1533+
else:
1534+
if data_fetching_thread_stop_event:
1535+
data_fetching_thread_stop_event.set()
1536+
for future in workers:
1537+
if not future.done():
1538+
future.cancel()
1539+
logger.debug(
1540+
f"Cancelled a remaining data fetching future {future} due to error in another thread."
1541+
)
14931542

14941543
if isinstance(exc, SnowparkDataframeReaderException):
14951544
raise exc
14961545

14971546
raise SnowparkDataframeReaderException(
14981547
f"Error occurred while ingesting data from the data source: {exc!r}"
14991548
)
1549+
finally:
1550+
if data_fetching_thread_pool_executor:
1551+
data_fetching_thread_pool_executor.shutdown(wait=True)
15001552

15011553
logger.debug("All data has been successfully loaded into the Snowflake table.")
15021554
self._session._conn._telemetry_client.send_data_source_perf_telemetry(

0 commit comments

Comments
 (0)