@@ -870,6 +870,7 @@ def generate_python_code(
870870 if object_type == TempObjectType .PROCEDURE :
871871 func_code = f"""
872872def { _DEFAULT_HANDLER_NAME } ({ args } ):
873+ { get_removing_sql_nulls_code (args , 4 )}
873874 return func({ args } )
874875"""
875876 else :
@@ -920,6 +921,7 @@ def __init__(self):
920921 if hasattr (func , TABLE_FUNCTION_PROCESS_METHOD ):
921922 func_code = f"""{ func_code }
922923 def process(self, { wrapper_params } ):
924+ { get_removing_sql_nulls_code (wrapper_params , 8 )}
923925 return lock_function_once(super().process, process_invoked)({ func_args } )
924926"""
925927 if hasattr (func , TABLE_FUNCTION_END_PARTITION_METHOD ):
@@ -942,6 +944,7 @@ def __init__(self):
942944 lock_function_once(super().__init__, init_invoked)()
943945
944946 def accumulate(self, { args } ):
947+ { get_removing_sql_nulls_code (args , 8 )}
945948 return lock_function_once(super().accumulate, accumulate_invoked)({ args } )
946949
947950 def merge(self, other_agg_state):
@@ -955,6 +958,7 @@ def finish(self):
955958invoked = InvokedFlag()
956959
957960def { _DEFAULT_HANDLER_NAME } ({ wrapper_params } ):
961+ { get_removing_sql_nulls_code (wrapper_params , 4 )}
958962 return lock_function_once(func, invoked)({ func_args } )
959963""" .rstrip ()
960964
@@ -993,6 +997,26 @@ def {_DEFAULT_HANDLER_NAME}({wrapper_params}):
993997""" .strip ()
994998
995999
1000+ def get_removing_sql_nulls_code (args : str , indentation_size : int = 0 ) -> str :
1001+ """
1002+ Generates Python code lines for removing SQL null values for a list of arguments.
1003+
1004+ Each generated line assigns the result of remove_sql_null(arg_name) back to arg_name.
1005+ Arguments are extracted from a comma-separated string.
1006+ """
1007+ indentation_size = max (0 , indentation_size )
1008+ arg_names = [arg_name .strip () for arg_name in args .split ("," ) if arg_name .strip ()]
1009+ if not arg_names :
1010+ return ""
1011+ indentation = " " * indentation_size
1012+ removing_sql_null_lines = [
1013+ f"""{ indentation } { arg_name } = None if getattr({ arg_name } , "is_sql_null", False) else { arg_name } """
1014+ for arg_name in arg_names
1015+ ]
1016+ removing_sql_nulls_code = "\n " .join (removing_sql_null_lines )
1017+ return removing_sql_nulls_code
1018+
1019+
9961020def add_snowpark_package_to_sproc_packages (
9971021 session : Optional ["snowflake.snowpark.Session" ],
9981022 packages : Optional [List [Union [str , ModuleType ]]],
0 commit comments