Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def __init__(
self._format_columns: Optional[list] = None
self._output_all_columns: bool = False
self._fingerprint: str = fingerprint
self._cache_dir: Optional[str] = None

# Read metadata

Expand All @@ -257,6 +258,7 @@ def __init__(
self._fingerprint = metadata["fingerprint"]

# Infer features if None

inferred_features = Features.from_arrow_schema(arrow_table.schema)
if self.info.features is None:
self.info.features = inferred_features
Expand All @@ -266,6 +268,11 @@ def __init__(
if self._fingerprint is None:
self._fingerprint = generate_fingerprint(self)

# Infer cache directory if None

if self._cache_dir is None and self.cache_files:
self._cache_dir = os.path.dirname(self.cache_files[0]["filename"])

# Sanity checks

assert self.features is not None, "Features can't be None in a Dataset object"
Expand Down Expand Up @@ -613,6 +620,7 @@ def save_to_disk(self, dataset_path: str, fs=None):
json.dumps(state["_format_kwargs"][k])
except TypeError as e:
raise TypeError(str(e) + f"\nThe format kwargs must be JSON serializable, but key '{k}' isn't.")
state["_cached"] = bool(self._cache_dir)

# Get json serializable dataset info
dataset_info = asdict(self._info)
Expand Down Expand Up @@ -696,14 +704,20 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] =
split = state["_split"]
split = NamedSplit(split) if split is not None else split

return Dataset(
dataset = Dataset(
arrow_table=arrow_table,
indices_table=indices_table,
info=dataset_info,
split=split,
fingerprint=state["_fingerprint"],
)

# Update cache directory for in-memory datasets created with load_dataset
if state["_cached"] and dataset._cache_dir is None:
dataset._cache_dir = str(Path(dataset_path).expanduser())

return dataset

@property
def data(self) -> Table:
"""The Apache Arrow table backing the dataset."""
Expand Down Expand Up @@ -1442,9 +1456,9 @@ def cleanup_cache_files(self) -> int:
return len(files_to_remove)

def _get_cache_file_path(self, fingerprint):
if is_caching_enabled() and self.cache_files:
if is_caching_enabled() and self._cache_dir is not None:
cache_file_name = "cache-" + fingerprint + ".arrow"
cache_directory = os.path.dirname(self.cache_files[0]["filename"])
cache_directory = self._cache_dir
else:
cache_file_name = "cache-" + generate_random_fingerprint() + ".arrow"
cache_directory = get_temporary_cache_files_directory()
Expand Down Expand Up @@ -1727,7 +1741,7 @@ def _map_single(
batch_size = self.num_rows

# Check if we've already cached this computation (indexed by a hash)
if self.cache_files:
if self._cache_dir is not None:
if cache_file_name is None:
# we create a unique hash from the function,
# current dataset file and the mapping args
Expand Down Expand Up @@ -2229,7 +2243,7 @@ def sort(
)

# Check if we've already cached this computation (indexed by a hash)
if self.cache_files:
if self._cache_dir is not None:
if indices_cache_file_name is None:
# we create a unique hash from the function, current dataset file and the mapping args
indices_cache_file_name = self._get_cache_file_path(new_fingerprint)
Expand Down Expand Up @@ -2312,7 +2326,7 @@ def shuffle(
generator = np.random.default_rng(seed)

# Check if we've already cached this computation (indexed by a hash)
if self.cache_files:
if self._cache_dir is not None:
if indices_cache_file_name is None:
# we create a unique hash from the function, current dataset file and the mapping args
indices_cache_file_name = self._get_cache_file_path(new_fingerprint)
Expand Down Expand Up @@ -2478,7 +2492,7 @@ def train_test_split(
generator = np.random.default_rng(seed)

# Check if we've already cached this computation (indexed by a hash)
if self.cache_files:
if self._cache_dir is not None:
if train_indices_cache_file_name is None or test_indices_cache_file_name is None:
# we create a unique hash from the function, current dataset file and the mapping args

Expand Down Expand Up @@ -3108,6 +3122,9 @@ def apply_offset_to_indices_table(table, offset):
indices_table=indices_table,
fingerprint=fingerprint,
)
cache_dirs = [dset._cache_dir for dset in dsets if dset._cache_dir is not None]
if concatenated_dataset._cache_dir is None and cache_dirs:
concatenated_dataset._cache_dir = cache_dirs[0]
concatenated_dataset.set_format(**format)
return concatenated_dataset

Expand Down
3 changes: 3 additions & 0 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,9 @@ def _build_single_dataset(
split=split,
in_memory=in_memory,
)
# Required for in-memory datasets
if not ds.cache_files:
ds._cache_dir = self._cache_dir
if run_post_process:
for resource_file_name in self._post_processing_resources(split).values():
if os.sep in resource_file_name:
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def _dump_license(self, file):
file.write(self.license.encode("utf-8"))

@classmethod
def from_merge(cls, dataset_infos: List["DatasetInfo"]):
def from_merge(cls, dataset_infos: List["DatasetInfo"]) -> "DatasetInfo":
def unique(values):
seen = set()
for value in values:
Expand Down