Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Release History

## 1.46.0 (TBD)

### Snowpark Python API Updates

#### Bug Fixes

- Fixed a bug where `cloudpickle` was not automatically added to the package list when using `artifact_repository` with custom packages, causing `ModuleNotFoundError` at runtime.

## 1.45.0 (TBD)

### Snowpark Python API Updates
Expand Down
8 changes: 7 additions & 1 deletion src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)

import cloudpickle
from packaging.requirements import Requirement

import snowflake.snowpark
from snowflake.connector.options import installed_pandas, pandas
Expand Down Expand Up @@ -1234,7 +1235,12 @@ def resolve_imports_and_packages(
raise TypeError(
"Artifact repository requires that all packages be passed as str."
)
resolved_packages = packages
resolved_packages = list(packages)
if not any(
Requirement(pkg).name.lower() == "cloudpickle"
for pkg in resolved_packages
):
resolved_packages.append(f"cloudpickle=={cloudpickle.__version__}")
else:
# resolve packages
if session is None: # In case of sandbox
Expand Down
40 changes: 40 additions & 0 deletions tests/integ/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2916,6 +2916,46 @@ def test_urllib() -> str:
session._run_query(f"drop function if exists {temp_func_name}(int)")


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="artifact repository not supported in local testing",
)
@pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources")
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="artifact repository requires Python 3.9+"
)
def test_register_artifact_repository_with_packages_includes_cloudpickle(session):
"""Test that cloudpickle is available when using artifact_repository with packages."""

def test_cloudpickle() -> str:
import cloudpickle

# Test that cloudpickle is available and works
def test_func(x):
return x + 1

pickled = cloudpickle.dumps(test_func)
unpickled = cloudpickle.loads(pickled)
return str(unpickled(5))

temp_func_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION)

try:
# Test function registration with packages list
udf(
func=test_cloudpickle,
name=temp_func_name,
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
packages=["urllib3", "requests"], # cloudpickle should be auto-added
)

# Test UDF call - should return "6" if cloudpickle works
df = session.create_dataframe([1]).to_df(["a"])
Utils.check_answer(df.select(call_udf(temp_func_name)), [Row("6")])
finally:
session._run_query(f"drop function if exists {temp_func_name}()")


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="artifact repository not supported in local testing",
Expand Down
78 changes: 78 additions & 0 deletions tests/unit/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,81 @@ def mock_callback(extension_function_properties):
"schema": "some_schema",
"application_roles": ["app_viewer"],
}


def test_artifact_repository_adds_cloudpickle():
"""Test that cloudpickle is automatically added when using artifact_repository with packages."""
from snowflake.snowpark._internal.udf_utils import resolve_imports_and_packages

# Test case 1: packages provided without cloudpickle
result = resolve_imports_and_packages(
session=None,
object_type=TempObjectType.FUNCTION,
func=lambda: 1,
arg_names=[],
udf_name="test_udf",
stage_location=None,
imports=None,
packages=["urllib3", "requests", "cloudpickle-non-existing"],
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
)
_, _, _, all_packages, _, _ = result

# Verify cloudpickle was added
assert all_packages is not None
package_list = all_packages.split(",") if all_packages else []
assert any(
pkg.strip().strip("'").startswith("cloudpickle==") for pkg in package_list
), f"cloudpickle not found in packages: {all_packages}"

# Test case 2: packages already contains cloudpickle
result2 = resolve_imports_and_packages(
session=None,
object_type=TempObjectType.FUNCTION,
func=lambda: 1,
arg_names=[],
udf_name="test_udf2",
stage_location=None,
imports=None,
packages=["urllib3", "cloudpickle>=2.0", "requests"],
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
)
_, _, _, all_packages2, _, _ = result2

# Verify cloudpickle was not duplicated
package_list2 = all_packages2.split(",") if all_packages2 else []
cloudpickle_count = sum(
1 for pkg in package_list2 if "cloudpickle" in pkg.strip().strip("'").lower()
)
assert (
cloudpickle_count == 1
), f"cloudpickle should appear exactly once, found {cloudpickle_count} times in: {all_packages2}"

# Test case 3: packages with various version specifiers
test_cases = [
["urllib3", "cloudpickle==2.2.1"],
["urllib3", "cloudpickle>=2.0"],
["urllib3", "cloudpickle~=2.2"],
["urllib3", "cloudpickle<=3.0"],
]

for packages in test_cases:
result = resolve_imports_and_packages(
session=None,
object_type=TempObjectType.FUNCTION,
func=lambda: 1,
arg_names=[],
udf_name="test_udf_versioned",
stage_location=None,
imports=None,
packages=packages,
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
)
_, _, _, all_packages, _, _ = result
package_list = all_packages.split(",") if all_packages else []
cloudpickle_count = sum(
1 for pkg in package_list if "cloudpickle" in pkg.strip().strip("'").lower()
)
assert (
cloudpickle_count == 1
), f"For {packages}, cloudpickle should appear exactly once, found {cloudpickle_count} times"
Loading