Skip to content

Commit 0a9d780

Browse files
authored
SNOW-2301201: use server cursor to fetch data (#3726)
1 parent 50771a4 commit 0a9d780

File tree

8 files changed

+58
-3
lines changed

8 files changed

+58
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
- Enhanced error handling in `DataFrameReader.dbapi` thread-based ingestion to prevent unnecessary operations, which improves resource efficiency.
5050
- Bumped cloudpickle dependency to also support `cloudpickle==3.1.1` in addition to previous versions.
51+
- Improved `DataFrameReader.dbapi` (PuPr) ingestion performance for PostgreSQL and MySQL by using server side cursor to fetch data.
5152

5253
### Snowpark pandas API Updates
5354

src/snowflake/snowpark/_internal/data_source/datasource_reader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def read(self, partition: str) -> Iterator[List[Any]]:
5252
conn = self.driver.prepare_connection(
5353
self.driver.create_connection(), self.query_timeout
5454
)
55-
cursor = conn.cursor()
5655
try:
56+
cursor = conn.cursor()
5757
if self.session_init_statement:
5858
for statement in self.session_init_statement:
5959
try:
@@ -62,6 +62,9 @@ def read(self, partition: str) -> Iterator[List[Any]]:
6262
raise SnowparkDataframeReaderException(
6363
f"Failed to execute session init statement: '{statement}' due to exception '{exc!r}'"
6464
)
65+
# use server side cursor to fetch data if supported by the driver
66+
# some drivers do not support execute twice on server side cursor (e.g. psycopg2)
67+
cursor = self.driver.get_server_cursor_if_supported(conn)
6568
if self.fetch_size == 0:
6669
cursor.execute(partition)
6770
result = cursor.fetchall()

src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,20 @@ def to_result_snowpark_df_udtf(
260260
for field in schema.fields
261261
]
262262
return res_df.select(cols, _emit_ast=_emit_ast)
263+
264+
def get_server_cursor_if_supported(self, conn: "Connection") -> "Cursor":
265+
"""
266+
This method is used to get a server cursor if the driver and the DBMS supports it.
267+
It can be overridden by the driver to return a server cursor if supported.
268+
Otherwise, it will return the default cursor supported by the driver and the DBMS.
269+
270+
- databricks-sql-connector: no concept of client/server cursor, no need to override
271+
- python-oracledb: default to the server cursor, no need to override
272+
- psycopg2: default to the client cursor which needs to be overridden to return the server cursor
273+
- pymysql: default to the client cursor which needs to be overridden to return the server cursor
274+
275+
TODO:
276+
- pyodbc: This is a Python wrapper on top of ODBC drivers, the ODBC driver and the DBMS may or may not support server cursor
277+
and if they do support, the way to get the server cursor may vary across different DBMS. we need to document pyodbc.
278+
"""
279+
return conn.cursor()

src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from snowflake.snowpark._internal.data_source.datasource_typing import Connection
99
from snowflake.snowpark._internal.data_source.drivers import BaseDriver
10+
from snowflake.snowpark._internal.utils import generate_random_alphanumeric
1011
from snowflake.snowpark.functions import to_variant, parse_json, column
1112
from snowflake.snowpark.types import (
1213
StructType,
@@ -28,6 +29,9 @@
2829
if TYPE_CHECKING:
2930
from snowflake.snowpark.session import Session # pragma: no cover
3031
from snowflake.snowpark.dataframe import DataFrame # pragma: no cover
32+
from snowflake.snowpark._internal.data_source.datasource_typing import (
33+
Cursor,
34+
) # pragma: no cover
3135

3236

3337
logger = logging.getLogger(__name__)
@@ -277,7 +281,9 @@ def prepare_connection_in_udtf(
277281
class UDTFIngestion:
278282
def process(self, query: str):
279283
conn = prepare_connection_in_udtf(create_connection())
280-
cursor = conn.cursor()
284+
cursor = conn.cursor(
285+
f"SNOWPARK_CURSOR_{generate_random_alphanumeric(5)}"
286+
)
281287
cursor.execute(query)
282288
while True:
283289
rows = cursor.fetchmany(fetch_size)
@@ -286,3 +292,6 @@ def process(self, query: str):
286292
yield from rows
287293

288294
return UDTFIngestion
295+
296+
def get_server_cursor_if_supported(self, conn: "Connection") -> "Cursor":
297+
return conn.cursor(f"SNOWPARK_CURSOR_{generate_random_alphanumeric(5)}")

src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,10 @@ def udtf_class_builder(
192192

193193
class UDTFIngestion:
194194
def process(self, query: str):
195+
import pymysql
195196

196197
conn = create_connection()
197-
cursor = conn.cursor()
198+
cursor = pymysql.cursors.SSCursor(conn)
198199
cursor.execute(query)
199200
while True:
200201
rows = cursor.fetchmany(fetch_size)
@@ -256,3 +257,8 @@ def to_result_snowpark_df_udtf(
256257
else:
257258
cols.append(res_df[field.name].cast(field.datatype).alias(field.name))
258259
return res_df.select(cols, _emit_ast=_emit_ast)
260+
261+
def get_server_cursor_if_supported(self, conn: "Connection") -> "Cursor":
262+
import pymysql
263+
264+
return pymysql.cursors.SSCursor(conn)

tests/integ/datasource/test_mysql.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,3 +287,12 @@ def test_pymysql_driver_udtf_class_builder():
287287
# Verify we got data with the right structure (2 columns)
288288
assert len(column_result_rows) > 0
289289
assert len(column_result_rows[0]) == 2 # Two columns
290+
291+
292+
def test_server_side_cursor():
293+
conn = create_connection_mysql()
294+
driver = PymysqlDriver(create_connection_mysql, DBMS_TYPE.MYSQL_DB)
295+
cursor = driver.get_server_cursor_if_supported(conn)
296+
assert isinstance(cursor, pymysql.cursors.SSCursor)
297+
cursor.close()
298+
conn.close()

tests/integ/datasource/test_postgres.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,3 +464,12 @@ def test_unit_generate_select_query():
464464
'SELECT TO_JSON("jsonb_col")::TEXT AS jsonb_col FROM test_table'
465465
)
466466
assert jsonb_query == expected_jsonb_query
467+
468+
469+
def test_server_side_cursor(session):
470+
conn = create_postgres_connection()
471+
driver = Psycopg2Driver(create_postgres_connection, DBMS_TYPE.POSTGRES_DB)
472+
cursor = driver.get_server_cursor_if_supported(conn)
473+
assert cursor.name is not None # Server-side cursor should have a name
474+
cursor.close()
475+
conn.close()

tests/unit/test_data_source.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def test_datasource_reader_close_error_handling(cursor_fails, conn_fails):
8181
mock_driver.prepare_connection.return_value = mock_conn
8282
mock_driver.create_connection.return_value = mock_conn
8383
mock_conn.cursor.return_value = mock_cursor
84+
mock_driver.get_server_cursor_if_supported.return_value = mock_cursor
8485
mock_cursor.fetchall.return_value = [("test_data",)]
8586

8687
# Configure failures

0 commit comments

Comments
 (0)