Skip to content

Commit d5b232d

Browse files
authored
SNOW-2367027: address behavioral gap between udtf and parquet approach (#3848)
1 parent 2e84f16 commit d5b232d

File tree

14 files changed

+206
-194
lines changed

14 files changed

+206
-194
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@
6767
- Fixed a bug where writing Snowpark pandas dataframes on the pandas backend with a column multiindex to Snowflake with `to_snowflake` would raise `KeyError`.
6868
- Fixed a bug that `DataFrameReader.dbapi` (PuPr) is not compatible with oracledb 3.4.0.
6969

70+
#### Improvements
71+
72+
- The default maximum length for inferred StringType columns during schema inference in `DataFrameReader.dbapi` is now increased from 16MB to 128MB in parquet file based ingestion.
73+
7074
#### Dependency Updates
7175

7276
- Updated dependency of `snowflake-connector-python>=3.17,<5.0.0`.

src/snowflake/snowpark/_internal/data_source/drivers/base_driver.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Connection,
1212
Cursor,
1313
)
14+
from snowflake.snowpark._internal.server_connection import MAX_STRING_SIZE
1415
from snowflake.snowpark._internal.utils import (
1516
get_sorted_key_for_version,
1617
measure_time,
@@ -27,6 +28,7 @@
2728
BinaryType,
2829
DateType,
2930
BooleanType,
31+
StringType,
3032
)
3133
import snowflake.snowpark
3234
import logging
@@ -103,7 +105,16 @@ def infer_schema_from_description(
103105
query_input_alias: str,
104106
) -> StructType:
105107
self.get_raw_schema(table_or_query, cursor, is_query, query_input_alias)
106-
return self.to_snow_type(self.raw_schema)
108+
generated_schema = self.to_snow_type(self.raw_schema)
109+
# snowflake will default string length to 128MB in the bundle which will be enabled in 2026-01
110+
# https://docs.snowflake.com/en/release-notes/bcr-bundles/2025_07_bundle
111+
# here we prematurely make the change to default string to
112+
# 1. align the string length with UDTF based ingestion
113+
# 2. avoid the BCR impact to dbapi feature
114+
for field in generated_schema.fields:
115+
if isinstance(field.datatype, StringType) and field.datatype.length is None:
116+
field.datatype.length = MAX_STRING_SIZE
117+
return generated_schema
107118

