diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 298b306d308..ee5646c1b13 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,7 +64,7 @@ jobs: run: pip install --upgrade pyarrow huggingface-hub dill - name: Install depencencies (minimum versions) if: ${{ matrix.deps_versions != 'deps-latest' }} - run: pip install pyarrow==8.0.0 huggingface-hub==0.14.0 transformers dill==0.3.1.1 + run: pip install pyarrow==8.0.0 huggingface-hub==0.18.0 transformers dill==0.3.1.1 - name: Test with pytest run: | python -m pytest -rfExX -m ${{ matrix.test }} -n 2 --dist loadfile -sv ./tests/ diff --git a/setup.py b/setup.py index acad2a0da45..e6d0a96ffc0 100644 --- a/setup.py +++ b/setup.py @@ -131,7 +131,7 @@ "aiohttp", # To get datasets from the Datasets Hub on huggingface.co # minimum 0.14.0 to support HfFileSystem - "huggingface-hub>=0.14.0,<1.0.0", + "huggingface_hub>=0.18.0", # Utilities from PyPA to e.g., compare versions "packaging", # To parse YAML metadata from dataset cards diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 9a5c51c9d07..445dc7452d4 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -17,8 +17,10 @@ import contextlib import copy +import fnmatch import itertools import json +import math import os import posixpath import re @@ -31,7 +33,6 @@ from collections import Counter from collections.abc import Mapping from copy import deepcopy -from fnmatch import fnmatch from functools import partial, wraps from io import BytesIO from math import ceil, floor @@ -58,7 +59,7 @@ import pandas as pd import pyarrow as pa import pyarrow.compute as pc -from huggingface_hub import DatasetCard, DatasetCardData, HfApi, HfFolder +from huggingface_hub import CommitOperationAdd, CommitOperationDelete, DatasetCard, DatasetCardData, HfApi from multiprocess import Pool from requests import HTTPError @@ -66,7 +67,6 @@ from .arrow_reader import ArrowReader from .arrow_writer import ArrowWriter, OptimizedTypedSequence from .data_files import sanitize_patterns -from .download.download_config import DownloadConfig from .download.streaming_download_manager import xgetsize from .features import Audio, ClassLabel, Features, Image, Sequence, Value from .features.features import ( @@ -112,8 +112,7 @@ from .tasks import TaskTemplate from .utils import logging from .utils.deprecation_utils import deprecated -from .utils.file_utils import _retry, cached_path, estimate_dataset_size -from .utils.hub import hf_hub_url +from .utils.file_utils import _retry, estimate_dataset_size from .utils.info_utils import is_small_dataset from .utils.metadata import MetadataConfigs from .utils.py_utils import ( @@ -5151,101 +5150,20 @@ def _push_parquet_shards_to_hub( repo_id: str, data_dir: str = "data", split: Optional[str] = None, - private: Optional[bool] = False, token: Optional[str] = None, - branch: Optional[str] = None, + revision: Optional[str] = None, + create_pr: Optional[bool] = False, max_shard_size: Optional[Union[int, str]] = None, num_shards: Optional[int] = None, embed_external_files: bool = True, ) -> Tuple[str, str, int, int, List[str], int]: - """Pushes the dataset to the hub. - The dataset is pushed using HTTP requests and does not need to have neither git or git-lfs installed. - - Args: - repo_id (`str`): - The ID of the repository to push to in the following format: `/` or - `/`. Also accepts ``, which will default to the namespace - of the logged-in user. - data_dir (`str`): - The name of directory to store parquet files. Defaults to "data". - split (Optional, `str`): - The name of the split that will be given to that dataset. Defaults to `self.split`. - private (Optional `bool`, defaults to `False`): - Whether the dataset repository should be set to private or not. Only affects repository creation: - a repository that already exists will not be affected by that parameter. - token (Optional `str`): - An optional authentication token for the Hugging Face Hub. If no token is passed, will default - to the token saved locally when logging in with ``huggingface-cli login``. Will raise an error - if no token is passed and the user is not logged-in. - branch (Optional `str`): - The git branch on which to push the dataset. This defaults to the default branch as specified - in your repository, which defaults to `"main"`. - max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): - The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a - a unit (like `"5MB"`). - num_shards (`int`, *optional*): - Number of shards to write. By default the number of shards depends on `max_shard_size`. - - - embed_external_files (`bool`, default ``True``): - Whether to embed file bytes in the shards. - In particular, this will do the following before the push for the fields of type: - - - :class:`Audio` and class:`Image`: remove local path information and embed file content in the Parquet files. + """Pushes the dataset shards as Parquet files to the hub. Returns: - repo_id (`str`): ID of the repository in /` or `/` format - split (`str`): name of the uploaded split + additions (`List[CommitOperation]`): list of the `CommitOperationAdd` of the uploaded shards uploaded_size (`int`): number of uploaded bytes to the repository dataset_nbytes (`int`): approximate size in bytes of the uploaded dataset afer uncompression - repo_files (`List[str]`): list of files in the repository - deleted_size (`int`): number of deleted bytes in the repository - - Example: - - ```python - >>> dataset.push_to_hub("/", split="evaluation") - ``` """ - if max_shard_size is not None and num_shards is not None: - raise ValueError( - "Failed to push_to_hub: please specify either max_shard_size or num_shards, but not both." - ) - - api = HfApi(endpoint=config.HF_ENDPOINT) - token = token if token is not None else HfFolder.get_token() - - if token is None: - raise EnvironmentError( - "You need to provide a `token` or be logged in to Hugging Face with `huggingface-cli login`." - ) - - if split is None: - split = str(self.split) if self.split is not None else "train" - - if not re.match(_split_re, split): - raise ValueError(f"Split name should match '{_split_re}' but got '{split}'.") - - identifier = repo_id.split("/") - - if len(identifier) > 2: - raise ValueError( - f"The identifier should be in the format or /. It is {identifier}, " - "which doesn't conform to either format." - ) - elif len(identifier) == 1: - dataset_name = identifier[0] - organization_or_username = api.whoami(token)["name"] - repo_id = f"{organization_or_username}/{dataset_name}" - - api.create_repo( - repo_id, - token=token, - repo_type="dataset", - private=private, - exist_ok=True, - ) - # Find decodable columns, because if there are any, we need to: # embed the bytes from the files in the shards decodable_columns = ( @@ -5280,86 +5198,52 @@ def shards_with_embedded_external_files(shards): shards = shards_with_embedded_external_files(shards) - files = api.list_repo_files(repo_id, repo_type="dataset", revision=branch, token=token) - data_files = [file for file in files if file.startswith(f"{data_dir}/")] - - def path_in_repo(_index, shard): - return f"{data_dir}/{split}-{_index:05d}-of-{num_shards:05d}-{shard._fingerprint}.parquet" - - shards_iter = iter(shards) - first_shard = next(shards_iter) - first_shard_path_in_repo = path_in_repo(0, first_shard) - if first_shard_path_in_repo in data_files and num_shards < len(data_files): - logger.info("Resuming upload of the dataset shards.") + api = HfApi(endpoint=config.HF_ENDPOINT, token=token) uploaded_size = 0 - shards_path_in_repo = [] + additions = [] for index, shard in logging.tqdm( - enumerate(itertools.chain([first_shard], shards_iter)), - desc="Pushing dataset shards to the dataset hub", + enumerate(shards), + desc="Uploading the dataset shards", total=num_shards, disable=not logging.is_progress_bar_enabled(), ): - shard_path_in_repo = path_in_repo(index, shard) - # Upload a shard only if it doesn't already exist in the repository - if shard_path_in_repo not in data_files: - buffer = BytesIO() - shard.to_parquet(buffer) - uploaded_size += buffer.tell() - _retry( - api.upload_file, - func_kwargs={ - "path_or_fileobj": buffer.getvalue(), - "path_in_repo": shard_path_in_repo, - "repo_id": repo_id, - "token": token, - "repo_type": "dataset", - "revision": branch, - }, - exceptions=HTTPError, - status_codes=[504], - base_wait_time=2.0, - max_retries=5, - max_wait_time=20.0, - ) - shards_path_in_repo.append(shard_path_in_repo) - - # Cleanup to remove unused files - data_files_to_delete = [ - data_file - for data_file in data_files - if data_file.startswith(f"{data_dir}/{split}-") and data_file not in shards_path_in_repo - ] - download_config = DownloadConfig(token=token) - deleted_size = sum( - xgetsize(hf_hub_url(repo_id, data_file, revision=branch), download_config=download_config) - for data_file in data_files_to_delete - ) - - def delete_file(file): - api.delete_file(file, repo_id=repo_id, token=token, repo_type="dataset", revision=branch) - - if len(data_files_to_delete): - for data_file in logging.tqdm( - data_files_to_delete, - desc="Deleting unused files from dataset repository", - total=len(data_files_to_delete), - disable=not logging.is_progress_bar_enabled(), - ): - delete_file(data_file) - - repo_files = list(set(files) - set(data_files_to_delete)) + shard_path_in_repo = f"{data_dir}/{split}-{index:05d}-of-{num_shards:05d}.parquet" + buffer = BytesIO() + shard.to_parquet(buffer) + uploaded_size += buffer.tell() + shard_addition = CommitOperationAdd(path_in_repo=shard_path_in_repo, path_or_fileobj=buffer) + _retry( + api.preupload_lfs_files, + func_kwargs={ + "repo_id": repo_id, + "additions": [shard_addition], + "token": token, + "repo_type": "dataset", + "revision": revision, + "create_pr": create_pr, + }, + exceptions=HTTPError, + status_codes=[504], + base_wait_time=2.0, + max_retries=5, + max_wait_time=20.0, + ) + additions.append(shard_addition) - return repo_id, split, uploaded_size, dataset_nbytes, repo_files, deleted_size + return additions, uploaded_size, dataset_nbytes def push_to_hub( self, repo_id: str, config_name: str = "default", split: Optional[str] = None, + commit_message: Optional[str] = None, private: Optional[bool] = False, token: Optional[str] = None, - branch: Optional[str] = None, + revision: Optional[str] = None, + branch="deprecated", + create_pr: Optional[bool] = False, max_shard_size: Optional[Union[int, str]] = None, num_shards: Optional[int] = None, embed_external_files: bool = True, @@ -5377,9 +5261,11 @@ def push_to_hub( `/`. Also accepts ``, which will default to the namespace of the logged-in user. config_name (`str`, defaults to "default"): - The configuration name (or subset) of a dataset. Defaults to "default" + The configuration name (or subset) of a dataset. Defaults to "default". split (`str`, *optional*): The name of the split that will be given to that dataset. Defaults to `self.split`. + commit_message (`str`, *optional*): + Message to commit while pushing. Will default to `"Upload dataset"`. private (`bool`, *optional*, defaults to `False`): Whether the dataset repository should be set to private or not. Only affects repository creation: a repository that already exists will not be affected by that parameter. @@ -5387,9 +5273,23 @@ def push_to_hub( An optional authentication token for the Hugging Face Hub. If no token is passed, will default to the token saved locally when logging in with `huggingface-cli login`. Will raise an error if no token is passed and the user is not logged-in. + revision (`str`, *optional*): + Branch to push the uploaded files to. Defaults to the `"main"` branch. + + branch (`str`, *optional*): The git branch on which to push the dataset. This defaults to the default branch as specified in your repository, which defaults to `"main"`. + + + + `branch` was deprecated in favor of `revision` in version 2.15.0 and will be removed in 3.0.0. + + + create_pr (`bool`, *optional*, defaults to `False`): + Whether or not to create a PR with the uploaded files or directly commit. + + max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). @@ -5439,20 +5339,76 @@ def push_to_hub( raise ValueError( "Failed to push_to_hub: please specify either max_shard_size or num_shards, but not both." ) + + if split is None: + split = str(self.split) if self.split is not None else "train" + + if not re.match(_split_re, split): + raise ValueError(f"Split name should match '{_split_re}' but got '{split}'.") + + if branch != "deprecated": + warnings.warn( + "'branch' was deprecated in favor of 'revision' in version 2.15.0 and will be removed in 3.0.0.\n" + f"You can remove this warning by passing 'revision={branch}' instead.", + FutureWarning, + ) + revision = branch + + api = HfApi(endpoint=config.HF_ENDPOINT, token=token) + + repo_url = api.create_repo( + repo_id, + token=token, + repo_type="dataset", + private=private, + exist_ok=True, + ) + repo_id = repo_url.repo_id + + if revision is not None: + api.create_branch(repo_id, branch=revision, token=token, repo_type="dataset", exist_ok=True) + data_dir = config_name if config_name != "default" else "data" # for backward compatibility - repo_id, split, uploaded_size, dataset_nbytes, repo_files, deleted_size = self._push_parquet_shards_to_hub( + additions, uploaded_size, dataset_nbytes = self._push_parquet_shards_to_hub( repo_id=repo_id, data_dir=data_dir, split=split, - private=private, token=token, - branch=branch, + revision=revision, max_shard_size=max_shard_size, num_shards=num_shards, + create_pr=create_pr, embed_external_files=embed_external_files, ) - organization, dataset_name = repo_id.split("/") + + # Check if the repo already has a README.md and/or a dataset_infos.json to update them with the new split info (size and pattern) + # and delete old split shards (if they exist) + repo_with_dataset_card, repo_with_dataset_infos = False, False + deletions, deleted_size = [], 0 + repo_splits = [] # use a list to keep the order of the splits + repo_files_to_add = [addition.path_in_repo for addition in additions] + for repo_file in api.list_files_info(repo_id, revision=revision, repo_type="dataset", token=token): + if repo_file.rfilename == "README.md": + repo_with_dataset_card = True + elif repo_file.rfilename == config.DATASETDICT_INFOS_FILENAME: + repo_with_dataset_infos = True + elif ( + repo_file.rfilename.startswith(f"{data_dir}/{split}-") and repo_file.rfilename not in repo_files_to_add + ): + deletions.append(CommitOperationDelete(path_in_repo=repo_file.rfilename)) + deleted_size += repo_file.size + elif fnmatch.fnmatch( + repo_file.rfilename, PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED.replace("{split}", "*") + ): + repo_split = string_to_dict( + repo_file.rfilename, + glob_pattern_to_regex(PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED), + )["split"] + if repo_split not in repo_splits: + repo_splits.append(repo_split) + + organization, dataset_name = repo_id.split("/") if "/" in repo_id else (None, repo_id) info_to_dump = self.info.copy() info_to_dump.download_checksums = None info_to_dump.download_size = uploaded_size @@ -5463,15 +5419,9 @@ def push_to_hub( {split: SplitInfo(split, num_bytes=dataset_nbytes, num_examples=len(self), dataset_name=dataset_name)} ) # get the info from the README to update them - if "README.md" in repo_files: - download_config = DownloadConfig() - download_config.download_desc = "Downloading metadata" - download_config.token = token - dataset_readme_path = cached_path( - hf_hub_url(repo_id, "README.md", revision=branch), - download_config=download_config, - ) - dataset_card = DatasetCard.load(Path(dataset_readme_path)) + if repo_with_dataset_card: + dataset_card_path = api.hf_hub_download(repo_id, "README.md", repo_type="dataset", revision=revision) + dataset_card = DatasetCard.load(Path(dataset_card_path)) dataset_card_data = dataset_card.data metadata_configs = MetadataConfigs.from_dataset_card_data(dataset_card_data) dataset_infos: DatasetInfosDict = DatasetInfosDict.from_dataset_card_data(dataset_card_data) @@ -5480,16 +5430,12 @@ def push_to_hub( else: repo_info = None # get the deprecated dataset_infos.json to update them - elif config.DATASETDICT_INFOS_FILENAME in repo_files: + elif repo_with_dataset_infos: dataset_card = None dataset_card_data = DatasetCardData() - download_config = DownloadConfig() metadata_configs = MetadataConfigs() - download_config.download_desc = "Downloading metadata" - download_config.token = token - dataset_infos_path = cached_path( - hf_hub_url(repo_id, config.DATASETDICT_INFOS_FILENAME, revision=branch), - download_config=download_config, + dataset_infos_path = api.hf_hub_download( + repo_id, config.DATASETDICT_INFOS_FILENAME, repo_type="dataset", revision=revision ) with open(dataset_infos_path, encoding="utf-8") as f: dataset_infos: dict = json.load(f) @@ -5517,32 +5463,17 @@ def push_to_hub( repo_info.download_size = (repo_info.download_size or 0) + uploaded_size repo_info.dataset_size = (repo_info.dataset_size or 0) + dataset_nbytes repo_info.size_in_bytes = repo_info.download_size + repo_info.dataset_size + repo_info.splits.pop(split, None) repo_info.splits[split] = SplitInfo( split, num_bytes=dataset_nbytes, num_examples=len(self), dataset_name=dataset_name ) info_to_dump = repo_info # create the metadata configs if it was uploaded with push_to_hub before metadata configs existed - if not metadata_configs: - _matched_paths = [ - p - for p in repo_files - if fnmatch(p, PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED.replace("{split}", "*")) - ] - if len(_matched_paths) > 0: - # it was uploaded with push_to_hub before metadata configs existed - _resolved_splits = { - string_to_dict( - p, glob_pattern_to_regex(PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED) - )["split"] - for p in _matched_paths - } - default_metadata_configs_to_dump = { - "data_files": [ - {"split": _resolved_split, "path": f"data/{_resolved_split}-*"} - for _resolved_split in _resolved_splits - ] - } - MetadataConfigs({"default": default_metadata_configs_to_dump}).to_dataset_card_data(dataset_card_data) + if not metadata_configs and repo_splits: + default_metadata_configs_to_dump = { + "data_files": [{"split": split, "path": f"data/{split}-*"} for split in repo_splits] + } + MetadataConfigs({"default": default_metadata_configs_to_dump}).to_dataset_card_data(dataset_card_data) # update the metadata configs if config_name in metadata_configs: metadata_config = metadata_configs[config_name] @@ -5564,48 +5495,60 @@ def push_to_hub( else: metadata_config_to_dump = {"data_files": [{"split": split, "path": f"{data_dir}/{split}-*"}]} # push to the deprecated dataset_infos.json - if config.DATASETDICT_INFOS_FILENAME in repo_files: - download_config = DownloadConfig() - download_config.download_desc = "Downloading deprecated dataset_infos.json" - download_config.use_auth_token = token - dataset_infos_path = cached_path( - hf_hub_url(repo_id, config.DATASETDICT_INFOS_FILENAME, revision=branch), - download_config=download_config, + if repo_with_dataset_infos: + dataset_infos_path = api.hf_hub_download( + repo_id, config.DATASETDICT_INFOS_FILENAME, repo_type="dataset", revision=revision ) with open(dataset_infos_path, encoding="utf-8") as f: dataset_infos: dict = json.load(f) dataset_infos[config_name] = asdict(info_to_dump) buffer = BytesIO() buffer.write(json.dumps(dataset_infos, indent=4).encode("utf-8")) - HfApi(endpoint=config.HF_ENDPOINT).upload_file( - path_or_fileobj=buffer.getvalue(), - path_in_repo=config.DATASETDICT_INFOS_FILENAME, - repo_id=repo_id, - token=token, - repo_type="dataset", - revision=branch, + additions.append( + CommitOperationAdd(path_in_repo=config.DATASETDICT_INFOS_FILENAME, path_or_fileobj=buffer) ) # push to README DatasetInfosDict({config_name: info_to_dump}).to_dataset_card_data(dataset_card_data) MetadataConfigs({config_name: metadata_config_to_dump}).to_dataset_card_data(dataset_card_data) - dataset_card = ( - DatasetCard( - "---\n" - + str(dataset_card_data) - + "\n---\n" - + f'# Dataset Card for "{repo_id.split("/")[-1]}"\n\n[More Information needed](https://github.com/huggingface/datasets/blob/main/CONTRIBUTING.md#how-to-contribute-to-the-dataset-cards)' + dataset_card = DatasetCard(f"---\n{dataset_card_data}\n---\n") if dataset_card is None else dataset_card + additions.append(CommitOperationAdd(path_in_repo="README.md", path_or_fileobj=str(dataset_card).encode())) + + if len(additions) <= config.UPLOADS_MAX_NUMBER_PER_COMMIT: + api.create_commit( + repo_id, + operations=additions + deletions, + commit_message=commit_message if commit_message is not None else "Upload dataset", + token=token, + repo_type="dataset", + revision=revision, + create_pr=create_pr, ) - if dataset_card is None - else dataset_card - ) - HfApi(endpoint=config.HF_ENDPOINT).upload_file( - path_or_fileobj=str(dataset_card).encode(), - path_in_repo="README.md", - repo_id=repo_id, - token=token, - repo_type="dataset", - revision=branch, - ) + else: + logger.info( + f"Number of files to upload is larger than {config.UPLOADS_MAX_NUMBER_PER_COMMIT}. Splitting the push into multiple commits." + ) + num_commits = math.ceil(len(additions) / config.UPLOADS_MAX_NUMBER_PER_COMMIT) + for i in range(0, num_commits): + operations = additions[ + i * config.UPLOADS_MAX_NUMBER_PER_COMMIT : (i + 1) * config.UPLOADS_MAX_NUMBER_PER_COMMIT + ] + (deletions if i == 0 else []) + commit_message = ( + commit_message if commit_message is not None else "Upload dataset" + ) + f" (part {i:05d}-of-{num_commits:05d})" + api.create_commit( + repo_id, + operations=operations, + commit_message=commit_message, + token=token, + repo_type="dataset", + revision=revision, + create_pr=create_pr, + ) + logger.info( + f"Commit #{i+1} completed" + + (f" (still {num_commits - i - 1} to go)" if num_commits - i - 1 else "") + + "." + ) @transmit_format @fingerprint_transform(inplace=False) diff --git a/src/datasets/config.py b/src/datasets/config.py index 798d31c72eb..3e1f20475ac 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -222,3 +222,6 @@ # Progress bars PBAR_REFRESH_TIME_INTERVAL = 0.05 # 20 progress updates per sec + +# Maximum number of uploaded files per commit +UPLOADS_MAX_NUMBER_PER_COMMIT = 50 diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 190a5d37095..23c8dcf6c75 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1,22 +1,28 @@ import contextlib import copy +import fnmatch import json +import math import os import posixpath import re import warnings -from fnmatch import fnmatch from io import BytesIO from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import fsspec import numpy as np -from huggingface_hub import DatasetCard, DatasetCardData, HfApi +from huggingface_hub import ( + CommitOperationAdd, + CommitOperationDelete, + DatasetCard, + DatasetCardData, + HfApi, +) from . import config from .arrow_dataset import PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED, Dataset -from .download import DownloadConfig from .features import Features from .features.features import FeatureType from .filesystems import extract_path_from_uri, is_remote_filesystem @@ -28,8 +34,6 @@ from .utils import logging from .utils.deprecation_utils import deprecated from .utils.doc_utils import is_documented_by -from .utils.file_utils import cached_path -from .utils.hub import hf_hub_url from .utils.metadata import MetadataConfigs from .utils.py_utils import asdict, glob_pattern_to_regex, string_to_dict from .utils.typing import PathLike @@ -1559,9 +1563,12 @@ def push_to_hub( self, repo_id, config_name: str = "default", + commit_message: Optional[str] = None, private: Optional[bool] = False, token: Optional[str] = None, - branch: Optional[None] = None, + revision: Optional[str] = None, + branch="deprecated", + create_pr: Optional[bool] = False, max_shard_size: Optional[Union[int, str]] = None, num_shards: Optional[Dict[str, int]] = None, embed_external_files: bool = True, @@ -1580,17 +1587,34 @@ def push_to_hub( The ID of the repository to push to in the following format: `/` or `/`. Also accepts ``, which will default to the namespace of the logged-in user. + config_name (`str`): + Configuration name of a dataset. Defaults to "default". + commit_message (`str`, *optional*): + Message to commit while pushing. Will default to `"Upload dataset"`. private (`bool`, *optional*): Whether the dataset repository should be set to private or not. Only affects repository creation: a repository that already exists will not be affected by that parameter. - config_name (`str`): - Configuration name of a dataset. Defaults to "default". token (`str`, *optional*): An optional authentication token for the Hugging Face Hub. If no token is passed, will default to the token saved locally when logging in with `huggingface-cli login`. Will raise an error if no token is passed and the user is not logged-in. + revision (`str`, *optional*): + Branch to push the uploaded files to. Defaults to the `"main"` branch. + + branch (`str`, *optional*): - The git branch on which to push the dataset. + The git branch on which to push the dataset. This defaults to the default branch as specified + in your repository, which defaults to `"main"`. + + + + `branch` was deprecated in favor of `revision` in version 2.15.0 and will be removed in 3.0.0. + + + create_pr (`bool`, *optional*, defaults to `False`): + Whether or not to create a PR with the uploaded files or directly commit. + + max_shard_size (`int` or `str`, *optional*, defaults to `"500MB"`): The maximum size of the dataset shards to be uploaded to the hub. If expressed as a string, needs to be digits followed by a unit (like `"500MB"` or `"1GB"`). @@ -1632,6 +1656,14 @@ def push_to_hub( "Please provide one `num_shards` per dataset in the dataset dictionary, e.g. {{'train': 128, 'test': 4}}" ) + if branch != "deprecated": + warnings.warn( + "'branch' was deprecated in favor of 'revision' in version 2.15.0 and will be removed in 3.0.0.\n" + f"You can remove this warning by passing 'revision={branch}' instead.", + FutureWarning, + ) + revision = branch + self._check_values_type() self._check_values_features() total_uploaded_size = 0 @@ -1644,21 +1676,38 @@ def push_to_hub( if not re.match(_split_re, split): raise ValueError(f"Split name should match '{_split_re}' but got '{split}'.") + api = HfApi(endpoint=config.HF_ENDPOINT, token=token) + + repo_url = api.create_repo( + repo_id, + token=token, + repo_type="dataset", + private=private, + exist_ok=True, + ) + repo_id = repo_url.repo_id + + if revision is not None: + api.create_branch(repo_id, branch=revision, token=token, repo_type="dataset", exist_ok=True) + data_dir = config_name if config_name != "default" else "data" # for backward compatibility + + additions = [] for split in self.keys(): logger.info(f"Pushing split {split} to the Hub.") # The split=key needs to be removed before merging - repo_id, split, uploaded_size, dataset_nbytes, _, _ = self[split]._push_parquet_shards_to_hub( + split_additions, uploaded_size, dataset_nbytes = self[split]._push_parquet_shards_to_hub( repo_id, data_dir=data_dir, split=split, - private=private, token=token, - branch=branch, + revision=revision, + create_pr=create_pr, max_shard_size=max_shard_size, num_shards=num_shards.get(split), embed_external_files=embed_external_files, ) + additions += split_additions total_uploaded_size += uploaded_size total_dataset_nbytes += dataset_nbytes info_to_dump.splits[split] = SplitInfo(str(split), num_bytes=dataset_nbytes, num_examples=len(self[split])) @@ -1671,23 +1720,40 @@ def push_to_hub( "data_files": [{"split": split, "path": f"{data_dir}/{split}-*"} for split in self.keys()], } - api = HfApi(endpoint=config.HF_ENDPOINT) - repo_files = api.list_repo_files(repo_id, repo_type="dataset", revision=branch, token=token) + # Check if the repo already has a README.md and/or a dataset_infos.json to update them with the new split info (size and pattern) + # and delete old split shards (if they exist) + repo_with_dataset_card, repo_with_dataset_infos = False, False + repo_splits = [] # use a list to keep the order of the splits + deletions = [] + repo_files_to_add = [addition.path_in_repo for addition in additions] + for repo_file in api.list_files_info(repo_id, revision=revision, repo_type="dataset", token=token): + if repo_file.rfilename == "README.md": + repo_with_dataset_card = True + elif repo_file.rfilename == config.DATASETDICT_INFOS_FILENAME: + repo_with_dataset_infos = True + elif ( + repo_file.rfilename.startswith(tuple(f"{data_dir}/{split}-" for split in self.keys())) + and repo_file.rfilename not in repo_files_to_add + ): + deletions.append(CommitOperationDelete(path_in_repo=repo_file.rfilename)) + elif fnmatch.fnmatch( + repo_file.rfilename, PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED.replace("{split}", "*") + ): + repo_split = string_to_dict( + repo_file.rfilename, + glob_pattern_to_regex(PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED), + )["split"] + if repo_split not in repo_splits: + repo_splits.append(split) # get the info from the README to update them - if "README.md" in repo_files: - download_config = DownloadConfig() - download_config.download_desc = "Downloading metadata" - download_config.token = token - dataset_readme_path = cached_path( - hf_hub_url(repo_id, "README.md", revision=branch), - download_config=download_config, - ) - dataset_card = DatasetCard.load(Path(dataset_readme_path)) + if repo_with_dataset_card: + dataset_card_path = api.hf_hub_download(repo_id, "README.md", repo_type="dataset", revision=revision) + dataset_card = DatasetCard.load(Path(dataset_card_path)) dataset_card_data = dataset_card.data metadata_configs = MetadataConfigs.from_dataset_card_data(dataset_card_data) # get the deprecated dataset_infos.json to update them - elif config.DATASETDICT_INFOS_FILENAME in repo_files: + elif repo_with_dataset_infos: dataset_card = None dataset_card_data = DatasetCardData() metadata_configs = MetadataConfigs() @@ -1696,70 +1762,66 @@ def push_to_hub( dataset_card_data = DatasetCardData() metadata_configs = MetadataConfigs() # create the metadata configs if it was uploaded with push_to_hub before metadata configs existed - if not metadata_configs: - _matched_paths = [ - p - for p in repo_files - if fnmatch(p, PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED.replace("{split}", "*")) - ] - if len(_matched_paths) > 0: - # it was uploaded with push_to_hub before metadata configs existed - _resolved_splits = { - string_to_dict( - p, glob_pattern_to_regex(PUSH_TO_HUB_WITHOUT_METADATA_CONFIGS_SPLIT_PATTERN_SHARDED) - )["split"] - for p in _matched_paths - } - default_metadata_configs_to_dump = { - "data_files": [ - {"split": _resolved_split, "path": f"data/{_resolved_split}-*"} - for _resolved_split in _resolved_splits - ] - } - MetadataConfigs({"default": default_metadata_configs_to_dump}).to_dataset_card_data(dataset_card_data) + if not metadata_configs and repo_splits: + default_metadata_configs_to_dump = { + "data_files": [{"split": split, "path": f"data/{split}-*"} for split in repo_splits] + } + MetadataConfigs({"default": default_metadata_configs_to_dump}).to_dataset_card_data(dataset_card_data) # push to the deprecated dataset_infos.json - if config.DATASETDICT_INFOS_FILENAME in repo_files: - download_config = DownloadConfig() - download_config.download_desc = "Downloading metadata" - download_config.token = token - dataset_infos_path = cached_path( - hf_hub_url(repo_id, config.DATASETDICT_INFOS_FILENAME, revision=branch), - download_config=download_config, + if repo_with_dataset_infos: + dataset_infos_path = api.hf_hub_download( + repo_id, config.DATASETDICT_INFOS_FILENAME, repo_type="dataset", revision=revision ) with open(dataset_infos_path, encoding="utf-8") as f: dataset_infos: dict = json.load(f) dataset_infos[config_name] = asdict(info_to_dump) buffer = BytesIO() buffer.write(json.dumps(dataset_infos, indent=4).encode("utf-8")) - HfApi(endpoint=config.HF_ENDPOINT).upload_file( - path_or_fileobj=buffer.getvalue(), - path_in_repo=config.DATASETDICT_INFOS_FILENAME, - repo_id=repo_id, - token=token, - repo_type="dataset", - revision=branch, + additions.append( + CommitOperationAdd(path_in_repo=config.DATASETDICT_INFOS_FILENAME, path_or_fileobj=buffer) ) # push to README DatasetInfosDict({config_name: info_to_dump}).to_dataset_card_data(dataset_card_data) MetadataConfigs({config_name: metadata_config_to_dump}).to_dataset_card_data(dataset_card_data) - dataset_card = ( - DatasetCard( - "---\n" - + str(dataset_card_data) - + "\n---\n" - + f'# Dataset Card for "{repo_id.split("/")[-1]}"\n\n[More Information needed](https://github.com/huggingface/datasets/blob/main/CONTRIBUTING.md#how-to-contribute-to-the-dataset-cards)' + dataset_card = DatasetCard(f"---\n{dataset_card_data}\n---\n") if dataset_card is None else dataset_card + additions.append(CommitOperationAdd(path_in_repo="README.md", path_or_fileobj=str(dataset_card).encode())) + + if len(additions) <= config.UPLOADS_MAX_NUMBER_PER_COMMIT: + api.create_commit( + repo_id, + operations=additions + deletions, + commit_message=commit_message if commit_message is not None else "Upload dataset", + token=token, + repo_type="dataset", + revision=revision, + create_pr=create_pr, ) - if dataset_card is None - else dataset_card - ) - HfApi(endpoint=config.HF_ENDPOINT).upload_file( - path_or_fileobj=str(dataset_card).encode(), - path_in_repo="README.md", - repo_id=repo_id, - token=token, - repo_type="dataset", - revision=branch, - ) + else: + logger.info( + f"Number of files to upload is larger than {config.UPLOADS_MAX_NUMBER_PER_COMMIT}. Splitting the push into multiple commits." + ) + num_commits = math.ceil(len(additions) / config.UPLOADS_MAX_NUMBER_PER_COMMIT) + for i in range(0, num_commits): + operations = additions[ + i * config.UPLOADS_MAX_NUMBER_PER_COMMIT : (i + 1) * config.UPLOADS_MAX_NUMBER_PER_COMMIT + ] + (deletions if i == 0 else []) + commit_message = ( + commit_message if commit_message is not None else "Upload dataset" + ) + f" (part {i:05d}-of-{num_commits:05d})" + api.create_commit( + repo_id, + operations=operations, + commit_message=commit_message, + token=token, + repo_type="dataset", + revision=revision, + create_pr=create_pr, + ) + logger.info( + f"Commit #{i+1} completed" + + (f" (still {num_commits - i - 1} to go)" if num_commits - i - 1 else "") + + "." + ) class IterableDatasetDict(dict): diff --git a/tests/test_upstream_hub.py b/tests/test_upstream_hub.py index 6d4b33f85c5..b240df8d7a9 100644 --- a/tests/test_upstream_hub.py +++ b/tests/test_upstream_hub.py @@ -1,4 +1,5 @@ import fnmatch +import gc import os import tempfile import time @@ -53,12 +54,7 @@ def test_push_dataset_dict_to_hub_no_token(self, temporary_repo, set_ci_hub_acce # Ensure that there is a single file on the repository that has the correct name files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset")) - assert all( - fnmatch.fnmatch(file, expected_file) - for file, expected_file in zip( - files, [".gitattributes", "README.md", "data/train-00000-of-00001-*.parquet"] - ) - ) + assert files == [".gitattributes", "README.md", "data/train-00000-of-00001.parquet"] def test_push_dataset_dict_to_hub_name_without_namespace(self, temporary_repo): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) @@ -75,12 +71,7 @@ def test_push_dataset_dict_to_hub_name_without_namespace(self, temporary_repo): # Ensure that there is a single file on the repository that has the correct name files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset")) - assert all( - fnmatch.fnmatch(file, expected_file) - for file, expected_file in zip( - files, [".gitattributes", "README.md", "data/train-00000-of-00001-*.parquet"] - ) - ) + assert files == [".gitattributes", "README.md", "data/train-00000-of-00001.parquet"] def test_push_dataset_dict_to_hub_datasets_with_different_features(self, cleanup_repo): ds_train = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) @@ -111,12 +102,7 @@ def test_push_dataset_dict_to_hub_private(self, temporary_repo): # Ensure that there is a single file on the repository that has the correct name files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) - assert all( - fnmatch.fnmatch(file, expected_file) - for file, expected_file in zip( - files, [".gitattributes", "README.md", "data/train-00000-of-00001-*.parquet"] - ) - ) + assert files == [".gitattributes", "README.md", "data/train-00000-of-00001.parquet"] def test_push_dataset_dict_to_hub(self, temporary_repo): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) @@ -133,12 +119,43 @@ def test_push_dataset_dict_to_hub(self, temporary_repo): # Ensure that there is a single file on the repository that has the correct name files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) - assert all( - fnmatch.fnmatch(file, expected_file) - for file, expected_file in zip( - files, [".gitattributes", "README.md", "data/train-00000-of-00001-*.parquet"] - ) + assert files == [".gitattributes", "README.md", "data/train-00000-of-00001.parquet"] + + def test_push_dataset_dict_to_hub_with_pull_request(self, temporary_repo): + ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) + + local_ds = DatasetDict({"train": ds}) + + with temporary_repo() as ds_name: + local_ds.push_to_hub(ds_name, token=self._token, create_pr=True) + hub_ds = load_dataset(ds_name, revision="refs/pr/1", download_mode="force_redownload") + + assert local_ds["train"].features == hub_ds["train"].features + assert list(local_ds.keys()) == list(hub_ds.keys()) + assert local_ds["train"].features == hub_ds["train"].features + + # Ensure that there is a single file on the repository that has the correct name + files = sorted( + self._api.list_repo_files(ds_name, revision="refs/pr/1", repo_type="dataset", token=self._token) ) + assert files == [".gitattributes", "README.md", "data/train-00000-of-00001.parquet"] + + def test_push_dataset_dict_to_hub_with_revision(self, temporary_repo): + ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) + + local_ds = DatasetDict({"train": ds}) + + with temporary_repo() as ds_name: + local_ds.push_to_hub(ds_name, token=self._token, revision="dev") + hub_ds = load_dataset(ds_name, revision="dev", download_mode="force_redownload") + + assert local_ds["train"].features == hub_ds["train"].features + assert list(local_ds.keys()) == list(hub_ds.keys()) + assert local_ds["train"].features == hub_ds["train"].features + + # Ensure that there is a single file on the repository that has the correct name + files = sorted(self._api.list_repo_files(ds_name, revision="dev", repo_type="dataset", token=self._token)) + assert files == [".gitattributes", "README.md", "data/train-00000-of-00001.parquet"] def test_push_dataset_dict_to_hub_multiple_files(self, temporary_repo): ds = Dataset.from_dict({"x": list(range(1000)), "y": list(range(1000))}) @@ -156,18 +173,12 @@ def test_push_dataset_dict_to_hub_multiple_files(self, temporary_repo): # Ensure that there are two files on the repository that have the correct name files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) - assert all( - fnmatch.fnmatch(file, expected_file) - for file, expected_file in zip( - files, - [ - ".gitattributes", - "README.md", - "data/train-00000-of-00002-*.parquet", - "data/train-00001-of-00002-*.parquet", - ], - ) - ) + assert files == [ + ".gitattributes", + "README.md", + "data/train-00000-of-00002.parquet", + "data/train-00001-of-00002.parquet", + ] def test_push_dataset_dict_to_hub_multiple_files_with_max_shard_size(self, temporary_repo): ds = Dataset.from_dict({"x": list(range(1000)), "y": list(range(1000))}) @@ -184,18 +195,12 @@ def test_push_dataset_dict_to_hub_multiple_files_with_max_shard_size(self, tempo # Ensure that there are two files on the repository that have the correct name files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) - assert all( - fnmatch.fnmatch(file, expected_file) - for file, expected_file in zip( - files, - [ - ".gitattributes", - "README.md", - "data/train-00000-of-00002-*.parquet", - "data/train-00001-of-00002-*.parquet", - ], - ) - ) + assert files == [ + ".gitattributes", + "README.md", + "data/train-00000-of-00002.parquet", + "data/train-00001-of-00002.parquet", + ] def test_push_dataset_dict_to_hub_multiple_files_with_num_shards(self, temporary_repo): ds = Dataset.from_dict({"x": list(range(1000)), "y": list(range(1000))}) @@ -212,18 +217,42 @@ def test_push_dataset_dict_to_hub_multiple_files_with_num_shards(self, temporary # Ensure that there are two files on the repository that have the correct name files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) - assert all( - fnmatch.fnmatch(file, expected_file) - for file, expected_file in zip( - files, - [ - ".gitattributes", - "README.md", - "data/train-00000-of-00002-*.parquet", - "data/train-00001-of-00002-*.parquet", - ], - ) - ) + assert files == [ + ".gitattributes", + "README.md", + "data/train-00000-of-00002.parquet", + "data/train-00001-of-00002.parquet", + ] + + def test_push_dataset_dict_to_hub_with_multiple_commits(self, temporary_repo): + ds = Dataset.from_dict({"x": list(range(1000)), "y": list(range(1000))}) + + local_ds = DatasetDict({"train": ds}) + + with temporary_repo() as ds_name: + self._api.create_repo(ds_name, token=self._token, repo_type="dataset") + num_commits_before_push = len(self._api.list_repo_commits(ds_name, repo_type="dataset", token=self._token)) + with patch("datasets.config.MAX_SHARD_SIZE", "16KB"), patch( + "datasets.config.UPLOADS_MAX_NUMBER_PER_COMMIT", 1 + ): + local_ds.push_to_hub(ds_name, token=self._token) + hub_ds = load_dataset(ds_name, download_mode="force_redownload") + + assert local_ds.column_names == hub_ds.column_names + assert list(local_ds["train"].features.keys()) == list(hub_ds["train"].features.keys()) + assert local_ds["train"].features == hub_ds["train"].features + + # Ensure that there are two files on the repository that have the correct name + files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) + assert files == [ + ".gitattributes", + "README.md", + "data/train-00000-of-00002.parquet", + "data/train-00001-of-00002.parquet", + ] + + num_commits_after_push = len(self._api.list_repo_commits(ds_name, repo_type="dataset", token=self._token)) + assert num_commits_after_push - num_commits_before_push > 1 def test_push_dataset_dict_to_hub_overwrite_files(self, temporary_repo): ds = Dataset.from_dict({"x": list(range(1000)), "y": list(range(1000))}) @@ -254,21 +283,14 @@ def test_push_dataset_dict_to_hub_overwrite_files(self, temporary_repo): # Ensure that there are two files on the repository that have the correct name files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) - - assert all( - fnmatch.fnmatch(file, expected_file) - for file, expected_file in zip( - files, - [ - ".gitattributes", - "README.md", - "data/random-00000-of-00001-*.parquet", - "data/train-00000-of-00002-*.parquet", - "data/train-00001-of-00002-*.parquet", - "datafile.txt", - ], - ) - ) + assert files == [ + ".gitattributes", + "README.md", + "data/random-00000-of-00001.parquet", + "data/train-00000-of-00002.parquet", + "data/train-00001-of-00002.parquet", + "datafile.txt", + ] self._api.delete_file("datafile.txt", repo_id=ds_name, repo_type="dataset", token=self._token) @@ -280,6 +302,9 @@ def test_push_dataset_dict_to_hub_overwrite_files(self, temporary_repo): del hub_ds + # To ensure the reference to the memory-mapped Arrow file is dropped to avoid the PermissionError on Windows + gc.collect() + # Push to hub two times, but the second time with fewer files. # Verify that the new files contain the correct dataset and that non-necessary files have been deleted. with temporary_repo(ds_name): @@ -303,20 +328,13 @@ def test_push_dataset_dict_to_hub_overwrite_files(self, temporary_repo): # Ensure that there are two files on the repository that have the correct name files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token)) - - assert all( - fnmatch.fnmatch(file, expected_file) - for file, expected_file in zip( - files, - [ - ".gitattributes", - "README.md", - "data/random-00000-of-00001-*.parquet", - "data/train-00000-of-00001-*.parquet", - "datafile.txt", - ], - ) - ) + assert files == [ + ".gitattributes", + "README.md", + "data/random-00000-of-00001.parquet", + "data/train-00000-of-00001.parquet", + "datafile.txt", + ] # Keeping the "datafile.txt" breaks the load_dataset to think it's a text-based dataset self._api.delete_file("datafile.txt", repo_id=ds_name, repo_type="dataset", token=self._token) @@ -450,31 +468,6 @@ def test_push_dataset_to_hub_custom_splits(self, temporary_repo): assert list(ds.features.keys()) == list(hub_ds["random"].features.keys()) assert ds.features == hub_ds["random"].features - def test_push_dataset_to_hub_skip_identical_files(self, temporary_repo): - ds = Dataset.from_dict({"x": list(range(1000)), "y": list(range(1000))}) - with temporary_repo() as ds_name: - with patch("datasets.arrow_dataset.HfApi.upload_file", side_effect=self._api.upload_file) as mock_hf_api: - # Initial push - ds.push_to_hub(ds_name, token=self._token, max_shard_size="1KB") - call_count_old = mock_hf_api.call_count - mock_hf_api.reset_mock() - - # Remove a data file - files = self._api.list_repo_files(ds_name, repo_type="dataset", token=self._token) - data_files = [f for f in files if f.startswith("data/")] - assert len(data_files) > 1 - self._api.delete_file(data_files[0], repo_id=ds_name, repo_type="dataset", token=self._token) - - # "Resume" push - push missing files - ds.push_to_hub(ds_name, token=self._token, max_shard_size="1KB") - call_count_new = mock_hf_api.call_count - assert call_count_old > call_count_new - - hub_ds = load_dataset(ds_name, split="train", download_mode="force_redownload") - assert ds.column_names == hub_ds.column_names - assert list(ds.features.keys()) == list(hub_ds.features.keys()) - assert ds.features == hub_ds.features - def test_push_dataset_to_hub_multiple_splits_one_by_one(self, temporary_repo): ds = Dataset.from_dict({"x": [1, 2, 3], "y": [4, 5, 6]}) with temporary_repo() as ds_name: @@ -560,16 +553,13 @@ def test_push_multiple_dataset_configs_to_hub_load_dataset(self, temporary_repo) ds_config2.push_to_hub(ds_name, "config2", token=self._token) files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset")) - expected_files = sorted( - [ - ".gitattributes", - "README.md", - "config1/train-00000-of-00001-*.parquet", - "config2/train-00000-of-00001-*.parquet", - "data/train-00000-of-00001-*.parquet", - ] - ) - assert all(fnmatch.fnmatch(file, expected_file) for file, expected_file in zip(files, expected_files)) + assert files == [ + ".gitattributes", + "README.md", + "config1/train-00000-of-00001.parquet", + "config2/train-00000-of-00001.parquet", + "data/train-00000-of-00001.parquet", + ] hub_ds_default = load_dataset(ds_name, download_mode="force_redownload") hub_ds_config1 = load_dataset(ds_name, "config1", download_mode="force_redownload") @@ -680,19 +670,16 @@ def test_push_multiple_dataset_dict_configs_to_hub_load_dataset(self, temporary_ ds_config2.push_to_hub(ds_name, "config2", token=self._token) files = sorted(self._api.list_repo_files(ds_name, repo_type="dataset")) - expected_files = sorted( - [ - ".gitattributes", - "README.md", - "config1/random-00000-of-00001-*.parquet", - "config1/train-00000-of-00001-*.parquet", - "config2/random-00000-of-00001-*.parquet", - "config2/train-00000-of-00001-*.parquet", - "data/random-00000-of-00001-*.parquet", - "data/train-00000-of-00001-*.parquet", - ] - ) - assert all(fnmatch.fnmatch(file, expected_file) for file, expected_file in zip(files, expected_files)) + assert files == [ + ".gitattributes", + "README.md", + "config1/random-00000-of-00001.parquet", + "config1/train-00000-of-00001.parquet", + "config2/random-00000-of-00001.parquet", + "config2/train-00000-of-00001.parquet", + "data/random-00000-of-00001.parquet", + "data/train-00000-of-00001.parquet", + ] hub_ds_default = load_dataset(ds_name, download_mode="force_redownload") hub_ds_config1 = load_dataset(ds_name, "config1", download_mode="force_redownload") @@ -790,7 +777,7 @@ def test_push_dataset_to_hub_with_config_no_metadata_configs(self, temporary_rep assert len(ds_another_config_builder.config.data_files["train"]) == 1 assert fnmatch.fnmatch( ds_another_config_builder.config.data_files["train"][0], - "*/another_config/train-00000-of-00001-*.parquet", + "*/another_config/train-00000-of-00001.parquet", ) def test_push_dataset_dict_to_hub_with_config_no_metadata_configs(self, temporary_repo): @@ -824,5 +811,5 @@ def test_push_dataset_dict_to_hub_with_config_no_metadata_configs(self, temporar assert len(ds_another_config_builder.config.data_files["random"]) == 1 assert fnmatch.fnmatch( ds_another_config_builder.config.data_files["random"][0], - "*/another_config/random-00000-of-00001-*.parquet", + "*/another_config/random-00000-of-00001.parquet", )