Skip to content

Commit ee0df28

Browse files
authored
SNOW-2372025: support passing connection parameters to dbapi (#3845)
1 parent 69c400a commit ee0df28

File tree

15 files changed

+352
-24
lines changed

15 files changed

+352
-24
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#### New Features
88

99
- Added a new function `service` in `snowflake.snowpark.functions` that allows users to create a callable representing a Snowpark Container Services (SPCS) service.
10+
- Added `connection_parameters` parameter to `DataFrameReader.dbapi()` (PuPr) method to allow passing keyword arguments to the `create_connection` callable.
1011
- Added support for `Session.begin_transaction`, `Session.commit` and `Session.rollback`.
1112
- Added support for the following functions in `functions.py`:
1213
- Geospatial functions:

src/snowflake/snowpark/_internal/data_source/datasource_partitioner.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
class DataSourcePartitioner:
3838
def __init__(
3939
self,
40-
create_connection: Callable[[], "Connection"],
40+
create_connection: Callable[..., "Connection"],
4141
table_or_query: str,
4242
is_query: bool,
4343
column: Optional[str] = None,
@@ -50,6 +50,7 @@ def __init__(
5050
predicates: Optional[List[str]] = None,
5151
session_init_statement: Optional[List[str]] = None,
5252
fetch_merge_count: Optional[int] = 1,
53+
connection_parameters: Optional[dict] = None,
5354
) -> None:
5455
self.create_connection = create_connection
5556
self.table_or_query = table_or_query
@@ -64,14 +65,21 @@ def __init__(
6465
self.predicates = predicates
6566
self.session_init_statement = session_init_statement
6667
self.fetch_merge_count = fetch_merge_count
67-
conn = create_connection()
68+
self.connection_parameters = connection_parameters
69+
conn = (
70+
create_connection(**connection_parameters)
71+
if connection_parameters
72+
else create_connection()
73+
)
6874
dbms_type, driver_type = detect_dbms(conn)
6975
self.driver_type = driver_type
7076
self.dbms_type = dbms_type
7177
self.dialect_class = DBMS_MAP.get(dbms_type, BaseDialect)
7278
self.driver_class = DRIVER_MAP.get(driver_type, BaseDriver)
7379
self.dialect = self.dialect_class()
74-
self.driver = self.driver_class(create_connection, dbms_type)
80+
self.driver = self.driver_class(
81+
create_connection, dbms_type, connection_parameters
82+
)
7583

7684
self._query_input_alias = (
7785
f"SNOWPARK_DBAPI_QUERY_INPUT_ALIAS_{generate_random_alphanumeric(5).upper()}"
@@ -89,6 +97,7 @@ def reader(self) -> DataSourceReader:
8997
self.query_timeout,
9098
self.session_init_statement,
9199
self.fetch_merge_count,
100+
self.connection_parameters,
92101
)
93102

94103
@cached_property

src/snowflake/snowpark/_internal/data_source/datasource_reader.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,23 @@ class DataSourceReader:
2121
def __init__(
2222
self,
2323
driver_class: Type[BaseDriver],
24-
create_connection: Callable[[], "Connection"],
24+
create_connection: Callable[..., "Connection"],
2525
schema: StructType,
2626
dbms_type: Enum,
2727
fetch_size: Optional[int] = 0,
2828
query_timeout: Optional[int] = 0,
2929
session_init_statement: Optional[List[str]] = None,
3030
fetch_merge_count: Optional[int] = 1,
31+
connection_parameters: Optional[dict] = None,
3132
) -> None:
3233
# we use cloudpickle to pickle the callback function so that local function and function defined in
3334
# __main__ can be pickled and unpickled in subprocess
3435
self.pickled_create_connection_callback = cloudpickle.dumps(
3536
create_connection, protocol=pickle.HIGHEST_PROTOCOL
3637
)
38+
self.pickled_connection_parameters = cloudpickle.dumps(
39+
connection_parameters, protocol=pickle.HIGHEST_PROTOCOL
40+
)
3741
self.driver = None
3842
self.driver_class = driver_class
3943
self.dbms_type = dbms_type
@@ -44,14 +48,19 @@ def __init__(
4448
self.fetch_merge_count = fetch_merge_count
4549

4650
def read(self, partition: str) -> Iterator[List[Any]]:
51+
connection_parameters = cloudpickle.loads(self.pickled_connection_parameters)
4752
self.driver = self.driver_class(
4853
cloudpickle.loads(self.pickled_create_connection_callback),
4954
self.dbms_type,
55+
connection_parameters,
5056
)
5157

52-
conn = self.driver.prepare_connection(
53-
self.driver.create_connection(), self.query_timeout
58+
create_conn_result = (
59+
self.driver.create_connection(**connection_parameters)
60+
if connection_parameters
61+
else self.driver.create_connection()
5462
)
63+
conn = self.driver.prepare_connection(create_conn_result, self.query_timeout)
5564
try:
5665
cursor = conn.cursor()
5766
if self.session_init_statement:

src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,21 @@
4343
class BaseDriver:
4444
def __init__(
4545
self,
46-
create_connection: Callable[[], "Connection"],
46+
create_connection: Callable[..., "Connection"],
4747
dbms_type: Enum,
48+
connection_parameters: Optional[dict] = None,
4849
) -> None:
4950
self.create_connection = create_connection
5051
self.dbms_type = dbms_type
52+
self.connection_parameters = connection_parameters
5153
self.raw_schema = None
5254

55+
def _call_create_connection(self) -> "Connection":
56+
"""Call create_connection with connection_parameters if provided."""
57+
if self.connection_parameters:
58+
return self.create_connection(**self.connection_parameters)
59+
return self.create_connection()
60+
5361
def to_snow_type(self, schema: List[Any]) -> StructType:
5462
raise NotImplementedError(
5563
f"{self.__class__.__name__} has not implemented to_snow_type function"
@@ -100,7 +108,7 @@ def infer_schema_from_description(
100108
def infer_schema_from_description_with_error_control(
101109
self, table_or_query: str, is_query: bool, query_input_alias: str
102110
) -> StructType:
103-
conn = self.create_connection()
111+
conn = self._call_create_connection()
104112
cursor = conn.cursor()
105113
try:
106114
return self.infer_schema_from_description(
@@ -184,10 +192,16 @@ def udtf_class_builder(
184192
) -> type:
185193
create_connection = self.create_connection
186194
prepare_connection = self.prepare_connection
195+
connection_parameters = self.connection_parameters
187196

188197
class UDTFIngestion:
189198
def process(self, query: str):
190-
conn = prepare_connection(create_connection(), query_timeout)
199+
conn_result = (
200+
create_connection(**connection_parameters)
201+
if connection_parameters
202+
else create_connection()
203+
)
204+
conn = prepare_connection(conn_result, query_timeout)
191205
cursor = conn.cursor()
192206
if session_init_statement is not None:
193207
for statement in session_init_statement:

src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,15 @@ def udtf_class_builder(
8888
query_timeout: int = 0,
8989
) -> type:
9090
create_connection = self.create_connection
91+
connection_parameters = self.connection_parameters
9192

9293
class UDTFIngestion:
9394
def process(self, query: str):
94-
conn = create_connection()
95+
conn = (
96+
create_connection(**connection_parameters)
97+
if connection_parameters
98+
else create_connection()
99+
)
95100
cursor = conn.cursor()
96101
if session_init_statement is not None:
97102
for statement in session_init_statement:

src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def udtf_class_builder(
143143
query_timeout: int = 0,
144144
) -> type:
145145
create_connection = self.create_connection
146+
connection_parameters = self.connection_parameters
146147

147148
def oracledb_output_type_handler(cursor, metadata):
148149
from oracledb import (
@@ -166,7 +167,11 @@ def convert_to_hex(value):
166167

167168
class UDTFIngestion:
168169
def process(self, query: str):
169-
conn = create_connection()
170+
conn = (
171+
create_connection(**connection_parameters)
172+
if connection_parameters
173+
else create_connection()
174+
)
170175
if query_timeout > 0:
171176
conn.call_timeout = query_timeout * 1000
172177
if conn.outputtypehandler is None:

src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#
44
import logging
55
from enum import Enum
6-
from typing import Callable, List, Any, TYPE_CHECKING
6+
from typing import List, Any, TYPE_CHECKING
77

88
from snowflake.snowpark._internal.data_source.datasource_typing import Connection
99
from snowflake.snowpark._internal.data_source.drivers import BaseDriver
@@ -167,11 +167,6 @@ class Psycopg2TypeCode(Enum):
167167

168168

169169
class Psycopg2Driver(BaseDriver):
170-
def __init__(
171-
self, create_connection: Callable[[], "Connection"], dbms_type: Enum
172-
) -> None:
173-
super().__init__(create_connection, dbms_type)
174-
175170
def to_snow_type(self, schema: List[Any]) -> StructType:
176171
# The psycopg2 spec is defined in the following links:
177172
# https://www.psycopg.org/docs/cursor.html#cursor.description
@@ -272,6 +267,7 @@ def udtf_class_builder(
272267
query_timeout: int = 0,
273268
) -> type:
274269
create_connection = self.create_connection
270+
connection_parameters = self.connection_parameters
275271

276272
# TODO: SNOW-2101485 use class method to prepare connection
277273
# ideally we should use the same function as prepare_connection
@@ -291,7 +287,12 @@ def prepare_connection_in_udtf(
291287

292288
class UDTFIngestion:
293289
def process(self, query: str):
294-
conn = prepare_connection_in_udtf(create_connection(), query_timeout)
290+
conn_result = (
291+
create_connection(**connection_parameters)
292+
if connection_parameters
293+
else create_connection()
294+
)
295+
conn = prepare_connection_in_udtf(conn_result, query_timeout)
295296
cursor = conn.cursor(
296297
f"SNOWPARK_CURSOR_{generate_random_alphanumeric(5)}"
297298
)

src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,17 @@ def udtf_class_builder(
203203
query_timeout: int = 0,
204204
) -> type:
205205
create_connection = self.create_connection
206+
connection_parameters = self.connection_parameters
206207

207208
class UDTFIngestion:
208209
def process(self, query: str):
209210
import pymysql
210211

211-
conn = create_connection()
212+
conn = (
213+
create_connection(**connection_parameters)
214+
if connection_parameters
215+
else create_connection()
216+
)
212217
cursor = pymysql.cursors.SSCursor(conn)
213218
if session_init_statement is not None:
214219
for statement in session_init_statement:

src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def udtf_class_builder(
8686
) -> type:
8787
create_connection = self.create_connection
8888
prepare_connection = self.prepare_connection
89+
connection_parameters = self.connection_parameters
8990

9091
def binary_converter(value):
9192
return value.hex() if value is not None else None
@@ -94,7 +95,12 @@ class UDTFIngestion:
9495
def process(self, query: str):
9596
import pyodbc
9697

97-
conn = prepare_connection(create_connection(), query_timeout)
98+
conn_result = (
99+
create_connection(**connection_parameters)
100+
if connection_parameters
101+
else create_connection()
102+
)
103+
conn = prepare_connection(conn_result, query_timeout)
98104
if (
99105
conn.get_output_converter(pyodbc.SQL_BINARY) is None
100106
and conn.get_output_converter(pyodbc.SQL_VARBINARY) is None

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,7 +1681,7 @@ def jdbc(
16811681
@publicapi
16821682
def dbapi(
16831683
self,
1684-
create_connection: Callable[[], "Connection"],
1684+
create_connection: Callable[..., "Connection"],
16851685
*,
16861686
table: Optional[str] = None,
16871687
query: Optional[str] = None,
@@ -1698,6 +1698,7 @@ def dbapi(
16981698
udtf_configs: Optional[dict] = None,
16991699
fetch_merge_count: int = 1,
17001700
fetch_with_process: bool = False,
1701+
connection_parameters: Optional[dict] = None,
17011702
_emit_ast: bool = True,
17021703
) -> DataFrame:
17031704
"""
@@ -1718,7 +1719,9 @@ def dbapi(
17181719
column name with table in external data source.
17191720
17201721
Args:
1721-
create_connection: A callable that takes no arguments and returns a DB-API compatible database connection.
1722+
create_connection: A callable that returns a DB-API compatible database connection.
1723+
The callable can optionally accept keyword arguments via `**kwargs`.
1724+
If connection_parameters is provided, those will be passed as keyword arguments to this callable.
17221725
The callable must be picklable, as it will be passed to and executed in child processes.
17231726
table: The name of the table in the external data source.
17241727
This parameter cannot be used together with the `query` parameter.
@@ -1777,6 +1780,11 @@ def dbapi(
17771780
like Parquet file generation. When using multiprocessing, guard your script with
17781781
`if __name__ == "__main__":` and call `multiprocessing.freeze_support()` on Windows if needed.
17791782
This parameter has no effect in UDFT ingestion.
1783+
connection_parameters: Optional dictionary of parameters to pass to the create_connection callable.
1784+
If provided, these parameters will be unpacked and passed as keyword arguments
1785+
to create_connection(`**connection_parameters`).
1786+
This allows for flexible connection configuration without hardcoding values in the callable.
1787+
Example: {"timeout": 30, "isolation_level": "READ_UNCOMMITTED"}
17801788
17811789
Example::
17821790
.. code-block:: python
@@ -1796,8 +1804,27 @@ def create_oracledb_connection():
17961804
connection = oracledb.connect(...)
17971805
return connection
17981806
1799-
if __name__ == "__main__":
1800-
df = session.read.dbapi(create_oracledb_connection, table=..., fetch_with_process=True)
1807+
df = session.read.dbapi(create_oracledb_connection, table=..., fetch_with_process=True)
1808+
1809+
Example::
1810+
.. code-block:: python
1811+
1812+
import sqlite3
1813+
def create_sqlite_connection(timeout=5.0, isolation_level=None, **kwargs):
1814+
connection = sqlite3.connect(
1815+
database=":memory:",
1816+
timeout=timeout,
1817+
isolation_level=isolation_level
1818+
)
1819+
return connection
1820+
1821+
connection_params = {"timeout": 30.0, "isolation_level": "DEFERRED"}
1822+
df = session.read.dbapi(
1823+
create_sqlite_connection,
1824+
table=...,
1825+
connection_parameters=connection_params
1826+
)
1827+
18011828
"""
18021829
if (not table and not query) or (table and query):
18031830
raise SnowparkDataframeReaderException(
@@ -1831,6 +1858,7 @@ def create_oracledb_connection():
18311858
predicates,
18321859
session_init_statement,
18331860
fetch_merge_count,
1861+
connection_parameters,
18341862
)
18351863
struct_schema = partitioner.schema
18361864
partitioned_queries = partitioner.partitions

0 commit comments

Comments
 (0)