108119
def infer_schema_from_description_with_error_control(
109120
self, table_or_query: str, is_query: bool, query_input_alias: str
@@ -184,7 +195,10 @@ def udtf_ingestion(
184195
select * from {partition_table}, table({udtf_name}({PARTITION_TABLE_COLUMN_NAME}))
185196
"""
186197
res = session.sql(call_udtf_sql, _emit_ast=_emit_ast)
187-
return self.to_result_snowpark_df_udtf(res, schema, _emit_ast=_emit_ast)
198+
return BaseDriver.keep_nullable_attributes(
199+
self.to_result_snowpark_df_udtf(res, schema, _emit_ast=_emit_ast),
200+
schema,
201+
)
188202

189203
def udtf_class_builder(
190204
self,
@@ -284,6 +298,14 @@ def to_result_snowpark_df(
284298
) -> "DataFrame":
285299
return session.table(table_name, _emit_ast=_emit_ast)
286300

301+
@staticmethod
302+
def keep_nullable_attributes(
303+
selected_df: "DataFrame", schema: StructType
304+
) -> "DataFrame":
305+
for attr, source_field in zip(selected_df._plan.attributes, schema.fields):
306+
attr.nullable = source_field.nullable
307+
return selected_df
308+
287309
@staticmethod
288310
def to_result_snowpark_df_udtf(
289311
res_df: "DataFrame",
@@ -294,10 +316,7 @@ def to_result_snowpark_df_udtf(
294316
res_df[field.name].cast(field.datatype).alias(field.name)
295317
for field in schema.fields
296318
]
297-
selected_df = res_df.select(cols, _emit_ast=_emit_ast)
298-
for attr, source_field in zip(selected_df._plan.attributes, schema.fields):
299-
attr.nullable = source_field.nullable
300-
return selected_df
319+
return res_df.select(cols, _emit_ast=_emit_ast)
301320

302321
def get_server_cursor_if_supported(self, conn: "Connection") -> "Cursor":
303322
"""

src/snowflake/snowpark/_internal/server_connection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
PARAM_INTERNAL_APPLICATION_NAME = "internal_application_name"
8787
PARAM_INTERNAL_APPLICATION_VERSION = "internal_application_version"
8888
DEFAULT_STRING_SIZE = 16777216
89+
MAX_STRING_SIZE = 134217728
8990

9091

9192
def _build_target_path(stage_location: str, dest_prefix: str = "") -> str:

src/snowflake/snowpark/dataframe_reader.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,18 +1707,24 @@ def dbapi(
17071707
Reads data from a database table or query into a DataFrame using a DBAPI connection,
17081708
with support for optional partitioning, parallel processing, and query customization.
17091709
1710-
There are multiple methods to partition data and accelerate ingestion.
1711-
These methods can be combined to achieve optimal performance:
1712-
1713-
1.Use column, lower_bound, upper_bound and num_partitions at the same time when you need to split large tables into smaller partitions for parallel processing.
1714-
These must all be specified together, otherwise error will be raised.
1715-
2.Set max_workers to a proper positive integer.
1716-
This defines the maximum number of processes and threads used for parallel execution.
1717-
3.Adjusting fetch_size can optimize performance by reducing the number of round trips to the database.
1718-
4.Use predicates to defining WHERE conditions for partitions,
1719-
predicates will be ignored if column is specified to generate partition.
1720-
5.Set custom_schema to avoid snowpark infer schema, custom_schema must have a matched
1721-
column name with table in external data source.
1710+
Usage Notes:
1711+
- Ingestion performance tuning:
1712+
- **Partitioning**: Use ``column``, ``lower_bound``, ``upper_bound``, and ``num_partitions``
1713+
together to split large tables into smaller partitions for parallel processing.
1714+
All four parameters must be specified together, otherwise an error will be raised.
1715+
- **Parallel execution**: Set ``max_workers`` to control the maximum number of processes
1716+
and threads used for parallel execution.
1717+
- **Fetch optimization**: Adjust ``fetch_size`` to optimize performance by reducing
1718+
the number of round trips to the database.
1719+
- **Partition filtering**: Use ``predicates`` to define WHERE conditions for partitions.
1720+
Note that ``predicates`` will be ignored if ``column`` is specified for partitioning.
1721+
- **Schema specification**: Set ``custom_schema`` to skip schema inference. The custom schema
1722+
must have matching column names with the table in the external data source.
1723+
- Execution timing and error handling:
1724+
- **UDTF Ingestion**: Uses lazy evaluation. Errors are reported as ``SnowparkSQLException``
1725+
during DataFrame actions (e.g., ``DataFrame.collect()``).
1726+
- **Local Ingestion**: Uses eager execution. Errors are reported immediately as
1727+
``SnowparkDataFrameReaderException`` when this method is called.
17221728
17231729
Args:
17241730
create_connection: A callable that returns a DB-API compatible database connection.

tests/integ/datasource/test_databricks.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,9 +177,6 @@ def test_double_quoted_column_databricks(session, custom_schema):
177177
[("table", TEST_TABLE_NAME), ("query", f"(SELECT * FROM {TEST_TABLE_NAME})")],
178178
)
179179
@pytest.mark.udf
180-
@pytest.mark.skipif(
181-
sys.version_info[:2] == (3, 13), reason="driver not supported in python 3.13"
182-
)
183180
def test_udtf_ingestion_databricks(session, input_type, input_value, caplog):
184181
# we define here to avoid test_databricks.py to be pickled and unpickled in UDTF
185182
def local_create_databricks_connection():

tests/integ/datasource/test_mysql.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#
44
import logging
55
import math
6-
import sys
76
from decimal import Decimal
87

98
import pytest
@@ -225,9 +224,6 @@ def test_infer_type_from_data(data, number_of_columns, expected_result):
225224

226225

227226
@pytest.mark.udf
228-
@pytest.mark.skipif(
229-
sys.version_info[:2] == (3, 13), reason="driver not supported in python 3.13"
230-
)
231227
def test_udtf_ingestion_mysql(session, caplog):
232228
from tests.parameters import MYSQL_CONNECTION_PARAMETERS
233229

@@ -251,6 +247,7 @@ def create_connection_mysql():
251247
).order_by("ID")
252248

253249
Utils.check_answer(df, mysql_real_data)
250+
assert df.schema == mysql_schema
254251

255252
# check that udtf is used
256253
assert (

tests/integ/datasource/test_oracledb.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import logging
66
import math
7-
import sys
87
from collections import namedtuple
98
from unittest.mock import patch
109

@@ -154,9 +153,6 @@ def test_oracledb_driver_coverage(caplog):
154153

155154

156155
@pytest.mark.udf
157-
@pytest.mark.skipif(
158-
sys.version_info[:2] == (3, 13), reason="driver not supported in python 3.13"
159-
)
160156
def test_udtf_ingestion_oracledb(session):
161157
from tests.parameters import ORACLEDB_CONNECTION_PARAMETERS
162158

@@ -183,6 +179,7 @@ def create_connection_oracledb():
183179
).order_by("ID")
184180

185181
Utils.check_answer(df, oracledb_real_data)
182+
assert df.schema == oracledb_real_schema
186183

187184
# check that udtf is used
188185
flag = False

tests/integ/datasource/test_postgres.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
#
22
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33
#
4-
import sys
5-
64
import pytest
75

86
from snowflake.snowpark import Row
@@ -174,9 +172,6 @@ def test_unicode_column_name_postgres(session, custom_schema):
174172
],
175173
)
176174
@pytest.mark.udf
177-
@pytest.mark.skipif(
178-
sys.version_info[:2] == (3, 13), reason="driver not supported in python 3.13"
179-
)
180175
def test_udtf_ingestion_postgres(session, input_type, input_value, caplog):
181176
from tests.parameters import POSTGRES_CONNECTION_PARAMETERS
182177

