Skip to content

Commit f83ae56

Browse files
Merge branch 'main' into helmeleegy-SNOW-1819523
2 parents 2b668f2 + 98330fa commit f83ae56

File tree

9 files changed

+262
-31
lines changed

9 files changed

+262
-31
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
- Added documentation for `DataFrame.map`.
6262
- Improve performance of `DataFrame.apply` by mapping numpy functions to snowpark functions if possible.
6363
- Added documentation on the extent of Snowpark pandas interoperability with scikit-learn
64+
- Infer return type of functions in `Series.map`, `Series.apply` and `DataFrame.map` if type-hint is not provided.
6465

6566
## 1.26.0 (2024-12-05)
6667

src/snowflake/snowpark/_internal/utils.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@
4545
)
4646

4747
import snowflake.snowpark
48+
from snowflake.connector.constants import FIELD_ID_TO_NAME
4849
from snowflake.connector.cursor import ResultMetadata, SnowflakeCursor
4950
from snowflake.connector.description import OPERATING_SYSTEM, PLATFORM
5051
from snowflake.connector.options import MissingOptionalDependency, ModuleLikeObject
5152
from snowflake.connector.version import VERSION as connector_version
5253
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
54+
from snowflake.snowpark.context import _should_use_structured_type_semantics
5355
from snowflake.snowpark.row import Row
5456
from snowflake.snowpark.version import VERSION as snowpark_version
5557

@@ -698,19 +700,50 @@ def column_to_bool(col_):
698700
return bool(col_)
699701

700702

