Skip to content

Commit 8aa988e

Browse files
committed
impl
1 parent fd46eb2 commit 8aa988e

File tree

7 files changed

+215
-25
lines changed

7 files changed

+215
-25
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
### Snowpark Python API Updates
66

7+
#### New Features
8+
9+
- Added PostgreSQL support to `DataFrameReader.dbapi` (PrPr) for both Parquet and UDTF-based ingestion.
10+
711
#### Improvements
812

913
- Invoking snowflake system procedures does not invoke an additional `describe procedure` call to check the return type of the procedure.

src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
4040
f"{self.__class__.__name__} has not implemented to_snow_type function"
4141
)
4242

43+
@staticmethod
4344
def prepare_connection(
44-
self,
4545
conn: "Connection",
4646
query_timeout: int = 0,
4747
) -> "Connection":
@@ -81,7 +81,7 @@ def udtf_ingestion(
8181
udtf_name = f"data_source_udtf_{generate_random_alphanumeric(5)}"
8282
start = time.time()
8383
session.udtf.register(
84-
self.udtf_class_builder(fetch_size=fetch_size),
84+
self.udtf_class_builder(fetch_size=fetch_size, schema=schema),
8585
name=udtf_name,
8686
output_schema=StructType(
8787
[
@@ -104,7 +104,9 @@ def udtf_ingestion(
104104
]
105105
return res.select(cols)
106106

107-
def udtf_class_builder(self, fetch_size: int = 1000) -> type:
107+
def udtf_class_builder(
108+
self, fetch_size: int = 1000, schema: StructType = None
109+
) -> type:
108110
create_connection = self.create_connection
109111

110112
class UDTFIngestion:

src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
109109

110110
return StructType(fields)
111111

112+
@staticmethod
112113
def prepare_connection(
113-
self,
114114
conn: "Connection",
115115
query_timeout: int = 0,
116116
) -> "Connection":

src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py

Lines changed: 96 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,10 @@ def __init__(
173173
super().__init__(create_connection, dbms_type)
174174

175175
def to_snow_type(self, schema: List[Any]) -> StructType:
176-
# TODO: Implement this method to convert PostgreSQL types to Snowflake types.
177-
# https://other-docs.snowflake.com/en/connectors/postgres6/view-data#postgresql-to-snowflake-data-type-mapping
178-
# psycopg2 type code: https://github.com/psycopg/psycopg2/blob/master/psycopg/pgtypes.h
179-
# https://www.postgresql.org/docs/current/datatype.html
176+
# The psycopg2 spec is defined in the following links:
180177
# https://www.psycopg.org/docs/cursor.html#cursor.description
181-
# https://www.psycopg.org/docs/extensions.html#psycopg2.extensions.Column.type_code
182-
# https://www.postgresql.org/docs/current/catalog-pg-type.html
183-
# https://www.psycopg.org/docs/advanced.html#type-casting-from-sql-to-python
184-
fields = []
185178
# https://www.psycopg.org/docs/extensions.html#psycopg2.extensions.Column
179+
fields = []
186180
for (
187181
name,
188182
type_code,
@@ -222,14 +216,6 @@ def data_source_data_to_pandas_df(
222216
data: List[Any], schema: StructType
223217
) -> "pd.DataFrame":
224218
df = BaseDriver.data_source_data_to_pandas_df(data, schema)
225-
# psycopg2 returns binary data as memoryview, we need to convert it to bytes
226-
binary_type_indexes = [
227-
i
228-
for i, field in enumerate(schema.fields)
229-
if isinstance(field.datatype, BinaryType)
230-
]
231-
col_names = df.columns[binary_type_indexes]
232-
df[col_names] = BaseDriver.df_map_method(df[col_names])(lambda x: bytes(x))
233219

234220
variant_type_indexes = [
235221
i
@@ -259,8 +245,8 @@ def to_result_snowpark_df(
259245
project_columns, _emit_ast=_emit_ast
260246
)
261247

248+
@staticmethod
262249
def prepare_connection(
263-
self,
264250
conn: "Connection",
265251
query_timeout: int = 0,
266252
) -> "Connection":
@@ -275,4 +261,97 @@ def prepare_connection(
275261
lambda data, cursor: data,
276262
)
277263
register_type(SNOWPARK_INTERVAL_STR, conn)
264+
265+
# by default psycopg2 returns binary data as memoryview
266+
# to avoid using pandas to convert memoryview to bytes, we use the following native psycopg2 type conversion
267+
# psycopg2.extensions.new_type() only works for text format data, it returns bytes as hex string
268+
# we reconstruct the bytes from hex string
269+
SNOWPARK_BYTE = new_type(
270+
(Psycopg2TypeCode.BYTEAOID.value,),
271+
"SNOWPARK_BYTE_BYTES",
272+
lambda data, cursor: bytes.fromhex(data[2:])
273+
if data is not None
274+
else None, # [2:] to skip the '\\x' prefix
275+
)
276+
register_type(SNOWPARK_BYTE, conn)
277+
278+
if query_timeout:
279+
# https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
280+
# postgres default uses milliseconds
281+
conn.cursor().execute(f"SET STATEMENT_TIMEOUT = {query_timeout * 1000}")
278282
return conn
283+
284+
def udtf_class_builder(
285+
self, fetch_size: int = 1000, schema: StructType = None
286+
) -> type:
287+
create_connection = self.create_connection
288+
289+
# TODO: SNOW-2101485 ues class method to prepare connection
290+
# ideally we should use the same function as prepare_connection
291+
# however, since we introduce new module for new driver support and initially the new module is not available in the backend
292+
# so if registering UDTF which uses the class method, cloudpickle will pickle the class method along with
293+
# the new module -- this leads to not being able to find the new module when unpickling on the backend.
294+
# once the new module is available in the backend, we can use the class method.
295+
def prepare_connection_in_udtf(
296+
conn: "Connection",
297+
query_timeout: int = 0,
298+
) -> "Connection":
299+
# The following is to align with Snowflake Connector behavior which get Interval as string
300+
# the default behavior of psycopg2 is to get Interval as datetime.timedelta
301+
# https://other-docs.snowflake.com/en/connectors/postgres6/view-data#postgresql-to-snowflake-data-type-mapping
302+
from psycopg2.extensions import new_type, register_type
303+
304+
# we do not use Psycopg2TypeCode.INTERVALOID.value because UTDF pickles the psycopg2_driver module
305+
# unpickling in the UDTF would results in module not found error if package not available in the backend
306+
SNOWPARK_INTERVAL_STR = new_type(
307+
(1186,),
308+
"SNOWPARK_INTERVAL_STR",
309+
lambda data, cursor: data,
310+
)
311+
register_type(SNOWPARK_INTERVAL_STR, conn)
312+
313+
if query_timeout:
314+
# https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT
315+
# postgres default uses milliseconds
316+
conn.cursor().execute(f"SET STATEMENT_TIMEOUT = {query_timeout * 1000}")
317+
return conn
318+
319+
binary_column_indexes = [
320+
i
321+
for i, field in enumerate(schema.fields)
322+
if isinstance(field.datatype, BinaryType)
323+
]
324+
time_column_indexes = [
325+
i
326+
for i, field in enumerate(schema.fields)
327+
if isinstance(field.datatype, TimeType)
328+
]
329+
330+
# postgres returns binary data as memoryview, we need to convert it to bytes
331+
def convert_rows(rows_to_update):
332+
ret = []
333+
for row in rows_to_update:
334+
# convert tuple to list to make it mutable
335+
new_row = list(row)
336+
# convert bytes to hexstring so that variant column can be cast to bytes
337+
for idx in binary_column_indexes:
338+
new_row[idx] = bytes(row[idx]).hex() if row[idx] else None
339+
# remove timezone info from time columns
340+
for idx in time_column_indexes:
341+
new_row[idx] = row[idx].replace(tzinfo=None) if row[idx] else None
342+
# convert list back to tuple as UDTF requires tuple
343+
ret.append(tuple(new_row))
344+
return ret
345+
346+
class UDTFIngestion:
347+
def process(self, query: str):
348+
conn = prepare_connection_in_udtf(create_connection())
349+
cursor = conn.cursor()
350+
cursor.execute(query)
351+
while True:
352+
rows = cursor.fetchmany(fetch_size)
353+
if not rows:
354+
break
355+
yield from convert_rows(rows)
356+
357+
return UDTFIngestion

src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def process(self, query: str):
116116

117117
return UDTFIngestion
118118

119+
@staticmethod
119120
def prepare_connection(
120-
self,
121121
conn: "Connection",
122122
query_timeout: int = 0,
123123
) -> "Connection":

tests/integ/datasource/test_postgres.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import datetime
66
from decimal import Decimal
77

8+
from snowflake.snowpark import Row
89
from snowflake.snowpark.exceptions import SnowparkDataframeReaderException
910
from snowflake.snowpark.types import (
1011
StructType,
@@ -40,7 +41,7 @@
4041
]
4142

4243

43-
TEST_TABLE_NAME = "test_schema.ALL_TYPE_TABLE"
44+
POSTGRES_TABLE_NAME = "test_schema.ALL_TYPE_TABLE"
4445
EXPECTED_TEST_DATA = [
4546
(
4647
-6645531000000000000,
@@ -267,6 +268,51 @@
267268
"960b86a9-a8dd-4634-bc1f-956ae6589726",
268269
"<root><element>47</element></root>",
269270
),
271+
(
272+
None,
273+
6,
274+
None,
275+
None,
276+
None,
277+
None,
278+
None,
279+
None,
280+
None,
281+
None,
282+
None,
283+
None,
284+
None,
285+
None,
286+
None,
287+
None,
288+
"null",
289+
"null",
290+
None,
291+
None,
292+
None,
293+
None,
294+
"null",
295+
None,
296+
None,
297+
None,
298+
None,
299+
None,
300+
None,
301+
None,
302+
None,
303+
6,
304+
6,
305+
None,
306+
None,
307+
None,
308+
None,
309+
None,
310+
None,
311+
None,
312+
None,
313+
None,
314+
None,
315+
),
270316
]
271317
EXPECTED_TYPE = StructType(
272318
[
@@ -319,6 +365,7 @@
319365
StructField("XML_COL", StringType(16777216), nullable=True),
320366
]
321367
)
368+
POSTGRES_TEST_EXTERNAL_ACCESS_INTEGRATION = "snowpark_dbapi_postgres_test_integration"
322369

323370

324371
def create_postgres_connection():
@@ -327,7 +374,10 @@ def create_postgres_connection():
327374

328375
@pytest.mark.parametrize(
329376
"input_type, input_value",
330-
[("table", TEST_TABLE_NAME), ("query", f"(SELECT * FROM {TEST_TABLE_NAME})")],
377+
[
378+
("table", POSTGRES_TABLE_NAME),
379+
("query", f"(SELECT * FROM {POSTGRES_TABLE_NAME})"),
380+
],
331381
)
332382
def test_basic_postgres(session, input_type, input_value):
333383
input_dict = {
@@ -350,3 +400,58 @@ def test_error_case(session, input_type, input_value, error_message):
350400
}
351401
with pytest.raises(SnowparkDataframeReaderException, match=error_message):
352402
session.read.dbapi(create_postgres_connection, **input_dict)
403+
404+
405+
def test_query_timeout(session):
406+
with pytest.raises(
407+
SnowparkDataframeReaderException,
408+
match=r"due to exception 'QueryCanceled\('canceling statement due to statement timeout",
409+
):
410+
session.read.dbapi(
411+
create_postgres_connection,
412+
table=POSTGRES_TABLE_NAME,
413+
query_timeout=1,
414+
session_init_statement=["SELECT pg_sleep(5)"],
415+
)
416+
417+
418+
def test_external_access_integration_not_set(session):
419+
with pytest.raises(
420+
ValueError,
421+
match="external_access_integration cannot be None when udtf ingestion is used.",
422+
):
423+
session.read.dbapi(
424+
create_postgres_connection, table=POSTGRES_TABLE_NAME, udtf_configs={}
425+
)
426+
427+
428+
def test_unicode_column_name_postgres(session):
429+
df = session.read.dbapi(
430+
create_postgres_connection, table='test_schema."用户資料"'
431+
).order_by("編號")
432+
assert df.collect() == [Row(編號=1, 姓名="山田太郎", 國家="日本", 備註="これはUnicodeテストです")]
433+
assert df.columns == ['"編號"', '"姓名"', '"國家"', '"備註"']
434+
435+
436+
def test_udtf_ingestion_postgres(session, caplog):
437+
from tests.parameters import POSTGRES_CONNECTION_PARAMETERS
438+
439+
def create_connection_postgres():
440+
import psycopg2
441+
442+
return psycopg2.connect(**POSTGRES_CONNECTION_PARAMETERS)
443+
444+
df = session.read.dbapi(
445+
create_connection_postgres,
446+
table=POSTGRES_TABLE_NAME,
447+
udtf_configs={
448+
"external_access_integration": POSTGRES_TEST_EXTERNAL_ACCESS_INTEGRATION
449+
},
450+
).order_by("BIGSERIAL_COL")
451+
452+
assert df.collect() == EXPECTED_TEST_DATA
453+
# assert UDTF creation and UDTF call
454+
assert (
455+
"TEMPORARY FUNCTION data_source_udtf_" "" in caplog.text
456+
and "table(data_source_udtf" in caplog.text
457+
)

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ deps =
223223
{[testenv]deps}
224224
databricks-sql-connector
225225
oracledb
226-
psycopg2
226+
psycopg2-binary
227227
commands = {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE}" {posargs:} tests/integ/datasource
228228

229229
[pytest]

0 commit comments

Comments
 (0)