diff --git a/.github/workflows/parameters/parameters_dbapi.py.gpg b/.github/workflows/parameters/parameters_dbapi.py.gpg index 0fb874c74f..6a792a00e5 100644 Binary files a/.github/workflows/parameters/parameters_dbapi.py.gpg and b/.github/workflows/parameters/parameters_dbapi.py.gpg differ diff --git a/.github/workflows/precommit.yml b/.github/workflows/precommit.yml index b17d12f750..2f3126410d 100644 --- a/.github/workflows/precommit.yml +++ b/.github/workflows/precommit.yml @@ -210,6 +210,8 @@ jobs: TOX_PARALLEL_NO_SPINNER: 1 shell: bash - name: Run data source tests + # psycopg2 is not supported on macos 3.9 + if: ${{ !(matrix.os == 'macos-latest' && matrix.python-version == '3.9') }} run: python -m tox -e datasource env: PYTHON_VERSION: ${{ matrix.python-version }} diff --git a/CHANGELOG.md b/CHANGELOG.md index bd808a2cd5..3ad6937409 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,9 @@ #### New Features -- Added support for ingestion with Snowflake UDTF to databricks in `DataFrameReader.dbapi` (PrPr). -- Added support for Mysql in `DataFrameWriter.dbapi` (PrPr). +- Added support for MySQL in `DataFrameWriter.dbapi` (PrPr) for both Parquet and UDTF-based ingestion. +- Added support for PostgreSQL in `DataFrameReader.dbapi` (PrPr) for both Parquet and UDTF-based ingestion. +- Added support for Databricks in `DataFrameWriter.dbapi` (PrPr) for UDTF-based ingestion. #### Bug Fixes diff --git a/scripts/parameters.py.gpg b/scripts/parameters.py.gpg index 1e24f00ba2..c757093b60 100644 Binary files a/scripts/parameters.py.gpg and b/scripts/parameters.py.gpg differ diff --git a/src/snowflake/snowpark/_internal/data_source/dbms_dialects/__init__.py b/src/snowflake/snowpark/_internal/data_source/dbms_dialects/__init__.py index 8fd8287bc3..57304599aa 100644 --- a/src/snowflake/snowpark/_internal/data_source/dbms_dialects/__init__.py +++ b/src/snowflake/snowpark/_internal/data_source/dbms_dialects/__init__.py @@ -7,6 +7,7 @@ "SqlServerDialect", "OracledbDialect", "DatabricksDialect", + "PostgresDialect", "MysqlDialect", ] @@ -25,6 +26,9 @@ from snowflake.snowpark._internal.data_source.dbms_dialects.databricks_dialect import ( DatabricksDialect, ) +from snowflake.snowpark._internal.data_source.dbms_dialects.postgresql_dialect import ( + PostgresDialect, +) from snowflake.snowpark._internal.data_source.dbms_dialects.mysql_dialect import ( MysqlDialect, ) diff --git a/src/snowflake/snowpark/_internal/data_source/dbms_dialects/postgresql_dialect.py b/src/snowflake/snowpark/_internal/data_source/dbms_dialects/postgresql_dialect.py new file mode 100644 index 0000000000..d24c8185bd --- /dev/null +++ b/src/snowflake/snowpark/_internal/data_source/dbms_dialects/postgresql_dialect.py @@ -0,0 +1,40 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# +from typing import List + +from snowflake.snowpark._internal.data_source.dbms_dialects import BaseDialect +from snowflake.snowpark._internal.data_source.drivers.psycopg2_driver import ( + Psycopg2TypeCode, +) +from snowflake.snowpark.types import StructType + + +class PostgresDialect(BaseDialect): + @staticmethod + def generate_select_query( + table_or_query: str, schema: StructType, raw_schema: List[tuple], is_query: bool + ) -> str: + cols = [] + for _field, raw_field in zip(schema.fields, raw_schema): + # databricks-sql-connector returns list of tuples for MapType + # here we push down to-dict conversion to Databricks + type_code = raw_field[1] + if type_code in ( + Psycopg2TypeCode.JSONB.value, + Psycopg2TypeCode.JSON.value, + ): + cols.append(f"""TO_JSON("{raw_field[0]}")::TEXT AS {raw_field[0]}""") + elif type_code == Psycopg2TypeCode.CASHOID.value: + cols.append( + f"""CASE WHEN "{raw_field[0]}" IS NULL THEN NULL ELSE FORMAT('"%s"', "{raw_field[0]}"::TEXT) END AS {raw_field[0]}""" + ) + elif type_code == Psycopg2TypeCode.BYTEAOID.value: + cols.append(f"""ENCODE("{raw_field[0]}", 'HEX') AS {raw_field[0]}""") + elif type_code == Psycopg2TypeCode.TIMETZOID.value: + cols.append(f""""{raw_field[0]}"::TIME AS {raw_field[0]}""") + elif type_code == Psycopg2TypeCode.INTERVALOID.value: + cols.append(f""""{raw_field[0]}"::TEXT AS {raw_field[0]}""") + else: + cols.append(f'"{raw_field[0]}"') + return f"""SELECT {", ".join(cols)} FROM {table_or_query}""" diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/__init__.py b/src/snowflake/snowpark/_internal/data_source/drivers/__init__.py index ea967e9deb..1fe51f6811 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/__init__.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/__init__.py @@ -8,6 +8,7 @@ "SqliteDriver", "PyodbcDriver", "DatabricksDriver", + "Psycopg2Driver", "PymysqlDriver", ] @@ -20,4 +21,7 @@ from snowflake.snowpark._internal.data_source.drivers.databricks_driver import ( DatabricksDriver, ) +from snowflake.snowpark._internal.data_source.drivers.psycopg2_driver import ( + Psycopg2Driver, +) from snowflake.snowpark._internal.data_source.drivers.pymsql_driver import PymysqlDriver diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py index e053a6c28f..a8fc124820 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py @@ -51,8 +51,8 @@ def to_snow_type(self, schema: List[Any]) -> StructType: f"{self.__class__.__name__} has not implemented to_snow_type function" ) + @staticmethod def prepare_connection( - self, conn: "Connection", query_timeout: int = 0, ) -> "Connection": @@ -100,7 +100,7 @@ def udtf_ingestion( udtf_name = f"data_source_udtf_{generate_random_alphanumeric(5)}" start = time.time() session.udtf.register( - self.udtf_class_builder(fetch_size=fetch_size), + self.udtf_class_builder(fetch_size=fetch_size, schema=schema), name=udtf_name, output_schema=StructType( [ @@ -119,7 +119,9 @@ def udtf_ingestion( res = session.sql(call_udtf_sql, _emit_ast=_emit_ast) return self.to_result_snowpark_df_udtf(res, schema, _emit_ast=_emit_ast) - def udtf_class_builder(self, fetch_size: int = 1000) -> type: + def udtf_class_builder( + self, fetch_size: int = 1000, schema: StructType = None + ) -> type: create_connection = self.create_connection class UDTFIngestion: diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py index 0fab5a4b15..6c17daf6b2 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/databricks_driver.py @@ -61,7 +61,9 @@ def to_snow_type(self, schema: List[Any]) -> StructType: all_columns.append(StructField(column_name, data_type, True)) return StructType(all_columns) - def udtf_class_builder(self, fetch_size: int = 1000) -> type: + def udtf_class_builder( + self, fetch_size: int = 1000, schema: StructType = None + ) -> type: create_connection = self.create_connection class UDTFIngestion: diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py index 06cf1f2adc..f3c544a477 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/oracledb_driver.py @@ -103,8 +103,8 @@ def to_snow_type(self, schema: List[Any]) -> StructType: return StructType(fields) + @staticmethod def prepare_connection( - self, conn: "Connection", query_timeout: int = 0, ) -> "Connection": @@ -113,7 +113,9 @@ def prepare_connection( conn.outputtypehandler = output_type_handler return conn - def udtf_class_builder(self, fetch_size: int = 1000) -> type: + def udtf_class_builder( + self, fetch_size: int = 1000, schema: StructType = None + ) -> type: create_connection = self.create_connection def oracledb_output_type_handler(cursor, metadata): diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py new file mode 100644 index 0000000000..8b9bef506e --- /dev/null +++ b/src/snowflake/snowpark/_internal/data_source/drivers/psycopg2_driver.py @@ -0,0 +1,288 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# +import logging +from enum import Enum +from typing import Callable, List, Any, TYPE_CHECKING + +from snowflake.snowpark._internal.data_source.datasource_typing import Connection +from snowflake.snowpark._internal.data_source.drivers import BaseDriver +from snowflake.snowpark.functions import to_variant, parse_json, column +from snowflake.snowpark.types import ( + StructType, + IntegerType, + StringType, + DecimalType, + BooleanType, + DateType, + DoubleType, + TimestampType, + VariantType, + FloatType, + BinaryType, + TimeType, + TimestampTimeZone, + StructField, +) + +if TYPE_CHECKING: + from snowflake.snowpark.session import Session # pragma: no cover + from snowflake.snowpark.dataframe import DataFrame # pragma: no cover + + +logger = logging.getLogger(__name__) + + +# The following Enum Class is generated from the following two docs: +# 1. https://github.com/psycopg/psycopg2/blob/master/psycopg/pgtypes.h +# 2. https://www.postgresql.org/docs/current/datatype.html +# pgtypes.h includes a broad range of type codes, but some newer type codes are missing. +# We will focus on the overlapping types that appear in both the documentation and the results from our Postgres tests. +class Psycopg2TypeCode(Enum): + BOOLOID = 16 + BYTEAOID = 17 + CHAROID = 18 + # NAMEOID = 19 # Not listed in the Postgres doc. + INT8OID = 20 + INT2OID = 21 + # INT2VECTOROID = 22 # Not listed in the Postgres doc. + INT4OID = 23 + # REGPROCOID = 24 # Not listed in the Postgres doc. + TEXTOID = 25 + # OIDOID = 26 # Not listed in the Postgres doc. + # TIDOID = 27 # Not listed in the Postgres doc. + # XIDOID = 28 # Not listed in the Postgres doc. + # CIDOID = 29 # Not listed in the Postgres doc. + # OIDVECTOROID = 30 # Not listed in the Postgres doc. + # PG_TYPE_RELTYPE_OID = 71 # Not listed in the Postgres doc. + # PG_ATTRIBUTE_RELTYPE_OID = 75 # Not listed in the Postgres doc. + # PG_PROC_RELTYPE_OID = 81 # Not listed in the Postgres doc. + # PG_CLASS_RELTYPE_OID = 83 # Not listed in the Postgres doc. + JSON = 114 # Not listed in the pgtypes.h + XML = 142 # Not listed in the pgtypes.h + POINTOID = 600 + LSEGOID = 601 + PATHOID = 602 + BOXOID = 603 + POLYGONOID = 604 + LINEOID = 628 + FLOAT4OID = 700 + FLOAT8OID = 701 + # ABSTIMEOID = 702 # Not listed in the Postgres doc. + # RELTIMEOID = 703 # Not listed in the Postgres doc. + # TINTERVALOID = 704 # Not listed in the Postgres doc. + # UNKNOWNOID = 705 # Not listed in the Postgres doc. + CIRCLEOID = 718 + MACADDR8 = 774 # Not listed in the pgtypes.h + CASHOID = 790 # MONEY + MACADDROID = 829 + CIDROID = 650 + INETOID = 869 + INT4ARRAYOID = 1007 # Not listed in the Postgres doc. + ACLITEMOID = 1033 # Not listed in the Postgres doc. + BPCHAROID = 1042 + VARCHAROID = 1043 + DATEOID = 1082 + TIMEOID = 1083 + TIMESTAMPOID = 1114 + TIMESTAMPTZOID = 1184 + INTERVALOID = 1186 + TIMETZOID = 1266 + BITOID = 1560 + VARBITOID = 1562 + NUMERICOID = 1700 + # REFCURSOROID = 1790 # Not listed in the Postgres doc. + # REGPROCEDUREOID = 2202 # Not listed in the Postgres doc. + # REGOPEROID = 2203 # Not listed in the Postgres doc. + # REGOPERATOROID = 2204 # Not listed in the Postgres doc. + # REGCLASSOID = 2205 # Not listed in the Postgres doc. + # REGTYPEOID = 2206 # Not listed in the Postgres doc. + # RECORDOID = 2249 # Not listed in the Postgres doc. + # CSTRINGOID = 2275 # Not listed in the Postgres doc. + # ANYOID = 2276 # Not listed in the Postgres doc. + # ANYARRAYOID = 2277 # Not listed in the Postgres doc. + # VOIDOID = 2278 # Not listed in the Postgres doc. + # TRIGGEROID = 2279 # Not listed in the Postgres doc. + # LANGUAGE_HANDLEROID = 2280 # Not listed in the Postgres doc. + # INTERNALOID = 2281 # Not listed in the Postgres doc. + # OPAQUEOID = 2282 # Not listed in the Postgres doc. + # ANYELEMENTOID = 2283 # Not listed in the Postgres doc. + UUID = 2950 # Not listed in the pgtypes.h + TXID_SNAPSHOT = 2970 # Not listed in the pgtypes.h + PG_LSN = 3220 # Not listed in the pgtypes.h + TSVECTOR = 3614 # Not listed in the pgtypes.h + TSQUERY = 3615 # Not listed in the pgtypes.h + JSONB = 3802 # Not listed in the pgtypes.h + PG_SNAPSHOT = 5038 # Not listed in the pgtypes.h + + +# https://other-docs.snowflake.com/en/connectors/postgres6/view-data#postgresql-to-snowflake-data-type-mapping +BASE_POSTGRES_TYPE_TO_SNOW_TYPE = { + Psycopg2TypeCode.BOOLOID: BooleanType, + Psycopg2TypeCode.BYTEAOID: BinaryType, + Psycopg2TypeCode.CHAROID: StringType, + Psycopg2TypeCode.INT8OID: IntegerType, + Psycopg2TypeCode.INT2OID: IntegerType, + Psycopg2TypeCode.INT4OID: IntegerType, + Psycopg2TypeCode.TEXTOID: StringType, + Psycopg2TypeCode.POINTOID: StringType, + Psycopg2TypeCode.LSEGOID: StringType, + Psycopg2TypeCode.PATHOID: StringType, + Psycopg2TypeCode.BOXOID: StringType, + Psycopg2TypeCode.POLYGONOID: StringType, + Psycopg2TypeCode.LINEOID: StringType, + Psycopg2TypeCode.FLOAT4OID: FloatType, + Psycopg2TypeCode.FLOAT8OID: DoubleType, + Psycopg2TypeCode.CIRCLEOID: StringType, + Psycopg2TypeCode.CASHOID: VariantType, + Psycopg2TypeCode.MACADDROID: StringType, + Psycopg2TypeCode.CIDROID: StringType, + Psycopg2TypeCode.INETOID: StringType, + Psycopg2TypeCode.BPCHAROID: StringType, + Psycopg2TypeCode.VARCHAROID: StringType, + Psycopg2TypeCode.DATEOID: DateType, + Psycopg2TypeCode.TIMEOID: TimeType, + Psycopg2TypeCode.TIMESTAMPOID: TimestampType, + Psycopg2TypeCode.TIMESTAMPTZOID: TimestampType, + Psycopg2TypeCode.INTERVALOID: StringType, + Psycopg2TypeCode.TIMETZOID: TimeType, + Psycopg2TypeCode.BITOID: StringType, + Psycopg2TypeCode.VARBITOID: StringType, + Psycopg2TypeCode.NUMERICOID: DecimalType, + Psycopg2TypeCode.JSON: VariantType, + Psycopg2TypeCode.JSONB: VariantType, + Psycopg2TypeCode.MACADDR8: StringType, + Psycopg2TypeCode.UUID: StringType, + Psycopg2TypeCode.XML: StringType, + Psycopg2TypeCode.TSVECTOR: StringType, + Psycopg2TypeCode.TSQUERY: StringType, + Psycopg2TypeCode.TXID_SNAPSHOT: StringType, + Psycopg2TypeCode.PG_LSN: StringType, + Psycopg2TypeCode.PG_SNAPSHOT: StringType, +} + + +class Psycopg2Driver(BaseDriver): + def __init__( + self, create_connection: Callable[[], "Connection"], dbms_type: Enum + ) -> None: + super().__init__(create_connection, dbms_type) + + def to_snow_type(self, schema: List[Any]) -> StructType: + # The psycopg2 spec is defined in the following links: + # https://www.psycopg.org/docs/cursor.html#cursor.description + # https://www.psycopg.org/docs/extensions.html#psycopg2.extensions.Column + fields = [] + for ( + name, + type_code, + _display_size, + _internal_size, + precision, + scale, + _null_ok, + ) in schema: + try: + type_code = Psycopg2TypeCode(type_code) + except ValueError: + raise NotImplementedError( + f"Postgres type not supported: {type_code} for column: {name}" + ) + snow_type = BASE_POSTGRES_TYPE_TO_SNOW_TYPE.get(type_code) + if snow_type is None: + raise NotImplementedError( + f"Postgres type not supported: {type_code} for column: {name}" + ) + if Psycopg2TypeCode(type_code) == Psycopg2TypeCode.NUMERICOID: + if not self.validate_numeric_precision_scale(precision, scale): + logger.debug( + f"Snowpark does not support column" + f" {name} of type {type_code} with precision {precision} and scale {scale}. " + "The default Numeric precision and scale will be used." + ) + precision, scale = None, None + data_type = snow_type( + precision if precision is not None else 38, + scale if scale is not None else 0, + ) + elif type_code == Psycopg2TypeCode.TIMESTAMPTZOID: + data_type = snow_type(TimestampTimeZone.TZ) + else: + data_type = snow_type() + fields.append(StructField(name, data_type, True)) + return StructType(fields) + + @staticmethod + def to_result_snowpark_df( + session: "Session", table_name, schema, _emit_ast: bool = True + ) -> "DataFrame": + project_columns = [] + for field in schema.fields: + if isinstance(field.datatype, VariantType): + project_columns.append( + to_variant(parse_json(column(field.name))).as_(field.name) + ) + else: + project_columns.append(column(field.name)) + return session.table(table_name, _emit_ast=_emit_ast).select( + project_columns, _emit_ast=_emit_ast + ) + + @staticmethod + def to_result_snowpark_df_udtf( + res_df: "DataFrame", + schema: StructType, + _emit_ast: bool = True, + ): + cols = [] + for field in schema.fields: + if isinstance(field.datatype, VariantType): + cols.append(to_variant(parse_json(column(field.name))).as_(field.name)) + else: + cols.append(res_df[field.name].cast(field.datatype).alias(field.name)) + return res_df.select(cols, _emit_ast=_emit_ast) + + @staticmethod + def prepare_connection( + conn: "Connection", + query_timeout: int = 0, + ) -> "Connection": + if query_timeout: + # https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT + # postgres default uses milliseconds + conn.cursor().execute(f"SET STATEMENT_TIMEOUT = {query_timeout * 1000}") + return conn + + def udtf_class_builder( + self, fetch_size: int = 1000, schema: StructType = None + ) -> type: + create_connection = self.create_connection + + # TODO: SNOW-2101485 use class method to prepare connection + # ideally we should use the same function as prepare_connection + # however, since we introduce new module for new driver support and initially the new module is not available in the backend + # so if registering UDTF which uses the class method, cloudpickle will pickle the class method along with + # the new module -- this leads to not being able to find the new module when unpickling on the backend. + # once the new module is available in the backend, we can use the class method. + def prepare_connection_in_udtf( + conn: "Connection", + query_timeout: int = 0, + ) -> "Connection": + if query_timeout: + # https://www.postgresql.org/docs/current/runtime-config-client.html#GUC-STATEMENT-TIMEOUT + # postgres default uses milliseconds + conn.cursor().execute(f"SET STATEMENT_TIMEOUT = {query_timeout * 1000}") + return conn + + class UDTFIngestion: + def process(self, query: str): + conn = prepare_connection_in_udtf(create_connection()) + cursor = conn.cursor() + cursor.execute(query) + while True: + rows = cursor.fetchmany(fetch_size) + if not rows: + break + yield from rows + + return UDTFIngestion diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py index 095a506819..a3e403e5ea 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/pymsql_driver.py @@ -182,7 +182,9 @@ def to_snow_type(self, schema: List[Any]) -> StructType: fields.append(StructField(name, data_type, null_ok)) return StructType(fields) - def udtf_class_builder(self, fetch_size: int = 1000) -> type: + def udtf_class_builder( + self, fetch_size: int = 1000, schema: StructType = None + ) -> type: create_connection = self.create_connection class UDTFIngestion: diff --git a/src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py b/src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py index 1d561a903a..ba01ad3072 100644 --- a/src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py +++ b/src/snowflake/snowpark/_internal/data_source/drivers/pyodbc_driver.py @@ -79,7 +79,9 @@ def to_snow_type(self, schema: List[Any]) -> StructType: fields.append(StructField(name, data_type, null_ok)) return StructType(fields) - def udtf_class_builder(self, fetch_size: int = 1000) -> type: + def udtf_class_builder( + self, fetch_size: int = 1000, schema: StructType = None + ) -> type: create_connection = self.create_connection def binary_converter(value): @@ -110,8 +112,8 @@ def process(self, query: str): return UDTFIngestion + @staticmethod def prepare_connection( - self, conn: "Connection", query_timeout: int = 0, ) -> "Connection": diff --git a/src/snowflake/snowpark/_internal/data_source/utils.py b/src/snowflake/snowpark/_internal/data_source/utils.py index adf23f1d99..55d2dad25b 100644 --- a/src/snowflake/snowpark/_internal/data_source/utils.py +++ b/src/snowflake/snowpark/_internal/data_source/utils.py @@ -13,6 +13,7 @@ OracledbDialect, SqlServerDialect, DatabricksDialect, + PostgresDialect, MysqlDialect, ) from snowflake.snowpark._internal.data_source.drivers import ( @@ -20,6 +21,7 @@ OracledbDriver, PyodbcDriver, DatabricksDriver, + Psycopg2Driver, PymysqlDriver, ) import snowflake @@ -43,6 +45,7 @@ class DBMS_TYPE(Enum): ORACLE_DB = "ORACLE_DB" SQLITE_DB = "SQLITE3_DB" DATABRICKS_DB = "DATABRICKS_DB" + POSTGRES_DB = "POSTGRES_DB" MYSQL_DB = "MYSQL_DB" UNKNOWN = "UNKNOWN" @@ -52,6 +55,7 @@ class DRIVER_TYPE(str, Enum): ORACLEDB = "oracledb" SQLITE3 = "sqlite3" DATABRICKS = "databricks.sql.client" + PSYCOPG2 = "psycopg2.extensions" PYMYSQL = "pymysql.connections" UNKNOWN = "unknown" @@ -61,6 +65,7 @@ class DRIVER_TYPE(str, Enum): DBMS_TYPE.ORACLE_DB: OracledbDialect, DBMS_TYPE.SQLITE_DB: Sqlite3Dialect, DBMS_TYPE.DATABRICKS_DB: DatabricksDialect, + DBMS_TYPE.POSTGRES_DB: PostgresDialect, DBMS_TYPE.MYSQL_DB: MysqlDialect, } @@ -69,17 +74,19 @@ class DRIVER_TYPE(str, Enum): DRIVER_TYPE.ORACLEDB: OracledbDriver, DRIVER_TYPE.SQLITE3: SqliteDriver, DRIVER_TYPE.DATABRICKS: DatabricksDriver, + DRIVER_TYPE.PSYCOPG2: Psycopg2Driver, DRIVER_TYPE.PYMYSQL: PymysqlDriver, } UDTF_PACKAGE_MAP = { - DBMS_TYPE.ORACLE_DB: ["oracledb", "snowflake-snowpark-python"], + DBMS_TYPE.ORACLE_DB: ["oracledb>=2.0.0,<4.0.0", "snowflake-snowpark-python"], DBMS_TYPE.SQLITE_DB: ["snowflake-snowpark-python"], DBMS_TYPE.SQL_SERVER_DB: [ - "pyodbc>=4.0.26", + "pyodbc>=4.0.26,<6.0.0", "msodbcsql", "snowflake-snowpark-python", ], + DBMS_TYPE.POSTGRES_DB: ["psycopg2>=2.0.0,<3.0.0", "snowflake-snowpark-python"], DBMS_TYPE.DATABRICKS_DB: [ "snowflake-snowpark-python", "databricks-sql-connector>=4.0.0,<5.0.0", @@ -129,6 +136,7 @@ def detect_dbms_pyodbc(dbapi2_conn): DRIVER_TYPE.ORACLEDB: lambda conn: DBMS_TYPE.ORACLE_DB, DRIVER_TYPE.SQLITE3: lambda conn: DBMS_TYPE.SQLITE_DB, DRIVER_TYPE.DATABRICKS: lambda conn: DBMS_TYPE.DATABRICKS_DB, + DRIVER_TYPE.PSYCOPG2: lambda conn: DBMS_TYPE.POSTGRES_DB, DRIVER_TYPE.PYMYSQL: lambda conn: DBMS_TYPE.MYSQL_DB, } diff --git a/tests/integ/datasource/test_postgres.py b/tests/integ/datasource/test_postgres.py new file mode 100644 index 0000000000..ae66497030 --- /dev/null +++ b/tests/integ/datasource/test_postgres.py @@ -0,0 +1,433 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +import pytest + +from snowflake.snowpark import Row +from snowflake.snowpark._internal.data_source.drivers import Psycopg2Driver +from snowflake.snowpark._internal.data_source.drivers.psycopg2_driver import ( + Psycopg2TypeCode, +) +from snowflake.snowpark._internal.data_source.utils import DBMS_TYPE +from snowflake.snowpark.exceptions import SnowparkDataframeReaderException +from snowflake.snowpark.types import ( + DecimalType, + BinaryType, + VariantType, + StructType, + StructField, + StringType, + TimeType, + BooleanType, + IntegerType, + FloatType, + DoubleType, + DateType, + TimestampType, + TimestampTimeZone, +) +from snowflake.snowpark._internal.data_source.dbms_dialects.postgresql_dialect import ( + PostgresDialect, +) +from tests.parameters import POSTGRES_CONNECTION_PARAMETERS +from tests.resources.test_data_source_dir.test_postgres_data import ( + POSTGRES_TABLE_NAME, + EXPECTED_TEST_DATA, + EXPECTED_TYPE, + POSTGRES_TEST_EXTERNAL_ACCESS_INTEGRATION, +) +from tests.utils import IS_IN_STORED_PROC + +DEPENDENCIES_PACKAGE_UNAVAILABLE = True + +try: + import psycopg2 # noqa: F401 + import pandas # noqa: F401 + + DEPENDENCIES_PACKAGE_UNAVAILABLE = False +except ImportError: + pass + + +pytestmark = [ + pytest.mark.skipif(DEPENDENCIES_PACKAGE_UNAVAILABLE, reason="Missing 'psycopg2'"), + pytest.mark.skipif(IS_IN_STORED_PROC, reason="Need External Access Integration"), +] + + +def create_postgres_connection(): + return psycopg2.connect(**POSTGRES_CONNECTION_PARAMETERS) + + +@pytest.mark.parametrize( + "input_type, input_value", + [ + ("table", POSTGRES_TABLE_NAME), + ("query", f"(SELECT * FROM {POSTGRES_TABLE_NAME})"), + ], +) +def test_basic_postgres(session, input_type, input_value): + input_dict = { + input_type: input_value, + } + df = session.read.dbapi(create_postgres_connection, **input_dict) + assert df.collect() == EXPECTED_TEST_DATA and df.schema == EXPECTED_TYPE + + +@pytest.mark.parametrize( + "input_type, input_value, error_message", + [ + ("table", "NONEXISTTABLE", "does not exist"), + ("query", "SELEC ** FORM TABLE", "syntax error at or near"), + ], +) +def test_error_case(session, input_type, input_value, error_message): + input_dict = { + input_type: input_value, + } + with pytest.raises(SnowparkDataframeReaderException, match=error_message): + session.read.dbapi(create_postgres_connection, **input_dict) + + +def test_query_timeout(session): + with pytest.raises( + SnowparkDataframeReaderException, + match=r"due to exception 'QueryCanceled\('canceling statement due to statement timeout", + ): + session.read.dbapi( + create_postgres_connection, + table=POSTGRES_TABLE_NAME, + query_timeout=1, + session_init_statement=["SELECT pg_sleep(5)"], + ) + + +def test_external_access_integration_not_set(session): + with pytest.raises( + ValueError, + match="external_access_integration cannot be None when udtf ingestion is used.", + ): + session.read.dbapi( + create_postgres_connection, table=POSTGRES_TABLE_NAME, udtf_configs={} + ) + + +def test_unicode_column_name_postgres(session): + df = session.read.dbapi( + create_postgres_connection, table='test_schema."用户資料"' + ).order_by("編號") + assert df.collect() == [Row(編號=1, 姓名="山田太郎", 國家="日本", 備註="これはUnicodeテストです")] + assert df.columns == ['"編號"', '"姓名"', '"國家"', '"備註"'] + + +@pytest.mark.parametrize( + "input_type, input_value", + [ + ("table", POSTGRES_TABLE_NAME), + ("query", f"(SELECT * FROM {POSTGRES_TABLE_NAME})"), + ], +) +def test_udtf_ingestion_postgres(session, input_type, input_value, caplog): + from tests.parameters import POSTGRES_CONNECTION_PARAMETERS + + def create_connection_postgres(): + import psycopg2 + + return psycopg2.connect(**POSTGRES_CONNECTION_PARAMETERS) + + input_dict = { + input_type: input_value, + } + df = session.read.dbapi( + create_connection_postgres, + **input_dict, + udtf_configs={ + "external_access_integration": POSTGRES_TEST_EXTERNAL_ACCESS_INTEGRATION + }, + ).order_by("BIGSERIAL_COL") + + assert df.collect() == EXPECTED_TEST_DATA + # assert UDTF creation and UDTF call + assert ( + "TEMPORARY FUNCTION data_source_udtf_" "" in caplog.text + and "table(data_source_udtf" in caplog.text + ) + + +def test_psycopg2_driver_udtf_class_builder(): + """Test the UDTF class builder in Psycopg2Driver using a real PostgreSQL connection""" + # Create the driver with the real connection function + driver = Psycopg2Driver(create_postgres_connection, DBMS_TYPE.POSTGRES_DB) + + # Get the UDTF class with a small fetch size to test batching + UDTFClass = driver.udtf_class_builder(fetch_size=2) + + # Instantiate the UDTF class + udtf_instance = UDTFClass() + + # Test with a simple query that should return a few rows + test_query = f"SELECT * FROM {POSTGRES_TABLE_NAME} LIMIT 5" + result_rows = list(udtf_instance.process(test_query)) + + # Verify we got some data back (we know the test table has data from other tests) + assert len(result_rows) > 0 + + # Test with a query that returns specific columns + test_columns_query = ( + f"SELECT TEXT_COL, BIGINT_COL FROM {POSTGRES_TABLE_NAME} LIMIT 3" + ) + column_result_rows = list(udtf_instance.process(test_columns_query)) + + # Verify we got data with the right structure (2 columns) + assert len(column_result_rows) > 0 + assert len(column_result_rows[0]) == 2 # Two columns + + +def test_unit_psycopg2_driver_to_snow_type_mapping(): + """Test the mapping of PostgreSQL types to Snowflake types in Psycopg2Driver.to_snow_type""" + driver = Psycopg2Driver(create_postgres_connection, DBMS_TYPE.POSTGRES_DB) + + # Test basic types + basic_schema = [ + ("bool_col", Psycopg2TypeCode.BOOLOID.value, None, None, None, None, True), + ("int2_col", Psycopg2TypeCode.INT2OID.value, None, None, None, None, True), + ("int4_col", Psycopg2TypeCode.INT4OID.value, None, None, None, None, True), + ("int8_col", Psycopg2TypeCode.INT8OID.value, None, None, None, None, True), + ("text_col", Psycopg2TypeCode.TEXTOID.value, None, None, None, None, True), + ( + "varchar_col", + Psycopg2TypeCode.VARCHAROID.value, + None, + None, + None, + None, + True, + ), + ("char_col", Psycopg2TypeCode.CHAROID.value, None, None, None, None, True), + ] + + result = driver.to_snow_type(basic_schema) + + assert len(result.fields) == 7 + assert isinstance(result.fields[0].datatype, BooleanType) + assert isinstance(result.fields[1].datatype, IntegerType) + assert isinstance(result.fields[2].datatype, IntegerType) + assert isinstance(result.fields[3].datatype, IntegerType) + assert isinstance(result.fields[4].datatype, StringType) + assert isinstance(result.fields[5].datatype, StringType) + assert isinstance(result.fields[6].datatype, StringType) + + # Test float types + float_schema = [ + ("float4_col", Psycopg2TypeCode.FLOAT4OID.value, None, None, None, None, True), + ("float8_col", Psycopg2TypeCode.FLOAT8OID.value, None, None, None, None, True), + ] + + result = driver.to_snow_type(float_schema) + + assert len(result.fields) == 2 + assert isinstance(result.fields[0].datatype, FloatType) + assert isinstance(result.fields[1].datatype, DoubleType) + + # Test date and time types + datetime_schema = [ + ("date_col", Psycopg2TypeCode.DATEOID.value, None, None, None, None, True), + ("time_col", Psycopg2TypeCode.TIMEOID.value, None, None, None, None, True), + ("timetz_col", Psycopg2TypeCode.TIMETZOID.value, None, None, None, None, True), + ( + "timestamp_col", + Psycopg2TypeCode.TIMESTAMPOID.value, + None, + None, + None, + None, + True, + ), + ( + "timestamptz_col", + Psycopg2TypeCode.TIMESTAMPTZOID.value, + None, + None, + None, + None, + True, + ), + ( + "interval_col", + Psycopg2TypeCode.INTERVALOID.value, + None, + None, + None, + None, + True, + ), + ] + + result = driver.to_snow_type(datetime_schema) + + assert len(result.fields) == 6 + assert isinstance(result.fields[0].datatype, DateType) + assert isinstance(result.fields[1].datatype, TimeType) + assert isinstance(result.fields[2].datatype, TimeType) + assert isinstance(result.fields[3].datatype, TimestampType) + assert isinstance(result.fields[4].datatype, TimestampType) + # Check timezone-aware timestamp + assert result.fields[4].datatype.tz == TimestampTimeZone.TZ + assert isinstance(result.fields[5].datatype, StringType) + + # Test binary and complex types + complex_schema = [ + ("bytea_col", Psycopg2TypeCode.BYTEAOID.value, None, None, None, None, True), + ("json_col", Psycopg2TypeCode.JSON.value, None, None, None, None, True), + ("jsonb_col", Psycopg2TypeCode.JSONB.value, None, None, None, None, True), + ("uuid_col", Psycopg2TypeCode.UUID.value, None, None, None, None, True), + ("cash_col", Psycopg2TypeCode.CASHOID.value, None, None, None, None, True), + ("inet_col", Psycopg2TypeCode.INETOID.value, None, None, None, None, True), + ] + + result = driver.to_snow_type(complex_schema) + + assert len(result.fields) == 6 + assert isinstance(result.fields[0].datatype, BinaryType) + assert isinstance(result.fields[1].datatype, VariantType) + assert isinstance(result.fields[2].datatype, VariantType) + assert isinstance(result.fields[3].datatype, StringType) + assert isinstance(result.fields[4].datatype, VariantType) + assert isinstance(result.fields[5].datatype, StringType) + + # Test numeric with various precision and scale + numeric_schema = [ + ( + "numeric_default", + Psycopg2TypeCode.NUMERICOID.value, + None, + None, + None, + None, + True, + ), + ("numeric_valid", Psycopg2TypeCode.NUMERICOID.value, None, None, 10, 2, True), + ("numeric_max", Psycopg2TypeCode.NUMERICOID.value, None, None, 38, 37, True), + ( + "numeric_invalid", + Psycopg2TypeCode.NUMERICOID.value, + None, + None, + 1000, + 1000, + True, + ), + ] + + result = driver.to_snow_type(numeric_schema) + + assert len(result.fields) == 4 + # Default precision/scale + assert isinstance(result.fields[0].datatype, DecimalType) + assert result.fields[0].datatype.precision == 38 + assert result.fields[0].datatype.scale == 0 + + # Valid precision/scale + assert isinstance(result.fields[1].datatype, DecimalType) + assert result.fields[1].datatype.precision == 10 + assert result.fields[1].datatype.scale == 2 + + # Max valid precision/scale + assert isinstance(result.fields[2].datatype, DecimalType) + assert result.fields[2].datatype.precision == 38 + assert result.fields[2].datatype.scale == 37 + + # Invalid precision/scale - should be defaulted + assert isinstance(result.fields[3].datatype, DecimalType) + assert result.fields[3].datatype.precision == 38 + assert result.fields[3].datatype.scale == 0 + + # Test unsupported type code + with pytest.raises(NotImplementedError, match="Postgres type not supported"): + nonexisting_type_code = -1 + Psycopg2Driver(create_postgres_connection, DBMS_TYPE.POSTGRES_DB).to_snow_type( + [("UNSUPPORTED_COL", nonexisting_type_code, None, None, None, None, True)] + ) + + # Test unsupported type code + with pytest.raises(NotImplementedError, match="Postgres type not supported"): + unimplemented_code = Psycopg2TypeCode.ACLITEMOID + Psycopg2Driver(create_postgres_connection, DBMS_TYPE.POSTGRES_DB).to_snow_type( + [("UNSUPPORTED_COL", unimplemented_code, None, None, None, None, True)] + ) + + +def test_unit_generate_select_query(): + # Create a mock schema with different field types + schema = StructType( + [ + StructField("json_col", VariantType()), + StructField("cash_col", VariantType()), + StructField("bytea_col", BinaryType()), + StructField("timetz_col", TimeType()), + StructField("interval_col", StringType()), + StructField("regular_col", StringType()), + ] + ) + + # Create mock raw schema - each tuple represents (name, type_code, display_size, internal_size, precision, scale, null_ok) + raw_schema = [ + ("json_col", Psycopg2TypeCode.JSON.value, None, None, None, None, True), + ("cash_col", Psycopg2TypeCode.CASHOID.value, None, None, None, None, True), + ("bytea_col", Psycopg2TypeCode.BYTEAOID.value, None, None, None, None, True), + ("timetz_col", Psycopg2TypeCode.TIMETZOID.value, None, None, None, None, True), + ( + "interval_col", + Psycopg2TypeCode.INTERVALOID.value, + None, + None, + None, + None, + True, + ), + ("regular_col", Psycopg2TypeCode.TEXTOID.value, None, None, None, None, True), + ] + + # Test with table name + table_query = PostgresDialect.generate_select_query( + "test_table", schema, raw_schema, is_query=False + ) + expected_table_query = ( + 'SELECT TO_JSON("json_col")::TEXT AS json_col, ' + 'CASE WHEN "cash_col" IS NULL THEN NULL ELSE FORMAT(\'"%s"\', "cash_col"::TEXT) END AS cash_col, ' + "ENCODE(\"bytea_col\", 'HEX') AS bytea_col, " + '"timetz_col"::TIME AS timetz_col, ' + '"interval_col"::TEXT AS interval_col, ' + '"regular_col" ' + "FROM test_table" + ) + assert table_query == expected_table_query + + # Test with subquery + subquery_query = PostgresDialect.generate_select_query( + "(SELECT * FROM test_table)", schema, raw_schema, is_query=True + ) + expected_subquery_query = ( + 'SELECT TO_JSON("json_col")::TEXT AS json_col, ' + 'CASE WHEN "cash_col" IS NULL THEN NULL ELSE FORMAT(\'"%s"\', "cash_col"::TEXT) END AS cash_col, ' + "ENCODE(\"bytea_col\", 'HEX') AS bytea_col, " + '"timetz_col"::TIME AS timetz_col, ' + '"interval_col"::TEXT AS interval_col, ' + '"regular_col" ' + "FROM (SELECT * FROM test_table)" + ) + assert subquery_query == expected_subquery_query + + # Test with JSONB type + jsonb_raw_schema = [ + ("jsonb_col", Psycopg2TypeCode.JSONB.value, None, None, None, None, True) + ] + jsonb_schema = StructType([StructField("jsonb_col", VariantType())]) + jsonb_query = PostgresDialect.generate_select_query( + "test_table", jsonb_schema, jsonb_raw_schema, is_query=False + ) + expected_jsonb_query = ( + 'SELECT TO_JSON("jsonb_col")::TEXT AS jsonb_col FROM test_table' + ) + assert jsonb_query == expected_jsonb_query diff --git a/tests/resources/test_data_source_dir/test_postgres_data.py b/tests/resources/test_data_source_dir/test_postgres_data.py new file mode 100644 index 0000000000..5f4db2dd54 --- /dev/null +++ b/tests/resources/test_data_source_dir/test_postgres_data.py @@ -0,0 +1,348 @@ +# +# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. +# + +import datetime +from decimal import Decimal + +from snowflake.snowpark.types import ( + StructType, + StructField, + StringType, + DecimalType, + DoubleType, + LongType, + BinaryType, + BooleanType, + VariantType, + DateType, + TimestampType, + TimeType, + TimestampTimeZone, +) + +POSTGRES_TABLE_NAME = "test_schema.ALL_TYPE_TABLE" +EXPECTED_TEST_DATA = [ + ( + -6645531000000000123, + 1, + "0000", + "10100101", + False, + "(13.68,50.87),(-14.5,-36.82)", + b"b35feb1d048b61ac6c3", + "6U0hXrbbRm", + "almMRlCPh3onp9celUXb", + "63.9.184.0/24", + "<(-29.47,-12.75),49.48>", + datetime.date(2004, 5, 9), + 1858055000.0, + "5.188.71.132", + 1865101000, + "4 years 10 mons 15 days 18:38:52", + '{\n "key": 123\n}', + '{\n "jsonb_key": 83\n}', + "{-2.18,8.69,3.09}", + "[(5.12,-83.91),(41.89,62.49)]", + "e8:16:cd:a9:9f:e6", + "e3:da:aa:fc:fb:51:86:f5", + '"$5,452.35"', + Decimal("113414.83"), + "((43.79,36.77),(-64.49,-34.68))", + "85D2538C/FFC30C2E", + "96:207:", + "(61.05,18.47)", + "((48.52,53.43),(89.46,69.09),(89.54,10.13))", + 64374.96, + -10428, + 1, + 1, + "OaVsansivU5I1BLQdUbRaYyzbYDmK6e", + datetime.time(19, 28, 51), + datetime.time(8, 26, 45), + datetime.datetime(2002, 2, 16, 2, 6), + datetime.datetime(2020, 1, 24, 20, 0, tzinfo=datetime.timezone.utc), + "'word3' & 'word1'", + "'lex1':1 'lex2':2 'lex4':3", + "10:20:10,14,15", + "69ad9235-6c5e-4f95-b179-9730ed771aa8", + "34", + ), + ( + -8065880000000000456, + 2, + "1010", + "11100010", + False, + "(29.78,3.39),(26.92,-5.57)", + b"\xc3\xac4977ddf59e03da6c1e", + "yhrAXej1DO", + "CnwnTp8SLJKTSeQAi8oW", + "205.202.89.0/24", + "<(13.94,18.73),4.52>", + datetime.date(2016, 7, 9), + 9561914000.0, + "136.74.101.171", + -2129554000, + "4 years 4 mons 18:37:09", + '{\n "key": 123\n}', + '{\n "jsonb_key": 79\n}', + "{-9.8,-0.07,0.36}", + "[(21.72,55.15),(-25.05,94.64)]", + "00:28:95:21:77:65", + "58:3a:c1:f7:a1:f4:d3:31", + '"$4,342.72"', + Decimal("-636436.38"), + "((-0.45,-24.74),(1.41,-82.18))", + "D4997CEE/F6A7D891", + "214:931:", + "(-3.91,82.16)", + "((21.36,22.34),(-4.94,90.88),(30.78,18.14),(25.06,60.1),(-98.04,74.43))", + 65115.97, + 32062, + 2, + 2, + "M6Z9jw QXfvErDkIj3xYeAg0IVCTrTnWx5hS4kSu", + datetime.time(14, 41, 19), + datetime.time(5, 47, 38), + datetime.datetime(2024, 12, 25, 7, 53, 11), + datetime.datetime(2000, 8, 20, 21, 26, 46, tzinfo=datetime.timezone.utc), + "'word2' & 'word3'", + "'lex1':1 'lex2':2 'lex3':3", + "701:924:720,800", + "721f1281-9f97-4b0a-b41a-bd789ad318ea", + "29", + ), + ( + 9083626000000000789, + 3, + "1100", + "11000110", + False, + "(9.1,86.79),(-46.22,-5.63)", + b"Rd745b97541083b56b9", + "wTTtppE3ND", + "mKb7K1PTNCoAdRGQHb2C", + "145.149.241.0/24", + "<(-91.43,-25.5),2.19>", + datetime.date(2019, 6, 9), + -3365294000.0, + "11.106.208.59", + -1889577000, + "1 year 10 mons 23 days 21:40:12", + '{\n "key": null\n}', + '{\n "jsonb_key": 39\n}', + "{6.62,9.28,1.6}", + "[(-13.37,-20.05),(-96.04,-48.7)]", + "4a:a0:71:cc:05:06", + "dc:c8:be:f4:0c:62:69:9c", + '"$7,161.12"', + Decimal("261715.46"), + "((-66.03,-52.76),(-53.48,82.21),(-99.03,7.58),(-63.51,97.87))", + "3F5DA3B6/515C58E5", + "86:763:", + "(-74.55,7.85)", + "((-6.65,69.89),(30.98,-66.15),(-96.04,-9.29))", + -81603.21, + -32235, + 3, + 3, + "SWPw6w6zINbqWooVf wC90ixy0Huv9tAIbaNkTrKolY48E", + datetime.time(7, 20, 38), + datetime.time(20, 4), + datetime.datetime(2017, 1, 28, 6, 29, 17), + datetime.datetime(2009, 10, 1, 8, 16, 17, tzinfo=datetime.timezone.utc), + "'word3' & 'word2'", + "'lex1':2 'lex3':1 'lex4':3", + "691:960:712,892", + "d0164090-8196-42f0-88f6-979a4039214d", + "50", + ), + ( + 8517274000000000123, + 4, + "1010", + "10100111", + False, + "(-26.08,-37.14),(-51.59,-94.61)", + b"n0457e0312dd983dc89", + "8jCXO6fI1J", + "2UENUrNFXEeMArKotkwP", + "222.181.247.0/24", + "<(67.48,8.88),20.47>", + datetime.date(2012, 8, 5), + -3428823000.0, + "173.30.26.53", + -495502045, + "5 mons 25 days 21:24:37", + '{\n "key": null\n}', + '{\n "jsonb_key": 54\n}', + "{7.34,8.92,-4.74}", + "[(-33.09,-91.78),(-46.63,-75.51)]", + "aa:26:f0:fb:6e:3a", + "e0:1a:46:50:7d:b0:9f:0b", + '"$745.17"', + Decimal("-786274.55"), + "((-68.79,-68.06),(-86.31,79.21),(-84.42,-97),(-80.85,-31.05),(84.56,12.15))", + "457D6A04/B65BEF56", + "343:454:", + "(-27.87,-96.35)", + "((-13.3,49.21),(88.92,57.41),(-11.51,71.87),(61.62,72.2),(57.03,-84.98),(73.6,-11.46))", + 62962.21, + 32364, + 4, + 4, + "0wMQO5LsnVsn80kMaI 9bWKZnJJcZh VWV2cAG14Q", + datetime.time(9, 52, 2), + datetime.time(12, 40, 50), + datetime.datetime(2015, 5, 26, 19, 39, 51), + datetime.datetime(2015, 11, 14, 16, 36, 6, tzinfo=datetime.timezone.utc), + "'word3' & 'word4'", + "'lex2':2 'lex3':1 'lex4':3", + "667:828:699,799", + "45befd8f-992d-4df6-93f0-b879792d16fe", + "95", + ), + ( + -5569903000000000456, + 5, + "1100", + "11110110", + False, + "(49.79,76.36),(-92.6,65.41)", + b"\xc2\xb9b83bbf671ff5e2441c", + "wOR7ebsM5d", + "eludr6sfjh0rI1LLq73q", + "170.29.227.0/24", + "<(-10.95,-27.76),47.88>", + datetime.date(2018, 4, 29), + -6298310000.0, + "75.135.255.85", + 310008582, + "5 years 8 mons 5 days 00:18:31", + '{\n "key": "value1"\n}', + '{\n "jsonb_key": 37\n}', + "{-1.37,5.98,-6.45}", + "[(54.17,21.95),(47.55,82.92)]", + "41:a3:38:5a:23:13", + "4c:9e:ea:48:2b:8d:5b:39", + '"$1,604.15"', + Decimal("525117.11"), + "((-55.68,-40.24),(7.22,-65.82),(88.94,-5.42))", + "BCE30B1/AE3C727", + "147:739:", + "(-86.34,-55.89)", + "((38.51,61.05),(84.17,17.95),(73.11,-73.06),(-66.17,-78.89),(21.63,-71.6),(-40.84,-70.27))", + 25392.21, + -30587, + 5, + 5, + "KeS36QdtwKnI9vFZ1GakjEBnNLlRpTrLEyzB", + datetime.time(19, 35, 6), + datetime.time(11, 21, 14), + datetime.datetime(2005, 6, 25, 22, 2, 58), + datetime.datetime(2017, 5, 24, 10, 3, 4, tzinfo=datetime.timezone.utc), + "'word2' & 'word3'", + "'lex2':2 'lex3':1 'lex4':3", + "853:868:854,855", + "960b86a9-a8dd-4634-bc1f-956ae6589726", + "47", + ), + ( + None, + 6, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + 6, + 6, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ), +] +EXPECTED_TYPE = StructType( + [ + StructField("BIGINT_COL", LongType(), nullable=True), + StructField("BIGSERIAL_COL", LongType(), nullable=True), + StructField("BIT_COL", StringType(16777216), nullable=True), + StructField("BIT_VARYING_COL", StringType(16777216), nullable=True), + StructField("BOOLEAN_COL", BooleanType(), nullable=True), + StructField("BOX_COL", StringType(16777216), nullable=True), + StructField("BYTEA_COL", BinaryType(), nullable=True), + StructField("CHAR_COL", StringType(16777216), nullable=True), + StructField("VARCHAR_COL", StringType(16777216), nullable=True), + StructField("CIDR_COL", StringType(16777216), nullable=True), + StructField("CIRCLE_COL", StringType(16777216), nullable=True), + StructField("DATE_COL", DateType(), nullable=True), + StructField("DOUBLE_PRECISION_COL", DoubleType(), nullable=True), + StructField("INET_COL", StringType(16777216), nullable=True), + StructField("INTEGER_COL", LongType(), nullable=True), + StructField("INTERVAL_COL", StringType(16777216), nullable=True), + StructField("JSON_COL", VariantType(), nullable=True), + StructField("JSONB_COL", VariantType(), nullable=True), + StructField("LINE_COL", StringType(16777216), nullable=True), + StructField("LSEG_COL", StringType(16777216), nullable=True), + StructField("MACADDR_COL", StringType(16777216), nullable=True), + StructField("MACADDR8_COL", StringType(16777216), nullable=True), + StructField("MONEY_COL", VariantType(), nullable=True), + StructField("NUMERIC_COL", DecimalType(10, 2), nullable=True), + StructField("PATH_COL", StringType(16777216), nullable=True), + StructField("PG_LSN_COL", StringType(16777216), nullable=True), + StructField("PG_SNAPSHOT_COL", StringType(16777216), nullable=True), + StructField("POINT_COL", StringType(16777216), nullable=True), + StructField("POLYGON_COL", StringType(16777216), nullable=True), + StructField("REAL_COL", DoubleType(), nullable=True), + StructField("SMALLINT_COL", LongType(), nullable=True), + StructField("SMALLSERIAL_COL", LongType(), nullable=True), + StructField("SERIAL_COL", LongType(), nullable=True), + StructField("TEXT_COL", StringType(16777216), nullable=True), + StructField("TIME_COL", TimeType(), nullable=True), + StructField("TIME_TZ_COL", TimeType(), nullable=True), + StructField( + "TIMESTAMP_COL", TimestampType(TimestampTimeZone.NTZ), nullable=True + ), + StructField( + "TIMESTAMPTZ_COL", TimestampType(TimestampTimeZone.TZ), nullable=True + ), + StructField("TSQUERY_COL", StringType(16777216), nullable=True), + StructField("TSVECTOR_COL", StringType(16777216), nullable=True), + StructField("TXID_SNAPSHOT_COL", StringType(16777216), nullable=True), + StructField("UUID_COL", StringType(16777216), nullable=True), + StructField("XML_COL", StringType(16777216), nullable=True), + ] +) +POSTGRES_TEST_EXTERNAL_ACCESS_INTEGRATION = "snowpark_dbapi_postgres_test_integration" diff --git a/tests/unit/scala/test_utils_suite.py b/tests/unit/scala/test_utils_suite.py index 3911ce9d3e..be54de6bc9 100644 --- a/tests/unit/scala/test_utils_suite.py +++ b/tests/unit/scala/test_utils_suite.py @@ -318,6 +318,7 @@ def check_zip_files_and_close_stream(input_stream, expected_files): "resources/test_sas.xpt", "resources/test_data_source_dir/", "resources/test_data_source_dir/test_data_source_data.py", + "resources/test_data_source_dir/test_postgres_data.py", "resources/test_data_source_dir/test_databricks_data.py", "resources/test_data_source_dir/test_mysql_data.py", "resources/test_sp_dir/", diff --git a/tox.ini b/tox.ini index 3ed7786245..31975aff89 100644 --- a/tox.ini +++ b/tox.ini @@ -225,8 +225,10 @@ deps = {[testenv]deps} databricks-sql-connector oracledb + psycopg2-binary pymysql -commands = {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE}" {posargs:} tests/integ/datasource +commands = + {env:SNOWFLAKE_PYTEST_CMD} -m "{env:SNOWFLAKE_TEST_TYPE}" {posargs:} tests/integ/datasource [pytest] log_cli = True