@@ -601,7 +601,8 @@ def __init__(
601601 self ._conn = conn
602602 self ._query_tag = None
603603 self ._import_paths : Dict [str , Tuple [Optional [str ], Optional [str ]]] = {}
604- # unused, needed for test infra?
604+ # packages under the DEFAULT_ARTIFACT_REPOSITORY
605+ # due to server side accessing private session members, this cannot be merged with _artifact_repository_packages
605606 self ._packages : Dict [str , str ] = {}
606607 # map of artifact repository name -> packages that should be added to functions under that repository
607608 self ._artifact_repository_packages : DefaultDict [
@@ -1601,6 +1602,14 @@ def _list_files_in_stage(
16011602 prefix_length = get_stage_file_prefix_length (stage_location )
16021603 return {str (row [0 ])[prefix_length :] for row in file_list }
16031604
1605+ def _get_packages_by_artifact_repository (
1606+ self , artifact_repository : str
1607+ ) -> Dict [str , str ]:
1608+ if artifact_repository == _DEFAULT_ARTIFACT_REPOSITORY :
1609+ return self ._packages
1610+ else :
1611+ return self ._artifact_repository_packages [artifact_repository ]
1612+
16041613 def get_packages (self , artifact_repository : Optional [str ] = None ) -> Dict [str , str ]:
16051614 """
16061615 Returns a ``dict`` of packages added for user-defined functions (UDFs).
@@ -1615,7 +1624,7 @@ def get_packages(self, artifact_repository: Optional[str] = None) -> Dict[str, s
16151624 artifact_repository = self ._get_default_artifact_repository ()
16161625
16171626 with self ._package_lock :
1618- return self ._artifact_repository_packages [ artifact_repository ] .copy ()
1627+ return self ._get_packages_by_artifact_repository ( artifact_repository ) .copy ()
16191628
16201629 def add_packages (
16211630 self ,
@@ -1688,8 +1697,10 @@ def add_packages(
16881697
16891698 self ._resolve_packages (
16901699 parse_positional_args_to_list (* packages ),
1691- artifact_repository ,
1692- self ._artifact_repository_packages [artifact_repository ],
1700+ artifact_repository = artifact_repository ,
1701+ existing_packages_dict = self ._get_packages_by_artifact_repository (
1702+ artifact_repository
1703+ ),
16931704 )
16941705
16951706 def remove_package (
@@ -1726,7 +1737,7 @@ def remove_package(
17261737 artifact_repository = self ._get_default_artifact_repository ()
17271738
17281739 with self ._package_lock :
1729- packages = self ._artifact_repository_packages [ artifact_repository ]
1740+ packages = self ._get_packages_by_artifact_repository ( artifact_repository )
17301741 if package_name in packages :
17311742 packages .pop (package_name )
17321743 else :
@@ -1744,7 +1755,7 @@ def clear_packages(
17441755 artifact_repository = self ._get_default_artifact_repository ()
17451756
17461757 with self ._package_lock :
1747- self ._artifact_repository_packages [ artifact_repository ] .clear ()
1758+ self ._get_packages_by_artifact_repository ( artifact_repository ) .clear ()
17481759
17491760 def add_requirements (
17501761 self ,
@@ -2112,11 +2123,11 @@ def _get_req_identifiers_list(
21122123 def _resolve_packages (
21132124 self ,
21142125 packages : List [Union [str , ModuleType ]],
2115- artifact_repository : str ,
2116- existing_packages_dict : Dict [str , str ],
2126+ existing_packages_dict : Dict [str , str ] = None ,
21172127 validate_package : bool = True ,
21182128 include_pandas : bool = False ,
21192129 statement_params : Optional [Dict [str , str ]] = None ,
2130+ artifact_repository : str = None ,
21202131 ** kwargs ,
21212132 ) -> List [str ]:
21222133 """
@@ -2134,6 +2145,13 @@ def _resolve_packages(
21342145 Returns:
21352146 List[str]: List of package specifiers
21362147 """
2148+ if artifact_repository is None :
2149+ artifact_repository = self ._get_default_artifact_repository ()
2150+ if existing_packages_dict is None :
2151+ existing_packages_dict = self ._get_packages_by_artifact_repository (
2152+ artifact_repository
2153+ )
2154+
21372155 # Always include cloudpickle
21382156 extra_modules = [cloudpickle ]
21392157 if include_pandas :
0 commit comments