Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/processing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,8 @@ Enable or disable caching
Locally you can prevent the library from reloading a cached file by using ``load_from_cache=False`` in transforms like :func:`datasets.Dataset.map` for example.
You can also specify the name of path where the cache file will be written using the parameter ``cache_file_name``.

By setting ``use_caching=False`` in :func:`datasets.load_dataset`, you can disable caching on a specific dataset instance.

It is also possible to disable caching globally with :func:`datasets.set_caching_enabled`.

If the caching is disabled, the library will no longer reload cached dataset files when applying transforms to the datasets.
Expand Down
52 changes: 43 additions & 9 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def __init__(
self._format_columns: Optional[list] = None
self._output_all_columns: bool = False
self._fingerprint: str = fingerprint
self._cache_dir: Optional[str] = None

# Read metadata

Expand All @@ -262,6 +263,7 @@ def __init__(
self._fingerprint = metadata["fingerprint"]

# Infer features if None

inferred_features = Features.from_arrow_schema(arrow_table.schema)
if self.info.features is None:
self.info.features = inferred_features
Expand All @@ -273,6 +275,11 @@ def __init__(
if self._fingerprint is None:
self._fingerprint = generate_fingerprint(self)

# Infer cache directory if None

if self._cache_dir is None and self.cache_files:
self._cache_dir = os.path.dirname(self.cache_files[0]["filename"])

# Sanity checks

assert self.features is not None, "Features can't be None in a Dataset object"
Expand Down Expand Up @@ -620,6 +627,8 @@ def save_to_disk(self, dataset_path: str, fs=None):
json.dumps(state["_format_kwargs"][k])
except TypeError as e:
raise TypeError(str(e) + f"\nThe format kwargs must be JSON serializable, but key '{k}' isn't.")
if self._cache_dir:
state["_cached"] = bool(self._cache_dir)

# Get json serializable dataset info
dataset_info = asdict(self._info)
Expand Down Expand Up @@ -706,14 +715,20 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] =
split = state["_split"]
split = NamedSplit(split) if split is not None else split

return Dataset(
dataset = Dataset(
arrow_table=arrow_table,
indices_table=indices_table,
info=dataset_info,
split=split,
fingerprint=state["_fingerprint"],
)

# Update cache directory for in-memory datasets that use caching
if dataset._cache_dir is None and state.get("_cached"):
dataset._cache_dir = str(Path(dataset_path).expanduser())

return dataset

@property
def data(self) -> Table:
"""The Apache Arrow table backing the dataset."""
Expand Down Expand Up @@ -1498,9 +1513,9 @@ def cleanup_cache_files(self) -> int:
return len(files_to_remove)

def _get_cache_file_path(self, fingerprint):
if is_caching_enabled() and self.cache_files:
if is_caching_enabled() and self._cache_dir is not None:
cache_file_name = "cache-" + fingerprint + ".arrow"
cache_directory = os.path.dirname(self.cache_files[0]["filename"])
cache_directory = self._cache_dir
else:
cache_file_name = "cache-" + generate_random_fingerprint() + ".arrow"
cache_directory = get_temporary_cache_files_directory()
Expand Down Expand Up @@ -1789,7 +1804,7 @@ def _map_single(
batch_size = self.num_rows

# Check if we've already cached this computation (indexed by a hash)
if self.cache_files:
if self._cache_dir is not None:
if cache_file_name is None:
# we create a unique hash from the function,
# current dataset file and the mapping args
Expand All @@ -1798,7 +1813,15 @@ def _map_single(
logger.warning("Loading cached processed dataset at %s", cache_file_name)
info = self.info.copy()
info.features = features
return Dataset.from_file(cache_file_name, info=info, split=self.split)
dataset = Dataset.from_file(
cache_file_name, info=info, split=self.split, in_memory=not self.cache_files
)
dataset._cache_dir = (
os.path.dirname(cache_file_name)
if dataset._cache_dir is None and self._cache_dir is not None
else None
)
return dataset

# We set this variable to True after processing the first example/batch in
# `apply_function_on_filtered_inputs` if the map function returns a dict.
Expand Down Expand Up @@ -1984,7 +2007,15 @@ def init_buffer_and_writer():
info = self.info.copy()
info.features = writer._features
if buf_writer is None:
return Dataset.from_file(cache_file_name, info=info, split=self.split)
dataset = Dataset.from_file(
cache_file_name, info=info, split=self.split, in_memory=not self.cache_files
)
dataset._cache_dir = (
os.path.dirname(cache_file_name)
if dataset._cache_dir is None and self._cache_dir is not None
else None
)
return dataset
Comment on lines +2010 to +2018
Copy link
Collaborator Author

@mariosasko mariosasko Jun 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lhoestq I tried to address the first point from your earlier comment, but this change breaks the tests and it's getting late here, so I don't have time to fix this now. Think this is due to BaseDatasetTest._to in test_arrow_dataset.py relying on Dataset.map.

else:
return Dataset.from_buffer(buf_writer.getvalue(), info=info, split=self.split)
else:
Expand Down Expand Up @@ -2291,7 +2322,7 @@ def sort(
)

# Check if we've already cached this computation (indexed by a hash)
if self.cache_files:
if self._cache_dir is not None:
if indices_cache_file_name is None:
# we create a unique hash from the function, current dataset file and the mapping args
indices_cache_file_name = self._get_cache_file_path(new_fingerprint)
Expand Down Expand Up @@ -2374,7 +2405,7 @@ def shuffle(
generator = np.random.default_rng(seed)

# Check if we've already cached this computation (indexed by a hash)
if self.cache_files:
if self._cache_dir is not None:
if indices_cache_file_name is None:
# we create a unique hash from the function, current dataset file and the mapping args
indices_cache_file_name = self._get_cache_file_path(new_fingerprint)
Expand Down Expand Up @@ -2540,7 +2571,7 @@ def train_test_split(
generator = np.random.default_rng(seed)

# Check if we've already cached this computation (indexed by a hash)
if self.cache_files:
if self._cache_dir is not None:
if train_indices_cache_file_name is None or test_indices_cache_file_name is None:
# we create a unique hash from the function, current dataset file and the mapping args

Expand Down Expand Up @@ -3182,6 +3213,9 @@ def apply_offset_to_indices_table(table, offset):
indices_table=indices_table,
fingerprint=fingerprint,
)
cache_dirs = [dset._cache_dir for dset in dsets if dset._cache_dir is not None]
if concatenated_dataset._cache_dir is None and cache_dirs:
concatenated_dataset._cache_dir = cache_dirs[0]
concatenated_dataset.set_format(**format)
return concatenated_dataset

Expand Down
13 changes: 12 additions & 1 deletion src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,12 @@ def _make_split_generators_kwargs(self, prepare_split_kwargs):
return {}

def as_dataset(
self, split: Optional[Split] = None, run_post_process=True, ignore_verifications=False, in_memory=False
self,
split: Optional[Split] = None,
run_post_process=True,
ignore_verifications=False,
in_memory=False,
use_caching=True,
) -> Union[Dataset, DatasetDict]:
"""Return a Dataset for the specified split.

Expand All @@ -711,6 +716,7 @@ def as_dataset(
ignore_verifications (bool, default=False): Whether to ignore the verifications of the
downloaded/processed dataset information (checksums/size/splits/...).
in_memory (bool, default=False): Whether to copy the data in-memory.
use_caching(bool, default=True): Whether to cache the dataset transforms.

Returns:
datasets.Dataset
Expand Down Expand Up @@ -740,6 +746,7 @@ def as_dataset(
run_post_process=run_post_process,
ignore_verifications=ignore_verifications,
in_memory=in_memory,
use_caching=use_caching,
),
split,
map_tuple=True,
Expand All @@ -754,6 +761,7 @@ def _build_single_dataset(
run_post_process: bool,
ignore_verifications: bool,
in_memory: bool = False,
use_caching: bool = True,
):
"""as_dataset for a single split."""
verify_infos = not ignore_verifications
Expand All @@ -765,6 +773,9 @@ def _build_single_dataset(
split=split,
in_memory=in_memory,
)
# Enables caching of dataset transforms for in-memory datasets
if in_memory and use_caching:
ds._cache_dir = self._cache_dir
if run_post_process:
for resource_file_name in self._post_processing_resources(split).values():
if os.sep in resource_file_name:
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _dump_license(self, file):
file.write(self.license.encode("utf-8"))

@classmethod
def from_merge(cls, dataset_infos: List["DatasetInfo"]):
def from_merge(cls, dataset_infos: List["DatasetInfo"]) -> "DatasetInfo":
def unique(values):
seen = set()
for value in values:
Expand Down
6 changes: 5 additions & 1 deletion src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ def load_dataset(
save_infos: bool = False,
script_version: Optional[Union[str, Version]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
use_caching: bool = True,
task: Optional[Union[str, TaskTemplate]] = None,
**config_kwargs,
) -> Union[DatasetDict, Dataset]:
Expand Down Expand Up @@ -696,6 +697,7 @@ def load_dataset(
You can specify a different version that the default "main" by using a commit sha or a git tag of the dataset repository.
use_auth_token (``str`` or ``bool``, optional): Optional string or boolean to use as Bearer token for remote files on the Datasets Hub.
If True, will get token from `"~/.huggingface"`.
use_caching (``bool``, default ``True``): Cache the dataset transforms created with :func:`datasets.Dataset.map`.
task (``str``): The task to prepare the dataset for during training and evaluation. Casts the dataset's :class:`Features` to standardized column names and types as detailed in :py:mod:`datasets.tasks`.
**config_kwargs: Keyword arguments to be passed to the :class:`BuilderConfig` and used in the :class:`DatasetBuilder`.

Expand Down Expand Up @@ -754,7 +756,9 @@ def load_dataset(
keep_in_memory = (
keep_in_memory if keep_in_memory is not None else is_small_dataset(builder_instance.info.dataset_size)
)
ds = builder_instance.as_dataset(split=split, ignore_verifications=ignore_verifications, in_memory=keep_in_memory)
ds = builder_instance.as_dataset(
split=split, ignore_verifications=ignore_verifications, in_memory=keep_in_memory, use_caching=use_caching
)
# Rename and cast features to match task schema
if task is not None:
ds = ds.prepare_for_task(task)
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/utils/info_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_size_checksum_dict(path: str) -> dict:
return {"num_bytes": os.path.getsize(path), "checksum": m.hexdigest()}


def is_small_dataset(dataset_size):
def is_small_dataset(dataset_size: int) -> bool:
"""Check if `dataset_size` is smaller than `config.IN_MEMORY_MAX_SIZE`.

Args:
Expand Down