Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#### Bug Fixes

- Fixed a bug where `cloudpickle` was not automatically added to the package list when using `artifact_repository` with custom packages, causing `ModuleNotFoundError` at runtime.
- Fixed a bug when reading xml with custom schema, result include element attributes when column is not `StructType` type.

#### Improvements
Expand Down
22 changes: 21 additions & 1 deletion src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)

import cloudpickle
from packaging.requirements import Requirement

import snowflake.snowpark
from snowflake.connector.options import installed_pandas, pandas
Expand Down Expand Up @@ -1234,7 +1235,26 @@ def resolve_imports_and_packages(
raise TypeError(
"Artifact repository requires that all packages be passed as str."
)
resolved_packages = packages
try:
has_cloudpickle = bool(
any(
Requirement(pkg).name.lower() == "cloudpickle"
for pkg in packages
)
)
except BaseException:
# backward compatibility, we don't raise an error here
# based on PyPI search (https://pypi.org/search/?q=cloudpickle), and Anaconda search (https://anaconda.org/search?q=cloudpickle),
# "cloudpickle" is the only package with this prefix, making startswith() check safe.
has_cloudpickle = bool(
any(pkg.startswith("cloudpickle") for pkg in packages)
)
resolved_packages = packages + (
[f"cloudpickle=={cloudpickle.__version__}"]
if not has_cloudpickle
else []
)

else:
# resolve packages
if session is None: # In case of sandbox
Expand Down
37 changes: 36 additions & 1 deletion tests/integ/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2331,7 +2331,42 @@ def artifact_repo_test(_):
)
except SnowparkSQLException as ex:
if "No matching distribution found for snowflake-snowpark-python" in str(ex):
pytest.mark.xfail(
pytest.xfail(
"Unreleased snowpark versions are unavailable in artifact repository."
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="artifact repository not supported in local testing",
)
@pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources")
def test_sproc_artifact_repository_with_packages_includes_cloudpickle(session):
"""Test that cloudpickle is available when using artifact_repository with packages."""

def test_cloudpickle(_: Session) -> str:
import cloudpickle

# Test that cloudpickle is available and works
def test_func(x):
return x + 1

pickled = cloudpickle.dumps(test_func)
unpickled = cloudpickle.loads(pickled)
return str(unpickled(5))

try:
test_cloudpickle_sproc = sproc(
test_cloudpickle,
session=session,
return_type=StringType(),
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
packages=["urllib3", "requests"], # cloudpickle should be auto-added
)
assert test_cloudpickle_sproc(session=session) == "6"
except SnowparkSQLException as ex:
if "No matching distribution found for snowflake-snowpark-python" in str(ex):
pytest.xfail(
"Unreleased snowpark versions are unavailable in artifact repository."
)

Expand Down
37 changes: 37 additions & 0 deletions tests/integ/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2916,6 +2916,43 @@ def test_urllib() -> str:
session._run_query(f"drop function if exists {temp_func_name}(int)")


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="artifact repository not supported in local testing",
)
@pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources")
def test_register_artifact_repository_with_packages_includes_cloudpickle(session):
"""Test that cloudpickle is available when using artifact_repository with packages."""

def test_cloudpickle() -> str:
import cloudpickle

# Test that cloudpickle is available and works
def test_func(x):
return x + 1

pickled = cloudpickle.dumps(test_func)
unpickled = cloudpickle.loads(pickled)
return str(unpickled(5))

temp_func_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION)

try:
# Test function registration with packages list
udf(
func=test_cloudpickle,
name=temp_func_name,
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
packages=["urllib3", "requests"], # cloudpickle should be auto-added
)

# Test UDF call - should return "6" if cloudpickle works
df = session.create_dataframe([1]).to_df(["a"])
Utils.check_answer(df.select(call_udf(temp_func_name)), [Row("6")])
finally:
session._run_query(f"drop function if exists {temp_func_name}()")


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="artifact repository not supported in local testing",
Expand Down
78 changes: 78 additions & 0 deletions tests/unit/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,81 @@ def mock_callback(extension_function_properties):
"schema": "some_schema",
"application_roles": ["app_viewer"],
}


def test_artifact_repository_adds_cloudpickle():
"""Test that cloudpickle is automatically added when using artifact_repository with packages."""
from snowflake.snowpark._internal.udf_utils import resolve_imports_and_packages

# Test case 1: packages provided without cloudpickle
result = resolve_imports_and_packages(
session=None,
object_type=TempObjectType.FUNCTION,
func=lambda: 1,
arg_names=[],
udf_name="test_udf",
stage_location=None,
imports=None,
packages=["urllib3", "requests", "invalid package!!!"],
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
)
_, _, _, all_packages, _, _ = result

# Verify cloudpickle was added
assert all_packages is not None
package_list = all_packages.split(",") if all_packages else []
assert any(
pkg.strip().strip("'").startswith("cloudpickle==") for pkg in package_list
), f"cloudpickle not found in packages: {all_packages}"

# Test case 2: packages already contains cloudpickle
result2 = resolve_imports_and_packages(
session=None,
object_type=TempObjectType.FUNCTION,
func=lambda: 1,
arg_names=[],
udf_name="test_udf2",
stage_location=None,
imports=None,
packages=["urllib3", "cloudpickle>=2.0", "requests"],
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
)
_, _, _, all_packages2, _, _ = result2

# Verify cloudpickle was not duplicated
package_list2 = all_packages2.split(",") if all_packages2 else []
cloudpickle_count = sum(
1 for pkg in package_list2 if "cloudpickle" in pkg.strip().strip("'").lower()
)
assert (
cloudpickle_count == 1
), f"cloudpickle should appear exactly once, found {cloudpickle_count} times in: {all_packages2}"

# Test case 3: packages with various version specifiers
test_cases = [
["urllib3", "cloudpickle==2.2.1"],
["urllib3", "cloudpickle>=2.0"],
["urllib3", "cloudpickle~=2.2"],
["urllib3", "cloudpickle<=3.0"],
]

for packages in test_cases:
result = resolve_imports_and_packages(
session=None,
object_type=TempObjectType.FUNCTION,
func=lambda: 1,
arg_names=[],
udf_name="test_udf_versioned",
stage_location=None,
imports=None,
packages=packages,
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
)
_, _, _, all_packages, _, _ = result
package_list = all_packages.split(",") if all_packages else []
cloudpickle_count = sum(
1 for pkg in package_list if "cloudpickle" in pkg.strip().strip("'").lower()
)
assert (
cloudpickle_count == 1
), f"For {packages}, cloudpickle should appear exactly once, found {cloudpickle_count} times"
Loading