88from collections .abc import Hashable
99from enum import Enum , auto
1010from typing import Any , Callable , Literal , Optional , Union
11+ from datetime import datetime
1112
1213import cloudpickle
1314import numpy as np
2122from snowflake .snowpark ._internal .udf_utils import get_types_from_type_hints
2223import functools
2324from snowflake .snowpark .column import Column as SnowparkColumn
25+ from snowflake .snowpark .modin .plugin ._internal .snowpark_pandas_types import (
26+ TimedeltaType ,
27+ )
2428from snowflake .snowpark .modin .plugin ._internal .type_utils import (
2529 infer_object_type ,
2630 pandas_lit ,
4549from snowflake .snowpark .session import Session
4650from snowflake .snowpark .types import (
4751 ArrayType ,
52+ BinaryType ,
53+ BooleanType ,
4854 DataType ,
55+ _IntegralType ,
56+ _FractionalType ,
4957 IntegerType ,
5058 LongType ,
5159 MapType ,
60+ NullType ,
5261 PandasDataFrameType ,
5362 PandasSeriesType ,
5463 StringType ,
64+ TimestampType ,
5565 VariantType ,
5666)
5767from snowflake .snowpark .udf import UserDefinedFunction
@@ -113,7 +123,7 @@ class GroupbyApplySortMethod(Enum):
113123
114124def check_return_variant_and_get_return_type (func : Callable ) -> tuple [bool , DataType ]:
115125 """Check whether the function returns a variant in Snowflake, and get its return type."""
116- return_type = deduce_return_type_from_function (func )
126+ return_type = deduce_return_type_from_function (func , None )
117127 if return_type is None or isinstance (
118128 return_type , (VariantType , PandasSeriesType , PandasDataFrameType )
119129 ):
@@ -756,6 +766,8 @@ def apply_func(x): # type: ignore[no-untyped-def] # pragma: no cover
756766 else :
757767
758768 def apply_func (x ): # type: ignore[no-untyped-def] # pragma: no cover
769+ # TODO SNOW-1874779: Add verification here to ensure inferred type matches
770+ # actual type.
759771 return x .apply (func , args = args , ** kwargs )
760772
761773 func_udf = sp_func .udf (
@@ -829,14 +841,128 @@ def convert_numpy_int_result_to_int(value: Any) -> Any:
829841 )
830842
831843
844+ DUMMY_BOOL_INPUT = native_pd .Series ([False , True ])
845+ DUMMY_INT_INPUT = native_pd .Series (
846+ [- 37 , - 9 , - 2 , - 1 , 0 , 2 , 3 , 5 , 7 , 9 , 13 , 16 , 20 ]
847+ + np .power (10 , np .arange (19 )).tolist ()
848+ + np .multiply (- 1 , np .power (10 , np .arange (19 ))).tolist ()
849+ )
850+ DUMMY_FLOAT_INPUT = native_pd .Series (
851+ [- 9.9 , - 2.2 , - 1.0 , 0.0 , 0.5 , 0.33 , None , 0.99 , 2.0 , 3.0 , 5.0 , 7.7 , 9.898989 ]
852+ + np .power (10.1 , np .arange (19 )).tolist ()
853+ + np .multiply (- 1.0 , np .power (10.1 , np .arange (19 ))).tolist ()
854+ )
855+ DUMMY_STRING_INPUT = native_pd .Series (
856+ ["" , "a" , "A" , "0" , "1" , "01" , "123" , "-1" , "-12" , "true" , "True" , "false" , "False" ]
857+ + [
None ,
"null" ,
"Jane Smith" ,
"[email protected] " ,
"[email protected] " ]
858+ + ["650-592-4563" , "Jane Smith, 123 Main St., Anytown, CA 12345" ]
859+ + ["2020-12-23" , "2020-12-23 12:34:56" , "08/08/2024" , "07-08-2022" , "12:34:56" ]
860+ + ["ABC" , "bat-man" , "super_man" , "1@#$%^&*()_+" , "<>?:{}|[]\\ ;'/.," , "<tag>" ]
861+ )
862+ DUMMY_BINARY_INPUT = native_pd .Series (
863+ [bytes ("snow" , "utf-8" ), bytes ("flake" , "utf-8" ), bytes ("12" , "utf-8" ), None ]
864+ )
865+ DUMMY_TIMESTAMP_INPUT = native_pd .to_datetime (
866+ ["2020-12-31 00:00:00" , "2020-01-01 00:00:00" , native_pd .Timestamp .min ] # past
867+ + ["2090-01-01 00:00:00" , "2090-12-31 00:00:00" , native_pd .Timestamp .max ] # future
868+ + [datetime .today (), None ], # current
869+ format = "mixed" ,
870+ )
871+
872+
873+ def infer_return_type_using_dummy_data (
874+ func : Callable , input_type : DataType , ** kwargs : Any
875+ ) -> Optional [DataType ]:
876+ """
877+ Infer the return type of given function by applying it to a dummy input.
878+ This method only supports the following input types: _IntegralType, _FractionalType,
879+ StringType, BooleanType, TimestampType, BinaryType.
880+ Args:
881+ func: The function to infer the return type from.
882+ input_type: The input type of the function.
883+ **kwargs : Additional keyword arguments to pass as keywords arguments to func.
884+ Returns:
885+ The inferred return type of the function. If the return type cannot be inferred,
886+ return None.
887+ """
888+ if input_type is None :
889+ return None
890+ input_data = None
891+ if isinstance (input_type , _IntegralType ):
892+ input_data = DUMMY_INT_INPUT
893+ elif isinstance (input_type , _FractionalType ):
894+ input_data = DUMMY_FLOAT_INPUT
895+ elif isinstance (input_type , StringType ):
896+ input_data = DUMMY_STRING_INPUT
897+ elif isinstance (input_type , BooleanType ):
898+ input_data = DUMMY_BOOL_INPUT
899+ elif isinstance (input_type , TimestampType ):
900+ input_data = DUMMY_TIMESTAMP_INPUT
901+ elif isinstance (input_type , BinaryType ):
902+ input_data = DUMMY_BINARY_INPUT
903+ else :
904+ return None
905+
906+ def merge_types (t1 : DataType , t2 : DataType ) -> DataType :
907+ """
908+ Merge two types into one as per the following rules:
909+ - Null + T = T
910+ - T + Null = T
911+ - T1 + T2 = T1 where T1 == T2
912+ - T1 + T2 = Variant where T1 != T2
913+ Args:
914+ t1: first type to merge.
915+ t2: second type to merge.
916+
917+ Returns:
918+ Merged type of t1 and t2.
919+ """
920+ # treat NullType as None
921+ t1 = None if t1 == NullType () else t1
922+ t2 = None if t2 == NullType () else t2
923+
924+ if t1 is None :
925+ return t2
926+ if t2 is None :
927+ return t1
928+ if t1 == t2 :
929+ return t1
930+ if isinstance (t1 , MapType ) and isinstance (t2 , MapType ):
931+ return MapType (
932+ merge_types (t1 .key_type , t2 .key_type ),
933+ merge_types (t1 .value_type , t2 .value_type ),
934+ )
935+ if isinstance (t1 , ArrayType ) and isinstance (t2 , ArrayType ):
936+ return ArrayType (merge_types (t1 .element_type , t2 .element_type ))
937+ return VariantType ()
938+
939+ inferred_type = None
940+ for x in input_data :
941+ try :
942+ inferred_type = merge_types (
943+ inferred_type , infer_object_type (func (x , ** kwargs ))
944+ )
945+ except Exception :
946+ pass
947+
948+ if isinstance (inferred_type , TimedeltaType ):
949+ # TODO: SNOW-1619940: pd.Timedelta is encoded as string.
950+ return StringType ()
951+ return inferred_type
952+
953+
832954def deduce_return_type_from_function (
833- func : Union [AggFuncType , UserDefinedFunction ]
955+ func : Union [AggFuncType , UserDefinedFunction ],
956+ input_type : Optional [DataType ],
957+ ** kwargs : Any ,
834958) -> Optional [DataType ]:
835959 """
836960 Deduce return type if possible from a function, list, dict or type object. List will be mapped to ArrayType(),
837961 dict to MapType(), and if a type object (e.g., str) is given a mapping will be consulted.
838962 Args:
839963 func: callable function, object or Snowpark UserDefinedFunction that can be passed in pandas to reference a function.
964+ input_type: input data type this function is applied to.
965+ **kwargs : Additional keyword arguments to pass as keywords arguments to func.
840966
841967 Returns:
842968 Snowpark Datatype or None if no return type could be deduced.
@@ -860,13 +986,17 @@ def deduce_return_type_from_function(
860986 else :
861987 # handle special case 'object' type, in this case use Variant Type.
862988 # Catch potential TypeError exception here from python_type_to_snow_type.
863- # If it is not the object type, return None to indicate that type hint could not be extracted successfully.
989+ # If it is not the object type, return None to indicate that type hint could not
990+ # be extracted successfully.
864991 try :
865- return get_types_from_type_hints (func , TempObjectType .FUNCTION )[0 ]
992+ return_type = get_types_from_type_hints (func , TempObjectType .FUNCTION )[0 ]
993+ if return_type is not None :
994+ return return_type
866995 except TypeError as te :
867996 if str (te ) == "invalid type <class 'object'>" :
868997 return VariantType ()
869- return None
998+ # infer return type using dummy data.
999+ return infer_return_type_using_dummy_data (func , input_type , ** kwargs )
8701000
8711001
8721002def sort_apply_udtf_result_columns_by_pandas_positions (
0 commit comments