Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
a802ba5
multiprocessing-compatible naming scheme and refactor
TevenLeScao Oct 11, 2022
ea56329
multiprocessed shard writing for GeneratorBasedBuilder
TevenLeScao Oct 12, 2022
9536184
multiprocessed shard writing for ArrowBasedBuilder
TevenLeScao Oct 12, 2022
31d8395
style
TevenLeScao Oct 12, 2022
9c5843a
multiprocessed dataset loading
TevenLeScao Oct 15, 2022
328112e
compatibility with non-sharded datasets
TevenLeScao Oct 15, 2022
9dc8539
bugfix
TevenLeScao Oct 17, 2022
21a603a
bugfix
TevenLeScao Oct 17, 2022
55cb365
Merge remote-tracking branch 'origin/multiprocessed_dataset_prep' int…
TevenLeScao Oct 19, 2022
94efbdb
removed unused import
TevenLeScao Oct 19, 2022
bac2b2f
fixed bad ordering
TevenLeScao Oct 19, 2022
3e4f337
less misleading tqdm
TevenLeScao Oct 19, 2022
b2f634d
fix gen_kwargs distribution + read shards
lhoestq Oct 20, 2022
296302f
minor
lhoestq Oct 20, 2022
9b312d4
minor2
lhoestq Oct 20, 2022
d2e70f2
support beam datasets
lhoestq Oct 21, 2022
e3a30fa
docstrings + minor
lhoestq Oct 25, 2022
cf6fd25
add iflatmap_unordered for parallel write & progress updates
lhoestq Oct 26, 2022
3e5d0cc
use 1 tqdm bar receiving updates from subprocesses
lhoestq Oct 26, 2022
09c13a7
docs
lhoestq Oct 26, 2022
a2e83d5
add test_iflatmap_unordered
lhoestq Oct 27, 2022
e3bc7a7
style
lhoestq Oct 27, 2022
e8923e2
test arrow_reader.py
lhoestq Oct 27, 2022
ef9c7f1
fix test_iflatmap_unordered
lhoestq Oct 28, 2022
088dbb1
add Beam test_download_and_prepare_sharded
lhoestq Oct 28, 2022
eb1fc58
test gen_kwargs distribution
lhoestq Oct 28, 2022
e035339
test download_and_prepare with num_proc
lhoestq Oct 28, 2022
06c5d33
Merge branch 'main' into multiprocessed_dataset_prep
lhoestq Oct 28, 2022
e50ec74
style
lhoestq Oct 28, 2022
525c829
improve test
lhoestq Nov 2, 2022
eae6491
don't close the pool
lhoestq Nov 2, 2022
93f355d
Merge branch 'main' into multiprocessed_dataset_prep
lhoestq Nov 2, 2022
b321c61
fix multiprocessing on windows
lhoestq Nov 2, 2022
b05e551
keep multiprocessing disabled by default
lhoestq Nov 2, 2022
020eb89
again + docs
lhoestq Nov 2, 2022
142f822
more docs
lhoestq Nov 2, 2022
f22c162
more docs
lhoestq Nov 2, 2022
08b8626
Merge remote-tracking branch 'upstream/main' into multiprocessed_data…
lhoestq Nov 3, 2022
4ce2d12
some var renaming
lhoestq Nov 3, 2022
e05ad83
style
lhoestq Nov 3, 2022
c621cb6
Apply suggestions from code review
lhoestq Nov 8, 2022
22d965e
Apply suggestions from code review
lhoestq Nov 8, 2022
dc0ef15
added utils/sharding.py
lhoestq Nov 8, 2022
95cdd0b
Merge remote-tracking branch 'upstream/main' into multiprocessed_data…
lhoestq Nov 8, 2022
12d69f3
style
lhoestq Nov 8, 2022
db45b3b
style
lhoestq Nov 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions docs/source/dataset_script.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,59 @@ Congratulations, you can now load your dataset from the Hub! 🥳
>>> from datasets import load_dataset
>>> load_dataset("<username>/my_dataset")
```

## Advanced features

### Sharding

If your dataset is made of many big files, 🤗 Datasets automatically runs your script in parallel to make it super fast!
It can help if you have hundreds or thousands of TAR archives, or JSONL files like [oscar](https://huggingface.co/datasets/oscar/blob/main/oscar.py) for example.

To make it work, we consider lists of files in `gen_kwargs` to be shards.
Therefore 🤗 Datasets can automatically spawn several workers to run `_generate_examples` in parallel, and each worker is given a subset of shards to process.


```python

