44
55import logging
66import re
7- from typing import Any , Optional , Iterable , List , Union , Dict , Tuple , Callable
7+ from typing import Any , Optional , Iterable , List , Union , Dict , Tuple , Callable , Literal
88from datetime import date , datetime , time , timedelta , timezone
99from decimal import Decimal
1010
1818from snowflake .snowpark .relational_grouped_dataframe import GroupingSets
1919from snowflake .snowpark import Session , Column , DataFrameAnalyticsFunctions , Row
2020import snowflake .snowpark .functions
21- from snowflake .snowpark .functions import udf , udtf , when , sproc , call_table_function
21+ from snowflake .snowpark .functions import (
22+ udaf ,
23+ udf ,
24+ udtf ,
25+ when ,
26+ sproc ,
27+ call_table_function ,
28+ )
2229from snowflake .snowpark .types import (
2330 DataType ,
2431 ArrayType ,
@@ -111,7 +118,9 @@ def get_dataframe_analytics_function_column_formatter(
111118 return DataFrameAnalyticsFunctions ._default_col_formatter
112119
113120 def decode_callable_expr (
114- self , callable_expr : proto .SpCallable
121+ self ,
122+ callable_expr : proto .SpCallable ,
123+ callable_type : Optional [Literal ["udaf" , "udtf" ]] = None ,
115124 ) -> Tuple [Callable , str ]:
116125 """
117126 Decode a callable expression to get the callable.
@@ -120,6 +129,10 @@ def decode_callable_expr(
120129 ----------
121130 callable_expr : proto.SpCallable
122131 The callable expression to decode.
132+ callable_type : Optional[Literal["udaf", "udtf"]]
133+ The type of callable.
134+ If None, it will be treated as a regular function; an empty function will be created and renamed based on
135+ the recorded function's name.
123136
124137 Returns
125138 -------
@@ -133,16 +146,17 @@ def decode_callable_expr(
133146 if callable_expr .HasField ("object_name" )
134147 else None
135148 )
136- try :
149+ if callable_type == "udtf" :
137150 handler = self .session ._udtf_registration .get_udtf (object_name ).handler
138- except KeyError :
151+ elif callable_type == "udaf" :
152+ handler = self .session ._udaf_registration .get_udaf (object_name ).handler
153+ else :
139154
140155 def __temp_handler_func ():
141156 pass
142157
143- __temp_handler_func .__name__ = (
144- name # Set the name of the function to whatever it was originally.
145- )
158+ # Set the name of the function to whatever it was originally.
159+ __temp_handler_func .__name__ = name
146160 handler , object_name = __temp_handler_func , name
147161 return handler , object_name
148162
@@ -1895,6 +1909,68 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
18951909 df .to_snowpark_pandas (index_col , columns )
18961910 return None
18971911
1912+ case "udaf" :
1913+ comment = (
1914+ expr .udaf .comment .value if expr .udaf .HasField ("comment" ) else None
1915+ )
1916+ external_access_integrations = [
1917+ eai for eai in expr .udaf .external_access_integrations
1918+ ]
1919+ handler , handler_name = self .decode_callable_expr (
1920+ expr .udaf .handler , "udaf"
1921+ )
1922+ if_not_exists = expr .udaf .if_not_exists
1923+ immutable = expr .udaf .immutable
1924+ imports = [
1925+ self .decode_name_expr (import_ ) for import_ in expr .udaf .imports
1926+ ]
1927+ input_types = [
1928+ self .decode_data_type_expr (input_type )
1929+ for input_type in expr .udaf .input_types .list
1930+ ]
1931+ is_permanent = expr .udaf .is_permanent
1932+ kwargs = self .decode_dsl_map_expr (expr .udaf .kwargs )
1933+ if "copy_grants" in kwargs :
1934+ kwargs .pop ("copy_grants" )
1935+ name = (
1936+ self .decode_name_expr (expr .udaf .name )
1937+ if expr .udaf .HasField ("name" )
1938+ else None
1939+ )
1940+ packages = [package for package in expr .udaf .packages ]
1941+ parallel = expr .udaf .parallel
1942+ replace = expr .udaf .replace
1943+ return_type = self .decode_data_type_expr (expr .udaf .return_type )
1944+ secrets = self .decode_dsl_map_expr (expr .udaf .secrets )
1945+ stage_location = (
1946+ expr .udaf .stage_location .value
1947+ if expr .udaf .HasField ("stage_location" )
1948+ else None
1949+ )
1950+ statement_params = self .decode_dsl_map_expr (expr .udaf .statement_params )
1951+ # Run udaf to create the required AST but return the first registered version of the UDAF.
1952+ _ = udaf (
1953+ handler ,
1954+ return_type = return_type ,
1955+ input_types = input_types ,
1956+ name = name ,
1957+ is_permanent = is_permanent ,
1958+ stage_location = stage_location ,
1959+ imports = imports ,
1960+ packages = packages ,
1961+ replace = replace ,
1962+ if_not_exists = if_not_exists ,
1963+ session = self .session ,
1964+ parallel = parallel ,
1965+ statement_params = statement_params ,
1966+ immutable = immutable ,
1967+ external_access_integrations = external_access_integrations ,
1968+ secrets = secrets ,
1969+ comment = comment ,
1970+ ** kwargs ,
1971+ )
1972+ return self .session ._udaf_registration .get_udaf (handler_name )
1973+
18981974 case "udf" :
18991975 return_type = self .decode_data_type_expr (expr .udf .return_type )
19001976 input_types = [
@@ -1912,7 +1988,9 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
19121988 external_access_integrations = [
19131989 eai for eai in expr .udtf .external_access_integrations
19141990 ]
1915- handler , handler_name = self .decode_callable_expr (expr .udtf .handler )
1991+ handler , handler_name = self .decode_callable_expr (
1992+ expr .udtf .handler , "udtf"
1993+ )
19161994 if_not_exists = expr .udtf .if_not_exists
19171995 immutable = expr .udtf .immutable
19181996 imports = [
0 commit comments