703+
def _parse_result_meta(
704+
result_meta: Union[List[ResultMetadata], List["ResultMetadataV2"]]
705+
) -> Tuple[Optional[List[str]], Optional[List[Callable]]]:
706+
"""
707+
Takes a list of result metadata objects and returns a list containing the names of all fields as
708+
well as a list of functions that wrap specific columns.
709+
710+
A column type may need to be wrapped if the connector is unable to provide the columns data in
711+
an expected format. For example StructType columns are returned as dict objects, but are better
712+
represented as Row objects.
713+
"""
714+
if not result_meta:
715+
return None, None
716+
col_names = []
717+
wrappers = []
718+
for col in result_meta:
719+
col_names.append(col.name)
720+
if (
721+
_should_use_structured_type_semantics()
722+
and FIELD_ID_TO_NAME[col.type_code] == "OBJECT"
723+
and col.fields is not None
724+
):
725+
wrappers.append(lambda x: Row(**x))
726+
else:
727+
wrappers.append(None)
728+
return col_names, wrappers
729+
730+
701731
def result_set_to_rows(
702732
result_set: List[Any],
703733
result_meta: Optional[Union[List[ResultMetadata], List["ResultMetadataV2"]]] = None,
704734
case_sensitive: bool = True,
705735
) -> List[Row]:
706-
col_names = [col.name for col in result_meta] if result_meta else None
736+
col_names, wrappers = _parse_result_meta(result_meta or [])
707737
rows = []
708738
row_struct = Row
709739
if col_names:
710740
row_struct = (
711741
Row._builder.build(*col_names).set_case_sensitive(case_sensitive).to_row()
712742
)
713743
for data in result_set:
744+
if wrappers:
745+
data = [wrap(d) if wrap else d for wrap, d in zip(wrappers, data)]
746+
714747
if data is None:
715748
raise ValueError("Result returned from Python connector is None")
716749
row = row_struct(*data)
@@ -723,7 +756,7 @@ def result_set_to_iter(
723756
result_meta: Optional[List[ResultMetadata]] = None,
724757
case_sensitive: bool = True,
725758
) -> Iterator[Row]:
726-
col_names = [col.name for col in result_meta] if result_meta else None
759+
col_names, wrappers = _parse_result_meta(result_meta)
727760
row_struct = Row
728761
if col_names:
729762
row_struct = (
@@ -732,6 +765,8 @@ def result_set_to_iter(
732765
for data in result_set:
733766
if data is None:
734767
raise ValueError("Result returned from Python connector is None")
768+
if wrappers:
769+
data = [wrap(d) if wrap else d for wrap, d in zip(wrappers, data)]
735770
row = row_struct(*data)
736771
yield row
737772

src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py

Lines changed: 135 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Hashable
99
from enum import Enum, auto
1010
from typing import Any, Callable, Literal, Optional, Union
11+
from datetime import datetime
1112

1213
import cloudpickle
1314
import numpy as np
@@ -21,6 +22,9 @@
2122
from snowflake.snowpark._internal.udf_utils import get_types_from_type_hints
2223
import functools
2324
from snowflake.snowpark.column import Column as SnowparkColumn
25+
from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import (
26+
TimedeltaType,
27+
)
2428
from snowflake.snowpark.modin.plugin._internal.type_utils import (
2529
infer_object_type,
2630
pandas_lit,
@@ -45,13 +49,19 @@
4549
from snowflake.snowpark.session import Session
4650
from snowflake.snowpark.types import (
4751
ArrayType,
52+
BinaryType,
53+
BooleanType,
4854
DataType,
55+
_IntegralType,
56+
_FractionalType,
4957
IntegerType,
5058
LongType,
5159
MapType,
60+
NullType,
5261
PandasDataFrameType,
5362
PandasSeriesType,
5463
StringType,
64+
TimestampType,
5565
VariantType,
5666
)
5767
from snowflake.snowpark.udf import UserDefinedFunction
@@ -113,7 +123,7 @@ class GroupbyApplySortMethod(Enum):
113123

114124
def check_return_variant_and_get_return_type(func: Callable) -> tuple[bool, DataType]:
115125
"""Check whether the function returns a variant in Snowflake, and get its return type."""
116-
return_type = deduce_return_type_from_function(func)
126+
return_type = deduce_return_type_from_function(func, None)
117127
if return_type is None or isinstance(
118128
return_type, (VariantType, PandasSeriesType, PandasDataFrameType)
119129
):
@@ -756,6 +766,8 @@ def apply_func(x): # type: ignore[no-untyped-def] # pragma: no cover
756766
else:
757767

758768
def apply_func(x): # type: ignore[no-untyped-def] # pragma: no cover
769+
# TODO SNOW-1874779: Add verification here to ensure inferred type matches
770+
# actual type.
759771
return x.apply(func, args=args, **kwargs)
760772

761773
func_udf = sp_func.udf(
@@ -829,14 +841,128 @@ def convert_numpy_int_result_to_int(value: Any) -> Any:
829841
)
830842

831843

844+
DUMMY_BOOL_INPUT = native_pd.Series([False, True])
845+
DUMMY_INT_INPUT = native_pd.Series(
846+
[-37, -9, -2, -1, 0, 2, 3, 5, 7, 9, 13, 16, 20]
847+
+ np.power(10, np.arange(19)).tolist()
848+
+ np.multiply(-1, np.power(10, np.arange(19))).tolist()
849+
)
850+
DUMMY_FLOAT_INPUT = native_pd.Series(
851+
[-9.9, -2.2, -1.0, 0.0, 0.5, 0.33, None, 0.99, 2.0, 3.0, 5.0, 7.7, 9.898989]
852+
+ np.power(10.1, np.arange(19)).tolist()
853+
+ np.multiply(-1.0, np.power(10.1, np.arange(19))).tolist()
854+
)
855+
DUMMY_STRING_INPUT = native_pd.Series(
856+
["", "a", "A", "0", "1", "01", "123", "-1", "-12", "true", "True", "false", "False"]
857+
+ [None, "null", "Jane Smith", "[email protected]", "[email protected]"]
858+
+ ["650-592-4563", "Jane Smith, 123 Main St., Anytown, CA 12345"]
859+
+ ["2020-12-23", "2020-12-23 12:34:56", "08/08/2024", "07-08-2022", "12:34:56"]
860+
+ ["ABC", "bat-man", "super_man", "1@#$%^&*()_+", "<>?:{}|[]\\;'/.,", "<tag>"]
861+
)
862+
DUMMY_BINARY_INPUT = native_pd.Series(
863+
[bytes("snow", "utf-8"), bytes("flake", "utf-8"), bytes("12", "utf-8"), None]
864+
)
865+
DUMMY_TIMESTAMP_INPUT = native_pd.to_datetime(
866+
["2020-12-31 00:00:00", "2020-01-01 00:00:00", native_pd.Timestamp.min] # past
867+
+ ["2090-01-01 00:00:00", "2090-12-31 00:00:00", native_pd.Timestamp.max] # future
868+
+ [datetime.today(), None], # current
869+
format="mixed",
870+
)
871+
872+
873+
def infer_return_type_using_dummy_data(
874+
func: Callable, input_type: DataType, **kwargs: Any
875+
) -> Optional[DataType]:
876+
"""
877+
Infer the return type of given function by applying it to a dummy input.
878+
This method only supports the following input types: _IntegralType, _FractionalType,
879+
StringType, BooleanType, TimestampType, BinaryType.
880+
Args:
881+
func: The function to infer the return type from.
882+
input_type: The input type of the function.
883+
**kwargs : Additional keyword arguments to pass as keywords arguments to func.
884+
Returns:
885+
The inferred return type of the function. If the return type cannot be inferred,
886+
return None.
887+
"""
888+
if input_type is None:
889+
return None
890+
input_data = None
891+
if isinstance(input_type, _IntegralType):
892+
input_data = DUMMY_INT_INPUT
893+
elif isinstance(input_type, _FractionalType):
894+
input_data = DUMMY_FLOAT_INPUT
895+
elif isinstance(input_type, StringType):
896+
input_data = DUMMY_STRING_INPUT
897+
elif isinstance(input_type, BooleanType):
898+
input_data = DUMMY_BOOL_INPUT
899+
elif isinstance(input_type, TimestampType):
900+
input_data = DUMMY_TIMESTAMP_INPUT
901+
elif isinstance(input_type, BinaryType):
902+
input_data = DUMMY_BINARY_INPUT
903+
else:
904+
return None
905+
906+
def merge_types(t1: DataType, t2: DataType) -> DataType:
907+
"""
908+
Merge two types into one as per the following rules:
909+
- Null + T = T
910+
- T + Null = T
911+
- T1 + T2 = T1 where T1 == T2
912+
- T1 + T2 = Variant where T1 != T2
913+
Args:
914+
t1: first type to merge.
915+
t2: second type to merge.
916+
917+
Returns:
918+
Merged type of t1 and t2.
919+
"""
920+
# treat NullType as None
921+
t1 = None if t1 == NullType() else t1
922+
t2 = None if t2 == NullType() else t2
923+
924+
if t1 is None:
925+
return t2
926+
if t2 is None:
927+
return t1
928+
if t1 == t2:
929+
return t1
930+
if isinstance(t1, MapType) and isinstance(t2, MapType):
931+
return MapType(
932+
merge_types(t1.key_type, t2.key_type),
933+
merge_types(t1.value_type, t2.value_type),
934+
)
935+
if isinstance(t1, ArrayType) and isinstance(t2, ArrayType):
936+
return ArrayType(merge_types(t1.element_type, t2.element_type))
937+
return VariantType()
938+
939+
inferred_type = None
940+
for x in input_data:
941+
try:
942+
inferred_type = merge_types(
943+
inferred_type, infer_object_type(func(x, **kwargs))
944+
)
945+
except Exception:
946+
pass
947+
948+
if isinstance(inferred_type, TimedeltaType):
949+
# TODO: SNOW-1619940: pd.Timedelta is encoded as string.
950+
return StringType()
951+
return inferred_type
952+
953+
832954
def deduce_return_type_from_function(
833-
func: Union[AggFuncType, UserDefinedFunction]
955+
func: Union[AggFuncType, UserDefinedFunction],
956+
input_type: Optional[DataType],
957+
**kwargs: Any,
834958
) -> Optional[DataType]:
835959
"""
836960
Deduce return type if possible from a function, list, dict or type object. List will be mapped to ArrayType(),
837961
dict to MapType(), and if a type object (e.g., str) is given a mapping will be consulted.
838962
Args:
839963
func: callable function, object or Snowpark UserDefinedFunction that can be passed in pandas to reference a function.
964+
input_type: input data type this function is applied to.
965+
**kwargs : Additional keyword arguments to pass as keywords arguments to func.
840966
841967
Returns:
842968
Snowpark Datatype or None if no return type could be deduced.
@@ -860,13 +986,17 @@ def deduce_return_type_from_function(
860986
else:
861987
# handle special case 'object' type, in this case use Variant Type.
862988
# Catch potential TypeError exception here from python_type_to_snow_type.
863-
# If it is not the object type, return None to indicate that type hint could not be extracted successfully.
989+
# If it is not the object type, return None to indicate that type hint could not
990+
# be extracted successfully.
864991
try:
865-
return get_types_from_type_hints(func, TempObjectType.FUNCTION)[0]
992+
return_type = get_types_from_type_hints(func, TempObjectType.FUNCTION)[0]
993+
if return_type is not None:
994+
return return_type
866995
except TypeError as te:
867996
if str(te) == "invalid type <class 'object'>":
868997
return VariantType()
869-
return None
998+
# infer return type using dummy data.
999+
return infer_return_type_using_dummy_data(func, input_type, **kwargs)
8701000

8711001

8721002
def sort_apply_udtf_result_columns_by_pandas_positions(

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8672,8 +8672,8 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no
86728672
)
86738673

86748674
# Extract return type from annotations (or lookup for known pandas functions) for func object,
8675-
# if not return type could be extracted the variable will hold None.
8676-
return_type = deduce_return_type_from_function(func)
8675+
# if no return type could be extracted the variable will hold None.
8676+
return_type = deduce_return_type_from_function(func, None)
86778677

86788678
# Check whether return_type has been extracted. If return type is not
86798679
# a Series, tuple or list object, compute df.apply using a vUDF. In this case no column expansion needs to
@@ -8766,7 +8766,9 @@ def applymap(
87668766
Function to apply to each element of the QueryCompiler.
87678767
na_action: If 'ignore', propagate NULL values
87688768
*args : iterable
8769+
Positional arguments passed to func after the input data.
87698770
**kwargs : dict
8771+
Additional keyword arguments to pass as keywords arguments to func.
87708772
"""
87718773
self._raise_not_implemented_error_for_timedelta()
87728774

@@ -8799,15 +8801,17 @@ def applymap(
87998801
ErrorMessage.not_implemented(
88008802
"Snowpark pandas applymap API doesn't yet support na_action == 'ignore'"
88018803
)
8802-
return_type = deduce_return_type_from_function(func)
8803-
if not return_type:
8804-
return_type = VariantType()
88058804

88068805
# create and apply udfs on all data columns
88078806
replace_mapping = {}
88088807
for f in self._modin_frame.ordered_dataframe.schema.fields:
88098808
identifier = f.column_identifier.quoted_name
88108809
if identifier in self._modin_frame.data_column_snowflake_quoted_identifiers:
8810+
return_type = deduce_return_type_from_function(
8811+
func, f.datatype, **kwargs
8812+
)
8813+
if not return_type:
8814+
return_type = VariantType()
88118815
func_udf = create_udf_for_series_apply(
88128816
func,
88138817
return_type,

src/snowflake/snowpark/modin/plugin/docstrings/series.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,8 @@ def apply():
687687
Notes
688688
-----
689689
1. When ``func`` has a type annotation for its return value, the result will be cast
690-
to the corresponding dtype. When no type annotation is provided, data will be converted
690+
to the corresponding dtype. When no type annotation is provided, we try to infer
691+
return type using dummy data. If return type inference is not successful data will be converted
691692
to VARIANT type in Snowflake, and the result will have ``dtype=object``. In this case, the return value must
692693
be JSON-serializable, which can be a valid input to ``json.dumps`` (e.g., ``dict`` and
693694
``list`` objects are JSON-serializable, but ``bytes`` and ``datetime.datetime`` objects

tests/integ/modin/frame/test_applymap.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ def test_preserve_order():
166166
eval_snowpark_pandas_result(df, native_df, lambda x: x.applymap(lambda y: -y))
167167

168168

169+
@sql_count_checker(
170+
query_count=10,
171+
udf_count=1,
172+
high_count_expected=True,
173+
high_count_reason="udf creation",
174+
)
169175
def test_applymap_variant_json_null():
170176
def f(x):
171177
if native_pd.isna(x):
@@ -182,11 +188,5 @@ def f(x):
182188
# the last column is a variant column [None, pd.NA], where both None and pd.NA
183189
# are mapped to SQL null by Python UDF in the input
184190
df = pd.DataFrame([[1, 2, None], [3, 4, pd.NA]])
185-
with SqlCounter(query_count=9):
186-
df = df.applymap(f)
187-
188-
with SqlCounter(query_count=1, udf_count=1):
189-
assert df.isna().to_numpy().tolist() == [
190-
[False, True, True],
191-
[True, False, True],
192-
]
191+
native_df = native_pd.DataFrame([[1, 2, None], [3, 4, pd.NA]])
192+
eval_snowpark_pandas_result(df, native_df, lambda x: x.applymap(f).isna())

0 commit comments

Comments
 (0)