22# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33#
44import json
5- import multiprocessing as mp
65import os
76import re
87import sys
98import time
10- import queue
119from collections import defaultdict
12- from concurrent .futures import ThreadPoolExecutor
1310from logging import getLogger
1411from typing import Any , Dict , List , Literal , Optional , Tuple , Union , Callable
15- import threading
1612from datetime import datetime
1713
1814import snowflake .snowpark
3935)
4036from snowflake .snowpark ._internal .data_source .datasource_typing import Connection
4137from snowflake .snowpark ._internal .data_source .utils import (
42- worker_process ,
43- process_parquet_queue_with_threads ,
4438 STATEMENT_PARAMS_DATA_SOURCE ,
4539 DATA_SOURCE_DBAPI_SIGNATURE ,
46- _MAX_WORKER_SCALE ,
4740 create_data_source_table_and_stage ,
41+ local_ingestion ,
4842)
4943from snowflake .snowpark ._internal .error_message import SnowparkClientExceptionMessages
5044from snowflake .snowpark ._internal .telemetry import set_api_call_source
@@ -2018,7 +2012,7 @@ def create_oracledb_connection():
20182012 set_api_call_source (df , DATA_SOURCE_DBAPI_SIGNATURE )
20192013 return df
20202014
2021- # parquet ingestion
2015+ # create table and stage for data source
20222016 snowflake_table_name = random_name_for_temp_object (TempObjectType .TABLE )
20232017 snowflake_stage_name = random_name_for_temp_object (TempObjectType .STAGE )
20242018 create_data_source_table_and_stage (
@@ -2029,119 +2023,21 @@ def create_oracledb_connection():
20292023 statements_params_for_telemetry = statements_params_for_telemetry ,
20302024 )
20312025
2032- data_fetching_thread_pool_executor = None
2033- data_fetching_thread_stop_event = None
2034- workers = []
2035- try :
2036- # Determine the number of processes or threads to use
2037- max_workers = max_workers or os .cpu_count ()
2038- queue_class = mp .Queue if fetch_with_process else queue .Queue
2039- process_or_thread_error_indicator = queue_class ()
2040- # a queue of partitions to be processed, this is filled by the partitioner before starting the workers
2041- partition_queue = queue_class ()
2042- # a queue of parquet BytesIO objects to be uploaded
2043- # Set max size for parquet_queue to prevent overfilling when thread consumers are slower than process producers
2044- # process workers will block on this queue if it's full until the upload threads consume the BytesIO objects
2045- parquet_queue = queue_class (_MAX_WORKER_SCALE * max_workers )
2046- for partition_idx , query in enumerate (partitioned_queries ):
2047- partition_queue .put ((partition_idx , query ))
2048-
2049- # Start worker processes
2050- logger .debug (
2051- f"Starting { max_workers } worker processes to fetch data from the data source."
2052- )
2053-
2054- fetch_to_local_start_time = time .perf_counter ()
2055- logger .debug (f"fetch to local start at: { fetch_to_local_start_time } " )
2056-
2057- if fetch_with_process :
2058- for _worker_id in range (max_workers ):
2059- process = mp .Process (
2060- target = worker_process ,
2061- args = (
2062- partition_queue ,
2063- parquet_queue ,
2064- process_or_thread_error_indicator ,
2065- partitioner .reader (),
2066- ),
2067- )
2068- process .start ()
2069- workers .append (process )
2070- else :
2071- data_fetching_thread_pool_executor = ThreadPoolExecutor (
2072- max_workers = max_workers
2073- )
2074- data_fetching_thread_stop_event = threading .Event ()
2075- workers = [
2076- data_fetching_thread_pool_executor .submit (
2077- worker_process ,
2078- partition_queue ,
2079- parquet_queue ,
2080- process_or_thread_error_indicator ,
2081- partitioner .reader (),
2082- data_fetching_thread_stop_event ,
2083- )
2084- for _worker_id in range (max_workers )
2085- ]
2086-
2087- # Process BytesIO objects from queue and upload them using utility method
2088- (
2089- fetch_to_local_end_time ,
2090- upload_to_sf_start_time ,
2091- upload_to_sf_end_time ,
2092- ) = process_parquet_queue_with_threads (
2093- session = self ._session ,
2094- parquet_queue = parquet_queue ,
2095- process_or_thread_error_indicator = process_or_thread_error_indicator ,
2096- workers = workers ,
2097- total_partitions = len (partitioned_queries ),
2098- snowflake_stage_name = snowflake_stage_name ,
2099- snowflake_table_name = snowflake_table_name ,
2100- max_workers = max_workers ,
2101- statements_params = statements_params_for_telemetry ,
2102- on_error = "abort_statement" ,
2103- fetch_with_process = fetch_with_process ,
2104- )
2105- logger .debug (f"upload and copy into start at: { upload_to_sf_start_time } " )
2106- logger .debug (
2107- f"fetch to local total time: { fetch_to_local_end_time - fetch_to_local_start_time } "
2108- )
2026+ # parquet ingestion, ingestion external data source into temporary table
21092027
2110- telemetry_json_string ["fetch_to_local_duration" ] = (
2111- fetch_to_local_end_time - fetch_to_local_start_time
2112- )
2113- telemetry_json_string ["upload_and_copy_into_sf_table_duration" ] = (
2114- upload_to_sf_end_time - upload_to_sf_start_time
2115- )
2116-
2117- except BaseException as exc :
2118- if fetch_with_process :
2119- # Graceful shutdown - terminate all processes
2120- for process in workers :
2121- if process .is_alive ():
2122- process .terminate ()
2123- process .join (timeout = 5 )
2124- else :
2125- if data_fetching_thread_stop_event :
2126- data_fetching_thread_stop_event .set ()
2127- for future in workers :
2128- if not future .done ():
2129- future .cancel ()
2130- logger .debug (
2131- f"Cancelled a remaining data fetching future { future } due to error in another thread."
2132- )
2133-
2134- if isinstance (exc , SnowparkDataframeReaderException ):
2135- raise exc
2136-
2137- raise SnowparkDataframeReaderException (
2138- f"Error occurred while ingesting data from the data source: { exc !r} "
2139- )
2140- finally :
2141- if data_fetching_thread_pool_executor :
2142- data_fetching_thread_pool_executor .shutdown (wait = True )
2028+ local_ingestion (
2029+ session = self ._session ,
2030+ partitioner = partitioner ,
2031+ partitioned_queries = partitioned_queries ,
2032+ max_workers = max_workers ,
2033+ fetch_with_process = fetch_with_process ,
2034+ snowflake_stage_name = snowflake_stage_name ,
2035+ snowflake_table_name = snowflake_table_name ,
2036+ statements_params_for_telemetry = statements_params_for_telemetry ,
2037+ telemetry_json_string = telemetry_json_string ,
2038+ _emit_ast = _emit_ast ,
2039+ )
21432040
2144- logger .debug ("All data has been successfully loaded into the Snowflake table." )
21452041 end_time = time .perf_counter ()
21462042 telemetry_json_string ["end_to_end_duration" ] = end_time - start_time
21472043
0 commit comments