Skip to content

Commit e095d5a

Browse files
SNOW-2201361:fix dbapi does not work in python stored proc (#3541)
1 parent 6ef7e58 commit e095d5a

File tree

4 files changed

+71
-11
lines changed

4 files changed

+71
-11
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
### Snowpark Python API Updates
66

7+
#### Bug Fixes
8+
9+
- Fixed a bug in `DataFrameReader.dbapi` (PrPr) that `dbapi` fail in python stored procedure with process exit with code 1.
10+
711
#### New Features
812

913
- Added support for the following AI-powered functions in `functions.py`:

src/snowflake/snowpark/_internal/data_source/utils.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#
22
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33
#
4-
4+
import os
55
import queue
66
import time
77
import traceback
@@ -11,7 +11,7 @@
1111
from threading import BoundedSemaphore
1212
from io import BytesIO
1313
from enum import Enum
14-
from typing import Any, Tuple, Optional, Callable, Dict, Union
14+
from typing import Any, Tuple, Optional, Callable, Dict, Union, Set
1515
import logging
1616
from snowflake.snowpark._internal.data_source.dbms_dialects import (
1717
Sqlite3Dialect,
@@ -309,6 +309,7 @@ def _retry_run(func: Callable, *args, **kwargs) -> Any:
309309
def worker_process(
310310
partition_queue: Union[mp.Queue, queue.Queue],
311311
parquet_queue: Union[mp.Queue, queue.Queue],
312+
process_or_thread_error_indicator: Union[mp.Queue, queue.Queue],
312313
reader,
313314
stop_event: threading.Event = None,
314315
):
@@ -326,6 +327,8 @@ def worker_process(
326327
stop_event,
327328
)
328329
except queue.Empty:
330+
# indicate whether a process is exit gracefully
331+
process_or_thread_error_indicator.put(os.getpid())
329332
# No more work available, exit gracefully
330333
break
331334
except Exception as e:
@@ -358,9 +361,22 @@ def process_completed_futures(thread_futures):
358361
raise
359362

360363

364+
def _drain_process_status_queue(
365+
process_or_thread_error_indicator: Union[mp.Queue, queue.Queue],
366+
) -> Set:
367+
result = set()
368+
while True:
369+
try:
370+
result.add(process_or_thread_error_indicator.get(block=False))
371+
except queue.Empty:
372+
break
373+
return result
374+
375+
361376
def process_parquet_queue_with_threads(
362377
session: "snowflake.snowpark.Session",
363378
parquet_queue: Union[mp.Queue, queue.Queue],
379+
process_or_thread_error_indicator: Union[mp.Queue, queue.Queue],
364380
workers: list,
365381
total_partitions: int,
366382
snowflake_stage_name: str,
@@ -382,6 +398,7 @@ def process_parquet_queue_with_threads(
382398
Args:
383399
session: Snowflake session for database operations
384400
parquet_queue: Multiprocessing queue containing parquet data
401+
process_or_thread_error_indicator: Multiprocessing queue containing process exit information
385402
workers: List of worker processes or thread futures to monitor
386403
total_partitions: Total number of partitions expected
387404
snowflake_stage_name: Name of the Snowflake stage for uploads
@@ -395,6 +412,7 @@ def process_parquet_queue_with_threads(
395412
"""
396413

397414
completed_partitions = set()
415+
gracefully_exited_processes = set()
398416
# process parquet_queue may produce more data than the threads can handle,
399417
# so we use semaphore to limit the number of threads
400418
backpressure_semaphore = BoundedSemaphore(value=_MAX_WORKER_SCALE * max_workers)
@@ -448,10 +466,18 @@ def process_parquet_queue_with_threads(
448466
if fetch_with_process:
449467
# Check if any processes have failed
450468
for i, process in enumerate(workers):
451-
if not process.is_alive() and process.exitcode != 0:
452-
raise SnowparkDataframeReaderException(
453-
f"Partition {i} data fetching process failed with exit code {process.exitcode}"
469+
if not process.is_alive():
470+
gracefully_exited_processes = (
471+
gracefully_exited_processes.union(
472+
_drain_process_status_queue(
473+
process_or_thread_error_indicator
474+
)
475+
)
454476
)
477+
if process.pid not in gracefully_exited_processes:
478+
raise SnowparkDataframeReaderException(
479+
f"Partition {i} data fetching process failed with exit code {process.exitcode} or failed silently"
480+
)
455481
else:
456482
# Check if any threads have failed
457483
for i, future in enumerate(workers):
@@ -469,11 +495,18 @@ def process_parquet_queue_with_threads(
469495

470496
if fetch_with_process:
471497
# Wait for all processes to complete
472-
for idx, process in enumerate(workers):
498+
for process in workers:
473499
process.join()
474-
if process.exitcode != 0:
500+
# empty parquet queue to get all signals after each process ends
501+
gracefully_exited_processes = gracefully_exited_processes.union(
502+
_drain_process_status_queue(process_or_thread_error_indicator)
503+
)
504+
505+
# check if any process fails
506+
for idx, process in enumerate(workers):
507+
if process.pid not in gracefully_exited_processes:
475508
raise SnowparkDataframeReaderException(
476-
f"Partition {idx} data fetching process failed with exit code {process.exitcode}"
509+
f"Partition {idx} data fetching process failed with exit code {process.exitcode} or failed silently"
477510
)
478511
else:
479512
# Wait for all threads to complete

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1491,6 +1491,7 @@ def create_oracledb_connection():
14911491
# Determine the number of processes or threads to use
14921492
max_workers = max_workers or os.cpu_count()
14931493
queue_class = mp.Queue if fetch_with_process else queue.Queue
1494+
process_or_thread_error_indicator = queue_class()
14941495
# a queue of partitions to be processed, this is filled by the partitioner before starting the workers
14951496
partition_queue = queue_class()
14961497
# a queue of parquet BytesIO objects to be uploaded
@@ -1509,7 +1510,12 @@ def create_oracledb_connection():
15091510
for _worker_id in range(max_workers):
15101511
process = mp.Process(
15111512
target=worker_process,
1512-
args=(partition_queue, parquet_queue, partitioner.reader()),
1513+
args=(
1514+
partition_queue,
1515+
parquet_queue,
1516+
process_or_thread_error_indicator,
1517+
partitioner.reader(),
1518+
),
15131519
)
15141520
process.start()
15151521
workers.append(process)
@@ -1523,6 +1529,7 @@ def create_oracledb_connection():
15231529
worker_process,
15241530
partition_queue,
15251531
parquet_queue,
1532+
process_or_thread_error_indicator,
15261533
partitioner.reader(),
15271534
data_fetching_thread_stop_event,
15281535
)
@@ -1533,6 +1540,7 @@ def create_oracledb_connection():
15331540
process_parquet_queue_with_threads(
15341541
session=self._session,
15351542
parquet_queue=parquet_queue,
1543+
process_or_thread_error_indicator=process_or_thread_error_indicator,
15361544
workers=workers,
15371545
total_partitions=len(partitioned_queries),
15381546
snowflake_stage_name=snowflake_stage_name,

tests/integ/test_data_source_api.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,13 +1160,18 @@ def test_worker_process_unit(fetch_with_process):
11601160
multiprocessing.Queue() if fetch_with_process else queue.Queue()
11611161
)
11621162
parquet_queue = multiprocessing.Queue() if fetch_with_process else queue.Queue()
1163+
process_or_thread_error_indicator = (
1164+
multiprocessing.Queue() if fetch_with_process else queue.Queue()
1165+
)
11631166

11641167
# Set up partition_queue to return test data, then raise queue.Empty
11651168
partition_queue.put((0, f"SELECT * FROM {table_name} WHERE id <= 3"))
11661169
partition_queue.put((1, f"SELECT * FROM {table_name} WHERE id > 3"))
11671170

11681171
# Call the worker_process function directly (using real sqlite3 operations)
1169-
worker_process(partition_queue, parquet_queue, reader)
1172+
worker_process(
1173+
partition_queue, parquet_queue, process_or_thread_error_indicator, reader
1174+
)
11701175

11711176
expected_order = [
11721177
"data_partition0_fetch0.parquet",
@@ -1191,7 +1196,9 @@ def test_worker_process_unit(fetch_with_process):
11911196

11921197
# check error handling
11931198
partition_queue.put((0, "SELECT * FROM NON_EXISTING_TABLE"))
1194-
worker_process(partition_queue, parquet_queue, reader)
1199+
worker_process(
1200+
partition_queue, parquet_queue, process_or_thread_error_indicator, reader
1201+
)
11951202
error_signal, error_instance = parquet_queue.get()
11961203
assert error_signal == PARTITION_TASK_ERROR_SIGNAL
11971204
assert isinstance(
@@ -1520,13 +1527,15 @@ def test_thread_worker_exception(exception, match_message):
15201527

15211528
# Create test parameters
15221529
parquet_queue = queue.Queue()
1530+
process_or_thread_error_indicator = queue.Queue()
15231531
workers = [mock_future] # Single worker that will fail
15241532
total_partitions = 0 # No partitions to complete
15251533

15261534
with pytest.raises(SnowparkDataframeReaderException, match=match_message):
15271535
process_parquet_queue_with_threads(
15281536
session=mock_session,
15291537
parquet_queue=parquet_queue,
1538+
process_or_thread_error_indicator=process_or_thread_error_indicator,
15301539
workers=workers,
15311540
total_partitions=total_partitions,
15321541
snowflake_stage_name="test_stage",
@@ -1553,6 +1562,7 @@ def test_process_worker_non_zero_exitcode():
15531562

15541563
# Create test parameters
15551564
parquet_queue = queue.Queue()
1565+
process_or_thread_error_indicator = queue.Queue()
15561566
workers = [mock_process] # Single worker that will fail
15571567
total_partitions = 0 # No partitions to complete
15581568

@@ -1563,6 +1573,7 @@ def test_process_worker_non_zero_exitcode():
15631573
process_parquet_queue_with_threads(
15641574
session=mock_session,
15651575
parquet_queue=parquet_queue,
1576+
process_or_thread_error_indicator=process_or_thread_error_indicator,
15661577
workers=workers,
15671578
total_partitions=total_partitions,
15681579
snowflake_stage_name="test_stage",
@@ -1589,6 +1600,7 @@ def test_queue_empty_process_failure():
15891600

15901601
# Create empty queue to trigger queue.Empty exception
15911602
parquet_queue = queue.Queue()
1603+
process_or_thread_error_indicator = queue.Queue()
15921604
workers = [mock_process]
15931605
total_partitions = 1 # Set to 1 so the loop continues
15941606

@@ -1599,6 +1611,7 @@ def test_queue_empty_process_failure():
15991611
process_parquet_queue_with_threads(
16001612
session=mock_session,
16011613
parquet_queue=parquet_queue,
1614+
process_or_thread_error_indicator=process_or_thread_error_indicator,
16021615
workers=workers,
16031616
total_partitions=total_partitions,
16041617
snowflake_stage_name="test_stage",
@@ -1636,13 +1649,15 @@ def test_queue_empty_thread_failure(exception, match_message):
16361649

16371650
# Create empty queue to trigger queue.Empty exception
16381651
parquet_queue = queue.Queue()
1652+
process_or_thread_error_indicator = queue.Queue()
16391653
workers = [mock_future]
16401654
total_partitions = 1 # Set to 1 so the loop continues
16411655

16421656
with pytest.raises(SnowparkDataframeReaderException, match=match_message):
16431657
process_parquet_queue_with_threads(
16441658
session=mock_session,
16451659
parquet_queue=parquet_queue,
1660+
process_or_thread_error_indicator=process_or_thread_error_indicator,
16461661
workers=workers,
16471662
total_partitions=total_partitions,
16481663
snowflake_stage_name="test_stage",

0 commit comments

Comments
 (0)