Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d9bf2cb
SNOW-1829870: Allow structured types to be enabled by default
sfc-gh-jrose Dec 5, 2024
ec43e1a
type checking
sfc-gh-jrose Dec 6, 2024
7f3a5fd
lint
sfc-gh-jrose Dec 6, 2024
2e0dce9
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 16, 2024
ed232de
Move flag to context
sfc-gh-jrose Dec 16, 2024
0dd7b91
typo
sfc-gh-jrose Dec 16, 2024
13c1424
SNOW-1852779 Fix AST encoding for Column `in_`, `asc`, and `desc` (#2…
sfc-gh-vbudati Dec 16, 2024
a787e74
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 16, 2024
b32806f
merge main and fix test
sfc-gh-jrose Dec 17, 2024
c3db223
make feature flag thread safe
sfc-gh-jrose Dec 17, 2024
1c262d7
typo
sfc-gh-jrose Dec 17, 2024
869931f
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 17, 2024
0caef58
Fix ast test
sfc-gh-jrose Dec 17, 2024
2380040
move lock
sfc-gh-jrose Dec 18, 2024
995e519
test coverage
sfc-gh-jrose Dec 18, 2024
1b89027
remove context manager
sfc-gh-jrose Dec 18, 2024
4fc61d4
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 19, 2024
26fd29e
switch to using patch
sfc-gh-jrose Dec 19, 2024
9295e11
move test to other module
sfc-gh-jrose Dec 19, 2024
fcd16d7
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 19, 2024
77a57a6
fix broken import
sfc-gh-jrose Dec 19, 2024
4769169
another broken import
sfc-gh-jrose Dec 19, 2024
af5af87
another test fix
sfc-gh-jrose Dec 19, 2024
dea741b
Merge branch 'main' into jrose_snow_1829870_structured_by_default
sfc-gh-jrose Dec 20, 2024
ee22980
SNOW-1865926: Infer schema for StructType columns from nested Rows
sfc-gh-jrose Dec 20, 2024
087b238
SNOW-1843881: Change StructType columns to return Row objects
sfc-gh-jrose Jan 2, 2025
831e4c7
Review Feedback
sfc-gh-jrose Jan 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,16 @@ def to_sql(
return f"'{binascii.hexlify(bytes(value)).decode()}' :: BINARY"

if isinstance(value, (list, tuple, array)) and isinstance(datatype, ArrayType):
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: ARRAY"
type_str = "ARRAY"
if datatype.structured:
type_str = convert_sp_to_sf_type(datatype)
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: {type_str}"

if isinstance(value, dict) and isinstance(datatype, MapType):
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: OBJECT"
type_str = "OBJECT"
if datatype.structured:
type_str = convert_sp_to_sf_type(datatype)
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: {type_str}"

if isinstance(datatype, VariantType):
# PARSE_JSON returns VARIANT, so no need to append :: VARIANT here explicitly.
Expand Down Expand Up @@ -260,11 +266,14 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str:
return "to_timestamp('2020-09-16 06:30:00')"
if isinstance(data_type, ArrayType):
if data_type.structured:
assert isinstance(data_type.element_type, DataType)
element = schema_expression(data_type.element_type, is_nullable)
return f"to_array({element}) :: {convert_sp_to_sf_type(data_type)}"
return "to_array(0)"
if isinstance(data_type, MapType):
if data_type.structured:
assert isinstance(data_type.key_type, DataType)
assert isinstance(data_type.value_type, DataType)
key = schema_expression(data_type.key_type, is_nullable)
value = schema_expression(data_type.value_type, is_nullable)
return f"object_construct_keep_null({key}, {value}) :: {convert_sp_to_sf_type(data_type)}"
Expand Down
23 changes: 16 additions & 7 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from snowflake.connector.cursor import ResultMetadata
from snowflake.connector.options import installed_pandas, pandas
from snowflake.snowpark._internal.utils import quote_name
from snowflake.snowpark.row import Row
from snowflake.snowpark.types import (
LTZ,
NTZ,
Expand Down Expand Up @@ -159,7 +160,7 @@ def convert_metadata_to_sp_type(
[
StructField(
field.name
if context._should_use_structured_type_semantics
if context._should_use_structured_type_semantics()
else quote_name(field.name, keep_case=True),
convert_metadata_to_sp_type(field, max_string_size),
nullable=field.is_nullable,
Expand Down Expand Up @@ -187,12 +188,15 @@ def convert_sf_to_sp_type(
max_string_size: int,
) -> DataType:
"""Convert the Snowflake logical type to the Snowpark type."""
semi_structured_fill = (
None if context._should_use_structured_type_semantics() else StringType()
)
if column_type_name == "ARRAY":
return ArrayType(StringType())
return ArrayType(semi_structured_fill)
if column_type_name == "VARIANT":
return VariantType()
if column_type_name in {"OBJECT", "MAP"}:
return MapType(StringType(), StringType())
return MapType(semi_structured_fill, semi_structured_fill)
if column_type_name == "GEOGRAPHY":
return GeographyType()
if column_type_name == "GEOMETRY":
Expand Down Expand Up @@ -438,6 +442,8 @@ def infer_type(obj: Any) -> DataType:
if key is not None and value is not None:
return MapType(infer_type(key), infer_type(value))
return MapType(NullType(), NullType())
elif isinstance(obj, Row) and context._should_use_structured_type_semantics():
return infer_schema(obj)
elif isinstance(obj, (list, tuple)):
for v in obj:
if v is not None:
Expand Down Expand Up @@ -534,7 +540,10 @@ def merge_type(a: DataType, b: DataType, name: Optional[str] = None) -> DataType
return a


def python_value_str_to_object(value, tp: DataType) -> Any:
def python_value_str_to_object(value, tp: Optional[DataType]) -> Any:
if tp is None:
return None

if isinstance(tp, StringType):
return value

Expand Down Expand Up @@ -643,7 +652,7 @@ def python_type_to_snow_type(
element_type = (
python_type_to_snow_type(tp_args[0], is_return_type_of_sproc)[0]
if tp_args
else StringType()
else None
)
return ArrayType(element_type), False

Expand All @@ -653,12 +662,12 @@ def python_type_to_snow_type(
key_type = (
python_type_to_snow_type(tp_args[0], is_return_type_of_sproc)[0]
if tp_args
else StringType()
else None
)
value_type = (
python_type_to_snow_type(tp_args[1], is_return_type_of_sproc)[0]
if tp_args
else StringType()
else None
)
return MapType(key_type, value_type), False

Expand Down
39 changes: 37 additions & 2 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@
)

import snowflake.snowpark
from snowflake.connector.constants import FIELD_ID_TO_NAME
from snowflake.connector.cursor import ResultMetadata, SnowflakeCursor
from snowflake.connector.description import OPERATING_SYSTEM, PLATFORM
from snowflake.connector.options import MissingOptionalDependency, ModuleLikeObject
from snowflake.connector.version import VERSION as connector_version
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark.context import _should_use_structured_type_semantics
from snowflake.snowpark.row import Row
from snowflake.snowpark.version import VERSION as snowpark_version

Expand Down Expand Up @@ -698,19 +700,50 @@ def column_to_bool(col_):
return bool(col_)


def _parse_result_meta(
result_meta: Union[List[ResultMetadata], List["ResultMetadataV2"]]
) -> Tuple[Optional[List[str]], Optional[List[Callable]]]:
"""
Takes a list of result metadata objects and returns a list containing the names of all fields as
well as a list of functions that wrap specific columns.

A column type may need to be wrapped if the connector is unable to provide the columns data in
an expected format. For example StructType columns are returned as dict objects, but are better
represented as Row objects.
"""
if not result_meta:
return None, None
col_names = []
wrappers = []
for col in result_meta:
col_names.append(col.name)
if (
_should_use_structured_type_semantics()
and FIELD_ID_TO_NAME[col.type_code] == "OBJECT"
and col.fields is not None
):
wrappers.append(lambda x: Row(**x))
else:
wrappers.append(None)
return col_names, wrappers


def result_set_to_rows(
result_set: List[Any],
result_meta: Optional[Union[List[ResultMetadata], List["ResultMetadataV2"]]] = None,
case_sensitive: bool = True,
) -> List[Row]:
col_names = [col.name for col in result_meta] if result_meta else None
col_names, wrappers = _parse_result_meta(result_meta or [])
rows = []
row_struct = Row
if col_names:
row_struct = (
Row._builder.build(*col_names).set_case_sensitive(case_sensitive).to_row()
)
for data in result_set:
if wrappers:
data = [wrap(d) if wrap else d for wrap, d in zip(wrappers, data)]

if data is None:
raise ValueError("Result returned from Python connector is None")
row = row_struct(*data)
Expand All @@ -723,7 +756,7 @@ def result_set_to_iter(
result_meta: Optional[List[ResultMetadata]] = None,
case_sensitive: bool = True,
) -> Iterator[Row]:
col_names = [col.name for col in result_meta] if result_meta else None
col_names, wrappers = _parse_result_meta(result_meta)
row_struct = Row
if col_names:
row_struct = (
Expand All @@ -732,6 +765,8 @@ def result_set_to_iter(
for data in result_set:
if data is None:
raise ValueError("Result returned from Python connector is None")
if wrappers:
data = [wrap(d) if wrap else d for wrap, d in zip(wrappers, data)]
row = row_struct(*data)
yield row

Expand Down
13 changes: 11 additions & 2 deletions src/snowflake/snowpark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Callable, Optional

import snowflake.snowpark
import threading

_use_scoped_temp_objects = True

Expand All @@ -21,8 +22,16 @@
_should_continue_registration: Optional[Callable[..., bool]] = None


# Global flag that determines if structured type semantics should be used
_should_use_structured_type_semantics = False
# Internal-only global flag that determines if structured type semantics should be used
_use_structured_type_semantics = False
_use_structured_type_semantics_lock = threading.RLock()


def _should_use_structured_type_semantics():
global _use_structured_type_semantics
global _use_structured_type_semantics_lock
with _use_structured_type_semantics_lock:
return _use_structured_type_semantics


def get_active_session() -> "snowflake.snowpark.Session":
Expand Down
9 changes: 9 additions & 0 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import pkg_resources

import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
import snowflake.snowpark.context as context
from snowflake.connector import ProgrammingError, SnowflakeConnection
from snowflake.connector.options import installed_pandas, pandas
from snowflake.connector.pandas_tools import write_pandas
Expand Down Expand Up @@ -3294,6 +3295,14 @@ def convert_row_to_list(
data_type, (MapType, StructType)
):
converted_row.append(json.dumps(value, cls=PythonObjJSONEncoder))
elif (
isinstance(value, Row)
and isinstance(data_type, StructType)
and context._should_use_structured_type_semantics()
):
converted_row.append(
json.dumps(value.as_dict(), cls=PythonObjJSONEncoder)
)
elif isinstance(data_type, VariantType):
converted_row.append(json.dumps(value, cls=PythonObjJSONEncoder))
elif isinstance(data_type, GeographyType):
Expand Down
Loading
Loading