Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@

#### New Features

- Added support for ingestion with Snowflake UDTF to databricks in `DataFrameReader.dbapi` (PrPr).
- Added support for Mysql in `DataFrameWriter.dbapi` (PrPr).

#### Bug Fixes

- Fixed a bug in `DataFrameReader.dbapi` (PrPr) where the `create_connection` defined as local function was incompatible with multiprocessing.
- Fixed a bug in `DataFrameReader.dbapi` (PrPr) where databricks `TIMESTAMP` type was converted to Snowflake `TIMESTAMP_NTZ` type which should be `TIMESTAMP_LTZ` type.

#### Improvements

- Added support for reading XML files with namespaces using `rowTag` and `stripNamespaces` options.
Expand Down Expand Up @@ -45,7 +51,6 @@

- Fixed a bug in `DataFrameWriter.dbapi` (PrPr) that unicode or double-quoted column name in external database causes error because not quoted correctly.
- Fixed a bug where named fields in nested OBJECT data could cause errors when containing spaces.
- Fixed a bug in `DataFrameReader.dbapi` (PrPr) where the `create_connection` defined as local function was incompatible with multiprocessing.

### Snowpark Local Testing Updates

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
#
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
from typing import List

from snowflake.snowpark._internal.data_source.dbms_dialects import BaseDialect
from snowflake.snowpark.types import StructType, MapType, BinaryType


