Skip to content

Commit 888cec5

Browse files
authored
SNOW-1829870: Allow structured types to be enabled by default (#2727)
1 parent a79fe9f commit 888cec5

File tree

14 files changed

+286
-120
lines changed

14 files changed

+286
-120
lines changed

src/snowflake/snowpark/_internal/analyzer/datatype_mapper.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,16 @@ def to_sql(
202202
return f"'{binascii.hexlify(bytes(value)).decode()}' :: BINARY"
203203

204204
if isinstance(value, (list, tuple, array)) and isinstance(datatype, ArrayType):
205-
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: ARRAY"
205+
type_str = "ARRAY"
206+
if datatype.structured:
207+
type_str = convert_sp_to_sf_type(datatype)
208+
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: {type_str}"
206209

207210
if isinstance(value, dict) and isinstance(datatype, MapType):
208-
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: OBJECT"
211+
type_str = "OBJECT"
212+
if datatype.structured:
213+
type_str = convert_sp_to_sf_type(datatype)
214+
return f"PARSE_JSON({str_to_sql(json.dumps(value, cls=PythonObjJSONEncoder))}) :: {type_str}"
209215

210216
if isinstance(datatype, VariantType):
211217
# PARSE_JSON returns VARIANT, so no need to append :: VARIANT here explicitly.
@@ -260,11 +266,14 @@ def schema_expression(data_type: DataType, is_nullable: bool) -> str:
260266
return "to_timestamp('2020-09-16 06:30:00')"
261267
if isinstance(data_type, ArrayType):
262268
if data_type.structured:
269+
assert isinstance(data_type.element_type, DataType)
263270
element = schema_expression(data_type.element_type, is_nullable)
264271
return f"to_array({element}) :: {convert_sp_to_sf_type(data_type)}"
265272
return "to_array(0)"
266273
if isinstance(data_type, MapType):
267274
if data_type.structured:
275+
assert isinstance(data_type.key_type, DataType)
276+
assert isinstance(data_type.value_type, DataType)
268277
key = schema_expression(data_type.key_type, is_nullable)
269278
value = schema_expression(data_type.value_type, is_nullable)
270279
return f"object_construct_keep_null({key}, {value}) :: {convert_sp_to_sf_type(data_type)}"

src/snowflake/snowpark/_internal/type_utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def convert_metadata_to_sp_type(
159159
[
160160
StructField(
161161
field.name
162-
if context._should_use_structured_type_semantics
162+
if context._should_use_structured_type_semantics()
163163
else quote_name(field.name, keep_case=True),
164164
convert_metadata_to_sp_type(field, max_string_size),
165165
nullable=field.is_nullable,
@@ -187,12 +187,15 @@ def convert_sf_to_sp_type(
187187
max_string_size: int,
188188
) -> DataType:
189189
"""Convert the Snowflake logical type to the Snowpark type."""
190+
semi_structured_fill = (
191+
None if context._should_use_structured_type_semantics() else StringType()
192+
)
190193
if column_type_name == "ARRAY":
191-
return ArrayType(StringType())
194+
return ArrayType(semi_structured_fill)
192195
if column_type_name == "VARIANT":
193196
return VariantType()
194197
if column_type_name in {"OBJECT", "MAP"}:
195-
return MapType(StringType(), StringType())
198+
return MapType(semi_structured_fill, semi_structured_fill)
196199
if column_type_name == "GEOGRAPHY":
197200
return GeographyType()
198201
if column_type_name == "GEOMETRY":
@@ -534,7 +537,10 @@ def merge_type(a: DataType, b: DataType, name: Optional[str] = None) -> DataType
534537
return a
535538

536539

537-
def python_value_str_to_object(value, tp: DataType) -> Any:
540+
def python_value_str_to_object(value, tp: Optional[DataType]) -> Any:
541+
if tp is None:
542+
return None
543+
538544
if isinstance(tp, StringType):
539545
return value
540546

@@ -643,7 +649,7 @@ def python_type_to_snow_type(
643649
element_type = (
644650
python_type_to_snow_type(tp_args[0], is_return_type_of_sproc)[0]
645651
if tp_args
646-
else StringType()
652+
else None
647653
)
648654
return ArrayType(element_type), False
649655

@@ -653,12 +659,12 @@ def python_type_to_snow_type(
653659
key_type = (
654660
python_type_to_snow_type(tp_args[0], is_return_type_of_sproc)[0]
655661
if tp_args
656-
else StringType()
662+
else None
657663
)
658664
value_type = (
659665
python_type_to_snow_type(tp_args[1], is_return_type_of_sproc)[0]
660666
if tp_args
661-
else StringType()
667+
else None
662668
)
663669
return MapType(key_type, value_type), False
664670

src/snowflake/snowpark/context.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Callable, Optional
88

99
import snowflake.snowpark
10+
import threading
1011

1112
_use_scoped_temp_objects = True
1213

@@ -21,8 +22,16 @@
2122
_should_continue_registration: Optional[Callable[..., bool]] = None
2223

2324

24-
# Global flag that determines if structured type semantics should be used
25-
_should_use_structured_type_semantics = False
25+
# Internal-only global flag that determines if structured type semantics should be used
26+
_use_structured_type_semantics = False
27+
_use_structured_type_semantics_lock = threading.RLock()
28+
29+
30+
def _should_use_structured_type_semantics():
31+
global _use_structured_type_semantics
32+
global _use_structured_type_semantics_lock
33+
with _use_structured_type_semantics_lock:
34+
return _use_structured_type_semantics
2635

2736

2837
def get_active_session() -> "snowflake.snowpark.Session":

src/snowflake/snowpark/types.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
from enum import Enum
1212
from typing import Generic, List, Optional, Type, TypeVar, Union, Dict, Any
1313

14+
import snowflake.snowpark.context as context
1415
import snowflake.snowpark._internal.analyzer.expression as expression
1516
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
1617

1718
# Use correct version from here:
1819
from snowflake.snowpark._internal.utils import installed_pandas, pandas, quote_name
19-
import snowflake.snowpark.context as context
2020

2121
# TODO: connector installed_pandas is broken. If pyarrow is not installed, but pandas is this function returns the wrong answer.
2222
# The core issue is that in the connector detection of both pandas/arrow are mixed, which is wrong.
@@ -334,16 +334,22 @@ class ArrayType(DataType):
334334
def __init__(
335335
self,
336336
element_type: Optional[DataType] = None,
337-
structured: bool = False,
337+
structured: Optional[bool] = None,
338338
) -> None:
339-
self.structured = structured
340-
self.element_type = element_type if element_type else StringType()
339+
if context._should_use_structured_type_semantics():
340+
self.structured = (
341+
structured if structured is not None else element_type is not None
342+
)
343+
self.element_type = element_type
344+
else:
345+
self.structured = structured or False
346+
self.element_type = element_type if element_type else StringType()
341347

342348
def __repr__(self) -> str:
343349
return f"ArrayType({repr(self.element_type) if self.element_type else ''})"
344350

345351
def _as_nested(self) -> "ArrayType":
346-
if not context._should_use_structured_type_semantics:
352+
if not context._should_use_structured_type_semantics():
347353
return self
348354
element_type = self.element_type
349355
if isinstance(element_type, (ArrayType, MapType, StructType)):
@@ -378,6 +384,10 @@ def json_value(self) -> Dict[str, Any]:
378384

379385
def _fill_ast(self, ast: proto.SpDataType) -> None:
380386
ast.sp_array_type.structured = self.structured
387+
if self.element_type is None:
388+
raise NotImplementedError(
389+
"SNOW-1862700: AST does not support empty element_type."
390+
)
381391
self.element_type._fill_ast(ast.sp_array_type.ty)
382392

383393

@@ -388,20 +398,36 @@ def __init__(
388398
self,
389399
key_type: Optional[DataType] = None,
390400
value_type: Optional[DataType] = None,
391-
structured: bool = False,
401+
structured: Optional[bool] = None,
392402
) -> None:
393-
self.structured = structured
394-
self.key_type = key_type if key_type else StringType()
395-
self.value_type = value_type if value_type else StringType()
403+
if context._should_use_structured_type_semantics():
404+
if (key_type is None and value_type is not None) or (
405+
key_type is not None and value_type is None
406+
):
407+
raise ValueError(
408+
"Must either set both key_type and value_type or leave both unset."
409+
)
410+
self.structured = (
411+
structured if structured is not None else key_type is not None
412+
)
413+
self.key_type = key_type
414+
self.value_type = value_type
415+
else:
416+
self.structured = structured or False
417+
self.key_type = key_type if key_type else StringType()
418+
self.value_type = value_type if value_type else StringType()
396419

397420
def __repr__(self) -> str:
398-
return f"MapType({repr(self.key_type) if self.key_type else ''}, {repr(self.value_type) if self.value_type else ''})"
421+
type_str = ""
422+
if self.key_type and self.value_type:
423+
type_str = f"{repr(self.key_type)}, {repr(self.value_type)}"
424+
return f"MapType({type_str})"
399425

400426
def is_primitive(self):
401427
return False
402428

403429
def _as_nested(self) -> "MapType":
404-
if not context._should_use_structured_type_semantics:
430+
if not context._should_use_structured_type_semantics():
405431
return self
406432
value_type = self.value_type
407433
if isinstance(value_type, (ArrayType, MapType, StructType)):
@@ -447,6 +473,10 @@ def valueType(self):
447473

448474
def _fill_ast(self, ast: proto.SpDataType) -> None:
449475
ast.sp_map_type.structured = self.structured
476+
if self.key_type is None or self.value_type is None:
477+
raise NotImplementedError(
478+
"SNOW-1862700: AST does not support empty key or value type."
479+
)
450480
self.key_type._fill_ast(ast.sp_map_type.key_ty)
451481
self.value_type._fill_ast(ast.sp_map_type.value_ty)
452482

@@ -578,7 +608,7 @@ def __init__(
578608

579609
@property
580610
def name(self) -> str:
581-
if self._is_column or not context._should_use_structured_type_semantics:
611+
if self._is_column or not context._should_use_structured_type_semantics():
582612
return self.column_identifier.name
583613
else:
584614
return self._name
@@ -593,7 +623,7 @@ def name(self, n: Union[ColumnIdentifier, str]) -> None:
593623
self.column_identifier = ColumnIdentifier(n)
594624

595625
def _as_nested(self) -> "StructField":
596-
if not context._should_use_structured_type_semantics:
626+
if not context._should_use_structured_type_semantics():
597627
return self
598628
datatype = self.datatype
599629
if isinstance(datatype, (ArrayType, MapType, StructType)):
@@ -651,9 +681,17 @@ class StructType(DataType):
651681
"""Represents a table schema or structured column. Contains :class:`StructField` for each field."""
652682

653683
def __init__(
654-
self, fields: Optional[List["StructField"]] = None, structured=False
684+
self,
685+
fields: Optional[List["StructField"]] = None,
686+
structured: Optional[bool] = None,
655687
) -> None:
656-
self.structured = structured
688+
if context._should_use_structured_type_semantics():
689+
self.structured = (
690+
structured if structured is not None else fields is not None
691+
)
692+
else:
693+
self.structured = structured or False
694+
657695
self.fields = []
658696
for field in fields or []:
659697
self.add(field)
@@ -683,7 +721,7 @@ def add(
683721
return self
684722

685723
def _as_nested(self) -> "StructType":
686-
if not context._should_use_structured_type_semantics:
724+
if not context._should_use_structured_type_semantics():
687725
return self
688726
return StructType(
689727
[field._as_nested() for field in self.fields], self.structured

src/snowflake/snowpark/udaf.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,10 @@
3737
TempObjectType,
3838
parse_positional_args_to_list,
3939
publicapi,
40+
warning,
4041
)
4142
from snowflake.snowpark.column import Column
42-
from snowflake.snowpark.types import DataType
43+
from snowflake.snowpark.types import DataType, MapType
4344

4445
# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
4546
# Python 3.9 can use both
@@ -710,6 +711,14 @@ def _do_register_udaf(
710711
name,
711712
)
712713

714+
if isinstance(return_type, MapType):
715+
if return_type.structured:
716+
warning(
717+
"_do_register_udaf",
718+
"Snowflake does not support structured maps as return type for UDAFs. Downcasting to semi-structured object.",
719+
)
720+
return_type = MapType()
721+
713722
# Capture original parameters.
714723
if _emit_ast:
715724
stmt = self._session._ast_batch.assign()

src/snowflake/snowpark/udtf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,10 @@ def _do_register_udtf(
969969
output_schema=output_schema,
970970
)
971971

972+
# Structured Struct is interpreted as Object by function registration
973+
# Force unstructured to ensure Table return type.
974+
output_schema.structured = False
975+
972976
# Capture original parameters.
973977
if _emit_ast:
974978
stmt = self._session._ast_batch.assign()

0 commit comments

Comments
 (0)