Skip to content

Commit c59a539

Browse files
committed
keep the PRIVATE api consistent from before
1 parent 33bfffa commit c59a539

9 files changed

Lines changed: 66 additions & 77 deletions

File tree

src/snowflake/snowpark/_internal/udf_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,9 +1142,9 @@ def add_snowpark_package_to_sproc_packages(
11421142
packages = [this_package]
11431143
else:
11441144
with session._package_lock:
1145-
existing_packages = session._artifact_repository_packages[
1145+
existing_packages = session._get_packages_by_artifact_repository(
11461146
artifact_repository
1147-
]
1147+
)
11481148
if package_name not in existing_packages:
11491149
packages = list(existing_packages.values()) + [this_package]
11501150
return packages
@@ -1239,7 +1239,9 @@ def resolve_imports_and_packages(
12391239
)
12401240

12411241
existing_packages_dict = (
1242-
session._artifact_repository_packages[artifact_repository] if session else {}
1242+
session._get_packages_by_artifact_repository(artifact_repository)
1243+
if session
1244+
else {}
12431245
)
12441246

12451247
if artifact_repository != _ANACONDA_SHARED_REPOSITORY:
@@ -1248,7 +1250,9 @@ def resolve_imports_and_packages(
12481250
if not packages and session:
12491251
resolved_packages = list(
12501252
session._resolve_packages(
1251-
[], artifact_repository, existing_packages_dict
1253+
[],
1254+
artifact_repository=artifact_repository,
1255+
existing_packages_dict=existing_packages_dict,
12521256
)
12531257
)
12541258
elif packages:
@@ -1286,17 +1290,17 @@ def resolve_imports_and_packages(
12861290
resolved_packages = (
12871291
session._resolve_packages(
12881292
packages,
1289-
artifact_repository,
1290-
{}, # ignore session packages if passed in explicitly
1293+
artifact_repository=artifact_repository,
1294+
existing_packages_dict={}, # ignore session packages if passed in explicitly
12911295
include_pandas=is_pandas_udf,
12921296
statement_params=statement_params,
12931297
_suppress_local_package_warnings=_suppress_local_package_warnings,
12941298
)
12951299
if packages is not None
12961300
else session._resolve_packages(
12971301
[],
1298-
artifact_repository,
1299-
existing_packages_dict,
1302+
artifact_repository=artifact_repository,
1303+
existing_packages_dict=existing_packages_dict,
13001304
validate_package=False,
13011305
include_pandas=is_pandas_udf,
13021306
statement_params=statement_params,

src/snowflake/snowpark/session.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,8 @@ def __init__(
601601
self._conn = conn
602602
self._query_tag = None
603603
self._import_paths: Dict[str, Tuple[Optional[str], Optional[str]]] = {}
604-
# unused, needed for test infra?
604+
# packages under the DEFAULT_ARTIFACT_REPOSITORY
605+
# due to server side accessing private session members, this cannot be merged with _artifact_repository_packages
605606
self._packages: Dict[str, str] = {}
606607
# map of artifact repository name -> packages that should be added to functions under that repository
607608
self._artifact_repository_packages: DefaultDict[
@@ -1601,6 +1602,14 @@ def _list_files_in_stage(
16011602
prefix_length = get_stage_file_prefix_length(stage_location)
16021603
return {str(row[0])[prefix_length:] for row in file_list}
16031604

1605+
def _get_packages_by_artifact_repository(
1606+
self, artifact_repository: str
1607+
) -> Dict[str, str]:
1608+
if artifact_repository == _DEFAULT_ARTIFACT_REPOSITORY:
1609+
return self._packages
1610+
else:
1611+
return self._artifact_repository_packages[artifact_repository]
1612+
16041613
def get_packages(self, artifact_repository: Optional[str] = None) -> Dict[str, str]:
16051614
"""
16061615
Returns a ``dict`` of packages added for user-defined functions (UDFs).
@@ -1615,7 +1624,7 @@ def get_packages(self, artifact_repository: Optional[str] = None) -> Dict[str, s
16151624
artifact_repository = self._get_default_artifact_repository()
16161625

16171626
with self._package_lock:
1618-
return self._artifact_repository_packages[artifact_repository].copy()
1627+
return self._get_packages_by_artifact_repository(artifact_repository).copy()
16191628

16201629
def add_packages(
16211630
self,
@@ -1688,8 +1697,10 @@ def add_packages(
16881697

16891698
self._resolve_packages(
16901699
parse_positional_args_to_list(*packages),
1691-
artifact_repository,
1692-
self._artifact_repository_packages[artifact_repository],
1700+
artifact_repository=artifact_repository,
1701+
existing_packages_dict=self._get_packages_by_artifact_repository(
1702+
artifact_repository
1703+
),
16931704
)
16941705

16951706
def remove_package(
@@ -1726,7 +1737,7 @@ def remove_package(
17261737
artifact_repository = self._get_default_artifact_repository()
17271738

17281739
with self._package_lock:
1729-
packages = self._artifact_repository_packages[artifact_repository]
1740+
packages = self._get_packages_by_artifact_repository(artifact_repository)
17301741
if package_name in packages:
17311742
packages.pop(package_name)
17321743
else:
@@ -1744,7 +1755,7 @@ def clear_packages(
17441755
artifact_repository = self._get_default_artifact_repository()
17451756

17461757
with self._package_lock:
1747-
self._artifact_repository_packages[artifact_repository].clear()
1758+
self._get_packages_by_artifact_repository(artifact_repository).clear()
17481759

17491760
def add_requirements(
17501761
self,
@@ -2112,11 +2123,11 @@ def _get_req_identifiers_list(
21122123
def _resolve_packages(
21132124
self,
21142125
packages: List[Union[str, ModuleType]],
2115-
artifact_repository: str,
2116-
existing_packages_dict: Dict[str, str],
2126+
existing_packages_dict: Dict[str, str] = None,
21172127
validate_package: bool = True,
21182128
include_pandas: bool = False,
21192129
statement_params: Optional[Dict[str, str]] = None,
2130+
artifact_repository: str = None,
21202131
**kwargs,
21212132
) -> List[str]:
21222133
"""
@@ -2134,6 +2145,13 @@ def _resolve_packages(
21342145
Returns:
21352146
List[str]: List of package specifiers
21362147
"""
2148+
if artifact_repository is None:
2149+
artifact_repository = self._get_default_artifact_repository()
2150+
if existing_packages_dict is None:
2151+
existing_packages_dict = self._get_packages_by_artifact_repository(
2152+
artifact_repository
2153+
)
2154+
21372155
# Always include cloudpickle
21382156
extra_modules = [cloudpickle]
21392157
if include_pandas:

tests/integ/test_packaging.py

Lines changed: 9 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ def extract_major_minor_patch(version_string):
271271

272272
resolved_packages = session._resolve_packages(
273273
[numpy, pandas, dateutil],
274-
_ANACONDA_SHARED_REPOSITORY,
275-
{},
274+
artifact_repository=_ANACONDA_SHARED_REPOSITORY,
275+
existing_packages_dict={},
276276
validate_package=False,
277277
)
278278
# resolved_packages is a list of strings like
@@ -1204,17 +1204,10 @@ def test_replicate_local_environment(session):
12041204
"force_push": True,
12051205
}
12061206

1207-
assert not any(
1208-
[
1209-
package.startswith("cloudpickle")
1210-
for package in session._artifact_repository_packages[
1211-
_ANACONDA_SHARED_REPOSITORY
1212-
]
1213-
]
1214-
)
1207+
assert not any([package.startswith("cloudpickle") for package in session._packages])
12151208

12161209
def naive_add_packages(self, packages):
1217-
self._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY] = packages
1210+
self._packages = packages
12181211

12191212
with patch.object(session, "_is_anaconda_terms_acknowledged", lambda: True):
12201213
with patch.object(Session, "add_packages", new=naive_add_packages):
@@ -1228,22 +1221,10 @@ def naive_add_packages(self, packages):
12281221
},
12291222
)
12301223

1231-
assert any(
1232-
[
1233-
package.startswith("cloudpickle==")
1234-
for package in session._artifact_repository_packages[
1235-
_ANACONDA_SHARED_REPOSITORY
1236-
]
1237-
]
1238-
)
1224+
assert any([package.startswith("cloudpickle==") for package in session._packages])
12391225
for default_package in DEFAULT_PACKAGES:
12401226
assert not any(
1241-
[
1242-
package.startswith(default_package)
1243-
for package in session._artifact_repository_packages[
1244-
_ANACONDA_SHARED_REPOSITORY
1245-
]
1246-
]
1227+
[package.startswith(default_package) for package in session._packages]
12471228
)
12481229

12491230
session.clear_packages()
@@ -1262,29 +1243,12 @@ def naive_add_packages(self, packages):
12621243
ignore_packages=ignored_packages, relax=True
12631244
)
12641245

1265-
assert any(
1266-
[
1267-
package == "cloudpickle"
1268-
for package in session._artifact_repository_packages[
1269-
_ANACONDA_SHARED_REPOSITORY
1270-
]
1271-
]
1272-
)
1246+
assert any([package == "cloudpickle" for package in session._packages])
12731247
for default_package in DEFAULT_PACKAGES:
12741248
assert not any(
1275-
[
1276-
package.startswith(default_package)
1277-
for package in session._artifact_repository_packages[
1278-
_ANACONDA_SHARED_REPOSITORY
1279-
]
1280-
]
1249+
[package.startswith(default_package) for package in session._packages]
12811250
)
12821251
for ignored_package in ignored_packages:
12831252
assert not any(
1284-
[
1285-
package.startswith(ignored_package)
1286-
for package in session._artifact_repository_packages[
1287-
_ANACONDA_SHARED_REPOSITORY
1288-
]
1289-
]
1253+
[package.startswith(ignored_package) for package in session._packages]
12901254
)

tests/unit/test_session.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True
213213

214214
session._resolve_packages(
215215
["random_package_name"],
216-
_ANACONDA_SHARED_REPOSITORY,
217-
{},
216+
artifact_repository=_ANACONDA_SHARED_REPOSITORY,
217+
existing_packages_dict={},
218218
validate_package=True,
219219
include_pandas=False,
220220
)
@@ -248,8 +248,8 @@ def run_query(sql: str):
248248
):
249249
session._resolve_packages(
250250
["random_package_name"],
251-
_ANACONDA_SHARED_REPOSITORY,
252-
{},
251+
artifact_repository=_ANACONDA_SHARED_REPOSITORY,
252+
existing_packages_dict={},
253253
validate_package=True,
254254
include_pandas=False,
255255
)
@@ -273,8 +273,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True
273273

274274
resolved_packages = session._resolve_packages(
275275
["random_package_name"],
276-
_ANACONDA_SHARED_REPOSITORY,
277-
existing_packages,
276+
artifact_repository=_ANACONDA_SHARED_REPOSITORY,
277+
existing_packages_dict=existing_packages,
278278
validate_package=True,
279279
include_pandas=False,
280280
)
@@ -305,8 +305,8 @@ def mock_get_information_schema_packages(table_name: str, _emit_ast: bool = True
305305
):
306306
session._resolve_packages(
307307
["snowflake-snowpark-python"],
308-
_ANACONDA_SHARED_REPOSITORY,
309-
{},
308+
artifact_repository=_ANACONDA_SHARED_REPOSITORY,
309+
existing_packages_dict={},
310310
validate_package=True,
311311
include_pandas=False,
312312
_suppress_local_package_warnings=True,
@@ -333,16 +333,16 @@ def assert_packages(packages):
333333

334334
packages = session._resolve_packages(
335335
["snowflake-snowpark-python==1.0.0", "cloudpickle==1.0.0"],
336-
"snowflake.snowpark.pypi_shared_repository",
337-
existing_packages,
336+
artifact_repository="snowflake.snowpark.pypi_shared_repository",
337+
existing_packages_dict=existing_packages,
338338
)
339339

340340
assert_packages(packages)
341341

342342
packages = session._resolve_packages(
343343
[],
344-
"snowflake.snowpark.pypi_shared_repository",
345-
existing_packages,
344+
artifact_repository="snowflake.snowpark.pypi_shared_repository",
345+
existing_packages_dict=existing_packages,
346346
)
347347

348348
assert_packages(packages)

tests/unit/test_stored_procedure.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def test_stored_procedure_execute_as(execute_as):
4444
fake_session._analyzer = Analyzer(fake_session)
4545
fake_session._runtime_version_from_requirement = None
4646
fake_session._artifact_repository_packages = defaultdict(dict)
47+
fake_session._packages = {}
4748

4849
def return1(_):
4950
return 1
@@ -92,6 +93,7 @@ def test_do_register_sp_negative(cleanup_registration_patch):
9293
fake_session._run_query = mock.Mock(side_effect=ProgrammingError())
9394
fake_session.sproc = StoredProcedureRegistration(fake_session)
9495
fake_session._artifact_repository_packages = defaultdict(dict)
96+
fake_session._packages = {}
9597
with pytest.raises(SnowparkSQLException) as ex_info:
9698
sproc(lambda: 1, session=fake_session, return_type=IntegerType(), packages=[])
9799
assert ex_info.value.error_code == "1304"

tests/unit/test_udaf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def test_do_register_udaf_negative(cleanup_registration_patch):
5959
fake_session._runtime_version_from_requirement = None
6060
fake_session.udaf = UDAFRegistration(fake_session)
6161
fake_session._artifact_repository_packages = defaultdict(dict)
62+
fake_session._packages = {}
6263
with pytest.raises(SnowparkSQLException) as ex_info:
6364

6465
@udaf(session=fake_session)

tests/unit/test_udf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def test_do_register_sp_negative(cleanup_registration_patch):
3434
fake_session._run_query = mock.Mock(side_effect=ProgrammingError())
3535
fake_session.udf = UDFRegistration(fake_session)
3636
fake_session._artifact_repository_packages = defaultdict(dict)
37+
fake_session._packages = {}
3738
with pytest.raises(SnowparkSQLException) as ex_info:
3839
udf(lambda: 1, session=fake_session, return_type=IntegerType(), packages=[])
3940
assert ex_info.value.error_code == "1304"

tests/unit/test_udf_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33
#
44

5-
from collections import defaultdict
65
import logging
76
import os
87
import pickle
@@ -254,8 +253,7 @@ def test_add_snowpark_package_to_sproc_packages_does_not_replace_package():
254253

255254
def test_add_snowpark_package_to_sproc_packages_to_session():
256255
fake_session = mock.create_autospec(Session)
257-
fake_session._artifact_repository_packages = defaultdict(dict)
258-
fake_session._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY] = {
256+
fake_session._packages = {
259257
"random_package_one": "random_package_one",
260258
"random_package_two": "random_package_two",
261259
}
@@ -272,7 +270,7 @@ def test_add_snowpark_package_to_sproc_packages_to_session():
272270
assert len(result) == 3
273271
assert final_name in result
274272

275-
fake_session._artifact_repository_packages[_ANACONDA_SHARED_REPOSITORY][
273+
fake_session._packages[
276274
"snowflake-snowpark-python"
277275
] = "snowflake-snowpark-python==1.12.0"
278276
result = add_snowpark_package_to_sproc_packages(

tests/unit/test_udtf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def test_do_register_sp_negative(cleanup_registration_patch):
4242
fake_session._runtime_version_from_requirement = None
4343
fake_session.udtf = UDTFRegistration(fake_session)
4444
fake_session._artifact_repository_packages = defaultdict(dict)
45+
fake_session._packages = {}
4546
with pytest.raises(SnowparkSQLException) as ex_info:
4647

4748
@udtf(output_schema=["num"], session=fake_session)

0 commit comments

Comments
 (0)