Skip to content

Commit 2e60174

Browse files
authored
SNOW-2097824: cloudpickle callback function to support locally defined function (#3356)
1 parent 41e4e4f commit 2e60174

File tree

4 files changed

+69
-4
lines changed

4 files changed

+69
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
- Fixed a bug in `DataFrameWriter.dbapi` (PrPr) that unicode or double-quoted column name in external database causes error because not quoted correctly.
3131
- Fixed a bug where named fields in nested OBJECT data could cause errors when containing spaces.
32+
- Fixed a bug in `DataFrameReader.dbapi` (PrPr) where the `create_connection` defined as local function was incompatible with multiprocessing.
3233

3334
### Snowpark Local Testing Updates
3435

src/snowflake/snowpark/_internal/data_source/datasource_reader.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
#
22
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33
#
4-
4+
import pickle
5+
import cloudpickle
56
from enum import Enum
67

78
from typing import List, Any, Iterator, Type, Callable, Optional
@@ -28,14 +29,26 @@ def __init__(
2829
session_init_statement: Optional[List[str]] = None,
2930
fetch_merge_count: Optional[int] = 1,
3031
) -> None:
31-
self.driver = driver_class(create_connection, dbms_type)
32+
# we use cloudpickle to pickle the callback function so that local function and function defined in
33+
# __main__ can be pickled and unpickled in subprocess
34+
self.pickled_create_connection_callback = cloudpickle.dumps(
35+
create_connection, protocol=pickle.HIGHEST_PROTOCOL
36+
)
37+
self.driver = None
38+
self.driver_class = driver_class
39+
self.dbms_type = dbms_type
3240
self.schema = schema
3341
self.fetch_size = fetch_size
3442
self.query_timeout = query_timeout
3543
self.session_init_statement = session_init_statement
3644
self.fetch_merge_count = fetch_merge_count
3745

3846
def read(self, partition: str) -> Iterator[List[Any]]:
47+
self.driver = self.driver_class(
48+
cloudpickle.loads(self.pickled_create_connection_callback),
49+
self.dbms_type,
50+
)
51+
3952
conn = self.driver.prepare_connection(
4053
self.driver.create_connection(), self.query_timeout
4154
)
@@ -74,4 +87,6 @@ def read(self, partition: str) -> Iterator[List[Any]]:
7487
conn.close()
7588

7689
def data_source_data_to_pandas_df(self, data: List[Any]) -> "pd.DataFrame":
90+
# self.driver is guaranteed to be initialized in self.read() which is called prior to this method
91+
assert self.driver is not None
7792
return self.driver.data_source_data_to_pandas_df(data, self.schema)

tests/integ/test_data_source_api.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import logging
77
import math
88
import os
9+
import subprocess
910
import tempfile
1011
import datetime
12+
from textwrap import dedent
1113
from unittest import mock
1214
from unittest.mock import patch, MagicMock, PropertyMock
1315

@@ -929,3 +931,49 @@ def test_double_quoted_column_name_sql_server(session):
929931
assert df.collect() == [
930932
Row(Id=1, FullName="John Doe", Country="USA", Notes="Fake note")
931933
]
934+
935+
936+
@pytest.mark.skipif(
937+
IS_WINDOWS,
938+
reason="sqlite3 file can not be shared across processes on windows",
939+
)
940+
def test_local_create_connection_function(session, db_parameters):
941+
with tempfile.TemporaryDirectory() as temp_dir:
942+
dbpath = os.path.join(temp_dir, "testsqlite3.db")
943+
table_name, _, _, assert_data = sqlite3_db(dbpath)
944+
945+
# test local function definition
946+
def local_create_connection():
947+
import sqlite3
948+
949+
return sqlite3.connect(dbpath)
950+
951+
df = session.read.dbapi(
952+
local_create_connection,
953+
table=table_name,
954+
custom_schema=SQLITE3_DB_CUSTOM_SCHEMA_STRING,
955+
)
956+
assert df.order_by("ID").collect() == assert_data
957+
958+
# test function is defined in the main
959+
code = dedent(
960+
f"""
961+
if __name__ == "__main__":
962+
import sqlite3
963+
from snowflake.snowpark import Session
964+
965+
def local_create_connection():
966+
return sqlite3.connect("{str(dbpath)}")
967+
968+
session = Session.builder.configs({str(db_parameters)}).create()
969+
df = session.read.dbapi(
970+
local_create_connection,
971+
table='{table_name}',
972+
custom_schema='{SQLITE3_DB_CUSTOM_SCHEMA_STRING}',
973+
)
974+
assert df.collect()
975+
print("successful ingestion")
976+
"""
977+
)
978+
result = subprocess.run(["python", "-c", code], capture_output=True, text=True)
979+
assert "successful ingestion" in result.stdout

tox.ini

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ omit = */src/snowflake/snowpark/modin/config/*
77
*/src/snowflake/snowpark/modin/plugin/docstrings/*
88
*/src/snowflake/snowpark/mock/*
99
*/src/snowflake/snowpark/_internal/data_source/datasource_typing.py
10-
tests/integ/datasource/test_databricks.py
11-
tests/integ/datasource/test_oracledb.py
10+
tests/integ/datasource/*.py
11+
tests/integ/test_data_source_api.py
12+
1213
[coverage:run]
1314
relative_files = true
1415
branch = true

0 commit comments

Comments
 (0)