Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datasets/arrow_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def make_file_instructions(
dataset_name=name,
split=info.name,
filetype_suffix=filetype_suffix,
shard_lengths=name2shard_lengths[info.name],
num_shards=len(name2shard_lengths[info.name] or ()),
)
for info in split_infos
}
Expand Down
22 changes: 17 additions & 5 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Dict, Iterable, Mapping, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, Union
from unittest.mock import patch

import fsspec
Expand Down Expand Up @@ -63,7 +63,7 @@
from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo
from .iterable_dataset import ArrowExamplesIterable, ExamplesIterable, IterableDataset
from .keyhash import DuplicatedKeysError
from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase
from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase, filenames_for_dataset_split
from .splits import Split, SplitDict, SplitGenerator, SplitInfo
from .streaming import extend_dataset_builder_for_streaming
from .table import CastError
Expand Down Expand Up @@ -744,6 +744,7 @@ def _rename(self, src: str, dst: str):
def download_and_prepare(
self,
output_dir: Optional[str] = None,
splits: Optional[List[str]] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[Union[DownloadMode, str]] = None,
verification_mode: Optional[Union[VerificationMode, str]] = None,
Expand Down Expand Up @@ -928,12 +929,23 @@ def download_and_prepare(
# File locking only with local paths; no file locking on GCS or S3
with FileLock(lock_path) if is_local else contextlib.nullcontext():
# Check if the data already exists
data_exists = self._fs.exists(posixpath.join(self._output_dir, config.DATASET_INFO_FILENAME))
if data_exists and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS:
logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})")
info_exists = self._fs.exists(posixpath.join(self._output_dir, config.DATASET_INFO_FILENAME))
if info_exists:
# We need to update the info in case some splits were added in the meantime
# for example when calling load_dataset from multiple workers.
self.info = self._load_info()
_dataset_name = self.name if self._check_legacy_cache() else self.dataset_name
if splits is not None:
for split in splits:
num_shards = len(self.info.splits[split].shard_lengths or ()) if self.info.splits else 1
_filename = filenames_for_dataset_split(
self._output_dir, _dataset_name, split, filetype_suffix=file_format, num_shards=num_shards
)[0]
if self._fs.exists(_filename):
splits.pop(split) # split is already cached
requested_splits_exist = not splits
if info_exists and requested_splits_exist and download_mode == DownloadMode.REUSE_DATASET_IF_EXISTS:
logger.info(f"Found cached dataset {self.dataset_name} ({self._output_dir})")
self.download_post_processing_resources(dl_manager)
return

Expand Down
15 changes: 7 additions & 8 deletions src/datasets/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Lint as: python3
"""Utilities for file names."""
import itertools
import os
import posixpath
import re


Expand Down Expand Up @@ -45,13 +45,13 @@ def snakecase_to_camelcase(name):


def filename_prefix_for_name(name):
if os.path.basename(name) != name:
if posixpath.basename(name) != name:
raise ValueError(f"Should be a dataset name, not a path: {name}")
return camelcase_to_snakecase(name)


def filename_prefix_for_split(name, split):
if os.path.basename(name) != name:
if posixpath.basename(name) != name:
raise ValueError(f"Should be a dataset name, not a path: {name}")
if not re.match(_split_re, split):
raise ValueError(f"Split name should match '{_split_re}'' but got '{split}'.")
Expand All @@ -62,16 +62,15 @@ def filepattern_for_dataset_split(dataset_name, split, data_dir, filetype_suffix
prefix = filename_prefix_for_split(dataset_name, split)
if filetype_suffix:
prefix += f".{filetype_suffix}"
filepath = os.path.join(data_dir, prefix)
filepath = posixpath.join(data_dir, prefix)
return f"{filepath}*"


def filenames_for_dataset_split(path, dataset_name, split, filetype_suffix=None, shard_lengths=None):
def filenames_for_dataset_split(path, dataset_name, split, filetype_suffix=None, num_shards=1):
prefix = filename_prefix_for_split(dataset_name, split)
prefix = os.path.join(path, prefix)
prefix = posixpath.join(path, prefix)

if shard_lengths:
num_shards = len(shard_lengths)
if num_shards > 1:
filenames = [f"{prefix}-{shard_id:05d}-of-{num_shards:05d}" for shard_id in range(num_shards)]
if filetype_suffix:
filenames = [filename + f".{filetype_suffix}" for filename in filenames]
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/packaged_modules/cache/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _split_generators(self, dl_manager):
dataset_name=self.dataset_name,
split=split_info.name,
filetype_suffix="arrow",
shard_lengths=split_info.shard_lengths,
num_shards=len(split_info.shard_lengths or ()),
)
},
)
Expand Down
21 changes: 12 additions & 9 deletions tests/test_arrow_reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import posixpath
import tempfile
from pathlib import Path
from unittest import TestCase
Expand Down Expand Up @@ -103,8 +104,8 @@ def test_read_files(self):
reader = ReaderTest(tmp_dir, info)

