Skip to content
Merged
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/_internal/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def convert_sf_to_sp_type(
return ArrayType(semi_structured_fill)
if column_type_name == "VARIANT":
return VariantType()
if context._should_use_structured_type_semantics() and column_type_name == "OBJECT":
return StructType()
if column_type_name in {"OBJECT", "MAP"}:
return MapType(semi_structured_fill, semi_structured_fill)
if column_type_name == "GEOGRAPHY":
Expand Down Expand Up @@ -690,6 +692,10 @@ def python_type_to_snow_type(
if tp_args
else None
)
if (
key_type is None or value_type is None
) and context._should_use_structured_type_semantics():
return StructType(), False
return MapType(key_type, value_type), False

if installed_pandas:
Expand Down
16 changes: 15 additions & 1 deletion src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import typing
import zipfile
from copy import deepcopy
from enum import Enum
from logging import getLogger
from types import ModuleType
from typing import (
Expand Down Expand Up @@ -112,6 +113,13 @@ class UDFColumn(NamedTuple):
name: str


class RegistrationType(Enum):
UDF = "UDF"
UDAF = "UDAF"
UDTF = "UDTF"
SPROC = "SPROC"


class ExtensionFunctionProperties:
"""
This is a data class to hold all information, resolved or otherwise, about a UDF/UDTF/UDAF/Sproc object
Expand Down Expand Up @@ -1266,6 +1274,7 @@ def create_python_udf_or_sp(
replace: bool,
if_not_exists: bool,
raw_imports: Optional[List[Union[str, Tuple[str, str]]]],
registration_type: RegistrationType,
inline_python_code: Optional[str] = None,
execute_as: Optional[typing.Literal["caller", "owner", "restricted caller"]] = None,
api_call_source: Optional[str] = None,
Expand All @@ -1288,7 +1297,12 @@ def create_python_udf_or_sp(

if replace and if_not_exists:
raise ValueError("options replace and if_not_exists are incompatible")
if isinstance(return_type, StructType) and not return_type.structured:

if (
isinstance(return_type, StructType)
and not return_type.structured
and registration_type in {RegistrationType.UDTF, RegistrationType.SPROC}
):
return_sql = f'RETURNS TABLE ({",".join(f"{field.name} {convert_sp_to_sf_type(field.datatype)}" for field in return_type.fields)})'
elif installed_pandas and isinstance(return_type, PandasDataFrameType):
return_sql = f'RETURNS TABLE ({",".join(f"{name} {convert_sp_to_sf_type(datatype)}" for name, datatype in zip(return_type.col_names, return_type.col_types))})'
Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
from snowflake.snowpark._internal.udf_utils import (
UDFColumn,
RegistrationType,
add_snowpark_package_to_sproc_packages,
check_execute_as_arg,
check_python_runtime_version,
Expand Down Expand Up @@ -1003,6 +1004,7 @@ def _do_register_sp(
all_imports=all_imports,
all_packages=all_packages,
raw_imports=imports,
registration_type=RegistrationType.SPROC,
is_permanent=is_permanent,
replace=replace,
if_not_exists=if_not_exists,
Expand Down
12 changes: 3 additions & 9 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,15 +403,9 @@ def __init__(
value_contains_null: bool = True,
) -> None:
if context._should_use_structured_type_semantics():
if (key_type is None and value_type is not None) or (
key_type is not None and value_type is None
):
raise ValueError(
"Must either set both key_type and value_type or leave both unset."
)
self.structured = (
structured if structured is not None else key_type is not None
)
if key_type is None or value_type is None:
raise ValueError("MapType requires key and value type be set.")
self.structured = True
self.key_type = key_type
self.value_type = value_type
else:
Expand Down
6 changes: 4 additions & 2 deletions src/snowflake/snowpark/udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from snowflake.snowpark._internal.type_utils import ColumnOrName, convert_sp_to_sf_type
from snowflake.snowpark._internal.udf_utils import (
UDFColumn,
RegistrationType,
check_python_runtime_version,
check_register_args,
cleanup_failed_permanent_registration,
Expand All @@ -40,7 +41,7 @@
warning,
)
from snowflake.snowpark.column import Column
from snowflake.snowpark.types import DataType, MapType
from snowflake.snowpark.types import DataType, MapType, StructType

# Python 3.8 needs to use typing.Iterable because collections.abc.Iterable is not subscriptable
# Python 3.9 can use both
Expand Down Expand Up @@ -750,7 +751,7 @@ def _do_register_udaf(
"_do_register_udaf",
"Snowflake does not support structured maps as return type for UDAFs. Downcasting to semi-structured object.",
)
return_type = MapType()
return_type = StructType()

# Capture original parameters.
if _emit_ast:
Expand Down Expand Up @@ -831,6 +832,7 @@ def _do_register_udaf(
all_imports=all_imports,
all_packages=all_packages,
raw_imports=imports,
registration_type=RegistrationType.UDAF,
is_permanent=is_permanent,
replace=replace,
if_not_exists=if_not_exists,
Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from snowflake.snowpark._internal.type_utils import ColumnOrName, convert_sp_to_sf_type
from snowflake.snowpark._internal.udf_utils import (
UDFColumn,
RegistrationType,
check_python_runtime_version,
check_register_args,
cleanup_failed_permanent_registration,
Expand Down Expand Up @@ -1026,6 +1027,7 @@ def _do_register_udf(
all_imports=all_imports,
all_packages=all_packages,
raw_imports=imports,
registration_type=RegistrationType.UDF,
is_permanent=is_permanent,
replace=replace,
if_not_exists=if_not_exists,
Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from snowflake.snowpark._internal.type_utils import ColumnOrName
from snowflake.snowpark._internal.udf_utils import (
UDFColumn,
RegistrationType,
check_python_runtime_version,
check_register_args,
cleanup_failed_permanent_registration,
Expand Down Expand Up @@ -1088,6 +1089,7 @@ def _do_register_udtf(
all_imports=all_imports,
all_packages=all_packages,
raw_imports=imports,
registration_type=RegistrationType.UDTF,
is_permanent=is_permanent,
replace=replace,
if_not_exists=if_not_exists,
Expand Down
19 changes: 9 additions & 10 deletions tests/integ/scala/test_datatype_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,11 +587,10 @@ def test_structured_dtypes_negative(structured_type_session, structured_type_sup
if not structured_type_support:
pytest.skip("Test requires structured type support.")

# Maptype requires both key and value type be set if either is set
with pytest.raises(
ValueError,
match="Must either set both key_type and value_type or leave both unset.",
):
with pytest.raises(ValueError, match="MapType requires key and value type be set."):
MapType()

with pytest.raises(ValueError, match="MapType requires key and value type be set."):
MapType(StringType())


Expand Down Expand Up @@ -633,7 +632,7 @@ def finish(self) -> dict:
"Snowflake does not support structured maps as return type for UDAFs. Downcasting to semi-structured object."
in caplog.text
)
assert MapCollector._return_type == MapType()
assert MapCollector._return_type == StructType()


@pytest.mark.skipif(
Expand Down Expand Up @@ -1054,8 +1053,8 @@ def test_structured_dtypes_cast(structured_type_session, structured_type_support
expected_semi_schema = StructType(
[
StructField("ARR", ArrayType(), nullable=True),
StructField("MAP", MapType(), nullable=True),
StructField("OBJ", MapType(), nullable=True),
StructField("MAP", StructType(), nullable=True),
StructField("OBJ", StructType(), nullable=True),
]
)
expected_structured_schema = StructType(
Expand Down Expand Up @@ -1084,8 +1083,8 @@ def test_structured_dtypes_cast(structured_type_session, structured_type_support
schema=StructType(
[
StructField("arr", ArrayType()),
StructField("map", MapType()),
StructField("obj", MapType()),
StructField("map", StructType()),
StructField("obj", StructType()),
]
),
)
Expand Down
11 changes: 11 additions & 0 deletions tests/integ/scala/test_udf_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
NullType,
ShortType,
StringType,
StructType,
TimestampType,
TimeType,
VariantType,
Expand Down Expand Up @@ -1179,3 +1180,13 @@ def norm_udf(my_val: float, my_max: float, my_min: float) -> float:
Row(1.0),
],
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="sql not supported in local testing.",
)
def test_object_return(session):
udf1 = udf(lambda: {"foo": "bar"}, return_type=StructType())
desc_df = session.sql(f"SELECT GET_DDL('function', '{udf1.name}()')")
assert "\nRETURNS OBJECT\n" in desc_df.collect()[0][0]
2 changes: 1 addition & 1 deletion tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1997,7 +1997,7 @@ def test_show_dataframe_spark(session):
),
),
StructField("col_17", ArrayType()),
StructField("col_18", MapType()),
StructField("col_18", StructType()),
]
)
df = session.create_dataframe([data], schema=schema)
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/test_udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from snowflake.snowpark import Session
from snowflake.snowpark._internal.udf_utils import (
RegistrationType,
add_snowpark_package_to_sproc_packages,
cleanup_failed_permanent_registration,
create_python_udf_or_sp,
Expand Down Expand Up @@ -294,6 +295,7 @@ def test_copy_grant_for_udf_or_sp_registration(
all_imports="",
all_packages="",
raw_imports=None,
registration_type=RegistrationType.UDF,
is_permanent=True,
replace=False,
if_not_exists=False,
Expand All @@ -302,7 +304,6 @@ def test_copy_grant_for_udf_or_sp_registration(
if copy_grants:
mock_run_query.assert_called_once()
assert "COPY GRANTS" in mock_run_query.call_args[0][0]
pass


def test_create_python_udf_or_sp_with_none_session():
Expand All @@ -324,6 +325,7 @@ def test_create_python_udf_or_sp_with_none_session():
all_imports="",
all_packages="",
raw_imports=None,
registration_type=RegistrationType.UDF,
is_permanent=True,
replace=False,
if_not_exists=False,
Expand Down
Loading