Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
8aac920
initial implementation
sfc-gh-bkogan Feb 5, 2026
e481ef8
remove cloudpickle todo
sfc-gh-bkogan Feb 5, 2026
8e3950f
simple int test
sfc-gh-bkogan Feb 5, 2026
62cf6a0
update tests + fix bug
sfc-gh-bkogan Feb 5, 2026
e8dd42d
non conda resolve test
sfc-gh-bkogan Feb 5, 2026
97ce30d
remove _packages, merge with _artifact_repository_packages
sfc-gh-bkogan Feb 6, 2026
24a0185
remove more _packages
sfc-gh-bkogan Feb 6, 2026
07cc8b9
add default cache
sfc-gh-bkogan Feb 6, 2026
bde687f
fix UT
sfc-gh-bkogan Feb 6, 2026
97ef1aa
doc udpates
sfc-gh-bkogan Feb 6, 2026
1f45379
fix more tests
sfc-gh-bkogan Feb 6, 2026
44ee0e4
remove debug log
sfc-gh-bkogan Feb 6, 2026
e2b71f0
Merge branch 'main' of github.com:snowflakedb/snowpark-python into bk…
sfc-gh-bkogan Feb 6, 2026
026c024
cleanup
sfc-gh-bkogan Feb 6, 2026
eea6081
wrap getting the default in a lock
sfc-gh-bkogan Feb 6, 2026
44f7114
Merge branch 'main' of github.com:snowflakedb/snowpark-python into bk…
sfc-gh-bkogan Feb 6, 2026
9746f7f
update comment
sfc-gh-bkogan Feb 10, 2026
8562a72
update test
sfc-gh-bkogan Feb 10, 2026
4a51691
try using test artifact repo
sfc-gh-bkogan Feb 10, 2026
273b028
filter out system call checks in modin
sfc-gh-bkogan Feb 10, 2026
1bbd89e
Merge branch 'main' of github.com:snowflakedb/snowpark-python into bk…
sfc-gh-bkogan Feb 17, 2026
4f3d174
private
sfc-gh-bkogan Feb 17, 2026
70e97e7
move constants to context + temp schema in test
sfc-gh-bkogan Feb 18, 2026
348210f
Merge branch 'main' of github.com:snowflakedb/snowpark-python into bk…
sfc-gh-bkogan Feb 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,21 +1222,44 @@ def resolve_imports_and_packages(
Optional[str],
bool,
]:
if artifact_repository and artifact_repository != "conda":
# Artifact Repository packages are not resolved
from snowflake.snowpark.session import ANACONDA_SHARED_REPOSITORY

use_default_artifact_repository = artifact_repository is None
if use_default_artifact_repository:
artifact_repository = (
session._get_default_artifact_repository()
if session is not None
else ANACONDA_SHARED_REPOSITORY
)

# TODO: if the user explicitly passes in the current default, should we use self._packages?
# note that the current default could change after calling session.add_packages, so it's hard
# to know what the intended default is
existing_packages_dict = {}
if session:
existing_packages_dict = (
session._packages
if use_default_artifact_repository
else session._artifact_repository_packages[artifact_repository]
)

