Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Release History

## 1.46.0 (TBD)

### Snowpark Python API Updates

#### Bug Fixes

- Fixed a bug where `cloudpickle` was not automatically added to the package list when using `artifact_repository` with custom packages, causing `ModuleNotFoundError` at runtime.

## 1.45.0 (TBD)

### Snowpark Python API Updates
Expand Down
8 changes: 7 additions & 1 deletion src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)

import cloudpickle
from packaging.requirements import Requirement

import snowflake.snowpark
from snowflake.connector.options import installed_pandas, pandas
Expand Down Expand Up @@ -1234,7 +1235,12 @@ def resolve_imports_and_packages(
raise TypeError(
"Artifact repository requires that all packages be passed as str."
)
resolved_packages = packages
resolved_packages = list(packages)
if not any(
Requirement(pkg).name.lower() == "cloudpickle"
for pkg in resolved_packages
):
resolved_packages.append(f"cloudpickle=={cloudpickle.__version__}")
else:
# resolve packages
if session is None: # In case of sandbox
Expand Down
40 changes: 40 additions & 0 deletions tests/integ/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2916,6 +2916,46 @@ def test_urllib() -> str:
session._run_query(f"drop function if exists {temp_func_name}(int)")


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="artifact repository not supported in local testing",
)
@pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources")
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="artifact repository requires Python 3.9+"
)
def test_register_artifact_repository_with_packages_includes_cloudpickle(session):
"""Test that cloudpickle is available when using artifact_repository with packages."""

def test_cloudpickle() -> str:
import cloudpickle

# Test that cloudpickle is available and works
def test_func(x):
return x + 1

pickled = cloudpickle.dumps(test_func)
unpickled = cloudpickle.loads(pickled)
return str(unpickled(5))

temp_func_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION)

try:
# Test function registration with packages list
udf(
func=test_cloudpickle,
name=temp_func_name,
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
packages=["urllib3", "requests"], # cloudpickle should be auto-added
)

# Test UDF call - should return "6" if cloudpickle works
df = session.create_dataframe([1]).to_df(["a"])
Utils.check_answer(df.select(call_udf(temp_func_name)), [Row("6")])
finally:
session._run_query(f"drop function if exists {temp_func_name}()")


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="artifact repository not supported in local testing",
Expand Down
78 changes: 78 additions & 0 deletions tests/unit/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,81 @@ def mock_callback(extension_function_properties):
"schema": "some_schema",
"application_roles": ["app_viewer"],
}


def test_artifact_repository_adds_cloudpickle():
"""Test that cloudpickle is automatically added when using artifact_repository with packages."""
from snowflake.snowpark._internal.udf_utils import resolve_imports_and_packages

# Test case 1: packages provided without cloudpickle
result = resolve_imports_and_packages(
session=None,
object_type=TempObjectType.FUNCTION,
func=lambda: 1,
arg_names=[],
udf_name="test_udf",
stage_location=None,
imports=None,
packages=["urllib3", "requests", "cloudpickle-non-existing"],
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
)
_, _, _, all_packages, _, _ = result

# Verify cloudpickle was added
assert all_packages is not None
package_list = all_packages.split(",") if all_packages else []
assert any(
pkg.strip().strip("'").startswith("cloudpickle==") for pkg in package_list
), f"cloudpickle not found in packages: {all_packages}"

# Test case 2: packages already contains cloudpickle
result2 = resolve_imports_and_packages(
session=None,
object_type=TempObjectType.FUNCTION,
func=lambda: 1,
arg_names=[],
udf_name="test_udf2",
stage_location=None,
imports=None,
packages=["urllib3", "cloudpickle>=2.0", "requests"],
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
)
_, _, _, all_packages2, _, _ = result2

# Verify cloudpickle was not duplicated
package_list2 = all_packages2.split(",") if all_packages2 else []
cloudpickle_count = sum(
1 for pkg in package_list2 if "cloudpickle" in pkg.strip().strip("'").lower()
)
assert (
cloudpickle_count == 1
), f"cloudpickle should appear exactly once, found {cloudpickle_count} times in: {all_packages2}"

# Test case 3: packages with various version specifiers
test_cases = [
["urllib3", "cloudpickle==2.2.1"],
["urllib3", "cloudpickle>=2.0"],
["urllib3", "cloudpickle~=2.2"],
["urllib3", "cloudpickle<=3.0"],
]

for packages in test_cases:
result = resolve_imports_and_packages(
session=None,
object_type=TempObjectType.FUNCTION,
func=lambda: 1,
arg_names=[],
udf_name="test_udf_versioned",
stage_location=None,
imports=None,
packages=packages,
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
)
_, _, _, all_packages, _, _ = result
package_list = all_packages.split(",") if all_packages else []
cloudpickle_count = sum(
1 for pkg in package_list if "cloudpickle" in pkg.strip().strip("'").lower()
)
assert (
cloudpickle_count == 1
), f"For {packages}, cloudpickle should appear exactly once, found {cloudpickle_count} times"
Loading