44import datetime
55import decimal
66import functools
7- import multiprocessing as mp
7+ import queue
88import os
9- import shutil
109import tempfile
1110import time
1211import traceback
2221from dateutil import parser
2322import sys
2423from logging import getLogger
25- import queue
2624from typing import Any , Dict , List , Literal , Optional , Tuple , Union , Callable
2725
2826import snowflake .snowpark
6260 DATA_SOURCE_SQL_COMMENT ,
6361 generate_sql_with_predicates ,
6462 output_type_handler ,
63+ add_unseen_files_to_process_queue ,
6564)
6665from snowflake .snowpark ._internal .utils import (
6766 INFER_SCHEMA_FORMAT_TYPES ,
@@ -1271,23 +1270,21 @@ def create_oracledb_connection():
12711270 )
12721271
12731272 try :
1274- with mp . Manager () as process_manager , ProcessPoolExecutor (
1273+ with ProcessPoolExecutor (
12751274 max_workers = max_workers
12761275 ) as process_executor , ThreadPoolExecutor (
12771276 max_workers = max_workers
12781277 ) as thread_executor :
12791278 thread_pool_futures , process_pool_futures = [], []
1280- parquet_file_queue = process_manager .Queue ()
12811279
12821280 def ingestion_thread_cleanup_callback (parquet_file_path , _ ):
12831281 # clean the local temp file after ingestion to avoid consuming too much temp disk space
1284- shutil . rmtree (parquet_file_path , ignore_errors = True )
1282+ os . remove (parquet_file_path )
12851283
12861284 logger .debug ("Starting to fetch data from the data source." )
12871285 for partition_idx , query in enumerate (partitioned_queries ):
12881286 process_future = process_executor .submit (
12891287 DataFrameReader ._task_fetch_from_data_source_with_retry ,
1290- parquet_file_queue ,
12911288 create_connection ,
12921289 query ,
12931290 struct_schema ,
@@ -1300,8 +1297,22 @@ def ingestion_thread_cleanup_callback(parquet_file_path, _):
13001297 )
13011298 process_pool_futures .append (process_future )
13021299 # Monitor queue while tasks are running
1300+
1301+ parquet_file_queue = (
1302+ queue .Queue ()
1303+ ) # maintain the queue of parquet files to process
1304+ set_of_files_already_added_in_queue = (
1305+ set ()
1306+ ) # maintain file names we have already put into queue
13031307 while True :
13041308 try :
1309+ # each process and per fetch will create a parquet with a unique file name
1310+ # we add unseen files to process queue
1311+ add_unseen_files_to_process_queue (
1312+ tmp_dir ,
1313+ set_of_files_already_added_in_queue ,
1314+ parquet_file_queue ,
1315+ )
13051316 file = parquet_file_queue .get_nowait ()
13061317 logger .debug (f"Retrieved file from parquet queue: { file } " )
13071318 thread_future = thread_executor .submit (
@@ -1336,8 +1347,15 @@ def ingestion_thread_cleanup_callback(parquet_file_path, _):
13361347 else :
13371348 unfinished_process_pool_futures .append (future )
13381349 all_job_done = False
1339- if all_job_done and parquet_file_queue .empty ():
1340- # all jod is done and parquet file queue is empty, we finished all the fetch work
1350+ if (
1351+ all_job_done
1352+ and parquet_file_queue .empty ()
1353+ and len (os .listdir (tmp_dir )) == 0
1354+ ):
1355+ # we finished all the fetch work based on the following 3 conditions:
1356+ # 1. all jod is done
1357+ # 2. parquet file queue is empty
1358+ # 3. no files in the temp work dir as they are all removed in thread future callback
13411359 # now we just need to wait for all ingestion threads to complete
13421360 logger .debug (
13431361 "All jobs are done, and the parquet file queue is empty. Fetching work is complete."
@@ -1537,7 +1555,6 @@ def _upload_and_copy_into_table_with_retry(
15371555
15381556 @staticmethod
15391557 def _task_fetch_from_data_source (
1540- parquet_file_queue : queue .Queue ,
15411558 create_connection : Callable [[], "Connection" ],
15421559 query : str ,
15431560 schema : StructType ,
@@ -1554,12 +1571,11 @@ def convert_to_parquet(fetched_data, fetch_idx):
15541571 logger .debug (
15551572 f"The DataFrame is empty, no parquet file is generated for partition { partition_idx } fetch { fetch_idx } ."
15561573 )
1557- return None
1574+ return
15581575 path = os .path .join (
15591576 tmp_dir , f"data_partition{ partition_idx } _fetch{ fetch_idx } .parquet"
15601577 )
15611578 df .to_parquet (path )
1562- return path
15631579
15641580 conn = create_connection ()
15651581 # this is specified to pyodbc, need other way to manage timeout on other drivers
@@ -1573,26 +1589,21 @@ def convert_to_parquet(fetched_data, fetch_idx):
15731589 if fetch_size == 0 :
15741590 cursor .execute (query )
15751591 result = cursor .fetchall ()
1576- parquet_file_path = convert_to_parquet (result , 0 )
1577- if parquet_file_path :
1578- parquet_file_queue .put (parquet_file_path )
1592+ convert_to_parquet (result , 0 )
15791593 elif fetch_size > 0 :
15801594 cursor = cursor .execute (query )
15811595 fetch_idx = 0
15821596 while True :
15831597 rows = cursor .fetchmany (fetch_size )
15841598 if not rows :
15851599 break
1586- parquet_file_path = convert_to_parquet (rows , fetch_idx )
1587- if parquet_file_path :
1588- parquet_file_queue .put (parquet_file_path )
1600+ convert_to_parquet (rows , fetch_idx )
15891601 fetch_idx += 1
15901602 else :
15911603 raise ValueError ("fetch size cannot be smaller than 0" )
15921604
15931605 @staticmethod
15941606 def _task_fetch_from_data_source_with_retry (
1595- parquet_file_queue : queue .Queue ,
15961607 create_connection : Callable [[], "Connection" ],
15971608 query : str ,
15981609 schema : StructType ,
@@ -1605,7 +1616,6 @@ def _task_fetch_from_data_source_with_retry(
16051616 ):
16061617 DataFrameReader ._retry_run (
16071618 DataFrameReader ._task_fetch_from_data_source ,
1608- parquet_file_queue ,
16091619 create_connection ,
16101620 query ,
16111621 schema ,
0 commit comments