diff --git a/docs/source/processing.rst b/docs/source/processing.rst index bd178b3161b..7e5b17f9b58 100644 --- a/docs/source/processing.rst +++ b/docs/source/processing.rst @@ -651,6 +651,8 @@ Enable or disable caching Locally you can prevent the library from reloading a cached file by using ``load_from_cache=False`` in transforms like :func:`datasets.Dataset.map` for example. You can also specify the name of path where the cache file will be written using the parameter ``cache_file_name``. +By setting ``use_caching=False`` in :func:`datasets.load_dataset`, you can disable caching on a specific dataset instance. + It is also possible to disable caching globally with :func:`datasets.set_caching_enabled`. If the caching is disabled, the library will no longer reload cached dataset files when applying transforms to the datasets. diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index dfad4b883bc..53a8c112743 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -249,6 +249,7 @@ def __init__( self._format_columns: Optional[list] = None self._output_all_columns: bool = False self._fingerprint: str = fingerprint + self._cache_dir: Optional[str] = None # Read metadata @@ -262,6 +263,7 @@ def __init__( self._fingerprint = metadata["fingerprint"] # Infer features if None + inferred_features = Features.from_arrow_schema(arrow_table.schema) if self.info.features is None: self.info.features = inferred_features @@ -273,6 +275,11 @@ def __init__( if self._fingerprint is None: self._fingerprint = generate_fingerprint(self) + # Infer cache directory if None + + if self._cache_dir is None and self.cache_files: + self._cache_dir = os.path.dirname(self.cache_files[0]["filename"]) + # Sanity checks assert self.features is not None, "Features can't be None in a Dataset object" @@ -620,6 +627,8 @@ def save_to_disk(self, dataset_path: str, fs=None): json.dumps(state["_format_kwargs"][k]) except TypeError as e: raise TypeError(str(e) + f"\nThe format kwargs must be JSON serializable, but key '{k}' isn't.") + if self._cache_dir: + state["_cached"] = bool(self._cache_dir) # Get json serializable dataset info dataset_info = asdict(self._info) @@ -706,7 +715,7 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] = split = state["_split"] split = NamedSplit(split) if split is not None else split - return Dataset( + dataset = Dataset( arrow_table=arrow_table, indices_table=indices_table, info=dataset_info, @@ -714,6 +723,12 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] = fingerprint=state["_fingerprint"], ) + # Update cache directory for in-memory datasets that use caching + if dataset._cache_dir is None and state.get("_cached"): + dataset._cache_dir = str(Path(dataset_path).expanduser()) + + return dataset + @property def data(self) -> Table: """The Apache Arrow table backing the dataset.""" @@ -1498,9 +1513,9 @@ def cleanup_cache_files(self) -> int: return len(files_to_remove) def _get_cache_file_path(self, fingerprint): - if is_caching_enabled() and self.cache_files: + if is_caching_enabled() and self._cache_dir is not None: cache_file_name = "cache-" + fingerprint + ".arrow" - cache_directory = os.path.dirname(self.cache_files[0]["filename"]) + cache_directory = self._cache_dir else: cache_file_name = "cache-" + generate_random_fingerprint() + ".arrow" cache_directory = get_temporary_cache_files_directory() @@ -1789,7 +1804,7 @@ def _map_single( batch_size = self.num_rows # Check if we've already cached this computation (indexed by a hash) - if self.cache_files: + if self._cache_dir is not None: if cache_file_name is None: # we create a unique hash from the function, # current dataset file and the mapping args @@ -1798,7 +1813,15 @@ def _map_single( logger.warning("Loading cached processed dataset at %s", cache_file_name) info = self.info.copy() info.features = features - return Dataset.from_file(cache_file_name, info=info, split=self.split) + dataset = Dataset.from_file( + cache_file_name, info=info, split=self.split, in_memory=not self.cache_files + ) + dataset._cache_dir = ( + os.path.dirname(cache_file_name) + if dataset._cache_dir is None and self._cache_dir is not None + else None + ) + return dataset # We set this variable to True after processing the first example/batch in # `apply_function_on_filtered_inputs` if the map function returns a dict. @@ -1984,7 +2007,15 @@ def init_buffer_and_writer(): info = self.info.copy() info.features = writer._features if buf_writer is None: - return Dataset.from_file(cache_file_name, info=info, split=self.split) + dataset = Dataset.from_file( + cache_file_name, info=info, split=self.split, in_memory=not self.cache_files + ) + dataset._cache_dir = ( + os.path.dirname(cache_file_name) + if dataset._cache_dir is None and self._cache_dir is not None + else None + ) + return dataset else: return Dataset.from_buffer(buf_writer.getvalue(), info=info, split=self.split) else: @@ -2291,7 +2322,7 @@ def sort( ) # Check if we've already cached this computation (indexed by a hash) - if self.cache_files: + if self._cache_dir is not None: if indices_cache_file_name is None: # we create a unique hash from the function, current dataset file and the mapping args indices_cache_file_name = self._get_cache_file_path(new_fingerprint) @@ -2374,7 +2405,7 @@ def shuffle( generator = np.random.default_rng(seed) # Check if we've already cached this computation (indexed by a hash) - if self.cache_files: + if self._cache_dir is not None: if indices_cache_file_name is None: # we create a unique hash from the function, current dataset file and the mapping args indices_cache_file_name = self._get_cache_file_path(new_fingerprint) @@ -2540,7 +2571,7 @@ def train_test_split( generator = np.random.default_rng(seed) # Check if we've already cached this computation (indexed by a hash) - if self.cache_files: + if self._cache_dir is not None: if train_indices_cache_file_name is None or test_indices_cache_file_name is None: # we create a unique hash from the function, current dataset file and the mapping args @@ -3182,6 +3213,9 @@ def apply_offset_to_indices_table(table, offset): indices_table=indices_table, fingerprint=fingerprint, ) + cache_dirs = [dset._cache_dir for dset in dsets if dset._cache_dir is not None] + if concatenated_dataset._cache_dir is None and cache_dirs: + concatenated_dataset._cache_dir = cache_dirs[0] concatenated_dataset.set_format(**format) return concatenated_dataset diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 2167a777f25..d0ce4d01866 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -700,7 +700,12 @@ def _make_split_generators_kwargs(self, prepare_split_kwargs): return {} def as_dataset( - self, split: Optional[Split] = None, run_post_process=True, ignore_verifications=False, in_memory=False + self, + split: Optional[Split] = None, + run_post_process=True, + ignore_verifications=False, + in_memory=False, + use_caching=True, ) -> Union[Dataset, DatasetDict]: """Return a Dataset for the specified split. @@ -711,6 +716,7 @@ def as_dataset( ignore_verifications (bool, default=False): Whether to ignore the verifications of the downloaded/processed dataset information (checksums/size/splits/...). in_memory (bool, default=False): Whether to copy the data in-memory. + use_caching(bool, default=True): Whether to cache the dataset transforms. Returns: datasets.Dataset @@ -740,6 +746,7 @@ def as_dataset( run_post_process=run_post_process, ignore_verifications=ignore_verifications, in_memory=in_memory, + use_caching=use_caching, ), split, map_tuple=True, @@ -754,6 +761,7 @@ def _build_single_dataset( run_post_process: bool, ignore_verifications: bool, in_memory: bool = False, + use_caching: bool = True, ): """as_dataset for a single split.""" verify_infos = not ignore_verifications @@ -765,6 +773,9 @@ def _build_single_dataset( split=split, in_memory=in_memory, ) + # Enables caching of dataset transforms for in-memory datasets + if in_memory and use_caching: + ds._cache_dir = self._cache_dir if run_post_process: for resource_file_name in self._post_processing_resources(split).values(): if os.sep in resource_file_name: diff --git a/src/datasets/info.py b/src/datasets/info.py index 0d1931d8bb1..99fa0507598 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -204,7 +204,7 @@ def _dump_license(self, file): file.write(self.license.encode("utf-8")) @classmethod - def from_merge(cls, dataset_infos: List["DatasetInfo"]): + def from_merge(cls, dataset_infos: List["DatasetInfo"]) -> "DatasetInfo": def unique(values): seen = set() for value in values: diff --git a/src/datasets/load.py b/src/datasets/load.py index e7281d38ad8..ab2ec3549e9 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -636,6 +636,7 @@ def load_dataset( save_infos: bool = False, script_version: Optional[Union[str, Version]] = None, use_auth_token: Optional[Union[bool, str]] = None, + use_caching: bool = True, task: Optional[Union[str, TaskTemplate]] = None, **config_kwargs, ) -> Union[DatasetDict, Dataset]: @@ -696,6 +697,7 @@ def load_dataset( You can specify a different version that the default "main" by using a commit sha or a git tag of the dataset repository. use_auth_token (``str`` or ``bool``, optional): Optional string or boolean to use as Bearer token for remote files on the Datasets Hub. If True, will get token from `"~/.huggingface"`. + use_caching (``bool``, default ``True``): Cache the dataset transforms created with :func:`datasets.Dataset.map`. task (``str``): The task to prepare the dataset for during training and evaluation. Casts the dataset's :class:`Features` to standardized column names and types as detailed in :py:mod:`datasets.tasks`. **config_kwargs: Keyword arguments to be passed to the :class:`BuilderConfig` and used in the :class:`DatasetBuilder`. @@ -754,7 +756,9 @@ def load_dataset( keep_in_memory = ( keep_in_memory if keep_in_memory is not None else is_small_dataset(builder_instance.info.dataset_size) ) - ds = builder_instance.as_dataset(split=split, ignore_verifications=ignore_verifications, in_memory=keep_in_memory) + ds = builder_instance.as_dataset( + split=split, ignore_verifications=ignore_verifications, in_memory=keep_in_memory, use_caching=use_caching + ) # Rename and cast features to match task schema if task is not None: ds = ds.prepare_for_task(task) diff --git a/src/datasets/utils/info_utils.py b/src/datasets/utils/info_utils.py index d99f9c4413e..3ed5728388b 100644 --- a/src/datasets/utils/info_utils.py +++ b/src/datasets/utils/info_utils.py @@ -84,7 +84,7 @@ def get_size_checksum_dict(path: str) -> dict: return {"num_bytes": os.path.getsize(path), "checksum": m.hexdigest()} -def is_small_dataset(dataset_size): +def is_small_dataset(dataset_size: int) -> bool: """Check if `dataset_size` is smaller than `config.IN_MEMORY_MAX_SIZE`. Args: