22# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33#
44import json
5- import multiprocessing as mp
65import os
76import re
87import sys
98import time
10- import queue
119from collections import defaultdict
12- from concurrent .futures import ThreadPoolExecutor
1310from logging import getLogger
1411from typing import Any , Dict , List , Literal , Optional , Tuple , Union , Callable
15- import threading
1612from datetime import datetime
1713
1814import snowflake .snowpark
3935)
4036from snowflake .snowpark ._internal .data_source .datasource_typing import Connection
4137from snowflake .snowpark ._internal .data_source .utils import (
42- worker_process ,
43- process_parquet_queue_with_threads ,
4438 STATEMENT_PARAMS_DATA_SOURCE ,
4539 DATA_SOURCE_DBAPI_SIGNATURE ,
46- _MAX_WORKER_SCALE ,
4740 create_data_source_table_and_stage ,
41+ local_ingestion ,
4842)
4943from snowflake .snowpark ._internal .error_message import SnowparkClientExceptionMessages
5044from snowflake .snowpark ._internal .telemetry import set_api_call_source
@@ -2016,7 +2010,7 @@ def create_oracledb_connection():
20162010 set_api_call_source (df , DATA_SOURCE_DBAPI_SIGNATURE )
20172011 return df
20182012
2019- # parquet ingestion
2013+ # create table and stage for data source
20202014 snowflake_table_name = random_name_for_temp_object (TempObjectType .TABLE )
20212015 snowflake_stage_name = random_name_for_temp_object (TempObjectType .STAGE )
20222016 create_data_source_table_and_stage (
@@ -2027,119 +2021,21 @@ def create_oracledb_connection():
20272021 statements_params_for_telemetry = statements_params_for_telemetry ,
20282022 )
20292023
2030- data_fetching_thread_pool_executor = None
2031- data_fetching_thread_stop_event = None
2032- workers = []
2033- try :
2034- # Determine the number of processes or threads to use
2035- max_workers = max_workers or os .cpu_count ()
2036- queue_class = mp .Queue if fetch_with_process else queue .Queue
2037- process_or_thread_error_indicator = queue_class ()
2038- # a queue of partitions to be processed, this is filled by the partitioner before starting the workers
2039- partition_queue = queue_class ()
2040- # a queue of parquet BytesIO objects to be uploaded
2041- # Set max size for parquet_queue to prevent overfilling when thread consumers are slower than process producers
2042- # process workers will block on this queue if it's full until the upload threads consume the BytesIO objects
2043- parquet_queue = queue_class (_MAX_WORKER_SCALE * max_workers )
2044- for partition_idx , query in enumerate (partitioned_queries ):
2045- partition_queue .put ((partition_idx , query ))
2046-
2047- # Start worker processes
2048- logger .debug (
2049- f"Starting { max_workers } worker processes to fetch data from the data source."
2050- )
2051-
2052- fetch_to_local_start_time = time .perf_counter ()
2053- logger .debug (f"fetch to local start at: { fetch_to_local_start_time } " )
2054-
2055- if fetch_with_process :
2056- for _worker_id in range (max_workers ):
2057- process = mp .Process (
2058- target = worker_process ,
2059- args = (
2060- partition_queue ,
2061- parquet_queue ,
2062- process_or_thread_error_indicator ,
2063- partitioner .reader (),
2064- ),
2065- )
2066- process .start ()
2067- workers .append (process )
2068- else :
2069- data_fetching_thread_pool_executor = ThreadPoolExecutor (
2070- max_workers = max_workers
2071- )
2072- data_fetching_thread_stop_event = threading .Event ()
2073- workers = [
2074- data_fetching_thread_pool_executor .submit (
2075- worker_process ,
2076- partition_queue ,
2077- parquet_queue ,
2078- process_or_thread_error_indicator ,
2079- partitioner .reader (),
2080- data_fetching_thread_stop_event ,
2081- )
2082- for _worker_id in range (max_workers )
2083- ]
2084-
2085- # Process BytesIO objects from queue and upload them using utility method
2086- (
2087- fetch_to_local_end_time ,
2088- upload_to_sf_start_time ,
2089- upload_to_sf_end_time ,
2090- ) = process_parquet_queue_with_threads (
2091- session = self ._session ,
2092- parquet_queue = parquet_queue ,
2093- process_or_thread_error_indicator = process_or_thread_error_indicator ,
2094- workers = workers ,
2095- total_partitions = len (partitioned_queries ),
2096- snowflake_stage_name = snowflake_stage_name ,
2097- snowflake_table_name = snowflake_table_name ,
2098- max_workers = max_workers ,
2099- statements_params = statements_params_for_telemetry ,
2100- on_error = "abort_statement" ,
2101- fetch_with_process = fetch_with_process ,
2102- )
2103- logger .debug (f"upload and copy into start at: { upload_to_sf_start_time } " )
2104- logger .debug (
2105- f"fetch to local total time: { fetch_to_local_end_time - fetch_to_local_start_time } "
2106- )
2024+ # parquet ingestion, ingestion external data source into temporary table
21072025
2108- telemetry_json_string ["fetch_to_local_duration" ] = (
2109- fetch_to_local_end_time - fetch_to_local_start_time
2110- )
2111- telemetry_json_string ["upload_and_copy_into_sf_table_duration" ] = (
2112- upload_to_sf_end_time - upload_to_sf_start_time
2113- )
2114-
2115- except BaseException as exc :
2116- if fetch_with_process :
2117- # Graceful shutdown - terminate all processes
2118- for process in workers :
2119- if process .is_alive ():
2120- process .terminate ()
2121- process .join (timeout = 5 )
2122- else :
2123- if data_fetching_thread_stop_event :
2124- data_fetching_thread_stop_event .set ()
2125- for future in workers :
2126- if not future .done ():
2127- future .cancel ()
2128- logger .debug (
2129- f"Cancelled a remaining data fetching future { future } due to error in another thread."
2130- )
2131-
2132- if isinstance (exc , SnowparkDataframeReaderException ):
2133- raise exc
2134-
2135- raise SnowparkDataframeReaderException (
2136- f"Error occurred while ingesting data from the data source: { exc !r} "
2137- )
2138- finally :
2139- if data_fetching_thread_pool_executor :
2140- data_fetching_thread_pool_executor .shutdown (wait = True )
2026+ local_ingestion (
2027+ session = self ._session ,
2028+ partitioner = partitioner ,
2029+ partitioned_queries = partitioned_queries ,
2030+ max_workers = max_workers ,
2031+ fetch_with_process = fetch_with_process ,
2032+ snowflake_stage_name = snowflake_stage_name ,
2033+ snowflake_table_name = snowflake_table_name ,
2034+ statements_params_for_telemetry = statements_params_for_telemetry ,
2035+ telemetry_json_string = telemetry_json_string ,
2036+ _emit_ast = _emit_ast ,
2037+ )
21412038
2142- logger .debug ("All data has been successfully loaded into the Snowflake table." )
21432039 end_time = time .perf_counter ()
21442040 telemetry_json_string ["end_to_end_duration" ] = end_time - start_time
21452041
0 commit comments