Skip to content

Commit 0b2a4c2

Browse files
authored
Keep hffs cache in workers when streaming (huggingface#7820)
* keep hffs cache in workers when streaming * bonus: reorder hffs args to improve caching
1 parent 12f5aca commit 0b2a4c2

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

src/datasets/download/download_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def copy(self) -> "DownloadConfig":
7575
def __setattr__(self, name, value):
7676
if name == "token" and getattr(self, "storage_options", None) is not None:
7777
if "hf" not in self.storage_options:
78-
self.storage_options["hf"] = {"token": value, "endpoint": config.HF_ENDPOINT}
78+
self.storage_options["hf"] = {"endpoint": config.HF_ENDPOINT, "token": value}
7979
elif getattr(self.storage_options["hf"], "token", None) is None:
8080
self.storage_options["hf"]["token"] = value
8181
super().__setattr__(name, value)

src/datasets/iterable_dataset.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,15 @@
2626
import pandas as pd
2727
import pyarrow as pa
2828
import pyarrow.parquet as pq
29-
from huggingface_hub import CommitInfo, CommitOperationAdd, CommitOperationDelete, DatasetCard, DatasetCardData, HfApi
29+
from huggingface_hub import (
30+
CommitInfo,
31+
CommitOperationAdd,
32+
CommitOperationDelete,
33+
DatasetCard,
34+
DatasetCardData,
35+
HfApi,
36+
HfFileSystem,
37+
)
3038
from huggingface_hub.hf_api import RepoFile
3139
from huggingface_hub.utils import HfHubHTTPError, RepositoryNotFoundError
3240
from multiprocess import Pool
@@ -2151,6 +2159,7 @@ def __init__(
21512159
self._token_per_repo_id: dict[str, Union[str, bool, None]] = token_per_repo_id or {}
21522160
self._epoch: Union[int, "torch.Tensor"] = _maybe_share_with_torch_persistent_workers(0)
21532161
self._starting_state_dict: Optional[dict] = None
2162+
self.__hffs_cache = HfFileSystem._cache # keep the cache on pickling (e.g. for dataloader workers)
21542163
self._prepare_ex_iterable_for_iteration() # set state_dict
21552164
_maybe_add_torch_iterable_dataset_parent_class(self.__class__) # subclass of torch IterableDataset
21562165

@@ -2299,6 +2308,8 @@ def __setstate__(self, d):
22992308
self.__dict__ = d
23002309
# Re-add torch shared memory, since shared memory is not always kept when pickling
23012310
self._epoch = _maybe_share_with_torch_persistent_workers(self._epoch)
2311+
# Re-add the cache to keep on pickling (e.g. for dataloader workers)
2312+
self.__hffs_cache = HfFileSystem._cache
23022313
# Re-add torch iterable dataset as a parent class, since dynamically added parent classes are not kept when pickling
23032314
_maybe_add_torch_iterable_dataset_parent_class(self.__class__)
23042315

src/datasets/utils/file_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,8 +897,8 @@ def _prepare_single_hop_path_and_storage_options(
897897
storage_options["headers"] = {"Accept-Encoding": "identity", **headers}
898898
elif protocol == "hf":
899899
storage_options = {
900-
"token": token,
901900
"endpoint": config.HF_ENDPOINT,
901+
"token": token,
902902
**storage_options,
903903
}
904904
if storage_options:

0 commit comments

Comments
 (0)