class MyShardedDataset(datasets.GeneratorBasedBuilder):

def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
downloaded_files = dl_manager.download([f"data/shard_{i}.jsonl" for i in range(1024)])
return [
datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepaths": downloaded_files}),
]

def _generate_examples(self, filepaths):
# Each worker can be given a slice of the original `filepaths` list defined in the `gen_kwargs`
# so that this code can run in parallel on several shards at the same time
for filepath in filepaths:
...
```

Users can also specify `num_proc=` in `load_dataset()` to specify the number of processes to use as workers.

### ArrowBasedBuilder

For some datasets it can be much faster to yield batches of data rather than examples one by one.
You can speed up the dataset generation by yielding Arrow tables directly, instead of examples.
This is especially useful if your data comes from Pandas DataFrames for example, since the conversion from Pandas to Arrow is as simple as:

```python
import pyarrow as pa
pa_table = pa.Table.from_pandas(df)
```

To yield Arrow tables instead of single examples, make your dataset builder inherit from [`ArrowBasedBuilder`] instead of [`GeneratorBasedBuilder`], and use `_generate_tables` instead of `_generate_examples`:

```python
class MySuperFastDataset(datasets.ArrowBasedBuilder):

def _generate_tables(self, filepaths):
idx = 0
for filepath in filepaths:
...
yield idx, pa_table
idx += 1
```

Don't forget to keep your script memory efficient, in case users run them on machines with a low amount of RAM.
15 changes: 15 additions & 0 deletions docs/source/loading.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,21 @@ You can specify [`Dataset.from_sql#con`] as a [URI string](https://docs.sqlalche

</Tip>

## Multiprocessing

When a dataset is made of several files (that we call "shards"), it is possible to significantly speed up the dataset downloading and preparation step.

You can choose how many processes you'd like to use to prepare a dataset in parallel using `num_proc`.
In this case, each process is given a subset of shards to prepare:

```python
from datasets import load_dataset

oscar_afrikaans = load_dataset("oscar-corpus/OSCAR-2201", "af", num_proc=8)
imagenet = load_dataset("imagenet-1k", num_proc=8)
ml_librispeech_spanish = load_dataset("facebook/multilingual_librispeech", "spanish", num_proc=8)
```

## In-memory data

🤗 Datasets will also allow you to create a [`Dataset`] directly from in-memory data structures like Python dictionaries and Pandas DataFrames.
Expand Down
63 changes: 44 additions & 19 deletions src/datasets/arrow_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@
import pyarrow.parquet as pq

from .download.download_config import DownloadConfig
from .naming import _split_re, filename_for_dataset_split
from .naming import _split_re, filenames_for_dataset_split
from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables
from .utils import logging
from .utils.file_utils import cached_path


if TYPE_CHECKING:
from .info import DatasetInfo # noqa: F401
from .splits import Split # noqa: F401
from .splits import Split, SplitInfo # noqa: F401


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -88,7 +88,13 @@ class FileInstructions:
file_instructions: List[dict]


def make_file_instructions(name, split_infos, instruction, filetype_suffix=None):
def make_file_instructions(
name: str,
split_infos: List["SplitInfo"],
instruction: Union[str, "ReadInstruction"],
filetype_suffix: Optional[str] = None,
prefix_path: Optional[str] = None,
):
"""Returns instructions of the split dict.

Args:
Expand All @@ -101,31 +107,48 @@ def make_file_instructions(name, split_infos, instruction, filetype_suffix=None)
file_intructions: FileInstructions instance
"""
name2len = {info.name: info.num_examples for info in split_infos}
name2shard_lengths = {info.name: info.shard_lengths for info in split_infos}
name2filenames = {
info.name: filenames_for_dataset_split(
path=prefix_path,
dataset_name=name,
split=info.name,
filetype_suffix=filetype_suffix,
shard_lengths=name2shard_lengths[info.name],
)
for info in split_infos
}
if not isinstance(instruction, ReadInstruction):
instruction = ReadInstruction.from_spec(instruction)
# Create the absolute instruction (per split)
absolute_instructions = instruction.to_absolute(name2len)

