Skip to content
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def __init__(
self._format_columns: Optional[list] = None
self._output_all_columns: bool = False
self._fingerprint: str = fingerprint
self._cache_dir: Optional[str] = None

# Read metadata

Expand All @@ -258,6 +259,7 @@ def __init__(
self._fingerprint = metadata["fingerprint"]

# Infer features if None

inferred_features = Features.from_arrow_schema(arrow_table.schema)
if self.info.features is None:
self.info.features = inferred_features
Expand All @@ -267,6 +269,11 @@ def __init__(
if self._fingerprint is None:
self._fingerprint = generate_fingerprint(self)

# Infer cache directory if None

if self._cache_dir is None and self.cache_files:
self._cache_dir = os.path.dirname(self.cache_files[0]["filename"])

# Sanity checks

assert self.features is not None, "Features can't be None in a Dataset object"
Expand Down Expand Up @@ -300,7 +307,7 @@ def from_file(
info: Optional[DatasetInfo] = None,
split: Optional[NamedSplit] = None,
indices_filename: Optional[str] = None,
in_memory: bool = False,
keep_in_memory: bool = False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change. I agree that keep_in_memory is nice for consistency but we try to avoid breaking changes as much as possible

) -> "Dataset":
"""Instantiate a Dataset backed by an Arrow table at filename.

Expand All @@ -309,15 +316,15 @@ def from_file(
info (:class:`DatasetInfo`, optional): Dataset information, like description, citation, etc.
split (:class:`NamedSplit`, optional): Name of the dataset split.
indices_filename (:obj:`str`, optional): File names of the indices.
in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory.
keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory.

Returns:
:class:`Dataset`
"""
table = ArrowReader.read_table(filename, in_memory=in_memory)
table = ArrowReader.read_table(filename, in_memory=keep_in_memory)

if indices_filename is not None:
indices_pa_table = ArrowReader.read_table(indices_filename, in_memory=in_memory)
indices_pa_table = ArrowReader.read_table(indices_filename, in_memory=keep_in_memory)
else:
indices_pa_table = None

Expand Down Expand Up @@ -614,6 +621,8 @@ def save_to_disk(self, dataset_path: str, fs=None):
json.dumps(state["_format_kwargs"][k])
except TypeError as e:
raise TypeError(str(e) + f"\nThe format kwargs must be JSON serializable, but key '{k}' isn't.")
if self._cache_dir:
state["_cached"] = bool(self._cache_dir)

# Get json serializable dataset info
dataset_info = asdict(self._info)
Expand Down Expand Up @@ -697,14 +706,20 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] =
split = state["_split"]
split = NamedSplit(split) if split is not None else split

return Dataset(
dataset = Dataset(
arrow_table=arrow_table,
indices_table=indices_table,
info=dataset_info,
split=split,
fingerprint=state["_fingerprint"],
)

# Update cache directory for in-memory datasets that use caching
if dataset._cache_dir is None and state.get("_cached"):
dataset._cache_dir = str(Path(dataset_path).expanduser())

return dataset

@property
def data(self) -> Table:
"""The Apache Arrow table backing the dataset."""
Expand Down Expand Up @@ -1481,9 +1496,9 @@ def cleanup_cache_files(self) -> int:
return len(files_to_remove)

def _get_cache_file_path(self, fingerprint):
if is_caching_enabled() and self.cache_files:
if is_caching_enabled() and self._cache_dir is not None:
cache_file_name = "cache-" + fingerprint + ".arrow"
cache_directory = os.path.dirname(self.cache_files[0]["filename"])
cache_directory = self._cache_dir
else:
cache_file_name = "cache-" + generate_random_fingerprint() + ".arrow"
cache_directory = get_temporary_cache_files_directory()
Expand Down Expand Up @@ -1766,7 +1781,7 @@ def _map_single(
batch_size = self.num_rows

# Check if we've already cached this computation (indexed by a hash)
if self.cache_files:
if self._cache_dir is not None:
if cache_file_name is None:
# we create a unique hash from the function,
# current dataset file and the mapping args
Expand All @@ -1775,7 +1790,9 @@ def _map_single(
logger.warning("Loading cached processed dataset at %s", cache_file_name)
info = self.info.copy()
info.features = features
return Dataset.from_file(cache_file_name, info=info, split=self.split)
return Dataset.from_file(
cache_file_name, info=info, split=self.split, keep_in_memory=not self.cache_files
)

# We set this variable to True after processing the first example/batch in
# `apply_function_on_filtered_inputs` if the map function returns a dict.
Expand Down Expand Up @@ -2268,7 +2285,7 @@ def sort(
)

# Check if we've already cached this computation (indexed by a hash)
if self.cache_files:
if self._cache_dir is not None:
if indices_cache_file_name is None:
# we create a unique hash from the function, current dataset file and the mapping args
indices_cache_file_name = self._get_cache_file_path(new_fingerprint)
Expand Down Expand Up @@ -2351,7 +2368,7 @@ def shuffle(
generator = np.random.default_rng(seed)

# Check if we've already cached this computation (indexed by a hash)
if self.cache_files:
if self._cache_dir is not None:
if indices_cache_file_name is None:
# we create a unique hash from the function, current dataset file and the mapping args
indices_cache_file_name = self._get_cache_file_path(new_fingerprint)
Expand Down Expand Up @@ -2517,7 +2534,7 @@ def train_test_split(
generator = np.random.default_rng(seed)

# Check if we've already cached this computation (indexed by a hash)
if self.cache_files:
if self._cache_dir is not None:
if train_indices_cache_file_name is None or test_indices_cache_file_name is None:
# we create a unique hash from the function, current dataset file and the mapping args

Expand Down Expand Up @@ -3159,6 +3176,9 @@ def apply_offset_to_indices_table(table, offset):
indices_table=indices_table,
fingerprint=fingerprint,
)
cache_dirs = [dset._cache_dir for dset in dsets if dset._cache_dir is not None]
if concatenated_dataset._cache_dir is None and cache_dirs:
concatenated_dataset._cache_dir = cache_dirs[0]
concatenated_dataset.set_format(**format)
return concatenated_dataset

Expand Down
12 changes: 11 additions & 1 deletion src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,12 @@ def _make_split_generators_kwargs(self, prepare_split_kwargs):
return {}

def as_dataset(
self, split: Optional[Split] = None, run_post_process=True, ignore_verifications=False, in_memory=False
self,
split: Optional[Split] = None,
run_post_process=True,
ignore_verifications=False,
in_memory=False,
use_caching=True,
) -> Union[Dataset, DatasetDict]:
"""Return a Dataset for the specified split.

Expand Down Expand Up @@ -740,6 +745,7 @@ def as_dataset(
run_post_process=run_post_process,
ignore_verifications=ignore_verifications,
in_memory=in_memory,
use_caching=use_caching,
),
split,
map_tuple=True,
Expand All @@ -754,6 +760,7 @@ def _build_single_dataset(
run_post_process: bool,
ignore_verifications: bool,
in_memory: bool = False,
use_caching: bool = True,
):
"""as_dataset for a single split."""
verify_infos = not ignore_verifications
Expand All @@ -765,6 +772,9 @@ def _build_single_dataset(
split=split,
in_memory=in_memory,
)
# Enables caching of dataset transforms for in-memory datasets
if in_memory and use_caching:
ds._cache_dir = self._cache_dir
if run_post_process:
for resource_file_name in self._post_processing_resources(split).values():
if os.sep in resource_file_name:
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _dump_license(self, file):
file.write(self.license.encode("utf-8"))

@classmethod
def from_merge(cls, dataset_infos: List["DatasetInfo"]):
def from_merge(cls, dataset_infos: List["DatasetInfo"]) -> "DatasetInfo":
def unique(values):
seen = set()
for value in values:
Expand Down
5 changes: 4 additions & 1 deletion src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ def load_dataset(
save_infos: bool = False,
script_version: Optional[Union[str, Version]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
use_caching: bool = True,
task: Optional[Union[str, TaskTemplate]] = None,
**config_kwargs,
) -> Union[DatasetDict, Dataset]:
Expand Down Expand Up @@ -754,7 +755,9 @@ def load_dataset(
keep_in_memory = (
keep_in_memory if keep_in_memory is not None else is_small_dataset(builder_instance.info.dataset_size)
)
ds = builder_instance.as_dataset(split=split, ignore_verifications=ignore_verifications, in_memory=keep_in_memory)
ds = builder_instance.as_dataset(
split=split, ignore_verifications=ignore_verifications, in_memory=keep_in_memory, use_caching=use_caching
)
# Rename and cast features to match task schema
if task is not None:
ds = ds.prepare_for_task(task)
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/utils/info_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def get_size_checksum_dict(path: str) -> dict:
return {"num_bytes": os.path.getsize(path), "checksum": m.hexdigest()}


def is_small_dataset(dataset_size):
def is_small_dataset(dataset_size: int) -> bool:
"""Check if `dataset_size` is smaller than `config.MAX_IN_MEMORY_DATASET_SIZE_IN_BYTES`.

Args:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2305,7 +2305,7 @@ def test_dataset_add_item(item, in_memory, dataset_dict, arrow_path, transform):
def test_dataset_from_file(in_memory, dataset, arrow_file):
filename = arrow_file
with assert_arrow_memory_increases() if in_memory else assert_arrow_memory_doesnt_increase():
dataset_from_file = Dataset.from_file(filename, in_memory=in_memory)
dataset_from_file = Dataset.from_file(filename, keep_in_memory=in_memory)
assert dataset_from_file.features.type == dataset.features.type
assert dataset_from_file.features == dataset.features
assert dataset_from_file.cache_files == ([{"filename": filename}] if not in_memory else [])
Expand Down Expand Up @@ -2555,8 +2555,8 @@ def test_dataset_to_json(dataset, tmp_path):
)
def test_pickle_dataset_after_transforming_the_table(in_memory, method_and_params, arrow_file):
method, args, kwargs = method_and_params
with Dataset.from_file(arrow_file, in_memory=in_memory) as dataset, Dataset.from_file(
arrow_file, in_memory=in_memory
with Dataset.from_file(arrow_file, keep_in_memory=in_memory) as dataset, Dataset.from_file(
arrow_file, keep_in_memory=in_memory
) as reference_dataset:
out = getattr(dataset, method)(*args, **kwargs)
dataset = out if out is not None else dataset
Expand Down