55import queue
66import time
77import traceback
8+ import threading
89import multiprocessing as mp
910from concurrent .futures import ThreadPoolExecutor
1011from threading import BoundedSemaphore
1112from io import BytesIO
1213from enum import Enum
13- from typing import Any , Tuple , Optional , Callable , Dict
14+ from typing import Any , Tuple , Optional , Callable , Dict , Union
1415import logging
1516from snowflake .snowpark ._internal .data_source .dbms_dialects import (
1617 Sqlite3Dialect ,
@@ -151,7 +152,8 @@ def _task_fetch_data_from_source(
151152 worker : DataSourceReader ,
152153 partition : str ,
153154 partition_idx : int ,
154- parquet_queue : mp .Queue ,
155+ parquet_queue : Union [mp .Queue , queue .Queue ],
156+ stop_event : threading .Event = None ,
155157):
156158 """
157159 Fetch data from source and convert to parquet BytesIO objects.
@@ -179,6 +181,16 @@ def convert_to_parquet_bytesio(fetched_data, fetch_idx):
179181 logger .debug (f"Added parquet BytesIO to queue: { parquet_id } " )
180182
181183 for i , result in enumerate (worker .read (partition )):
184+ if stop_event and stop_event .is_set ():
185+ parquet_queue .put (
186+ (
187+ PARTITION_TASK_ERROR_SIGNAL ,
188+ SnowparkDataframeReaderException (
189+ "Data fetching stopped by thread failure"
190+ ),
191+ )
192+ )
193+ break
182194 convert_to_parquet_bytesio (result , i )
183195
184196 parquet_queue .put ((f"{ PARTITION_TASK_COMPLETE_SIGNAL_PREFIX } { partition_idx } " , None ))
@@ -188,14 +200,16 @@ def _task_fetch_data_from_source_with_retry(
188200 worker : DataSourceReader ,
189201 partition : str ,
190202 partition_idx : int ,
191- parquet_queue : mp .Queue ,
203+ parquet_queue : Union [mp .Queue , queue .Queue ],
204+ stop_event : threading .Event = None ,
192205):
193206 _retry_run (
194207 _task_fetch_data_from_source ,
195208 worker ,
196209 partition ,
197210 partition_idx ,
198211 parquet_queue ,
212+ stop_event ,
199213 )
200214
201215
@@ -292,7 +306,12 @@ def _retry_run(func: Callable, *args, **kwargs) -> Any:
292306
293307
294308# DBAPI worker function that processes multiple partitions
295- def worker_process (partition_queue : mp .Queue , parquet_queue : mp .Queue , reader ):
309+ def worker_process (
310+ partition_queue : Union [mp .Queue , queue .Queue ],
311+ parquet_queue : Union [mp .Queue , queue .Queue ],
312+ reader ,
313+ stop_event : threading .Event = None ,
314+ ):
296315 """Worker process that fetches data from multiple partitions"""
297316 while True :
298317 try :
@@ -304,6 +323,7 @@ def worker_process(partition_queue: mp.Queue, parquet_queue: mp.Queue, reader):
304323 query ,
305324 partition_idx ,
306325 parquet_queue ,
326+ stop_event ,
307327 )
308328 except queue .Empty :
309329 # No more work available, exit gracefully
@@ -340,14 +360,15 @@ def process_completed_futures(thread_futures):
340360
341361def process_parquet_queue_with_threads (
342362 session : "snowflake.snowpark.Session" ,
343- parquet_queue : mp .Queue ,
344- processes : list ,
363+ parquet_queue : Union [ mp .Queue , queue . Queue ] ,
364+ workers : list ,
345365 total_partitions : int ,
346366 snowflake_stage_name : str ,
347367 snowflake_table_name : str ,
348368 max_workers : int ,
349369 statements_params : Optional [Dict [str , str ]] = None ,
350370 on_error : str = "abort_statement" ,
371+ fetch_with_process : bool = False ,
351372) -> None :
352373 """
353374 Process parquet data from a multiprocessing queue using a thread pool.
@@ -361,7 +382,7 @@ def process_parquet_queue_with_threads(
361382 Args:
362383 session: Snowflake session for database operations
363384 parquet_queue: Multiprocessing queue containing parquet data
364- processes : List of worker processes to monitor
385+ workers : List of worker processes or thread futures to monitor
365386 total_partitions: Total number of partitions expected
366387 snowflake_stage_name: Name of the Snowflake stage for uploads
367388 snowflake_table_name: Name of the target Snowflake table
@@ -424,19 +445,44 @@ def process_parquet_queue_with_threads(
424445
425446 except queue .Empty :
426447 backpressure_semaphore .release () # Release semaphore if no data was fetched
427- # Check if any processes have failed
428- for i , process in enumerate (processes ):
429- if not process .is_alive () and process .exitcode != 0 :
430- raise SnowparkDataframeReaderException (
431- f"Partition { i } data fetching process failed with exit code { process .exitcode } "
432- )
448+ if fetch_with_process :
449+ # Check if any processes have failed
450+ for i , process in enumerate (workers ):
451+ if not process .is_alive () and process .exitcode != 0 :
452+ raise SnowparkDataframeReaderException (
453+ f"Partition { i } data fetching process failed with exit code { process .exitcode } "
454+ )
455+ else :
456+ # Check if any threads have failed
457+ for i , future in enumerate (workers ):
458+ if future .done ():
459+ try :
460+ future .result ()
461+ except BaseException as e :
462+ if isinstance (e , SnowparkDataframeReaderException ):
463+ raise e
464+ raise SnowparkDataframeReaderException (
465+ f"Partition { i } data fetching thread failed with error: { e } "
466+ )
433467 time .sleep (0.1 )
434468 continue
435469
436- # Wait for all processes to complete
437- for idx , process in enumerate (processes ):
438- process .join ()
439- if process .exitcode != 0 :
440- raise SnowparkDataframeReaderException (
441- f"Partition { idx } data fetching process failed with exit code { process .exitcode } "
442- )
470+ if fetch_with_process :
471+ # Wait for all processes to complete
472+ for idx , process in enumerate (workers ):
473+ process .join ()
474+ if process .exitcode != 0 :
475+ raise SnowparkDataframeReaderException (
476+ f"Partition { idx } data fetching process failed with exit code { process .exitcode } "
477+ )
478+ else :
479+ # Wait for all threads to complete
480+ for idx , future in enumerate (workers ):
481+ try :
482+ future .result ()
483+ except BaseException as e :
484+ if isinstance (e , SnowparkDataframeReaderException ):
485+ raise e
486+ raise SnowparkDataframeReaderException (
487+ f"Partition { idx } data fetching thread failed with error: { e } "
488+ )
0 commit comments