if artifact_repository != ANACONDA_SHARED_REPOSITORY:
# Non-conda artifact repository - skip conda-based package resolution
resolved_packages = []
if not packages and session:
resolved_packages = list(
session._resolve_packages([], artifact_repository=artifact_repository)
session._resolve_packages(
[], artifact_repository, existing_packages_dict
)
)
elif packages:
if not all(isinstance(package, str) for package in packages):
raise TypeError(
"Artifact repository requires that all packages be passed as str."
"Non-conda artifact repository requires that all packages be passed as str."
)
resolved_packages = packages
else:
# resolve packages
# resolve packages using conda channel
if session is None: # In case of sandbox
resolved_packages = resolve_packages_in_client_side_sandbox(
packages=packages
Expand All @@ -1245,14 +1268,17 @@ def resolve_imports_and_packages(
resolved_packages = (
session._resolve_packages(
packages,
artifact_repository,
{}, # ignore session packages if passed in explicitly
include_pandas=is_pandas_udf,
statement_params=statement_params,
_suppress_local_package_warnings=_suppress_local_package_warnings,
)
if packages is not None
else session._resolve_packages(
[],
session._packages,
artifact_repository,
existing_packages_dict,
validate_package=False,
include_pandas=is_pandas_udf,
statement_params=statement_params,
Expand Down
59 changes: 46 additions & 13 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,10 @@
WRITE_PANDAS_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None
WRITE_ARROW_CHUNK_SIZE: int = 100000 if is_in_stored_procedure() else None

# The fully qualified name of the Anaconda shared repository (conda channel).
# Used as the fallback/default when the system function is unavailable or returns NULL.
ANACONDA_SHARED_REPOSITORY = "snowflake.snowpark.anaconda_shared_repository"


def _get_active_session() -> "Session":
with _session_management_lock:
Expand Down Expand Up @@ -599,7 +603,10 @@ def __init__(
self._conn = conn
self._query_tag = None
self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {}
# packages that should be added under the default artifact repository
# TODO: now that we have dynamic defaults, should we remove this and just use _artifact_repository_packages always?
self._packages: Dict[str, str] = {}
# packages that should be added under an explicit artifact repository
self._artifact_repository_packages: DefaultDict[
str, Dict[str, str]
] = defaultdict(dict)
Expand Down Expand Up @@ -1669,10 +1676,20 @@ def add_packages(
to ensure the consistent experience of a UDF between your local environment
and the Snowflake server.
"""
use_default_artifact_repository = artifact_repository is None
if use_default_artifact_repository:
artifact_repository = self._get_default_artifact_repository()

existing_packages_dict = (
self._packages
if use_default_artifact_repository
else self._artifact_repository_packages[artifact_repository]
)

self._resolve_packages(
parse_positional_args_to_list(*packages),
self._packages,
artifact_repository=artifact_repository,
artifact_repository,
existing_packages_dict,
)

def remove_package(
Expand Down Expand Up @@ -2097,11 +2114,11 @@ def _get_req_identifiers_list(
def _resolve_packages(
self,
packages: List[Union[str, ModuleType]],
existing_packages_dict: Optional[Dict[str, str]] = None,
artifact_repository: str,
existing_packages_dict: Dict[str, str],
validate_package: bool = True,
include_pandas: bool = False,
statement_params: Optional[Dict[str, str]] = None,
artifact_repository: Optional[str] = None,
**kwargs,
) -> List[str]:
"""
Expand All @@ -2128,18 +2145,12 @@ def _resolve_packages(
package_dict = self._parse_packages(packages)
if (
isinstance(self._conn, MockServerConnection)
or artifact_repository is not None
or artifact_repository != ANACONDA_SHARED_REPOSITORY
):
# in local testing we don't resolve the packages, we just return what is added
# in local testing or non-conda, we don't resolve the packages, we just return what is added
errors = []
with self._package_lock:
if artifact_repository is None:
result_dict = self._packages
else:
result_dict = self._artifact_repository_packages[
artifact_repository
]

result_dict = existing_packages_dict
for pkg_name, _, pkg_req in package_dict.values():
if (
pkg_name in result_dict
Expand Down Expand Up @@ -2377,6 +2388,28 @@ def _upload_unsupported_packages(

return supported_dependencies + new_dependencies

def _get_default_artifact_repository(self) -> str:
"""
Returns the default artifact repository for the current session context
by calling SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY.

Falls back to the Anaconda shared repository (conda) if:
- the system function is not available / fails, or
- the system function returns NULL (value was never set).
"""
if isinstance(self._conn, MockServerConnection):
return ANACONDA_SHARED_REPOSITORY

try:
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
result = self._run_query(
f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')"
)
value = result[0][0] if result else None
return value or ANACONDA_SHARED_REPOSITORY
except Exception:
return ANACONDA_SHARED_REPOSITORY

def _is_anaconda_terms_acknowledged(self) -> bool:
return self._run_query("select system$are_anaconda_terms_acknowledged()")[0][0]

Expand Down
4 changes: 4 additions & 0 deletions tests/integ/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,10 @@ def session(
"alter session set ENABLE_EXTRACTION_PUSHDOWN_EXTERNAL_PARQUET_FOR_COPY_PHASE_I='Track';"
).collect()
session.sql("alter session set ENABLE_ROW_ACCESS_POLICY=true").collect()
# TODO: remove
session.sql(
"ALTER SESSION SET ENABLE_DEFAULT_PYTHON_ARTIFACT_REPOSITORY = true"
).collect()

try:
yield session
Expand Down
40 changes: 40 additions & 0 deletions tests/integ/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3009,3 +3009,43 @@ def test_urllib() -> str:
)
df = session.create_dataframe([1]).to_df(["a"])
Utils.check_answer(df.select(ar_udf()), [Row("test")])


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="artifact repository not supported in local testing",
)
# @pytest.mark.skipif(IS_NOT_ON_GITHUB, reason="need resources")
@pytest.mark.skipif(
sys.version_info < (3, 9), reason="artifact repository requires Python 3.9+"
)
def test_use_default_artifact_repository(session):
# TODO: is this safe with parallel testing?
session.sql(
"ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = snowflake.snowpark.pypi_shared_repository"
).collect()

session.add_packages("art", "cloudpickle")

def test_art() -> str:
import art # art is not available in the conda channel, but is in pypi

_ = art.text2art("test")
return "art works!"

temp_func_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION)

try:
# Test function registration
udf(
func=test_art,
name=temp_func_name,
)

# Test UDF call
df = session.create_dataframe([1]).to_df(["a"])
Utils.check_answer(df.select(call_udf(temp_func_name)), [Row("art works!")])
finally:
session._run_query(f"drop function if exists {temp_func_name}(int)")

session.sql("ALTER schema unset DEFAULT_PYTHON_ARTIFACT_REPOSITORY").collect()
85 changes: 81 additions & 4 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
SnowparkInvalidObjectNameException,
SnowparkSessionException,
)
from snowflake.snowpark.session import _PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING
from snowflake.snowpark.session import (
_PYTHON_SNOWPARK_USE_SCOPED_TEMP_OBJECTS_STRING,
ANACONDA_SHARED_REPOSITORY,
)
from snowflake.snowpark.types import StructField, StructType


Expand Down Expand Up @@ -211,7 +214,11 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True
session.table.side_effect = mock_get_information_schema_packages

session._resolve_packages(
["random_package_name"], validate_package=True, include_pandas=False
["random_package_name"],
ANACONDA_SHARED_REPOSITORY,
{},
validate_package=True,
include_pandas=False,
)


Expand Down Expand Up @@ -242,7 +249,11 @@ def run_query(sql: str):
"#using-third-party-packages-from-anaconda.",
):
session._resolve_packages(
["random_package_name"], validate_package=True, include_pandas=False
["random_package_name"],
ANACONDA_SHARED_REPOSITORY,
{},
validate_package=True,
include_pandas=False,
)


Expand All @@ -264,7 +275,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True

resolved_packages = session._resolve_packages(
["random_package_name"],
existing_packages_dict=existing_packages,
ANACONDA_SHARED_REPOSITORY,
existing_packages,
validate_package=True,
include_pandas=False,
)
Expand Down Expand Up @@ -295,6 +307,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True
):
session._resolve_packages(
["snowflake-snowpark-python"],
ANACONDA_SHARED_REPOSITORY,
{},
validate_package=True,
include_pandas=False,
_suppress_local_package_warnings=True,
Expand All @@ -304,6 +318,38 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True
assert caplog.text == ""


def test_resolve_packages_non_conda_artifact_repository(mock_server_connection):
session = Session(mock_server_connection)

existing_packages = {}

def assert_packages(packages):
assert sorted(packages) == [
"cloudpickle==1.0.0",
"snowflake-snowpark-python==1.0.0",
]
assert existing_packages == {
"snowflake-snowpark-python": "snowflake-snowpark-python==1.0.0",
"cloudpickle": "cloudpickle==1.0.0",
}

packages = session._resolve_packages(
["snowflake-snowpark-python==1.0.0", "cloudpickle==1.0.0"],
"snowflake.snowpark.pypi_shared_repository",
existing_packages,
)

assert_packages(packages)

packages = session._resolve_packages(
[],
"snowflake.snowpark.pypi_shared_repository",
existing_packages,
)

assert_packages(packages)


@pytest.mark.skipif(not is_pandas_available, reason="requires pandas for write_pandas")
def test_write_pandas_wrong_table_type(mock_server_connection):
session = Session(mock_server_connection)
Expand Down Expand Up @@ -674,3 +720,34 @@ def test_parameter_version(version_value, expected_parameter_value, parameter_na
)
session = Session(fake_server_connection)
assert getattr(session, parameter_name, None) is expected_parameter_value


def test_get_default_artifact_repository():
fake_server_connection = mock.create_autospec(ServerConnection)
fake_server_connection._thread_safe_session_enabled = True
session = Session(fake_server_connection)

with mock.patch.object(
session,
"_run_query",
return_value=[["snowflake.snowpark.pypi_shared_repository"]],
):
result = session._get_default_artifact_repository()
assert result == "snowflake.snowpark.pypi_shared_repository"

with mock.patch.object(
session,
"_run_query",
return_value=[[None]],
):
result = session._get_default_artifact_repository()
assert result == ANACONDA_SHARED_REPOSITORY

# throws error
with mock.patch.object(
session,
"_run_query",
side_effect=ProgrammingError("Function not found"),
):
result = session._get_default_artifact_repository()
assert result == ANACONDA_SHARED_REPOSITORY
Loading