Skip to content

Commit 7c1ed3e

Browse files
SNOW-1852925: Add type inference for Series.apply/map and Dataframe.map (#2821)
SNOW-1852925 & SNOW-1852928 <!--- Please answer these questions before creating your pull request. Thanks! ---> 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. <!--- In this section, please add a Snowflake Jira issue number. Note that if a corresponding GitHub issue exists, you should still include the Snowflake Jira issue number. For example, for GitHub issue #1400, you should add "SNOW-1335071" here. ---> Fixes SNOW-NNNNNNN 2. Fill out the following pre-review checklist: - [ ] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [ ] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development) 3. Please describe how your code solves the related issue. Add type inference for DataFrame.map, Series.apply and Series.map
1 parent d92dee9 commit 7c1ed3e

File tree

7 files changed

+222
-26
lines changed

7 files changed

+222
-26
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
- Added documentation for `DataFrame.map`.
6161
- Improve performance of `DataFrame.apply` by mapping numpy functions to snowpark functions if possible.
6262
- Added documentation on the extent of Snowpark pandas interoperability with scikit-learn
63+
- Infer return type of functions in `Series.map`, `Series.apply` and `DataFrame.map` if type-hint is not provided.
6364

6465
## 1.26.0 (2024-12-05)
6566

src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py

Lines changed: 135 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections.abc import Hashable
99
from enum import Enum, auto
1010
from typing import Any, Callable, Literal, Optional, Union
11+
from datetime import datetime
1112

1213
import cloudpickle
1314
import numpy as np
@@ -21,6 +22,9 @@
2122
from snowflake.snowpark._internal.udf_utils import get_types_from_type_hints
2223
import functools
2324
from snowflake.snowpark.column import Column as SnowparkColumn
25+
from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import (
26+
TimedeltaType,
27+
)
2428
from snowflake.snowpark.modin.plugin._internal.type_utils import (
2529
infer_object_type,
2630
pandas_lit,
@@ -45,13 +49,19 @@
4549
from snowflake.snowpark.session import Session
4650
from snowflake.snowpark.types import (
4751
ArrayType,
52+
BinaryType,
53+
BooleanType,
4854
DataType,
55+
_IntegralType,
56+
_FractionalType,
4957
IntegerType,
5058
LongType,
5159
MapType,
60+
NullType,
5261
PandasDataFrameType,
5362
PandasSeriesType,
5463
StringType,
64+
TimestampType,
5565
VariantType,
5666
)
5767
from snowflake.snowpark.udf import UserDefinedFunction
@@ -113,7 +123,7 @@ class GroupbyApplySortMethod(Enum):
113123

114124
def check_return_variant_and_get_return_type(func: Callable) -> tuple[bool, DataType]:
115125
"""Check whether the function returns a variant in Snowflake, and get its return type."""
116-
return_type = deduce_return_type_from_function(func)
126+
return_type = deduce_return_type_from_function(func, None)
117127
if return_type is None or isinstance(
118128
return_type, (VariantType, PandasSeriesType, PandasDataFrameType)
119129
):
@@ -756,6 +766,8 @@ def apply_func(x): # type: ignore[no-untyped-def] # pragma: no cover
756766
else:
757767

758768
def apply_func(x): # type: ignore[no-untyped-def] # pragma: no cover
769+
# TODO SNOW-1874779: Add verification here to ensure inferred type matches
770+
# actual type.
759771
return x.apply(func, args=args, **kwargs)
760772

761773
func_udf = sp_func.udf(
@@ -829,14 +841,128 @@ def convert_numpy_int_result_to_int(value: Any) -> Any:
829841
)
830842

831843

844+
DUMMY_BOOL_INPUT = native_pd.Series([False, True])
845+
DUMMY_INT_INPUT = native_pd.Series(
846+
[-37, -9, -2, -1, 0, 2, 3, 5, 7, 9, 13, 16, 20]
847+
+ np.power(10, np.arange(19)).tolist()
848+
+ np.multiply(-1, np.power(10, np.arange(19))).tolist()
849+
)
850+
DUMMY_FLOAT_INPUT = native_pd.Series(
851+
[-9.9, -2.2, -1.0, 0.0, 0.5, 0.33, None, 0.99, 2.0, 3.0, 5.0, 7.7, 9.898989]
852+
+ np.power(10.1, np.arange(19)).tolist()
853+
+ np.multiply(-1.0, np.power(10.1, np.arange(19))).tolist()
854+
)
855+
DUMMY_STRING_INPUT = native_pd.Series(
856+
["", "a", "A", "0", "1", "01", "123", "-1", "-12", "true", "True", "false", "False"]
857+
+ [None, "null", "Jane Smith", "[email protected]", "[email protected]"]
858+
+ ["650-592-4563", "Jane Smith, 123 Main St., Anytown, CA 12345"]
859+
+ ["2020-12-23", "2020-12-23 12:34:56", "08/08/2024", "07-08-2022", "12:34:56"]
860+
+ ["ABC", "bat-man", "super_man", "1@#$%^&*()_+", "<>?:{}|[]\\;'/.,", "<tag>"]
861+
)
862+
DUMMY_BINARY_INPUT = native_pd.Series(
863+
[bytes("snow", "utf-8"), bytes("flake", "utf-8"), bytes("12", "utf-8"), None]
864+
)
865+
DUMMY_TIMESTAMP_INPUT = native_pd.to_datetime(
866+
["2020-12-31 00:00:00", "2020-01-01 00:00:00", native_pd.Timestamp.min] # past
867+
+ ["2090-01-01 00:00:00", "2090-12-31 00:00:00", native_pd.Timestamp.max] # future
868+
+ [datetime.today(), None], # current
869+
format="mixed",
870+
)
871+
872+
873+
def infer_return_type_using_dummy_data(
874+
func: Callable, input_type: DataType, **kwargs: Any
875+
) -> Optional[DataType]:
876+
"""
877+
Infer the return type of given function by applying it to a dummy input.
878+
This method only supports the following input types: _IntegralType, _FractionalType,
879+
StringType, BooleanType, TimestampType, BinaryType.
880+
Args:
881+
func: The function to infer the return type from.
882+
input_type: The input type of the function.
883+
**kwargs : Additional keyword arguments to pass as keywords arguments to func.
884+
Returns:
885+
The inferred return type of the function. If the return type cannot be inferred,
886+
return None.
887+
"""
888+
if input_type is None:
889+
return None
890+
input_data = None
891+
if isinstance(input_type, _IntegralType):
892+
input_data = DUMMY_INT_INPUT
893+
elif isinstance(input_type, _FractionalType):
894+
input_data = DUMMY_FLOAT_INPUT
895+
elif isinstance(input_type, StringType):
896+
input_data = DUMMY_STRING_INPUT
897+
elif isinstance(input_type, BooleanType):
898+
input_data = DUMMY_BOOL_INPUT
899+
elif isinstance(input_type, TimestampType):
900+
input_data = DUMMY_TIMESTAMP_INPUT
901+
elif isinstance(input_type, BinaryType):
902+
input_data = DUMMY_BINARY_INPUT
903+
else:
904+
return None
905+
906+
def merge_types(t1: DataType, t2: DataType) -> DataType:
907+
"""
908+
Merge two types into one as per the following rules:
909+
- Null + T = T
910+
- T + Null = T
911+
- T1 + T2 = T1 where T1 == T2
912+
- T1 + T2 = Variant where T1 != T2
913+
Args:
914+
t1: first type to merge.
915+
t2: second type to merge.
916+
917+
Returns:
918+
Merged type of t1 and t2.
919+
"""
920+
# treat NullType as None
921+
t1 = None if t1 == NullType() else t1
922+
t2 = None if t2 == NullType() else t2
923+
924+
if t1 is None:
925+
return t2
926+
if t2 is None:
927+
return t1
928+
if t1 == t2:
929+
return t1
930+
if isinstance(t1, MapType) and isinstance(t2, MapType):
931+
return MapType(
932+
merge_types(t1.key_type, t2.key_type),
933+
merge_types(t1.value_type, t2.value_type),
934+
)
935+
if isinstance(t1, ArrayType) and isinstance(t2, ArrayType):
936+
return ArrayType(merge_types(t1.element_type, t2.element_type))
937+
return VariantType()
938+
939+
inferred_type = None
940+
for x in input_data:
941+
try:
942+
inferred_type = merge_types(
943+
inferred_type, infer_object_type(func(x, **kwargs))
944+
)
945+
except Exception:
946+
pass
947+
948+
if isinstance(inferred_type, TimedeltaType):
949+
# TODO: SNOW-1619940: pd.Timedelta is encoded as string.
950+
return StringType()
951+
return inferred_type
952+
953+
832954
def deduce_return_type_from_function(
833-
func: Union[AggFuncType, UserDefinedFunction]
955+
func: Union[AggFuncType, UserDefinedFunction],
956+
input_type: Optional[DataType],
957+
**kwargs: Any,
834958
) -> Optional[DataType]:
835959
"""
836960
Deduce return type if possible from a function, list, dict or type object. List will be mapped to ArrayType(),
837961
dict to MapType(), and if a type object (e.g., str) is given a mapping will be consulted.
838962
Args:
839963
func: callable function, object or Snowpark UserDefinedFunction that can be passed in pandas to reference a function.
964+
input_type: input data type this function is applied to.
965+
**kwargs : Additional keyword arguments to pass as keywords arguments to func.
840966
841967
Returns:
842968
Snowpark Datatype or None if no return type could be deduced.
@@ -860,13 +986,17 @@ def deduce_return_type_from_function(
860986
else:
861987
# handle special case 'object' type, in this case use Variant Type.
862988
# Catch potential TypeError exception here from python_type_to_snow_type.
863-
# If it is not the object type, return None to indicate that type hint could not be extracted successfully.
989+
# If it is not the object type, return None to indicate that type hint could not
990+
# be extracted successfully.
864991
try:
865-
return get_types_from_type_hints(func, TempObjectType.FUNCTION)[0]
992+
return_type = get_types_from_type_hints(func, TempObjectType.FUNCTION)[0]
993+
if return_type is not None:
994+
return return_type
866995
except TypeError as te:
867996
if str(te) == "invalid type <class 'object'>":
868997
return VariantType()
869-
return None
998+
# infer return type using dummy data.
999+
return infer_return_type_using_dummy_data(func, input_type, **kwargs)
8701000

8711001

8721002
def sort_apply_udtf_result_columns_by_pandas_positions(

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8671,8 +8671,8 @@ def wrapped_func(*args, **kwargs): # type: ignore[no-untyped-def] # pragma: no
86718671
)
86728672

86738673
# Extract return type from annotations (or lookup for known pandas functions) for func object,
8674-
# if not return type could be extracted the variable will hold None.
8675-
return_type = deduce_return_type_from_function(func)
8674+
# if no return type could be extracted the variable will hold None.
8675+
return_type = deduce_return_type_from_function(func, None)
86768676

86778677
# Check whether return_type has been extracted. If return type is not
86788678
# a Series, tuple or list object, compute df.apply using a vUDF. In this case no column expansion needs to
@@ -8765,7 +8765,9 @@ def applymap(
87658765
Function to apply to each element of the QueryCompiler.
87668766
na_action: If 'ignore', propagate NULL values
87678767
*args : iterable
8768+
Positional arguments passed to func after the input data.
87688769
**kwargs : dict
8770+
Additional keyword arguments to pass as keywords arguments to func.
87698771
"""
87708772
self._raise_not_implemented_error_for_timedelta()
87718773

@@ -8798,15 +8800,17 @@ def applymap(
87988800
ErrorMessage.not_implemented(
87998801
"Snowpark pandas applymap API doesn't yet support na_action == 'ignore'"
88008802
)
8801-
return_type = deduce_return_type_from_function(func)
8802-
if not return_type:
8803-
return_type = VariantType()
88048803

88058804
# create and apply udfs on all data columns
88068805
replace_mapping = {}
88078806
for f in self._modin_frame.ordered_dataframe.schema.fields:
88088807
identifier = f.column_identifier.quoted_name
88098808
if identifier in self._modin_frame.data_column_snowflake_quoted_identifiers:
8809+
return_type = deduce_return_type_from_function(
8810+
func, f.datatype, **kwargs
8811+
)
8812+
if not return_type:
8813+
return_type = VariantType()
88108814
func_udf = create_udf_for_series_apply(
88118815
func,
88128816
return_type,

src/snowflake/snowpark/modin/plugin/docstrings/series.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,8 @@ def apply():
687687
Notes
688688
-----
689689
1. When ``func`` has a type annotation for its return value, the result will be cast
690-
to the corresponding dtype. When no type annotation is provided, data will be converted
690+
to the corresponding dtype. When no type annotation is provided, we try to infer
691+
return type using dummy data. If return type inference is not successful data will be converted
691692
to VARIANT type in Snowflake, and the result will have ``dtype=object``. In this case, the return value must
692693
be JSON-serializable, which can be a valid input to ``json.dumps`` (e.g., ``dict`` and
693694
``list`` objects are JSON-serializable, but ``bytes`` and ``datetime.datetime`` objects

tests/integ/modin/frame/test_applymap.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,12 @@ def test_preserve_order():
166166
eval_snowpark_pandas_result(df, native_df, lambda x: x.applymap(lambda y: -y))
167167

168168

169+
@sql_count_checker(
170+
query_count=10,
171+
udf_count=1,
172+
high_count_expected=True,
173+
high_count_reason="udf creation",
174+
)
169175
def test_applymap_variant_json_null():
170176
def f(x):
171177
if native_pd.isna(x):
@@ -182,11 +188,5 @@ def f(x):
182188
# the last column is a variant column [None, pd.NA], where both None and pd.NA
183189
# are mapped to SQL null by Python UDF in the input
184190
df = pd.DataFrame([[1, 2, None], [3, 4, pd.NA]])
185-
with SqlCounter(query_count=9):
186-
df = df.applymap(f)
187-
188-
with SqlCounter(query_count=1, udf_count=1):
189-
assert df.isna().to_numpy().tolist() == [
190-
[False, True, True],
191-
[True, False, True],
192-
]
191+
native_df = native_pd.DataFrame([[1, 2, None], [3, 4, pd.NA]])
192+
eval_snowpark_pandas_result(df, native_df, lambda x: x.applymap(f).isna())

tests/integ/modin/series/test_apply_and_map.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -435,16 +435,16 @@ def f(x) -> float:
435435
)
436436

437437
@pytest.mark.parametrize(
438-
"input, expected_output",
438+
"input",
439439
[
440440
# The last element in this numeric column becomes np.nan in Python UDF -> SQL null
441-
([1, 2, 3, 4, None], [False, True, True, False, True]),
441+
[1, 2, 3, 4, None],
442442
# The last element in this string column becomes pd.NA in Python UDF -> SQL null
443-
(["s", "t", "null", None], [False, False, False, True]),
443+
["s", "t", "null", None],
444444
],
445445
)
446446
@sql_count_checker(query_count=4, udf_count=1)
447-
def test_variant_json_null(self, method, input, expected_output):
447+
def test_variant_json_null(self, method, input):
448448
def f(x):
449449
if native_pd.isna(x):
450450
return x
@@ -457,7 +457,11 @@ def f(x):
457457
else:
458458
return x
459459

460-
assert getattr(pd.Series(input), method)(f).isna().tolist() == expected_output
460+
snow_series = pd.Series(input)
461+
native_series = native_pd.Series(input)
462+
eval_snowpark_pandas_result(
463+
snow_series, native_series, lambda x: getattr(x, method)(f).isna()
464+
)
461465

462466
# This import is related to the test below. Do not remove.
463467
import scipy # noqa: E402
@@ -811,6 +815,15 @@ def test_invalid_arg_type(self):
811815
expect_exception_type=TypeError,
812816
)
813817

818+
@sql_count_checker(query_count=3)
819+
def test_incorrect_inferred_type(self):
820+
s = pd.Series([1, 2, 17])
821+
# The return type of the lambda is inferred as int, but the return type is
822+
# mix of int and string.
823+
# Attempt to convert "abc" to int will raise an exception.
824+
with pytest.raises(SnowparkSQLException):
825+
s.map(lambda x: "abc" if x == 17 else x).to_pandas()
826+
814827

815828
# NOTE: Please add test cases to one of TestApplyOrMapCallable, TestApplyOnly,
816829
# or TestMapOnly, instead of adding separate test functions here.

0 commit comments

Comments
 (0)