Skip to content

Commit fb11f03

Browse files
SNOW-2675533: extract local ingestion into Utils (#3987)
1 parent c76802f commit fb11f03

File tree

3 files changed

+149
-122
lines changed

3 files changed

+149
-122
lines changed

src/snowflake/snowpark/_internal/data_source/utils.py

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from threading import BoundedSemaphore
1313
from io import BytesIO
1414
from enum import Enum
15-
from typing import Any, Tuple, Optional, Callable, Dict, Union, Set
15+
from typing import Any, Tuple, Optional, Callable, Dict, Union, Set, List
1616
import logging
1717
from snowflake.snowpark._internal.data_source.dbms_dialects import (
1818
Sqlite3Dialect,
@@ -30,13 +30,17 @@
3030
Psycopg2Driver,
3131
PymysqlDriver,
3232
)
33-
import snowflake
3433
from snowflake.snowpark._internal.data_source import DataSourceReader
3534
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
3635
from snowflake.snowpark._internal.utils import get_temp_type_for_object
3736
from snowflake.snowpark.exceptions import SnowparkDataframeReaderException
3837
from snowflake.snowpark.types import StructType
3938

39+
from typing import TYPE_CHECKING
40+
41+
if TYPE_CHECKING:
42+
import snowflake.snowpark
43+
4044
logger = logging.getLogger(__name__)
4145

4246
_MAX_RETRY_TIME = 3
@@ -600,3 +604,130 @@ def track_data_source_statement_params(
600604
statement_params[STATEMENT_PARAMS_DATA_SOURCE] = "1"
601605

602606
return statement_params if statement_params else None
607+
608+
609+
def local_ingestion(
610+
session: "snowflake.snowpark.Session",
611+
partitioner: "snowflake.snowpark._internal.data_source.datasource_partitioner.DataSourcePartitioner",
612+
partitioned_queries: List[str],
613+
max_workers: int,
614+
snowflake_stage_name: str,
615+
snowflake_table_name: str,
616+
statements_params_for_telemetry: Dict,
617+
telemetry_json_string: Dict,
618+
fetch_with_process: bool = False,
619+
_emit_ast: bool = True,
620+
) -> None:
621+
data_fetching_thread_pool_executor = None
622+
data_fetching_thread_stop_event = None
623+
workers = []
624+
try:
625+
# Determine the number of processes or threads to use
626+
max_workers = max_workers or os.cpu_count()
627+
queue_class = mp.Queue if fetch_with_process else queue.Queue
628+
process_or_thread_error_indicator = queue_class()
629+
# a queue of partitions to be processed, this is filled by the partitioner before starting the workers
630+
partition_queue = queue_class()
631+
# a queue of parquet BytesIO objects to be uploaded
632+
# Set max size for parquet_queue to prevent overfilling when thread consumers are slower than process producers
633+
# process workers will block on this queue if it's full until the upload threads consume the BytesIO objects
634+
parquet_queue = queue_class(_MAX_WORKER_SCALE * max_workers)
635+
for partition_idx, query in enumerate(partitioned_queries):
636+
partition_queue.put((partition_idx, query))
637+
638+
# Start worker processes
639+
logger.debug(
640+
f"Starting {max_workers} worker processes to fetch data from the data source."
641+
)
642+
643+
fetch_to_local_start_time = time.perf_counter()
644+
logger.debug(f"fetch to local start at: {fetch_to_local_start_time}")
645+
646+
if fetch_with_process:
647+
for _worker_id in range(max_workers):
648+
process = mp.Process(
649+
target=worker_process,
650+
args=(
651+
partition_queue,
652+
parquet_queue,
653+
process_or_thread_error_indicator,
654+
partitioner.reader(),
655+
),
656+
)
657+
process.start()
658+
workers.append(process)
659+
else:
660+
data_fetching_thread_pool_executor = ThreadPoolExecutor(
661+
max_workers=max_workers
662+
)
663+
data_fetching_thread_stop_event = threading.Event()
664+
workers = [
665+
data_fetching_thread_pool_executor.submit(
666+
worker_process,
667+
partition_queue,
668+
parquet_queue,
669+
process_or_thread_error_indicator,
670+
partitioner.reader(),
671+
data_fetching_thread_stop_event,
672+
)
673+
for _worker_id in range(max_workers)
674+
]
675+
676+
# Process BytesIO objects from queue and upload them using utility method
677+
(
678+
fetch_to_local_end_time,
679+
upload_to_sf_start_time,
680+
upload_to_sf_end_time,
681+
) = process_parquet_queue_with_threads(
682+
session=session,
683+
parquet_queue=parquet_queue,
684+
process_or_thread_error_indicator=process_or_thread_error_indicator,
685+
workers=workers,
686+
total_partitions=len(partitioned_queries),
687+
snowflake_stage_name=snowflake_stage_name,
688+
snowflake_table_name=snowflake_table_name,
689+
max_workers=max_workers,
690+
statements_params=statements_params_for_telemetry,
691+
on_error="abort_statement",
692+
fetch_with_process=fetch_with_process,
693+
)
694+
logger.debug(f"upload and copy into start at: {upload_to_sf_start_time}")
695+
logger.debug(
696+
f"fetch to local total time: {fetch_to_local_end_time - fetch_to_local_start_time}"
697+
)
698+
699+
telemetry_json_string["fetch_to_local_duration"] = (
700+
fetch_to_local_end_time - fetch_to_local_start_time
701+
)
702+
telemetry_json_string["upload_and_copy_into_sf_table_duration"] = (
703+
upload_to_sf_end_time - upload_to_sf_start_time
704+
)
705+
706+
except BaseException as exc:
707+
if fetch_with_process:
708+
# Graceful shutdown - terminate all processes
709+
for process in workers:
710+
if process.is_alive():
711+
process.terminate()
712+
process.join(timeout=5)
713+
else:
714+
if data_fetching_thread_stop_event:
715+
data_fetching_thread_stop_event.set()
716+
for future in workers:
717+
if not future.done():
718+
future.cancel()
719+
logger.debug(
720+
f"Cancelled a remaining data fetching future {future} due to error in another thread."
721+
)
722+
723+
if isinstance(exc, SnowparkDataframeReaderException):
724+
raise exc
725+
726+
raise SnowparkDataframeReaderException(
727+
f"Error occurred while ingesting data from the data source: {exc!r}"
728+
)
729+
finally:
730+
if data_fetching_thread_pool_executor:
731+
data_fetching_thread_pool_executor.shutdown(wait=True)
732+
733+
logger.debug("All data has been successfully loaded into the Snowflake table.")

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 15 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,13 @@
22
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33
#
44
import json
5-
import multiprocessing as mp
65
import os
76
import re
87
import sys
98
import time
10-
import queue
119
from collections import defaultdict
12-
from concurrent.futures import ThreadPoolExecutor
1310
from logging import getLogger
1411
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Callable
15-
import threading
1612
from datetime import datetime
1713

1814
import snowflake.snowpark
@@ -39,12 +35,10 @@
3935
)
4036
from snowflake.snowpark._internal.data_source.datasource_typing import Connection
4137
from snowflake.snowpark._internal.data_source.utils import (
42-
worker_process,
43-
process_parquet_queue_with_threads,
4438
STATEMENT_PARAMS_DATA_SOURCE,
4539
DATA_SOURCE_DBAPI_SIGNATURE,
46-
_MAX_WORKER_SCALE,
4740
create_data_source_table_and_stage,
41+
local_ingestion,
4842
)
4943
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
5044
from snowflake.snowpark._internal.telemetry import set_api_call_source
@@ -2016,7 +2010,7 @@ def create_oracledb_connection():
20162010
set_api_call_source(df, DATA_SOURCE_DBAPI_SIGNATURE)
20172011
return df
20182012

2019-
# parquet ingestion
2013+
# create table and stage for data source
20202014
snowflake_table_name = random_name_for_temp_object(TempObjectType.TABLE)
20212015
snowflake_stage_name = random_name_for_temp_object(TempObjectType.STAGE)
20222016
create_data_source_table_and_stage(
@@ -2027,119 +2021,21 @@ def create_oracledb_connection():
20272021
statements_params_for_telemetry=statements_params_for_telemetry,
20282022
)
20292023

2030-
data_fetching_thread_pool_executor = None
2031-
data_fetching_thread_stop_event = None
2032-
workers = []
2033-
try:
2034-
# Determine the number of processes or threads to use
2035-
max_workers = max_workers or os.cpu_count()
2036-
queue_class = mp.Queue if fetch_with_process else queue.Queue
2037-
process_or_thread_error_indicator = queue_class()
2038-
# a queue of partitions to be processed, this is filled by the partitioner before starting the workers
2039-
partition_queue = queue_class()
2040-
# a queue of parquet BytesIO objects to be uploaded
2041-
# Set max size for parquet_queue to prevent overfilling when thread consumers are slower than process producers
2042-
# process workers will block on this queue if it's full until the upload threads consume the BytesIO objects
2043-
parquet_queue = queue_class(_MAX_WORKER_SCALE * max_workers)
2044-
for partition_idx, query in enumerate(partitioned_queries):
2045-
partition_queue.put((partition_idx, query))
2046-
2047-
# Start worker processes
2048-
logger.debug(
2049-
f"Starting {max_workers} worker processes to fetch data from the data source."
2050-
)
2051-
2052-
fetch_to_local_start_time = time.perf_counter()
2053-
logger.debug(f"fetch to local start at: {fetch_to_local_start_time}")
2054-
2055-
if fetch_with_process:
2056-
for _worker_id in range(max_workers):
2057-
process = mp.Process(
2058-
target=worker_process,
2059-
args=(
2060-
partition_queue,
2061-
parquet_queue,
2062-
process_or_thread_error_indicator,
2063-
partitioner.reader(),
2064-
),
2065-
)
2066-
process.start()
2067-
workers.append(process)
2068-
else:
2069-
data_fetching_thread_pool_executor = ThreadPoolExecutor(
2070-
max_workers=max_workers
2071-
)
2072-
data_fetching_thread_stop_event = threading.Event()
2073-
workers = [
2074-
data_fetching_thread_pool_executor.submit(
2075-
worker_process,
2076-
partition_queue,
2077-
parquet_queue,
2078-
process_or_thread_error_indicator,
2079-
partitioner.reader(),
2080-
data_fetching_thread_stop_event,
2081-
)
2082-
for _worker_id in range(max_workers)
2083-
]
2084-
2085-
# Process BytesIO objects from queue and upload them using utility method
2086-
(
2087-
fetch_to_local_end_time,
2088-
upload_to_sf_start_time,
2089-
upload_to_sf_end_time,
2090-
) = process_parquet_queue_with_threads(
2091-
session=self._session,
2092-
parquet_queue=parquet_queue,
2093-
process_or_thread_error_indicator=process_or_thread_error_indicator,
2094-
workers=workers,
2095-
total_partitions=len(partitioned_queries),
2096-
snowflake_stage_name=snowflake_stage_name,
2097-
snowflake_table_name=snowflake_table_name,
2098-
max_workers=max_workers,
2099-
statements_params=statements_params_for_telemetry,
2100-
on_error="abort_statement",
2101-
fetch_with_process=fetch_with_process,
2102-
)
2103-
logger.debug(f"upload and copy into start at: {upload_to_sf_start_time}")
2104-
logger.debug(
2105-
f"fetch to local total time: {fetch_to_local_end_time - fetch_to_local_start_time}"
2106-
)
2024+
# parquet ingestion, ingestion external data source into temporary table
21072025

2108-
telemetry_json_string["fetch_to_local_duration"] = (
2109-
fetch_to_local_end_time - fetch_to_local_start_time
2110-
)
2111-
telemetry_json_string["upload_and_copy_into_sf_table_duration"] = (
2112-
upload_to_sf_end_time - upload_to_sf_start_time
2113-
)
2114-
2115-
except BaseException as exc:
2116-
if fetch_with_process:
2117-
# Graceful shutdown - terminate all processes
2118-
for process in workers:
2119-
if process.is_alive():
2120-
process.terminate()
2121-
process.join(timeout=5)
2122-
else:
2123-
if data_fetching_thread_stop_event:
2124-
data_fetching_thread_stop_event.set()
2125-
for future in workers:
2126-
if not future.done():
2127-
future.cancel()
2128-
logger.debug(
2129-
f"Cancelled a remaining data fetching future {future} due to error in another thread."
2130-
)
2131-
2132-
if isinstance(exc, SnowparkDataframeReaderException):
2133-
raise exc
2134-
2135-
raise SnowparkDataframeReaderException(
2136-
f"Error occurred while ingesting data from the data source: {exc!r}"
2137-
)
2138-
finally:
2139-
if data_fetching_thread_pool_executor:
2140-
data_fetching_thread_pool_executor.shutdown(wait=True)
2026+
local_ingestion(
2027+
session=self._session,
2028+
partitioner=partitioner,
2029+
partitioned_queries=partitioned_queries,
2030+
max_workers=max_workers,
2031+
fetch_with_process=fetch_with_process,
2032+
snowflake_stage_name=snowflake_stage_name,
2033+
snowflake_table_name=snowflake_table_name,
2034+
statements_params_for_telemetry=statements_params_for_telemetry,
2035+
telemetry_json_string=telemetry_json_string,
2036+
_emit_ast=_emit_ast,
2037+
)
21412038

2142-
logger.debug("All data has been successfully loaded into the Snowflake table.")
21432039
end_time = time.perf_counter()
21442040
telemetry_json_string["end_to_end_duration"] = end_time - start_time
21452041

tests/integ/test_data_source_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1348,7 +1348,7 @@ def track_process_init(self, *args, **kwargs):
13481348
with mock.patch.object(
13491349
multiprocessing.Process, "__init__", track_process_init
13501350
), mock.patch(
1351-
"snowflake.snowpark.dataframe_reader.process_parquet_queue_with_threads",
1351+
"snowflake.snowpark._internal.data_source.utils.process_parquet_queue_with_threads",
13521352
side_effect=RuntimeError("Simulated error in queue processing"),
13531353
):
13541354

0 commit comments

Comments
 (0)