Skip to content

Commit 6e4fae3

Browse files
committed
SNOW-1975364: use file for cross process communication (#3149)
1 parent 1c3c8fc commit 6e4fae3

File tree

3 files changed

+50
-34
lines changed

3 files changed

+50
-34
lines changed

src/snowflake/snowpark/_internal/data_source_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import datetime
66
import decimal
77
import logging
8+
import os
9+
import queue
810
from enum import Enum
9-
from typing import List, Any, Tuple, Protocol, Optional
11+
from typing import List, Any, Tuple, Protocol, Optional, Set
1012
from snowflake.connector.options import pandas as pd
1113

1214
from snowflake.snowpark._internal.utils import get_sorted_key_for_version
@@ -368,3 +370,15 @@ def output_type_handler(cursor, metadata):
368370
return cursor.var(oracledb.DB_TYPE_LONG, arraysize=cursor.arraysize)
369371
elif metadata.type_code == oracledb.DB_TYPE_BLOB:
370372
return cursor.var(oracledb.DB_TYPE_RAW, arraysize=cursor.arraysize)
373+
374+
375+
def add_unseen_files_to_process_queue(
376+
work_dir: str, set_of_files_already_added_in_queue: Set[str], queue: queue.Queue
377+
):
378+
"""Add unseen files in the work_dir to the queue for processing."""
379+
# all files in the work_dir are parquet files, no subdirectory
380+
all_files = set(os.listdir(work_dir))
381+
unseen = all_files - set_of_files_already_added_in_queue
382+
for file in unseen:
383+
queue.put(os.path.join(work_dir, file))
384+
set_of_files_already_added_in_queue.add(file)

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
import datetime
55
import decimal
66
import functools
7-
import multiprocessing as mp
7+
import queue
88
import os
9-
import shutil
109
import tempfile
1110
import time
1211
import traceback
@@ -22,7 +21,6 @@
2221
from dateutil import parser
2322
import sys
2423
from logging import getLogger
25-
import queue
2624
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, Callable
2725

2826
import snowflake.snowpark
@@ -62,6 +60,7 @@
6260
DATA_SOURCE_SQL_COMMENT,
6361
generate_sql_with_predicates,
6462
output_type_handler,
63+
add_unseen_files_to_process_queue,
6564
)
6665
from snowflake.snowpark._internal.utils import (
6766
INFER_SCHEMA_FORMAT_TYPES,
@@ -1271,23 +1270,21 @@ def create_oracledb_connection():
12711270
)
12721271

