Skip to content

Commit 4fb9f32

Browse files
udafs for decoder now work
1 parent 811937f commit 4fb9f32

File tree

1 file changed

+87
-9
lines changed

1 file changed

+87
-9
lines changed

tests/ast/decoder.py

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import logging
66
import re
7-
from typing import Any, Optional, Iterable, List, Union, Dict, Tuple, Callable
7+
from typing import Any, Optional, Iterable, List, Union, Dict, Tuple, Callable, Literal
88
from datetime import date, datetime, time, timedelta, timezone
99
from decimal import Decimal
1010

@@ -18,7 +18,14 @@
1818
from snowflake.snowpark.relational_grouped_dataframe import GroupingSets
1919
from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions, Row
2020
import snowflake.snowpark.functions
21-
from snowflake.snowpark.functions import udf, udtf, when, sproc, call_table_function
21+
from snowflake.snowpark.functions import (
22+
udaf,
23+
udf,
24+
udtf,
25+
when,
26+
sproc,
27+
call_table_function,
28+
)
2229
from snowflake.snowpark.types import (
2330
DataType,
2431
ArrayType,
@@ -111,7 +118,9 @@ def get_dataframe_analytics_function_column_formatter(
111118
return DataFrameAnalyticsFunctions._default_col_formatter
112119

113120
def decode_callable_expr(
114-
self, callable_expr: proto.SpCallable
121+
self,
122+
callable_expr: proto.SpCallable,
123+
callable_type: Optional[Literal["udaf", "udtf"]] = None,
115124
) -> Tuple[Callable, str]:
116125
"""
117126
Decode a callable expression to get the callable.
@@ -120,6 +129,10 @@ def decode_callable_expr(
120129
----------
121130
callable_expr : proto.SpCallable
122131
The callable expression to decode.
132+
callable_type : Optional[Literal["udaf", "udtf"]]
133+
The type of callable.
134+
If None, it will be treated as a regular function; an empty function will be created and renamed based on
135+
the recorded function's name.
123136
124137
Returns
125138
-------
@@ -133,16 +146,17 @@ def decode_callable_expr(
133146
if callable_expr.HasField("object_name")
134147
else None
135148
)
136-
try:
149+
if callable_type == "udtf":
137150
handler = self.session._udtf_registration.get_udtf(object_name).handler
138-
except KeyError:
151+
elif callable_type == "udaf":
152+
handler = self.session._udaf_registration.get_udaf(object_name).handler
153+
else:
139154

140155
def __temp_handler_func():
141156
pass
142157

143-
__temp_handler_func.__name__ = (
144-
name # Set the name of the function to whatever it was originally.
145-
)
158+
# Set the name of the function to whatever it was originally.
159+
__temp_handler_func.__name__ = name
146160
handler, object_name = __temp_handler_func, name
147161
return handler, object_name
148162

@@ -1895,6 +1909,68 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
18951909
df.to_snowpark_pandas(index_col, columns)
18961910
return None
18971911

1912+
case "udaf":
1913+
comment = (
1914+
expr.udaf.comment.value if expr.udaf.HasField("comment") else None
1915+
)
1916+
external_access_integrations = [
1917+
eai for eai in expr.udaf.external_access_integrations
1918+
]
1919+
handler, handler_name = self.decode_callable_expr(
1920+
expr.udaf.handler, "udaf"
1921+
)
1922+
if_not_exists = expr.udaf.if_not_exists
1923+
immutable = expr.udaf.immutable
1924+
imports = [
1925+
self.decode_name_expr(import_) for import_ in expr.udaf.imports
1926+
]
1927+
input_types = [
1928+
self.decode_data_type_expr(input_type)
1929+
for input_type in expr.udaf.input_types.list
1930+
]
1931+
is_permanent = expr.udaf.is_permanent
1932+
kwargs = self.decode_dsl_map_expr(expr.udaf.kwargs)
1933+
if "copy_grants" in kwargs:
1934+
kwargs.pop("copy_grants")
1935+
name = (
1936+
self.decode_name_expr(expr.udaf.name)
1937+
if expr.udaf.HasField("name")
1938+
else None
1939+
)
1940+
packages = [package for package in expr.udaf.packages]
1941+
parallel = expr.udaf.parallel
1942+
replace = expr.udaf.replace
1943+
return_type = self.decode_data_type_expr(expr.udaf.return_type)
1944+
secrets = self.decode_dsl_map_expr(expr.udaf.secrets)
1945+
stage_location = (
1946+
expr.udaf.stage_location.value
1947+
if expr.udaf.HasField("stage_location")
1948+
else None
1949+
)
1950+
statement_params = self.decode_dsl_map_expr(expr.udaf.statement_params)
1951+
# Run udaf to create the required AST but return the first registered version of the UDAF.
1952+
_ = udaf(
1953+
handler,
1954+
return_type=return_type,
1955+
input_types=input_types,
1956+
name=name,
1957+
is_permanent=is_permanent,
1958+
stage_location=stage_location,
1959+
imports=imports,
1960+
packages=packages,
1961+
replace=replace,
1962+
if_not_exists=if_not_exists,
1963+
session=self.session,
1964+
parallel=parallel,
1965+
statement_params=statement_params,
1966+
immutable=immutable,
1967+
external_access_integrations=external_access_integrations,
1968+
secrets=secrets,
1969+
comment=comment,
1970+
**kwargs,
1971+
)
1972+
return self.session._udaf_registration.get_udaf(handler_name)
1973+
18981974
case "udf":
18991975
return_type = self.decode_data_type_expr(expr.udf.return_type)
19001976
input_types = [
@@ -1912,7 +1988,9 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
19121988
external_access_integrations = [
19131989
eai for eai in expr.udtf.external_access_integrations
19141990
]
1915-
handler, handler_name = self.decode_callable_expr(expr.udtf.handler)
1991+
handler, handler_name = self.decode_callable_expr(
1992+
expr.udtf.handler, "udtf"
1993+
)
19161994
if_not_exists = expr.udtf.if_not_exists
19171995
immutable = expr.udtf.immutable
19181996
imports = [

0 commit comments

Comments
 (0)