files = [
{"filename": os.path.join(tmp_dir, "train")},
{"filename": os.path.join(tmp_dir, "test"), "skip": 10, "take": 10},
{"filename": posixpath.join(tmp_dir, "train")},
{"filename": posixpath.join(tmp_dir, "test"), "skip": 10, "take": 10},
]
dset = Dataset(**reader.read_files(files, original_instructions="train+test[10:20]"))
self.assertEqual(dset.num_rows, 110)
Expand Down Expand Up @@ -169,18 +170,18 @@ def test_make_file_instructions_basic():
assert isinstance(file_instructions, FileInstructions)
assert file_instructions.num_examples == 33
assert file_instructions.file_instructions == [
{"filename": os.path.join(prefix_path, f"{name}-train.arrow"), "skip": 0, "take": 33}
{"filename": posixpath.join(prefix_path, f"{name}-train.arrow"), "skip": 0, "take": 33}
]

split_infos = [SplitInfo(name="train", num_examples=100, shard_lengths=[10] * 10)]
file_instructions = make_file_instructions(name, split_infos, instruction, filetype_suffix, prefix_path)
assert isinstance(file_instructions, FileInstructions)
assert file_instructions.num_examples == 33
assert file_instructions.file_instructions == [
{"filename": os.path.join(prefix_path, f"{name}-train-00000-of-00010.arrow"), "skip": 0, "take": -1},
{"filename": os.path.join(prefix_path, f"{name}-train-00001-of-00010.arrow"), "skip": 0, "take": -1},
{"filename": os.path.join(prefix_path, f"{name}-train-00002-of-00010.arrow"), "skip": 0, "take": -1},
{"filename": os.path.join(prefix_path, f"{name}-train-00003-of-00010.arrow"), "skip": 0, "take": 3},
{"filename": posixpath.join(prefix_path, f"{name}-train-00000-of-00010.arrow"), "skip": 0, "take": -1},
{"filename": posixpath.join(prefix_path, f"{name}-train-00001-of-00010.arrow"), "skip": 0, "take": -1},
{"filename": posixpath.join(prefix_path, f"{name}-train-00002-of-00010.arrow"), "skip": 0, "take": -1},
{"filename": posixpath.join(prefix_path, f"{name}-train-00003-of-00010.arrow"), "skip": 0, "take": 3},
]


Expand Down Expand Up @@ -217,7 +218,7 @@ def test_make_file_instructions(split_name, instruction, shard_lengths, read_ran
if not isinstance(shard_lengths, list):
assert file_instructions.file_instructions == [
{
"filename": os.path.join(prefix_path, f"{name}-{split_name}.arrow"),
"filename": posixpath.join(prefix_path, f"{name}-{split_name}.arrow"),
"skip": read_range[0],
"take": read_range[1] - read_range[0],
}
Expand All @@ -226,7 +227,9 @@ def test_make_file_instructions(split_name, instruction, shard_lengths, read_ran
file_instructions_list = []
shard_offset = 0
for i, shard_length in enumerate(shard_lengths):
filename = os.path.join(prefix_path, f"{name}-{split_name}-{i:05d}-of-{len(shard_lengths):05d}.arrow")
filename = posixpath.join(
prefix_path, f"{name}-{split_name}-{i:05d}-of-{len(shard_lengths):05d}.arrow"
)
if shard_offset <= read_range[0] < shard_offset + shard_length:
file_instructions_list.append(
{
Expand Down