Skip to content

Commit e984327

Browse files
committed
SNOW-2360274: Fix schema query sql generation for structured types (#3804)
1 parent 2f8342e commit e984327

File tree

5 files changed

+133
-7
lines changed

5 files changed

+133
-7
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
### Snowpark Python API Updates
66

7-
#### New Features
7+
#### Bug Fixes
8+
9+
- Added an experimental fix for a bug in schema query generation that could cause invalid sql to be genrated when using nested structured types.
810

911
## 1.39.0 (YYYY-MM-DD)
1012

src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from decimal import Decimal
1313
from typing import Any
1414

15+
import snowflake.snowpark.context as context
1516
import snowflake.snowpark._internal.analyzer.analyzer_utils as analyzer_utils
1617
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
1718
from snowflake.snowpark._internal.utils import (
@@ -518,18 +519,26 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str:
518519
if isinstance(data_type, ArrayType):
519520
if data_type.structured:
520521
assert data_type.element_type is not None
521-
element = schema_expression(data_type.element_type, data_type.contains_null)
522+
if context._enable_fix_2360274:
523+
element = "NULL"
524+
else:
525+
element = schema_expression(
526+
data_type.element_type, data_type.contains_null
527+
)
522528
return f"to_array({element}) :: {convert_sp_to_sf_type(data_type)}"
523529
return "to_array(0)"
524530
if isinstance(data_type, MapType):
525531
if data_type.structured:
526532
assert data_type.key_type is not None and data_type.value_type is not None
527533
# Key values can never be null
528534
key = schema_expression(data_type.key_type, False)
529-
# Value nullability is variable. Defaults to True
530-
value = schema_expression(
531-
data_type.value_type, data_type.value_contains_null
532-
)
535+
if context._enable_fix_2360274:
536+
value = "NULL"
537+
else:
538+
# Value nullability is variable. Defaults to True
539+
value = schema_expression(
540+
data_type.value_type, data_type.value_contains_null
541+
)
533542
return f"object_construct_keep_null({key}, {value}) :: {convert_sp_to_sf_type(data_type)}"
534543
return "to_object(parse_json('0'))"
535544
if isinstance(data_type, StructType):
@@ -539,7 +548,9 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str:
539548
# Even if nulls are allowed the cast will fail due to schema mismatch when passed a null field.
540549
schema_strings += [
541550
f"'{field.name}'",
542-
schema_expression(field.datatype, is_nullable=False),
551+
"NULL"
552+
if context._enable_fix_2360274
553+
else schema_expression(field.datatype, is_nullable=False),
543554
]
544555
return f"object_construct_keep_null({', '.join(schema_strings)}) :: {convert_sp_to_sf_type(data_type)}"
545556
return "to_object(parse_json('{}'))"

src/snowflake/snowpark/context.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
# This is an internal-only global flag, used to determine whether to enable query line tracking for tracing sql compilation errors.
4040
_enable_trace_sql_errors_to_dataframe = False
4141

42+
# SNOW-2362050: Enable this fix by default.
43+
# Global flag for fix 2360274. When enabled schema queries will use NULL as a place holder for any values inside structured objects
44+
_enable_fix_2360274 = False
45+
4246

4347
def configure_development_features(
4448
*,

tests/integ/scala/test_datatype_suite.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import logging
1111
import pytest
12+
from unittest import mock
1213

1314
import snowflake.snowpark.context as context
1415
from snowflake.connector.options import installed_pandas
@@ -1763,3 +1764,110 @@ def test_lob_collect_max_size(session, server_side_max_string, type_string, data
17631764
)
17641765
assert df.schema == StructType([StructField("DATA", datatype, nullable=False)])
17651766
assert len(df.collect()[0][0]) >= server_side_max_string - 16
1767+
1768+
1769+
@pytest.mark.skipif(
1770+
"config.getoption('local_testing_mode', default=False)",
1771+
reason="Structured types are not supported in Local Testing",
1772+
)
1773+
@pytest.mark.parametrize("fix_enabled", [True, False])
1774+
def test_snow_2360274_repro(
1775+
structured_type_session, structured_type_support, fix_enabled
1776+
):
1777+
if not structured_type_support:
1778+
pytest.skip("Test requires structured type support.")
1779+
1780+
agg_table_name = f"snowpark_2360274_repro_agg_{uuid.uuid4().hex[:5]}".upper()
1781+
1782+
nested_field_name = (
1783+
"value" if context._should_use_structured_type_semantics() else '"value"'
1784+
)
1785+
expected_schema = StructType(
1786+
[
1787+
StructField("ID", LongType(), nullable=False),
1788+
StructField(
1789+
"VALS_ARR",
1790+
ArrayType(
1791+
StructType(
1792+
[StructField(nested_field_name, StringType(10), nullable=True)]
1793+
)
1794+
),
1795+
nullable=True,
1796+
),
1797+
StructField(
1798+
"VALS_MAP",
1799+
MapType(StringType(10), StringType(10)),
1800+
nullable=True,
1801+
),
1802+
StructField(
1803+
"VALS_OBJ",
1804+
StructType(
1805+
[StructField(nested_field_name, StringType(10), nullable=True)]
1806+
),
1807+
nullable=True,
1808+
),
1809+
StructField("TAG", StringType(2), nullable=False),
1810+
]
1811+
)
1812+
1813+
def inner():
1814+
structured_type_session.sql(
1815+
f"""
1816+
CREATE
1817+
OR REPLACE TABLE {agg_table_name} (
1818+
ID INT NOT NULL,
1819+
VALS_ARR ARRAY(OBJECT({nested_field_name} STRING(10))) NOT NULL,
1820+
VALS_MAP MAP(STRING(10), STRING(10)) NOT NULL,
1821+
VALS_OBJ OBJECT({nested_field_name} STRING(10)) NOT NULL
1822+
) AS WITH SRC(ID, VALUE) AS (
1823+
SELECT
1824+
$1,
1825+
$2
1826+
FROM
1827+
VALUES
1828+
(1, 'A'),
1829+
(1, 'B'),
1830+
(2, 'A')
1831+
)
1832+
SELECT
1833+
ID,
1834+
CAST(
1835+
ARRAY_AGG(OBJECT_CONSTRUCT('value', VALUE)) AS ARRAY(OBJECT({nested_field_name} STRING))
1836+
) AS VALS_ARR,
1837+
CAST(
1838+
OBJECT_CONSTRUCT('value', VALUE) AS MAP(STRING, STRING)
1839+
) AS VALS_MAP,
1840+
CAST(
1841+
OBJECT_CONSTRUCT('value', VALUE) AS OBJECT({nested_field_name} STRING)
1842+
) AS VALS_OBJ,
1843+
FROM
1844+
SRC
1845+
GROUP BY
1846+
ID, VALS_MAP, VALS_OBJ"""
1847+
).collect()
1848+
1849+
agged = structured_type_session.table(agg_table_name)
1850+
1851+
reference = structured_type_session.sql(
1852+
"""
1853+
SELECT $1 AS ID, $2 AS TAG FROM VALUES (1, 'AB'), (2, 'B')
1854+
"""
1855+
)
1856+
1857+
joined = agged.join(reference, on=agged.id == reference.id, how="inner").select(
1858+
agged.id.alias("ID"), "VALS_ARR", "VALS_MAP", "VALS_OBJ", "TAG"
1859+
)
1860+
Utils.is_schema_same(joined.schema, expected_schema, case_sensitive=False)
1861+
1862+
try:
1863+
with mock.patch.object(context, "_enable_fix_2360274", fix_enabled):
1864+
if fix_enabled:
1865+
inner()
1866+
else:
1867+
with pytest.raises(
1868+
SnowparkSQLException,
1869+
match="Unsupported data type 'STRUCTURED_OBJECT'",
1870+
):
1871+
inner()
1872+
finally:
1873+
Utils.drop_table(structured_type_session, agg_table_name)

tests/integ/test_stored_procedure.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2291,6 +2291,7 @@ def artifact_repo_test(_):
22912291
@pytest.mark.skipif(
22922292
sys.version_info < (3, 9), reason="artifact repository requires Python 3.9+"
22932293
)
2294+
@pytest.mark.skip("SNOW-2362946: Skip until root cause is found.")
22942295
def test_sproc_artifact_repository_from_file(session, tmpdir):
22952296
source = dedent(
22962297
"""

0 commit comments

Comments
 (0)