@@ -196,7 +191,7 @@ def create_connection_postgres():
196191
},
197192
).order_by("BIGSERIAL_COL")
198193

199-
assert df.collect() == EXPECTED_TEST_DATA
194+
assert df.collect() == EXPECTED_TEST_DATA and df.schema == postgres_schema
200195
# assert UDTF creation and UDTF call
201196
assert (
202197
"TEMPORARY FUNCTION SNOWPARK_TEMP_FUNCTION" "" in caplog.text

tests/integ/datasource/test_sql_server.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pytest
66

77
from snowflake.snowpark._internal.data_source.utils import DBMS_TYPE
8-
from snowflake.snowpark.types import StringType
98

109
from tests.parameters import SQL_SERVER_CONNECTION_PARAMETERS
1110
from tests.utils import IS_IN_STORED_PROC, Utils, IS_WINDOWS, IS_MACOS, RUNNING_ON_GH
@@ -63,30 +62,7 @@ def verify_save_table_result(
6362
df = df.order_by("ID")
6463

6564
Utils.check_answer(df, expected_data)
66-
67-
def verify_schemas(df, expected_schema, ignore_string_size):
68-
# TODO: SNOW-2362041
69-
# - UDTF ingestion returning StringType 128 MB (due to variant default to 128MB)
70-
# - parquet based ingestion returning StringType 16 MB
71-
# we should align the two
72-
for field, expected_field in zip(df.schema.fields, expected_schema.fields):
73-
if isinstance(field.datatype, StringType):
74-
assert isinstance(field.datatype, type(expected_field.datatype))
75-
if ignore_string_size:
76-
assert (
77-
field.datatype.length == expected_field.datatype.length
78-
or field.datatype.length == 134217728
79-
)
80-
else:
81-
assert field.datatype.length == expected_field.datatype.length
82-
else:
83-
assert field.datatype == expected_field.datatype
84-
assert field.name == expected_field.name
85-
assert field.nullable == expected_field.nullable
86-
87-
verify_schemas(df, expected_schema, ignore_string_size)
88-
# after the fix SNOW-2362041, we should be able to enable this assertion
89-
# assert df.schema == expected_schema
65+
assert df.schema == expected_schema
9066

9167
table_name = Utils.random_table_name()
9268
# save and read
@@ -97,9 +73,7 @@ def verify_schemas(df, expected_schema, ignore_string_size):
9773
read_table = read_table.order_by("ID")
9874

9975
Utils.check_answer(read_table, expected_data)
100-
verify_schemas(read_table, expected_schema, ignore_string_size)
101-
# after the fix SNOW-2362041, we should be able to enable this assertion
102-
# assert read_table.schema == expected_schema
76+
assert read_table.schema == expected_schema
10377

10478

10579
def create_connection_sql_server():
@@ -365,9 +339,6 @@ def connection_func():
365339
with pytest.raises(
366340
SnowparkClientException, match="Must declare the scalar variable"
367341
):
368-
# TODO: 2362041, UDTF error experience is different from parquet ingestion
369-
# 1. UDTF needs .collect() to trigger the error while parquet ingestion triggers on .dbapi()
370-
# 2. error exception is different
371342
session.read.dbapi(connection_func, **dbapi_kwargs).collect()
372343

373344

tests/resources/test_data_source_dir/test_data_source_data.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
NullType,
3232
TimestampTimeZone,
3333
)
34+
from snowflake.snowpark._internal.server_connection import MAX_STRING_SIZE
3435

3536

3637
# we manually mock these objects because mock object cannot be used in multi-process as they are not pickleable
@@ -95,12 +96,12 @@ def execute(self, sql: str):
9596
StructField("NUMBER_COL", DecimalType(10, 2), nullable=True),
9697
StructField("BINARY_FLOAT_COL", DoubleType(), nullable=True),
9798
StructField("BINARY_DOUBLE_COL", DoubleType(), nullable=True),
98-
StructField("VARCHAR2_COL", StringType(16777216), nullable=True),
99-
StructField("CHAR_COL", StringType(16777216), nullable=True),
100-
StructField("CLOB_COL", StringType(16777216), nullable=True),
101-
StructField("NCHAR_COL", StringType(16777216), nullable=True),
102-
StructField("NVARCHAR2_COL", StringType(16777216), nullable=True),
103-
StructField("NCLOB_COL", StringType(16777216), nullable=True),
99+
StructField("VARCHAR2_COL", StringType(MAX_STRING_SIZE), nullable=True),
100+
StructField("CHAR_COL", StringType(MAX_STRING_SIZE), nullable=True),
101+
StructField("CLOB_COL", StringType(MAX_STRING_SIZE), nullable=True),
102+
StructField("NCHAR_COL", StringType(MAX_STRING_SIZE), nullable=True),
103+
StructField("NVARCHAR2_COL", StringType(MAX_STRING_SIZE), nullable=True),
104+
StructField("NCLOB_COL", StringType(MAX_STRING_SIZE), nullable=True),
104105
StructField("DATE_COL", DateType(), nullable=True),
105106
StructField(
106107
"TIMESTAMP_COL", TimestampType(TimestampTimeZone.NTZ), nullable=True
@@ -131,12 +132,12 @@ def execute(self, sql: str):
131132
StructField("NUMBER_COL", DecimalType(10, 2), nullable=True),
132133
StructField("BINARY_FLOAT_COL", DoubleType(), nullable=True),
133134
StructField("BINARY_DOUBLE_COL", DoubleType(), nullable=True),
134-
StructField("VARCHAR2_COL", StringType(16777216), nullable=True),
135-
StructField("CHAR_COL", StringType(16777216), nullable=True),
136-
StructField("CLOB_COL", StringType(16777216), nullable=True),
137-
StructField("NCHAR_COL", StringType(16777216), nullable=True),
138-
StructField("NVARCHAR2_COL", StringType(16777216), nullable=True),
139-
StructField("NCLOB_COL", StringType(16777216), nullable=True),
135+
StructField("VARCHAR2_COL", StringType(MAX_STRING_SIZE), nullable=True),
136+
StructField("CHAR_COL", StringType(MAX_STRING_SIZE), nullable=True),
137+
StructField("CLOB_COL", StringType(MAX_STRING_SIZE), nullable=True),
138+
StructField("NCHAR_COL", StringType(MAX_STRING_SIZE), nullable=True),
139+
StructField("NVARCHAR2_COL", StringType(MAX_STRING_SIZE), nullable=True),
140+
StructField("NCLOB_COL", StringType(MAX_STRING_SIZE), nullable=True),
140141
StructField("DATE_COL", DateType(), nullable=True),
141142
StructField(
142143
"TIMESTAMP_COL", TimestampType(TimestampTimeZone.NTZ), nullable=True
@@ -156,18 +157,18 @@ def execute(self, sql: str):
156157
oracledb_unicode_schema = StructType(
157158
[
158159
StructField('"編號"', LongType(), nullable=False),
159-
StructField('"姓名"', StringType(16777216), nullable=True),
160-
StructField('"國家"', StringType(16777216), nullable=True),
161-
StructField('"備註"', StringType(16777216), nullable=True),
160+
StructField('"姓名"', StringType(MAX_STRING_SIZE), nullable=True),
161+
StructField('"國家"', StringType(MAX_STRING_SIZE), nullable=True),
162+
StructField('"備註"', StringType(MAX_STRING_SIZE), nullable=True),
162163
]
163164
)
164165

165166
oracledb_double_quoted_schema = StructType(
166167
[
167168
StructField("ID", LongType(), nullable=False),
168-
StructField("FULLNAME", StringType(16777216), nullable=True),
169-
StructField("COUNTRY", StringType(16777216), nullable=True),
170-
StructField("NOTES", StringType(16777216), nullable=True),
169+
StructField("FULLNAME", StringType(MAX_STRING_SIZE), nullable=True),
170+
StructField("COUNTRY", StringType(MAX_STRING_SIZE), nullable=True),
171+
StructField("NOTES", StringType(MAX_STRING_SIZE), nullable=True),
171172
]
172173
)
173174

0 commit comments

Comments
 (0)