Skip to content

Commit 71f9611

Browse files
sfc-gh-bkogansfc-gh-yuwang
authored andcommitted
SNOW-3052213: fix session package management (#4088)
1 parent cd296a0 commit 71f9611

File tree

10 files changed

+126
-84
lines changed

10 files changed

+126
-84
lines changed

src/snowflake/snowpark/_internal/udf_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,9 +1142,9 @@ def add_snowpark_package_to_sproc_packages(
11421142
packages = [this_package]
11431143
else:
11441144
with session._package_lock:
1145-
existing_packages = session._artifact_repository_packages[
1145+
existing_packages = session._get_packages_by_artifact_repository(
11461146
artifact_repository
1147-
]
1147+
)
11481148
if package_name not in existing_packages:
11491149
packages = list(existing_packages.values()) + [this_package]
11501150
return packages
@@ -1239,7 +1239,9 @@ def resolve_imports_and_packages(
12391239
)
12401240

12411241
existing_packages_dict = (
1242-
session._artifact_repository_packages[artifact_repository] if session else {}
1242+
session._get_packages_by_artifact_repository(artifact_repository)
1243+
if session
1244+
else {}
12431245
)
12441246

12451247
if artifact_repository != _ANACONDA_SHARED_REPOSITORY:
@@ -1248,7 +1250,9 @@ def resolve_imports_and_packages(
12481250
if not packages and session:
12491251
resolved_packages = list(
12501252
session._resolve_packages(
1251-
[], artifact_repository, existing_packages_dict
1253+
[],
1254+
artifact_repository=artifact_repository,
1255+
existing_packages_dict=existing_packages_dict,
12521256
)
12531257
)
12541258
elif packages:
@@ -1286,17 +1290,17 @@ def resolve_imports_and_packages(
12861290
resolved_packages = (
12871291
session._resolve_packages(
12881292
packages,
1289-
artifact_repository,
1290-
{}, # ignore session packages if passed in explicitly
1293+
artifact_repository=artifact_repository,
1294+
existing_packages_dict={}, # ignore session packages if passed in explicitly
12911295
include_pandas=is_pandas_udf,
12921296
statement_params=statement_params,
12931297
_suppress_local_package_warnings=_suppress_local_package_warnings,
12941298
)
12951299
if packages is not None
12961300
else session._resolve_packages(
12971301
[],
1298-
artifact_repository,
1299-
existing_packages_dict,
1302+
artifact_repository=artifact_repository,
1303+
existing_packages_dict=existing_packages_dict,
13001304
validate_package=False,
13011305
include_pandas=is_pandas_udf,
13021306
statement_params=statement_params,

src/snowflake/snowpark/session.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,9 @@ def __init__(
601601
self._conn = conn
602602
self._query_tag = None
603603
self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {}
604+
# packages under the DEFAULT_ARTIFACT_REPOSITORY
605+
# due to server side accessing private session members, this cannot be merged with _artifact_repository_packages
606+
self._packages: Dict[str, str] = {}
604607
# map of artifact repository name -> packages that should be added to functions under that repository
605608
self._artifact_repository_packages: DefaultDict[
606609
str, Dict[str, str]
@@ -1599,6 +1602,14 @@ def _list_files_in_stage(
15991602
prefix_length = get_stage_file_prefix_length(stage_location)
16001603
return {str(row[0])[prefix_length:] for row in file_list}
16011604

1605+
def _get_packages_by_artifact_repository(
1606+
self, artifact_repository: str
1607+
) -> Dict[str, str]:
1608+
if artifact_repository == _DEFAULT_ARTIFACT_REPOSITORY:
1609+
return self._packages
1610+
else:
1611+
return self._artifact_repository_packages[artifact_repository]
1612+
16021613
def get_packages(self, artifact_repository: Optional[str] = None) -> Dict[str, str]:
16031614
"""
16041615
Returns a ``dict`` of packages added for user-defined functions (UDFs).
@@ -1613,7 +1624,7 @@ def get_packages(self, artifact_repository: Optional[str] = None) -> Dict[str, s
16131624
artifact_repository = self._get_default_artifact_repository()
16141625

16151626
with self._package_lock:
1616-
return self._artifact_repository_packages[artifact_repository].copy()
1627+
return self._get_packages_by_artifact_repository(artifact_repository).copy()
16171628

16181629
def add_packages(
16191630
self,
@@ -1686,8 +1697,10 @@ def add_packages(
16861697

16871698
self._resolve_packages(
16881699
parse_positional_args_to_list(*packages),
1689-
artifact_repository,
1690-
self._artifact_repository_packages[artifact_repository],
1700+
artifact_repository=artifact_repository,
1701+
existing_packages_dict=self._get_packages_by_artifact_repository(
1702+
artifact_repository
1703+
),
16911704
)
16921705

16931706
def remove_package(
@@ -1724,7 +1737,7 @@ def remove_package(
17241737
artifact_repository = self._get_default_artifact_repository()
17251738

17261739
with self._package_lock:
1727-
packages = self._artifact_repository_packages[artifact_repository]
1740+
packages = self._get_packages_by_artifact_repository(artifact_repository)
17281741
if package_name in packages:
17291742
packages.pop(package_name)
17301743
else:
@@ -1742,7 +1755,7 @@ def clear_packages(
17421755
artifact_repository = self._get_default_artifact_repository()
17431756

17441757
with self._package_lock:
1745-
self._artifact_repository_packages[artifact_repository].clear()
1758+
self._get_packages_by_artifact_repository(artifact_repository).clear()
17461759

17471760
def add_requirements(
17481761
self,
@@ -2110,11 +2123,11 @@ def _get_req_identifiers_list(
21102123
def _resolve_packages(
21112124
self,
21122125
packages: List[Union[str, ModuleType]],
2113-
artifact_repository: str,
2114-
existing_packages_dict: Dict[str, str],
2126+
existing_packages_dict: Dict[str, str] = None,
21152127
validate_package: bool = True,
21162128
include_pandas: bool = False,
21172129
statement_params: Optional[Dict[str, str]] = None,
2130+
artifact_repository: str = None,
21182131
**kwargs,
21192132
) -> List[str]:
21202133
"""
@@ -2132,6 +2145,13 @@ def _resolve_packages(
21322145
Returns:
21332146
List[str]: List of package specifiers
21342147
"""
2148+
if artifact_repository is None:
2149+
artifact_repository = self._get_default_artifact_repository()
2150+
if existing_packages_dict is None:
2151+
existing_packages_dict = self._get_packages_by_artifact_repository(
2152+
artifact_repository
2153+
)
2154+
21352155
# Always include cloudpickle
21362156
extra_modules = [cloudpickle]
21372157
if include_pandas:
@@ -2404,7 +2424,10 @@ def _get_default_artifact_repository(self) -> str:
24042424
if isinstance(self._conn, MockServerConnection):
24052425
return _DEFAULT_ARTIFACT_REPOSITORY
24062426

2407-
cache_key = (self.get_current_database(), self.get_current_schema())
2427+
account = self.get_current_account()
2428+
database = self.get_current_database()
2429+
schema = self.get_current_schema()
2430+
cache_key = (database, schema)
24082431

24092432
if (
24102433
self._default_artifact_repository_cache is not None
@@ -2414,8 +2437,15 @@ def _get_default_artifact_repository(self) -> str:
24142437

24152438
try:
24162439
python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
2440+
entity_selector_args = (
2441+
f"'schema', '{schema}'"
2442+
if schema
2443+
else f"'database', '{database}'"
2444+
if database
2445+
else f"'account', '{account}'"
2446+
)
24172447
result = self._run_query(
2418-
f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}')"
2448+
f"SELECT SYSTEM$GET_DEFAULT_PYTHON_ARTIFACT_REPOSITORY('{python_version}', {entity_selector_args})"
24192449
)
24202450
value = result[0][0] if result else None
24212451
resolved = value or _DEFAULT_ARTIFACT_REPOSITORY

tests/integ/test_packaging.py

Lines changed: 9 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ def extract_major_minor_patch(version_string):
271271

272272
resolved_packages = session._resolve_packages(
273273
[numpy, pandas, dateutil],
274-
_ANACONDA_SHARED_REPOSITORY,
275-
{},
274+
artifact_repository=_ANACONDA_SHARED_REPOSITORY,
275+
existing_packages_dict={},
276276
validate_package=False,
277277
)
278278
# resolved_packages is a list of strings like
@@ -1204,17 +1204,10 @@ def test_replicate_local_environment(session):
12041204
"force_push": True,
12051205
}
12061206

1207-
assert not any(
1208-
[
1209-
package.startswith("cloudpickle")
1210-
for package in session._artifact_repository_packages[
1211-
_ANACONDA_SHARED_REPOSITORY
1212-
]
1213-
]
1214-
)
1207+
assert not any([package.startswith("cloudpickle") for package in session._packages])
12151208

12161209
def naive_add_packages(self, packages):
1217-
self._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY] = packages
1210+
self._packages = packages
12181211

12191212
with patch.object(session, "_is_anaconda_terms_acknowledged", lambda: True):
12201213
with patch.object(Session, "add_packages", new=naive_add_packages):
@@ -1228,22 +1221,10 @@ def naive_add_packages(self, packages):
12281221
},
12291222
)
12301223

1231-
assert any(
1232-
[
1233-
package.startswith("cloudpickle==")
1234-
for package in session._artifact_repository_packages[
1235-
_ANACONDA_SHARED_REPOSITORY
1236-
]
1237-
]
1238-
)
1224+
assert any([package.startswith("cloudpickle==") for package in session._packages])
12391225
for default_package in DEFAULT_PACKAGES:
12401226
assert not any(
1241-
[
1242-
package.startswith(default_package)
1243-
for package in session._artifact_repository_packages[
1244-
_ANACONDA_SHARED_REPOSITORY
1245-
]
1246-
]
1227+
[package.startswith(default_package) for package in session._packages]
12471228
)
12481229

12491230
session.clear_packages()
@@ -1262,29 +1243,12 @@ def naive_add_packages(self, packages):
12621243
ignore_packages=ignored_packages, relax=True
12631244
)
12641245

1265-
assert any(
1266-
[
1267-
package == "cloudpickle"
1268-
for package in session._artifact_repository_packages[
1269-
_ANACONDA_SHARED_REPOSITORY
1270-
]
1271-
]
1272-
)
1246+
assert any([package == "cloudpickle" for package in session._packages])
12731247
for default_package in DEFAULT_PACKAGES:
12741248
assert not any(
1275-
[
1276-
package.startswith(default_package)
1277-
for package in session._artifact_repository_packages[
1278-
_ANACONDA_SHARED_REPOSITORY
1279-
]
1280-
]
1249+
[package.startswith(default_package) for package in session._packages]
12811250
)
12821251
for ignored_package in ignored_packages:
12831252
assert not any(
1284-
[
1285-
package.startswith(ignored_package)
1286-
for package in session._artifact_repository_packages[
1287-
_ANACONDA_SHARED_REPOSITORY
1288-
]
1289-
]
1253+
[package.startswith(ignored_package) for package in session._packages]
12901254
)

tests/integ/test_udf.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3108,18 +3108,17 @@ def test_urllib() -> str:
31083108
)
31093109
def test_use_default_artifact_repository(db_parameters):
31103110
with Session.builder.configs(db_parameters).create() as session:
3111+
temp_database = Utils.random_temp_database()
31113112
temp_schema = Utils.random_temp_schema()
3112-
session.sql(f"create schema {temp_schema}").collect()
3113-
session.sql(f"use schema {temp_schema}").collect()
3113+
session.sql(f"create database {temp_database}").collect()
3114+
session.sql(f"use database {temp_database}").collect()
31143115
session.sql(
31153116
"ALTER SESSION SET ENABLE_DEFAULT_PYTHON_ARTIFACT_REPOSITORY = true"
31163117
).collect()
31173118
session.sql(
3118-
"ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = testdb_snowpark_python.testschema_snowpark_python.SNOWPARK_PYTHON_TEST_REPOSITORY"
3119+
"ALTER database set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = snowflake.snowpark.anaconda_shared_repository"
31193120
).collect()
31203121

3121-
session.add_packages("art", "cloudpickle")
3122-
31233122
def test_art() -> str:
31243123
import art # art is not available in the conda channel, but is in pypi
31253124

@@ -3128,6 +3127,25 @@ def test_art() -> str:
31283127

31293128
temp_func_name = Utils.random_name_for_temp_object(TempObjectType.FUNCTION)
31303129

3130+
# should not work in the database where the default is anaconda
3131+
with pytest.raises(
3132+
Exception,
3133+
match="Cannot add package art because it is not available in Snowflake",
3134+
):
3135+
udf(
3136+
session=session,
3137+
func=test_art,
3138+
name=temp_func_name,
3139+
packages=["art", "cloudpickle"],
3140+
)
3141+
3142+
session.sql(f"create schema {temp_schema}").collect()
3143+
session.use_schema(temp_schema)
3144+
session.sql(
3145+
"ALTER schema set DEFAULT_PYTHON_ARTIFACT_REPOSITORY = testdb_snowpark_python.testschema_snowpark_python.SNOWPARK_PYTHON_TEST_REPOSITORY"
3146+
).collect()
3147+
session.add_packages("art", "cloudpickle")
3148+
31313149
try:
31323150
# Test function registration
31333151
udf(
@@ -3142,4 +3160,4 @@ def test_art() -> str:
31423160
finally:
31433161
session._run_query(f"drop function if exists {temp_func_name}(int)")
31443162

3145-
session.sql(f"drop schema {temp_schema}").collect()
3163+
session.sql(f"drop database {temp_database}").collect()

0 commit comments

Comments
 (0)