12731272
try:
1274-
with mp.Manager() as process_manager, ProcessPoolExecutor(
1273+
with ProcessPoolExecutor(
12751274
max_workers=max_workers
12761275
) as process_executor, ThreadPoolExecutor(
12771276
max_workers=max_workers
12781277
) as thread_executor:
12791278
thread_pool_futures, process_pool_futures = [], []
1280-
parquet_file_queue = process_manager.Queue()
12811279

12821280
def ingestion_thread_cleanup_callback(parquet_file_path, _):
12831281
# clean the local temp file after ingestion to avoid consuming too much temp disk space
1284-
shutil.rmtree(parquet_file_path, ignore_errors=True)
1282+
os.remove(parquet_file_path)
12851283

12861284
logger.debug("Starting to fetch data from the data source.")
12871285
for partition_idx, query in enumerate(partitioned_queries):
12881286
process_future = process_executor.submit(
12891287
DataFrameReader._task_fetch_from_data_source_with_retry,
1290-
parquet_file_queue,
12911288
create_connection,
12921289
query,
12931290
struct_schema,
@@ -1300,8 +1297,22 @@ def ingestion_thread_cleanup_callback(parquet_file_path, _):
13001297
)
13011298
process_pool_futures.append(process_future)
13021299
# Monitor queue while tasks are running
1300+
1301+
parquet_file_queue = (
1302+
queue.Queue()
1303+
) # maintain the queue of parquet files to process
1304+
set_of_files_already_added_in_queue = (
1305+
set()
1306+
) # maintain file names we have already put into queue
13031307
while True:
13041308
try:
1309+
# each process and per fetch will create a parquet with a unique file name
1310+
# we add unseen files to process queue
1311+
add_unseen_files_to_process_queue(
1312+
tmp_dir,
1313+
set_of_files_already_added_in_queue,
1314+
parquet_file_queue,
1315+
)
13051316
file = parquet_file_queue.get_nowait()
13061317
logger.debug(f"Retrieved file from parquet queue: {file}")
13071318
thread_future = thread_executor.submit(
@@ -1336,8 +1347,15 @@ def ingestion_thread_cleanup_callback(parquet_file_path, _):
13361347
else:
13371348
unfinished_process_pool_futures.append(future)
13381349
all_job_done = False
1339-
if all_job_done and parquet_file_queue.empty():
1340-
# all jod is done and parquet file queue is empty, we finished all the fetch work
1350+
if (
1351+
all_job_done
1352+
and parquet_file_queue.empty()
1353+
and len(os.listdir(tmp_dir)) == 0
1354+
):
1355+
# we finished all the fetch work based on the following 3 conditions:
1356+
# 1. all jod is done
1357+
# 2. parquet file queue is empty
1358+
# 3. no files in the temp work dir as they are all removed in thread future callback
13411359
# now we just need to wait for all ingestion threads to complete
13421360
logger.debug(
13431361
"All jobs are done, and the parquet file queue is empty. Fetching work is complete."
@@ -1537,7 +1555,6 @@ def _upload_and_copy_into_table_with_retry(
15371555

15381556
@staticmethod
15391557
def _task_fetch_from_data_source(
1540-
parquet_file_queue: queue.Queue,
15411558
create_connection: Callable[[], "Connection"],
15421559
query: str,
15431560
schema: StructType,
@@ -1554,12 +1571,11 @@ def convert_to_parquet(fetched_data, fetch_idx):
15541571
logger.debug(
15551572
f"The DataFrame is empty, no parquet file is generated for partition {partition_idx} fetch {fetch_idx}."
15561573
)
1557-
return None
1574+
return
15581575
path = os.path.join(
15591576
tmp_dir, f"data_partition{partition_idx}_fetch{fetch_idx}.parquet"
15601577
)
15611578
df.to_parquet(path)
1562-
return path
15631579

15641580
conn = create_connection()
15651581
# this is specified to pyodbc, need other way to manage timeout on other drivers
@@ -1573,26 +1589,21 @@ def convert_to_parquet(fetched_data, fetch_idx):
15731589
if fetch_size == 0:
15741590
cursor.execute(query)
15751591
result = cursor.fetchall()
1576-
parquet_file_path = convert_to_parquet(result, 0)
1577-
if parquet_file_path:
1578-
parquet_file_queue.put(parquet_file_path)
1592+
convert_to_parquet(result, 0)
15791593
elif fetch_size > 0:
15801594
cursor = cursor.execute(query)
15811595
fetch_idx = 0
15821596
while True:
15831597
rows = cursor.fetchmany(fetch_size)
15841598
if not rows:
15851599
break
1586-
parquet_file_path = convert_to_parquet(rows, fetch_idx)
1587-
if parquet_file_path:
1588-
parquet_file_queue.put(parquet_file_path)
1600+
convert_to_parquet(rows, fetch_idx)
15891601
fetch_idx += 1
15901602
else:
15911603
raise ValueError("fetch size cannot be smaller than 0")
15921604

15931605
@staticmethod
15941606
def _task_fetch_from_data_source_with_retry(
1595-
parquet_file_queue: queue.Queue,
15961607
create_connection: Callable[[], "Connection"],
15971608
query: str,
15981609
schema: StructType,
@@ -1605,7 +1616,6 @@ def _task_fetch_from_data_source_with_retry(
16051616
):
16061617
DataFrameReader._retry_run(
16071618
DataFrameReader._task_fetch_from_data_source,
1608-
parquet_file_queue,
16091619
create_connection,
16101620
query,
16111621
schema,

tests/integ/test_data_source_api.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import functools
55
import math
66
import os
7-
import queue
87
import tempfile
98
import time
109
import datetime
@@ -152,7 +151,6 @@ def test_dbapi_retry(session):
152151
SnowparkDataframeReaderException, match="\\[RuntimeError\\] Test error"
153152
):
154153
DataFrameReader._task_fetch_from_data_source_with_retry(
155-
parquet_file_queue=queue.Queue(),
156154
create_connection=sql_server_create_connection,
157155
query="SELECT * FROM test_table",
158156
schema=StructType([StructField("col1", IntegerType(), False)]),
@@ -528,7 +526,6 @@ def test_negative_case(session):
528526
def test_task_fetch_from_data_source_with_fetch_size(
529527
fetch_size, partition_idx, expected_error
530528
):
531-
parquet_file_queue = queue.Queue()
532529
schema = infer_data_source_schema(
533530
sql_server_create_connection_small_data(),
534531
SQL_SERVER_TABLE_NAME,
@@ -544,7 +541,6 @@ def test_task_fetch_from_data_source_with_fetch_size(
544541
with tempfile.TemporaryDirectory() as tmp_dir:
545542

546543
params = {
547-
"parquet_file_queue": parquet_file_queue,
548544
"create_connection": sql_server_create_connection_small_data,
549545
"query": "SELECT * FROM test_table",
550546
"schema": schema,
@@ -562,16 +558,12 @@ def test_task_fetch_from_data_source_with_fetch_size(
562558
DataFrameReader._task_fetch_from_data_source(**params)
563559
else:
564560
DataFrameReader._task_fetch_from_data_source(**params)
565-
566-
file_idx = 0
567-
while not parquet_file_queue.empty():
568-
file_path = parquet_file_queue.get()
561+
files = sorted(os.listdir(tmp_dir))
562+
for idx, file in enumerate(files):
569563
assert (
570-
f"data_partition{partition_idx}_fetch{file_idx}.parquet"
571-
in file_path
572-
)
573-
file_idx += 1
574-
assert file_idx == file_count
564+
f"data_partition{partition_idx}_fetch{idx}.parquet" in file
565+
), f"file: {file} does not match"
566+
assert len(files) == file_count
575567

576568

577569
def test_database_detector():

0 commit comments

Comments
 (0)