Skip to content

Commit 376b46c

Browse files
Removing SQL NULL values from UDF's arguments
1 parent 0cec38c commit 376b46c

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

src/snowflake/snowpark/_internal/udf_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,7 @@ def generate_python_code(
870870
if object_type == TempObjectType.PROCEDURE:
871871
func_code = f"""
872872
def {_DEFAULT_HANDLER_NAME}({args}):
873+
{get_removing_sql_nulls_code(args, 4)}
873874
return func({args})
874875
"""
875876
else:
@@ -920,6 +921,7 @@ def __init__(self):
920921
if hasattr(func, TABLE_FUNCTION_PROCESS_METHOD):
921922
func_code = f"""{func_code}
922923
def process(self, {wrapper_params}):
924+
{get_removing_sql_nulls_code(wrapper_params, 8)}
923925
return lock_function_once(super().process, process_invoked)({func_args})
924926
"""
925927
if hasattr(func, TABLE_FUNCTION_END_PARTITION_METHOD):
@@ -942,6 +944,7 @@ def __init__(self):
942944
lock_function_once(super().__init__, init_invoked)()
943945
944946
def accumulate(self, {args}):
947+
{get_removing_sql_nulls_code(args, 8)}
945948
return lock_function_once(super().accumulate, accumulate_invoked)({args})
946949
947950
def merge(self, other_agg_state):
@@ -955,6 +958,7 @@ def finish(self):
955958
invoked = InvokedFlag()
956959
957960
def {_DEFAULT_HANDLER_NAME}({wrapper_params}):
961+
{get_removing_sql_nulls_code(wrapper_params, 4)}
958962
return lock_function_once(func, invoked)({func_args})
959963
""".rstrip()
960964

@@ -993,6 +997,26 @@ def {_DEFAULT_HANDLER_NAME}({wrapper_params}):
993997
""".strip()
994998

995999

1000+
def get_removing_sql_nulls_code(args: str, indentation_size: int = 0) -> str:
1001+
"""
1002+
Generates Python code lines for removing SQL null values for a list of arguments.
1003+
1004+
Each generated line assigns the result of remove_sql_null(arg_name) back to arg_name.
1005+
Arguments are extracted from a comma-separated string.
1006+
"""
1007+
indentation_size = max(0, indentation_size)
1008+
arg_names = [arg_name.strip() for arg_name in args.split(",") if arg_name.strip()]
1009+
if not arg_names:
1010+
return ""
1011+
indentation = " " * indentation_size
1012+
removing_sql_null_lines = [
1013+
f"""{indentation}{arg_name} = None if getattr({arg_name}, "is_sql_null", False) else {arg_name}"""
1014+
for arg_name in arg_names
1015+
]
1016+
removing_sql_nulls_code = "\n".join(removing_sql_null_lines)
1017+
return removing_sql_nulls_code
1018+
1019+
9961020
def add_snowpark_package_to_sproc_packages(
9971021
session: Optional["snowflake.snowpark.Session"],
9981022
packages: Optional[List[Union[str, ModuleType]]],

0 commit comments

Comments
 (0)