Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 87 additions & 9 deletions tests/ast/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import logging
import re
from typing import Any, Optional, Iterable, List, Union, Dict, Tuple, Callable
from typing import Any, Optional, Iterable, List, Union, Dict, Tuple, Callable, Literal
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal

Expand All @@ -18,7 +18,14 @@
from snowflake.snowpark.relational_grouped_dataframe import GroupingSets
from snowflake.snowpark import Session, Column, DataFrameAnalyticsFunctions, Row
import snowflake.snowpark.functions
from snowflake.snowpark.functions import udf, udtf, when, sproc, call_table_function
from snowflake.snowpark.functions import (
udaf,
udf,
udtf,
when,
sproc,
call_table_function,
)
from snowflake.snowpark.types import (
DataType,
ArrayType,
Expand Down Expand Up @@ -111,7 +118,9 @@ def get_dataframe_analytics_function_column_formatter(
return DataFrameAnalyticsFunctions._default_col_formatter

def decode_callable_expr(
self, callable_expr: proto.SpCallable
self,
callable_expr: proto.SpCallable,
callable_type: Optional[Literal["udaf", "udtf"]] = None,
) -> Tuple[Callable, str]:
"""
Decode a callable expression to get the callable.
Expand All @@ -120,6 +129,10 @@ def decode_callable_expr(
----------
callable_expr : proto.SpCallable
The callable expression to decode.
callable_type : Optional[Literal["udaf", "udtf"]]
The type of callable.
If None, it will be treated as a regular function; an empty function will be created and renamed based on
the recorded function's name.

Returns
-------
Expand All @@ -133,16 +146,17 @@ def decode_callable_expr(
if callable_expr.HasField("object_name")
else None
)
try:
if callable_type == "udtf":
handler = self.session._udtf_registration.get_udtf(object_name).handler
except KeyError:
elif callable_type == "udaf":
handler = self.session._udaf_registration.get_udaf(object_name).handler
else:

def __temp_handler_func():
pass

__temp_handler_func.__name__ = (
name # Set the name of the function to whatever it was originally.
)
# Set the name of the function to whatever it was originally.
__temp_handler_func.__name__ = name
handler, object_name = __temp_handler_func, name
return handler, object_name

Expand Down Expand Up @@ -1895,6 +1909,68 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
df.to_snowpark_pandas(index_col, columns)
return None

case "udaf":
comment = (
expr.udaf.comment.value if expr.udaf.HasField("comment") else None
)
external_access_integrations = [
eai for eai in expr.udaf.external_access_integrations
]
handler, handler_name = self.decode_callable_expr(
expr.udaf.handler, "udaf"
)
if_not_exists = expr.udaf.if_not_exists
immutable = expr.udaf.immutable
imports = [
self.decode_name_expr(import_) for import_ in expr.udaf.imports
]
input_types = [
self.decode_data_type_expr(input_type)
for input_type in expr.udaf.input_types.list
]
is_permanent = expr.udaf.is_permanent
kwargs = self.decode_dsl_map_expr(expr.udaf.kwargs)
if "copy_grants" in kwargs:
kwargs.pop("copy_grants")
name = (
self.decode_name_expr(expr.udaf.name)
if expr.udaf.HasField("name")
else None
)
packages = [package for package in expr.udaf.packages]
parallel = expr.udaf.parallel
replace = expr.udaf.replace
return_type = self.decode_data_type_expr(expr.udaf.return_type)
secrets = self.decode_dsl_map_expr(expr.udaf.secrets)
stage_location = (
expr.udaf.stage_location.value
if expr.udaf.HasField("stage_location")
else None
)
statement_params = self.decode_dsl_map_expr(expr.udaf.statement_params)
# Run udaf to create the required AST but return the first registered version of the UDAF.
_ = udaf(
handler,
return_type=return_type,
input_types=input_types,
name=name,
is_permanent=is_permanent,
stage_location=stage_location,
imports=imports,
packages=packages,
replace=replace,
if_not_exists=if_not_exists,
session=self.session,
parallel=parallel,
statement_params=statement_params,
immutable=immutable,
external_access_integrations=external_access_integrations,
secrets=secrets,
comment=comment,
**kwargs,
)
return self.session._udaf_registration.get_udaf(handler_name)

case "udf":
return_type = self.decode_data_type_expr(expr.udf.return_type)
input_types = [
Expand All @@ -1912,7 +1988,9 @@ def decode_expr(self, expr: proto.Expr, **kwargs) -> Any:
external_access_integrations = [
eai for eai in expr.udtf.external_access_integrations
]
handler, handler_name = self.decode_callable_expr(expr.udtf.handler)
handler, handler_name = self.decode_callable_expr(
expr.udtf.handler, "udtf"
)
if_not_exists = expr.udtf.if_not_exists
immutable = expr.udtf.immutable
imports = [
Expand Down
Loading