Skip to content

Commit 8895a30

Browse files
authored
SNOW-2436917: fix udxf/sproc registration regression (#3927)
1 parent fb08011 commit 8895a30

File tree

9 files changed

+89
-11
lines changed

9 files changed

+89
-11
lines changed

src/snowflake/snowpark/_internal/udf_utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,7 +1134,7 @@ def resolve_imports_and_packages(
11341134
skip_upload_on_content_match: bool = False,
11351135
is_permanent: bool = False,
11361136
force_inline_code: bool = False,
1137-
**kwargs,
1137+
_suppress_local_package_warnings: bool = False,
11381138
) -> Tuple[
11391139
Optional[str],
11401140
Optional[str],
@@ -1168,9 +1168,7 @@ def resolve_imports_and_packages(
11681168
packages,
11691169
include_pandas=is_pandas_udf,
11701170
statement_params=statement_params,
1171-
_suppress_local_package_warnings=kwargs.get(
1172-
"_suppress_local_package_warnings", False
1173-
),
1171+
_suppress_local_package_warnings=_suppress_local_package_warnings,
11741172
)
11751173
if packages is not None
11761174
else session._resolve_packages(
@@ -1179,9 +1177,7 @@ def resolve_imports_and_packages(
11791177
validate_package=False,
11801178
include_pandas=is_pandas_udf,
11811179
statement_params=statement_params,
1182-
_suppress_local_package_warnings=kwargs.get(
1183-
"_suppress_local_package_warnings", False
1184-
),
1180+
_suppress_local_package_warnings=_suppress_local_package_warnings,
11851181
)
11861182
)
11871183

src/snowflake/snowpark/stored_procedure.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,10 @@ def _do_register_sp(
955955
is_permanent=is_permanent,
956956
force_inline_code=force_inline_code,
957957
artifact_repository=artifact_repository,
958-
**kwargs,
958+
_suppress_local_package_warnings=kwargs.get(
959+
"_suppress_local_package_warnings", False
960+
),
961+
# DO NOT pass **kwargs here, as it can lead to TypeError: multiple values for the same argument
959962
)
960963

961964
runtime_version_from_requirement = None

src/snowflake/snowpark/udaf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,10 @@ def _do_register_udaf(
799799
skip_upload_on_content_match=skip_upload_on_content_match,
800800
is_permanent=is_permanent,
801801
artifact_repository=artifact_repository,
802-
**kwargs,
802+
_suppress_local_package_warnings=kwargs.get(
803+
"_suppress_local_package_warnings", False
804+
),
805+
# DO NOT pass **kwargs here, as it can lead to TypeError: multiple values for the same argument
803806
)
804807

805808
runtime_version_from_requirement = None

src/snowflake/snowpark/udf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,10 @@ def _do_register_udf(
994994
skip_upload_on_content_match=skip_upload_on_content_match,
995995
is_permanent=is_permanent,
996996
artifact_repository=artifact_repository,
997-
**kwargs,
997+
_suppress_local_package_warnings=kwargs.get(
998+
"_suppress_local_package_warnings", False
999+
),
1000+
# DO NOT pass **kwargs here, as it can lead to TypeError: multiple values for the same argument
9981001
)
9991002

10001003
runtime_version_from_requirement = None

src/snowflake/snowpark/udtf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1056,7 +1056,10 @@ def _do_register_udtf(
10561056
skip_upload_on_content_match=skip_upload_on_content_match,
10571057
is_permanent=is_permanent,
10581058
artifact_repository=artifact_repository,
1059-
**kwargs,
1059+
_suppress_local_package_warnings=kwargs.get(
1060+
"_suppress_local_package_warnings", False
1061+
)
1062+
# DO NOT pass **kwargs here, as it can lead to TypeError: multiple values for the same argument
10601063
)
10611064

10621065
runtime_version_from_requirement = None

tests/integ/test_stored_procedure.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,18 @@ def test_session_register_sp(session, local_testing_mode):
561561
)
562562
assert add_sp(1, 2) == 3
563563

564+
# testing SNOW-2436917
565+
add_sp_passing_session = session.sproc.register(
566+
lambda session_, x, y: session_.create_dataframe([(x, y)])
567+
.to_df("a", "b")
568+
.select(col("a") + col("b"))
569+
.collect()[0][0],
570+
session=session,
571+
return_type=IntegerType(),
572+
input_types=[IntegerType(), IntegerType()],
573+
)
574+
assert add_sp_passing_session(1, 2) == 3
575+
564576
query_tag = f"QUERY_TAG_{Utils.random_alphanumeric_str(10)}"
565577
add_sp = session.sproc.register(
566578
lambda session_, x, y: session_.create_dataframe([(x, y)])

tests/integ/test_udaf.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,36 @@ def finish(self):
6363
Utils.assert_executed_with_query_tag(session, query_tag)
6464

6565

66+
def test_basic_udaf_snow_2436917(session):
67+
class PythonCountUDAFHandler:
68+
def __init__(self) -> None:
69+
self._count = 0
70+
71+
@property
72+
def aggregate_state(self):
73+
return self._count
74+
75+
def accumulate(self, input_value):
76+
if input_value is not None:
77+
self._count += 1
78+
79+
def merge(self, other_count):
80+
self._count += other_count
81+
82+
def finish(self):
83+
return self._count
84+
85+
count_udaf = session.udaf.register(
86+
PythonCountUDAFHandler,
87+
return_type=IntegerType(),
88+
input_types=[IntegerType()],
89+
immutable=True,
90+
session=session,
91+
)
92+
df = session.create_dataframe([[1], [2], [None], [3]]).to_df("a")
93+
Utils.check_answer(df.agg(count_udaf("a")), [Row(3)])
94+
95+
6696
# TODO: use data class as state. This triggers a bug in UDF server during pickling/unpickling of a state.
6797
def test_int(session):
6898
@udaf

tests/integ/test_udf.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,23 @@ def test_session_register_udf(session, local_testing_mode):
396396
Row(7),
397397
],
398398
)
399+
400+
# testing SNOW-SNOW-2436917
401+
add_udf_passing_session = session.udf.register(
402+
lambda x, y: x + y,
403+
session=session,
404+
return_type=IntegerType(),
405+
input_types=[IntegerType(), IntegerType()],
406+
)
407+
assert isinstance(add_udf_passing_session.func, Callable)
408+
Utils.check_answer(
409+
df.select(add_udf_passing_session("a", "b")).collect(),
410+
[
411+
Row(3),
412+
Row(7),
413+
],
414+
)
415+
399416
# Query tags not supported in local testing.
400417
if not local_testing_mode:
401418
query_tag = f"QUERY_TAG_{Utils.random_alphanumeric_str(10)}"

tests/integ/test_udtf.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1335,6 +1335,17 @@ def process(
13351335
ddl_sql = f"select get_ddl('FUNCTION', '{echo_udtf.name}(number)')"
13361336
assert comment in session.sql(ddl_sql).collect()[0][0]
13371337

1338+
# testing SNOW-SNOW-2436917
1339+
echo_udtf_passing_session = session.udtf.register(
1340+
EchoUDTF,
1341+
output_schema=["num"],
1342+
comment=comment,
1343+
session=session,
1344+
)
1345+
1346+
ddl_sql = f"select get_ddl('FUNCTION', '{echo_udtf_passing_session.name}(number)')"
1347+
assert comment in session.sql(ddl_sql).collect()[0][0]
1348+
13381349

13391350
@pytest.mark.parametrize("from_file", [True, False])
13401351
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)