Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#### Bug Fixes

- Fixed a bug where `cloudpickle` was not automatically added to the package list when using `artifact_repository` with custom packages, causing `ModuleNotFoundError` at runtime.
- Fixed a bug when reading xml with custom schema, result include element attributes when column is not `StructType` type.

#### Improvements
Expand Down
19 changes: 14 additions & 5 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,11 +1230,20 @@ def resolve_imports_and_packages(
session._resolve_packages([], artifact_repository=artifact_repository)
)
elif packages:
if not all(isinstance(package, str) for package in packages):
raise TypeError(
"Artifact repository requires that all packages be passed as str."
)
resolved_packages = packages
has_cloudpickle = False
for pkg in packages:
if not isinstance(pkg, str):
raise TypeError(
"Artifact repository requires that all packages be passed as str."
)
# Note: According to PyPI search (https://pypi.org/search/?q=cloudpickle), and Anaconda search (https://anaconda.org/search?q=cloudpickle),
# "cloudpickle" is the only package with this prefix, making startswith() check safe.
if not has_cloudpickle and pkg.startswith("cloudpickle"):
has_cloudpickle = True

resolved_packages = list(packages)
if not has_cloudpickle:
resolved_packages.append(f"cloudpickle=={cloudpickle.__version__}")
else:
# resolve packages
if session is None: # In case of sandbox
Expand Down
35 changes: 35 additions & 0 deletions tests/integ/test_stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2336,6 +2336,41 @@ def artifact_repo_test(_):
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="artifact repository not supported in local testing",
)
@pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources")
def test_sproc_artifact_repository_with_packages_includes_cloudpickle(session):
"""Test that cloudpickle is available when using artifact_repository with packages."""

def test_cloudpickle(_: Session) -> str:
import cloudpickle

# Test that cloudpickle is available and works
def test_func(x):
return x + 1

pickled = cloudpickle.dumps(test_func)
unpickled = cloudpickle.loads(pickled)
return str(unpickled(5))

try:
test_cloudpickle_sproc = sproc(
test_cloudpickle,
session=session,
return_type=StringType(),
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
packages=["urllib3", "requests"], # cloudpickle should be auto-added
)
assert test_cloudpickle_sproc(session=session) == "6"
except SnowparkSQLException as ex:
if "No matching distribution found for snowflake-snowpark-python" in str(ex):
pytest.mark.xfail(
"Unreleased snowpark versions are unavailable in artifact repository."
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="artifact repository not supported in local testing",
Expand Down
37 changes: 37 additions & 0 deletions tests/integ/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2916,6 +2916,43 @@ def test_urllib() -> str:
session._run_query(f"drop function if exists {temp_func_name}(int)")


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="artifact repository not supported in local testing",
)
@pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources")
def test_register_artifact_repository_with_packages_includes_cloudpickle(session):
"""Test that cloudpickle is available when using artifact_repository with packages."""

def test_cloudpickle() -> str:
import cloudpickle

# Test that cloudpickle is available and works
def test_func(x):
return x + 1

pickled = cloudpickle.dumps(test_func)
unpickled = cloudpickle.loads(pickled)
return str(unpickled(5))

temp_func_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION)

try:
# Test function registration with packages list
udf(
func=test_cloudpickle,
name=temp_func_name,
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
packages=["urllib3", "requests"], # cloudpickle should be auto-added
)

# Test UDF call - should return "6" if cloudpickle works
df = session.create_dataframe([1]).to_df(["a"])
Utils.check_answer(df.select(call_udf(temp_func_name)), [Row("6")])
finally:
session._run_query(f"drop function if exists {temp_func_name}()")


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="artifact repository not supported in local testing",
Expand Down
78 changes: 78 additions & 0 deletions tests/unit/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,81 @@ def mock_callback(extension_function_properties):
"schema": "some_schema",
"application_roles": ["app_viewer"],
}


def test_artifact_repository_adds_cloudpickle():
"""Test that cloudpickle is automatically added when using artifact_repository with packages."""
from snowflake.snowpark._internal.udf_utils import resolve_imports_and_packages

# Test case 1: packages provided without cloudpickle
result = resolve_imports_and_packages(
session=None,
object_type=TempObjectType.FUNCTION,
func=lambda: 1,
arg_names=[],
udf_name="test_udf",
stage_location=None,
imports=None,
packages=["urllib3", "requests"],
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
)
_, _, _, all_packages, _, _ = result

# Verify cloudpickle was added
assert all_packages is not None
package_list = all_packages.split(",") if all_packages else []
assert any(
pkg.strip().strip("'").startswith("cloudpickle==") for pkg in package_list
), f"cloudpickle not found in packages: {all_packages}"

# Test case 2: packages already contains cloudpickle
result2 = resolve_imports_and_packages(
session=None,
object_type=TempObjectType.FUNCTION,
func=lambda: 1,
arg_names=[],
udf_name="test_udf2",
stage_location=None,
imports=None,
packages=["urllib3", "cloudpickle>=2.0", "requests"],
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
)
_, _, _, all_packages2, _, _ = result2

# Verify cloudpickle was not duplicated
package_list2 = all_packages2.split(",") if all_packages2 else []
cloudpickle_count = sum(
1 for pkg in package_list2 if "cloudpickle" in pkg.strip().strip("'").lower()
)
assert (
cloudpickle_count == 1
), f"cloudpickle should appear exactly once, found {cloudpickle_count} times in: {all_packages2}"

# Test case 3: packages with various version specifiers
test_cases = [
["urllib3", "cloudpickle==2.2.1"],
["urllib3", "cloudpickle>=2.0"],
["urllib3", "cloudpickle~=2.2"],
["urllib3", "cloudpickle<=3.0"],
]

for packages in test_cases:
result = resolve_imports_and_packages(
session=None,
object_type=TempObjectType.FUNCTION,
func=lambda: 1,
arg_names=[],
udf_name="test_udf_versioned",
stage_location=None,
imports=None,
packages=packages,
artifact_repository="SNOWPARK_PYTHON_TEST_REPOSITORY",
)
_, _, _, all_packages, _, _ = result
package_list = all_packages.split(",") if all_packages else []
cloudpickle_count = sum(
1 for pkg in package_list if "cloudpickle" in pkg.strip().strip("'").lower()
)
assert (
cloudpickle_count == 1
), f"For {packages}, cloudpickle should appear exactly once, found {cloudpickle_count} times"
Loading