Skip to content

Commit d4280f4

Browse files
SNOW-3052213: Add support for the DEFAULT_PYTHON_ARTIFACT_REPOSITORY parameter (#4073)
1 parent ab49a74 commit d4280f4

File tree

14 files changed

+359
-66
lines changed

14 files changed

+359
-66
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#### New Features
2121

2222
- Added support for the `DECFLOAT` data type that allows users to represent decimal numbers exactly with 38 digits of precision and a dynamic base-10 exponent.
23+
- Added support for the `DEFAULT_PYTHON_ARTIFACT_REPOSITORY` parameter that allows users to configure the default artifact repository at the account, database, and schema level.
2324

2425
#### Bug Fixes
2526

src/snowflake/snowpark/_internal/udf_utils.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@
5959
)
6060
from snowflake.snowpark.types import DataType, StructField, StructType
6161
from snowflake.snowpark.version import VERSION
62+
from snowflake.snowpark.context import (
63+
_ANACONDA_SHARED_REPOSITORY,
64+
_DEFAULT_ARTIFACT_REPOSITORY,
65+
)
6266

6367
if installed_pandas:
6468
from snowflake.snowpark.types import (
@@ -1122,6 +1126,7 @@ def {_DEFAULT_HANDLER_NAME}({wrapper_params}):
11221126
def add_snowpark_package_to_sproc_packages(
11231127
session: Optional["snowflake.snowpark.Session"],
11241128
packages: Optional[List[Union[str, ModuleType]]],
1129+
artifact_repository: str,
11251130
) -> List[Union[str, ModuleType]]:
11261131
major, minor, patch = VERSION
11271132
package_name = "snowflake-snowpark-python"
@@ -1137,8 +1142,11 @@ def add_snowpark_package_to_sproc_packages(
11371142
packages = [this_package]
11381143
else:
11391144
with session._package_lock:
1140-
if package_name not in session._packages:
1141-
packages = list(session._packages.values()) + [this_package]
1145+
existing_packages = session._artifact_repository_packages[
1146+
artifact_repository
1147+
]
1148+
if package_name not in existing_packages:
1149+
packages = list(existing_packages.values()) + [this_package]
11421150
return packages
11431151

11441152
return add_package_to_existing_packages(packages, package_name, this_package)
@@ -1223,17 +1231,30 @@ def resolve_imports_and_packages(
12231231
Optional[str],
12241232
bool,
12251233
]:
1226-
if artifact_repository and artifact_repository != "conda":
1227-
# Artifact Repository packages are not resolved
1234+
if artifact_repository is None:
1235+
artifact_repository = (
1236+
session._get_default_artifact_repository()
1237+
if session
1238+
else _DEFAULT_ARTIFACT_REPOSITORY
1239+
)
1240+
1241+
existing_packages_dict = (
1242+
session._artifact_repository_packages[artifact_repository] if session else {}
1243+
)
1244+
1245+
if artifact_repository != _ANACONDA_SHARED_REPOSITORY:
1246+
# Non-conda artifact repository - skip conda-based package resolution
12281247
resolved_packages = []
12291248
if not packages and session:
12301249
resolved_packages = list(
1231-
session._resolve_packages([], artifact_repository=artifact_repository)
1250+
session._resolve_packages(
1251+
[], artifact_repository, existing_packages_dict
1252+
)
12321253
)
12331254
elif packages:
12341255
if not all(isinstance(package, str) for package in packages):
12351256
raise TypeError(
1236-
"Artifact repository requires that all packages be passed as str."
1257+
"Non-conda artifact repository requires that all packages be passed as str."
12371258
)
12381259
try:
12391260
has_cloudpickle = bool(
@@ -1256,7 +1277,7 @@ def resolve_imports_and_packages(
12561277
)
12571278

12581279
else:
1259-
# resolve packages
1280+
# resolve packages using conda channel
12601281
if session is None: # In case of sandbox
12611282
resolved_packages = resolve_packages_in_client_side_sandbox(
12621283
packages=packages
@@ -1265,14 +1286,17 @@ def resolve_imports_and_packages(
12651286
resolved_packages = (
12661287
session._resolve_packages(
12671288
packages,
1289+
artifact_repository,
1290+
{}, # ignore session packages if passed in explicitly
12681291
include_pandas=is_pandas_udf,
12691292
statement_params=statement_params,
12701293
_suppress_local_package_warnings=_suppress_local_package_warnings,
12711294
)
12721295
if packages is not None
12731296
else session._resolve_packages(
12741297
[],
1275-
session._packages,
1298+
artifact_repository,
1299+
existing_packages_dict,
12761300
validate_package=False,
12771301
include_pandas=is_pandas_udf,
12781302
statement_params=statement_params,

src/snowflake/snowpark/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@
4848
# example: _integral_type_default_precision = {IntegerType: 9}, IntegerType default _precision is 9 now
4949
_integral_type_default_precision = {}
5050

51+
# The fully qualified name of the Anaconda shared repository (conda channel).
52+
_ANACONDA_SHARED_REPOSITORY = "snowflake.snowpark.anaconda_shared_repository"
53+
# In case of failures or the current default artifact repository is unset, we fallback to this
54+
_DEFAULT_ARTIFACT_REPOSITORY = _ANACONDA_SHARED_REPOSITORY
55+
5156

5257
def configure_development_features(
5358
*,

src/snowflake/snowpark/session.py

Lines changed: 85 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@
156156
from snowflake.snowpark.context import (
157157
_is_execution_environment_sandboxed_for_client,
158158
_use_scoped_temp_objects,
159+
_ANACONDA_SHARED_REPOSITORY,
160+
_DEFAULT_ARTIFACT_REPOSITORY,
159161
)
160162
from snowflake.snowpark.dataframe import DataFrame
161163
from snowflake.snowpark.dataframe_reader import DataFrameReader
@@ -599,10 +601,17 @@ def __init__(
599601
self._conn = conn
600602
self._query_tag = None
601603
self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {}
602-
self._packages: Dict[str, str] = {}
604+
# map of artifact repository name -> packages that should be added to functions under that repository
603605
self._artifact_repository_packages: DefaultDict[
604606
str, Dict[str, str]
605607
] = defaultdict(dict)
608+
# Single-entry cache for the default artifact repository value.
609+
# Stores a tuple of ((database, schema), cached_value). Only one entry is
610+
# kept at a time – switching to a different database/schema will evict the old
611+
# value and trigger a fresh query on the next call.
612+
self._default_artifact_repository_cache: Optional[
613+
Tuple[Tuple[Optional[str], Optional[str]], str]
614+
] = None
606615
self._session_id = self._conn.get_session_id()
607616
self._session_info = f"""
608617
"version" : {get_version()},
@@ -1598,11 +1607,13 @@ def get_packages(self, artifact_repository: Optional[str] = None) -> Dict[str, s
15981607
15991608
Args:
16001609
artifact_repository: When set this will function will return the packages for a specific artifact repository.
1610+
Otherwise, uses the default artifact repository configured in the current context.
16011611
"""
1612+
if artifact_repository is None:
1613+
artifact_repository = self._get_default_artifact_repository()
1614+
16021615
with self._package_lock:
1603-
if artifact_repository:
1604-
return self._artifact_repository_packages[artifact_repository].copy()
1605-
return self._packages.copy()
1616+
return self._artifact_repository_packages[artifact_repository].copy()
16061617

16071618
def add_packages(
16081619
self,
@@ -1629,7 +1640,8 @@ def add_packages(
16291640
for this argument. If a ``module`` object is provided, the package will be
16301641
installed with the version in the local environment.
16311642
artifact_repository: When set this parameter specifies the artifact repository that packages will be added from. Only functions
1632-
using that repository will use the packages. (Default None)
1643+
using that repository will use the packages. (Default None). Otherwise, uses the default artifact repository configured in the
1644+
current context.
16331645
16341646
Example::
16351647
@@ -1669,10 +1681,13 @@ def add_packages(
16691681
to ensure the consistent experience of a UDF between your local environment
16701682
and the Snowflake server.
16711683
"""
1684+
if artifact_repository is None:
1685+
artifact_repository = self._get_default_artifact_repository()
1686+
16721687
self._resolve_packages(
16731688
parse_positional_args_to_list(*packages),
1674-
self._packages,
1675-
artifact_repository=artifact_repository,
1689+
artifact_repository,
1690+
self._artifact_repository_packages[artifact_repository],
16761691
)
16771692

16781693
def remove_package(
@@ -1686,7 +1701,8 @@ def remove_package(
16861701
Args:
16871702
package: The package name.
16881703
artifact_repository: When set this parameter specifies that the package should be removed
1689-
from the default packages for a specific artifact repository.
1704+
from the default packages for a specific artifact repository. Otherwise, uses the default
1705+
artifact repository configured in the current context.
16901706
16911707
Examples::
16921708
@@ -1704,17 +1720,13 @@ def remove_package(
17041720
0
17051721
"""
17061722
package_name = Requirement(package).name
1723+
if artifact_repository is None:
1724+
artifact_repository = self._get_default_artifact_repository()
1725+
17071726
with self._package_lock:
1708-
if (
1709-
artifact_repository is not None
1710-
and package_name
1711-
in self._artifact_repository_packages.get(artifact_repository, {})
1712-
):
1713-
self._artifact_repository_packages[artifact_repository].pop(
1714-
package_name
1715-
)
1716-
elif package_name in self._packages:
1717-
self._packages.pop(package_name)
1727+
packages = self._artifact_repository_packages[artifact_repository]
1728+
if package_name in packages:
1729+
packages.pop(package_name)
17181730
else:
17191731
raise ValueError(f"{package_name} is not in the package list")
17201732

@@ -1726,11 +1738,11 @@ def clear_packages(
17261738
Clears all third-party packages of a user-defined function (UDF). When artifact_repository
17271739
is set packages are only clear from the specified repository.
17281740
"""
1741+
if artifact_repository is None:
1742+
artifact_repository = self._get_default_artifact_repository()
1743+
17291744
with self._package_lock:
1730-
if artifact_repository is not None:
1731-
self._artifact_repository_packages.get(artifact_repository, {}).clear()
1732-
else:
1733-
self._packages.clear()
1745+
self._artifact_repository_packages[artifact_repository].clear()
17341746

17351747
def add_requirements(
17361748
self,
@@ -1747,7 +1759,8 @@ def add_requirements(
17471759
Args:
17481760
file_path: The path of a local requirement file.
17491761
artifact_repository: When set this parameter specifies the artifact repository that packages will be added from. Only functions
1750-
using that repository will use the packages. (Default None)
1762+
using that repository will use the packages. (Default None). Otherwise, uses the default artifact repository configured in
1763+
the current context.
17511764
17521765
Example::
17531766
@@ -2097,11 +2110,11 @@ def _get_req_identifiers_list(
20972110
def _resolve_packages(
20982111
self,
20992112
packages: List[Union[str, ModuleType]],
2100-
existing_packages_dict: Optional[Dict[str, str]] = None,
2113+
artifact_repository: str,
2114+
existing_packages_dict: Dict[str, str],
21012115
validate_package: bool = True,
21022116
include_pandas: bool = False,
21032117
statement_params: Optional[Dict[str, str]] = None,
2104-
artifact_repository: Optional[str] = None,
21052118
**kwargs,
21062119
) -> List[str]:
21072120
"""
@@ -2128,18 +2141,12 @@ def _resolve_packages(
21282141
package_dict = self._parse_packages(packages)
21292142
if (
21302143
isinstance(self._conn, MockServerConnection)
2131-
or artifact_repository is not None
2144+
or artifact_repository != _ANACONDA_SHARED_REPOSITORY
21322145
):
2133-
# in local testing we don't resolve the packages, we just return what is added
2146+
# in local testing or non-conda, we don't resolve the packages, we just return what is added
21342147
errors = []
21352148
with self._package_lock:
2136-
if artifact_repository is None:
2137-
result_dict = self._packages
2138-
else:
2139-
result_dict = self._artifact_repository_packages[
2140-
artifact_repository
2141-
]
2142-
2149+
result_dict = existing_packages_dict
21432150
for pkg_name, _, pkg_req in package_dict.values():
21442151
if (
21452152
pkg_name in result_dict
@@ -2377,6 +2384,50 @@ def _upload_unsupported_packages(
23772384

23782385
return supported_dependencies + new_dependencies
23792386

2387+
def _get_default_artifact_repository(self) -> str:
2388+
"""
2389+
Returns the default artifact repository for the current session context
2390+
by calling SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY.
2391+
2392+
The result is cached per (database, schema) pair so that
2393+
repeated invocations in the same context do not issue
2394+
redundant system-function queries. Only one cache entry is kept at
2395+
a time; switching to a different database or schema evicts the
2396+
previous entry and triggers a fresh query on the next call.
2397+
2398+
Falls back to the Snowflake default artifact repository if:
2399+
- the session uses a mock connection (local testing), or
2400+
- the system function is not available / fails, or
2401+
- the system function returns NULL (value was never set).
2402+
"""
2403+
with self._package_lock:
2404+
if isinstance(self._conn, MockServerConnection):
2405+
return _DEFAULT_ARTIFACT_REPOSITORY
2406+
2407+
cache_key = (self.get_current_database(), self.get_current_schema())
2408+
2409+
if (
2410+
self._default_artifact_repository_cache is not None
2411+
and self._default_artifact_repository_cache[0] == cache_key
2412+
):
2413+
return self._default_artifact_repository_cache[1]
2414+
2415+
try:
2416+
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
2417+
result = self._run_query(
2418+
f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')"
2419+
)
2420+
value = result[0][0] if result else None
2421+
resolved = value or _DEFAULT_ARTIFACT_REPOSITORY
2422+
except Exception as e:
2423+
_logger.warning(
2424+
f"Error getting default artifact repository: {e}. Using fallback: {_DEFAULT_ARTIFACT_REPOSITORY}."
2425+
)
2426+
resolved = _DEFAULT_ARTIFACT_REPOSITORY
2427+
2428+
self._default_artifact_repository_cache = (cache_key, resolved)
2429+
return resolved
2430+
23802431
def _is_anaconda_terms_acknowledged(self) -> bool:
23812432
return self._run_query("select system$are_anaconda_terms_acknowledged()")[0][0]
23822433

src/snowflake/snowpark/stored_procedure.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -939,10 +939,21 @@ def _do_register_sp(
939939
UDFColumn(dt, arg_name) for dt, arg_name in zip(input_types, arg_names[1:])
940940
]
941941

942+
effective_artifact_repository = artifact_repository
943+
if effective_artifact_repository is None:
944+
from snowflake.snowpark.session import _DEFAULT_ARTIFACT_REPOSITORY
945+
946+
effective_artifact_repository = (
947+
self._session._get_default_artifact_repository()
948+
if self._session
949+
else _DEFAULT_ARTIFACT_REPOSITORY
950+
)
951+
942952
# Add in snowflake-snowpark-python if it is not already in the package list.
943953
packages = add_snowpark_package_to_sproc_packages(
944954
session=self._session,
945955
packages=packages,
956+
artifact_repository=effective_artifact_repository,
946957
)
947958

948959
(
@@ -967,7 +978,7 @@ def _do_register_sp(
967978
skip_upload_on_content_match=skip_upload_on_content_match,
968979
is_permanent=is_permanent,
969980
force_inline_code=force_inline_code,
970-
artifact_repository=artifact_repository,
981+
artifact_repository=effective_artifact_repository,
971982
_suppress_local_package_warnings=kwargs.get(
972983
"_suppress_local_package_warnings", False
973984
),

0 commit comments

Comments
 (0)