From 555e017292669d4c5aaae9b180bc3751e083ea19 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Tue, 7 Feb 2023 09:21:43 +0100 Subject: [PATCH 1/6] First draft to update tooling --- .gitignore | 3 + Makefile | 5 +- .../test_sentence_transformers.py | 11 +- contrib/utils.py | 5 +- pyproject.toml | 20 +- setup.cfg | 47 +- setup.py | 21 +- src/huggingface_hub/__init__.py | 8 +- src/huggingface_hub/_commit_api.py | 74 +-- src/huggingface_hub/_login.py | 17 +- src/huggingface_hub/_snapshot_download.py | 16 +- src/huggingface_hub/commands/delete_cache.py | 50 +- src/huggingface_hub/commands/env.py | 4 +- .../commands/huggingface_cli.py | 4 +- src/huggingface_hub/commands/lfs.py | 12 +- src/huggingface_hub/commands/scan_cache.py | 21 +- src/huggingface_hub/commands/user.py | 62 +-- src/huggingface_hub/constants.py | 20 +- src/huggingface_hub/fastai_utils.py | 26 +- src/huggingface_hub/file_download.py | 127 ++--- src/huggingface_hub/hf_api.py | 204 ++------ src/huggingface_hub/inference_api.py | 11 +- src/huggingface_hub/keras_mixin.py | 24 +- src/huggingface_hub/lfs.py | 32 +- src/huggingface_hub/repocard.py | 40 +- src/huggingface_hub/repocard_data.py | 26 +- src/huggingface_hub/repository.py | 145 ++---- src/huggingface_hub/utils/_cache_assets.py | 5 +- src/huggingface_hub/utils/_cache_manager.py | 56 +-- src/huggingface_hub/utils/_deprecation.py | 20 +- src/huggingface_hub/utils/_errors.py | 42 +- src/huggingface_hub/utils/_fixes.py | 8 +- src/huggingface_hub/utils/_git_credential.py | 21 +- src/huggingface_hub/utils/_http.py | 5 +- src/huggingface_hub/utils/_paths.py | 13 +- src/huggingface_hub/utils/_validators.py | 16 +- src/huggingface_hub/utils/endpoint_helpers.py | 4 +- src/huggingface_hub/utils/logging.py | 18 +- tests/conftest.py | 11 +- tests/test_cache_layout.py | 100 +--- tests/test_cache_no_symlinks.py | 24 +- tests/test_cli.py | 4 +- tests/test_command_delete_cache.py | 51 +- tests/test_commit_api.py | 41 +- tests/test_endpoint_helpers.py | 14 +- tests/test_fastai_integration.py | 12 +- tests/test_file_download.py | 89 +--- tests/test_hf_api.py | 459 +++++------------- tests/test_hubmixin.py | 24 +- tests/test_inference_api.py | 13 +- tests/test_init_lazy_loading.py | 6 +- tests/test_keras_integration.py | 80 +-- tests/test_lfs.py | 20 +- tests/test_login_utils.py | 8 +- tests/test_offline_utils.py | 2 +- tests/test_repocard.py | 58 +-- tests/test_repocard_data.py | 10 +- tests/test_repository.py | 34 +- tests/test_snapshot_download.py | 13 +- tests/test_utils_assets.py | 13 +- tests/test_utils_cache.py | 62 +-- tests/test_utils_cli.py | 12 +- tests/test_utils_datetime.py | 4 +- tests/test_utils_deprecation.py | 31 +- tests/test_utils_errors.py | 81 +--- tests/test_utils_fixes.py | 4 +- tests/test_utils_git_credentials.py | 12 +- tests/test_utils_headers.py | 24 +- tests/test_utils_http.py | 11 +- tests/test_utils_pagination.py | 4 +- tests/test_utils_validators.py | 20 +- tests/testing_constants.py | 4 +- tests/testing_utils.py | 32 +- utils/check_contrib_list.py | 12 +- utils/check_static_imports.py | 16 +- utils/push_repocard_examples.py | 19 +- 76 files changed, 725 insertions(+), 1952 deletions(-) diff --git a/.gitignore b/.gitignore index 4b6a96a102..73b196a5b8 100644 --- a/.gitignore +++ b/.gitignore @@ -132,3 +132,6 @@ dmypy.json .idea/ .DS_Store + +# Ruff +.ruff_cache \ No newline at end of file diff --git a/Makefile b/Makefile index 8c8b2c6741..70a8c5748a 100644 --- a/Makefile +++ b/Makefile @@ -6,15 +6,14 @@ check_dirs := contrib src tests utils setup.py quality: black --check $(check_dirs) - isort --check-only $(check_dirs) - flake8 $(check_dirs) + ruff $(check_dirs) mypy src python utils/check_contrib_list.py python utils/check_static_imports.py style: black $(check_dirs) - isort $(check_dirs) + ruff $(check_dirs) --fix python utils/check_contrib_list.py --update python utils/check_static_imports.py --update diff --git a/contrib/sentence_transformers/test_sentence_transformers.py b/contrib/sentence_transformers/test_sentence_transformers.py index 62d1593b5f..96217c477e 100644 --- a/contrib/sentence_transformers/test_sentence_transformers.py +++ b/contrib/sentence_transformers/test_sentence_transformers.py @@ -1,5 +1,4 @@ import pytest - from sentence_transformers import SentenceTransformer, util from ..utils import production_endpoint @@ -23,12 +22,6 @@ def test_from_pretrained(multi_qa_model: SentenceTransformer) -> None: print("Similarity:", util.dot_score(query_embedding, passage_embedding)) -@pytest.mark.xfail( - reason=( - "Production endpoint is hardcoded in sentence_transformers when pushing to Hub." - ) -) -def test_push_to_hub( - multi_qa_model: SentenceTransformer, repo_name: str, cleanup_repo: None -) -> None: +@pytest.mark.xfail(reason="Production endpoint is hardcoded in sentence_transformers when pushing to Hub.") +def test_push_to_hub(multi_qa_model: SentenceTransformer, repo_name: str, cleanup_repo: None) -> None: multi_qa_model.save_to_hub(repo_name) diff --git a/contrib/utils.py b/contrib/utils.py index 396c87e206..e1681cd561 100644 --- a/contrib/utils.py +++ b/contrib/utils.py @@ -41,10 +41,7 @@ def test_push_to_hub(): patchers = ( [patch(target + ".ENDPOINT", PROD_ENDPOINT) for target in ENDPOINT_TARGETS] - + [ - patch(target + ".HUGGINGFACE_CO_URL_TEMPLATE", PROD_URL_TEMPLATE) - for target in URL_TEMPLATE_TARGETS - ] + + [patch(target + ".HUGGINGFACE_CO_URL_TEMPLATE", PROD_URL_TEMPLATE) for target in URL_TEMPLATE_TARGETS] + [patch.object(api, "endpoint", PROD_URL_TEMPLATE)] ) diff --git a/pyproject.toml b/pyproject.toml index b453ad5999..601b77ebda 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,19 @@ [tool.black] -line-length = 88 +line-length = 119 target_version = ['py37', 'py38', 'py39', 'py310'] preview = true -[tool.mypy] -ignore_missing_imports = true -no_implicit_optional = true -scripts_are_modules = true \ No newline at end of file +[tool.ruff] +# Ignored rules: +# "E501" -> line length violation +# "F821" -> undefined named in type annotation (e.g. Literal["something"]) +ignore = ["E501", "F821"] +select = ["E", "F", "I", "W"] +line-length = 119 + +[tool.ruff.isort] +lines-after-imports = 2 +known-first-party = ["huggingface_hub"] + +[tool.ruff.per-file-ignores] +"src/huggingface_hub/__init__.py" = ["I001"] # Imports are autogenerated \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index ee04492292..e3f091e5e0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,53 +4,18 @@ ensure_newline_before_comments = True force_grid_wrap = 0 include_trailing_comma = True known_first_party = huggingface_hub -known_third_party = - absl - conllu - datasets - elasticsearch - fairseq - faiss-cpu - fastprogress - fire - fugashi - git - graphviz - h5py - matplotlib - nltk - numpy - packaging - pandas - pydot - PIL - psutil - pytest - pytorch_lightning - rouge_score - sacrebleu - seqeval - sklearn - streamlit - tensorboardX - tensorflow - tensorflow_datasets - timeout_decorator - torch - torchtext - torchvision - torch_xla - tqdm - -line_length = 88 +line_length = 119 lines_after_imports = 2 multi_line_output = 3 use_parentheses = True [flake8] exclude = .git,__pycache__,old,build,dist,.venv* -ignore = B028, E203, E501, E741, W503 -max-line-length = 88 +# ignore = B028, E203, E501, E741, W503 +# ignore = B028, E203, E501, E741, W503 +ignore = E501, E741, E821, W605 +# select = ["E", "F", "I", "W"] +max-line-length = 119 [tool:pytest] # -Werror::FutureWarning -> test fails if FutureWarning is thrown diff --git a/setup.py b/setup.py index 69155f7084..3cc9510557 100644 --- a/setup.py +++ b/setup.py @@ -64,9 +64,8 @@ def get_version() -> str: ] extras["quality"] = [ - "black==22.3", - "flake8>=3.8.3", - "flake8-bugbear", + "black~=23.1", + "ruff>=0.0.241", "isort>=5.5.4", "mypy==0.982", ] @@ -81,26 +80,16 @@ def get_version() -> str: version=get_version(), author="Hugging Face, Inc.", author_email="julien@huggingface.co", - description=( - "Client library to download and publish models, datasets and other repos on the" - " huggingface.co hub" - ), + description="Client library to download and publish models, datasets and other repos on the huggingface.co hub", long_description=open("README.md", "r", encoding="utf-8").read(), long_description_content_type="text/markdown", - keywords=( - "model-hub machine-learning models natural-language-processing deep-learning" - " pytorch pretrained-models" - ), + keywords="model-hub machine-learning models natural-language-processing deep-learning pytorch pretrained-models", license="Apache", url="https://github.com/huggingface/huggingface_hub", package_dir={"": "src"}, packages=find_packages("src"), extras_require=extras, - entry_points={ - "console_scripts": [ - "huggingface-cli=huggingface_hub.commands.huggingface_cli:main" - ] - }, + entry_points={"console_scripts": ["huggingface-cli=huggingface_hub.commands.huggingface_cli:main"]}, python_requires=">=3.7.0", install_requires=install_requires, classifiers=[ diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 009226019c..3f890aa716 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -258,9 +258,7 @@ def _attach(package_name, submodules=None, submod_attrs=None): else: submodules = set(submodules) - attr_to_modules = { - attr: mod for mod, attrs in submod_attrs.items() for attr in attrs - } + attr_to_modules = {attr: mod for mod, attrs in submod_attrs.items() for attr in attrs} __all__ = list(submodules | attr_to_modules.keys()) @@ -293,9 +291,7 @@ def __dir__(): return __getattr__, __dir__, list(__all__) -__getattr__, __dir__, __all__ = _attach( - __name__, submodules=[], submod_attrs=_SUBMOD_ATTRS -) +__getattr__, __dir__, __all__ = _attach(__name__, submodules=[], submod_attrs=_SUBMOD_ATTRS) # WARNING: any content below this statement is generated automatically. Any manual edit # will be lost when re-generating this file ! diff --git a/src/huggingface_hub/_commit_api.py b/src/huggingface_hub/_commit_api.py index fd62ca34b0..272a010ba9 100644 --- a/src/huggingface_hub/_commit_api.py +++ b/src/huggingface_hub/_commit_api.py @@ -11,15 +11,20 @@ from pathlib import Path, PurePosixPath from typing import Any, BinaryIO, Dict, Iterable, Iterator, List, Optional, Union -from tqdm.contrib.concurrent import thread_map - import requests +from tqdm.contrib.concurrent import thread_map from .constants import ENDPOINT from .lfs import UploadInfo, _validate_batch_actions, lfs_upload, post_lfs_batch_info -from .utils import build_hf_headers, chunk_iterable, hf_raise_for_status, logging +from .utils import ( + build_hf_headers, + chunk_iterable, + hf_raise_for_status, + logging, + tqdm_stream_file, + validate_hf_hub_args, +) from .utils import tqdm as hf_tqdm -from .utils import tqdm_stream_file, validate_hf_hub_args from .utils._deprecation import _deprecate_method from .utils._typing import Literal @@ -55,8 +60,7 @@ def __post_init__(self): self.is_folder = self.path_in_repo.endswith("/") if not isinstance(self.is_folder, bool): raise ValueError( - "Wrong value for `is_folder`. Must be one of [`True`, `False`," - f" `'auto'`]. Got '{self.is_folder}'." + f"Wrong value for `is_folder`. Must be one of [`True`, `False`, `'auto'`]. Got '{self.is_folder}'." ) @@ -97,10 +101,7 @@ def __post_init__(self) -> None: if isinstance(self.path_or_fileobj, str): path_or_fileobj = os.path.normpath(os.path.expanduser(self.path_or_fileobj)) if not os.path.isfile(path_or_fileobj): - raise ValueError( - f"Provided path: '{path_or_fileobj}' is not a file on the local" - " file system" - ) + raise ValueError(f"Provided path: '{path_or_fileobj}' is not a file on the local file system") elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)): # ^^ Inspired from: https://stackoverflow.com/questions/44584829/how-to-determine-if-file-is-opened-in-binary-or-text-mode raise ValueError( @@ -114,8 +115,7 @@ def __post_init__(self) -> None: self.path_or_fileobj.seek(0, os.SEEK_CUR) except (OSError, AttributeError) as exc: raise ValueError( - "path_or_fileobj is a file-like object but does not implement" - " seek() and tell()" + "path_or_fileobj is a file-like object but does not implement seek() and tell()" ) from exc # Compute "upload_info" attribute @@ -126,9 +126,7 @@ def __post_init__(self) -> None: else: self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj) - @_deprecate_method( - version="0.14", message="Operation is validated at initialization." - ) + @_deprecate_method(version="0.14", message="Operation is validated at initialization.") def validate(self) -> None: pass @@ -172,9 +170,7 @@ def as_file(self, with_tqdm: bool = False) -> Iterator[BinaryIO]: config.json: 100%|█████████████████████████| 8.19k/8.19k [00:02<00:00, 3.72kB/s] ``` """ - if isinstance(self.path_or_fileobj, str) or isinstance( - self.path_or_fileobj, Path - ): + if isinstance(self.path_or_fileobj, str) or isinstance(self.path_or_fileobj, Path): if with_tqdm: with tqdm_stream_file(self.path_or_fileobj) as file: yield file @@ -302,8 +298,7 @@ def upload_lfs_files( if batch_errors_chunk: message = "\n".join( [ - f'Encountered error for file with OID {err.get("oid")}:' - f' `{err.get("error", {}).get("message")}' + f'Encountered error for file with OID {err.get("oid")}: `{err.get("error", {}).get("message")}' for err in batch_errors_chunk ] ) @@ -331,17 +326,12 @@ def upload_lfs_files( def _inner_upload_lfs_object(batch_action): try: operation = oid2addop[batch_action["oid"]] - return _upload_lfs_object( - operation=operation, lfs_batch_action=batch_action, token=token - ) + return _upload_lfs_object(operation=operation, lfs_batch_action=batch_action, token=token) except Exception as exc: - raise RuntimeError( - f"Error while uploading '{operation.path_in_repo}' to the Hub." - ) from exc + raise RuntimeError(f"Error while uploading '{operation.path_in_repo}' to the Hub.") from exc logger.debug( - f"Uploading {len(filtered_actions)} LFS files to the Hub using up to" - f" {num_threads} threads concurrently" + f"Uploading {len(filtered_actions)} LFS files to the Hub using up to {num_threads} threads concurrently" ) thread_map( _inner_upload_lfs_object, @@ -352,9 +342,7 @@ def _inner_upload_lfs_object(batch_action): ) -def _upload_lfs_object( - operation: CommitOperationAdd, lfs_batch_action: dict, token: Optional[str] -): +def _upload_lfs_object(operation: CommitOperationAdd, lfs_batch_action: dict, token: Optional[str]): """ Handles uploading a given object to the Hub with the LFS protocol. @@ -379,10 +367,7 @@ def _upload_lfs_object( actions = lfs_batch_action.get("actions") if actions is None: # The file was already uploaded - logger.debug( - f"Content of file {operation.path_in_repo} is already present upstream" - " - skipping upload" - ) + logger.debug(f"Content of file {operation.path_in_repo} is already present upstream - skipping upload") return upload_action = lfs_batch_action["actions"].get("upload") verify_action = lfs_batch_action["actions"].get("verify") @@ -476,9 +461,7 @@ def fetch_upload_modes( ) hf_raise_for_status(resp) preupload_info = _validate_preupload_info(resp.json()) - upload_modes.update( - **{file["path"]: file["uploadMode"] for file in preupload_info["files"]} - ) + upload_modes.update(**{file["path"]: file["uploadMode"] for file in preupload_info["files"]}) # If a file is empty, it is most likely a mistake. # => a warning message is triggered to warn the user. @@ -490,10 +473,7 @@ def fetch_upload_modes( if addition.upload_info.size == 0: path = addition.path_in_repo if not path.endswith(".gitkeep"): - warnings.warn( - f"About to commit an empty file: '{path}'. Are you sure this is" - " intended?" - ) + warnings.warn(f"About to commit an empty file: '{path}'. Are you sure this is intended?") upload_modes[path] = "regular" return upload_modes @@ -527,10 +507,7 @@ def prepare_commit_payload( # 2. Send operations, one per line for operation in operations: # 2.a. Case adding a regular file - if ( - isinstance(operation, CommitOperationAdd) - and upload_modes.get(operation.path_in_repo) == "regular" - ): + if isinstance(operation, CommitOperationAdd) and upload_modes.get(operation.path_in_repo) == "regular": yield { "key": "file", "value": { @@ -540,10 +517,7 @@ def prepare_commit_payload( }, } # 2.b. Case adding an LFS file - elif ( - isinstance(operation, CommitOperationAdd) - and upload_modes.get(operation.path_in_repo) == "lfs" - ): + elif isinstance(operation, CommitOperationAdd) and upload_modes.get(operation.path_in_repo) == "lfs": yield { "key": "lfsFile", "value": { diff --git a/src/huggingface_hub/_login.py b/src/huggingface_hub/_login.py index 8b9eef0920..3adb7181fc 100644 --- a/src/huggingface_hub/_login.py +++ b/src/huggingface_hub/_login.py @@ -138,10 +138,7 @@ def interpreter_login() -> None: ) print(" Setting a new token will erase the existing one.") - print( - " To login, `huggingface_hub` requires a token generated from" - " https://huggingface.co/settings/tokens ." - ) + print(" To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .") if os.name == "nt": print("Token can be pasted using 'Right-Click'.") token = getpass("Token: ") @@ -193,14 +190,10 @@ def notebook_login() -> None: " Colab) and you need the `ipywidgets` module: `pip install ipywidgets`." ) - box_layout = widgets.Layout( - display="flex", flex_flow="column", align_items="center", width="50%" - ) + box_layout = widgets.Layout(display="flex", flex_flow="column", align_items="center", width="50%") token_widget = widgets.Password(description="Token:") - git_checkbox_widget = widgets.Checkbox( - value=True, description="Add token as git credential?" - ) + git_checkbox_widget = widgets.Checkbox(value=True, description="Add token as git credential?") token_finish_button = widgets.Button(description="Login") login_token_widget = widgets.VBox( @@ -304,8 +297,6 @@ def _set_store_as_git_credential_helper_globally() -> None: raise EnvironmentError(exc.stderr) -@_deprecate_method( - version="0.14", message="Please use `list_credential_helpers` instead." -) +@_deprecate_method(version="0.14", message="Please use `list_credential_helpers` instead.") def _currently_setup_credential_helpers(directory: Optional[str] = None) -> List[str]: return list_credential_helpers(directory) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index f4d6f83b61..0f0e0907d5 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -13,9 +13,8 @@ ) from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name from .hf_api import HfApi -from .utils import filter_repo_objects, logging +from .utils import filter_repo_objects, logging, validate_hf_hub_args from .utils import tqdm as hf_tqdm -from .utils import validate_hf_hub_args logger = logging.get_logger(__name__) @@ -124,14 +123,9 @@ def snapshot_download( if repo_type is None: repo_type = "model" if repo_type not in REPO_TYPES: - raise ValueError( - f"Invalid repo type: {repo_type}. Accepted repo types are:" - f" {str(REPO_TYPES)}" - ) + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}") - storage_folder = os.path.join( - cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type) - ) + storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) # if we have no internet connection we will look for an # appropriate folder in the cache @@ -166,9 +160,7 @@ def snapshot_download( revision=revision, token=token, ) - assert ( - repo_info.sha is not None - ), "Repo info returned from server must have a revision sha." + assert repo_info.sha is not None, "Repo info returned from server must have a revision sha." filtered_repo_files = list( filter_repo_objects( items=[f.rfilename for f in repo_info.siblings], diff --git a/src/huggingface_hub/commands/delete_cache.py b/src/huggingface_hub/commands/delete_cache.py index 1f9feedbc3..e33b60bae7 100644 --- a/src/huggingface_hub/commands/delete_cache.py +++ b/src/huggingface_hub/commands/delete_cache.py @@ -78,6 +78,7 @@ def require_inquirer_py(fn: Callable) -> Callable: """Decorator to flag methods that require `InquirerPy`.""" + # TODO: refactor this + imports in a unified pattern across codebase @wraps(fn) def _inner(*args, **kwargs): @@ -100,17 +101,13 @@ def _inner(*args, **kwargs): class DeleteCacheCommand(BaseHuggingfaceCLICommand): @staticmethod def register_subcommand(parser: _SubParsersAction): - delete_cache_parser = parser.add_parser( - "delete-cache", help="Delete revisions from the cache directory." - ) + delete_cache_parser = parser.add_parser("delete-cache", help="Delete revisions from the cache directory.") delete_cache_parser.add_argument( "--dir", type=str, default=None, - help=( - "cache directory (optional). Default to the default HuggingFace cache." - ), + help="cache directory (optional). Default to the default HuggingFace cache.", ) delete_cache_parser.add_argument( @@ -141,10 +138,7 @@ def run(self): # If deletion is not cancelled if len(selected_hashes) > 0 and _CANCEL_DELETION_STR not in selected_hashes: - confirm_message = ( - _get_expectations_str(hf_cache_info, selected_hashes) - + " Confirm deletion ?" - ) + confirm_message = _get_expectations_str(hf_cache_info, selected_hashes) + " Confirm deletion ?" # Confirm deletion if self.disable_tui: @@ -175,9 +169,7 @@ def _manual_review_tui(hf_cache_info: HFCacheInfo, preselected: List[str]) -> Li Displays a multi-select menu in the terminal (TUI). """ # Define multiselect list - choices = _get_tui_choices_from_scan( - repos=hf_cache_info.repos, preselected=preselected - ) + choices = _get_tui_choices_from_scan(repos=hf_cache_info.repos, preselected=preselected) checkbox = inquirer.checkbox( message="Select revisions to delete:", choices=choices, # List of revisions with some pre-selection @@ -187,15 +179,10 @@ def _manual_review_tui(hf_cache_info: HFCacheInfo, preselected: List[str]) -> Li # deletion. instruction=_get_expectations_str( hf_cache_info, - selected_hashes=[ - c.value for c in choices if isinstance(c, Choice) and c.enabled - ], + selected_hashes=[c.value for c in choices if isinstance(c, Choice) and c.enabled], ), # We use the long instruction to should keybindings instructions to the user - long_instruction=( - "Press to select, to validate and to quit" - " without modification." - ), + long_instruction="Press to select, to validate and to quit without modification.", # Message that is displayed once the user validates its selection. transformer=lambda result: f"{len(result)} revision(s) selected.", ) @@ -207,11 +194,7 @@ def _update_expectations(_) -> None: # a revision hash is selected/unselected. checkbox._instruction = _get_expectations_str( hf_cache_info, - selected_hashes=[ - choice["value"] - for choice in checkbox.content_control.choices - if choice["enabled"] - ], + selected_hashes=[choice["value"] for choice in checkbox.content_control.choices if choice["enabled"]], ) checkbox.kb_func_lookup["toggle"].append({"func": _update_expectations}) @@ -229,9 +212,7 @@ def _ask_for_confirmation_tui(message: str, default: bool = True) -> bool: return inquirer.confirm(message, default=default).execute() -def _get_tui_choices_from_scan( - repos: Iterable[CachedRepoInfo], preselected: List[str] -) -> List: +def _get_tui_choices_from_scan(repos: Iterable[CachedRepoInfo], preselected: List[str]) -> List: """Build a list of choices from the scanned repos. Args: @@ -282,9 +263,7 @@ def _get_tui_choices_from_scan( return choices -def _manual_review_no_tui( - hf_cache_info: HFCacheInfo, preselected: List[str] -) -> List[str]: +def _manual_review_no_tui(hf_cache_info: HFCacheInfo, preselected: List[str]) -> List[str]: """Ask the user for a manual review of the revisions to delete. Used when TUI is disabled. Manual review happens in a separate tmp file that the @@ -357,9 +336,7 @@ def _ask_for_confirmation_no_tui(message: str, default: bool = True) -> bool: print(f"Invalid input. Must be one of {ALL}") -def _get_expectations_str( - hf_cache_info: HFCacheInfo, selected_hashes: List[str] -) -> str: +def _get_expectations_str(hf_cache_info: HFCacheInfo, selected_hashes: List[str]) -> str: """Format a string to display to the user how much space would be saved. Example: @@ -371,10 +348,7 @@ def _get_expectations_str( if _CANCEL_DELETION_STR in selected_hashes: return "Nothing will be deleted." strategy = hf_cache_info.delete_revisions(*selected_hashes) - return ( - f"{len(selected_hashes)} revisions selected counting for" - f" {strategy.expected_freed_size_str}." - ) + return f"{len(selected_hashes)} revisions selected counting for {strategy.expected_freed_size_str}." def _read_manual_review_tmp_file(tmp_path: str) -> List[str]: diff --git a/src/huggingface_hub/commands/env.py b/src/huggingface_hub/commands/env.py index f987adcac0..26d0d7fb15 100644 --- a/src/huggingface_hub/commands/env.py +++ b/src/huggingface_hub/commands/env.py @@ -28,9 +28,7 @@ def __init__(self, args): @staticmethod def register_subcommand(parser: _SubParsersAction): - env_parser = parser.add_parser( - "env", help="Print information about the environment." - ) + env_parser = parser.add_parser("env", help="Print information about the environment.") env_parser.set_defaults(func=EnvironmentCommand) def run(self) -> None: diff --git a/src/huggingface_hub/commands/huggingface_cli.py b/src/huggingface_hub/commands/huggingface_cli.py index fc9a754dae..d5d4bbc79b 100644 --- a/src/huggingface_hub/commands/huggingface_cli.py +++ b/src/huggingface_hub/commands/huggingface_cli.py @@ -23,9 +23,7 @@ def main(): - parser = ArgumentParser( - "huggingface-cli", usage="huggingface-cli []" - ) + parser = ArgumentParser("huggingface-cli", usage="huggingface-cli []") commands_parser = parser.add_subparsers(help="huggingface-cli command helpers") # Register commands diff --git a/src/huggingface_hub/commands/lfs.py b/src/huggingface_hub/commands/lfs.py index 3e362752db..5e87d6a570 100644 --- a/src/huggingface_hub/commands/lfs.py +++ b/src/huggingface_hub/commands/lfs.py @@ -24,6 +24,7 @@ from typing import Dict, List, Optional import requests + from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.lfs import LFS_MULTIPART_UPLOAD_COMMAND, SliceFileObj @@ -60,9 +61,7 @@ def register_subcommand(parser: _SubParsersAction): "lfs-enable-largefiles", help="Configure your repository to enable upload of files > 5GB.", ) - enable_parser.add_argument( - "path", type=str, help="Local path to repository you want to configure." - ) + enable_parser.add_argument("path", type=str, help="Local path to repository you want to configure.") enable_parser.set_defaults(func=lambda args: LfsEnableCommand(args)) upload_parser = parser.add_parser( @@ -87,8 +86,7 @@ def run(self): cwd=local_path, ) subprocess.run( - "git config lfs.customtransfer.multipart.args" - f" {LFS_MULTIPART_UPLOAD_COMMAND}".split(), + f"git config lfs.customtransfer.multipart.args {LFS_MULTIPART_UPLOAD_COMMAND}".split(), check=True, cwd=local_path, ) @@ -126,9 +124,7 @@ def run(self): # sends initiation data to the process over stdin. # This tells the process useful information about the configuration. init_msg = json.loads(sys.stdin.readline().strip()) - if not ( - init_msg.get("event") == "init" and init_msg.get("operation") == "upload" - ): + if not (init_msg.get("event") == "init" and init_msg.get("operation") == "upload"): write_msg({"error": {"code": 32, "message": "Wrong lfs init operation"}}) sys.exit(1) diff --git a/src/huggingface_hub/commands/scan_cache.py b/src/huggingface_hub/commands/scan_cache.py index 84b728a532..ff26fa9de5 100644 --- a/src/huggingface_hub/commands/scan_cache.py +++ b/src/huggingface_hub/commands/scan_cache.py @@ -32,18 +32,13 @@ class ScanCacheCommand(BaseHuggingfaceCLICommand): @staticmethod def register_subcommand(parser: _SubParsersAction): - scan_cache_parser = parser.add_parser( - "scan-cache", help="Scan cache directory." - ) + scan_cache_parser = parser.add_parser("scan-cache", help="Scan cache directory.") scan_cache_parser.add_argument( "--dir", type=str, default=None, - help=( - "cache directory to scan (optional). Default to the" - " default HuggingFace cache." - ), + help="cache directory to scan (optional). Default to the default HuggingFace cache.", ) scan_cache_parser.add_argument( "-v", @@ -98,9 +93,7 @@ def _print_hf_cache_info_as_table(self, hf_cache_info: HFCacheInfo) -> None: ", ".join(sorted(repo.refs)), str(repo.repo_path), ] - for repo in sorted( - hf_cache_info.repos, key=lambda repo: repo.repo_path - ) + for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path) ], headers=[ "REPO ID", @@ -128,12 +121,8 @@ def _print_hf_cache_info_as_table(self, hf_cache_info: HFCacheInfo) -> None: ", ".join(sorted(revision.refs)), str(revision.snapshot_path), ] - for repo in sorted( - hf_cache_info.repos, key=lambda repo: repo.repo_path - ) - for revision in sorted( - repo.revisions, key=lambda revision: revision.commit_hash - ) + for repo in sorted(hf_cache_info.repos, key=lambda repo: repo.repo_path) + for revision in sorted(repo.revisions, key=lambda revision: revision.commit_hash) ], headers=[ "REPO ID", diff --git a/src/huggingface_hub/commands/user.py b/src/huggingface_hub/commands/user.py index 15e9fd3ba4..9441f31b91 100644 --- a/src/huggingface_hub/commands/user.py +++ b/src/huggingface_hub/commands/user.py @@ -14,6 +14,8 @@ import subprocess from argparse import _SubParsersAction +from requests.exceptions import HTTPError + from huggingface_hub.commands import BaseHuggingfaceCLICommand from huggingface_hub.constants import ( ENDPOINT, @@ -22,21 +24,18 @@ SPACES_SDK_TYPES, ) from huggingface_hub.hf_api import HfApi -from requests.exceptions import HTTPError -from .._login import ( # noqa: F401 # for backward compatibility +from .._login import ( # noqa: F401 # for backward compatibility # noqa: F401 # for backward compatibility NOTEBOOK_LOGIN_PASSWORD_HTML, NOTEBOOK_LOGIN_TOKEN_HTML_END, NOTEBOOK_LOGIN_TOKEN_HTML_START, -) -from .._login import ( # noqa: F401 # for backward compatibility - _currently_setup_credential_helpers as currently_setup_credential_helpers, -) -from .._login import ( # noqa: F401 # for backward compatibility login, logout, notebook_login, ) +from .._login import ( + _currently_setup_credential_helpers as currently_setup_credential_helpers, # noqa: F401 # for backward compatibility +) from ..utils import HfFolder from ._cli_utils import ANSI @@ -44,13 +43,9 @@ class UserCommands(BaseHuggingfaceCLICommand): @staticmethod def register_subcommand(parser: _SubParsersAction): - login_parser = parser.add_parser( - "login", help="Log in using a token from huggingface.co/settings/tokens" - ) + login_parser = parser.add_parser("login", help="Log in using a token from huggingface.co/settings/tokens") login_parser.set_defaults(func=lambda args: LoginCommand(args)) - whoami_parser = parser.add_parser( - "whoami", help="Find out which huggingface.co account you are logged in as." - ) + whoami_parser = parser.add_parser("whoami", help="Find out which huggingface.co account you are logged in as.") whoami_parser.set_defaults(func=lambda args: WhoamiCommand(args)) logout_parser = parser.add_parser("logout", help="Log out") logout_parser.set_defaults(func=lambda args: LogoutCommand(args)) @@ -58,43 +53,25 @@ def register_subcommand(parser: _SubParsersAction): # new system: git-based repo system repo_parser = parser.add_parser( "repo", - help=( - "{create, ls-files} Commands to interact with your huggingface.co" - " repos." - ), - ) - repo_subparsers = repo_parser.add_subparsers( - help="huggingface.co repos related commands" - ) - repo_create_parser = repo_subparsers.add_parser( - "create", help="Create a new repo on huggingface.co" + help="{create, ls-files} Commands to interact with your huggingface.co repos.", ) + repo_subparsers = repo_parser.add_subparsers(help="huggingface.co repos related commands") + repo_create_parser = repo_subparsers.add_parser("create", help="Create a new repo on huggingface.co") repo_create_parser.add_argument( "name", type=str, - help=( - "Name for your repo. Will be namespaced under your username to build" - " the repo id." - ), + help="Name for your repo. Will be namespaced under your username to build the repo id.", ) repo_create_parser.add_argument( "--type", type=str, - help=( - 'Optional: repo_type: set to "dataset" or "space" if creating a dataset' - " or space, default is model." - ), - ) - repo_create_parser.add_argument( - "--organization", type=str, help="Optional: organization namespace." + help='Optional: repo_type: set to "dataset" or "space" if creating a dataset or space, default is model.', ) + repo_create_parser.add_argument("--organization", type=str, help="Optional: organization namespace.") repo_create_parser.add_argument( "--space_sdk", type=str, - help=( - "Optional: Hugging Face Spaces SDK type. Required when --type is set to" - ' "space".' - ), + help='Optional: Hugging Face Spaces SDK type. Required when --type is set to "space".', choices=SPACES_SDK_TYPES, ) repo_create_parser.add_argument( @@ -169,9 +146,7 @@ def run(self): print("") user = self._api.whoami(token)["name"] - namespace = ( - self.args.organization if self.args.organization is not None else user - ) + namespace = self.args.organization if self.args.organization is not None else user repo_id = f"{namespace}/{self.args.name}" @@ -204,9 +179,6 @@ def run(self): exit(1) print("\nYour repo now lives at:") print(f" {ANSI.bold(url)}") - print( - "\nYou can clone it locally with the command below," - " and commit/push as usual." - ) + print("\nYou can clone it locally with the command below, and commit/push as usual.") print(f"\n git clone {url}") print("") diff --git a/src/huggingface_hub/constants.py b/src/huggingface_hub/constants.py index 9c82d5f8e8..19d18585a0 100644 --- a/src/huggingface_hub/constants.py +++ b/src/huggingface_hub/constants.py @@ -39,9 +39,7 @@ def _is_true_or_auto(value: Optional[str]) -> bool: _staging_mode = _is_true(os.environ.get("HUGGINGFACE_CO_STAGING")) -ENDPOINT = os.getenv("HF_ENDPOINT") or ( - "https://hub-ci.huggingface.co" if _staging_mode else "https://huggingface.co" -) +ENDPOINT = os.getenv("HF_ENDPOINT") or ("https://hub-ci.huggingface.co" if _staging_mode else "https://huggingface.co") HUGGINGFACE_CO_URL_TEMPLATE = ENDPOINT + "/{repo_id}/resolve/{revision}/{filename}" HUGGINGFACE_HEADER_X_REPO_COMMIT = "X-Repo-Commit" @@ -83,9 +81,7 @@ def _is_true_or_auto(value: Optional[str]) -> bool: default_assets_cache_path = os.path.join(hf_cache_home, "assets") HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", default_cache_path) -HUGGINGFACE_ASSETS_CACHE = os.getenv( - "HUGGINGFACE_ASSETS_CACHE", default_assets_cache_path -) +HUGGINGFACE_ASSETS_CACHE = os.getenv("HUGGINGFACE_ASSETS_CACHE", default_assets_cache_path) HF_HUB_OFFLINE = _is_true(os.environ.get("HF_HUB_OFFLINE")) @@ -103,20 +99,14 @@ def _is_true_or_auto(value: Optional[str]) -> bool: # TL;DR: env variable has priority over code __HF_HUB_DISABLE_PROGRESS_BARS = os.environ.get("HF_HUB_DISABLE_PROGRESS_BARS") HF_HUB_DISABLE_PROGRESS_BARS: Optional[bool] = ( - _is_true(__HF_HUB_DISABLE_PROGRESS_BARS) - if __HF_HUB_DISABLE_PROGRESS_BARS is not None - else None + _is_true(__HF_HUB_DISABLE_PROGRESS_BARS) if __HF_HUB_DISABLE_PROGRESS_BARS is not None else None ) # Disable warning on machines that do not support symlinks (e.g. Windows non-developer) -HF_HUB_DISABLE_SYMLINKS_WARNING: bool = _is_true( - os.environ.get("HF_HUB_DISABLE_SYMLINKS_WARNING") -) +HF_HUB_DISABLE_SYMLINKS_WARNING: bool = _is_true(os.environ.get("HF_HUB_DISABLE_SYMLINKS_WARNING")) # Disable sending the cached token by default is all HTTP requests to the Hub -HF_HUB_DISABLE_IMPLICIT_TOKEN: bool = _is_true( - os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN") -) +HF_HUB_DISABLE_IMPLICIT_TOKEN: bool = _is_true(os.environ.get("HF_HUB_DISABLE_IMPLICIT_TOKEN")) # Enable fast-download using external dependency "hf_transfer" # See: diff --git a/src/huggingface_hub/fastai_utils.py b/src/huggingface_hub/fastai_utils.py index 9e5127303f..8ef18aa86b 100644 --- a/src/huggingface_hub/fastai_utils.py +++ b/src/huggingface_hub/fastai_utils.py @@ -144,16 +144,11 @@ def _check_fastai_fastcore_pyproject_versions( # If the package is specified but not the version (e.g. "fastai" instead of "fastai=2.4"), the default versions are the highest. fastai_packages = [pck for pck in package_versions if pck.startswith("fastai")] if len(fastai_packages) == 0: - logger.warning( - "The repository does not have a fastai version specified in the" - " `pyproject.toml`." - ) + logger.warning("The repository does not have a fastai version specified in the `pyproject.toml`.") # fastai_version is an empty string if not specified else: fastai_version = str(fastai_packages[0]).partition("=")[2] - if fastai_version != "" and version.Version(fastai_version) < version.Version( - fastai_min_version - ): + if fastai_version != "" and version.Version(fastai_version) < version.Version(fastai_min_version): raise ImportError( "`from_pretrained_fastai` requires" f" fastai>={fastai_min_version} version but the model to load uses" @@ -162,16 +157,11 @@ def _check_fastai_fastcore_pyproject_versions( fastcore_packages = [pck for pck in package_versions if pck.startswith("fastcore")] if len(fastcore_packages) == 0: - logger.warning( - "The repository does not have a fastcore version specified in the" - " `pyproject.toml`." - ) + logger.warning("The repository does not have a fastcore version specified in the `pyproject.toml`.") # fastcore_version is an empty string if not specified else: fastcore_version = str(fastcore_packages[0]).partition("=")[2] - if fastcore_version != "" and version.Version( - fastcore_version - ) < version.Version(fastcore_min_version): + if fastcore_version != "" and version.Version(fastcore_version) < version.Version(fastcore_min_version): raise ImportError( "`from_pretrained_fastai` requires" f" fastcore>={fastcore_min_version} version, but you are using fastcore" @@ -281,9 +271,7 @@ def _save_pretrained_fastai( # if the user provides config then we update it with the fastai and fastcore versions in CONFIG_TEMPLATE. if config is not None: if not isinstance(config, dict): - raise RuntimeError( - f"Provided config should be a dict. Got: '{type(config)}'" - ) + raise RuntimeError(f"Provided config should be a dict. Got: '{type(config)}'") path = os.path.join(save_directory, CONFIG_NAME) with open(path, "w") as f: json.dump(config, f) @@ -413,9 +401,7 @@ def push_to_hub_fastai( """ _check_fastai_fastcore_versions() api = HfApi(endpoint=api_endpoint) - api.create_repo( - repo_id=repo_id, repo_type="model", token=token, private=private, exist_ok=True - ) + api.create_repo(repo_id=repo_id, repo_type="model", token=token, private=private, exist_ok=True) # Push the files to the repo in a single commit with SoftTemporaryDirectory() as tmp: diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 12575a543f..d5637aad0c 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -19,9 +19,10 @@ import requests from filelock import FileLock -from huggingface_hub import constants from requests.exceptions import ConnectTimeout, ProxyError +from huggingface_hub import constants + from . import __version__ # noqa: F401 # for backward compatibility from .constants import ( DEFAULT_REVISION, @@ -36,27 +37,27 @@ REPO_TYPES, REPO_TYPES_URL_PREFIXES, ) -from .utils import get_fastai_version # noqa: F401 # for backward compatibility -from .utils import get_fastcore_version # noqa: F401 # for backward compatibility -from .utils import get_graphviz_version # noqa: F401 # for backward compatibility -from .utils import get_jinja_version # noqa: F401 # for backward compatibility -from .utils import get_pydot_version # noqa: F401 # for backward compatibility -from .utils import get_tf_version # noqa: F401 # for backward compatibility -from .utils import get_torch_version # noqa: F401 # for backward compatibility -from .utils import is_fastai_available # noqa: F401 # for backward compatibility -from .utils import is_fastcore_available # noqa: F401 # for backward compatibility -from .utils import is_graphviz_available # noqa: F401 # for backward compatibility -from .utils import is_jinja_available # noqa: F401 # for backward compatibility -from .utils import is_pydot_available # noqa: F401 # for backward compatibility -from .utils import is_tf_available # noqa: F401 # for backward compatibility -from .utils import is_torch_available # noqa: F401 # for backward compatibility from .utils import ( EntryNotFoundError, LocalEntryNotFoundError, SoftTemporaryDirectory, build_hf_headers, + get_fastai_version, # noqa: F401 # for backward compatibility + get_fastcore_version, # noqa: F401 # for backward compatibility + get_graphviz_version, # noqa: F401 # for backward compatibility + get_jinja_version, # noqa: F401 # for backward compatibility + get_pydot_version, # noqa: F401 # for backward compatibility + get_tf_version, # noqa: F401 # for backward compatibility + get_torch_version, # noqa: F401 # for backward compatibility hf_raise_for_status, http_backoff, + is_fastai_available, # noqa: F401 # for backward compatibility + is_fastcore_available, # noqa: F401 # for backward compatibility + is_graphviz_available, # noqa: F401 # for backward compatibility + is_jinja_available, # noqa: F401 # for backward compatibility + is_pydot_available, # noqa: F401 # for backward compatibility + is_tf_available, # noqa: F401 # for backward compatibility + is_torch_available, # noqa: F401 # for backward compatibility logging, tqdm, validate_hf_hub_args, @@ -347,9 +348,7 @@ def _raise_if_offline_mode_is_enabled(msg: Optional[str] = None): HF_HUB_OFFLINE is True.""" if constants.HF_HUB_OFFLINE: raise OfflineModeIsEnabled( - "Offline mode is enabled." - if msg is None - else "Offline mode is enabled. " + str(msg) + "Offline mode is enabled." if msg is None else "Offline mode is enabled. " + str(msg) ) @@ -477,7 +476,7 @@ def http_get( # Download file using an external Rust-based package. Download is faster # (~2x speed-up) but support less features (no error handling, no retries, # no progress bars). - from hf_transfer import download + from hf_transfer import download # type: ignore logger.debug(f"Download {url} using HF_TRANSFER.") max_files = 100 @@ -627,8 +626,10 @@ def cached_download( """ if not legacy_cache_layout: warnings.warn( - "`cached_download` is the legacy way to download files from the HF hub," - " please consider upgrading to `hf_hub_download`", + ( + "`cached_download` is the legacy way to download files from the HF hub," + " please consider upgrading to `hf_hub_download`" + ), FutureWarning, ) @@ -666,8 +667,7 @@ def cached_download( # If we don't have any of those, raise an error. if etag is None: raise OSError( - "Distant resource does not have an ETag, we won't be able to" - " reliably ensure reproducibility." + "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." ) # In case of a redirect, save an extra redirect on the request.get call, # and ensure we download the exact atomic version even if it changed @@ -687,9 +687,7 @@ def cached_download( # etag is None pass - filename = ( - force_filename if force_filename is not None else url_to_filename(url, etag) - ) + filename = force_filename if force_filename is not None else url_to_filename(url, etag) # get cache path to put the file cache_path = os.path.join(cache_dir, filename) @@ -702,16 +700,10 @@ def cached_download( else: matching_files = [ file - for file in fnmatch.filter( - os.listdir(cache_dir), filename.split(".")[0] + ".*" - ) + for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*") if not file.endswith(".json") and not file.endswith(".lock") ] - if ( - len(matching_files) > 0 - and not force_download - and force_filename is None - ): + if len(matching_files) > 0 and not force_download and force_filename is None: return os.path.join(cache_dir, matching_files[-1]) else: # If files cannot be found and local_files_only=True, @@ -865,9 +857,7 @@ def _create_relative_symlink(src: str, dst: str, new_blob: bool = False) -> None shutil.copyfile(src, dst) -def _cache_commit_hash_for_specific_revision( - storage_folder: str, revision: str, commit_hash: str -) -> None: +def _cache_commit_hash_for_specific_revision(storage_folder: str, revision: str, commit_hash: str) -> None: """Cache reference between a revision (tag, branch or truncated commit hash) and the corresponding commit hash. Does nothing if `revision` is already a proper `commit_hash` or reference is already cached. @@ -1018,8 +1008,10 @@ def hf_hub_download( """ if force_filename is not None: warnings.warn( - "The `force_filename` parameter is deprecated as a new caching system, " - "which keeps the filenames as they are on the Hub, is now in place.", + ( + "The `force_filename` parameter is deprecated as a new caching system, " + "which keeps the filenames as they are on the Hub, is now in place." + ), FutureWarning, ) legacy_cache_layout = True @@ -1065,14 +1057,9 @@ def hf_hub_download( if repo_type is None: repo_type = "model" if repo_type not in REPO_TYPES: - raise ValueError( - f"Invalid repo type: {repo_type}. Accepted repo types are:" - f" {str(REPO_TYPES)}" - ) + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}") - storage_folder = os.path.join( - cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type) - ) + storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type)) os.makedirs(storage_folder, exist_ok=True) # cross platform transcription of filename, to be used as a local file path. @@ -1081,9 +1068,7 @@ def hf_hub_download( # if user provides a commit_hash and they already have the file on disk, # shortcut everything. if REGEX_COMMIT_HASH.match(revision): - pointer_path = os.path.join( - storage_folder, "snapshots", revision, relative_filename - ) + pointer_path = os.path.join(storage_folder, "snapshots", revision, relative_filename) if os.path.exists(pointer_path): return pointer_path @@ -1110,30 +1095,18 @@ def hf_hub_download( ) except EntryNotFoundError as http_error: # Cache the non-existence of the file and raise - commit_hash = http_error.response.headers.get( - HUGGINGFACE_HEADER_X_REPO_COMMIT - ) + commit_hash = http_error.response.headers.get(HUGGINGFACE_HEADER_X_REPO_COMMIT) if commit_hash is not None and not legacy_cache_layout: - no_exist_file_path = ( - Path(storage_folder) - / ".no_exist" - / commit_hash - / relative_filename - ) + no_exist_file_path = Path(storage_folder) / ".no_exist" / commit_hash / relative_filename no_exist_file_path.parent.mkdir(parents=True, exist_ok=True) no_exist_file_path.touch() - _cache_commit_hash_for_specific_revision( - storage_folder, revision, commit_hash - ) + _cache_commit_hash_for_specific_revision(storage_folder, revision, commit_hash) raise # Commit hash must exist commit_hash = metadata.commit_hash if commit_hash is None: - raise OSError( - "Distant resource does not seem to be on huggingface.co (missing" - " commit header)." - ) + raise OSError("Distant resource does not seem to be on huggingface.co (missing commit header).") # Etag must exist etag = metadata.etag @@ -1142,8 +1115,7 @@ def hf_hub_download( # If we don't have any of those, raise an error. if etag is None: raise OSError( - "Distant resource does not have an ETag, we won't be able to" - " reliably ensure reproducibility." + "Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility." ) # In case of a redirect, save an extra redirect on the request.get call, @@ -1174,8 +1146,7 @@ def hf_hub_download( # In those cases, we cannot force download. if force_download: raise ValueError( - "We have no connection or you passed local_files_only, so" - " force_download is not an accepted option." + "We have no connection or you passed local_files_only, so force_download is not an accepted option." ) # Try to get "commit_hash" from "revision" @@ -1190,9 +1161,7 @@ def hf_hub_download( # Return pointer file if exists if commit_hash is not None: - pointer_path = os.path.join( - storage_folder, "snapshots", commit_hash, relative_filename - ) + pointer_path = os.path.join(storage_folder, "snapshots", commit_hash, relative_filename) if os.path.exists(pointer_path): return pointer_path @@ -1218,9 +1187,7 @@ def hf_hub_download( assert etag is not None, "etag must have been retrieved from server" assert commit_hash is not None, "commit_hash must have been retrieved from server" blob_path = os.path.join(storage_folder, "blobs", etag) - pointer_path = os.path.join( - storage_folder, "snapshots", commit_hash, relative_filename - ) + pointer_path = os.path.join(storage_folder, "snapshots", commit_hash, relative_filename) os.makedirs(os.path.dirname(blob_path), exist_ok=True) os.makedirs(os.path.dirname(pointer_path), exist_ok=True) @@ -1356,10 +1323,7 @@ def try_to_load_from_cache( if repo_type is None: repo_type = "model" if repo_type not in REPO_TYPES: - raise ValueError( - f"Invalid repo type: {repo_type}. Accepted repo types are:" - f" {str(REPO_TYPES)}" - ) + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}") if cache_dir is None: cache_dir = HUGGINGFACE_HUB_CACHE @@ -1452,10 +1416,7 @@ def get_hf_file_metadata( # Do not use directly `url`, as `_request_wrapper` might have followed relative # redirects. location=r.headers.get("Location") or r.request.url, # type: ignore - size=_int_or_none( - r.headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE) - or r.headers.get("Content-Length") - ), + size=_int_or_none(r.headers.get(HUGGINGFACE_HEADER_X_LINKED_SIZE) or r.headers.get("Content-Length")), ) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 61fa0db927..5ac040f0b6 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -23,9 +23,10 @@ from urllib.parse import quote import requests -from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError from requests.exceptions import HTTPError +from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError + from ._commit_api import ( CommitOperation, CommitOperationAdd, @@ -90,9 +91,7 @@ logger = logging.get_logger(__name__) -def repo_type_and_id_from_hf_id( - hf_id: str, hub_url: Optional[str] = None -) -> Tuple[Optional[str], Optional[str], str]: +def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> Tuple[Optional[str], Optional[str], str]: """ Returns the repo type and ID from a huggingface.co URL linking to a repository @@ -154,9 +153,7 @@ def repo_type_and_id_from_hf_id( repo_id = url_segments[0] namespace, repo_type = None, None else: - raise ValueError( - f"Unable to retrieve user and repo ID from the passed HF ID: {hf_id}" - ) + raise ValueError(f"Unable to retrieve user and repo ID from the passed HF ID: {hf_id}") # Check if repo type is known (mapping "spaces" => "space" + empty value => `None`) if repo_type in REPO_TYPES_MAPPING: @@ -276,9 +273,7 @@ def __init__(self, url: Any, endpoint: Optional[str] = None) -> None: super().__init__() # Parse URL self.endpoint = endpoint or ENDPOINT - repo_type, namespace, repo_name = repo_type_and_id_from_hf_id( - self, hub_url=self.endpoint - ) + repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(self, hub_url=self.endpoint) # Populate fields self.namespace = namespace @@ -287,10 +282,7 @@ def __init__(self, url: Any, endpoint: Optional[str] = None) -> None: self.url = str(self) # just in case it's needed def __repr__(self) -> str: - return ( - f"RepoUrl('{self}', endpoint='{self.endpoint}'," - f" repo_type='{self.repo_type}', repo_id='{self.repo_id}')" - ) + return f"RepoUrl('{self}', endpoint='{self.endpoint}', repo_type='{self.repo_type}', repo_id='{self.repo_id}')" class RepoFile: @@ -389,9 +381,7 @@ def __init__( self.lastModified = lastModified self.tags = tags self.pipeline_tag = pipeline_tag - self.siblings = ( - [RepoFile(**x) for x in siblings] if siblings is not None else [] - ) + self.siblings = [RepoFile(**x) for x in siblings] if siblings is not None else [] self.private = private self.author = author self.config = config @@ -465,9 +455,7 @@ def __init__( self.description = description self.citation = citation self.cardData = cardData - self.siblings = ( - [RepoFile(**x) for x in siblings] if siblings is not None else [] - ) + self.siblings = [RepoFile(**x) for x in siblings] if siblings is not None else [] # Legacy stuff, "key" is always returned with an empty string # because of old versions of the datasets lib that need this field kwargs.pop("key", None) @@ -524,9 +512,7 @@ def __init__( self.id = id self.sha = sha self.lastModified = lastModified - self.siblings = ( - [RepoFile(**x) for x in siblings] if siblings is not None else [] - ) + self.siblings = [RepoFile(**x) for x in siblings] if siblings is not None else [] self.private = private self.author = author for k, v in kwargs.items(): @@ -1030,10 +1016,7 @@ def list_models( if emissions_thresholds is not None: if cardData is None: - raise ValueError( - "`emissions_thresholds` were passed without setting" - " `cardData=True`." - ) + raise ValueError("`emissions_thresholds` were passed without setting `cardData=True`.") else: return _filter_emissions(items, *emissions_thresholds) @@ -1058,11 +1041,7 @@ def _unpack_model_filter(self, model_filter: ModelFilter): # Handling tasks if model_filter.task is not None: - filter_list.extend( - [model_filter.task] - if isinstance(model_filter.task, str) - else model_filter.task - ) + filter_list.extend([model_filter.task] if isinstance(model_filter.task, str) else model_filter.task) # Handling dataset if model_filter.trained_dataset is not None: @@ -1076,18 +1055,12 @@ def _unpack_model_filter(self, model_filter: ModelFilter): # Handling library if model_filter.library: filter_list.extend( - [model_filter.library] - if isinstance(model_filter.library, str) - else model_filter.library + [model_filter.library] if isinstance(model_filter.library, str) else model_filter.library ) # Handling tags if model_filter.tags: - tags.extend( - [model_filter.tags] - if isinstance(model_filter.tags, str) - else model_filter.tags - ) + tags.extend([model_filter.tags] if isinstance(model_filter.tags, str) else model_filter.tags) query_dict: Dict[str, Any] = {} if model_str is not None: @@ -1520,8 +1493,7 @@ def list_liked_repos( user = me["name"] else: raise ValueError( - "Cannot list liked repos. You must provide a 'user' as input or be" - " logged in as a user." + "Cannot list liked repos. You must provide a 'user' as input or be logged in as a user." ) path = f"{self.endpoint}/api/users/{user}/likes" @@ -1540,21 +1512,9 @@ def list_liked_repos( return UserLikes( user=user, total=len(likes), - models=[ - like["repo"]["name"] - for like in likes - if like["repo"]["type"] == "model" - ], - datasets=[ - like["repo"]["name"] - for like in likes - if like["repo"]["type"] == "dataset" - ], - spaces=[ - like["repo"]["name"] - for like in likes - if like["repo"]["type"] == "space" - ], + models=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "model"], + datasets=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "dataset"], + spaces=[like["repo"]["name"] for like in likes if like["repo"]["type"] == "space"], ) @validate_hf_hub_args @@ -1613,9 +1573,7 @@ def model_info( path = ( f"{self.endpoint}/api/models/{repo_id}" if revision is None - else ( - f"{self.endpoint}/api/models/{repo_id}/revision/{quote(revision, safe='')}" - ) + else (f"{self.endpoint}/api/models/{repo_id}/revision/{quote(revision, safe='')}") ) params = {} if securityStatus: @@ -1679,9 +1637,7 @@ def dataset_info( path = ( f"{self.endpoint}/api/datasets/{repo_id}" if revision is None - else ( - f"{self.endpoint}/api/datasets/{repo_id}/revision/{quote(revision, safe='')}" - ) + else (f"{self.endpoint}/api/datasets/{repo_id}/revision/{quote(revision, safe='')}") ) params = {} if files_metadata: @@ -1744,9 +1700,7 @@ def space_info( path = ( f"{self.endpoint}/api/spaces/{repo_id}" if revision is None - else ( - f"{self.endpoint}/api/spaces/{repo_id}/revision/{quote(revision, safe='')}" - ) + else (f"{self.endpoint}/api/spaces/{repo_id}/revision/{quote(revision, safe='')}") ) params = {} if files_metadata: @@ -1981,23 +1935,17 @@ def create_repo( f" of {SPACES_SDK_TYPES} when repo_type is 'space'`" ) if space_sdk not in SPACES_SDK_TYPES: - raise ValueError( - f"Invalid space_sdk. Please choose one of {SPACES_SDK_TYPES}." - ) + raise ValueError(f"Invalid space_sdk. Please choose one of {SPACES_SDK_TYPES}.") json["sdk"] = space_sdk if space_sdk is not None and repo_type != "space": - warnings.warn( - "Ignoring provided space_sdk because repo_type is not 'space'." - ) + warnings.warn("Ignoring provided space_sdk because repo_type is not 'space'.") if space_hardware is not None: if repo_type == "space": json["hardware"] = space_hardware else: - warnings.warn( - "Ignoring provided space_hardware because repo_type is not 'space'." - ) + warnings.warn("Ignoring provided space_hardware because repo_type is not 'space'.") if getattr(self, "_lfsmultipartthresh", None): # Testing purposes only. @@ -2164,16 +2112,10 @@ def move_repo( """ if len(from_id.split("/")) != 2: - raise ValueError( - f"Invalid repo_id: {from_id}. It should have a namespace" - " (:namespace:/:repo_name:)" - ) + raise ValueError(f"Invalid repo_id: {from_id}. It should have a namespace (:namespace:/:repo_name:)") if len(to_id.split("/")) != 2: - raise ValueError( - f"Invalid repo_id: {to_id}. It should have a namespace" - " (:namespace:/:repo_name:)" - ) + raise ValueError(f"Invalid repo_id: {to_id}. It should have a namespace (:namespace:/:repo_name:)") if repo_type is None: repo_type = REPO_TYPE_MODEL # Hub won't accept `None`. @@ -2298,22 +2240,17 @@ def create_commit( if parent_commit is not None and not REGEX_COMMIT_OID.fullmatch(parent_commit): raise ValueError( - "`parent_commit` is not a valid commit OID. It must match the" - f" following regex: {REGEX_COMMIT_OID}" + f"`parent_commit` is not a valid commit OID. It must match the following regex: {REGEX_COMMIT_OID}" ) if commit_message is None or len(commit_message) == 0: raise ValueError("`commit_message` can't be empty, please pass a value.") - commit_description = ( - commit_description if commit_description is not None else "" - ) + commit_description = commit_description if commit_description is not None else "" repo_type = repo_type if repo_type is not None else REPO_TYPE_MODEL if repo_type not in REPO_TYPES: raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") - revision = ( - quote(revision, safe="") if revision is not None else DEFAULT_REVISION - ) + revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION create_pr = create_pr if create_pr is not None else False operations = list(operations) @@ -2321,10 +2258,7 @@ def create_commit( nb_additions = len(additions) nb_deletions = len(operations) - nb_additions - logger.debug( - f"About to commit to the hub: {len(additions)} addition(s) and" - f" {nb_deletions} deletion(s)." - ) + logger.debug(f"About to commit to the hub: {len(additions)} addition(s) and {nb_deletions} deletion(s).") # If updating twice the same file or update then delete a file in a single commit warn_on_overwriting_operations(operations) @@ -2344,11 +2278,7 @@ def create_commit( raise upload_lfs_files( - additions=[ - addition - for addition in additions - if upload_modes[addition.path_in_repo] == "lfs" - ], + additions=[addition for addition in additions if upload_modes[addition.path_in_repo] == "lfs"], repo_type=repo_type, repo_id=repo_id, token=token or self.token, @@ -2525,9 +2455,7 @@ def upload_file( raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") commit_message = ( - commit_message - if commit_message is not None - else f"Upload {path_in_repo} with huggingface_hub" + commit_message if commit_message is not None else f"Upload {path_in_repo} with huggingface_hub" ) operation = CommitOperationAdd( path_or_fileobj=path_or_fileobj, @@ -2684,9 +2612,7 @@ def upload_folder( path_in_repo = "" commit_message = ( - commit_message - if commit_message is not None - else f"Upload {path_in_repo} with huggingface_hub" + commit_message if commit_message is not None else f"Upload {path_in_repo} with huggingface_hub" ) files_to_add = _prepare_upload_folder_commit( @@ -2787,9 +2713,7 @@ def delete_file( """ commit_message = ( - commit_message - if commit_message is not None - else f"Delete {path_in_repo} with huggingface_hub" + commit_message if commit_message is not None else f"Delete {path_in_repo} with huggingface_hub" ) operations = [CommitOperationDelete(path_in_repo=path_in_repo)] @@ -2861,14 +2785,10 @@ def delete_folder( repo_id=repo_id, repo_type=repo_type, token=token, - operations=[ - CommitOperationDelete(path_in_repo=path_in_repo, is_folder=True) - ], + operations=[CommitOperationDelete(path_in_repo=path_in_repo, is_folder=True)], revision=revision, commit_message=( - commit_message - if commit_message is not None - else f"Delete folder {path_in_repo} with huggingface_hub" + commit_message if commit_message is not None else f"Delete folder {path_in_repo} with huggingface_hub" ), commit_description=commit_description, create_pr=create_pr, @@ -3044,9 +2964,7 @@ def create_tag( """ if repo_type is None: repo_type = REPO_TYPE_MODEL - revision = ( - quote(revision, safe="") if revision is not None else DEFAULT_REVISION - ) + revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION # Prepare request tag_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tag/{revision}" @@ -3196,9 +3114,7 @@ def get_repo_discussions( headers = self._build_hf_headers(token=token) def _fetch_discussion_page(page_index: int): - path = ( - f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?p={page_index}" - ) + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?p={page_index}" resp = requests.get(path, headers=headers) hf_raise_for_status(resp) paginated_discussions = resp.json() @@ -3272,9 +3188,7 @@ def get_discussion_details( if repo_type is None: repo_type = REPO_TYPE_MODEL - path = ( - f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions/{discussion_num}" - ) + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions/{discussion_num}" headers = self._build_hf_headers(token=token) resp = requests.get(path, params={"diff": "1"}, headers=headers) hf_raise_for_status(resp) @@ -3282,17 +3196,9 @@ def get_discussion_details( discussion_details = resp.json() is_pull_request = discussion_details["isPullRequest"] - target_branch = ( - discussion_details["changes"]["base"] if is_pull_request else None - ) - conflicting_files = ( - discussion_details["filesWithConflicts"] if is_pull_request else None - ) - merge_commit_oid = ( - discussion_details["changes"].get("mergeCommitId", None) - if is_pull_request - else None - ) + target_branch = discussion_details["changes"]["base"] if is_pull_request else None + conflicting_files = discussion_details["filesWithConflicts"] if is_pull_request else None + merge_commit_oid = discussion_details["changes"].get("mergeCommitId", None) if is_pull_request else None return DiscussionWithDetails( title=discussion_details["title"], @@ -3860,8 +3766,7 @@ def hide_discussion_comment( """ warnings.warn( - "Hidden comments' content cannot be retrieved anymore. Hiding a comment is" - " irreversible.", + "Hidden comments' content cannot be retrieved anymore. Hiding a comment is irreversible.", UserWarning, ) resp = self._post_discussion_changes( @@ -3874,9 +3779,7 @@ def hide_discussion_comment( return deserialize_event(resp.json()["updatedComment"]) # type: ignore @validate_hf_hub_args - def add_space_secret( - self, repo_id: str, key: str, value: str, *, token: Optional[str] = None - ) -> None: + def add_space_secret(self, repo_id: str, key: str, value: str, *, token: Optional[str] = None) -> None: """Adds or updates a secret in a Space. Secrets allow to set secret keys or tokens to a Space without hardcoding them. @@ -3900,9 +3803,7 @@ def add_space_secret( hf_raise_for_status(r) @validate_hf_hub_args - def delete_space_secret( - self, repo_id: str, key: str, *, token: Optional[str] = None - ) -> None: + def delete_space_secret(self, repo_id: str, key: str, *, token: Optional[str] = None) -> None: """Deletes a secret from a Space. Secrets allow to set secret keys or tokens to a Space without hardcoding them. @@ -3924,9 +3825,7 @@ def delete_space_secret( hf_raise_for_status(r) @validate_hf_hub_args - def get_space_runtime( - self, repo_id: str, *, token: Optional[str] = None - ) -> SpaceRuntime: + def get_space_runtime(self, repo_id: str, *, token: Optional[str] = None) -> SpaceRuntime: """Gets runtime information about a Space. Args: @@ -3953,9 +3852,7 @@ def get_space_runtime( ) @validate_hf_hub_args - def request_space_hardware( - self, repo_id: str, hardware: SpaceHardware, *, token: Optional[str] = None - ) -> None: + def request_space_hardware(self, repo_id: str, hardware: SpaceHardware, *, token: Optional[str] = None) -> None: """Request new hardware for a Space. Args: @@ -4027,9 +3924,7 @@ def _prepare_upload_folder_commit( files_to_add.append( CommitOperationAdd( path_or_fileobj=abs_path, - path_in_repo=os.path.normpath( - os.path.join(path_in_repo, rel_path) - ).replace(os.sep, "/"), + path_in_repo=os.path.normpath(os.path.join(path_in_repo, rel_path)).replace(os.sep, "/"), ) ) @@ -4054,10 +3949,7 @@ def _parse_revision_from_pr_url(pr_url: str) -> str: """ re_match = re.match(_REGEX_DISCUSSION_URL, pr_url) if re_match is None: - raise RuntimeError( - "Unexpected response from the hub, expected a Pull Request URL but got:" - f" '{pr_url}'" - ) + raise RuntimeError(f"Unexpected response from the hub, expected a Pull Request URL but got: '{pr_url}'") return f"refs/pr/{re_match[1]}" diff --git a/src/huggingface_hub/inference_api.py b/src/huggingface_hub/inference_api.py index 95bafc1e01..7c26eb01f9 100644 --- a/src/huggingface_hub/inference_api.py +++ b/src/huggingface_hub/inference_api.py @@ -146,10 +146,7 @@ def __init__( def __repr__(self): # Do not add headers to repr to avoid leaking token. - return ( - f"InferenceAPI(api_url='{self.api_url}', task='{self.task}'," - f" options={self.options})" - ) + return f"InferenceAPI(api_url='{self.api_url}', task='{self.task}', options={self.options})" def __call__( self, @@ -183,9 +180,7 @@ def __call__( payload["parameters"] = params # Make API call - response = requests.post( - self.api_url, headers=self.headers, json=payload, data=data - ) + response = requests.post(self.api_url, headers=self.headers, json=payload, data=data) # Let the user handle the response if raw_response: @@ -202,7 +197,7 @@ def __call__( " the image by yourself." ) - from PIL import Image + from PIL import Image # type: ignore return Image.open(io.BytesIO(response.content)) elif content_type == "application/json": diff --git a/src/huggingface_hub/keras_mixin.py b/src/huggingface_hub/keras_mixin.py index dd3bd37cf8..788a60bc04 100644 --- a/src/huggingface_hub/keras_mixin.py +++ b/src/huggingface_hub/keras_mixin.py @@ -63,9 +63,7 @@ def _create_hyperparameter_table(model): optimizer_params = model.optimizer.get_config() # flatten the configuration optimizer_params = _flatten_dict(optimizer_params) - optimizer_params[ - "training_precision" - ] = tf.keras.mixed_precision.global_policy().name + optimizer_params["training_precision"] = tf.keras.mixed_precision.global_policy().name table = "| Hyperparameters | Value |\n| :-- | :-- |\n" for key, value in optimizer_params.items(): table += f"| {key} | {value} |\n" @@ -170,9 +168,7 @@ def save_pretrained_keras( if is_tf_available(): import tensorflow as tf else: - raise ImportError( - "Called a Tensorflow-specific function but could not import it." - ) + raise ImportError("Called a Tensorflow-specific function but could not import it.") if not model.built: raise ValueError("Model should be built before trying to save") @@ -183,10 +179,7 @@ def save_pretrained_keras( # saving config if config: if not isinstance(config, dict): - raise RuntimeError( - "Provided config to save_pretrained_keras should be a dict. Got:" - f" '{type(config)}'" - ) + raise RuntimeError(f"Provided config to save_pretrained_keras should be a dict. Got: '{type(config)}'") with (save_directory / CONFIG_NAME).open("w") as f: json.dump(config, f) @@ -213,17 +206,14 @@ def save_pretrained_keras( path = save_directory / "history.json" if path.exists(): warnings.warn( - "`history.json` file already exists, it will be overwritten by the" - " history of this version.", + "`history.json` file already exists, it will be overwritten by the history of this version.", UserWarning, ) with path.open("w", encoding="utf-8") as f: json.dump(model.history.history, f, indent=2, sort_keys=True) _create_model_card(model, save_directory, plot_model, metadata) - tf.keras.models.save_model( - model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs - ) + tf.keras.models.save_model(model, save_directory, include_optimizer=include_optimizer, **model_save_kwargs) def from_pretrained_keras(*args, **kwargs): @@ -486,9 +476,7 @@ def _from_pretrained( if is_tf_available(): import tensorflow as tf else: - raise ImportError( - "Called a TensorFlow-specific function but could not import it." - ) + raise ImportError("Called a TensorFlow-specific function but could not import it.") # TODO - Figure out what to do about these config values. Config is not going to be needed to load model cfg = model_kwargs.pop("config", None) diff --git a/src/huggingface_hub/lfs.py b/src/huggingface_hub/lfs.py index 944d674662..514729c57b 100644 --- a/src/huggingface_hub/lfs.py +++ b/src/huggingface_hub/lfs.py @@ -23,9 +23,10 @@ from typing import BinaryIO, Iterable, List, Optional, Tuple import requests -from huggingface_hub.constants import ENDPOINT, REPO_TYPES_URL_PREFIXES from requests.auth import HTTPBasicAuth +from huggingface_hub.constants import ENDPOINT, REPO_TYPES_URL_PREFIXES + from .utils import ( get_token_to_send, hf_raise_for_status, @@ -92,10 +93,7 @@ def _validate_lfs_action(lfs_action: dict): """validates response from the LFS batch endpoint""" if not ( isinstance(lfs_action.get("href"), str) - and ( - lfs_action.get("header") is None - or isinstance(lfs_action.get("header"), dict) - ) + and (lfs_action.get("header") is None or isinstance(lfs_action.get("header"), dict)) ): raise ValueError("lfs_action is improperly formatted") return lfs_action @@ -103,10 +101,7 @@ def _validate_lfs_action(lfs_action: dict): def _validate_batch_actions(lfs_batch_actions: dict): """validates response from the LFS batch endpoint""" - if not ( - isinstance(lfs_batch_actions.get("oid"), str) - and isinstance(lfs_batch_actions.get("size"), int) - ): + if not (isinstance(lfs_batch_actions.get("oid"), str) and isinstance(lfs_batch_actions.get("size"), int)): raise ValueError("lfs_batch_actions is improperly formatted") upload_action = lfs_batch_actions.get("actions", {}).get("upload") @@ -120,10 +115,7 @@ def _validate_batch_actions(lfs_batch_actions: dict): def _validate_batch_error(lfs_batch_error: dict): """validates response from the LFS batch endpoint""" - if not ( - isinstance(lfs_batch_error.get("oid"), str) - and isinstance(lfs_batch_error.get("size"), int) - ): + if not (isinstance(lfs_batch_error.get("oid"), str) and isinstance(lfs_batch_error.get("size"), int)): raise ValueError("lfs_batch_error is improperly formatted") error_info = lfs_batch_error.get("error") if not ( @@ -258,10 +250,7 @@ def lfs_upload( if isinstance(chunk_size, str): chunk_size = int(chunk_size, 10) else: - raise ValueError( - "Malformed response from LFS batch endpoint: `chunk_size`" - " should be a string" - ) + raise ValueError("Malformed response from LFS batch endpoint: `chunk_size` should be a string") _upload_multi_part( completion_url=upload_action["href"], fileobj=fileobj, @@ -385,10 +374,7 @@ def _upload_multi_part( hf_raise_for_status(part_upload_res) etag = part_upload_res.headers.get("etag") if etag is None or etag == "": - raise ValueError( - f"Invalid etag (`{etag}`) returned for part {part_idx +1} of" - f" {num_parts}" - ) + raise ValueError(f"Invalid etag (`{etag}`) returned for part {part_idx +1} of {num_parts}") completion_payload["parts"][part_idx]["etag"] = etag completion_res = requests.post( @@ -466,9 +452,7 @@ def read(self, n: int = -1): if pos >= self._len: return b"" remaining_amount = self._len - pos - data = self.fileobj.read( - remaining_amount if n < 0 else min(n, remaining_amount) - ) + data = self.fileobj.read(remaining_amount if n < 0 else min(n, remaining_amount)) return data def tell(self) -> int: diff --git a/src/huggingface_hub/repocard.py b/src/huggingface_hub/repocard.py index 84fa1dd42e..30ed3b19fa 100644 --- a/src/huggingface_hub/repocard.py +++ b/src/huggingface_hub/repocard.py @@ -5,6 +5,7 @@ import requests import yaml + from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import upload_file from huggingface_hub.repocard_data import ( @@ -24,9 +25,7 @@ TEMPLATE_MODELCARD_PATH = Path(__file__).parent / "templates" / "modelcard_template.md" -TEMPLATE_DATASETCARD_PATH = ( - Path(__file__).parent / "templates" / "datasetcard_template.md" -) +TEMPLATE_DATASETCARD_PATH = Path(__file__).parent / "templates" / "datasetcard_template.md" # exact same regex as in the Hub server. Please keep in sync. # See https://github.com/huggingface/moon-landing/blob/main/server/lib/ViewMarkdown.ts#L18 @@ -101,9 +100,7 @@ def content(self, content: str): raise ValueError("repo card metadata block should be a dict") else: # Model card without metadata... create empty metadata - logger.warning( - "Repo card metadata block was not found. Setting CardData to empty." - ) + logger.warning("Repo card metadata block was not found. Setting CardData to empty.") data_dict = {} self.text = content @@ -176,9 +173,7 @@ def load( token=token, ) else: - raise ValueError( - f"Cannot load RepoCard: path not found on disk ({repo_id_or_path})." - ) + raise ValueError(f"Cannot load RepoCard: path not found on disk ({repo_id_or_path}).") # Preserve newlines in the existing file. with Path(card_path).open(mode="r", newline="", encoding="utf-8") as f: @@ -215,9 +210,7 @@ def validate(self, repo_type: Optional[str] = None): headers = {"Accept": "text/plain"} try: - r = requests.post( - "https://huggingface.co/api/validate-yaml", body, headers=headers - ) + r = requests.post("https://huggingface.co/api/validate-yaml", body, headers=headers) r.raise_for_status() except requests.exceptions.HTTPError as exc: if r.status_code == 400: @@ -321,9 +314,7 @@ def from_template( kwargs = card_data.to_dict().copy() kwargs.update(template_kwargs) # Template_kwargs have priority - template = jinja2.Template( - Path(template_path or cls.default_template_path).read_text() - ) + template = jinja2.Template(Path(template_path or cls.default_template_path).read_text()) content = template.render(card_data=card_data.to_yaml(), **kwargs) return cls(content) @@ -472,7 +463,7 @@ def from_template( # type: ignore # violates Liskov property but easier to use return super().from_template(card_data, template_path, **template_kwargs) -def _detect_line_ending(content: str) -> Literal["\r", "\n", "\r\n", None]: +def _detect_line_ending(content: str) -> Literal["\r", "\n", "\r\n", None]: # noqa: F722 """Detect the line ending of a string. Used by RepoCard to avoid making huge diff on newlines. Uses same implem as in Hub server, keep it in sync. @@ -531,11 +522,7 @@ def metadata_save(local_path: Union[str, Path], data: Dict) -> None: # sort_keys: keep dict order match = REGEX_YAML_BLOCK.search(content) if match: - output = ( - content[: match.start()] - + f"---{line_break}{data_yaml}---{line_break}" - + content[match.end() :] - ) + output = content[: match.start()] + f"---{line_break}{data_yaml}---{line_break}" + content[match.end() :] else: output = f"---{line_break}{data_yaml}---{line_break}{content}" @@ -739,11 +726,7 @@ def metadata_update( ``` """ - commit_message = ( - commit_message - if commit_message is not None - else "Update metadata with huggingface_hub" - ) + commit_message = commit_message if commit_message is not None else "Update metadata with huggingface_hub" # Card class given repo_type card_class: Type[RepoCard] @@ -762,10 +745,7 @@ def metadata_update( card = card_class.load(repo_id, token=token, repo_type=repo_type) except EntryNotFoundError: if repo_type == "space": - raise ValueError( - "Cannot update metadata on a Space that doesn't contain a `README.md`" - " file." - ) + raise ValueError("Cannot update metadata on a Space that doesn't contain a `README.md` file.") # Initialize a ModelCard or DatasetCard from default template and no data. card = card_class.from_template(CardData()) diff --git a/src/huggingface_hub/repocard_data.py b/src/huggingface_hub/repocard_data.py index bc70766db6..7db39a0f3b 100644 --- a/src/huggingface_hub/repocard_data.py +++ b/src/huggingface_hub/repocard_data.py @@ -266,9 +266,7 @@ def __init__( self.model_name = model_name self.eval_results = eval_results except KeyError: - logger.warning( - "Invalid model-index. Not loading eval results into CardData." - ) + logger.warning("Invalid model-index. Not loading eval results into CardData.") super().__init__(**kwargs) @@ -276,16 +274,12 @@ def __init__( if type(self.eval_results) == EvalResult: self.eval_results = [self.eval_results] if self.model_name is None: - raise ValueError( - "Passing `eval_results` requires `model_name` to be set." - ) + raise ValueError("Passing `eval_results` requires `model_name` to be set.") def _to_dict(self, data_dict): """Format the internal data dict. In this case, we convert eval results to a valid model index""" if self.eval_results is not None: - data_dict["model-index"] = eval_results_to_model_index( - self.model_name, self.eval_results - ) + data_dict["model-index"] = eval_results_to_model_index(self.model_name, self.eval_results) del data_dict["eval_results"], data_dict["model_name"] @@ -368,9 +362,7 @@ def _to_dict(self, data_dict): data_dict["train-eval-index"] = data_dict.pop("train_eval_index") -def model_index_to_eval_results( - model_index: List[Dict[str, Any]] -) -> Tuple[str, List[EvalResult]]: +def model_index_to_eval_results(model_index: List[Dict[str, Any]]) -> Tuple[str, List[EvalResult]]: """Takes in a model index and returns the model name and a list of `huggingface_hub.EvalResult` objects. A detailed spec of the model index can be found here: @@ -477,18 +469,12 @@ def _remove_none(obj): if isinstance(obj, (list, tuple, set)): return type(obj)(_remove_none(x) for x in obj if x is not None) elif isinstance(obj, dict): - return type(obj)( - (_remove_none(k), _remove_none(v)) - for k, v in obj.items() - if k is not None and v is not None - ) + return type(obj)((_remove_none(k), _remove_none(v)) for k, v in obj.items() if k is not None and v is not None) else: return obj -def eval_results_to_model_index( - model_name: str, eval_results: List[EvalResult] -) -> List[Dict[str, Any]]: +def eval_results_to_model_index(model_name: str, eval_results: List[EvalResult]) -> List[Dict[str, Any]]: """Takes in given model name and list of `huggingface_hub.EvalResult` and returns a valid model-index that will be compatible with the format expected by the Hugging Face Hub. diff --git a/src/huggingface_hub/repository.py b/src/huggingface_hub/repository.py index ddda4fb535..676d284831 100644 --- a/src/huggingface_hub/repository.py +++ b/src/huggingface_hub/repository.py @@ -124,9 +124,7 @@ def is_git_repo(folder: Union[str, Path]) -> bool: otherwise. """ folder_exists = os.path.exists(os.path.join(folder, ".git")) - git_branch = subprocess.run( - "git branch".split(), cwd=folder, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) + git_branch = subprocess.run("git branch".split(), cwd=folder, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return folder_exists and git_branch.returncode == 0 @@ -234,17 +232,13 @@ def is_binary_file(filename: Union[str, Path]) -> bool: # Code sample taken from the following stack overflow thread # https://stackoverflow.com/questions/898669/how-can-i-detect-if-a-file-is-binary-non-text-in-python/7392391#7392391 - text_chars = bytearray( - {7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F} - ) + text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F}) return bool(content.translate(None, text_chars)) except UnicodeDecodeError: return True -def files_to_be_staged( - pattern: str = ".", folder: Union[str, Path, None] = None -) -> List[str]: +def files_to_be_staged(pattern: str = ".", folder: Union[str, Path, None] = None) -> List[str]: """ Returns a list of filenames that are to be staged. @@ -520,14 +514,9 @@ def __init__( if is_git_repo(self.local_dir): logger.debug("[Repository] is a valid git repo") else: - raise ValueError( - "If not specifying `clone_from`, you need to pass Repository a" - " valid git clone." - ) + raise ValueError("If not specifying `clone_from`, you need to pass Repository a valid git clone.") - if self.huggingface_token is not None and ( - git_email is None or git_user is None - ): + if self.huggingface_token is not None and (git_email is None or git_user is None): user = self.client.whoami(self.huggingface_token) if git_email is None: @@ -558,9 +547,7 @@ def current_branch(self) -> str: `str`: Current checked out branch. """ try: - result = run_subprocess( - "git rev-parse --abbrev-ref HEAD", self.local_dir - ).stdout.strip() + result = run_subprocess("git rev-parse --abbrev-ref HEAD", self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) @@ -577,14 +564,10 @@ def check_git_versions(self): try: git_version = run_subprocess("git --version", self.local_dir).stdout.strip() except FileNotFoundError: - raise EnvironmentError( - "Looks like you do not have git installed, please install." - ) + raise EnvironmentError("Looks like you do not have git installed, please install.") try: - lfs_version = run_subprocess( - "git-lfs --version", self.local_dir - ).stdout.strip() + lfs_version = run_subprocess("git-lfs --version", self.local_dir).stdout.strip() except FileNotFoundError: raise EnvironmentError( "Looks like you do not have git-lfs installed, please install." @@ -645,12 +628,8 @@ def clone_from(self, repo_url: str, token: Union[bool, str, None] = None): ) hub_url = self.client.endpoint - if hub_url in repo_url or ( - "http" not in repo_url and len(repo_url.split("/")) <= 2 - ): - repo_type, namespace, repo_name = repo_type_and_id_from_hf_id( - repo_url, hub_url=hub_url - ) + if hub_url in repo_url or ("http" not in repo_url and len(repo_url.split("/")) <= 2): + repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(repo_url, hub_url=hub_url) repo_id = f"{namespace}/{repo_name}" if namespace is not None else repo_name if repo_type is not None: @@ -710,9 +689,7 @@ def clone_from(self, repo_url: str, token: Union[bool, str, None] = None): " `repo.git_pull()`." ) else: - output = run_subprocess( - "git remote get-url origin", self.local_dir, check=False - ) + output = run_subprocess("git remote get-url origin", self.local_dir, check=False) error_msg = ( f"Tried to clone {clean_repo_url} in an unrelated git" @@ -720,21 +697,14 @@ def clone_from(self, repo_url: str, token: Union[bool, str, None] = None): f" a remote with the following URL: {clean_repo_url}." ) if output.returncode == 0: - clean_local_remote_url = re.sub( - r"https://.*@", "https://", output.stdout - ) - error_msg += ( - "\nLocal path has its origin defined as:" - f" {clean_local_remote_url}" - ) + clean_local_remote_url = re.sub(r"https://.*@", "https://", output.stdout) + error_msg += f"\nLocal path has its origin defined as: {clean_local_remote_url}" raise EnvironmentError(error_msg) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) - def git_config_username_and_email( - self, git_user: Optional[str] = None, git_email: Optional[str] = None - ): + def git_config_username_and_email(self, git_user: Optional[str] = None, git_email: Optional[str] = None): """ Sets git username and email (only in the current repo). @@ -746,14 +716,10 @@ def git_config_username_and_email( """ try: if git_user is not None: - run_subprocess( - "git config user.name".split() + [git_user], self.local_dir - ) + run_subprocess("git config user.name".split() + [git_user], self.local_dir) if git_email is not None: - run_subprocess( - f"git config user.email {git_email}".split(), self.local_dir - ) + run_subprocess(f"git config user.email {git_email}".split(), self.local_dir) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) @@ -836,14 +802,10 @@ def list_deleted_files(self) -> List[str]: modified_files_statuses = [status.strip() for status in git_status.split("\n")] # Only keep files that are deleted using the D prefix - deleted_files_statuses = [ - status for status in modified_files_statuses if "D" in status.split()[0] - ] + deleted_files_statuses = [status for status in modified_files_statuses if "D" in status.split()[0]] # Remove the D prefix and strip to keep only the relevant filename - deleted_files = [ - status.split()[-1].strip() for status in deleted_files_statuses - ] + deleted_files = [status.split()[-1].strip() for status in deleted_files_statuses] return deleted_files @@ -969,11 +931,7 @@ def auto_track_large_files(self, pattern: str = ".") -> List[str]: path_to_file = os.path.join(os.getcwd(), self.local_dir, filename) size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024) - if ( - size_in_mb >= 10 - and not is_tracked_with_lfs(path_to_file) - and not is_git_ignored(path_to_file) - ): + if size_in_mb >= 10 and not is_tracked_with_lfs(path_to_file) and not is_git_ignored(path_to_file): self.lfs_track(filename) files_to_be_tracked_with_lfs.append(filename) @@ -995,9 +953,7 @@ def lfs_prune(self, recent=False): """ try: with _lfs_log_progress(): - result = run_subprocess( - f"git lfs prune {'--recent' if recent else ''}", self.local_dir - ) + result = run_subprocess(f"git lfs prune {'--recent' if recent else ''}", self.local_dir) logger.info(result.stdout) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) @@ -1069,9 +1025,7 @@ def git_commit(self, commit_message: str = "commit files to HF hub"): The message attributed to the commit. """ try: - result = run_subprocess( - "git commit -v -m".split() + [commit_message], self.local_dir - ) + result = run_subprocess("git commit -v -m".split() + [commit_message], self.local_dir) logger.info(f"Committed:\n{result.stdout}\n") except subprocess.CalledProcessError as exc: if len(exc.stderr) > 0: @@ -1115,9 +1069,7 @@ def git_push( number_of_commits = commits_to_push(self.local_dir, upstream) if number_of_commits > 1: - logger.warning( - f"Several commits ({number_of_commits}) will be pushed upstream." - ) + logger.warning(f"Several commits ({number_of_commits}) will be pushed upstream.") if blocking: logger.warning("The progress bars may be unreliable.") @@ -1140,9 +1092,7 @@ def git_push( logger.warning(stderr) if return_code: - raise subprocess.CalledProcessError( - return_code, process.args, output=stdout, stderr=stderr - ) + raise subprocess.CalledProcessError(return_code, process.args, output=stdout, stderr=stderr) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) @@ -1197,12 +1147,9 @@ def git_checkout(self, revision: str, create_branch_ok: bool = False): raise EnvironmentError(exc.stderr) else: try: - result = run_subprocess( - f"git checkout -b {revision}", self.local_dir - ) + result = run_subprocess(f"git checkout -b {revision}", self.local_dir) logger.warning( - f"Revision `{revision}` does not exist. Created and checked out" - f" branch `{revision}`." + f"Revision `{revision}` does not exist. Created and checked out branch `{revision}`." ) logger.warning(result.stdout) except subprocess.CalledProcessError as exc: @@ -1224,9 +1171,7 @@ def tag_exists(self, tag_name: str, remote: Optional[str] = None) -> bool: """ if remote: try: - result = run_subprocess( - f"git ls-remote origin refs/tags/{tag_name}", self.local_dir - ).stdout.strip() + result = run_subprocess(f"git ls-remote origin refs/tags/{tag_name}", self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) @@ -1265,25 +1210,19 @@ def delete_tag(self, tag_name: str, remote: Optional[str] = None) -> bool: if delete_locally: try: - run_subprocess( - ["git", "tag", "-d", tag_name], self.local_dir - ).stdout.strip() + run_subprocess(["git", "tag", "-d", tag_name], self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) if remote and delete_remotely: try: - run_subprocess( - f"git push {remote} --delete {tag_name}", self.local_dir - ).stdout.strip() + run_subprocess(f"git push {remote} --delete {tag_name}", self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) return True - def add_tag( - self, tag_name: str, message: Optional[str] = None, remote: Optional[str] = None - ): + def add_tag(self, tag_name: str, message: Optional[str] = None, remote: Optional[str] = None): """ Add a tag at the current head and push it @@ -1313,9 +1252,7 @@ def add_tag( if remote: try: - run_subprocess( - f"git push {remote} {tag_name}", self.local_dir - ).stdout.strip() + run_subprocess(f"git push {remote} {tag_name}", self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) @@ -1327,9 +1264,7 @@ def is_repo_clean(self) -> bool: `bool`: `True` if the git status is clean, `False` otherwise. """ try: - git_status = run_subprocess( - "git status --porcelain", self.local_dir - ).stdout.strip() + git_status = run_subprocess("git status --porcelain", self.local_dir).stdout.strip() except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) @@ -1446,10 +1381,7 @@ def commit( logger.warning("Pulling changes ...") self.git_pull(rebase=True) else: - logger.warning( - "The current branch has no upstream branch. Will push to 'origin" - f" {self.current_branch}'" - ) + logger.warning(f"The current branch has no upstream branch. Will push to 'origin {self.current_branch}'") current_working_directory = os.getcwd() os.chdir(os.path.join(current_working_directory, self.local_dir)) @@ -1475,10 +1407,7 @@ def commit( except OSError as e: # If no changes are detected, there is nothing to commit. if "could not read Username" in str(e): - raise OSError( - "Couldn't authenticate user for push. Did you set" - " `token` to `True`?" - ) from e + raise OSError("Couldn't authenticate user for push. Did you set `token` to `True`?") from e else: raise e @@ -1514,17 +1443,13 @@ def wait_for_commands(self): """ index = 0 for command_failed in self.commands_failed: - logger.error( - f"The {command_failed.title} command with PID" - f" {command_failed._process.pid} failed." - ) + logger.error(f"The {command_failed.title} command with PID {command_failed._process.pid} failed.") logger.error(command_failed.stderr) while self.commands_in_progress: if index % 10 == 0: logger.error( - "Waiting for the following commands to finish before shutting" - f" down: {self.commands_in_progress}." + f"Waiting for the following commands to finish before shutting down: {self.commands_in_progress}." ) index += 1 diff --git a/src/huggingface_hub/utils/_cache_assets.py b/src/huggingface_hub/utils/_cache_assets.py index 467ddaae54..d6a6421e3b 100644 --- a/src/huggingface_hub/utils/_cache_assets.py +++ b/src/huggingface_hub/utils/_cache_assets.py @@ -129,10 +129,7 @@ def cached_assets_path( try: path.mkdir(exist_ok=True, parents=True) except (FileExistsError, NotADirectoryError): - raise ValueError( - "Corrupted assets folder: cannot create directory because of an existing" - f" file ({path})." - ) + raise ValueError(f"Corrupted assets folder: cannot create directory because of an existing file ({path}).") # Return return path diff --git a/src/huggingface_hub/utils/_cache_manager.py b/src/huggingface_hub/utils/_cache_manager.py index a4e966aeaf..3e1443a789 100644 --- a/src/huggingface_hub/utils/_cache_manager.py +++ b/src/huggingface_hub/utils/_cache_manager.py @@ -430,9 +430,7 @@ def delete_revisions(self, *revisions: str) -> DeleteCacheStrategy: """ hashes_to_delete: Set[str] = set(revisions) - repos_with_revisions: Dict[ - CachedRepoInfo, Set[CachedRevisionInfo] - ] = defaultdict(set) + repos_with_revisions: Dict[CachedRepoInfo, Set[CachedRevisionInfo]] = defaultdict(set) for repo in self.repos: for revision in repo.revisions: @@ -441,10 +439,7 @@ def delete_revisions(self, *revisions: str) -> DeleteCacheStrategy: hashes_to_delete.remove(revision.commit_hash) if len(hashes_to_delete) > 0: - logger.warning( - "Revision(s) not found - cannot delete them:" - f" {', '.join(hashes_to_delete)}" - ) + logger.warning(f"Revision(s) not found - cannot delete them: {', '.join(hashes_to_delete)}") delete_strategy_blobs: Set[Path] = set() delete_strategy_refs: Set[Path] = set() @@ -591,8 +586,10 @@ def scan_cache_dir(cache_dir: Optional[Union[str, Path]] = None) -> HFCacheInfo: cache_dir = Path(cache_dir).expanduser().resolve() if not cache_dir.exists(): raise CacheNotFound( - f"Cache directory not found: {cache_dir}. Please use `cache_dir` argument" - " or set `HUGGINGFACE_HUB_CACHE` environment variable.", + ( + f"Cache directory not found: {cache_dir}. Please use `cache_dir`" + " argument or set `HUGGINGFACE_HUB_CACHE` environment variable." + ), cache_dir=cache_dir, ) @@ -627,9 +624,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: raise CorruptedCacheException(f"Repo path is not a directory: {repo_path}") if "--" not in repo_path.name: - raise CorruptedCacheException( - f"Repo path is not a valid HuggingFace cache directory: {repo_path}" - ) + raise CorruptedCacheException(f"Repo path is not a valid HuggingFace cache directory: {repo_path}") repo_type, repo_id = repo_path.name.split("--", maxsplit=1) repo_type = repo_type[:-1] # "models" -> "model" @@ -637,8 +632,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: if repo_type not in {"dataset", "model", "space"}: raise CorruptedCacheException( - f"Repo type must be `dataset`, `model` or `space`, found `{repo_type}`" - f" ({repo_path})." + f"Repo type must be `dataset`, `model` or `space`, found `{repo_type}` ({repo_path})." ) blob_stats: Dict[Path, os.stat_result] = {} # Key is blob_path, value is blob stats @@ -647,9 +641,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: refs_path = repo_path / "refs" if not snapshots_path.exists() or not snapshots_path.is_dir(): - raise CorruptedCacheException( - f"Snapshots dir doesn't exist in cached repo: {snapshots_path}" - ) + raise CorruptedCacheException(f"Snapshots dir doesn't exist in cached repo: {snapshots_path}") # Scan over `refs` directory @@ -663,9 +655,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: # └── pr # └── 1 if refs_path.is_file(): - raise CorruptedCacheException( - f"Refs directory cannot be a file: {refs_path}" - ) + raise CorruptedCacheException(f"Refs directory cannot be a file: {refs_path}") for ref_path in refs_path.glob("**/*"): # glob("**/*") iterates over all files and directories -> skip directories @@ -682,9 +672,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: cached_revisions: Set[CachedRevisionInfo] = set() for revision_path in snapshots_path.iterdir(): if revision_path.is_file(): - raise CorruptedCacheException( - f"Snapshots folder corrupted. Found a file: {revision_path}" - ) + raise CorruptedCacheException(f"Snapshots folder corrupted. Found a file: {revision_path}") cached_files = set() for file_path in revision_path.glob("**/*"): @@ -694,9 +682,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: blob_path = Path(file_path).resolve() if not blob_path.exists(): - raise CorruptedCacheException( - f"Blob missing (broken symlink): {blob_path}" - ) + raise CorruptedCacheException(f"Blob missing (broken symlink): {blob_path}") if blob_path not in blob_stats: blob_stats[blob_path] = blob_path.stat() @@ -715,9 +701,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: # Last modified is either the last modified blob file or the revision folder # itself if it is empty if len(cached_files) > 0: - revision_last_modified = max( - blob_stats[file.blob_path].st_mtime for file in cached_files - ) + revision_last_modified = max(blob_stats[file.blob_path].st_mtime for file in cached_files) else: revision_last_modified = revision_path.stat().st_mtime @@ -727,8 +711,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: files=frozenset(cached_files), refs=frozenset(refs_by_hash.pop(revision_path.name, set())), size_on_disk=sum( - blob_stats[blob_path].st_size - for blob_path in set(file.blob_path for file in cached_files) + blob_stats[blob_path].st_size for blob_path in set(file.blob_path for file in cached_files) ), snapshot_path=revision_path, last_modified=revision_last_modified, @@ -738,8 +721,7 @@ def _scan_cached_repo(repo_path: Path) -> CachedRepoInfo: # Check that all refs referred to an existing revision if len(refs_by_hash) > 0: raise CorruptedCacheException( - "Reference(s) refer to missing commit hashes:" - f" {dict(refs_by_hash)} ({repo_path})." + f"Reference(s) refer to missing commit hashes: {dict(refs_by_hash)} ({repo_path})." ) # Last modified is either the last modified blob file or the repo folder itself if @@ -823,10 +805,6 @@ def _try_delete_path(path: Path, path_type: str) -> None: else: shutil.rmtree(path) except FileNotFoundError: - logger.warning( - f"Couldn't delete {path_type}: file not found ({path})", exc_info=True - ) + logger.warning(f"Couldn't delete {path_type}: file not found ({path})", exc_info=True) except PermissionError: - logger.warning( - f"Couldn't delete {path_type}: permission denied ({path})", exc_info=True - ) + logger.warning(f"Couldn't delete {path_type}: permission denied ({path})", exc_info=True) diff --git a/src/huggingface_hub/utils/_deprecation.py b/src/huggingface_hub/utils/_deprecation.py index 68a14b11be..b8031ca430 100644 --- a/src/huggingface_hub/utils/_deprecation.py +++ b/src/huggingface_hub/utils/_deprecation.py @@ -36,9 +36,11 @@ def inner_f(*args, **kwargs): ] args_msg = ", ".join(args_msg) warnings.warn( - f"Deprecated positional argument(s) used in '{f.__name__}': pass" - f" {args_msg} as keyword args. From version {version} passing these as" - " positional arguments will result in an error,", + ( + f"Deprecated positional argument(s) used in '{f.__name__}': pass" + f" {args_msg} as keyword args. From version {version} passing these" + " as positional arguments will result in an error," + ), FutureWarning, ) kwargs.update(zip(sig.parameters, args)) @@ -120,8 +122,7 @@ def _inner_deprecate_method(f): @wraps(f) def inner_f(*args, **kwargs): warning_message = ( - f"'{f.__name__}' (from '{f.__module__}') is deprecated and will be" - f" removed from version '{version}'." + f"'{f.__name__}' (from '{f.__module__}') is deprecated and will be removed from version '{version}'." ) if message is not None: warning_message += " " + message @@ -189,14 +190,9 @@ class DeprecateListMetaclass(type): def __new__(cls, clsname, bases, attrs): # Check consistency if "_deprecate" not in attrs: - raise TypeError( - "A `_deprecate` method must be implemented to use" - " `DeprecateListMetaclass`." - ) + raise TypeError("A `_deprecate` method must be implemented to use `DeprecateListMetaclass`.") if list not in bases: - raise TypeError( - "Class must inherit from `list` to use `DeprecateListMetaclass`." - ) + raise TypeError("Class must inherit from `list` to use `DeprecateListMetaclass`.") # Create decorator to deprecate list-only methods, including magic ones def _with_deprecation(f, name): diff --git a/src/huggingface_hub/utils/_errors.py b/src/huggingface_hub/utils/_errors.py index 766cff7689..53037255ff 100644 --- a/src/huggingface_hub/utils/_errors.py +++ b/src/huggingface_hub/utils/_errors.py @@ -52,9 +52,7 @@ def __init__(self, message: str, response: Optional[Response]): server_message_from_headers = response.headers.get("X-Error-Message") server_message_from_body = server_data.get("error") server_multiple_messages_from_body = "\n".join( - error["message"] - for error in server_data.get("errors", []) - if "message" in error + error["message"] for error in server_data.get("errors", []) if "message" in error ) # Concatenate error messages @@ -203,9 +201,7 @@ class BadRequestError(HfHubHTTPError, ValueError): """ -def hf_raise_for_status( - response: Response, endpoint_name: Optional[str] = None -) -> None: +def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) -> None: """ Internal version of `response.raise_for_status()` that will refine a potential HTTPError. Raised exception will be an instance of `HfHubHTTPError`. @@ -266,26 +262,16 @@ def hf_raise_for_status( error_code = response.headers.get("X-Error-Code") if error_code == "RevisionNotFound": - message = ( - f"{response.status_code} Client Error." - + "\n\n" - + f"Revision Not Found for url: {response.url}." - ) + message = f"{response.status_code} Client Error." + "\n\n" + f"Revision Not Found for url: {response.url}." raise RevisionNotFoundError(message, response) from e elif error_code == "EntryNotFound": - message = ( - f"{response.status_code} Client Error." - + "\n\n" - + f"Entry Not Found for url: {response.url}." - ) + message = f"{response.status_code} Client Error." + "\n\n" + f"Entry Not Found for url: {response.url}." raise EntryNotFoundError(message, response) from e elif error_code == "GatedRepo": message = ( - f"{response.status_code} Client Error." - + "\n\n" - + f"Cannot access gated repo for url {response.url}." + f"{response.status_code} Client Error." + "\n\n" + f"Cannot access gated repo for url {response.url}." ) raise GatedRepoError(message, response) from e @@ -307,9 +293,7 @@ def hf_raise_for_status( elif response.status_code == 400: message = ( - f"\n\nBad request for {endpoint_name} endpoint:" - if endpoint_name is not None - else "\n\nBad request:" + f"\n\nBad request for {endpoint_name} endpoint:" if endpoint_name is not None else "\n\nBad request:" ) raise BadRequestError(message, response=response) from e @@ -340,9 +324,7 @@ def _raise_convert_bad_request(response: Response, endpoint_name: str): hf_raise_for_status(response, endpoint_name=endpoint_name) -def _format_error_message( - message: str, request_id: Optional[str], server_message: Optional[str] -) -> str: +def _format_error_message(message: str, request_id: Optional[str], server_message: Optional[str]) -> str: """ Format the `HfHubHTTPError` error message based on initial message and information returned by the server. @@ -350,11 +332,7 @@ def _format_error_message( Used when initializing `HfHubHTTPError`. """ # Add message from response body - if ( - server_message is not None - and len(server_message) > 0 - and server_message.lower() not in message.lower() - ): + if server_message is not None and len(server_message) > 0 and server_message.lower() not in message.lower(): if "\n\n" in message: message += "\n" + server_message else: @@ -365,9 +343,7 @@ def _format_error_message( request_id_message = f" (Request ID: {request_id})" if "\n" in message: newline_index = message.index("\n") - message = ( - message[:newline_index] + request_id_message + message[newline_index:] - ) + message = message[:newline_index] + request_id_message + message[newline_index:] else: message += request_id_message diff --git a/src/huggingface_hub/utils/_fixes.py b/src/huggingface_hub/utils/_fixes.py index cd2ca678b1..ff4f9e2d70 100644 --- a/src/huggingface_hub/utils/_fixes.py +++ b/src/huggingface_hub/utils/_fixes.py @@ -31,9 +31,7 @@ # >>> yaml_dump({"emoji": "👀", "some unicode": "日本か"}) # 'emoji: "👀"\nsome unicode: "日本か"\n' # ``` -yaml_dump: Callable[..., str] = partial( # type: ignore - yaml.dump, stream=None, allow_unicode=True -) +yaml_dump: Callable[..., str] = partial(yaml.dump, stream=None, allow_unicode=True) # type: ignore @contextlib.contextmanager @@ -53,9 +51,7 @@ def SoftTemporaryDirectory( See https://www.scivision.dev/python-tempfile-permission-error-windows/. """ - tmpdir = tempfile.TemporaryDirectory( - prefix=prefix, suffix=suffix, dir=dir, **kwargs - ) + tmpdir = tempfile.TemporaryDirectory(prefix=prefix, suffix=suffix, dir=dir, **kwargs) yield tmpdir.name try: diff --git a/src/huggingface_hub/utils/_git_credential.py b/src/huggingface_hub/utils/_git_credential.py index 1cc5a2d061..48e319fe9f 100644 --- a/src/huggingface_hub/utils/_git_credential.py +++ b/src/huggingface_hub/utils/_git_credential.py @@ -40,18 +40,14 @@ def list_credential_helpers(folder: Optional[str] = None) -> List[str]: # See: https://github.com/huggingface/huggingface_hub/pull/1138#discussion_r1013324508 return sorted( # Sort for nice printing { # Might have some duplicates - line.split("=")[-1].split()[0] - for line in output.split("\n") - if "credential.helper" in line + line.split("=")[-1].split()[0] for line in output.split("\n") if "credential.helper" in line } ) except subprocess.CalledProcessError as exc: raise EnvironmentError(exc.stderr) -def set_git_credential( - token: str, username: str = "hf_user", folder: Optional[str] = None -) -> None: +def set_git_credential(token: str, username: str = "hf_user", folder: Optional[str] = None) -> None: """Save a username/token pair in git credential for HF Hub registry. Credentials are saved in all configured helpers (store, cache, macOS keychain,...). @@ -70,15 +66,11 @@ def set_git_credential( stdin, _, ): - stdin.write( - f"url={ENDPOINT}\nusername={username.lower()}\npassword={token}\n\n" - ) + stdin.write(f"url={ENDPOINT}\nusername={username.lower()}\npassword={token}\n\n") stdin.flush() -def unset_git_credential( - username: str = "hf_user", folder: Optional[str] = None -) -> None: +def unset_git_credential(username: str = "hf_user", folder: Optional[str] = None) -> None: """Erase credentials from git credential for HF Hub registry. Credentials are erased from the configured helpers (store, cache, macOS @@ -122,10 +114,7 @@ def write_to_credential_store(username: str, password: str) -> None: @_deprecate_method( version="0.14", - message=( - "Please open an issue on https://github.com/huggingface/huggingface_hub if this" - " a useful feature for you." - ), + message="Please open an issue on https://github.com/huggingface/huggingface_hub if this a useful feature for you.", ) def read_from_credential_store( username: Optional[str] = None, diff --git a/src/huggingface_hub/utils/_http.py b/src/huggingface_hub/utils/_http.py index 24b203162f..d0ee2c0953 100644 --- a/src/huggingface_hub/utils/_http.py +++ b/src/huggingface_hub/utils/_http.py @@ -131,10 +131,7 @@ def http_backoff( return response # Wrong status code returned (HTTP 503 for instance) - logger.warning( - f"HTTP Error {response.status_code} thrown while requesting" - f" {method} {url}" - ) + logger.warning(f"HTTP Error {response.status_code} thrown while requesting {method} {url}") if nb_tries > max_retries: response.raise_for_status() # Will raise uncaught exception # We return response to avoid infinite loop in the corner case where the diff --git a/src/huggingface_hub/utils/_paths.py b/src/huggingface_hub/utils/_paths.py index 3a496696cc..93a993c17d 100644 --- a/src/huggingface_hub/utils/_paths.py +++ b/src/huggingface_hub/utils/_paths.py @@ -97,10 +97,7 @@ def _identity(item: T) -> str: return item if isinstance(item, Path): return str(item) - raise ValueError( - f"Please provide `key` argument in `filter_repo_objects`: `{item}` is" - " not a string." - ) + raise ValueError(f"Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.") key = _identity # Items must be `str` or `Path`, otherwise raise ValueError @@ -108,15 +105,11 @@ def _identity(item: T) -> str: path = key(item) # Skip if there's an allowlist and path doesn't match any - if allow_patterns is not None and not any( - fnmatch(path, r) for r in allow_patterns - ): + if allow_patterns is not None and not any(fnmatch(path, r) for r in allow_patterns): continue # Skip if there's a denylist and path matches any - if ignore_patterns is not None and any( - fnmatch(path, r) for r in ignore_patterns - ): + if ignore_patterns is not None and any(fnmatch(path, r) for r in ignore_patterns): continue yield item diff --git a/src/huggingface_hub/utils/_validators.py b/src/huggingface_hub/utils/_validators.py index a41d67c25c..89caf19fd0 100644 --- a/src/huggingface_hub/utils/_validators.py +++ b/src/huggingface_hub/utils/_validators.py @@ -99,9 +99,7 @@ def validate_hf_hub_args(fn: CallableT) -> CallableT: # Should the validator switch `use_auth_token` values to `token`? In practice, always # True in `huggingface_hub`. Might not be the case in a downstream library. - check_use_auth_token = ( - "use_auth_token" not in signature.parameters and "token" in signature.parameters - ) + check_use_auth_token = "use_auth_token" not in signature.parameters and "token" in signature.parameters @wraps(fn) def _inner_fn(*args, **kwargs): @@ -117,9 +115,7 @@ def _inner_fn(*args, **kwargs): has_token = True if check_use_auth_token: - kwargs = smoothly_deprecate_use_auth_token( - fn_name=fn.__name__, has_token=has_token, kwargs=kwargs - ) + kwargs = smoothly_deprecate_use_auth_token(fn_name=fn.__name__, has_token=has_token, kwargs=kwargs) return fn(*args, **kwargs) @@ -158,9 +154,7 @@ def validate_repo_id(repo_id: str) -> None: """ if not isinstance(repo_id, str): # Typically, a Path is not a repo_id - raise HFValidationError( - f"Repo id must be a string, not {type(repo_id)}: '{repo_id}'." - ) + raise HFValidationError(f"Repo id must be a string, not {type(repo_id)}: '{repo_id}'.") if repo_id.count("/") > 1: raise HFValidationError( @@ -182,9 +176,7 @@ def validate_repo_id(repo_id: str) -> None: raise HFValidationError(f"Repo_id cannot end by '.git': '{repo_id}'.") -def smoothly_deprecate_use_auth_token( - fn_name: str, has_token: bool, kwargs: Dict[str, Any] -) -> Dict[str, Any]: +def smoothly_deprecate_use_auth_token(fn_name: str, has_token: bool, kwargs: Dict[str, Any]) -> Dict[str, Any]: """Smoothly deprecate `use_auth_token` in the `huggingface_hub` codebase. The long-term goal is to remove any mention of `use_auth_token` in the codebase in diff --git a/src/huggingface_hub/utils/endpoint_helpers.py b/src/huggingface_hub/utils/endpoint_helpers.py index 6049f8f9e2..4286feb9d3 100644 --- a/src/huggingface_hub/utils/endpoint_helpers.py +++ b/src/huggingface_hub/utils/endpoint_helpers.py @@ -36,9 +36,7 @@ def _filter_emissions( A maximum carbon threshold to filter by, such as 10. """ if minimum_threshold is None and maximum_threshold is None: - raise ValueError( - "Both `minimum_threshold` and `maximum_threshold` cannot both be `None`" - ) + raise ValueError("Both `minimum_threshold` and `maximum_threshold` cannot both be `None`") if minimum_threshold is None: minimum_threshold = -1 if maximum_threshold is None: diff --git a/src/huggingface_hub/utils/logging.py b/src/huggingface_hub/utils/logging.py index 20ba5a8634..187641d03b 100644 --- a/src/huggingface_hub/utils/logging.py +++ b/src/huggingface_hub/utils/logging.py @@ -16,14 +16,16 @@ import logging import os -from logging import CRITICAL # NOQA -from logging import DEBUG # NOQA -from logging import ERROR # NOQA -from logging import FATAL # NOQA -from logging import INFO # NOQA -from logging import NOTSET # NOQA -from logging import WARN # NOQA -from logging import WARNING # NOQA +from logging import ( + CRITICAL, # NOQA + DEBUG, # NOQA + ERROR, # NOQA + FATAL, # NOQA + INFO, # NOQA + NOTSET, # NOQA + WARN, # NOQA + WARNING, # NOQA +) from typing import Optional diff --git a/tests/conftest.py b/tests/conftest.py index b936b468ed..34fab9168a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,9 +4,9 @@ from typing import Generator import pytest +from _pytest.fixtures import SubRequest import huggingface_hub -from _pytest.fixtures import SubRequest from huggingface_hub import HfApi, HfFolder from huggingface_hub.utils import SoftTemporaryDirectory @@ -91,10 +91,7 @@ def test_space(self) -> None: """ # Check if production token exists if PRODUCTION_TOKEN is None: - pytest.skip( - "Skip Space tests. `HUGGINGFACE_PRODUCTION_USER_TOKEN` environment variable" - " is not set." - ) + pytest.skip("Skip Space tests. `HUGGINGFACE_PRODUCTION_USER_TOKEN` environment variable is not set.") # Generate repo id from prod token api = HfApi(token=PRODUCTION_TOKEN, endpoint=ENDPOINT_PRODUCTION) @@ -104,9 +101,7 @@ def test_space(self) -> None: request.cls.repo_id = repo_id # Create and clean space repo - api.create_repo( - repo_id=repo_id, repo_type="space", space_sdk="gradio", private=True - ) + api.create_repo(repo_id=repo_id, repo_type="space", space_sdk="gradio", private=True) api.upload_file( path_or_fileobj=_BASIC_APP_PY_TEMPLATE, repo_id=repo_id, diff --git a/tests/test_cache_layout.py b/tests/test_cache_layout.py index 2160225df3..a30aeb6a83 100644 --- a/tests/test_cache_layout.py +++ b/tests/test_cache_layout.py @@ -43,9 +43,7 @@ def test_file_downloaded_in_cache(self): revision=revision, ) - expected_directory_name = ( - f'models--{MODEL_IDENTIFIER.replace("/", "--")}' - ) + expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' expected_path = os.path.join(cache, expected_directory_name) refs = os.listdir(os.path.join(expected_path, "refs")) @@ -72,9 +70,7 @@ def test_file_downloaded_in_cache(self): self.assertTrue(os.path.islink(snapshot_content_path)) resolved_blob_relative = os.readlink(snapshot_content_path) - resolved_blob_absolute = os.path.normpath( - os.path.join(snapshot_path, resolved_blob_relative) - ) + resolved_blob_absolute = os.path.normpath(os.path.join(snapshot_path, resolved_blob_relative)) with open(resolved_blob_absolute) as f: blob_contents = f.readline().strip() @@ -90,19 +86,13 @@ def test_no_exist_file_is_cached(self): filename = "this_does_not_exist.txt" with self.assertRaises(EntryNotFoundError): # The file does not exist, so we get an exception. - hf_hub_download( - MODEL_IDENTIFIER, filename, cache_dir=cache, revision=revision - ) + hf_hub_download(MODEL_IDENTIFIER, filename, cache_dir=cache, revision=revision) - expected_directory_name = ( - f'models--{MODEL_IDENTIFIER.replace("/", "--")}' - ) + expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' expected_path = os.path.join(cache, expected_directory_name) refs = os.listdir(os.path.join(expected_path, "refs")) - no_exist_snapshots = os.listdir( - os.path.join(expected_path, ".no_exist") - ) + no_exist_snapshots = os.listdir(os.path.join(expected_path, ".no_exist")) # Only reference should be `main`. self.assertListEqual(refs, [expected_reference]) @@ -152,9 +142,7 @@ def test_file_download_happens_once_intra_revision(self): time.sleep(2) - path = hf_hub_download( - MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision="file-2" - ) + path = hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision="file-2") creation_time_1 = os.path.getmtime(path) self.assertEqual(creation_time_0, creation_time_1) @@ -163,9 +151,7 @@ def test_file_download_happens_once_intra_revision(self): def test_multiple_refs_for_same_file(self): with SoftTemporaryDirectory() as cache: hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache) - hf_hub_download( - MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision="file-2" - ) + hf_hub_download(MODEL_IDENTIFIER, "file_0.txt", cache_dir=cache, revision="file-2") expected_directory_name = f'models--{MODEL_IDENTIFIER.replace("/", "--")}' expected_path = os.path.join(cache, expected_directory_name) @@ -179,19 +165,14 @@ def test_multiple_refs_for_same_file(self): # Directory should contain two revisions self.assertListEqual(refs, ["file-2", "main"]) - refs_contents = [ - get_file_contents(os.path.join(expected_path, "refs", f)) for f in refs - ] + refs_contents = [get_file_contents(os.path.join(expected_path, "refs", f)) for f in refs] refs_contents.sort() # snapshots directory should contain two snapshots self.assertListEqual(refs_contents, snapshots) snapshot_links = [ - os.readlink( - os.path.join(expected_path, "snapshots", filename, "file_0.txt") - ) - for filename in snapshots + os.readlink(os.path.join(expected_path, "snapshots", filename, "file_0.txt")) for filename in snapshots ] # All snapshot links should point to the same file. @@ -216,9 +197,7 @@ def test_file_downloaded_in_cache(self): # Directory should contain two revisions self.assertListEqual(refs, ["main"]) - ref_content = get_file_contents( - os.path.join(expected_path, "refs", refs[0]) - ) + ref_content = get_file_contents(os.path.join(expected_path, "refs", refs[0])) # snapshots directory should contain two snapshots self.assertListEqual([ref_content], snapshots) @@ -227,17 +206,11 @@ def test_file_downloaded_in_cache(self): files_in_snapshot = os.listdir(snapshot_path) - snapshot_links = [ - os.readlink(os.path.join(snapshot_path, filename)) - for filename in files_in_snapshot - ] + snapshot_links = [os.readlink(os.path.join(snapshot_path, filename)) for filename in files_in_snapshot] - resolved_snapshot_links = [ - os.path.normpath(os.path.join(snapshot_path, link)) - for link in snapshot_links - ] + resolved_snapshot_links = [os.path.normpath(os.path.join(snapshot_path, link)) for link in snapshot_links] - self.assertTrue(all([os.path.isfile(l) for l in resolved_snapshot_links])) + self.assertTrue(all([os.path.isfile(link) for link in resolved_snapshot_links])) @xfail_on_windows(reason="Symlinks are deactivated in Windows tests.") def test_file_downloaded_in_cache_several_revisions(self): @@ -257,28 +230,21 @@ def test_file_downloaded_in_cache_several_revisions(self): # Directory should contain two revisions self.assertListEqual(refs, ["file-2", "file-3"]) - refs_content = [ - get_file_contents(os.path.join(expected_path, "refs", ref)) - for ref in refs - ] + refs_content = [get_file_contents(os.path.join(expected_path, "refs", ref)) for ref in refs] refs_content.sort() # snapshots directory should contain two snapshots self.assertListEqual(refs_content, snapshots) - snapshots_paths = [ - os.path.join(expected_path, "snapshots", s) for s in snapshots - ] + snapshots_paths = [os.path.join(expected_path, "snapshots", s) for s in snapshots] files_in_snapshots = {s: os.listdir(s) for s in snapshots_paths} links_in_snapshots = { - k: [os.readlink(os.path.join(k, _v)) for _v in v] - for k, v in files_in_snapshots.items() + k: [os.readlink(os.path.join(k, _v)) for _v in v] for k, v in files_in_snapshots.items() } resolved_snapshots_links = { - k: [os.path.normpath(os.path.join(k, link)) for link in v] - for k, v in links_in_snapshots.items() + k: [os.path.normpath(os.path.join(k, link)) for link in v] for k, v in links_in_snapshots.items() } all_links = [b for a in resolved_snapshots_links.values() for b in a] @@ -347,9 +313,7 @@ def test_update_reference(self): # Directory should contain two revisions self.assertListEqual(refs, ["main"]) - initial_ref_content = get_file_contents( - os.path.join(expected_path, "refs", refs[0]) - ) + initial_ref_content = get_file_contents(os.path.join(expected_path, "refs", refs[0])) # Upload a new file on the same branch self._api.upload_file( @@ -360,36 +324,18 @@ def test_update_reference(self): hf_hub_download(repo_id, "file.txt", cache_dir=cache) - final_ref_content = get_file_contents( - os.path.join(expected_path, "refs", refs[0]) - ) + final_ref_content = get_file_contents(os.path.join(expected_path, "refs", refs[0])) # The `main` reference should point to two different, but existing snapshots which contain # a 'file.txt' self.assertNotEqual(initial_ref_content, final_ref_content) + self.assertTrue(os.path.isdir(os.path.join(expected_path, "snapshots", initial_ref_content))) self.assertTrue( - os.path.isdir( - os.path.join(expected_path, "snapshots", initial_ref_content) - ) - ) - self.assertTrue( - os.path.isfile( - os.path.join( - expected_path, "snapshots", initial_ref_content, "file.txt" - ) - ) - ) - self.assertTrue( - os.path.isdir( - os.path.join(expected_path, "snapshots", final_ref_content) - ) + os.path.isfile(os.path.join(expected_path, "snapshots", initial_ref_content, "file.txt")) ) + self.assertTrue(os.path.isdir(os.path.join(expected_path, "snapshots", final_ref_content))) self.assertTrue( - os.path.isfile( - os.path.join( - expected_path, "snapshots", final_ref_content, "file.txt" - ) - ) + os.path.isfile(os.path.join(expected_path, "snapshots", final_ref_content, "file.txt")) ) except Exception: raise diff --git a/tests/test_cache_no_symlinks.py b/tests/test_cache_no_symlinks.py index fc4cff3e3b..1703cc31a2 100644 --- a/tests/test_cache_no_symlinks.py +++ b/tests/test_cache_no_symlinks.py @@ -27,9 +27,7 @@ def test_are_symlinks_supported_default(self) -> None: @patch("huggingface_hub.file_download.os.symlink") @patch("huggingface_hub.file_download._are_symlinks_supported_in_dir", {}) - def test_are_symlinks_supported_windows_specific_dir( - self, mock_symlink: Mock - ) -> None: + def test_are_symlinks_supported_windows_specific_dir(self, mock_symlink: Mock) -> None: mock_symlink.side_effect = [OSError(), None] # First dir not supported then yes this_dir = Path(__file__).parent @@ -49,9 +47,7 @@ def test_are_symlinks_supported_windows_specific_dir( self.assertTrue(are_symlinks_supported()) # True @patch("huggingface_hub.file_download.are_symlinks_supported") - def test_download_no_symlink_new_file( - self, mock_are_symlinks_supported: Mock - ) -> None: + def test_download_no_symlink_new_file(self, mock_are_symlinks_supported: Mock) -> None: mock_are_symlinks_supported.return_value = False filepath = Path( hf_hub_download( @@ -70,9 +66,7 @@ def test_download_no_symlink_new_file( self.assertEqual(len(list((Path(filepath).parents[2] / "blobs").glob("*"))), 0) @patch("huggingface_hub.file_download.are_symlinks_supported") - def test_download_no_symlink_existing_file( - self, mock_are_symlinks_supported: Mock - ) -> None: + def test_download_no_symlink_existing_file(self, mock_are_symlinks_supported: Mock) -> None: mock_are_symlinks_supported.return_value = True filepath = Path( hf_hub_download( @@ -110,9 +104,7 @@ def test_download_no_symlink_existing_file( self.assertTrue(blob_path.is_file()) @patch("huggingface_hub.file_download.are_symlinks_supported") - def test_scan_and_delete_cache_no_symlinks( - self, mock_are_symlinks_supported: Mock - ) -> None: + def test_scan_and_delete_cache_no_symlinks(self, mock_are_symlinks_supported: Mock) -> None: """Test scan_cache_dir works as well when cache-system doesn't use symlinks.""" OLDER_REVISION = "44c70f043cfe8162efc274ff531575e224a0e6f0" @@ -191,9 +183,7 @@ def test_scan_and_delete_cache_no_symlinks( # Since files are not shared (README.md is duplicated in cache), the total size # of the repo is the sum of each revision size. If symlinks were used, the total # size of the repo would be lower. - self.assertEqual( - repo.size_on_disk, main_revision.size_on_disk + older_revision.size_on_disk - ) + self.assertEqual(repo.size_on_disk, main_revision.size_on_disk + older_revision.size_on_disk) # Test delete repo strategy strategy_delete_repo = report.delete_revisions(main_ref, OLDER_REVISION) @@ -209,9 +199,7 @@ def test_scan_and_delete_cache_no_symlinks( strategy_delete_revision.blobs, {file.blob_path for file in older_revision.files}, ) - self.assertEqual( - strategy_delete_revision.snapshots, {older_revision.snapshot_path} - ) + self.assertEqual(strategy_delete_revision.snapshots, {older_revision.snapshot_path}) self.assertEqual(len(strategy_delete_revision.refs), 0) self.assertEqual(len(strategy_delete_revision.repos), 0) strategy_delete_revision.execute() # Execute without error diff --git a/tests/test_cli.py b/tests/test_cli.py index 5717b6b06d..bfe64de76c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,9 +12,7 @@ def setUp(self) -> None: TODO: add other subcommands. """ - self.parser = ArgumentParser( - "huggingface-cli", usage="huggingface-cli []" - ) + self.parser = ArgumentParser("huggingface-cli", usage="huggingface-cli []") commands_parser = self.parser.add_subparsers() ScanCacheCommand.register_subcommand(commands_parser) DeleteCacheCommand.register_subcommand(commands_parser) diff --git a/tests/test_command_delete_cache.py b/tests/test_command_delete_cache.py index 6a9717afe5..ca50418a8f 100644 --- a/tests/test_command_delete_cache.py +++ b/tests/test_command_delete_cache.py @@ -4,6 +4,9 @@ from tempfile import mkstemp from unittest.mock import Mock, patch +from InquirerPy.base.control import Choice +from InquirerPy.separator import Separator + from huggingface_hub.commands.delete_cache import ( _CANCEL_DELETION_STR, DeleteCacheCommand, @@ -14,8 +17,6 @@ _read_manual_review_tmp_file, ) from huggingface_hub.utils import SoftTemporaryDirectory -from InquirerPy.base.control import Choice -from InquirerPy.separator import Separator from .testing_utils import capture_output, handle_injection @@ -48,9 +49,7 @@ def test_get_tui_choices_from_scan_with_preselection(self) -> None: # Dataset repo separator self.assertIsInstance(choices[1], Separator) - self.assertEqual( - choices[1]._line, "\nDataset dummy_dataset (8M, used 2 weeks ago)" - ) + self.assertEqual(choices[1]._line, "\nDataset dummy_dataset (8M, used 2 weeks ago)") # Only revision of `dummy_dataset` self.assertIsInstance(choices[2], Choice) @@ -64,9 +63,7 @@ def test_get_tui_choices_from_scan_with_preselection(self) -> None: # Model `dummy_model` separator self.assertIsInstance(choices[3], Separator) - self.assertEqual( - choices[3]._line, "\nModel dummy_model (1.4K, used 2 years ago)" - ) + self.assertEqual(choices[3]._line, "\nModel dummy_model (1.4K, used 2 years ago)") # Oldest revision of `dummy_model` self.assertIsInstance(choices[4], Choice) @@ -87,9 +84,7 @@ def test_get_tui_choices_from_scan_with_preselection(self) -> None: # Only revision of `gpt2` self.assertIsInstance(choices[7], Choice) self.assertEqual(choices[7].value, "abcdef123456789") - self.assertEqual( - choices[7].name, "abcdef12: main, refs/pr/1 # modified 2 years ago" - ) + self.assertEqual(choices[7].name, "abcdef12: main, refs/pr/1 # modified 2 years ago") self.assertFalse(choices[7].enabled) def test_get_expectations_str_on_no_deletion_item(self) -> None: @@ -182,9 +177,7 @@ def _input_answers(): self.assertIn("# recent_hash_id", content) # Select dataset revision - content = content.replace( - "# dataset_revision_hash_id", "dataset_revision_hash_id" - ) + content = content.replace("# dataset_revision_hash_id", "dataset_revision_hash_id") # Deselect abcdef123456789 content = content.replace("abcdef123456789", "# abcdef123456789") with open(tmp_path, "w") as f: @@ -207,9 +200,7 @@ def _input_answers(): self.assertFalse(os.path.isfile(tmp_path)) # now deleted # User changed the selection - self.assertListEqual( - selected_hashes, ["dataset_revision_hash_id", "older_hash_id"] - ) + self.assertListEqual(selected_hashes, ["dataset_revision_hash_id", "older_hash_id"]) # Check printed instructions printed = output.getvalue() @@ -292,12 +283,8 @@ def test_run_and_delete_with_tui( mock__manual_review_tui.assert_called_once_with(cache_mock, preselected=[]) # Step 3: ask confirmation - mock__get_expectations_str.assert_called_once_with( - cache_mock, ["hash_1", "hash_2"] - ) - mock_confirm.assert_called_once_with( - "Will delete A and B. Confirm deletion ?", default=True - ) + mock__get_expectations_str.assert_called_once_with(cache_mock, ["hash_1", "hash_2"]) + mock_confirm.assert_called_once_with("Will delete A and B. Confirm deletion ?", default=True) mock_confirm().execute.assert_called_once_with() # Step 4: delete @@ -308,8 +295,7 @@ def test_run_and_delete_with_tui( # Check output self.assertEqual( output.getvalue(), - "Start deletion.\n" - "Done. Deleted 0 repo(s) and 0 revision(s) for a total of 5.1M.\n", + "Start deletion.\nDone. Deleted 0 repo(s) and 0 revision(s) for a total of 5.1M.\n", ) def test_run_nothing_selected_with_tui(self, mock__manual_review_tui: Mock) -> None: @@ -325,9 +311,7 @@ def test_run_nothing_selected_with_tui(self, mock__manual_review_tui: Mock) -> N # Check output self.assertEqual(output.getvalue(), "Deletion is cancelled. Do nothing.\n") - def test_run_stuff_selected_but_cancel_item_as_well_with_tui( - self, mock__manual_review_tui: Mock - ) -> None: + def test_run_stuff_selected_but_cancel_item_as_well_with_tui(self, mock__manual_review_tui: Mock) -> None: """Test command run when some are selected but "cancel item" as well.""" # Mock return value mock__manual_review_tui.return_value = [ @@ -371,12 +355,8 @@ def test_run_and_delete_no_tui( mock__manual_review_no_tui.assert_called_once_with(cache_mock, preselected=[]) # Step 3: ask confirmation - mock__get_expectations_str.assert_called_once_with( - cache_mock, ["hash_1", "hash_2"] - ) - mock__ask_for_confirmation_no_tui.assert_called_once_with( - "Will delete A and B. Confirm deletion ?" - ) + mock__get_expectations_str.assert_called_once_with(cache_mock, ["hash_1", "hash_2"]) + mock__ask_for_confirmation_no_tui.assert_called_once_with("Will delete A and B. Confirm deletion ?") # Step 4: delete cache_mock.delete_revisions.assert_called_once_with("hash_1", "hash_2") @@ -386,8 +366,7 @@ def test_run_and_delete_no_tui( # Check output self.assertEqual( output.getvalue(), - "Start deletion.\n" - "Done. Deleted 0 repo(s) and 0 revision(s) for a total of 5.1M.\n", + "Start deletion.\nDone. Deleted 0 repo(s) and 0 revision(s) for a total of 5.1M.\n", ) diff --git a/tests/test_commit_api.py b/tests/test_commit_api.py index facd27c6a4..55a647ee80 100644 --- a/tests/test_commit_api.py +++ b/tests/test_commit_api.py @@ -10,42 +10,22 @@ class TestCommitOperationDelete(unittest.TestCase): def test_implicit_file(self): self.assertFalse(CommitOperationDelete(path_in_repo="path/to/file").is_folder) - self.assertFalse( - CommitOperationDelete(path_in_repo="path/to/file.md").is_folder - ) + self.assertFalse(CommitOperationDelete(path_in_repo="path/to/file.md").is_folder) def test_implicit_folder(self): self.assertTrue(CommitOperationDelete(path_in_repo="path/to/folder/").is_folder) - self.assertTrue( - CommitOperationDelete(path_in_repo="path/to/folder.md/").is_folder - ) + self.assertTrue(CommitOperationDelete(path_in_repo="path/to/folder.md/").is_folder) def test_explicit_file(self): # Weird case: if user explicitly set as file (`is_folder`=False) but path has a # trailing "/" => user input has priority - self.assertFalse( - CommitOperationDelete( - path_in_repo="path/to/folder/", is_folder=False - ).is_folder - ) - self.assertFalse( - CommitOperationDelete( - path_in_repo="path/to/folder.md/", is_folder=False - ).is_folder - ) + self.assertFalse(CommitOperationDelete(path_in_repo="path/to/folder/", is_folder=False).is_folder) + self.assertFalse(CommitOperationDelete(path_in_repo="path/to/folder.md/", is_folder=False).is_folder) def test_explicit_folder(self): # No need for the trailing "/" is `is_folder` explicitly passed - self.assertTrue( - CommitOperationDelete( - path_in_repo="path/to/folder", is_folder=True - ).is_folder - ) - self.assertTrue( - CommitOperationDelete( - path_in_repo="path/to/folder.md", is_folder=True - ).is_folder - ) + self.assertTrue(CommitOperationDelete(path_in_repo="path/to/folder", is_folder=True).is_folder) + self.assertTrue(CommitOperationDelete(path_in_repo="path/to/folder.md", is_folder=True).is_folder) def test_is_folder_wrong_value(self): with self.assertRaises(ValueError): @@ -53,13 +33,10 @@ def test_is_folder_wrong_value(self): class TestWarnOnOverwritingOperations(unittest.TestCase): - add_file_ab = CommitOperationAdd(path_in_repo="a/b.txt", path_or_fileobj=b"data") add_file_abc = CommitOperationAdd(path_in_repo="a/b/c.md", path_or_fileobj=b"data") add_file_abd = CommitOperationAdd(path_in_repo="a/b/d.md", path_or_fileobj=b"data") - update_file_abc = CommitOperationAdd( - path_in_repo="a/b/c.md", path_or_fileobj=b"updated data" - ) + update_file_abc = CommitOperationAdd(path_in_repo="a/b/c.md", path_or_fileobj=b"updated data") delete_file_abc = CommitOperationDelete(path_in_repo="a/b/c.md") delete_folder_a = CommitOperationDelete(path_in_repo="a/") delete_folder_e = CommitOperationDelete(path_in_repo="e/") @@ -93,6 +70,4 @@ def test_delete_file_then_add(self) -> None: warn_on_overwriting_operations([self.delete_file_abc, self.add_file_abc]) def test_delete_folder_then_add(self) -> None: - warn_on_overwriting_operations( - [self.delete_folder_a, self.add_file_ab, self.add_file_abc] - ) + warn_on_overwriting_operations([self.delete_folder_a, self.add_file_ab, self.add_file_abc]) diff --git a/tests/test_endpoint_helpers.py b/tests/test_endpoint_helpers.py index d9a8f34d76..33db1ec6aa 100644 --- a/tests/test_endpoint_helpers.py +++ b/tests/test_endpoint_helpers.py @@ -15,6 +15,7 @@ import unittest import requests + from huggingface_hub.hf_api import HfApi from huggingface_hub.utils.endpoint_helpers import ( AttributeDictionary, @@ -91,10 +92,7 @@ def test_repr(self): self._attrdict.itemB = 3 self._attrdict["1a"] = 2 self._attrdict["itemA?"] = 4 - repr_string = ( - "Available Attributes or Keys:\n * 1a (Key only)\n * itemA\n * itemA? (Key" - " only)\n * itemB\n" - ) + repr_string = "Available Attributes or Keys:\n * 1a (Key only)\n * itemA\n * itemA? (Key only)\n * itemB\n" self.assertEqual(repr_string, repr(self._attrdict)) @@ -128,17 +126,13 @@ def test_init(self): self.assertTrue("1Item_B" not in dir(languages)) - self.assertEqual( - licenses, AttributeDictionary({"ItemC": "itemC", "Item_D": "itemD"}) - ) + self.assertEqual(licenses, AttributeDictionary({"ItemC": "itemC", "Item_D": "itemD"})) def test_filter(self): _tags = GeneralTags(self._tag_dictionary, keys=["license"]) self.assertTrue(hasattr(_tags, "license")) self.assertFalse(hasattr(_tags, "languages")) - self.assertEqual( - _tags.license, AttributeDictionary({"ItemC": "itemC", "Item_D": "itemD"}) - ) + self.assertEqual(_tags.license, AttributeDictionary({"ItemC": "itemC", "Item_D": "itemD"})) class ModelTagsTest(unittest.TestCase): diff --git a/tests/test_fastai_integration.py b/tests/test_fastai_integration.py index 24843c9b27..ab8794a567 100644 --- a/tests/test_fastai_integration.py +++ b/tests/test_fastai_integration.py @@ -18,9 +18,7 @@ WORKING_REPO_SUBDIR = f"fixtures/working_repo_{__name__.split('.')[-1]}" -WORKING_REPO_DIR = os.path.join( - os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR -) +WORKING_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR) if is_fastai_available(): from fastai.data.block import DataBlock @@ -90,9 +88,7 @@ def test_save_pretrained_without_config(self): def test_save_pretrained_with_config(self): REPO_NAME = repo_name("fastai-save") - _save_pretrained_fastai( - dummy_model, f"{WORKING_REPO_DIR}/{REPO_NAME}", config=dummy_config - ) + _save_pretrained_fastai(dummy_model, f"{WORKING_REPO_DIR}/{REPO_NAME}", config=dummy_config) files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") self.assertTrue("config.json" in files) self.assertEqual(len(files), 4) @@ -109,7 +105,5 @@ def test_push_to_hub_and_from_pretrained_fastai(self): self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}") loaded_model = from_pretrained_fastai(f"{USER}/{REPO_NAME}") - self.assertEqual( - dummy_model.show_training_loop(), loaded_model.show_training_loop() - ) + self.assertEqual(dummy_model.show_training_loop(), loaded_model.show_training_loop()) self._api.delete_repo(repo_id=f"{REPO_NAME}") diff --git a/tests/test_file_download.py b/tests/test_file_download.py index 1abf31963c..4a0e0b3aa5 100644 --- a/tests/test_file_download.py +++ b/tests/test_file_download.py @@ -19,8 +19,8 @@ from unittest.mock import Mock, patch import pytest - import requests + from huggingface_hub import HfApi from huggingface_hub.constants import ( CONFIG_NAME, @@ -90,23 +90,15 @@ def test_no_connection(self): filename=CONFIG_NAME, revision=DUMMY_MODEL_ID_REVISION_INVALID, ) - valid_url = hf_hub_url( - DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT - ) - self.assertIsNotNone( - cached_download(valid_url, force_download=True, legacy_cache_layout=True) - ) + valid_url = hf_hub_url(DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT) + self.assertIsNotNone(cached_download(valid_url, force_download=True, legacy_cache_layout=True)) for offline_mode in OfflineSimulationMode: with offline(mode=offline_mode): with self.assertRaisesRegex(ValueError, "Connection error"): _ = cached_download(invalid_url, legacy_cache_layout=True) with self.assertRaisesRegex(ValueError, "Connection error"): - _ = cached_download( - valid_url, force_download=True, legacy_cache_layout=True - ) - self.assertIsNotNone( - cached_download(valid_url, legacy_cache_layout=True) - ) + _ = cached_download(valid_url, force_download=True, legacy_cache_layout=True) + self.assertIsNotNone(cached_download(valid_url, legacy_cache_layout=True)) def test_file_not_found_on_repo(self): # Valid revision (None) but missing file on repo. @@ -195,9 +187,7 @@ def test_repo_not_found(self): _ = cached_download(url, legacy_cache_layout=True) def test_standard_object(self): - url = hf_hub_url( - DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT - ) + url = hf_hub_url(DUMMY_MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT) filepath = cached_download(url, force_download=True, legacy_cache_layout=True) metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertEqual(metadata, (url, f'"{DUMMY_MODEL_ID_PINNED_SHA1}"')) @@ -215,9 +205,7 @@ def test_standard_object_rev(self): # Caution: check that the etag is *not* equal to the one from `test_standard_object` def test_lfs_object(self): - url = hf_hub_url( - DUMMY_MODEL_ID, filename=PYTORCH_WEIGHTS_NAME, revision=REVISION_ID_DEFAULT - ) + url = hf_hub_url(DUMMY_MODEL_ID, filename=PYTORCH_WEIGHTS_NAME, revision=REVISION_ID_DEFAULT) filepath = cached_download(url, force_download=True, legacy_cache_layout=True) metadata = filename_to_url(filepath, legacy_cache_layout=True) self.assertEqual(metadata, (url, f'"{DUMMY_MODEL_ID_PINNED_SHA256}"')) @@ -260,9 +248,7 @@ def test_hf_hub_download_custom_cache_permission(self): # Equivalent to umask u=rwx,g=r,o= previous_umask = os.umask(0o037) try: - filepath = hf_hub_download( - DUMMY_RENAMED_OLD_MODEL_ID, "config.json", cache_dir=tmpdir - ) + filepath = hf_hub_download(DUMMY_RENAMED_OLD_MODEL_ID, "config.json", cache_dir=tmpdir) # Permissions are honored (640: u=rw,g=r,o=) self.assertEqual(stat.S_IMODE(os.stat(filepath).st_mode), 0o640) finally: @@ -275,9 +261,7 @@ def test_download_from_a_renamed_repo_with_hf_hub_download(self): https://github.com/huggingface/huggingface_hub/issues/981 """ with SoftTemporaryDirectory() as tmpdir: - filepath = hf_hub_download( - DUMMY_RENAMED_OLD_MODEL_ID, "config.json", cache_dir=tmpdir - ) + filepath = hf_hub_download(DUMMY_RENAMED_OLD_MODEL_ID, "config.json", cache_dir=tmpdir) self.assertTrue(os.path.exists(filepath)) def test_download_from_a_renamed_repo_with_cached_download(self): @@ -373,9 +357,7 @@ def test_try_to_load_from_cache_exist(self): new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename=CONFIG_NAME) self.assertEqual(filepath, new_file_path) - new_file_path = try_to_load_from_cache( - DUMMY_MODEL_ID, filename=CONFIG_NAME, revision="main" - ) + new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename=CONFIG_NAME, revision="main") self.assertEqual(filepath, new_file_path) # If file is not cached, returns None @@ -393,28 +375,16 @@ def test_try_to_load_from_cache_exist(self): def test_try_to_load_from_cache_specific_pr_revision_exists(self): # Make sure the file is cached - file_path = hf_hub_download( - DUMMY_MODEL_ID, filename=CONFIG_NAME, revision="refs/pr/1" - ) + file_path = hf_hub_download(DUMMY_MODEL_ID, filename=CONFIG_NAME, revision="refs/pr/1") - new_file_path = try_to_load_from_cache( - DUMMY_MODEL_ID, filename=CONFIG_NAME, revision="refs/pr/1" - ) + new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename=CONFIG_NAME, revision="refs/pr/1") self.assertEqual(file_path, new_file_path) # If file is not cached, returns None - self.assertIsNone( - try_to_load_from_cache( - DUMMY_MODEL_ID, filename="conf.json", revision="refs/pr/1" - ) - ) + self.assertIsNone(try_to_load_from_cache(DUMMY_MODEL_ID, filename="conf.json", revision="refs/pr/1")) # If revision does not exist, returns None - self.assertIsNone( - try_to_load_from_cache( - DUMMY_MODEL_ID, filename=CONFIG_NAME, revision="does-not-exist" - ) - ) + self.assertIsNone(try_to_load_from_cache(DUMMY_MODEL_ID, filename=CONFIG_NAME, revision="does-not-exist")) def test_try_to_load_from_cache_no_exist(self): # Make sure the file is cached @@ -424,9 +394,7 @@ def test_try_to_load_from_cache_no_exist(self): new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename="dummy") self.assertEqual(new_file_path, _CACHED_NO_EXIST) - new_file_path = try_to_load_from_cache( - DUMMY_MODEL_ID, filename="dummy", revision="main" - ) + new_file_path = try_to_load_from_cache(DUMMY_MODEL_ID, filename="dummy", revision="main") self.assertEqual(new_file_path, _CACHED_NO_EXIST) # If file non-existence is not cached, returns None @@ -489,9 +457,7 @@ def test_get_hf_file_metadata_basic(self) -> None: metadata = get_hf_file_metadata(url) # Metadata - self.assertEqual( - metadata.commit_hash, DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT - ) + self.assertEqual(metadata.commit_hash, DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT) self.assertIsNotNone(metadata.etag) # example: "85c2fc2dcdd86563aaa85ef4911..." self.assertEqual(metadata.location, url) # no redirect self.assertEqual(metadata.size, 851) @@ -533,9 +499,7 @@ def test_download_from_a_gated_repo_with_hf_hub_download(self): # Create a gated repo on the fly. Repo is created by "other user" so that the # usual CI user don't have access to it. api = HfApi(token=OTHER_TOKEN) - repo_url = api.create_repo( - repo_id="gated_repo_for_huggingface_hub_ci", exist_ok=True - ) + repo_url = api.create_repo(repo_id="gated_repo_for_huggingface_hub_ci", exist_ok=True) requests.put( f"{repo_url.endpoint}/api/models/{repo_url.repo_id}/settings", headers=api._build_hf_headers(), @@ -546,8 +510,7 @@ def test_download_from_a_gated_repo_with_hf_hub_download(self): with SoftTemporaryDirectory() as tmpdir: with self.assertRaisesRegex( GatedRepoError, - "Access to model .* is restricted and you are not in the authorized" - " list", + "Access to model .* is restricted and you are not in the authorized list", ): hf_hub_download( repo_id=repo_url.repo_id, @@ -575,7 +538,9 @@ class StagingCachedDownloadOnAwfulFilenamesTest(unittest.TestCase): def setUpClass(cls): cls.api = HfApi(endpoint=ENDPOINT_STAGING, token=TOKEN) cls.repo_url = cls.api.create_repo(repo_id=repo_name("awful_filename")) - cls.expected_resolve_url = f"{cls.repo_url}/resolve/main/subfolder/to%3F/awful%3Ffilename%25you%3Ashould%2Cnever.give" + cls.expected_resolve_url = ( + f"{cls.repo_url}/resolve/main/subfolder/to%3F/awful%3Ffilename%25you%3Ashould%2Cnever.give" + ) cls.api.upload_file( path_or_fileobj=b"content", path_in_repo=cls.filepath, @@ -587,9 +552,7 @@ def tearDownClass(cls) -> None: cls.api.delete_repo(repo_id=cls.repo_url.repo_id) def test_hf_hub_url_on_awful_filepath(self): - self.assertEqual( - hf_hub_url(self.repo_url.repo_id, self.filepath), self.expected_resolve_url - ) + self.assertEqual(hf_hub_url(self.repo_url.repo_id, self.filepath), self.expected_resolve_url) def test_hf_hub_url_on_awful_subfolder_and_filename(self): self.assertEqual( @@ -599,9 +562,7 @@ def test_hf_hub_url_on_awful_subfolder_and_filename(self): @xfail_on_windows(reason="Windows paths cannot contain a '?'.") def test_hf_hub_download_on_awful_filepath(self): - local_path = hf_hub_download( - self.repo_url.repo_id, self.filepath, cache_dir=self.cache_dir - ) + local_path = hf_hub_download(self.repo_url.repo_id, self.filepath, cache_dir=self.cache_dir) # Local path is not url-encoded self.assertTrue(local_path.endswith(self.filepath)) @@ -620,9 +581,7 @@ def test_hf_hub_download_on_awful_subfolder_and_filename(self): class CreateSymlinkTest(unittest.TestCase): @unittest.skipIf(os.name == "nt", "No symlinks on Windows") @patch("huggingface_hub.file_download.are_symlinks_supported") - def test_create_relative_symlink_concurrent_access( - self, mock_are_symlinks_supported: Mock - ) -> None: + def test_create_relative_symlink_concurrent_access(self, mock_are_symlinks_supported: Mock) -> None: with SoftTemporaryDirectory() as tmpdir: src = os.path.join(tmpdir, "source") other = os.path.join(tmpdir, "other") diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index d2353c43ef..1b4d18dbd3 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -28,8 +28,9 @@ from urllib.parse import quote import pytest - import requests +from requests.exceptions import HTTPError + from huggingface_hub import Repository, SpaceHardware, SpaceStage from huggingface_hub._commit_api import ( CommitOperationAdd, @@ -76,7 +77,6 @@ ModelFilter, _filter_emissions, ) -from requests.exceptions import HTTPError from .testing_constants import ( ENDPOINT_STAGING, @@ -108,9 +108,7 @@ space_repo_name = partial(repo_name, prefix="my-space") large_file_repo_name = partial(repo_name, prefix="my-model-largefiles") -WORKING_REPO_DIR = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "fixtures/working_repo" -) +WORKING_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/working_repo") LARGE_FILE_14MB = "https://cdn-media.huggingface.co/lfs-largefiles/progit.epub" LARGE_FILE_18MB = "https://cdn-media.huggingface.co/lfs-largefiles/progit.pdf" @@ -131,18 +129,14 @@ def tearDownClass(cls) -> None: @expect_deprecation("read_from_credential_store") def test_login_git_credentials(self): - self.assertTupleEqual( - read_from_credential_store(USERNAME_PLACEHOLDER), (None, None) - ) + self.assertTupleEqual(read_from_credential_store(USERNAME_PLACEHOLDER), (None, None)) self._api.set_access_token(TOKEN) self.assertTupleEqual( read_from_credential_store(USERNAME_PLACEHOLDER), (USERNAME_PLACEHOLDER, TOKEN), ) erase_from_credential_store(username=USERNAME_PLACEHOLDER) - self.assertTupleEqual( - read_from_credential_store(USERNAME_PLACEHOLDER), (None, None) - ) + self.assertTupleEqual(read_from_credential_store(USERNAME_PLACEHOLDER), (None, None)) @expect_deprecation("read_from_credential_store") def test_login_cli(self): @@ -152,9 +146,7 @@ def test_login_cli(self): (USERNAME_PLACEHOLDER, TOKEN), ) erase_from_credential_store(username=USERNAME_PLACEHOLDER) - self.assertTupleEqual( - read_from_credential_store(USERNAME_PLACEHOLDER), (None, None) - ) + self.assertTupleEqual(read_from_credential_store(USERNAME_PLACEHOLDER), (None, None)) _set_store_as_git_credential_helper_globally() _login(token=TOKEN, add_to_git_credential=True) @@ -163,14 +155,10 @@ def test_login_cli(self): (USERNAME_PLACEHOLDER, TOKEN), ) erase_from_credential_store(username=USERNAME_PLACEHOLDER) - self.assertTupleEqual( - read_from_credential_store(USERNAME_PLACEHOLDER), (None, None) - ) + self.assertTupleEqual(read_from_credential_store(USERNAME_PLACEHOLDER), (None, None)) def test_login_cli_org_fail(self): - with pytest.raises( - ValueError, match="You must use your personal account token." - ): + with pytest.raises(ValueError, match="You must use your personal account token."): _login(token="api_org_dummy_token", add_to_git_credential=True) @@ -257,13 +245,9 @@ def test_create_update_and_delete_repo(self): def test_create_update_and_delete_model_repo(self): REPO_NAME = repo_name("crud") self._api.create_repo(repo_id=REPO_NAME, repo_type=REPO_TYPE_MODEL) - res = self._api.update_repo_visibility( - repo_id=REPO_NAME, private=True, repo_type=REPO_TYPE_MODEL - ) + res = self._api.update_repo_visibility(repo_id=REPO_NAME, private=True, repo_type=REPO_TYPE_MODEL) self.assertTrue(res["private"]) - res = self._api.update_repo_visibility( - repo_id=REPO_NAME, private=False, repo_type=REPO_TYPE_MODEL - ) + res = self._api.update_repo_visibility(repo_id=REPO_NAME, private=False, repo_type=REPO_TYPE_MODEL) self.assertFalse(res["private"]) self._api.delete_repo(repo_id=REPO_NAME, repo_type=REPO_TYPE_MODEL) @@ -271,13 +255,9 @@ def test_create_update_and_delete_model_repo(self): def test_create_update_and_delete_dataset_repo(self): DATASET_REPO_NAME = dataset_repo_name("crud") self._api.create_repo(repo_id=DATASET_REPO_NAME, repo_type=REPO_TYPE_DATASET) - res = self._api.update_repo_visibility( - repo_id=DATASET_REPO_NAME, private=True, repo_type=REPO_TYPE_DATASET - ) + res = self._api.update_repo_visibility(repo_id=DATASET_REPO_NAME, private=True, repo_type=REPO_TYPE_DATASET) self.assertTrue(res["private"]) - res = self._api.update_repo_visibility( - repo_id=DATASET_REPO_NAME, private=False, repo_type=REPO_TYPE_DATASET - ) + res = self._api.update_repo_visibility(repo_id=DATASET_REPO_NAME, private=False, repo_type=REPO_TYPE_DATASET) self.assertFalse(res["private"]) self._api.delete_repo(repo_id=DATASET_REPO_NAME, repo_type=REPO_TYPE_DATASET) @@ -290,26 +270,16 @@ def test_create_update_and_delete_dataset_repo(self): def test_create_update_and_delete_space_repo(self): SPACE_REPO_NAME = space_repo_name("failing") with pytest.raises(ValueError, match=r"No space_sdk provided.*"): - self._api.create_repo( - repo_id=SPACE_REPO_NAME, repo_type=REPO_TYPE_SPACE, space_sdk=None - ) + self._api.create_repo(repo_id=SPACE_REPO_NAME, repo_type=REPO_TYPE_SPACE, space_sdk=None) with pytest.raises(ValueError, match=r"Invalid space_sdk.*"): - self._api.create_repo( - repo_id=SPACE_REPO_NAME, repo_type=REPO_TYPE_SPACE, space_sdk="asdfasdf" - ) + self._api.create_repo(repo_id=SPACE_REPO_NAME, repo_type=REPO_TYPE_SPACE, space_sdk="asdfasdf") for sdk in SPACES_SDK_TYPES: SPACE_REPO_NAME = space_repo_name(sdk) - self._api.create_repo( - repo_id=SPACE_REPO_NAME, repo_type=REPO_TYPE_SPACE, space_sdk=sdk - ) - res = self._api.update_repo_visibility( - repo_id=SPACE_REPO_NAME, private=True, repo_type=REPO_TYPE_SPACE - ) + self._api.create_repo(repo_id=SPACE_REPO_NAME, repo_type=REPO_TYPE_SPACE, space_sdk=sdk) + res = self._api.update_repo_visibility(repo_id=SPACE_REPO_NAME, private=True, repo_type=REPO_TYPE_SPACE) self.assertTrue(res["private"]) - res = self._api.update_repo_visibility( - repo_id=SPACE_REPO_NAME, private=False, repo_type=REPO_TYPE_SPACE - ) + res = self._api.update_repo_visibility(repo_id=SPACE_REPO_NAME, private=False, repo_type=REPO_TYPE_SPACE) self.assertFalse(res["private"]) self._api.delete_repo(repo_id=SPACE_REPO_NAME, repo_type=REPO_TYPE_SPACE) @@ -335,12 +305,8 @@ def test_move_repo_target_already_exists(self) -> None: self._api.create_repo(repo_id=repo_id_1) self._api.create_repo(repo_id=repo_id_2) - with pytest.raises( - HfHubHTTPError, match=r"A model repository called .* already exists" - ): - self._api.move_repo( - from_id=repo_id_1, to_id=repo_id_2, repo_type=REPO_TYPE_MODEL - ) + with pytest.raises(HfHubHTTPError, match=r"A model repository called .* already exists"): + self._api.move_repo(from_id=repo_id_1, to_id=repo_id_2, repo_type=REPO_TYPE_MODEL) self._api.delete_repo(repo_id=repo_id_1) self._api.delete_repo(repo_id=repo_id_2) @@ -386,13 +352,9 @@ def test_commit_operation_validation(self): ValueError, msg="If you passed a file-like object, make sure it is in binary mode", ): - CommitOperationAdd( - path_or_fileobj=ftext, path_in_repo="README.md" # type: ignore - ) + CommitOperationAdd(path_or_fileobj=ftext, path_in_repo="README.md") # type: ignore - with self.assertRaises( - ValueError, msg="path_or_fileobj is str but does not point to a file" - ): + with self.assertRaises(ValueError, msg="path_or_fileobj is str but does not point to a file"): CommitOperationAdd( path_or_fileobj=os.path.join(self.tmp_dir, "nofile.pth"), path_in_repo="README.md", @@ -417,9 +379,7 @@ def test_upload_file_str_path(self): user=USER, repo=REPO_NAME, ) - filepath = cached_download( - url, force_download=True, legacy_cache_layout=True - ) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, self.tmp_file_content) @@ -457,12 +417,8 @@ def test_upload_file_fileobj(self): return_val, f"{self._api.endpoint}/{USER}/{REPO_NAME}/blob/main/temp/new_file.md", ) - url = "{}/{user}/{repo}/resolve/main/temp/new_file.md".format( - ENDPOINT_STAGING, user=USER, repo=REPO_NAME - ) - filepath = cached_download( - url, force_download=True, legacy_cache_layout=True - ) + url = "{}/{user}/{repo}/resolve/main/temp/new_file.md".format(ENDPOINT_STAGING, user=USER, repo=REPO_NAME) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, self.tmp_file_content) @@ -488,12 +444,8 @@ def test_upload_file_bytesio(self): f"{self._api.endpoint}/{USER}/{REPO_NAME}/blob/main/temp/new_file.md", ) - url = "{}/{user}/{repo}/resolve/main/temp/new_file.md".format( - ENDPOINT_STAGING, user=USER, repo=REPO_NAME - ) - filepath = cached_download( - url, force_download=True, legacy_cache_layout=True - ) + url = "{}/{user}/{repo}/resolve/main/temp/new_file.md".format(ENDPOINT_STAGING, user=USER, repo=REPO_NAME) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, filecontent.getvalue().decode()) @@ -515,17 +467,13 @@ def test_create_repo_return_value(self): @retry_endpoint def test_create_repo_org_token_fail(self): REPO_NAME = repo_name("org") - with pytest.raises( - ValueError, match="You must use your personal account token." - ): + with pytest.raises(ValueError, match="You must use your personal account token."): self._api.create_repo(repo_id=REPO_NAME, token="api_org_dummy_token") @retry_endpoint def test_create_repo_org_token_none_fail(self): HfFolder.save_token("api_org_dummy_token") - with pytest.raises( - ValueError, match="You must use your personal account token." - ): + with pytest.raises(ValueError, match="You must use your personal account token."): with patch.object(self._api, "token", None): # no default token self._api.create_repo(repo_id=repo_name("org")) @@ -546,12 +494,8 @@ def test_upload_buffer(self): f"{self._api.endpoint}/{USER}/{REPO_NAME}/blob/main/temp/new_file.md", ) - url = "{}/{user}/{repo}/resolve/main/temp/new_file.md".format( - ENDPOINT_STAGING, user=USER, repo=REPO_NAME - ) - filepath = cached_download( - url, force_download=True, legacy_cache_layout=True - ) + url = "{}/{user}/{repo}/resolve/main/temp/new_file.md".format(ENDPOINT_STAGING, user=USER, repo=REPO_NAME) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, self.tmp_file_content) @@ -583,9 +527,7 @@ def test_upload_file_create_pr(self): url = "{}/{user}/{repo}/resolve/{revision}/temp/new_file.md".format( ENDPOINT_STAGING, revision=pr_revision, user=USER, repo=REPO_NAME ) - filepath = cached_download( - url, force_download=True, legacy_cache_layout=True - ) + filepath = cached_download(url, force_download=True, legacy_cache_layout=True) with open(filepath) as downloaded_file: content = downloaded_file.read() self.assertEqual(content, self.tmp_file_content) @@ -605,9 +547,7 @@ def test_delete_file(self): path_in_repo="temp/new_file.md", repo_id=f"{USER}/{REPO_NAME}", ) - self._api.delete_file( - path_in_repo="temp/new_file.md", repo_id=f"{USER}/{REPO_NAME}" - ) + self._api.delete_file(path_in_repo="temp/new_file.md", repo_id=f"{USER}/{REPO_NAME}") with self.assertRaises(HTTPError): # Should raise a 404 @@ -622,9 +562,7 @@ def test_get_full_repo_name(self): repo_name_with_no_org = self._api.get_full_repo_name("model") self.assertEqual(repo_name_with_no_org, f"{USER}/model") - repo_name_with_no_org = self._api.get_full_repo_name( - "model", organization="org" - ) + repo_name_with_no_org = self._api.get_full_repo_name("model", organization="org") self.assertEqual(repo_name_with_no_org, "org/model") @retry_endpoint @@ -633,9 +571,7 @@ def test_upload_folder(self): visibility = "private" if private else "public" with self.subTest(f"{visibility} repo"): REPO_NAME = repo_name(f"upload_folder_{visibility}") - self._api.create_repo( - repo_id=REPO_NAME, private=private, exist_ok=False - ) + self._api.create_repo(repo_id=REPO_NAME, private=private, exist_ok=False) try: url = self._api.upload_folder( folder_path=self.tmp_dir, @@ -680,9 +616,7 @@ def test_upload_folder_create_pr(self): visibility = "private" if private else "public" with self.subTest(f"{visibility} repo"): REPO_NAME = repo_name(f"upload_folder_{visibility}") - self._api.create_repo( - repo_id=REPO_NAME, private=private, exist_ok=False - ) + self._api.create_repo(repo_id=REPO_NAME, private=private, exist_ok=False) try: return_val = self._api.upload_folder( folder_path=self.tmp_dir, @@ -718,9 +652,7 @@ def test_upload_folder_create_pr(self): def test_upload_folder_default_path_in_repo(self): REPO_NAME = repo_name("upload_folder_to_root") self._api.create_repo(repo_id=REPO_NAME, exist_ok=False) - url = self._api.upload_folder( - folder_path=self.tmp_dir, repo_id=f"{USER}/{REPO_NAME}" - ) + url = self._api.upload_folder(folder_path=self.tmp_dir, repo_id=f"{USER}/{REPO_NAME}") # URL to root of repository self.assertEqual(url, f"{self._api.endpoint}/{USER}/{REPO_NAME}/tree/main/") @@ -736,9 +668,7 @@ def test_create_commit_create_pr(self): ) operations = [ CommitOperationDelete(path_in_repo="temp/new_file.md"), - CommitOperationAdd( - path_in_repo="buffer", path_or_fileobj=b"Buffer data" - ), + CommitOperationAdd(path_in_repo="buffer", path_or_fileobj=b"Buffer data"), ] resp = self._api.create_commit( operations=operations, @@ -768,9 +698,7 @@ def test_create_commit_create_pr(self): with self.assertRaises(HTTPError) as ctx: # Should raise a 404 - hf_hub_download( - f"{USER}/{REPO_NAME}", "buffer", use_auth_token=self._token - ) + hf_hub_download(f"{USER}/{REPO_NAME}", "buffer", use_auth_token=self._token) self.assertEqual(ctx.exception.response.status_code, 404) filepath = hf_hub_download( filename="buffer", @@ -841,12 +769,8 @@ def test_create_commit_create_pr_on_foreign_repo(self): self._api.create_commit( operations=[ - CommitOperationAdd( - path_in_repo="regular.txt", path_or_fileobj=b"File content" - ), - CommitOperationAdd( - path_in_repo="lfs.pkl", path_or_fileobj=b"File content" - ), + CommitOperationAdd(path_in_repo="regular.txt", path_or_fileobj=b"File content"), + CommitOperationAdd(path_in_repo="lfs.pkl", path_or_fileobj=b"File content"), ], commit_message="PR on foreign repo", repo_id=foreign_repo_url.repo_id, @@ -861,9 +785,7 @@ def test_create_commit(self): visibility = "private" if private else "public" with self.subTest(f"{visibility} repo"): REPO_NAME = repo_name(f"create_commit_{visibility}") - self._api.create_repo( - repo_id=REPO_NAME, private=private, exist_ok=False - ) + self._api.create_repo(repo_id=REPO_NAME, private=private, exist_ok=False) try: self._api.upload_file( path_or_fileobj=self.tmp_file, @@ -873,16 +795,12 @@ def test_create_commit(self): with open(self.tmp_file, "rb") as fileobj: operations = [ CommitOperationDelete(path_in_repo="temp/new_file.md"), - CommitOperationAdd( - path_in_repo="buffer", path_or_fileobj=b"Buffer data" - ), + CommitOperationAdd(path_in_repo="buffer", path_or_fileobj=b"Buffer data"), CommitOperationAdd( path_in_repo="bytesio", path_or_fileobj=BytesIO(b"BytesIO data"), ), - CommitOperationAdd( - path_in_repo="fileobj", path_or_fileobj=fileobj - ), + CommitOperationAdd(path_in_repo="fileobj", path_or_fileobj=fileobj), CommitOperationAdd( path_in_repo="nested/path", path_or_fileobj=self.tmp_file, @@ -939,9 +857,7 @@ def test_create_commit_conflict(self): repo_id=f"{USER}/{REPO_NAME}", ) operations = [ - CommitOperationAdd( - path_in_repo="buffer", path_or_fileobj=b"Buffer data" - ), + CommitOperationAdd(path_in_repo="buffer", path_or_fileobj=b"Buffer data"), ] with self.assertRaises(HTTPError) as exc_ctx: self._api.create_commit( @@ -1021,18 +937,14 @@ def _inner(mock: Mock) -> None: # Upload a PNG file self._api.create_commit( operations=[ - CommitOperationAdd( - path_in_repo="image.png", path_or_fileobj=b"image data" - ), + CommitOperationAdd(path_in_repo="image.png", path_or_fileobj=b"image data"), ], commit_message="Test upload lfs file", repo_id=repo_id, ) # Check uploaded as LFS - info = self._api.model_info( - repo_id=repo_id, use_auth_token=self._token, files_metadata=True - ) + info = self._api.model_info(repo_id=repo_id, use_auth_token=self._token, files_metadata=True) siblings = {file.rfilename: file for file in info.siblings} self.assertIsInstance(siblings["image.png"].lfs, dict) # LFS file @@ -1131,25 +1043,19 @@ def tearDownClass(cls): def test_upload_empty_regular_file(self) -> None: with self.assertWarns(UserWarning): - self._api.upload_file( - repo_id=self.repo_id, path_in_repo="empty.txt", path_or_fileobj=b"" - ) + self._api.upload_file(repo_id=self.repo_id, path_in_repo="empty.txt", path_or_fileobj=b"") def test_upload_empty_gitkeep_file(self) -> None: # No warning in case of .gitkeep file with warnings.catch_warnings(record=True) as w: # Taken from https://stackoverflow.com/a/3892301 - self._api.upload_file( - repo_id=self.repo_id, path_in_repo="foo/.gitkeep", path_or_fileobj=b"" - ) + self._api.upload_file(repo_id=self.repo_id, path_in_repo="foo/.gitkeep", path_or_fileobj=b"") self.assertEqual(len(w), 0) def test_upload_empty_lfs_file(self) -> None: # Should have been an LFS file, but uploaded as regular (would fail otherwise) with self.assertWarns(UserWarning): - self._api.upload_file( - repo_id=self.repo_id, path_in_repo="empty.pkl", path_or_fileobj=b"" - ) + self._api.upload_file(repo_id=self.repo_id, path_in_repo="empty.pkl", path_or_fileobj=b"") info = self._api.repo_info(repo_id=self.repo_id, files_metadata=True) repo_file = {file.rfilename: file for file in info.siblings}["empty.pkl"] @@ -1216,9 +1122,7 @@ class HfApiTagEndpointTest(HfApiCommonTestWithLogin): @use_tmp_repo("model") def test_create_tag_on_main(self, repo_url: RepoUrl) -> None: """Check `create_tag` on default main branch works.""" - self._api.create_tag( - repo_url.repo_id, tag="v0", tag_message="This is a tag message." - ) + self._api.create_tag(repo_url.repo_id, tag="v0", tag_message="This is a tag message.") # Check tag is on `main` tag_info = self._api.model_info(repo_url.repo_id, revision="v0") @@ -1234,23 +1138,15 @@ def test_create_tag_on_pr(self, repo_url: RepoUrl) -> None: repo_id=repo_url.repo_id, create_pr=True, commit_message="upload readme", - operations=[ - CommitOperationAdd( - path_or_fileobj=b"this is a file content", path_in_repo="readme.md" - ) - ], + operations=[CommitOperationAdd(path_or_fileobj=b"this is a file content", path_in_repo="readme.md")], ) # Tag the PR - self._api.create_tag( - repo_url.repo_id, tag="v0", revision=commit_info.pr_revision - ) + self._api.create_tag(repo_url.repo_id, tag="v0", revision=commit_info.pr_revision) # Check tag is on `refs/pr/1` tag_info = self._api.model_info(repo_url.repo_id, revision="v0") - pr_info = self._api.model_info( - repo_url.repo_id, revision=commit_info.pr_revision - ) + pr_info = self._api.model_info(repo_url.repo_id, revision=commit_info.pr_revision) main_info = self._api.model_info(repo_url.repo_id) self.assertEqual(tag_info.sha, pr_info.sha) @@ -1268,21 +1164,13 @@ def test_create_tag_on_commit_oid(self, repo_url: RepoUrl) -> None: repo_id=repo_url.repo_id, repo_type="dataset", commit_message="upload readme", - operations=[ - CommitOperationAdd( - path_or_fileobj=b"this is a file content", path_in_repo="readme.md" - ) - ], + operations=[CommitOperationAdd(path_or_fileobj=b"this is a file content", path_in_repo="readme.md")], ) commit_info_2: CommitInfo = self._api.create_commit( repo_id=repo_url.repo_id, repo_type="dataset", commit_message="upload config", - operations=[ - CommitOperationAdd( - path_or_fileobj=b"{'hello': 'world'}", path_in_repo="config.json" - ) - ], + operations=[CommitOperationAdd(path_or_fileobj=b"{'hello': 'world'}", path_in_repo="config.json")], ) # Tag commits @@ -1439,21 +1327,15 @@ def test_create_branch_from_revision(self, repo_url: RepoUrl) -> None: initial_commit = self._api.model_info(repo_url.repo_id).sha self._api.create_commit( repo_url.repo_id, - operations=[ - CommitOperationAdd(path_in_repo="app.py", path_or_fileobj=b"content") - ], + operations=[CommitOperationAdd(path_in_repo="app.py", path_or_fileobj=b"content")], commit_message="test commit", ) latest_commit = self._api.model_info(repo_url.repo_id).sha # Create branches self._api.create_branch(repo_url.repo_id, branch="from-head") - self._api.create_branch( - repo_url.repo_id, branch="from-initial", revision=initial_commit - ) - self._api.create_branch( - repo_url.repo_id, branch="from-branch", revision="from-initial" - ) + self._api.create_branch(repo_url.repo_id, branch="from-initial", revision=initial_commit) + self._api.create_branch(repo_url.repo_id, branch="from-branch", revision="from-initial") # Checks branches start from expected commits self.assertEqual( @@ -1463,10 +1345,7 @@ def test_create_branch_from_revision(self, repo_url: RepoUrl) -> None: "from-initial": initial_commit, "from-branch": initial_commit, }, - { - ref.name: ref.target_commit - for ref in self._api.list_repo_refs(repo_id=repo_url.repo_id).branches - }, + {ref.name: ref.target_commit for ref in self._api.list_repo_refs(repo_id=repo_url.repo_id).branches}, ) @@ -1519,9 +1398,7 @@ def test_list_models_complex_query(self): # Let's list the 10 most recent models # with tags "bert" and "jax", # ordered by last modified date. - models = self._api.list_models( - filter=("bert", "jax"), sort="lastModified", direction=-1, limit=10 - ) + models = self._api.list_models(filter=("bert", "jax"), sort="lastModified", direction=-1, limit=10) # we have at least 1 models self.assertGreater(len(models), 1) self.assertLessEqual(len(models), 10) @@ -1530,15 +1407,11 @@ def test_list_models_complex_query(self): self.assertTrue(all(tag in model.tags for tag in ["bert", "jax"])) def test_list_models_with_config(self): - for model in self._api.list_models( - filter="adapter-transformers", fetch_config=True, limit=20 - ): + for model in self._api.list_models(filter="adapter-transformers", fetch_config=True, limit=20): self.assertIsNotNone(model.config) def test_list_models_without_config(self): - for model in self._api.list_models( - filter="adapter-transformers", fetch_config=False, limit=20 - ): + for model in self._api.list_models(filter="adapter-transformers", fetch_config=False, limit=20): self.assertIsNone(model.config) def test_model_info(self): @@ -1546,9 +1419,7 @@ def test_model_info(self): self.assertIsInstance(model, ModelInfo) self.assertNotEqual(model.sha, DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT) # One particular commit (not the top of `main`) - model = self._api.model_info( - repo_id=DUMMY_MODEL_ID, revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT - ) + model = self._api.model_info(repo_id=DUMMY_MODEL_ID, revision=DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT) self.assertIsInstance(model, ModelInfo) self.assertEqual(model.sha, DUMMY_MODEL_ID_REVISION_ONE_SPECIFIC_COMMIT) @@ -1637,42 +1508,32 @@ def test_filter_datasets_by_language_only(self): self.assertTrue("language:en" in datasets[0].tags) args = DatasetSearchArguments(api=self._api) - datasets = self._api.list_datasets( - filter=DatasetFilter(language=(args.language.en, args.language.fr)) - ) + datasets = self._api.list_datasets(filter=DatasetFilter(language=(args.language.en, args.language.fr))) self.assertGreater(len(datasets), 0) self.assertTrue("language:en" in datasets[0].tags) self.assertTrue("language:fr" in datasets[0].tags) @expect_deprecation("list_datasets") def test_filter_datasets_by_multilinguality(self): - datasets = self._api.list_datasets( - filter=DatasetFilter(multilinguality="multilingual") - ) + datasets = self._api.list_datasets(filter=DatasetFilter(multilinguality="multilingual")) self.assertGreater(len(datasets), 0) self.assertTrue("multilinguality:multilingual" in datasets[0].tags) @expect_deprecation("list_datasets") def test_filter_datasets_by_size_categories(self): - datasets = self._api.list_datasets( - filter=DatasetFilter(size_categories="100K 0 - ) - self.assertTrue( - isinstance(dataset.siblings, list) and len(dataset.siblings) > 0 - ) + self.assertTrue(isinstance(dataset.cardData, dict) and len(dataset.cardData) > 0) + self.assertTrue(isinstance(dataset.siblings, list) and len(dataset.siblings) > 0) self.assertIsInstance(dataset, DatasetInfo) self.assertNotEqual(dataset.sha, DUMMY_DATASET_ID_REVISION_ONE_SPECIFIC_COMMIT) dataset = self._api.dataset_info( @@ -1763,40 +1616,28 @@ def test_filter_models_by_author(self): @expect_deprecation("list_models") def test_filter_models_by_author_and_name(self): # Test we can search by an author and a name, but the model is not found - models = self._api.list_models( - filter=ModelFilter("facebook", model_name="bart-base") - ) + models = self._api.list_models(filter=ModelFilter("facebook", model_name="bart-base")) self.assertTrue("facebook/bart-base" in models[0].modelId) @expect_deprecation("list_models") def test_failing_filter_models_by_author_and_model_name(self): # Test we can search by an author and a name, but the model is not found - models = self._api.list_models( - filter=ModelFilter(author="muellerzr", model_name="testme") - ) + models = self._api.list_models(filter=ModelFilter(author="muellerzr", model_name="testme")) self.assertEqual(len(models), 0) @expect_deprecation("list_models") def test_filter_models_with_library(self): models = self._api.list_models( - filter=ModelFilter( - "microsoft", model_name="wavlm-base-sd", library="tensorflow" - ) + filter=ModelFilter("microsoft", model_name="wavlm-base-sd", library="tensorflow") ) self.assertEqual(len(models), 0) - models = self._api.list_models( - filter=ModelFilter( - "microsoft", model_name="wavlm-base-sd", library="pytorch" - ) - ) + models = self._api.list_models(filter=ModelFilter("microsoft", model_name="wavlm-base-sd", library="pytorch")) self.assertGreater(len(models), 0) @expect_deprecation("list_models") def test_filter_models_with_task(self): - models = self._api.list_models( - filter=ModelFilter(task="fill-mask", model_name="albert-base-v2") - ) + models = self._api.list_models(filter=ModelFilter(task="fill-mask", model_name="albert-base-v2")) self.assertTrue("fill-mask" == models[0].pipeline_tag) self.assertTrue("albert-base-v2" in models[0].modelId) @@ -1819,15 +1660,9 @@ def test_filter_models_with_complex_query(self): models = self._api.list_models(filter=f) self.assertGreater(len(models), 1) self.assertTrue( - [ - "text-classification" in model.pipeline_tag - or "text-classification" in model.tags - for model in models - ] - ) - self.assertTrue( - ["pytorch" in model.tags and "tf" in model.tags for model in models] + ["text-classification" in model.pipeline_tag or "text-classification" in model.tags for model in models] ) + self.assertTrue(["pytorch" in model.tags and "tf" in model.tags for model in models]) def test_filter_models_with_cardData(self): models = self._api.list_models(filter="co2_eq_emissions", cardData=True) @@ -1916,12 +1751,8 @@ def test_list_spaces_search(self): def test_list_spaces_sort_and_direction(self): spaces_descending_likes = self._api.list_spaces(sort="likes", direction=-1) spaces_ascending_likes = self._api.list_spaces(sort="likes") - self.assertGreater( - spaces_descending_likes[0].likes, spaces_descending_likes[1].likes - ) - self.assertLess( - spaces_ascending_likes[-2].likes, spaces_ascending_likes[-1].likes - ) + self.assertGreater(spaces_descending_likes[0].likes, spaces_descending_likes[1].likes) + self.assertLess(spaces_ascending_likes[-2].likes, spaces_ascending_likes[-1].likes) @expect_deprecation("list_spaces") def test_list_spaces_limit(self): @@ -1940,24 +1771,14 @@ def test_list_spaces_with_datasets(self): def test_list_spaces_linked(self): spaces = self._api.list_spaces(linked=True) + self.assertTrue(any((getattr(space, "models", None) is not None) for space in spaces)) + self.assertTrue(any((getattr(space, "datasets", None) is not None) for space in spaces)) self.assertTrue( - any((getattr(space, "models", None) is not None) for space in spaces) - ) - self.assertTrue( - any((getattr(space, "datasets", None) is not None) for space in spaces) - ) - self.assertTrue( - any( - (getattr(space, "models", None) is not None) - and getattr(space, "datasets", None) is not None - ) + any((getattr(space, "models", None) is not None) and getattr(space, "datasets", None) is not None) for space in spaces ) self.assertTrue( - all( - (getattr(space, "models", None) is not None) - or getattr(space, "datasets", None) is not None - ) + all((getattr(space, "models", None) is not None) or getattr(space, "datasets", None) is not None) for space in spaces ) @@ -1987,9 +1808,7 @@ def test_model_info(self): ): _ = self._api.model_info(repo_id=f"{USER}/{self.REPO_NAME}") - model_info = self._api.model_info( - repo_id=f"{USER}/{self.REPO_NAME}", use_auth_token=self._token - ) + model_info = self._api.model_info(repo_id=f"{USER}/{self.REPO_NAME}", use_auth_token=self._token) self.assertIsInstance(model_info, ModelInfo) def test_dataset_info(self): @@ -2005,9 +1824,7 @@ def test_dataset_info(self): ): _ = self._api.dataset_info(repo_id=f"{USER}/{self.REPO_NAME}") - dataset_info = self._api.dataset_info( - repo_id=f"{USER}/{self.REPO_NAME}", use_auth_token=self._token - ) + dataset_info = self._api.dataset_info(repo_id=f"{USER}/{self.REPO_NAME}", use_auth_token=self._token) self.assertIsInstance(dataset_info, DatasetInfo) @expect_deprecation("list_datasets") @@ -2046,29 +1863,21 @@ def setUp(self): self.REPO_NAME_LARGE_FILE = large_file_repo_name() if os.path.exists(WORKING_REPO_DIR): rmtree_with_retry(WORKING_REPO_DIR) - logger.info( - f"Does {WORKING_REPO_DIR} exist: {os.path.exists(WORKING_REPO_DIR)}" - ) + logger.info(f"Does {WORKING_REPO_DIR} exist: {os.path.exists(WORKING_REPO_DIR)}") def tearDown(self): self._api.delete_repo(repo_id=self.REPO_NAME_LARGE_FILE) def setup_local_clone(self, REMOTE_URL): - REMOTE_URL_AUTH = REMOTE_URL.replace( - ENDPOINT_STAGING, ENDPOINT_STAGING_BASIC_AUTH - ) + REMOTE_URL_AUTH = REMOTE_URL.replace(ENDPOINT_STAGING, ENDPOINT_STAGING_BASIC_AUTH) subprocess.run( ["git", "clone", REMOTE_URL_AUTH, WORKING_REPO_DIR], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) - subprocess.run( - ["git", "lfs", "track", "*.pdf"], check=True, cwd=WORKING_REPO_DIR - ) - subprocess.run( - ["git", "lfs", "track", "*.epub"], check=True, cwd=WORKING_REPO_DIR - ) + subprocess.run(["git", "lfs", "track", "*.pdf"], check=True, cwd=WORKING_REPO_DIR) + subprocess.run(["git", "lfs", "track", "*.epub"], check=True, cwd=WORKING_REPO_DIR) @retry_endpoint def test_end_to_end_thresh_6M(self): @@ -2085,9 +1894,7 @@ def test_end_to_end_thresh_6M(self): cwd=WORKING_REPO_DIR, ) subprocess.run(["git", "add", "*"], check=True, cwd=WORKING_REPO_DIR) - subprocess.run( - ["git", "commit", "-m", "commit message"], check=True, cwd=WORKING_REPO_DIR - ) + subprocess.run(["git", "commit", "-m", "commit message"], check=True, cwd=WORKING_REPO_DIR) # This will fail as we haven't set up our custom transfer agent yet. failed_process = subprocess.run( @@ -2100,9 +1907,7 @@ def test_end_to_end_thresh_6M(self): self.assertIn("cli lfs-enable-largefiles", failed_process.stderr.decode()) # ^ Instructions on how to fix this are included in the error message. - subprocess.run( - ["huggingface-cli", "lfs-enable-largefiles", WORKING_REPO_DIR], check=True - ) + subprocess.run(["huggingface-cli", "lfs-enable-largefiles", WORKING_REPO_DIR], check=True) start_time = time.time() subprocess.run(["git", "push"], check=True, cwd=WORKING_REPO_DIR) @@ -2150,9 +1955,7 @@ def test_end_to_end_thresh_16M(self): cwd=WORKING_REPO_DIR, ) - subprocess.run( - ["huggingface-cli", "lfs-enable-largefiles", WORKING_REPO_DIR], check=True - ) + subprocess.run(["huggingface-cli", "lfs-enable-largefiles", WORKING_REPO_DIR], check=True) start_time = time.time() subprocess.run(["git", "push"], check=True, cwd=WORKING_REPO_DIR) @@ -2211,18 +2014,14 @@ def tearDown(self): super().tearDown() def test_create_discussion(self): - discussion = self._api.create_discussion( - repo_id=self.repo_name, title=" Test discussion ! " - ) + discussion = self._api.create_discussion(repo_id=self.repo_name, title=" Test discussion ! ") self.assertEqual(discussion.num, 3) self.assertEqual(discussion.author, USER) self.assertEqual(discussion.is_pull_request, False) self.assertEqual(discussion.title, "Test discussion !") def test_create_pull_request(self): - discussion = self._api.create_discussion( - repo_id=self.repo_name, title=" Test PR ! ", pull_request=True - ) + discussion = self._api.create_discussion(repo_id=self.repo_name, title=" Test PR ! ", pull_request=True) self.assertEqual(discussion.num, 3) self.assertEqual(discussion.author, USER) self.assertEqual(discussion.is_pull_request, True) @@ -2243,9 +2042,7 @@ def test_get_repo_discussion(self): ) def test_get_discussion_details(self): - retrieved = self._api.get_discussion_details( - repo_id=self.repo_name, discussion_num=2 - ) + retrieved = self._api.get_discussion_details(repo_id=self.repo_name, discussion_num=2) self.assertEqual(retrieved, self.discussion) def test_edit_discussion_comment(self): @@ -2258,13 +2055,9 @@ def get_first_comment(discussion: DiscussionWithDetails) -> DiscussionComment: comment_id=get_first_comment(self.pull_request).id, new_content="**Edited** comment 🤗", ) - retrieved = self._api.get_discussion_details( - repo_id=self.repo_name, discussion_num=self.pull_request.num - ) + retrieved = self._api.get_discussion_details(repo_id=self.repo_name, discussion_num=self.pull_request.num) self.assertEqual(get_first_comment(retrieved).edited, True) - self.assertEqual( - get_first_comment(retrieved).id, get_first_comment(self.pull_request).id - ) + self.assertEqual(get_first_comment(retrieved).id, get_first_comment(self.pull_request).id) self.assertEqual(get_first_comment(retrieved).content, "**Edited** comment 🤗") self.assertEqual(get_first_comment(retrieved), edited_comment) @@ -2280,9 +2073,7 @@ def test_comment_discussion(self): And even [links](http://hf.co)! 💥🤯 """, ) - retrieved = self._api.get_discussion_details( - repo_id=self.repo_name, discussion_num=self.discussion.num - ) + retrieved = self._api.get_discussion_details(repo_id=self.repo_name, discussion_num=self.discussion.num) self.assertEqual(len(retrieved.events), 2) self.assertIn(new_comment.id, {event.id for event in retrieved.events}) @@ -2292,9 +2083,7 @@ def test_rename_discussion(self): discussion_num=self.discussion.num, new_title="New titlee", ) - retrieved = self._api.get_discussion_details( - repo_id=self.repo_name, discussion_num=self.discussion.num - ) + retrieved = self._api.get_discussion_details(repo_id=self.repo_name, discussion_num=self.discussion.num) self.assertIn(rename_event, retrieved.events) self.assertEqual(rename_event.old_title, self.discussion.title) self.assertEqual(rename_event.new_title, "New titlee") @@ -2305,9 +2094,7 @@ def test_change_discussion_status(self): discussion_num=self.discussion.num, new_status="closed", ) - retrieved = self._api.get_discussion_details( - repo_id=self.repo_name, discussion_num=self.discussion.num - ) + retrieved = self._api.get_discussion_details(repo_id=self.repo_name, discussion_num=self.discussion.num) self.assertIn(status_change_event, retrieved.events) self.assertEqual(status_change_event.new_status, "closed") @@ -2323,9 +2110,7 @@ def test_merge_pull_request(self): self._api.create_commit( repo_id=self.repo_name, commit_message="Commit some file", - operations=[ - CommitOperationAdd(path_in_repo="file.test", path_or_fileobj=b"Content") - ], + operations=[CommitOperationAdd(path_in_repo="file.test", path_or_fileobj=b"Content")], revision=self.pull_request.git_reference, ) self._api.change_discussion_status( @@ -2335,9 +2120,7 @@ def test_merge_pull_request(self): ) self._api.merge_pull_request(self.repo_name, self.pull_request.num) - retrieved = self._api.get_discussion_details( - repo_id=self.repo_name, discussion_num=self.pull_request.num - ) + retrieved = self._api.get_discussion_details(repo_id=self.repo_name, discussion_num=self.pull_request.num) self.assertEqual(retrieved.status, "merged") self.assertIsNotNone(retrieved.merge_commit_oid) @@ -2413,9 +2196,7 @@ def test_list_liked_repos_no_auth(self) -> None: # Fetch liked repos without auth likes = self.api.list_liked_repos(USER) self.assertEqual(likes.user, USER) - self.assertGreater( - len(likes.models) + len(likes.datasets) + len(likes.spaces), 0 - ) + self.assertGreater(len(likes.models) + len(likes.datasets) + len(likes.spaces), 0) self.assertIn(repo_url.repo_id, likes.models) def test_list_likes_repos_auth_and_implicit_user(self) -> None: @@ -2425,18 +2206,14 @@ def test_list_likes_repos_auth_and_implicit_user(self) -> None: def test_list_likes_repos_auth_and_explicit_user(self) -> None: # User is explicit even if auth - likes = self.api.list_liked_repos( - user="__DUMMY_DATASETS_SERVER_USER__", token=TOKEN - ) + likes = self.api.list_liked_repos(user="__DUMMY_DATASETS_SERVER_USER__", token=TOKEN) self.assertEqual(likes.user, "__DUMMY_DATASETS_SERVER_USER__") @with_production_testing def test_list_likes_on_production(self) -> None: # Test julien-c likes a lot of repos ! likes = HfApi().list_liked_repos("julien-c") - self.assertEqual( - len(likes.models) + len(likes.datasets) + len(likes.spaces), likes.total - ) + self.assertEqual(len(likes.models) + len(likes.datasets) + len(likes.spaces), likes.total) self.assertGreater(len(likes.models), 0) self.assertGreater(len(likes.datasets), 0) self.assertGreater(len(likes.spaces), 0) @@ -2547,9 +2324,7 @@ def test_list_refs_bigcode(self) -> None: main_branch = [branch for branch in refs.branches if branch.name == "main"][0] self.assertEqual(main_branch.ref, "refs/heads/main") - convert_branch = [ - branch for branch in refs.converts if branch.name == "parquet" - ][0] + convert_branch = [branch for branch in refs.converts if branch.name == "parquet"][0] self.assertEqual(convert_branch.ref, "refs/convert/parquet") # Can get info by convert revision @@ -2582,21 +2357,15 @@ def test_no_token_at_all(self, mock_build_hf_headers: Mock) -> None: HfApi()._build_hf_headers(token=None) self._assert_token_is(mock_build_hf_headers, None) - def _assert_token_is( - self, mock_build_hf_headers: Mock, expected_value: str - ) -> None: + def _assert_token_is(self, mock_build_hf_headers: Mock, expected_value: str) -> None: self.assertEqual(mock_build_hf_headers.call_args[1]["token"], expected_value) - def test_library_name_and_version_are_set( - self, mock_build_hf_headers: Mock - ) -> None: + def test_library_name_and_version_are_set(self, mock_build_hf_headers: Mock) -> None: HfApi(library_name="a", library_version="b")._build_hf_headers() self.assertEqual(mock_build_hf_headers.call_args[1]["library_name"], "a") self.assertEqual(mock_build_hf_headers.call_args[1]["library_version"], "b") - def test_library_name_and_version_are_overwritten( - self, mock_build_hf_headers: Mock - ) -> None: + def test_library_name_and_version_are_overwritten(self, mock_build_hf_headers: Mock) -> None: api = HfApi(library_name="a", library_version="b") api._build_hf_headers(library_name="A", library_version="B") self.assertEqual(mock_build_hf_headers.call_args[1]["library_name"], "A") @@ -2630,8 +2399,10 @@ def test_repo_url_class(self): # __repr__ is modified for debugging purposes self.assertEqual( repr(url), - "RepoUrl('https://huggingface.co/gpt2', endpoint='https://huggingface.co'," - " repo_type='model', repo_id='gpt2')", + ( + "RepoUrl('https://huggingface.co/gpt2'," + " endpoint='https://huggingface.co', repo_type='model', repo_id='gpt2')" + ), ) def test_repo_url_endpoint(self): diff --git a/tests/test_hubmixin.py b/tests/test_hubmixin.py index 46305e6ab9..a27163722c 100644 --- a/tests/test_hubmixin.py +++ b/tests/test_hubmixin.py @@ -13,9 +13,7 @@ logger = logging.get_logger(__name__) WORKING_REPO_SUBDIR = "fixtures/working_repo_2" -WORKING_REPO_DIR = os.path.join( - os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR -) +WORKING_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR) if is_torch_available(): import torch.nn as nn @@ -59,9 +57,7 @@ class HubMixingTest(HubMixingCommonTest): def tearDown(self) -> None: if os.path.exists(WORKING_REPO_DIR): rmtree_with_retry(WORKING_REPO_DIR) - logger.info( - f"Does {WORKING_REPO_DIR} exist: {os.path.exists(WORKING_REPO_DIR)}" - ) + logger.info(f"Does {WORKING_REPO_DIR} exist: {os.path.exists(WORKING_REPO_DIR)}") @classmethod @expect_deprecation("set_access_token") @@ -82,9 +78,7 @@ def test_save_pretrained(self): self.assertTrue("pytorch_model.bin" in files) self.assertEqual(len(files), 1) - model.save_pretrained( - f"{WORKING_REPO_DIR}/{REPO_NAME}", config={"num": 12, "act": "gelu"} - ) + model.save_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}", config={"num": 12, "act": "gelu"}) files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") self.assertTrue("config.json" in files) self.assertTrue("pytorch_model.bin" in files) @@ -103,9 +97,7 @@ def test_save_pretrained_with_push_to_hub(self): mocked_model.push_to_hub.assert_not_called() # Push to hub with repo_id - mocked_model.save_pretrained( - save_directory, push_to_hub=True, repo_id="CustomID", config=config - ) + mocked_model.save_pretrained(save_directory, push_to_hub=True, repo_id="CustomID", config=config) mocked_model.push_to_hub.assert_called_with(repo_id="CustomID", config=config) # Push to hub with default repo_id (based on dir name) @@ -119,9 +111,7 @@ def test_rel_path_from_pretrained(self): config={"num": 10, "act": "gelu_fast"}, ) - model = DummyModel.from_pretrained( - f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED" - ) + model = DummyModel.from_pretrained(f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED") self.assertTrue(model.config == {"num": 10, "act": "gelu_fast"}) def test_abs_path_from_pretrained(self): @@ -151,9 +141,7 @@ def test_push_to_hub_via_http_basic(self): self.assertEqual(model_info.modelId, repo_id) # Test config has been pushed to hub - tmp_config_path = hf_hub_download( - repo_id=repo_id, filename="config.json", use_auth_token=self._token - ) + tmp_config_path = hf_hub_download(repo_id=repo_id, filename="config.json", use_auth_token=self._token) with open(tmp_config_path) as f: self.assertEqual(json.load(f), {"num": 7, "act": "gelu_fast"}) diff --git a/tests/test_inference_api.py b/tests/test_inference_api.py index cc8b605f8e..d673062942 100644 --- a/tests/test_inference_api.py +++ b/tests/test_inference_api.py @@ -31,9 +31,7 @@ def read(self, filename: str) -> bytes: @classmethod @with_production_testing def setUpClass(cls) -> None: - cls.image_file = hf_hub_download( - repo_id="Narsil/image_dummy", repo_type="dataset", filename="lena.png" - ) + cls.image_file = hf_hub_download(repo_id="Narsil/image_dummy", repo_type="dataset", filename="lena.png") return super().setUpClass() def test_simple_inference(self): @@ -49,10 +47,7 @@ def test_simple_inference(self): def test_inference_with_params(self): api = InferenceApi("typeform/distilbert-base-uncased-mnli") - inputs = ( - "I bought a device but it is not working and I would like to get" - " reimbursed!" - ) + inputs = "I bought a device but it is not working and I would like to get reimbursed!" params = {"candidate_labels": ["refund", "legal", "faq"]} result = api(inputs, params) self.assertIsInstance(result, dict) @@ -119,9 +114,7 @@ def test_inference_overriding_task(self): self.assertIsInstance(result, list) def test_inference_overriding_invalid_task(self): - with self.assertRaises( - ValueError, msg="Invalid task invalid-task. Make sure it's valid." - ): + with self.assertRaises(ValueError, msg="Invalid task invalid-task. Make sure it's valid."): InferenceApi("bert-base-uncased", task="invalid-task") def test_inference_missing_input(self): diff --git a/tests/test_init_lazy_loading.py b/tests/test_init_lazy_loading.py index e7c92383cb..8faa320673 100644 --- a/tests/test_init_lazy_loading.py +++ b/tests/test_init_lazy_loading.py @@ -28,11 +28,7 @@ def test_autocomplete_on_root_imports(self) -> None: # the help section. signature_list = goto_list[0].get_signatures() self.assertEqual(len(signature_list), 1) - self.assertTrue( - signature_list[0] - .docstring() - .startswith("create_commit(repo_id: str,") - ) + self.assertTrue(signature_list[0].docstring().startswith("create_commit(repo_id: str,")) break else: self.fail( diff --git a/tests/test_keras_integration.py b/tests/test_keras_integration.py index 195b215bd7..c6ca14b6e0 100644 --- a/tests/test_keras_integration.py +++ b/tests/test_keras_integration.py @@ -33,13 +33,9 @@ logger = logging.get_logger(__name__) WORKING_REPO_SUBDIR = f"fixtures/working_repo_{__name__.split('.')[-1]}" -WORKING_REPO_DIR = os.path.join( - os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR -) +WORKING_REPO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), WORKING_REPO_SUBDIR) -PUSH_TO_HUB_KERAS_WARNING_REGEX = re.escape( - "Deprecated argument(s) used in 'push_to_hub_keras':" -) +PUSH_TO_HUB_KERAS_WARNING_REGEX = re.escape("Deprecated argument(s) used in 'push_to_hub_keras':") if is_tf_available(): import tensorflow as tf @@ -79,9 +75,7 @@ class CommonKerasTest(unittest.TestCase): def tearDown(self) -> None: if os.path.exists(WORKING_REPO_DIR): rmtree_with_retry(WORKING_REPO_DIR) - logger.info( - f"Does {WORKING_REPO_DIR} exist: {os.path.exists(WORKING_REPO_DIR)}" - ) + logger.info(f"Does {WORKING_REPO_DIR} exist: {os.path.exists(WORKING_REPO_DIR)}") @classmethod @expect_deprecation("set_access_token") @@ -107,9 +101,7 @@ def test_save_pretrained(self): self.assertTrue("model.png" in files) self.assertEqual(len(files), 7) - model.save_pretrained( - f"{WORKING_REPO_DIR}/{REPO_NAME}", config={"num": 12, "act": "gelu"} - ) + model.save_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}", config={"num": 12, "act": "gelu"}) files = os.listdir(f"{WORKING_REPO_DIR}/{REPO_NAME}") self.assertTrue("config.json" in files) self.assertTrue("saved_model.pb" in files) @@ -128,11 +120,7 @@ def test_keras_from_pretrained_weights(self): # Check a new model's weights are not the same as the reloaded model's weights another_model = DummyModel() another_model(tf.ones([2, 2])) - self.assertFalse( - tf.reduce_all(tf.equal(new_model.weights[0], another_model.weights[0])) - .numpy() - .item() - ) + self.assertFalse(tf.reduce_all(tf.equal(new_model.weights[0], another_model.weights[0])).numpy().item()) def test_rel_path_from_pretrained(self): model = DummyModel() @@ -142,18 +130,14 @@ def test_rel_path_from_pretrained(self): config={"num": 10, "act": "gelu_fast"}, ) - model = DummyModel.from_pretrained( - f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED" - ) + model = DummyModel.from_pretrained(f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED") self.assertTrue(model.config == {"num": 10, "act": "gelu_fast"}) def test_abs_path_from_pretrained(self): REPO_NAME = repo_name("FROM_PRETRAINED") model = DummyModel() model(model.dummy_inputs) - model.save_pretrained( - f"{WORKING_REPO_DIR}/{REPO_NAME}", config={"num": 10, "act": "gelu_fast"} - ) + model.save_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}", config={"num": 10, "act": "gelu_fast"}) model = DummyModel.from_pretrained(f"{WORKING_REPO_DIR}/{REPO_NAME}") self.assertDictEqual(model.config, {"num": 10, "act": "gelu_fast"}) @@ -178,9 +162,7 @@ def test_push_to_hub_keras_mixin_via_http_basic(self): self.assertEqual(model_info.modelId, repo_id) # Test config has been pushed to hub - tmp_config_path = hf_hub_download( - repo_id=repo_id, filename="config.json", use_auth_token=self._token - ) + tmp_config_path = hf_hub_download(repo_id=repo_id, filename="config.json", use_auth_token=self._token) with open(tmp_config_path) as f: self.assertEqual(json.load(f), {"num": 7, "act": "gelu_fast"}) @@ -254,22 +236,16 @@ def test_save_model_card_history_removal(self): model = self.model_fit(model) with SoftTemporaryDirectory() as tmpdirname: os.makedirs(f"{tmpdirname}/{WORKING_REPO_DIR}/{REPO_NAME}") - with open( - f"{tmpdirname}/{WORKING_REPO_DIR}/{REPO_NAME}/history.json", "w+" - ) as fp: + with open(f"{tmpdirname}/{WORKING_REPO_DIR}/{REPO_NAME}/history.json", "w+") as fp: fp.write("Keras FTW") - with pytest.warns( - UserWarning, match="`history.json` file already exists, *" - ): + with pytest.warns(UserWarning, match="`history.json` file already exists, *"): save_pretrained_keras( model, f"{tmpdirname}/{WORKING_REPO_DIR}/{REPO_NAME}", ) # assert that it's not the same as old history file and it's overridden - with open( - f"{tmpdirname}/{WORKING_REPO_DIR}/{REPO_NAME}/history.json", "r" - ) as f: + with open(f"{tmpdirname}/{WORKING_REPO_DIR}/{REPO_NAME}/history.json", "r") as f: history_content = f.read() self.assertNotEqual("Keras FTW", history_content) @@ -279,9 +255,7 @@ def test_save_model_card_history_removal(self): # Check that there is no "Training Metrics" section in the model card. # This was done in an older version. - with open( - f"{tmpdirname}/{WORKING_REPO_DIR}/{REPO_NAME}/README.md", "r" - ) as file: + with open(f"{tmpdirname}/{WORKING_REPO_DIR}/{REPO_NAME}/README.md", "r") as file: data = file.read() self.assertNotIn(data, "Training Metrics") @@ -290,9 +264,7 @@ def test_save_pretrained_optimizer_state(self): model = self.model_init() model.build((None, 2)) - save_pretrained_keras( - model, f"{WORKING_REPO_DIR}/{REPO_NAME}", include_optimizer=True - ) + save_pretrained_keras(model, f"{WORKING_REPO_DIR}/{REPO_NAME}", include_optimizer=True) loaded_model = from_pretrained_keras(f"{WORKING_REPO_DIR}/{REPO_NAME}") self.assertIsNotNone(loaded_model.optimizer) @@ -324,11 +296,7 @@ def test_from_pretrained_weights(self): # Check a new model's weights are not the same as the reloaded model's weights another_model = DummyModel() another_model(tf.ones([2, 2])) - self.assertFalse( - tf.reduce_all(tf.equal(new_model.weights[0], another_model.weights[0])) - .numpy() - .item() - ) + self.assertFalse(tf.reduce_all(tf.equal(new_model.weights[0], another_model.weights[0])).numpy().item()) def test_save_pretrained_task_name_deprecation(self): REPO_NAME = repo_name("save") @@ -356,9 +324,7 @@ def test_rel_path_from_pretrained(self): config={"num": 10, "act": "gelu_fast"}, ) - new_model = from_pretrained_keras( - f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED" - ) + new_model = from_pretrained_keras(f"tests/{WORKING_REPO_SUBDIR}/FROM_PRETRAINED") # Check the reloaded model's weights match the original model's weights self.assertTrue(tf.reduce_all(tf.equal(new_model.weights[0], model.weights[0]))) @@ -389,9 +355,7 @@ def test_push_to_hub_keras_sequential_via_http_basic(self): model = self.model_init() model = self.model_fit(model) - push_to_hub_keras( - model, repo_id=repo_id, token=self._token, api_endpoint=ENDPOINT_STAGING - ) + push_to_hub_keras(model, repo_id=repo_id, token=self._token, api_endpoint=ENDPOINT_STAGING) model_info = HfApi(endpoint=ENDPOINT_STAGING).model_info(repo_id) self.assertEqual(model_info.modelId, repo_id) self.assertTrue("README.md" in [f.rfilename for f in model_info.siblings]) @@ -447,12 +411,8 @@ def test_push_to_hub_keras_via_http_override_tensorboard(self): ) model_info = self._api.model_info(repo_id) - self.assertTrue( - "logs/override.txt" in [f.rfilename for f in model_info.siblings] - ) - self.assertFalse( - "logs/tensorboard.txt" in [f.rfilename for f in model_info.siblings] - ) + self.assertTrue("logs/override.txt" in [f.rfilename for f in model_info.siblings]) + self.assertFalse("logs/tensorboard.txt" in [f.rfilename for f in model_info.siblings]) self._api.delete_repo(repo_id=repo_id) @@ -476,9 +436,7 @@ def test_push_to_hub_keras_via_http_with_model_kwargs(self): self.assertEqual(model_info.modelId, repo_id) with SoftTemporaryDirectory() as tmpdirname: - Repository( - local_dir=tmpdirname, clone_from=ENDPOINT_STAGING + "/" + repo_id - ) + Repository(local_dir=tmpdirname, clone_from=ENDPOINT_STAGING + "/" + repo_id) from_pretrained_keras(tmpdirname) self._api.delete_repo(repo_id=f"{REPO_NAME}") diff --git a/tests/test_lfs.py b/tests/test_lfs.py index d80d7f7463..b0f9c18b29 100644 --- a/tests/test_lfs.py +++ b/tests/test_lfs.py @@ -118,9 +118,7 @@ def test_slice_fileobj_file(self): with open(filepath, "rb") as fileobj: prev_pos = fileobj.tell() # Test read - with SliceFileObj( - fileobj, seek_from=24, read_limit=18 - ) as fileobj_slice: + with SliceFileObj(fileobj, seek_from=24, read_limit=18) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.read(), self.content[24:42]) self.assertEqual(fileobj_slice.tell(), 18) @@ -129,9 +127,7 @@ def test_slice_fileobj_file(self): self.assertEqual(fileobj.tell(), prev_pos) - with SliceFileObj( - fileobj, seek_from=0, read_limit=990 - ) as fileobj_slice: + with SliceFileObj(fileobj, seek_from=0, read_limit=990) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) self.assertEqual(fileobj_slice.read(200), self.content[0:200]) self.assertEqual(fileobj_slice.read(500), self.content[200:700]) @@ -140,9 +136,7 @@ def test_slice_fileobj_file(self): self.assertEqual(fileobj_slice.read(200), b"") # Test seek with whence = os.SEEK_SET - with SliceFileObj( - fileobj, seek_from=100, read_limit=100 - ) as fileobj_slice: + with SliceFileObj(fileobj, seek_from=100, read_limit=100) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) fileobj_slice.seek(2, os.SEEK_SET) self.assertEqual(fileobj_slice.tell(), 2) @@ -155,9 +149,7 @@ def test_slice_fileobj_file(self): self.assertEqual(fileobj_slice.fileobj.tell(), 200) # Test seek with whence = os.SEEK_CUR - with SliceFileObj( - fileobj, seek_from=100, read_limit=100 - ) as fileobj_slice: + with SliceFileObj(fileobj, seek_from=100, read_limit=100) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) fileobj_slice.seek(-5, os.SEEK_CUR) self.assertEqual(fileobj_slice.tell(), 0) @@ -173,9 +165,7 @@ def test_slice_fileobj_file(self): self.assertEqual(fileobj_slice.fileobj.tell(), 100) # Test seek with whence = os.SEEK_END - with SliceFileObj( - fileobj, seek_from=100, read_limit=100 - ) as fileobj_slice: + with SliceFileObj(fileobj, seek_from=100, read_limit=100) as fileobj_slice: self.assertEqual(fileobj_slice.tell(), 0) fileobj_slice.seek(-5, os.SEEK_END) self.assertEqual(fileobj_slice.tell(), 95) diff --git a/tests/test_login_utils.py b/tests/test_login_utils.py index 65e0db167b..f2d524db48 100644 --- a/tests/test_login_utils.py +++ b/tests/test_login_utils.py @@ -12,9 +12,7 @@ class TestSetGlobalStore(unittest.TestCase): def setUp(self) -> None: """Get current global config value.""" try: - self.previous_config = run_subprocess( - "git config --global credential.helper" - ).stdout + self.previous_config = run_subprocess("git config --global credential.helper").stdout except subprocess.CalledProcessError: self.previous_config = None # Means global credential.helper value not set @@ -25,9 +23,7 @@ def tearDown(self) -> None: if self.previous_config is None: run_subprocess("git config --global --unset credential.helper") else: - run_subprocess( - f"git config --global credential.helper {self.previous_config}" - ) + run_subprocess(f"git config --global credential.helper {self.previous_config}") def test_set_store_as_git_credential_helper_globally(self) -> None: """Test `_set_store_as_git_credential_helper_globally` works as expected. diff --git a/tests/test_offline_utils.py b/tests/test_offline_utils.py index 076bb29a7d..cb9bf28fa2 100644 --- a/tests/test_offline_utils.py +++ b/tests/test_offline_utils.py @@ -1,8 +1,8 @@ from io import BytesIO import pytest - import requests + from huggingface_hub.file_download import http_get from .testing_utils import ( diff --git a/tests/test_repocard.py b/tests/test_repocard.py index 47d0967204..5730805c0d 100644 --- a/tests/test_repocard.py +++ b/tests/test_repocard.py @@ -20,9 +20,9 @@ from pathlib import Path import pytest - import requests import yaml + from huggingface_hub import ( DatasetCard, DatasetCardData, @@ -167,9 +167,7 @@ logger = logging.get_logger(__name__) -REPOCARD_DIR = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "fixtures/repocard" -) +REPOCARD_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/repocard") repo_name = partial(repo_name, prefix="dummy-hf-hub") @@ -285,9 +283,7 @@ def setUp(self) -> None: git_user="ci", git_email="ci@dummy.com", ) - self.existing_metadata = yaml.safe_load( - DUMMY_MODELCARD_EVAL_RESULT.strip().strip("-") - ) + self.existing_metadata = yaml.safe_load(DUMMY_MODELCARD_EVAL_RESULT.strip().strip("-")) def tearDown(self) -> None: self._api.delete_repo(repo_id=self.repo_id) @@ -305,9 +301,7 @@ def test_update_dataset_name(self): def test_update_existing_result_with_overwrite(self): new_metadata = copy.deepcopy(self.existing_metadata) - new_metadata["model-index"][0]["results"][0]["metrics"][0][ - "value" - ] = 0.2862102282047272 + new_metadata["model-index"][0]["results"][0]["metrics"][0]["value"] = 0.2862102282047272 metadata_update(self.repo_id, new_metadata, token=self._token, overwrite=True) self.repo.git_pull() @@ -320,9 +314,7 @@ def test_update_verify_token(self): Regression test for https://github.com/huggingface/huggingface_hub/issues/1210 """ new_metadata = copy.deepcopy(self.existing_metadata) - new_metadata["model-index"][0]["results"][0]["metrics"][0][ - "verifyToken" - ] = "1234" + new_metadata["model-index"][0]["results"][0]["metrics"][0]["verifyToken"] = "1234" metadata_update(self.repo_id, new_metadata, token=self._token, overwrite=True) self.repo.git_pull() updated_metadata = metadata_load(self.repo_path / self.REPO_NAME / "README.md") @@ -345,9 +337,7 @@ def test_metadata_update_upstream(self): def test_update_existing_result_without_overwrite(self): new_metadata = copy.deepcopy(self.existing_metadata) - new_metadata["model-index"][0]["results"][0]["metrics"][0][ - "value" - ] = 0.2862102282047272 + new_metadata["model-index"][0]["results"][0]["metrics"][0]["value"] = 0.2862102282047272 with pytest.raises( ValueError, @@ -356,9 +346,7 @@ def test_update_existing_result_without_overwrite(self): " accuracy'. Set `overwrite=True` to overwrite existing metrics." ), ): - metadata_update( - self.repo_id, new_metadata, token=self._token, overwrite=False - ) + metadata_update(self.repo_id, new_metadata, token=self._token, overwrite=False) def test_update_existing_field_without_overwrite(self): new_datasets_data = {"datasets": "['test/test_dataset']"} @@ -425,9 +413,7 @@ def test_update_new_result_new_dataset(self): metadata_update(self.repo_id, new_result, token=self._token, overwrite=False) expected_metadata = copy.deepcopy(self.existing_metadata) - expected_metadata["model-index"][0]["results"].append( - new_result["model-index"][0]["results"][0] - ) + expected_metadata["model-index"][0]["results"].append(new_result["model-index"][0]["results"][0]) self.repo.git_pull() updated_metadata = metadata_load(self.repo_path / self.REPO_NAME / "README.md") self.assertDictEqual(updated_metadata, expected_metadata) @@ -452,15 +438,11 @@ def test_update_metadata_on_empty_text_content(self) -> None: def test_update_with_existing_name(self): new_metadata = copy.deepcopy(self.existing_metadata) new_metadata["model-index"][0].pop("name") - new_metadata["model-index"][0]["results"][0]["metrics"][0][ - "value" - ] = 0.2862102282047272 + new_metadata["model-index"][0]["results"][0]["metrics"][0]["value"] = 0.2862102282047272 metadata_update(self.repo_id, new_metadata, token=self._token, overwrite=True) card_data = ModelCard.load(self.repo_id, token=self._token) - self.assertEqual( - card_data.data.model_name, self.existing_metadata["model-index"][0]["name"] - ) + self.assertEqual(card_data.data.model_name, self.existing_metadata["model-index"][0]["name"]) def test_update_without_existing_name(self): # delete existing metadata @@ -598,9 +580,7 @@ def test_change_repocard_data(self): card.save(updated_card_path) updated_card = RepoCard.load(updated_card_path) - self.assertEqual( - updated_card.data.language, ["fr"], "Card data not updated properly" - ) + self.assertEqual(updated_card.data.language, ["fr"], "Card data not updated properly") @require_jinja def test_repo_card_from_default_template(self): @@ -661,9 +641,7 @@ def test_repo_card_from_custom_template(self): def test_repo_card_data_must_be_dict(self): sample_path = SAMPLE_CARDS_DIR / "sample_invalid_card_data.md" - with pytest.raises( - ValueError, match="repo card metadata block should be a dict" - ): + with pytest.raises(ValueError, match="repo card metadata block should be a dict"): RepoCard(sample_path.read_text()) def test_repo_card_without_metadata(self): @@ -874,9 +852,7 @@ def test_model_card_from_template_eval_results(self): self.assertIsInstance(card, ModelCard) self.assertTrue(card.text.endswith("asdf")) self.assertTrue(card.data.to_dict().get("eval_results") is None) - self.assertEqual( - str(card)[: len(DUMMY_MODELCARD_EVAL_RESULT)], DUMMY_MODELCARD_EVAL_RESULT - ) + self.assertEqual(str(card)[: len(DUMMY_MODELCARD_EVAL_RESULT)], DUMMY_MODELCARD_EVAL_RESULT) class DatasetCardTest(TestCaseWithCapLog): @@ -920,9 +896,7 @@ def test_dataset_card_from_default_template(self): # Here we pass the card data as kwargs as well so template picks up pretty_name. card = DatasetCard.from_template(card_data, **card_data.to_dict()) - self.assertTrue( - card.text.strip().startswith("# Dataset Card for My Cool Dataset") - ) + self.assertTrue(card.text.strip().startswith("# Dataset Card for My Cool Dataset")) self.assertIsInstance(card, DatasetCard) @@ -946,9 +920,7 @@ def test_dataset_card_from_default_template_with_template_variables(self): "in the dataset card template are working." ), ) - self.assertTrue( - card.text.strip().startswith("# Dataset Card for My Cool Dataset") - ) + self.assertTrue(card.text.strip().startswith("# Dataset Card for My Cool Dataset")) self.assertIsInstance(card, DatasetCard) matches = re.findall(r"Homepage:\*\* https:\/\/huggingface\.co", str(card)) diff --git a/tests/test_repocard_data.py b/tests/test_repocard_data.py index 3f57e27987..c1e8894a29 100644 --- a/tests/test_repocard_data.py +++ b/tests/test_repocard_data.py @@ -1,8 +1,8 @@ import unittest import pytest - import yaml + from huggingface_hub.repocard_data import ( DatasetCardData, EvalResult, @@ -112,9 +112,7 @@ def test_model_index_to_eval_results(self): self.assertEqual(eval_results[2].verify_token, 1234) def test_card_data_requires_model_name_for_eval_results(self): - with pytest.raises( - ValueError, match="`eval_results` requires `model_name` to be set." - ): + with pytest.raises(ValueError, match="`eval_results` requires `model_name` to be set."): ModelCardData( eval_results=[ EvalResult( @@ -143,9 +141,7 @@ def test_card_data_requires_model_name_for_eval_results(self): model_index = eval_results_to_model_index(data.model_name, data.eval_results) self.assertEqual(model_index[0]["name"], "my-cool-model") - self.assertEqual( - model_index[0]["results"][0]["task"]["type"], "image-classification" - ) + self.assertEqual(model_index[0]["results"][0]["task"]["type"], "image-classification") def test_abitrary_incoming_card_data(self): data = ModelCardData( diff --git a/tests/test_repository.py b/tests/test_repository.py index 3ffe715cdb..3586129156 100644 --- a/tests/test_repository.py +++ b/tests/test_repository.py @@ -18,8 +18,8 @@ from pathlib import Path import pytest - import requests + from huggingface_hub import RepoUrl from huggingface_hub.hf_api import HfApi from huggingface_hub.repository import ( @@ -125,9 +125,7 @@ def test_clone_from_not_hf_url(self): # Should not error out Repository( self.repo_path, - clone_from=( - "https://hf.co/hf-internal-testing/huggingface-hub-dummy-repository" - ), + clone_from="https://hf.co/hf-internal-testing/huggingface-hub-dummy-repository", ) def test_clone_from_missing_repo(self): @@ -203,9 +201,7 @@ def test_is_tracked_upstream(self): def test_push_errors_on_wrong_checkout(self): repo = Repository(self.repo_path, clone_from=self.repo_id) - head_commit_ref = run_subprocess( - "git show --oneline -s", folder=self.repo_path - ).stdout.split()[0] + head_commit_ref = run_subprocess("git show --oneline -s", folder=self.repo_path).stdout.split()[0] repo.git_checkout(head_commit_ref) @@ -398,9 +394,7 @@ def test_commit_context_manager(self): @retry_endpoint def test_clone_skip_lfs_files(self): # Upload LFS file - self._api.upload_file( - path_or_fileobj=b"Bin file", path_in_repo="file.bin", repo_id=self.repo_id - ) + self._api.upload_file(path_or_fileobj=b"Bin file", path_in_repo="file.bin", repo_id=self.repo_id) repo = self.clone_repo(skip_lfs_files=True) file_bin = self.repo_path / "file.bin" @@ -553,13 +547,9 @@ def test_lfs_prune(self): f.write("Random string 2") root_directory = self.repo_path / ".git" / "lfs" - git_lfs_files_size = sum( - f.stat().st_size for f in root_directory.glob("**/*") if f.is_file() - ) + git_lfs_files_size = sum(f.stat().st_size for f in root_directory.glob("**/*") if f.is_file()) repo.lfs_prune() - post_prune_git_lfs_files_size = sum( - f.stat().st_size for f in root_directory.glob("**/*") if f.is_file() - ) + post_prune_git_lfs_files_size = sum(f.stat().st_size for f in root_directory.glob("**/*") if f.is_file()) # Size of the directory holding LFS files was reduced self.assertLess(post_prune_git_lfs_files_size, git_lfs_files_size) @@ -572,9 +562,7 @@ def test_lfs_prune_git_push(self): f.write("Random string 1") root_directory = self.repo_path / ".git" / "lfs" - git_lfs_files_size = sum( - f.stat().st_size for f in root_directory.glob("**/*") if f.is_file() - ) + git_lfs_files_size = sum(f.stat().st_size for f in root_directory.glob("**/*") if f.is_file()) with open(os.path.join(repo.local_dir, "file.bin"), "w+") as f: f.write("Random string 2") @@ -583,9 +571,7 @@ def test_lfs_prune_git_push(self): repo.git_commit("New commit") repo.git_push(auto_lfs_prune=True) - post_prune_git_lfs_files_size = sum( - f.stat().st_size for f in root_directory.glob("**/*") if f.is_file() - ) + post_prune_git_lfs_files_size = sum(f.stat().st_size for f in root_directory.glob("**/*") if f.is_file()) # Size of the directory holding LFS files is the exact same self.assertEqual(post_prune_git_lfs_files_size, git_lfs_files_size) @@ -816,9 +802,7 @@ def test_is_not_tracked_upstream(self): self.assertFalse(is_tracked_upstream(self.repo.local_dir)) def test_no_branch_checked_out_raises(self): - head_commit_ref = run_subprocess( - "git show --oneline -s", folder=self.repo_path - ).stdout.split()[0] + head_commit_ref = run_subprocess("git show --oneline -s", folder=self.repo_path).stdout.split()[0] self.repo.git_checkout(head_commit_ref) self.assertRaises(OSError, is_tracked_upstream, self.repo.local_dir) diff --git a/tests/test_snapshot_download.py b/tests/test_snapshot_download.py index 4d3216bc3e..105f4d4088 100644 --- a/tests/test_snapshot_download.py +++ b/tests/test_snapshot_download.py @@ -2,6 +2,7 @@ import unittest import requests + from huggingface_hub import HfApi, Repository, snapshot_download from huggingface_hub.utils import ( HfFolder, @@ -83,9 +84,7 @@ def tearDownClass(cls) -> None: def test_download_model(self): # Test `main` branch with SoftTemporaryDirectory() as tmpdirname: - storage_folder = snapshot_download( - f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname - ) + storage_folder = snapshot_download(f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname) # folder contains the two files contributed and the .gitattributes folder_contents = os.listdir(storage_folder) @@ -127,12 +126,8 @@ def test_download_private_model(self): # Test download fails without token with SoftTemporaryDirectory() as tmpdirname: - with self.assertRaisesRegex( - requests.exceptions.HTTPError, "401 Client Error" - ): - _ = snapshot_download( - f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname - ) + with self.assertRaisesRegex(requests.exceptions.HTTPError, "401 Client Error"): + _ = snapshot_download(f"{USER}/{REPO_NAME}", revision="main", cache_dir=tmpdirname) # Test we can download with token from cache with SoftTemporaryDirectory() as tmpdirname: diff --git a/tests/test_utils_assets.py b/tests/test_utils_assets.py index 1dbae04b72..0726dda4ae 100644 --- a/tests/test_utils_assets.py +++ b/tests/test_utils_assets.py @@ -26,16 +26,12 @@ def test_cached_assets_path_with_namespace_and_subfolder(self) -> None: self.assertTrue(path.is_dir()) # And dir is created def test_cached_assets_path_without_subfolder(self) -> None: - path = cached_assets_path( - library_name="datasets", namespace="SQuAD", assets_dir=self.cache_dir - ) + path = cached_assets_path(library_name="datasets", namespace="SQuAD", assets_dir=self.cache_dir) self.assertEqual(path, self.cache_dir / "datasets" / "SQuAD" / "default") self.assertTrue(path.is_dir()) def test_cached_assets_path_without_namespace(self) -> None: - path = cached_assets_path( - library_name="datasets", subfolder="download", assets_dir=self.cache_dir - ) + path = cached_assets_path(library_name="datasets", subfolder="download", assets_dir=self.cache_dir) self.assertEqual(path, self.cache_dir / "datasets" / "default" / "download") self.assertTrue(path.is_dir()) @@ -53,10 +49,7 @@ def test_cached_assets_path_forbidden_symbols(self) -> None: ) self.assertEqual( path, - self.cache_dir - / "ReAlLy--dumb" - / "user--repo_name" - / "this--is--not--clever", + self.cache_dir / "ReAlLy--dumb" / "user--repo_name" / "this--is--not--clever", ) self.assertTrue(path.is_dir()) diff --git a/tests/test_utils_cache.py b/tests/test_utils_cache.py index d760e65a46..9ab80a26e8 100644 --- a/tests/test_utils_cache.py +++ b/tests/test_utils_cache.py @@ -48,9 +48,7 @@ class TestMissingCacheUtils(unittest.TestCase): def test_cache_dir_is_missing(self) -> None: """Directory to scan does not exist raises CacheNotFound.""" - self.assertRaises( - CacheNotFound, scan_cache_dir, self.cache_dir / "does_not_exist" - ) + self.assertRaises(CacheNotFound, scan_cache_dir, self.cache_dir / "does_not_exist") def test_cache_dir_is_a_file(self) -> None: """Directory to scan is a file raises ValueError.""" @@ -169,9 +167,7 @@ def test_scan_cache_on_valid_cache_unix(self) -> None: ) # Check readme file from "main" revision - main_readme_file = [ - file for file in main_revision.files if file.file_name == "README.md" - ][0] + main_readme_file = [file for file in main_revision.files if file.file_name == "README.md"][0] main_readme_file_path = main_revision_path / "README.md" main_readme_blob_path = repo_a_path / "blobs" / REPO_A_MAIN_README_BLOB_HASH @@ -182,9 +178,7 @@ def test_scan_cache_on_valid_cache_unix(self) -> None: # Check readme file from "refs/pr/1" revision pr_1_revision = repo_a.refs[REF_1_NAME] pr_1_revision_path = repo_a_path / "snapshots" / REPO_A_PR_1_HASH - pr_1_readme_file = [ - file for file in pr_1_revision.files if file.file_name == "README.md" - ][0] + pr_1_readme_file = [file for file in pr_1_revision.files if file.file_name == "README.md"][0] pr_1_readme_file_path = pr_1_revision_path / "README.md" # file_path in "refs/pr/1" revision is different than "main" but same blob path @@ -253,25 +247,19 @@ def test_scan_cache_on_valid_cache_windows(self) -> None: ) # Check readme file from "main" revision - main_readme_file = [ - file for file in main_revision.files if file.file_name == "README.md" - ][0] + main_readme_file = [file for file in main_revision.files if file.file_name == "README.md"][0] main_readme_file_path = main_revision_path / "README.md" main_readme_blob_path = repo_a_path / "blobs" / REPO_A_MAIN_README_BLOB_HASH self.assertEqual(main_readme_file.file_name, "README.md") self.assertEqual(main_readme_file.file_path, main_readme_file_path) - self.assertEqual( # Windows-specific: no blob file - main_readme_file.blob_path, main_readme_file_path - ) + self.assertEqual(main_readme_file.blob_path, main_readme_file_path) # Windows-specific: no blob file self.assertFalse(main_readme_blob_path.exists()) # Windows-specific # Check readme file from "refs/pr/1" revision pr_1_revision = repo_a.refs[REF_1_NAME] pr_1_revision_path = repo_a_path / "snapshots" / REPO_A_PR_1_HASH - pr_1_readme_file = [ - file for file in pr_1_revision.files if file.file_name == "README.md" - ][0] + pr_1_readme_file = [file for file in pr_1_revision.files if file.file_name == "README.md"][0] pr_1_readme_file_path = pr_1_revision_path / "README.md" # file_path in "refs/pr/1" revision is different than "main" @@ -391,9 +379,7 @@ def test_repo_path_not_valid_dir(self) -> None: self.assertEqual(len(report.repos), 1) # Scan still worked ! self.assertEqual(len(report.warnings), 1) - self.assertEqual( - str(report.warnings[0]), f"Repo path is not a directory: {repo_path}" - ) + self.assertEqual(str(report.warnings[0]), f"Repo path is not a directory: {repo_path}") # Case 2: a folder with wrong naming os.remove(repo_path) @@ -420,8 +406,7 @@ def test_repo_path_not_valid_dir(self) -> None: self.assertEqual(len(report.warnings), 1) self.assertEqual( str(report.warnings[0]), - "Repo type must be `dataset`, `model` or `space`, found `not-model`" - f" ({repo_path}).", + f"Repo type must be `dataset`, `model` or `space`, found `not-model` ({repo_path}).", ) def test_snapshots_path_not_found(self) -> None: @@ -507,8 +492,7 @@ def test_ref_to_missing_revision(self) -> None: self.assertEqual(len(report.warnings), 1) self.assertEqual( str(report.warnings[0]), - "Reference(s) refer to missing commit hashes:" - " {'revision_hash_that_does_not_exist': {'not_main'}} " + "Reference(s) refer to missing commit hashes: {'revision_hash_that_does_not_exist': {'not_main'}} " + f"({self.repo_path }).", ) @@ -523,12 +507,8 @@ def test_scan_cache_last_modified_and_last_accessed(self) -> None: # Values from first report repo_1 = list(report_1.repos)[0] revision_1 = list(repo_1.revisions)[0] - readme_file_1 = [ - file for file in revision_1.files if file.file_name == "README.md" - ][0] - another_file_1 = [ - file for file in revision_1.files if file.file_name == ".gitattributes" - ][0] + readme_file_1 = [file for file in revision_1.files if file.file_name == "README.md"][0] + another_file_1 = [file for file in revision_1.files if file.file_name == ".gitattributes"][0] # Comparison of last_accessed/last_modified between file and repo self.assertLessEqual(readme_file_1.blob_last_accessed, repo_1.last_accessed) @@ -551,12 +531,8 @@ def test_scan_cache_last_modified_and_last_accessed(self) -> None: # Values from second report repo_2 = list(report_2.repos)[0] revision_2 = list(repo_2.revisions)[0] - readme_file_2 = [ - file for file in revision_2.files if file.file_name == "README.md" - ][0] - another_file_2 = [ - file for file in revision_1.files if file.file_name == ".gitattributes" - ][0] + readme_file_2 = [file for file in revision_2.files if file.file_name == "README.md"][0] + another_file_2 = [file for file in revision_1.files if file.file_name == ".gitattributes"][0] # Report 1 is not updated when cache changes self.assertLess(repo_1.last_accessed, repo_2.last_accessed) @@ -608,9 +584,7 @@ def setUp(self) -> None: pr_1_only_file.size_on_disk = 100 detached_and_pr_1_only_file = Mock() - detached_and_pr_1_only_file.blob_path = ( - blobs_path / "detached_and_pr_1_only_hash" - ) + detached_and_pr_1_only_file.blob_path = blobs_path / "detached_and_pr_1_only_hash" detached_and_pr_1_only_file.size_on_disk = 1000 shared_file = Mock() @@ -684,9 +658,7 @@ def test_delete_pr_1_revision(self) -> None: self.assertEqual(strategy, expected) def test_delete_pr_1_and_detached(self) -> None: - strategy = HFCacheInfo.delete_revisions( - self.cache_info, "repo_A_rev_detached", "repo_A_rev_pr_1" - ) + strategy = HFCacheInfo.delete_revisions(self.cache_info, "repo_A_rev_detached", "repo_A_rev_pr_1") expected = DeleteCacheStrategy( expected_freed_size=1110, blobs={ @@ -719,9 +691,7 @@ def test_delete_all_revisions(self) -> None: def test_delete_unknown_revision(self) -> None: with self.assertLogs() as captured: - strategy = HFCacheInfo.delete_revisions( - self.cache_info, "repo_A_rev_detached", "abcdef123456789" - ) + strategy = HFCacheInfo.delete_revisions(self.cache_info, "repo_A_rev_detached", "abcdef123456789") # Expected is same strategy as without "abcdef123456789" expected = HFCacheInfo.delete_revisions(self.cache_info, "repo_A_rev_detached") diff --git a/tests/test_utils_cli.py b/tests/test_utils_cli.py index e1d68118da..7419283808 100644 --- a/tests/test_utils_cli.py +++ b/tests/test_utils_cli.py @@ -59,11 +59,13 @@ def test_tabulate_utility(self) -> None: headers = ["Header 1", "something else", "a third column"] self.assertEqual( tabulate(rows=rows, headers=headers), - "Header 1 something else a third column \n" - "----------------- -------------- -------------- \n" - " 1 2 3 \n" - "a very long value foo bar \n" - " 123 456 ", + ( + "Header 1 something else a third column \n" + "----------------- -------------- -------------- \n" + " 1 2 3 \n" + "a very long value foo bar \n" + " 123 456 " + ), ) def test_tabulate_utility_with_too_short_row(self) -> None: diff --git a/tests/test_utils_datetime.py b/tests/test_utils_datetime.py index 4fb350a4e7..aac379ac6c 100644 --- a/tests/test_utils_datetime.py +++ b/tests/test_utils_datetime.py @@ -14,9 +14,7 @@ def test_parse_datetime(self): datetime(2022, 8, 19, 7, 19, 38, 123000, tzinfo=timezone.utc), ) - with pytest.raises( - ValueError, match=r".*Cannot parse '2022-08-19T07:19:38' as a datetime.*" - ): + with pytest.raises(ValueError, match=r".*Cannot parse '2022-08-19T07:19:38' as a datetime.*"): parse_datetime("2022-08-19T07:19:38") with pytest.raises( diff --git a/tests/test_utils_deprecation.py b/tests/test_utils_deprecation.py index 9dadfc8b6e..976adcd520 100644 --- a/tests/test_utils_deprecation.py +++ b/tests/test_utils_deprecation.py @@ -87,8 +87,10 @@ def dummy_deprecated_default_message(a: str = "a") -> None: self.assertEqual(len(record), 1) self.assertEqual( record[0].message.args[0], - "Deprecated argument(s) used in 'dummy_deprecated_default_message': a. Will" - " not be supported from version 'xxx'.", + ( + "Deprecated argument(s) used in 'dummy_deprecated_default_message': a." + " Will not be supported from version 'xxx'." + ), ) def test_deprecate_arguments_with_custom_warning_message(self) -> None: @@ -108,8 +110,11 @@ def dummy_deprecated_custom_message(a: str = "a") -> None: self.assertEqual(len(record), 1) self.assertEqual( record[0].message.args[0], - "Deprecated argument(s) used in 'dummy_deprecated_custom_message': a. Will" - " not be supported from version 'xxx'.\n\nThis is a custom message.", + ( + "Deprecated argument(s) used in 'dummy_deprecated_custom_message': a." + " Will not be supported from version 'xxx'.\n\nThis is a custom" + " message." + ), ) def test_deprecated_method(self) -> None: @@ -125,8 +130,10 @@ def dummy_deprecated() -> None: self.assertEqual(len(record), 1) self.assertEqual( record[0].message.args[0], - "'dummy_deprecated' (from 'tests.test_utils_deprecation') is deprecated" - " and will be removed from version 'xxx'. This is a custom message.", + ( + "'dummy_deprecated' (from 'tests.test_utils_deprecation') is deprecated" + " and will be removed from version 'xxx'. This is a custom message." + ), ) def test_deprecate_list_output(self) -> None: @@ -148,11 +155,13 @@ def dummy_deprecated() -> None: # (check real message once) self.assertEqual( record[0].message.args[0], - "'dummy_deprecated' currently returns a list of objects but is planned to" - " be a generator starting from version xxx in order to implement" - " pagination. Please avoid to use `dummy_deprecated(...).__getitem__` or" - " explicitly convert the output to a list first with" - " `list(iter(dummy_deprecated)(...))`.", + ( + "'dummy_deprecated' currently returns a list of objects but is planned" + " to be a generator starting from version xxx in order to implement" + " pagination. Please avoid to use `dummy_deprecated(...).__getitem__`" + " or explicitly convert the output to a list first with" + " `list(iter(dummy_deprecated)(...))`." + ), ) # __setitem__ diff --git a/tests/test_utils_errors.py b/tests/test_utils_errors.py index a1d625d5e0..a761da1d76 100644 --- a/tests/test_utils_errors.py +++ b/tests/test_utils_errors.py @@ -1,6 +1,8 @@ import unittest from unittest.mock import Mock, patch +from requests.models import Response + from huggingface_hub.utils._errors import ( BadRequestError, EntryNotFoundError, @@ -12,7 +14,6 @@ _raise_with_request_id, hf_raise_for_status, ) -from requests.models import Response from .testing_utils import expect_deprecation @@ -22,9 +23,7 @@ def test_hf_raise_for_status_repo_not_found(self) -> None: response = Response() response.headers = {"X-Error-Code": "RepoNotFound", "X-Request-Id": 123} response.status_code = 404 - with self.assertRaisesRegex( - RepositoryNotFoundError, "Repository Not Found" - ) as context: + with self.assertRaisesRegex(RepositoryNotFoundError, "Repository Not Found") as context: hf_raise_for_status(response) self.assertEqual(context.exception.response.status_code, 404) @@ -34,9 +33,7 @@ def test_hf_raise_for_status_repo_not_found_without_error_code(self) -> None: response = Response() response.headers = {"X-Request-Id": 123} response.status_code = 401 - with self.assertRaisesRegex( - RepositoryNotFoundError, "Repository Not Found" - ) as context: + with self.assertRaisesRegex(RepositoryNotFoundError, "Repository Not Found") as context: hf_raise_for_status(response) self.assertEqual(context.exception.response.status_code, 401) @@ -46,9 +43,7 @@ def test_hf_raise_for_status_revision_not_found(self) -> None: response = Response() response.headers = {"X-Error-Code": "RevisionNotFound", "X-Request-Id": 123} response.status_code = 404 - with self.assertRaisesRegex( - RevisionNotFoundError, "Revision Not Found" - ) as context: + with self.assertRaisesRegex(RevisionNotFoundError, "Revision Not Found") as context: hf_raise_for_status(response) self.assertEqual(context.exception.response.status_code, 404) @@ -76,9 +71,7 @@ def test_hf_raise_for_status_bad_request_with_endpoint_name(self) -> None: """Test endpoint name is added to BadRequestError message.""" response = Response() response.status_code = 400 - with self.assertRaisesRegex( - BadRequestError, "Bad request for preupload endpoint:" - ) as context: + with self.assertRaisesRegex(BadRequestError, "Bad request for preupload endpoint:") as context: hf_raise_for_status(response, endpoint_name="preupload") self.assertEqual(context.exception.response.status_code, 400) @@ -119,9 +112,7 @@ def test_raise_convert_bad_request(self, mock_hf_raise_for_status: Mock) -> None response_mock = Mock() endpoint_name_mock = Mock() _raise_convert_bad_request(response_mock, endpoint_name_mock) - mock_hf_raise_for_status.assert_called_once_with( - response_mock, endpoint_name=endpoint_name_mock - ) + mock_hf_raise_for_status.assert_called_once_with(response_mock, endpoint_name=endpoint_name_mock) class TestHfHubHTTPError(unittest.TestCase): @@ -151,16 +142,10 @@ def test_hf_hub_http_error_init_with_request_id(self) -> None: def test_hf_hub_http_error_init_with_request_id_and_multiline_message(self) -> None: """Test request id is added to the end of the first line.""" self.response.headers = {"X-Request-Id": "test-id"} - error = HfHubHTTPError( - "this is a message\nthis is more details", response=self.response - ) - self.assertEqual( - str(error), "this is a message (Request ID: test-id)\nthis is more details" - ) + error = HfHubHTTPError("this is a message\nthis is more details", response=self.response) + self.assertEqual(str(error), "this is a message (Request ID: test-id)\nthis is more details") - error = HfHubHTTPError( - "this is a message\n\nthis is more details", response=self.response - ) + error = HfHubHTTPError("this is a message\n\nthis is more details", response=self.response) self.assertEqual( str(error), "this is a message (Request ID: test-id)\n\nthis is more details", @@ -169,39 +154,26 @@ def test_hf_hub_http_error_init_with_request_id_and_multiline_message(self) -> N def test_hf_hub_http_error_init_with_request_id_already_in_message(self) -> None: """Test request id is not duplicated in error message (case insensitive)""" self.response.headers = {"X-Request-Id": "test-id"} - error = HfHubHTTPError( - "this is a message on request TEST-ID", response=self.response - ) + error = HfHubHTTPError("this is a message on request TEST-ID", response=self.response) self.assertEqual(str(error), "this is a message on request TEST-ID") self.assertEqual(error.request_id, "test-id") def test_hf_hub_http_error_init_with_server_error(self) -> None: """Test server error is added to the error message.""" - self.response._content = ( - b'{"error": "This is a message returned by the server"}' - ) + self.response._content = b'{"error": "This is a message returned by the server"}' error = HfHubHTTPError("this is a message", response=self.response) - self.assertEqual( - str(error), "this is a message\n\nThis is a message returned by the server" - ) - self.assertEqual( - error.server_message, "This is a message returned by the server" - ) + self.assertEqual(str(error), "this is a message\n\nThis is a message returned by the server") + self.assertEqual(error.server_message, "This is a message returned by the server") def test_hf_hub_http_error_init_with_server_error_and_multiline_message( self, ) -> None: """Test server error is added to the error message after the details.""" - self.response._content = ( - b'{"error": "This is a message returned by the server"}' - ) - error = HfHubHTTPError( - "this is a message\n\nSome details.", response=self.response - ) + self.response._content = b'{"error": "This is a message returned by the server"}' + error = HfHubHTTPError("this is a message\n\nSome details.", response=self.response) self.assertEqual( str(error), - "this is a message\n\nSome details.\nThis is a message returned by the" - " server", + "this is a message\n\nSome details.\nThis is a message returned by the server", ) def test_hf_hub_http_error_init_with_multiple_server_errors( @@ -215,9 +187,7 @@ def test_hf_hub_http_error_init_with_multiple_server_errors( b'{"httpStatusCode": 400, "errors": [{"message": "this is error 1", "type":' b' "error"}, {"message": "this is error 2", "type": "error"}]}' ) - error = HfHubHTTPError( - "this is a message\n\nSome details.", response=self.response - ) + error = HfHubHTTPError("this is a message\n\nSome details.", response=self.response) self.assertEqual( str(error), "this is a message\n\nSome details.\nthis is error 1\nthis is error 2", @@ -277,8 +247,7 @@ def test_hf_hub_http_error_init_with_error_message_from_header_and_body( error = HfHubHTTPError("this is a message", response=self.response) self.assertEqual( str(error), - "this is a message\n\nError message from headers.\nError message from" - " body.", + "this is a message\n\nError message from headers.\nError message from body.", ) self.assertEqual( error.server_message, @@ -292,17 +261,11 @@ def test_hf_hub_http_error_init_with_error_message_duplicated_in_header_and_body Should not duplicate it in the raised `HfHubHTTPError`. """ - self.response._content = ( - b'{"error": "Error message duplicated in headers and body."}' - ) - self.response.headers = { - "X-Error-Message": "Error message duplicated in headers and body." - } + self.response._content = b'{"error": "Error message duplicated in headers and body."}' + self.response.headers = {"X-Error-Message": "Error message duplicated in headers and body."} error = HfHubHTTPError("this is a message", response=self.response) self.assertEqual( str(error), "this is a message\n\nError message duplicated in headers and body.", ) - self.assertEqual( - error.server_message, "Error message duplicated in headers and body." - ) + self.assertEqual(error.server_message, "Error message duplicated in headers and body.") diff --git a/tests/test_utils_fixes.py b/tests/test_utils_fixes.py index abe9b80e4d..582962aea1 100644 --- a/tests/test_utils_fixes.py +++ b/tests/test_utils_fixes.py @@ -12,9 +12,7 @@ def test_yaml_dump_japanese_characters(self) -> None: self.assertEqual(yaml_dump({"some unicode": "日本か"}), "some unicode: 日本か\n") def test_yaml_dump_explicit_no_unicode(self) -> None: - self.assertEqual( - yaml_dump({"emoji": "👀"}, allow_unicode=False), 'emoji: "\\U0001F440"\n' - ) + self.assertEqual(yaml_dump({"emoji": "👀"}, allow_unicode=False), 'emoji: "\\U0001F440"\n') class TestTemporaryDirectory(unittest.TestCase): diff --git a/tests/test_utils_git_credentials.py b/tests/test_utils_git_credentials.py index 476ae9238c..20f88ea70e 100644 --- a/tests/test_utils_git_credentials.py +++ b/tests/test_utils_git_credentials.py @@ -42,14 +42,10 @@ def test_set_and_unset_git_credential(self) -> None: username = "hf_test_user_" + str(round(time.time())) # make username unique # Set credentials - set_git_credential( - token="hf_test_token", username=username, folder=self.cache_dir - ) + set_git_credential(token="hf_test_token", username=username, folder=self.cache_dir) # Check credentials are stored - with run_interactive_subprocess( - "git credential fill", folder=self.cache_dir - ) as (stdin, stdout): + with run_interactive_subprocess("git credential fill", folder=self.cache_dir) as (stdin, stdout): stdin.write(f"url={ENDPOINT}\nusername={username}\n\n") stdin.flush() output = stdout.read() @@ -61,9 +57,7 @@ def test_set_and_unset_git_credential(self) -> None: # Check credentials are NOT stored # Cannot check with `git credential fill` as it would hang forever: only # checking `store` helper instead. - with run_interactive_subprocess( - "git credential-store get", folder=self.cache_dir - ) as (stdin, stdout): + with run_interactive_subprocess("git credential-store get", folder=self.cache_dir) as (stdin, stdout): stdin.write(f"url={ENDPOINT}\nusername={username}\n\n") stdin.flush() output = stdout.read() diff --git a/tests/test_utils_headers.py b/tests/test_utils_headers.py index 54eddacea5..0573da0170 100644 --- a/tests/test_utils_headers.py +++ b/tests/test_utils_headers.py @@ -110,10 +110,12 @@ def test_default_user_agent( mock_is_torch_available.return_value = True self.assertEqual( self._get_user_agent(), - f"unknown/None; hf_hub/{get_hf_hub_version()};" - f" python/{get_python_version()}; torch/torch_version;" - " tensorflow/tf_version; fastai/fastai_version;" - " fastcore/fastcore_version", + ( + f"unknown/None; hf_hub/{get_hf_hub_version()};" + f" python/{get_python_version()}; torch/torch_version;" + " tensorflow/tf_version; fastai/fastai_version;" + " fastcore/fastcore_version" + ), ) @patch("huggingface_hub.utils._headers.is_torch_available") @@ -136,18 +138,10 @@ def test_user_agent_with_library_name_and_version(self) -> None: ) def test_user_agent_with_library_name_no_version(self) -> None: - self.assertTrue( - self._get_user_agent(library_name="foo").startswith("foo/None;") - ) + self.assertTrue(self._get_user_agent(library_name="foo").startswith("foo/None;")) def test_user_agent_with_custom_agent_string(self) -> None: - self.assertTrue( - self._get_user_agent(user_agent="this is a custom agent").endswith( - "this is a custom agent" - ) - ) + self.assertTrue(self._get_user_agent(user_agent="this is a custom agent").endswith("this is a custom agent")) def test_user_agent_with_custom_agent_dict(self) -> None: - self.assertTrue( - self._get_user_agent(user_agent={"a": "b", "c": "d"}).endswith("a/b; c/d") - ) + self.assertTrue(self._get_user_agent(user_agent={"a": "b", "c": "d"}).endswith("a/b; c/d")) diff --git a/tests/test_utils_http.py b/tests/test_utils_http.py index ea39bac31a..d3f139ebdf 100644 --- a/tests/test_utils_http.py +++ b/tests/test_utils_http.py @@ -3,9 +3,10 @@ from typing import Generator from unittest.mock import Mock, call, patch -from huggingface_hub.utils._http import http_backoff from requests import ConnectTimeout, HTTPError +from huggingface_hub.utils._http import http_backoff + URL = "https://www.google.com" @@ -87,9 +88,7 @@ def test_backoff_on_valid_status_code(self, mock_request: Mock) -> None: mock_200.status_code = 200 mock_request.side_effect = (mock_200, mock_200, mock_200, mock_200) - response = http_backoff( - "GET", URL, base_wait_time=0.0, max_retries=3, retry_on_status_codes=200 - ) + response = http_backoff("GET", URL, base_wait_time=0.0, max_retries=3, retry_on_status_codes=200) self.assertEqual(mock_request.call_count, 4) self.assertIs(response, mock_200) @@ -116,9 +115,7 @@ def _side_effect_timer() -> Generator[ConnectTimeout, None, None]: mock_request.side_effect = _side_effect_timer() with self.assertRaises(ConnectTimeout): - http_backoff( - "GET", URL, base_wait_time=0.1, max_wait_time=0.5, max_retries=5 - ) + http_backoff("GET", URL, base_wait_time=0.1, max_wait_time=0.5, max_retries=5) self.assertEqual(mock_request.call_count, 6) diff --git a/tests/test_utils_pagination.py b/tests/test_utils_pagination.py index d1e6a00885..430a24402a 100644 --- a/tests/test_utils_pagination.py +++ b/tests/test_utils_pagination.py @@ -10,9 +10,7 @@ class TestPagination(unittest.TestCase): @patch("huggingface_hub.utils._pagination.requests.get") @patch("huggingface_hub.utils._pagination.hf_raise_for_status") @handle_injection_in_test - def test_mocked_paginate( - self, mock_get: Mock, mock_hf_raise_for_status: Mock - ) -> None: + def test_mocked_paginate(self, mock_get: Mock, mock_hf_raise_for_status: Mock) -> None: mock_params = Mock() mock_headers = Mock() diff --git a/tests/test_utils_validators.py b/tests/test_utils_validators.py index fd98727509..66161daf18 100644 --- a/tests/test_utils_validators.py +++ b/tests/test_utils_validators.py @@ -56,9 +56,7 @@ def test_valid_repo_ids(self) -> None: def test_not_valid_repo_ids(self) -> None: """Test `repo_id` validation on not valid values.""" for repo_id in self.NOT_VALID_VALUES: - with self.assertRaises( - HFValidationError, msg=f"'{repo_id}' must not be valid" - ): + with self.assertRaises(HFValidationError, msg=f"'{repo_id}' must not be valid"): validate_repo_id(repo_id) @@ -89,21 +87,15 @@ def test_token_with_smoothly_deprecated_use_auth_token(self) -> None: def test_input_kwargs_not_mutated_by_smooth_deprecation(self) -> None: initial_kwargs = {"a": "b", "use_auth_token": "token"} - kwargs = smoothly_deprecate_use_auth_token( - fn_name="name", has_token=False, kwargs=initial_kwargs - ) + kwargs = smoothly_deprecate_use_auth_token(fn_name="name", has_token=False, kwargs=initial_kwargs) self.assertEqual(kwargs, {"a": "b", "token": "token"}) - self.assertEqual( # not mutated! - initial_kwargs, {"a": "b", "use_auth_token": "token"} - ) + self.assertEqual(initial_kwargs, {"a": "b", "use_auth_token": "token"}) # not mutated! def test_with_both_token_and_use_auth_token(self) -> None: with self.assertWarns(UserWarning): # `use_auth_token` is ignored ! self.assertEqual( - self.dummy_token_function( - token="this_is_a_token", use_auth_token="this_is_a_use_auth_token" - ), + self.dummy_token_function(token="this_is_a_token", use_auth_token="this_is_a_use_auth_token"), ("this_is_a_token", {}), ) @@ -111,9 +103,7 @@ def test_not_deprecated_use_auth_token(self) -> None: # `use_auth_token` is accepted by `dummy_use_auth_token_function` # => `smoothly_deprecate_use_auth_token` is not called self.assertEqual( - self.dummy_use_auth_token_function( - use_auth_token="this_is_a_use_auth_token" - ), + self.dummy_use_auth_token_function(use_auth_token="this_is_a_use_auth_token"), ("this_is_a_use_auth_token", {}), ) diff --git a/tests/testing_constants.py b/tests/testing_constants.py index 46ba7f21ee..fd7a258263 100644 --- a/tests/testing_constants.py +++ b/tests/testing_constants.py @@ -17,9 +17,7 @@ ENDPOINT_STAGING = "https://hub-ci.huggingface.co" ENDPOINT_STAGING_BASIC_AUTH = f"https://{USER}:{PASS}@hub-ci.huggingface.co" -ENDPOINT_PRODUCTION_URL_SCHEME = ( - ENDPOINT_PRODUCTION + "/{repo_id}/resolve/{revision}/{filename}" -) +ENDPOINT_PRODUCTION_URL_SCHEME = ENDPOINT_PRODUCTION + "/{repo_id}/resolve/{revision}/{filename}" # Token to be set as environment variable. # Almost all features are tested on staging environment. However, Spaces are not supported diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 0f9c46685c..fa7e8c1d40 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -15,9 +15,9 @@ from unittest.mock import Mock, patch import pytest +from requests.exceptions import HTTPError from huggingface_hub.utils import logging -from requests.exceptions import HTTPError from tests.testing_constants import ENDPOINT_PRODUCTION, ENDPOINT_PRODUCTION_URL_SCHEME @@ -36,9 +36,7 @@ # This commit does not exist, so we should 404. DUMMY_MODEL_ID_PINNED_SHA1 = "d9e9f15bc825e4b2c9249e9578f884bbcb5e3684" # Sha-1 of config.json on the top of `main`, for checking purposes -DUMMY_MODEL_ID_PINNED_SHA256 = ( - "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3" -) +DUMMY_MODEL_ID_PINNED_SHA256 = "4b243c475af8d0a7754e87d7d096c92e5199ec2fe168a2ee7998e3b8e9bcb1d3" # Sha-256 of pytorch_model.bin on the top of `main`, for checking purposes # "hf-internal-testing/dummy-will-be-renamed" has been renamed to "hf-internal-testing/dummy-renamed" @@ -48,9 +46,7 @@ SAMPLE_DATASET_IDENTIFIER = "lhoestq/custom_squad" # Example dataset ids DUMMY_DATASET_ID = "lhoestq/test" -DUMMY_DATASET_ID_REVISION_ONE_SPECIFIC_COMMIT = ( # on branch "test-branch" - "81d06f998585f8ee10e6e3a2ea47203dc75f2a16" -) +DUMMY_DATASET_ID_REVISION_ONE_SPECIFIC_COMMIT = "81d06f998585f8ee10e6e3a2ea47203dc75f2a16" # on branch "test-branch" YES = ("y", "yes", "t", "true", "on", "1") NO = ("n", "no", "f", "false", "off", "0") @@ -154,8 +150,7 @@ def timeout_request(method, url, **kwargs): invalid_url = "https://10.255.255.1" if kwargs.get("timeout") is None: raise RequestWouldHangIndefinitelyError( - f"Tried a call to {url} in offline mode with no timeout set. Please set" - " a timeout." + f"Tried a call to {url} in offline mode with no timeout set. Please set a timeout." ) kwargs["timeout"] = timeout try: @@ -164,9 +159,7 @@ def timeout_request(method, url, **kwargs): # The following changes in the error are just here to make the offline timeout error prettier e.request.url = url max_retry_error = e.args[0] - max_retry_error.args = ( - max_retry_error.args[0].replace("10.255.255.1", f"OfflineMock[{url}]"), - ) + max_retry_error.args = (max_retry_error.args[0].replace("10.255.255.1", f"OfflineMock[{url}]"),) e.args = (max_retry_error,) raise @@ -318,9 +311,7 @@ def xfail_on_windows(reason: str, raises: Optional[Type[Exception]] = None): """ def _inner_decorator(test_function: Callable) -> Callable: - return pytest.mark.xfail( - os.name == "nt", reason=reason, raises=raises, strict=True, run=True - )(test_function) + return pytest.mark.xfail(os.name == "nt", reason=reason, raises=raises, strict=True, run=True)(test_function) return _inner_decorator @@ -450,10 +441,9 @@ def _inner(*args, **kwargs): if name == "self": continue assert parameter.annotation is Mock - assert name in mocks, ( - f"Mock `{name}` not found for test `{fn.__name__}`. Available:" - f" {', '.join(sorted(mocks.keys()))}" - ) + assert ( + name in mocks + ), f"Mock `{name}` not found for test `{fn.__name__}`. Available: {', '.join(sorted(mocks.keys()))}" new_kwargs[name] = mocks[name] # Run test only with a subset of mocks @@ -492,9 +482,7 @@ def _inner(*args, **kwargs): self = args[0] assert isinstance(self, unittest.TestCase) - repo_url = self._api.create_repo( - repo_id=repo_name(prefix=repo_type), repo_type=repo_type - ) + repo_url = self._api.create_repo(repo_id=repo_name(prefix=repo_type), repo_type=repo_type) try: return test_fn(*args, **kwargs, repo_url=repo_url) finally: diff --git a/utils/check_contrib_list.py b/utils/check_contrib_list.py index ea3d0a73f1..2c3ae08b0e 100644 --- a/utils/check_contrib_list.py +++ b/utils/check_contrib_list.py @@ -45,23 +45,17 @@ def check_contrib_list(update: bool) -> NoReturn: the list.""" # List contrib test suites contrib_list = sorted( - path.name - for path in CONTRIB_PATH.glob("*") - if path.is_dir() and not path.name.startswith("_") + path.name for path in CONTRIB_PATH.glob("*") if path.is_dir() and not path.name.startswith("_") ) # Check Makefile is consistent with list makefile_content = MAKEFILE_PATH.read_text() - makefile_expected_content = MAKEFILE_REGEX.sub( - f"CONTRIB_LIBS := {' '.join(contrib_list)}", makefile_content - ) + makefile_expected_content = MAKEFILE_REGEX.sub(f"CONTRIB_LIBS := {' '.join(contrib_list)}", makefile_content) # Check workflow is consistent with list workflow_content = WORKFLOW_PATH.read_text() _substitute = "\n".join(f'{" "*10}"{lib}",' for lib in contrib_list) - workflow_content_expected = WORKFLOW_REGEX.sub( - rf"\g{_substitute}\n\g", workflow_content - ) + workflow_content_expected = WORKFLOW_REGEX.sub(rf"\g{_substitute}\n\g", workflow_content) # failed = False diff --git a/utils/check_static_imports.py b/utils/check_static_imports.py index 2d9a10e3bf..19dd2e5e9c 100644 --- a/utils/check_static_imports.py +++ b/utils/check_static_imports.py @@ -19,6 +19,7 @@ from typing import NoReturn import isort + from huggingface_hub import _SUBMOD_ATTRS @@ -46,10 +47,7 @@ def check_static_imports(update: bool) -> NoReturn: # Search and replace `_SUBMOD_ATTRS` dictionary definition. This ensures modules # and functions that can be lazy-loaded are alphabetically ordered for readability. if SUBMOD_ATTRS_PATTERN.search(init_content_before_static_checks) is None: - print( - "Error: _SUBMOD_ATTRS dictionary definition not found in" - " `./src/huggingface_hub/__init__.py`." - ) + print("Error: _SUBMOD_ATTRS dictionary definition not found in `./src/huggingface_hub/__init__.py`.") exit(1) _submod_attrs_definition = ( @@ -75,10 +73,7 @@ def check_static_imports(update: bool) -> NoReturn: # Generate the expected `__init__.py` file content and apply formatter on it. expected_init_content = isort.code( - reordered_content_before_static_checks - + IF_TYPE_CHECKING_LINE - + "\n".join(static_imports) - + "\n", + reordered_content_before_static_checks + IF_TYPE_CHECKING_LINE + "\n".join(static_imports) + "\n", config=isort.Config(settings_path=SETUP_CFG_PATH), ) @@ -113,10 +108,7 @@ def check_static_imports(update: bool) -> NoReturn: parser.add_argument( "--update", action="store_true", - help=( - "Whether to fix `./src/huggingface_hub/__init__.py` if a change is" - " detected." - ), + help="Whether to fix `./src/huggingface_hub/__init__.py` if a change is detected.", ) args = parser.parse_args() diff --git a/utils/push_repocard_examples.py b/utils/push_repocard_examples.py index 432d6b8b35..2184796936 100644 --- a/utils/push_repocard_examples.py +++ b/utils/push_repocard_examples.py @@ -17,6 +17,7 @@ from pathlib import Path import jinja2 + from huggingface_hub import DatasetCard, ModelCard, hf_hub_download, upload_file, whoami from huggingface_hub.constants import REPOCARD_NAME @@ -34,10 +35,7 @@ def check_can_push(): print("You must be logged in to push repo card examples.") if all(org["name"] != ORG_NAME for org in me.get("orgs", [])): - print( - f"❌ You must have access to organization '{ORG_NAME}' to push repo card" - " examples." - ) + print(f"❌ You must have access to organization '{ORG_NAME}' to push repo card examples.") exit(1) @@ -58,9 +56,7 @@ def push_model_card_example(overwrite: bool) -> None: ), ) if not overwrite: - existing_content = Path( - hf_hub_download(MODEL_CARD_REPO_ID, REPOCARD_NAME, repo_type="model") - ).read_text() + existing_content = Path(hf_hub_download(MODEL_CARD_REPO_ID, REPOCARD_NAME, repo_type="model")).read_text() if content == existing_content: print("Model Card not pushed: did not change.") return @@ -90,9 +86,7 @@ def push_dataset_card_example(overwrite: bool) -> None: ), ) if not overwrite: - existing_content = Path( - hf_hub_download(DATASET_CARD_REPO_ID, REPOCARD_NAME, repo_type="dataset") - ).read_text() + existing_content = Path(hf_hub_download(DATASET_CARD_REPO_ID, REPOCARD_NAME, repo_type="dataset")).read_text() if content == existing_content: print("Dataset Card not pushed: did not change.") return @@ -110,10 +104,7 @@ def push_dataset_card_example(overwrite: bool) -> None: parser.add_argument( "--overwrite", action="store_true", - help=( - "Whether to force updating examples. By default, push to hub only if card" - " is updated." - ), + help="Whether to force updating examples. By default, push to hub only if card is updated.", ) args = parser.parse_args() From 6175d23adc69895b752340eeba9e589485fda90c Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Tue, 7 Feb 2023 09:47:52 +0100 Subject: [PATCH 2/6] Fixed __init__.py generation --- pyproject.toml | 5 +- src/huggingface_hub/__init__.py | 266 ++++++++++++++++++-------------- utils/check_static_imports.py | 16 +- 3 files changed, 158 insertions(+), 129 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 601b77ebda..5b231157d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,4 @@ line-length = 119 [tool.ruff.isort] lines-after-imports = 2 -known-first-party = ["huggingface_hub"] - -[tool.ruff.per-file-ignores] -"src/huggingface_hub/__init__.py" = ["I001"] # Imports are autogenerated \ No newline at end of file +known-first-party = ["huggingface_hub"] \ No newline at end of file diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 3f890aa716..1d55a2c710 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -305,126 +305,152 @@ def __dir__(): # make style # ``` if TYPE_CHECKING: # pragma: no cover - from ._login import interpreter_login # noqa: F401 - from ._login import login # noqa: F401 - from ._login import logout # noqa: F401 - from ._login import notebook_login # noqa: F401 + from ._login import ( + interpreter_login, # noqa: F401 + login, # noqa: F401 + logout, # noqa: F401 + notebook_login, # noqa: F401 + ) from ._snapshot_download import snapshot_download # noqa: F401 - from ._space_api import SpaceHardware # noqa: F401 - from ._space_api import SpaceRuntime # noqa: F401 - from ._space_api import SpaceStage # noqa: F401 - from .community import Discussion # noqa: F401 - from .community import DiscussionComment # noqa: F401 - from .community import DiscussionCommit # noqa: F401 - from .community import DiscussionEvent # noqa: F401 - from .community import DiscussionStatusChange # noqa: F401 - from .community import DiscussionTitleChange # noqa: F401 - from .community import DiscussionWithDetails # noqa: F401 - from .constants import CONFIG_NAME # noqa: F401 - from .constants import FLAX_WEIGHTS_NAME # noqa: F401 - from .constants import HUGGINGFACE_CO_URL_HOME # noqa: F401 - from .constants import HUGGINGFACE_CO_URL_TEMPLATE # noqa: F401 - from .constants import PYTORCH_WEIGHTS_NAME # noqa: F401 - from .constants import REPO_TYPE_DATASET # noqa: F401 - from .constants import REPO_TYPE_MODEL # noqa: F401 - from .constants import REPO_TYPE_SPACE # noqa: F401 - from .constants import TF2_WEIGHTS_NAME # noqa: F401 - from .constants import TF_WEIGHTS_NAME # noqa: F401 - from .fastai_utils import _save_pretrained_fastai # noqa: F401 - from .fastai_utils import from_pretrained_fastai # noqa: F401 - from .fastai_utils import push_to_hub_fastai # noqa: F401 - from .file_download import _CACHED_NO_EXIST # noqa: F401 - from .file_download import HfFileMetadata # noqa: F401 - from .file_download import cached_download # noqa: F401 - from .file_download import get_hf_file_metadata # noqa: F401 - from .file_download import hf_hub_download # noqa: F401 - from .file_download import hf_hub_url # noqa: F401 - from .file_download import try_to_load_from_cache # noqa: F401 - from .hf_api import CommitInfo # noqa: F401 - from .hf_api import CommitOperation # noqa: F401 - from .hf_api import CommitOperationAdd # noqa: F401 - from .hf_api import CommitOperationDelete # noqa: F401 - from .hf_api import DatasetSearchArguments # noqa: F401 - from .hf_api import GitRefInfo # noqa: F401 - from .hf_api import GitRefs # noqa: F401 - from .hf_api import HfApi # noqa: F401 - from .hf_api import ModelSearchArguments # noqa: F401 - from .hf_api import RepoUrl # noqa: F401 - from .hf_api import UserLikes # noqa: F401 - from .hf_api import add_space_secret # noqa: F401 - from .hf_api import change_discussion_status # noqa: F401 - from .hf_api import comment_discussion # noqa: F401 - from .hf_api import create_branch # noqa: F401 - from .hf_api import create_commit # noqa: F401 - from .hf_api import create_discussion # noqa: F401 - from .hf_api import create_pull_request # noqa: F401 - from .hf_api import create_repo # noqa: F401 - from .hf_api import create_tag # noqa: F401 - from .hf_api import dataset_info # noqa: F401 - from .hf_api import delete_branch # noqa: F401 - from .hf_api import delete_file # noqa: F401 - from .hf_api import delete_folder # noqa: F401 - from .hf_api import delete_repo # noqa: F401 - from .hf_api import delete_space_secret # noqa: F401 - from .hf_api import delete_tag # noqa: F401 - from .hf_api import edit_discussion_comment # noqa: F401 - from .hf_api import get_dataset_tags # noqa: F401 - from .hf_api import get_discussion_details # noqa: F401 - from .hf_api import get_full_repo_name # noqa: F401 - from .hf_api import get_model_tags # noqa: F401 - from .hf_api import get_repo_discussions # noqa: F401 - from .hf_api import get_space_runtime # noqa: F401 - from .hf_api import like # noqa: F401 - from .hf_api import list_datasets # noqa: F401 - from .hf_api import list_liked_repos # noqa: F401 - from .hf_api import list_metrics # noqa: F401 - from .hf_api import list_models # noqa: F401 - from .hf_api import list_repo_files # noqa: F401 - from .hf_api import list_repo_refs # noqa: F401 - from .hf_api import list_spaces # noqa: F401 - from .hf_api import merge_pull_request # noqa: F401 - from .hf_api import model_info # noqa: F401 - from .hf_api import move_repo # noqa: F401 - from .hf_api import rename_discussion # noqa: F401 - from .hf_api import repo_type_and_id_from_hf_id # noqa: F401 - from .hf_api import request_space_hardware # noqa: F401 - from .hf_api import set_access_token # noqa: F401 - from .hf_api import space_info # noqa: F401 - from .hf_api import unlike # noqa: F401 - from .hf_api import unset_access_token # noqa: F401 - from .hf_api import update_repo_visibility # noqa: F401 - from .hf_api import upload_file # noqa: F401 - from .hf_api import upload_folder # noqa: F401 - from .hf_api import whoami # noqa: F401 - from .hub_mixin import ModelHubMixin # noqa: F401 - from .hub_mixin import PyTorchModelHubMixin # noqa: F401 + from ._space_api import ( + SpaceHardware, # noqa: F401 + SpaceRuntime, # noqa: F401 + SpaceStage, # noqa: F401 + ) + from .community import ( + Discussion, # noqa: F401 + DiscussionComment, # noqa: F401 + DiscussionCommit, # noqa: F401 + DiscussionEvent, # noqa: F401 + DiscussionStatusChange, # noqa: F401 + DiscussionTitleChange, # noqa: F401 + DiscussionWithDetails, # noqa: F401 + ) + from .constants import ( + CONFIG_NAME, # noqa: F401 + FLAX_WEIGHTS_NAME, # noqa: F401 + HUGGINGFACE_CO_URL_HOME, # noqa: F401 + HUGGINGFACE_CO_URL_TEMPLATE, # noqa: F401 + PYTORCH_WEIGHTS_NAME, # noqa: F401 + REPO_TYPE_DATASET, # noqa: F401 + REPO_TYPE_MODEL, # noqa: F401 + REPO_TYPE_SPACE, # noqa: F401 + TF2_WEIGHTS_NAME, # noqa: F401 + TF_WEIGHTS_NAME, # noqa: F401 + ) + from .fastai_utils import ( + _save_pretrained_fastai, # noqa: F401 + from_pretrained_fastai, # noqa: F401 + push_to_hub_fastai, # noqa: F401 + ) + from .file_download import ( + _CACHED_NO_EXIST, # noqa: F401 + HfFileMetadata, # noqa: F401 + cached_download, # noqa: F401 + get_hf_file_metadata, # noqa: F401 + hf_hub_download, # noqa: F401 + hf_hub_url, # noqa: F401 + try_to_load_from_cache, # noqa: F401 + ) + from .hf_api import ( + CommitInfo, # noqa: F401 + CommitOperation, # noqa: F401 + CommitOperationAdd, # noqa: F401 + CommitOperationDelete, # noqa: F401 + DatasetSearchArguments, # noqa: F401 + GitRefInfo, # noqa: F401 + GitRefs, # noqa: F401 + HfApi, # noqa: F401 + ModelSearchArguments, # noqa: F401 + RepoUrl, # noqa: F401 + UserLikes, # noqa: F401 + add_space_secret, # noqa: F401 + change_discussion_status, # noqa: F401 + comment_discussion, # noqa: F401 + create_branch, # noqa: F401 + create_commit, # noqa: F401 + create_discussion, # noqa: F401 + create_pull_request, # noqa: F401 + create_repo, # noqa: F401 + create_tag, # noqa: F401 + dataset_info, # noqa: F401 + delete_branch, # noqa: F401 + delete_file, # noqa: F401 + delete_folder, # noqa: F401 + delete_repo, # noqa: F401 + delete_space_secret, # noqa: F401 + delete_tag, # noqa: F401 + edit_discussion_comment, # noqa: F401 + get_dataset_tags, # noqa: F401 + get_discussion_details, # noqa: F401 + get_full_repo_name, # noqa: F401 + get_model_tags, # noqa: F401 + get_repo_discussions, # noqa: F401 + get_space_runtime, # noqa: F401 + like, # noqa: F401 + list_datasets, # noqa: F401 + list_liked_repos, # noqa: F401 + list_metrics, # noqa: F401 + list_models, # noqa: F401 + list_repo_files, # noqa: F401 + list_repo_refs, # noqa: F401 + list_spaces, # noqa: F401 + merge_pull_request, # noqa: F401 + model_info, # noqa: F401 + move_repo, # noqa: F401 + rename_discussion, # noqa: F401 + repo_type_and_id_from_hf_id, # noqa: F401 + request_space_hardware, # noqa: F401 + set_access_token, # noqa: F401 + space_info, # noqa: F401 + unlike, # noqa: F401 + unset_access_token, # noqa: F401 + update_repo_visibility, # noqa: F401 + upload_file, # noqa: F401 + upload_folder, # noqa: F401 + whoami, # noqa: F401 + ) + from .hub_mixin import ( + ModelHubMixin, # noqa: F401 + PyTorchModelHubMixin, # noqa: F401 + ) from .inference_api import InferenceApi # noqa: F401 - from .keras_mixin import KerasModelHubMixin # noqa: F401 - from .keras_mixin import from_pretrained_keras # noqa: F401 - from .keras_mixin import push_to_hub_keras # noqa: F401 - from .keras_mixin import save_pretrained_keras # noqa: F401 - from .repocard import DatasetCard # noqa: F401 - from .repocard import ModelCard # noqa: F401 - from .repocard import metadata_eval_result # noqa: F401 - from .repocard import metadata_load # noqa: F401 - from .repocard import metadata_save # noqa: F401 - from .repocard import metadata_update # noqa: F401 - from .repocard_data import CardData # noqa: F401 - from .repocard_data import DatasetCardData # noqa: F401 - from .repocard_data import EvalResult # noqa: F401 - from .repocard_data import ModelCardData # noqa: F401 + from .keras_mixin import ( + KerasModelHubMixin, # noqa: F401 + from_pretrained_keras, # noqa: F401 + push_to_hub_keras, # noqa: F401 + save_pretrained_keras, # noqa: F401 + ) + from .repocard import ( + DatasetCard, # noqa: F401 + ModelCard, # noqa: F401 + metadata_eval_result, # noqa: F401 + metadata_load, # noqa: F401 + metadata_save, # noqa: F401 + metadata_update, # noqa: F401 + ) + from .repocard_data import ( + CardData, # noqa: F401 + DatasetCardData, # noqa: F401 + EvalResult, # noqa: F401 + ModelCardData, # noqa: F401 + ) from .repository import Repository # noqa: F401 - from .utils import CachedFileInfo # noqa: F401 - from .utils import CachedRepoInfo # noqa: F401 - from .utils import CachedRevisionInfo # noqa: F401 - from .utils import CacheNotFound # noqa: F401 - from .utils import CorruptedCacheException # noqa: F401 - from .utils import DeleteCacheStrategy # noqa: F401 - from .utils import HFCacheInfo # noqa: F401 - from .utils import HfFolder # noqa: F401 - from .utils import cached_assets_path # noqa: F401 - from .utils import dump_environment_info # noqa: F401 - from .utils import logging # noqa: F401 - from .utils import scan_cache_dir # noqa: F401 - from .utils.endpoint_helpers import DatasetFilter # noqa: F401 - from .utils.endpoint_helpers import ModelFilter # noqa: F401 + from .utils import ( + CachedFileInfo, # noqa: F401 + CachedRepoInfo, # noqa: F401 + CachedRevisionInfo, # noqa: F401 + CacheNotFound, # noqa: F401 + CorruptedCacheException, # noqa: F401 + DeleteCacheStrategy, # noqa: F401 + HFCacheInfo, # noqa: F401 + HfFolder, # noqa: F401 + cached_assets_path, # noqa: F401 + dump_environment_info, # noqa: F401 + logging, # noqa: F401 + scan_cache_dir, # noqa: F401 + ) + from .utils.endpoint_helpers import ( + DatasetFilter, # noqa: F401 + ModelFilter, # noqa: F401 + ) diff --git a/utils/check_static_imports.py b/utils/check_static_imports.py index 19dd2e5e9c..c245c07bca 100644 --- a/utils/check_static_imports.py +++ b/utils/check_static_imports.py @@ -14,11 +14,13 @@ # limitations under the License. """Contains a tool to reformat static imports in `huggingface_hub.__init__.py`.""" import argparse +import os import re +import tempfile from pathlib import Path from typing import NoReturn -import isort +from ruff.__main__ import find_ruff_bin from huggingface_hub import _SUBMOD_ATTRS @@ -72,10 +74,14 @@ def check_static_imports(update: bool) -> NoReturn: ] # Generate the expected `__init__.py` file content and apply formatter on it. - expected_init_content = isort.code( - reordered_content_before_static_checks + IF_TYPE_CHECKING_LINE + "\n".join(static_imports) + "\n", - config=isort.Config(settings_path=SETUP_CFG_PATH), - ) + with tempfile.TemporaryDirectory() as tmpdir: + filepath = Path(tmpdir) / "__init__.py" + filepath.write_text( + reordered_content_before_static_checks + IF_TYPE_CHECKING_LINE + "\n".join(static_imports) + "\n" + ) + ruff_bin = find_ruff_bin() + os.spawnv(os.P_WAIT, ruff_bin, ["ruff", str(filepath), "--fix"]) + expected_init_content = filepath.read_text() # If expected `__init__.py` content is different, test fails. If '--update-init-file' # is used, `__init__.py` file is updated before the test fails. From 9085a29a1df0f7e72e309c2294ac3bcfa303e1fc Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Tue, 7 Feb 2023 09:49:38 +0100 Subject: [PATCH 3/6] quiet make quality --- utils/check_static_imports.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/check_static_imports.py b/utils/check_static_imports.py index c245c07bca..e2caeb9fec 100644 --- a/utils/check_static_imports.py +++ b/utils/check_static_imports.py @@ -80,7 +80,7 @@ def check_static_imports(update: bool) -> NoReturn: reordered_content_before_static_checks + IF_TYPE_CHECKING_LINE + "\n".join(static_imports) + "\n" ) ruff_bin = find_ruff_bin() - os.spawnv(os.P_WAIT, ruff_bin, ["ruff", str(filepath), "--fix"]) + os.spawnv(os.P_WAIT, ruff_bin, ["ruff", str(filepath), "--fix", "--quiet"]) expected_init_content = filepath.read_text() # If expected `__init__.py` content is different, test fails. If '--update-init-file' From d2f42fcbd7675a4c657ead5679f30620bb766003 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Tue, 7 Feb 2023 10:06:02 +0100 Subject: [PATCH 4/6] update pre-commit and contributing guide --- .github/workflows/python-quality.yml | 5 ++--- .pre-commit-config.yaml | 28 +++------------------------- CONTRIBUTING.md | 8 ++++---- setup.cfg | 21 +-------------------- setup.py | 2 -- 5 files changed, 10 insertions(+), 54 deletions(-) diff --git a/.github/workflows/python-quality.yml b/.github/workflows/python-quality.yml index 69a1b58729..9c785d7109 100644 --- a/.github/workflows/python-quality.yml +++ b/.github/workflows/python-quality.yml @@ -27,9 +27,8 @@ jobs: run: | pip install --upgrade pip pip install .[dev] - - run: black --check tests src - - run: isort --check-only tests src - - run: flake8 tests src + - run: black --check tests src contrib + - run: ruff tests src contrib - run: python utils/check_contrib_list.py - run: python utils/check_static_imports.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index df44588bf0..b3760a9db8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,29 +12,7 @@ repos: rev: 22.3.0 hooks: - id: black - - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + - repo: https://github.com/charliermarsh/ruff-pre-commit # https://github.com/charliermarsh/ruff#usage + rev: 'v0.0.243' hooks: - - id: flake8 - types: [file, python] - - repo: https://github.com/PyCQA/isort - rev: 5.10.1 - hooks: - - id: isort - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.981 - hooks: - - id: mypy - # taken from https://github.com/pre-commit/mirrors-mypy/issues/33#issuecomment-735449356 - args: [src, --config-file=pyproject.toml] - pass_filenames: false - # Same list of dependencies as in `setup.py` - additional_dependencies: - [ - "types-PyYAML", - "types-requests", - "types-simplejson", - "types-toml", - "types-tqdm", - "types-urllib3", - ] + - id: ruff diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index aa002a4007..8a5feb16fe 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -155,7 +155,7 @@ Follow these steps to start contributing: $ make test ``` - `hugginface_hub` relies on `black` and `isort` to format its source code + `hugginface_hub` relies on `black` and `ruff` to format its source code consistently. You can install pre-commit hooks so that these styles are applied and checked on files that you have touched in each commit: @@ -181,7 +181,7 @@ Follow these steps to start contributing: $ make style ``` - `huggingface_hub` also uses `flake8` and a few custom scripts to check for coding mistakes. Quality + `huggingface_hub` also uses `ruff` and a few custom scripts to check for coding mistakes. Quality control runs in CI, however you can also run the same checks with: ```bash @@ -223,7 +223,7 @@ Follow these steps to start contributing: too! So everyone can see the changes in the Pull request, work in your local branch and push the changes to your fork. They will automatically appear in the pull request. - + 8. Once your changes have been approved, one of the project maintainers will merge your pull request for you. @@ -286,4 +286,4 @@ Fully testing Spaces is not possible on staging. We need to use the production e the tests locally. The token requires write permission and a credit card must be set on your account. - Note that if the token is not find, the related tests are skipped. \ No newline at end of file + Note that if the token is not find, the related tests are skipped. diff --git a/setup.cfg b/setup.cfg index e3f091e5e0..ddbf9ef125 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,22 +1,3 @@ -[isort] -default_section = FIRSTPARTY -ensure_newline_before_comments = True -force_grid_wrap = 0 -include_trailing_comma = True -known_first_party = huggingface_hub -line_length = 119 -lines_after_imports = 2 -multi_line_output = 3 -use_parentheses = True - -[flake8] -exclude = .git,__pycache__,old,build,dist,.venv* -# ignore = B028, E203, E501, E741, W503 -# ignore = B028, E203, E501, E741, W503 -ignore = E501, E741, E821, W605 -# select = ["E", "F", "I", "W"] -max-line-length = 119 - [tool:pytest] # -Werror::FutureWarning -> test fails if FutureWarning is thrown # -s -> logs are not captured @@ -26,4 +7,4 @@ max-line-length = 119 addopts = -Werror::FutureWarning --log-cli-level=INFO -sv --durations=0 env = HUGGINGFACE_CO_STAGING=1 - DISABLE_SYMLINKS_IN_WINDOWS_TESTS=1 \ No newline at end of file + DISABLE_SYMLINKS_IN_WINDOWS_TESTS=1 diff --git a/setup.py b/setup.py index 3cc9510557..60cf5afbeb 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,6 @@ def get_version() -> str: extras["tensorflow"] = ["tensorflow", "pydot", "graphviz"] extras["testing"] = extras["cli"] + [ - "isort>=5.5.4", "jedi", "Jinja2", "pytest", @@ -66,7 +65,6 @@ def get_version() -> str: extras["quality"] = [ "black~=23.1", "ruff>=0.0.241", - "isort>=5.5.4", "mypy==0.982", ] From 24f798b8a1ac91376a9269d6a8dbc56f7d41e4af Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Tue, 7 Feb 2023 10:21:27 +0100 Subject: [PATCH 5/6] set back mypy config --- .pre-commit-config.yaml | 2 +- pyproject.toml | 7 ++++++- src/huggingface_hub/file_download.py | 2 +- src/huggingface_hub/inference_api.py | 2 +- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b3760a9db8..7963b172a3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: - id: check-case-conflict - id: check-merge-conflict - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.1.0 hooks: - id: black - repo: https://github.com/charliermarsh/ruff-pre-commit # https://github.com/charliermarsh/ruff#usage diff --git a/pyproject.toml b/pyproject.toml index 5b231157d5..c7e8c4c3ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,11 @@ line-length = 119 target_version = ['py37', 'py38', 'py39', 'py310'] preview = true +[tool.mypy] +ignore_missing_imports = true +no_implicit_optional = true +scripts_are_modules = true + [tool.ruff] # Ignored rules: # "E501" -> line length violation @@ -13,4 +18,4 @@ line-length = 119 [tool.ruff.isort] lines-after-imports = 2 -known-first-party = ["huggingface_hub"] \ No newline at end of file +known-first-party = ["huggingface_hub"] diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 0b606e5039..bf0fd3fc8a 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -478,7 +478,7 @@ def http_get( # Download file using an external Rust-based package. Download is faster # (~2x speed-up) but support less features (no error handling, no retries, # no progress bars). - from hf_transfer import download # type: ignore + from hf_transfer import download logger.debug(f"Download {url} using HF_TRANSFER.") max_files = 100 diff --git a/src/huggingface_hub/inference_api.py b/src/huggingface_hub/inference_api.py index 7c26eb01f9..ce56794ae5 100644 --- a/src/huggingface_hub/inference_api.py +++ b/src/huggingface_hub/inference_api.py @@ -197,7 +197,7 @@ def __call__( " the image by yourself." ) - from PIL import Image # type: ignore + from PIL import Image return Image.open(io.BytesIO(response.content)) elif content_type == "application/json": From 0f595385cfade0eb3bc13470dab3c1969e7deb70 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Tue, 7 Feb 2023 10:33:47 +0100 Subject: [PATCH 6/6] FIX test due to moon-landing wording update --- tests/test_hf_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 1b4d18dbd3..7d64d6cca3 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -1108,7 +1108,7 @@ def test_create_commit_delete_folder_explicit(self): def test_create_commit_failing_implicit_delete_folder(self): with self.assertRaisesRegex( EntryNotFoundError, - "Make sure to differentiate file and folder paths", + 'A file with the name "1" does not exist', ): self._api.create_commit( operations=[CommitOperationDelete(path_in_repo="1")],