class DatabricksDialect(BaseDialect):
pass
def generate_select_query(
self,
table_or_query: str,
schema: StructType,
raw_schema: List[tuple],
is_query: bool,
) -> str:
cols = []
for field, raw_field in zip(schema.fields, raw_schema):
# databricks-sql-connector returns list of tuples for MapType
# here we push down to-dict conversion to Databricks
if isinstance(field.datatype, MapType):
cols.append(f"""TO_JSON(`{raw_field[0]}`) AS {raw_field[0]}""")
elif isinstance(field.datatype, BinaryType):
cols.append(f"""HEX(`{raw_field[0]}`) AS {raw_field[0]}""")
else:
cols.append(f"`{raw_field[0]}`")
return f"""SELECT {" , ".join(cols)} FROM {table_or_query}"""
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def data_source_data_to_pandas_df(

@staticmethod
def to_result_snowpark_df(
session: "Session", table_name, schema, _emit_ast: bool = True
session: "Session", table_name: str, schema: StructType, _emit_ast: bool = True
) -> "DataFrame":
return session.table(table_name, _emit_ast=_emit_ast)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@
#
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#
import json
import logging
from typing import List, Any, TYPE_CHECKING

from snowflake.snowpark._internal.data_source.drivers import BaseDriver
from snowflake.snowpark._internal.type_utils import type_string_to_type_object
from snowflake.snowpark._internal.utils import PythonObjJSONEncoder
from snowflake.snowpark._internal.data_source.datasource_typing import (
Cursor,
)
from snowflake.snowpark._internal.data_source.drivers import BaseDriver
from snowflake.snowpark._internal.type_utils import type_string_to_type_object
from snowflake.snowpark.functions import column, to_variant, parse_json
from snowflake.snowpark.types import (
StructType,
MapType,
StructField,
ArrayType,
VariantType,
TimestampType,
TimestampTimeZone,
)
from snowflake.snowpark.functions import column, to_variant
from snowflake.connector.options import pandas as pd

if TYPE_CHECKING:
from snowflake.snowpark.session import Session # pragma: no cover
Expand All @@ -38,6 +37,7 @@ def infer_schema_from_description(
query = f"DESCRIBE QUERY SELECT * FROM ({table_or_query})"
logger.debug(f"trying to get schema using query: {query}")
raw_schema = cursor.execute(query).fetchall()
self.raw_schema = raw_schema
return self.to_snow_type(raw_schema)

def to_snow_type(self, schema: List[Any]) -> StructType:
Expand All @@ -55,29 +55,55 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
for column_name, column_type, _ in schema:
column_type = convert_map_to_use.get(column_type, column_type)
data_type = type_string_to_type_object(column_type)
if column_type.lower() == "timestamp":
# by default https://docs.databricks.com/aws/en/sql/language-manual/data-types/timestamp-type
data_type = TimestampType(TimestampTimeZone.LTZ)
all_columns.append(StructField(column_name, data_type, True))
return StructType(all_columns)

@staticmethod
def data_source_data_to_pandas_df(
data: List[Any], schema: StructType
) -> "pd.DataFrame":
df = BaseDriver.data_source_data_to_pandas_df(data, schema)
# 1. Regular snowflake table (compared to Iceberg Table) does not support structured data
# type (array, map, struct), thus we store structured data as variant in regular table
# 2. map type needs special handling because:
# i. databricks sql returned it as a list of tuples, which needs to be converted to a dict
# ii. pandas parquet conversion does not support dict having int as key, we convert it to json string
map_type_indexes = [
i
for i, field in enumerate(schema.fields)
if isinstance(field.datatype, MapType)
]
col_names = df.columns[map_type_indexes]
df[col_names] = BaseDriver.df_map_method(df[col_names])(
lambda x: json.dumps(dict(x), cls=PythonObjJSONEncoder)
)
return df
def udtf_class_builder(self, fetch_size: int = 1000) -> type:
create_connection = self.create_connection

class UDTFIngestion:
def process(self, query: str):
conn = create_connection()
cursor = conn.cursor()

# First get schema information
describe_query = f"DESCRIBE QUERY SELECT * FROM ({query})"
cursor.execute(describe_query)
schema_info = cursor.fetchall()

# Find which columns are array types based on column type description
# databricks-sql-connector does not provide built-in output handler nor databricks provide simple
# built-in function to do the transformation meeting our snowflake table requirement
# from nd.array to list
array_column_indices = []
for idx, (_, column_type, _) in enumerate(schema_info):
if column_type.startswith("array<"):
array_column_indices.append(idx)

# Execute the actual query
cursor.execute(query)
while True:
rows = cursor.fetchmany(fetch_size)
if not rows:
break
processed_rows = []
for row in rows:
processed_row = list(row)
# Handle array columns - convert ndarray to list
for idx in array_column_indices:
if (
idx < len(processed_row)
and processed_row[idx] is not None
):
processed_row[idx] = processed_row[idx].tolist()

processed_rows.append(tuple(processed_row))
yield from processed_rows

return UDTFIngestion

@staticmethod
def to_result_snowpark_df(
Expand All @@ -90,7 +116,25 @@ def to_result_snowpark_df(
):
project_columns.append(to_variant(column(field.name)).as_(field.name))
else:
project_columns.append(column(field.name))
project_columns.append(
column(field.name).cast(field.datatype).alias(field.name)
)
return session.table(table_name, _emit_ast=_emit_ast).select(
project_columns, _emit_ast=_emit_ast
)

@staticmethod
def to_result_snowpark_df_udtf(
res_df: "DataFrame",
schema: StructType,
_emit_ast: bool = True,
):
cols = []
for field in schema.fields:
if isinstance(
field.datatype, (MapType, ArrayType, StructType, VariantType)
):
cols.append(to_variant(parse_json(column(field.name))).as_(field.name))
else:
cols.append(res_df[field.name].cast(field.datatype).alias(field.name))
return res_df.select(cols, _emit_ast=_emit_ast)
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/_internal/data_source/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class DRIVER_TYPE(str, Enum):
"msodbcsql",
"snowflake-snowpark-python",
],
DBMS_TYPE.DATABRICKS_DB: [
"snowflake-snowpark-python",
"databricks-sql-connector>=4.0.0,<5.0.0",
],
DBMS_TYPE.MYSQL_DB: ["pymysql>=1.0.0,<2.0.0", "snowflake-snowpark-python"],
}

Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/dataframe_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,6 +1355,7 @@ def create_oracledb_connection():
fetch_size=fetch_size,
imports=udtf_configs.get("imports", None),
packages=udtf_configs.get("packages", None),
_emit_ast=_emit_ast,
)
set_api_call_source(df, DATA_SOURCE_DBAPI_SIGNATURE)
return df
Expand Down
Loading
Loading