Skip to content

Commit 2ee064b

Browse files
authored
SNOW-1955847: DBAPI Postgres Support (#3351)
1 parent 5fb210d commit 2ee064b

File tree

18 files changed

+1155
-14
lines changed

18 files changed

+1155
-14
lines changed
58 Bytes
Binary file not shown.

.github/workflows/precommit.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ jobs:
210210
TOX_PARALLEL_NO_SPINNER: 1
211211
shell: bash
212212
- name: Run data source tests
213+
# psycopg2 is not supported on macos 3.9
214+
if: ${{ !(matrix.os == 'macos-latest' && matrix.python-version == '3.9') }}
213215
run: python -m tox -e datasource
214216
env:
215217
PYTHON_VERSION: ${{ matrix.python-version }}

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
#### New Features
88

9-
- Added support for ingestion with Snowflake UDTF to databricks in `DataFrameReader.dbapi` (PrPr).
10-
- Added support for Mysql in `DataFrameWriter.dbapi` (PrPr).
9+
- Added support for MySQL in `DataFrameWriter.dbapi` (PrPr) for both Parquet and UDTF-based ingestion.
10+
- Added support for PostgreSQL in `DataFrameReader.dbapi` (PrPr) for both Parquet and UDTF-based ingestion.
11+
- Added support for Databricks in `DataFrameWriter.dbapi` (PrPr) for UDTF-based ingestion.
1112

1213
#### Bug Fixes
1314

scripts/parameters.py.gpg

63 Bytes
Binary file not shown.

src/snowflake/snowpark/_internal/data_source/dbms_dialects/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"SqlServerDialect",
88
"OracledbDialect",
99
"DatabricksDialect",
10+
"PostgresDialect",
1011
"MysqlDialect",
1112
]
1213

@@ -25,6 +26,9 @@
2526
from snowflake.snowpark._internal.data_source.dbms_dialects.databricks_dialect import (
2627
DatabricksDialect,
2728
)
29+
from snowflake.snowpark._internal.data_source.dbms_dialects.postgresql_dialect import (
30+
PostgresDialect,
31+
)
2832
from snowflake.snowpark._internal.data_source.dbms_dialects.mysql_dialect import (
2933
MysqlDialect,
3034
)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#
2+
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
3+
#
4+
from typing import List
5+
6+
from snowflake.snowpark._internal.data_source.dbms_dialects import BaseDialect
7+
from snowflake.snowpark._internal.data_source.drivers.psycopg2_driver import (
8+
Psycopg2TypeCode,
9+
)
10+
from snowflake.snowpark.types import StructType
11+
12+
13+
class PostgresDialect(BaseDialect):
14+
@staticmethod
15+
def generate_select_query(
16+
table_or_query: str, schema: StructType, raw_schema: List[tuple], is_query: bool
17+
) -> str:
18+
cols = []
19+
for _field, raw_field in zip(schema.fields, raw_schema):
20+
# databricks-sql-connector returns list of tuples for MapType
21+
# here we push down to-dict conversion to Databricks
22+
type_code = raw_field[1]
23+
if type_code in (
24+
Psycopg2TypeCode.JSONB.value,
25+
Psycopg2TypeCode.JSON.value,
26+
):
27+
cols.append(f"""TO_JSON("{raw_field[0]}")::TEXT AS {raw_field[0]}""")
28+
elif type_code == Psycopg2TypeCode.CASHOID.value:
29+
cols.append(
30+
f"""CASE WHEN "{raw_field[0]}" IS NULL THEN NULL ELSE FORMAT('"%s"', "{raw_field[0]}"::TEXT) END AS {raw_field[0]}"""
31+
)
32+
elif type_code == Psycopg2TypeCode.BYTEAOID.value:
33+
cols.append(f"""ENCODE("{raw_field[0]}", 'HEX') AS {raw_field[0]}""")
34+
elif type_code == Psycopg2TypeCode.TIMETZOID.value:
35+
cols.append(f""""{raw_field[0]}"::TIME AS {raw_field[0]}""")
36+
elif type_code == Psycopg2TypeCode.INTERVALOID.value:
37+
cols.append(f""""{raw_field[0]}"::TEXT AS {raw_field[0]}""")
38+
else:
39+
cols.append(f'"{raw_field[0]}"')
40+
return f"""SELECT {", ".join(cols)} FROM {table_or_query}"""

src/snowflake/snowpark/_internal/data_source/drivers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"SqliteDriver",
99
"PyodbcDriver",
1010
"DatabricksDriver",
11+
"Psycopg2Driver",
1112
"PymysqlDriver",
1213
]
1314

@@ -20,4 +21,7 @@
2021
from snowflake.snowpark._internal.data_source.drivers.databricks_driver import (
2122
DatabricksDriver,
2223
)
24+
from snowflake.snowpark._internal.data_source.drivers.psycopg2_driver import (
25+
Psycopg2Driver,
26+
)
2327
from snowflake.snowpark._internal.data_source.drivers.pymsql_driver import PymysqlDriver

src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
5151
f"{self.__class__.__name__} has not implemented to_snow_type function"
5252
)
5353

54+
@staticmethod
5455
def prepare_connection(
55-
self,
5656
conn: "Connection",
5757
query_timeout: int = 0,
5858
) -> "Connection":
@@ -100,7 +100,7 @@ def udtf_ingestion(
100100
udtf_name = f"data_source_udtf_{generate_random_alphanumeric(5)}"
101101
start = time.time()
102102
session.udtf.register(
103-
self.udtf_class_builder(fetch_size=fetch_size),
103+
self.udtf_class_builder(fetch_size=fetch_size, schema=schema),
104104
name=udtf_name,
105105
output_schema=StructType(
106106
[
@@ -119,7 +119,9 @@ def udtf_ingestion(
119119
res = session.sql(call_udtf_sql, _emit_ast=_emit_ast)
120120
return self.to_result_snowpark_df_udtf(res, schema, _emit_ast=_emit_ast)
121121

122-
def udtf_class_builder(self, fetch_size: int = 1000) -> type:
122+
def udtf_class_builder(
123+
self, fetch_size: int = 1000, schema: StructType = None
124+
) -> type:
123125
create_connection = self.create_connection
124126

125127
class UDTFIngestion:

src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
6161
all_columns.append(StructField(column_name, data_type, True))
6262
return StructType(all_columns)
6363

64-
def udtf_class_builder(self, fetch_size: int = 1000) -> type:
64+
def udtf_class_builder(
65+
self, fetch_size: int = 1000, schema: StructType = None
66+
) -> type:
6567
create_connection = self.create_connection
6668

6769
class UDTFIngestion:

src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
103103

104104
return StructType(fields)
105105

106+
@staticmethod
106107
def prepare_connection(
107-
self,
108108
conn: "Connection",
109109
query_timeout: int = 0,
110110
) -> "Connection":
@@ -113,7 +113,9 @@ def prepare_connection(
113113
conn.outputtypehandler = output_type_handler
114114
return conn
115115

116-
def udtf_class_builder(self, fetch_size: int = 1000) -> type:
116+
def udtf_class_builder(
117+
self, fetch_size: int = 1000, schema: StructType = None
118+
) -> type:
117119
create_connection = self.create_connection
118120

119121
def oracledb_output_type_handler(cursor, metadata):

0 commit comments

Comments
 (0)