return _make_file_instructions_from_absolutes(
name=name, name2len=name2len, absolute_instructions=absolute_instructions, filetype_suffix=filetype_suffix
)


def _make_file_instructions_from_absolutes(name, name2len, absolute_instructions, filetype_suffix=None):
"""Returns the files instructions from the absolute instructions list."""
# For each split, return the files instruction (skip/take)
file_instructions = []
num_examples = 0
for abs_instr in absolute_instructions:
length = name2len[abs_instr.splitname]
filename = filename_for_dataset_split(
dataset_name=name, split=abs_instr.splitname, filetype_suffix=filetype_suffix
)
split_length = name2len[abs_instr.splitname]
filenames = name2filenames[abs_instr.splitname]
shard_lengths = name2shard_lengths[abs_instr.splitname]
from_ = 0 if abs_instr.from_ is None else abs_instr.from_
to = length if abs_instr.to is None else abs_instr.to
num_examples += to - from_
single_file_instructions = [{"filename": filename, "skip": from_, "take": to - from_}]
file_instructions.extend(single_file_instructions)
to = split_length if abs_instr.to is None else abs_instr.to
if shard_lengths is None: # not sharded
for filename in filenames:
num_examples += to - from_
file_instructions.append({"filename": filename, "skip": from_, "take": to - from_})
else: # sharded
index_start = 0 # Beginning (included) of moving window.
index_end = 0 # End (excluded) of moving window.
for filename, shard_length in zip(filenames, shard_lengths):
index_end += shard_length
if from_ < index_end and to > index_start: # There is something to take.
skip = from_ - index_start if from_ > index_start else 0
take = to - index_start - skip if to < index_end else -1
if take == 0:
continue
file_instructions.append({"filename": filename, "skip": skip, "take": take})
num_examples += shard_length - skip if take == -1 else take
index_start += shard_length
return FileInstructions(
num_examples=num_examples,
file_instructions=file_instructions,
Expand Down Expand Up @@ -182,7 +205,7 @@ def _read_files(self, files, in_memory=False) -> Table:
def get_file_instructions(self, name, instruction, split_infos):
"""Return list of dict {'filename': str, 'skip': int, 'take': int}"""
file_instructions = make_file_instructions(
name, split_infos, instruction, filetype_suffix=self._filetype_suffix
name, split_infos, instruction, filetype_suffix=self._filetype_suffix, prefix_path=self._path
)
files = file_instructions.file_instructions
return files
Expand Down Expand Up @@ -304,6 +327,8 @@ def _get_table_from_filename(self, filename_skip_take, in_memory=False) -> Table
filename_skip_take["take"] if "take" in filename_skip_take else None,
)
table = ArrowReader.read_table(filename, in_memory=in_memory)
if take == -1:
take = len(table) - skip
# here we don't want to slice an empty table, or it may segfault
if skip is not None and take is not None and not (skip == 0 and take == len(table)):
table = table.slice(skip, take)
Expand Down
74 changes: 44 additions & 30 deletions src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,19 +636,33 @@ def finalize(self, metrics_query_result: dict):

from .utils import beam_utils

shards_metadata = [
metadata
for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[0].metadata_list
]
shards = [metadata.path for metadata in shards_metadata]
num_bytes = sum([metadata.size_in_bytes for metadata in shards_metadata])
shard_lengths = get_parquet_lengths(shards)

