Skip to content

Commit 44f0db7

Browse files
authored
SNOW-2060749: add databricks udtf support (#3379)
1 parent 8370e4e commit 44f0db7

File tree

9 files changed

+358
-178
lines changed

9 files changed

+358
-178
lines changed

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@
66

77
#### New Features
88

9+
- Added support for ingestion with Snowflake UDTF to databricks in `DataFrameReader.dbapi` (PrPr).
910
- Added support for Mysql in `DataFrameWriter.dbapi` (PrPr).
1011

12+
#### Bug Fixes
13+
14+
- Fixed a bug in `DataFrameReader.dbapi` (PrPr) where the `create_connection` defined as local function was incompatible with multiprocessing.
15+
- Fixed a bug in `DataFrameReader.dbapi` (PrPr) where databricks `TIMESTAMP` type was converted to Snowflake `TIMESTAMP_NTZ` type which should be `TIMESTAMP_LTZ` type.
16+
1117
#### Improvements
1218

1319
- Added support for reading XML files with namespaces using `rowTag` and `stripNamespaces` options.
@@ -45,7 +51,6 @@
4551

4652
- Fixed a bug in `DataFrameWriter.dbapi` (PrPr) that unicode or double-quoted column name in external database causes error because not quoted correctly.
4753
- Fixed a bug where named fields in nested OBJECT data could cause errors when containing spaces.
48-
- Fixed a bug in `DataFrameReader.dbapi` (PrPr) where the `create_connection` defined as local function was incompatible with multiprocessing.
4954

5055
### Snowpark Local Testing Updates
5156

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,28 @@
11
#
22
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33
#
4+
from typing import List
5+
46
from snowflake.snowpark._internal.data_source.dbms_dialects import BaseDialect
7+
from snowflake.snowpark.types import StructType, MapType, BinaryType
58

69

710
class DatabricksDialect(BaseDialect):
8-
pass
11+
def generate_select_query(
12+
self,
13+
table_or_query: str,
14+
schema: StructType,
15+
raw_schema: List[tuple],
16+
is_query: bool,
17+
) -> str:
18+
cols = []
19+
for field, raw_field in zip(schema.fields, raw_schema):
20+
# databricks-sql-connector returns list of tuples for MapType
21+
# here we push down to-dict conversion to Databricks
22+
if isinstance(field.datatype, MapType):
23+
cols.append(f"""TO_JSON(`{raw_field[0]}`) AS {raw_field[0]}""")
24+
elif isinstance(field.datatype, BinaryType):
25+
cols.append(f"""HEX(`{raw_field[0]}`) AS {raw_field[0]}""")
26+
else:
27+
cols.append(f"`{raw_field[0]}`")
28+
return f"""SELECT {" , ".join(cols)} FROM {table_or_query}"""

src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def data_source_data_to_pandas_df(
190190

191191
@staticmethod
192192
def to_result_snowpark_df(
193-
session: "Session", table_name, schema, _emit_ast: bool = True
193+
session: "Session", table_name: str, schema: StructType, _emit_ast: bool = True
194194
) -> "DataFrame":
195195
return session.table(table_name, _emit_ast=_emit_ast)
196196

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,24 @@
11
#
22
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33
#
4-
import json
54
import logging
65
from typing import List, Any, TYPE_CHECKING
76

8-
from snowflake.snowpark._internal.data_source.drivers import BaseDriver
9-
from snowflake.snowpark._internal.type_utils import type_string_to_type_object
10-
from snowflake.snowpark._internal.utils import PythonObjJSONEncoder
117
from snowflake.snowpark._internal.data_source.datasource_typing import (
128
Cursor,
139
)
10+
from snowflake.snowpark._internal.data_source.drivers import BaseDriver
11+
from snowflake.snowpark._internal.type_utils import type_string_to_type_object
12+
from snowflake.snowpark.functions import column, to_variant, parse_json
1413
from snowflake.snowpark.types import (
1514
StructType,
1615
MapType,
1716
StructField,
1817
ArrayType,
1918
VariantType,
19+
TimestampType,
20+
TimestampTimeZone,
2021
)
21-
from snowflake.snowpark.functions import column, to_variant
22-
from snowflake.connector.options import pandas as pd
2322

2423
if TYPE_CHECKING:
2524
from snowflake.snowpark.session import Session # pragma: no cover
@@ -38,6 +37,7 @@ def infer_schema_from_description(
3837
query = f"DESCRIBE QUERY SELECT * FROM ({table_or_query})"
3938
logger.debug(f"trying to get schema using query: {query}")
4039
raw_schema = cursor.execute(query).fetchall()
40+
self.raw_schema = raw_schema
4141
return self.to_snow_type(raw_schema)
4242

4343
def to_snow_type(self, schema: List[Any]) -> StructType:
@@ -55,29 +55,55 @@ def to_snow_type(self, schema: List[Any]) -> StructType:
5555
for column_name, column_type, _ in schema:
5656
column_type = convert_map_to_use.get(column_type, column_type)
5757
data_type = type_string_to_type_object(column_type)
58+
if column_type.lower() == "timestamp":
59+
# by default https://docs.databricks.com/aws/en/sql/language-manual/data-types/timestamp-type
60+
data_type = TimestampType(TimestampTimeZone.LTZ)
5861
all_columns.append(StructField(column_name, data_type, True))
5962
return StructType(all_columns)
6063

61-
@staticmethod
62-
def data_source_data_to_pandas_df(
63-
data: List[Any], schema: StructType
64-
) -> "pd.DataFrame":
65-
df = BaseDriver.data_source_data_to_pandas_df(data, schema)
66-
# 1. Regular snowflake table (compared to Iceberg Table) does not support structured data
67-
# type (array, map, struct), thus we store structured data as variant in regular table
68-
# 2. map type needs special handling because:
69-
# i. databricks sql returned it as a list of tuples, which needs to be converted to a dict
70-
# ii. pandas parquet conversion does not support dict having int as key, we convert it to json string
71-
map_type_indexes = [
72-
i
73-
for i, field in enumerate(schema.fields)
74-
if isinstance(field.datatype, MapType)
75-
]
76-
col_names = df.columns[map_type_indexes]
77-
df[col_names] = BaseDriver.df_map_method(df[col_names])(
78-
lambda x: json.dumps(dict(x), cls=PythonObjJSONEncoder)
79-
)
80-
return df
64+
def udtf_class_builder(self, fetch_size: int = 1000) -> type:
65+
create_connection = self.create_connection
66+
67+
class UDTFIngestion:
68+
def process(self, query: str):
69+
conn = create_connection()
70+
cursor = conn.cursor()
71+
72+
# First get schema information
73+
describe_query = f"DESCRIBE QUERY SELECT * FROM ({query})"
74+
cursor.execute(describe_query)
75+
schema_info = cursor.fetchall()
76+
77+
# Find which columns are array types based on column type description
78+
# databricks-sql-connector does not provide built-in output handler nor databricks provide simple
79+
# built-in function to do the transformation meeting our snowflake table requirement
80+
# from nd.array to list
81+
array_column_indices = []
82+
for idx, (_, column_type, _) in enumerate(schema_info):
83+
if column_type.startswith("array<"):
84+
array_column_indices.append(idx)
85+
86+
# Execute the actual query
87+
cursor.execute(query)
88+
while True:
89+
rows = cursor.fetchmany(fetch_size)
90+
if not rows:
91+
break
92+
processed_rows = []
93+
for row in rows:
94+
processed_row = list(row)
95+
# Handle array columns - convert ndarray to list
96+
for idx in array_column_indices:
97+
if (
98+
idx < len(processed_row)
99+
and processed_row[idx] is not None
100+
):
101+
processed_row[idx] = processed_row[idx].tolist()
102+
103+
processed_rows.append(tuple(processed_row))
104+
yield from processed_rows
105+
106+
return UDTFIngestion
81107

82108
@staticmethod
83109
def to_result_snowpark_df(
@@ -90,7 +116,25 @@ def to_result_snowpark_df(
90116
):
91117
project_columns.append(to_variant(column(field.name)).as_(field.name))
92118
else:
93-
project_columns.append(column(field.name))
119+
project_columns.append(
120+
column(field.name).cast(field.datatype).alias(field.name)
121+
)
94122
return session.table(table_name, _emit_ast=_emit_ast).select(
95123
project_columns, _emit_ast=_emit_ast
96124
)
125+
126+
@staticmethod
127+
def to_result_snowpark_df_udtf(
128+
res_df: "DataFrame",
129+
schema: StructType,
130+
_emit_ast: bool = True,
131+
):
132+
cols = []
133+
for field in schema.fields:
134+
if isinstance(
135+
field.datatype, (MapType, ArrayType, StructType, VariantType)
136+
):
137+
cols.append(to_variant(parse_json(column(field.name))).as_(field.name))
138+
else:
139+
cols.append(res_df[field.name].cast(field.datatype).alias(field.name))
140+
return res_df.select(cols, _emit_ast=_emit_ast)

src/snowflake/snowpark/_internal/data_source/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ class DRIVER_TYPE(str, Enum):
8080
"msodbcsql",
8181
"snowflake-snowpark-python",
8282
],
83+
DBMS_TYPE.DATABRICKS_DB: [
84+
"snowflake-snowpark-python",
85+
"databricks-sql-connector>=4.0.0,<5.0.0",
86+
],
8387
DBMS_TYPE.MYSQL_DB: ["pymysql>=1.0.0,<2.0.0", "snowflake-snowpark-python"],
8488
}
8589

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,7 @@ def create_oracledb_connection():
13551355
fetch_size=fetch_size,
13561356
imports=udtf_configs.get("imports", None),
13571357
packages=udtf_configs.get("packages", None),
1358+
_emit_ast=_emit_ast,
13581359
)
13591360
set_api_call_source(df, DATA_SOURCE_DBAPI_SIGNATURE)
13601361
return df

0 commit comments

Comments
 (0)