# Convert to arrow
if self._path.endswith(".arrow"):
logger.info(f"Converting parquet file {self._parquet_path} to arrow {self._path}")
logger.info(f"Converting parquet files {self._parquet_path} to arrow {self._path}")
shards = [
metadata.path
for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[
0
].metadata_list
]
try: # stream conversion
sources = [beam.io.filesystems.FileSystems.open(shard) for shard in shards]
with beam.io.filesystems.FileSystems.create(self._path) as dest:
parquet_to_arrow(sources, dest)
disable = not logging.is_progress_bar_enabled()
num_bytes = 0
for shard in logging.tqdm(shards, unit="shards", disable=disable):
with beam.io.filesystems.FileSystems.open(shard) as source:
with beam.io.filesystems.FileSystems.create(
shard.replace(".parquet", ".arrow")
) as destination:
shard_num_bytes, _ = parquet_to_arrow(source, destination)
num_bytes += shard_num_bytes
except OSError as e: # broken pipe can happen if the connection is unstable, do local conversion instead
if e.errno != errno.EPIPE: # not a broken pipe
raise
Expand All @@ -657,41 +671,41 @@ def finalize(self, metrics_query_result: dict):
)
local_convert_dir = os.path.join(self._cache_dir, "beam_convert")
os.makedirs(local_convert_dir, exist_ok=True)
local_arrow_path = os.path.join(local_convert_dir, hash_url_to_filename(self._parquet_path) + ".arrow")
local_shards = []
for shard in shards:
disable = not logging.is_progress_bar_enabled()
num_bytes = 0
for shard in logging.tqdm(shards, unit="shards", disable=disable):
local_parquet_path = os.path.join(local_convert_dir, hash_url_to_filename(shard) + ".parquet")
local_shards.append(local_parquet_path)
beam_utils.download_remote_to_local(shard, local_parquet_path)
parquet_to_arrow(local_shards, local_arrow_path)
beam_utils.upload_local_to_remote(local_arrow_path, self._path)
output_file_metadata = beam.io.filesystems.FileSystems.match([self._path], limits=[1])[0].metadata_list[0]
num_bytes = output_file_metadata.size_in_bytes
else:
num_bytes = sum(
[
metadata.size_in_bytes
for metadata in beam.io.filesystems.FileSystems.match([self._parquet_path + "*.parquet"])[
0
].metadata_list
]
)
local_arrow_path = local_parquet_path.replace(".parquet", ".arrow")
shard_num_bytes, _ = parquet_to_arrow(local_parquet_path, local_arrow_path)
num_bytes += shard_num_bytes
remote_arrow_path = shard.replace(".parquet", ".arrow")
beam_utils.upload_local_to_remote(local_arrow_path, remote_arrow_path)

# Save metrics
counters_dict = {metric.key.metric.name: metric.result for metric in metrics_query_result["counters"]}
self._num_examples = counters_dict["num_examples"]
self._num_bytes = num_bytes
self._shard_lengths = shard_lengths
return self._num_examples, self._num_bytes


def parquet_to_arrow(sources, destination):
"""Convert parquet files to arrow file. Inputs can be str paths or file-like objects"""
stream = None if isinstance(destination, str) else destination
def get_parquet_lengths(sources) -> List[int]:
shard_lengths = []
disable = not logging.is_progress_bar_enabled()
for source in logging.tqdm(sources, unit="parquet files", disable=disable):
parquet_file = pa.parquet.ParquetFile(source)
shard_lengths.append(parquet_file.metadata.num_rows)
return shard_lengths


def parquet_to_arrow(source, destination) -> List[int]:
"""Convert parquet file to arrow file. Inputs can be str paths or file-like objects"""
stream = None if isinstance(destination, str) else destination
with ArrowWriter(path=destination, stream=stream) as writer:
for source in logging.tqdm(sources, unit="sources", disable=disable):
parquet_file = pa.parquet.ParquetFile(source)
for record_batch in parquet_file.iter_batches():
pa_table = pa.Table.from_batches([record_batch])
writer.write_table(pa_table)
return destination
parquet_file = pa.parquet.ParquetFile(source)
for record_batch in parquet_file.iter_batches():
pa_table = pa.Table.from_batches([record_batch])
writer.write_table(pa_table)
num_bytes, num_examples = writer.finalize()
return num_bytes, num_examples
Loading