Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
00d69e0
add streaming module
lhoestq Apr 14, 2021
587bbb9
make oscar streaming compatible
lhoestq Apr 14, 2021
ed6f7c3
minor
lhoestq Apr 14, 2021
bcd5b92
minor
lhoestq Apr 14, 2021
00a4a57
use right name for format type
lhoestq Apr 14, 2021
45bd319
add shuffle
lhoestq Apr 14, 2021
ca4a4b3
shuffle examples generator's source data files
lhoestq Apr 15, 2021
8abcc73
add transform_batch_size
lhoestq Apr 15, 2021
40e6014
clean shuffling buffer
lhoestq Apr 15, 2021
bd9b940
iterable_dataset factory
lhoestq Apr 15, 2021
7fb47d9
support hub datasets and arrow based builder
lhoestq Apr 20, 2021
c461499
Merge branch 'master' into streaming
lhoestq Apr 21, 2021
ff0702d
add merge_datasets with probabilities
lhoestq Apr 22, 2021
f14b51b
style
lhoestq Apr 22, 2021
eae82d2
add retries if server disconnection occurs
lhoestq Apr 22, 2021
5b7843a
add aiohttp to setup.py
lhoestq Apr 28, 2021
a3df1e4
allow streaming from zip files
lhoestq Apr 28, 2021
97bce1c
replace download_and_extract by simple download when there's no extra…
lhoestq Apr 28, 2021
4b655a9
Merge branch 'master' into streaming
lhoestq Apr 28, 2021
dc5309a
re-organize code
lhoestq May 18, 2021
3008988
Merge branch 'master' into dataset-streaming
lhoestq Jun 2, 2021
8579c71
start tests
lhoestq Jun 2, 2021
080a083
more tests
lhoestq Jun 2, 2021
6172953
more tests
lhoestq Jun 3, 2021
f4b84eb
add `streaming` argument to `load_dataset`
lhoestq Jun 4, 2021
daede36
allow streaming from private repos
lhoestq Jun 7, 2021
e2f26dc
Merge branch 'master' into dataset-streaming
lhoestq Jun 7, 2021
064ab00
Revert "replace download_and_extract by simple download when there's …
lhoestq Jun 7, 2021
bbd1389
fix import
lhoestq Jun 7, 2021
ed174e5
use py int instead of np int
lhoestq Jun 7, 2021
5fd7edb
start documentation
lhoestq Jun 7, 2021
79014f8
add batched parameter, add n_shards
lhoestq Jun 9, 2021
369d238
import from main init
lhoestq Jun 9, 2021
602b985
docs
lhoestq Jun 9, 2021
42da548
Merge branch 'master' into dataset-streaming
lhoestq Jun 9, 2021
ba7bbca
docs
lhoestq Jun 9, 2021
4fa549e
fix docs
lhoestq Jun 9, 2021
4fa1e0f
add missing language codes for oscar using pycountry
lhoestq Jun 9, 2021
6a6e21f
add missing sections in oscar dataset card
lhoestq Jun 9, 2021
39f717a
add gz support + add tests
lhoestq Jun 10, 2021
c9acd44
remove constrains on fsspec and s3fs for py3.6
lhoestq Jun 10, 2021
f5cf3f3
fix test
lhoestq Jun 11, 2021
c1a63bf
fix test on windows
lhoestq Jun 11, 2021
1bf093f
style
lhoestq Jun 11, 2021
20aba4d
rename to interleave_datasets + comments
lhoestq Jun 17, 2021
9c1a2e1
Merge branch 'master' into dataset-streaming
lhoestq Jun 18, 2021
b180bc8
lewis' comments
lhoestq Jun 18, 2021
a657d03
typing in gzip
lhoestq Jun 18, 2021
cd23946
move interleave_datasets in combine.py
lhoestq Jun 18, 2021
7f19f4a
add pretty_name to OSCAR
lhoestq Jun 18, 2021
5ab438c
docs
lhoestq Jun 21, 2021
8f68a43
Merge branch 'master' into dataset-streaming
lhoestq Jun 21, 2021
0ee60af
Update src/datasets/combine.py
lhoestq Jun 23, 2021
ed9569a
fix docstring
lhoestq Jun 23, 2021
593e229
docstrings again
lhoestq Jun 23, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions datasets/oscar/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
---
pretty_name: OSCAR
annotations_creators:
- no-annotation
language_creators:
Expand Down Expand Up @@ -6329,16 +6330,30 @@ Filtering and cleaning processes at line level are done before feeding each line

### Source Data

#### Initial Data Collection and Normalization

[Common Crawl](https://commoncrawl.org/) is a non-profit foundation which produces and maintains an open repository of web crawled data that is both accessible and analysable. Common Crawl's complete web archive consists of petabytes of data collected over 8 years of web crawling. The repository contains raw web page HTML data (WARC files), metdata extracts (WAT files) and plain text extracts (WET files). The organisation's crawlers has always respected [nofollow](http://microformats.org/wiki/rel-nofollow) and [robots.txt](https://www.robotstxt.org/) policies.

Each monthly Common Crawl snapshot is in itself a massive multilingual corpus, where every single file contains data coming from multiple web pages written in a large variety of languages and covering all possible types of topics.

To construct OSCAR the WET files of Common Crawl were used. These contain the extracted plain texts from the websites mostly converted to UTF-8, as well as headers containing the metatada of each crawled document. Each WET file comes compressed in gzip format and is stored on Amazon Web Services. In the case of OSCAR, the **November 2018** snapshot was used. It surpasses 20TB of uncompressed data and contains more than 50 thousand plain text files where each file consists of the plain text from multiple websites along its metadata header.

#### Who are the source language producers?

The data comes from multiple web pages in a large variety of languages.

### Annotations

The dataset does not contain any additional annotations.

#### Annotation process

N/A

#### Who are the annotators?

N/A

### Personal and Sensitive Information

Being constructed from Common Crawl, Personal and sensitive information might be present. This **must** be considered before training deep learning models with OSCAR, specially in the case of text-generation models.
Expand Down
2 changes: 1 addition & 1 deletion datasets/oscar/oscar.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def _generate_examples(self, filepaths):
current_lines = []
for filepath in filepaths:
logger.info("generating examples from = %s", filepath)
with gzip.open(filepath, "rt", encoding="utf-8") as f:
with gzip.open(open(filepath, "rb"), "rt", encoding="utf-8") as f:
for line in f:
if len(line.strip()) > 0:
current_lines.append(line)
Expand Down
138 changes: 138 additions & 0 deletions docs/source/dataset_streaming.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
Load a Dataset in Streaming mode
==============================================================

When a dataset is in streaming mode, you can iterate over it directly without having to download the entire dataset.
The data are downloaded progressively as you iterate over the dataset.
You can enable dataset streaming by passing ``streaming=True`` in the :func:`load_dataset` function to get an iterable dataset.

This is useful if you don't have enough space on your disk to download the dataset, or if you don't want to wait for your dataset to be downloaded before using it.

Here is a demonstration:

.. code-block::

>>> from datasets import load_dataset
>>> dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
>>> print(next(iter(dataset)))
{'text': 'Mtendere Village was inspired by the vision of Chief Napoleon Dzombe, which he shared with John Blanchard during his first visit to Malawi. Chief Napoleon conveyed the desperate need for a program to intervene and care for the orphans and vulnerable children (OVC) in Malawi, and John committed to help...

Even though the dataset is 1.2 terabytes of data, you can start using it right away. Under the hood, it downloaded only the first examples of the dataset for buffering, and returned the first example.

.. note::

The dataset that is returned is a :class:`datasets.IterableDataset`, not the classic map-style :class:`datasets.Dataset`. To get examples from an iterable dataset, you have to iterate over it using a for loop for example. To get the very last example of the dataset, you first have to iterate on all the previous examples.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to create a classic Dataset from an IterableDataset?

one application that i have in mind is picking the first N examples of a huge dataset, collecting them in a standard Dataset and then doing all my exploration / preprocessing / task preparation etc on that dataset.

e.g. something like

from datasets import load_dataset 

dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
# create a `Dataset` i can play with?
sample = dataset.select(range(100))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure definitely :)

I was thinking of adding something like this in a next PR.
Maybe IterableDataset.to_map_style_dataset() ?
To get only the first examples we can also add IterableDataset.take(n)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, both those features would be great!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the IterableDataset.take(n) as well. Could we also have a IterableDataset.sample(n) taking a random sample?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a random sample would be very neat as well. here we might want to use something like reservoir sampling to deal with unbounded streams: https://en.wikipedia.org/wiki/Reservoir_sampling

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As soon as we have .take(), you can do iterable_dataset.shuffle(buffer_size=buffer_size, seed=seed).take(n) to take random samples.
This could be simplified by adding a .sample() method indeed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about slicing support i.e iterable_dataset[100:200] to get an iterator or Dataset at a particular slice?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to avoid allowing users to get items using __getitem__ since it's not a map-style dataset.
So I agree it would be nice to get a slice of the data, but with a different API. Maybe something like

sliced_dataset = iterable_dataset.skip(100).take(100)

What do you think ?

This is pretty close to the tf.data.Dataset API, which is also an iterable dataset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, very cool @lhoestq

Therefore iterable datasets are mostly useful for iterative jobs like training a model, but not for jobs that require random access of examples.


Shuffling the dataset: ``shuffle``
--------------------------------------------------

Shuffle the dataset
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To shuffle your dataset, the :func:`datasets.IterableDataset.shuffle` method fills a buffer of size ``buffer_size`` and randomly samples examples from this buffer.
The selected examples in the buffer are replaced by new examples.

For instance, if your dataset contains 1,000,000 examples but ``buffer_size`` is set to 1,000, then shuffle will initially select a random examples from only the first 1,000 examples in the buffer.
Once an example is selected, its space in the buffer is replaced by the next (i.e. 1,001-st) example, maintaining the 1,000 example buffer.

.. note::
For perfect shuffling, you need to set ``buffer_size`` to be greater than the size of your dataset. But in this case it will download the full dataset in the buffer.

Moreover, for larger datasets that are sharded into multiple files, :func:`datasets.IterableDataset.shuffle` also shuffles the order of the shards.

.. code-block::

>>> shuffled_dataset = dataset.shuffle(buffer_size=10_000, seed=42)
>>> print(next(iter(shuffled_dataset)))
{text': 'In this role, she oversees the day-to-day operations of the agency’s motoring services divisions (Vehicle Titles & Registration, Motor Vehicles, Motor Carrier, Enforcement, Consumer Relations and the Automobile Burglary & Theft Prevention Authority) to ensure they are constantly improving and identifying opportunities to become more efficient and effective in service delivery...
>>> print(dataset.n_shards)
670

In this example, the shuffle buffer contains 10,000 examples that were downloaded from one random shard of the dataset (here it actually comes from the 480-th shard out of 670).
The example was selected randomly from this buffer, and replaced by the 10,001-st example of the dataset shard.

Reshuffle the dataset at each epoch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The seed used to shuffle the dataset is the one you specify in :func:`datasets.IterableDataset.shuffle`. But often we want to use another seed after each epoch to reshuffle the dataset.
Therefore between epochs you can simply tell the dataset at what epoch you're at, and the data will be shuffled using an effective seed of ``seed + epoch``.

For example your training loop can look like this:

.. code-block::

>>> for epoch in range(epochs):
... shuffled_dataset.set_epoch(epoch)
... for example in shuffled_dataset:
... ...

In this case in the first epoch, the dataset is shuffled with ``seed + 0`` and in the second epoch it is shuffled with ``seed + 1``, making your dataset reshuffled at each epoch. It randomizes both the shuffle buffer and the shards order.


Processing data with ``map``
--------------------------------------------------

As for :class:`datasets.Dataset` objects, you can process your data using ``map``. This is useful if you want to transform the data or rename/remove columns.
Since the examples of an :class:`datasets.IterableDataset` are downloaded progressively, the :func:`datasets.IterableDataset.map` method processes the examples on-the-fly when you are iterating over the dataset (contrary to :func:`datasets.Dataset.map` which processes all the examples directly).

This example shows how to tokenize your dataset:

.. code-block::

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
>>> tokenized_dataset = dataset.map(lambda x: tokenizer(x["text"]))
>>> print(next(iter(tokenized_dataset)))
{'input_ids': [101, 11047, 10497, 7869, 2352...], 'token_type_ids': [0, 0, 0, 0, 0...], 'attention_mask': [1, 1, 1, 1, 1...]}

Tokenizers are written in Rust and use parallelism to speed up tokenization. To leverage parallelism, you can process the examples batch by batch. Note that the output examples are still returned one by one.

>>> tokenized_dataset = dataset.map(lambda x: tokenizer(x["text"]), batched=True) # default batch_size is 1000 but you can specify another batch_size if needed
>>> print(next(iter(tokenized_dataset)))
{'input_ids': [101, 11047, 10497, 7869, 2352...], 'token_type_ids': [0, 0, 0, 0, 0...], 'attention_mask': [1, 1, 1, 1, 1...]}


Mix several iterable datasets together with ``interleave_datasets``
----------------------------------------------------------------------------------------------------

It is common to use several datasets to use a model. For example BERT was trained on a mix of Wikipedia and BookCorpus.
You can mix several iterable datasets together using :func:`datasets.interleave_datasets`.

By default, the resulting dataset alternates between the original datasets, but can also define sampling probabilities to sample randomly from the different datasets.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, this is very cool!


For example if you want a dataset in several languages:

.. code-block::

>>> from datasets import interleave_datasets
>>> from itertools import islice
>>> en_dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
>>> fr_dataset = load_dataset('oscar', "unshuffled_deduplicated_fr", split='train', streaming=True)
>>>
>>> multilingual_dataset = interleave_datasets([en_dataset, fr_dataset])
>>> print(list(islice(multilingual_dataset, 2)))
[{'text': 'Mtendere Village was inspired by the vision...}, {'text': "Média de débat d'idées, de culture et de littérature....}]
>>>
>>> multilingual_dataset_with_oversampling = interleave_datasets([en_dataset, fr_dataset], probabilities=[0.8, 0.2], seed=42)
>>> print(list(islice(multilingual_dataset_with_oversampling, 2)))
[{'text': 'Mtendere Village was inspired by the vision...}, {'text': 'Lily James cannot fight the music...}]


Working with NumPy, pandas, PyTorch and TensorFlow
--------------------------------------------------

This part is still experimental and breaking changes may happen in the near future.

It is possible to get a ``torch.utils.data.IterableDataset`` from a :class:`datasets.IterableDataset` by setting the dataset format to "torch", as for a :class:`datasets.Dataset`:

.. code-block::

>>> import torch
>>> tokenized_dataset = dataset.map(lambda x: tokenizer(x["text"], return_tensors="pt"))
>>> torch_tokenized_dataset = tokenized_dataset.with_format("torch")
>>> assert isinstance(torch_tokenized_dataset, torch.utils.data.IterableDataset)
>>> print(next(iter(torch_tokenized_dataset)))
{'input_ids': tensor([[101, 11047, 10497, 7869, 2352...]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0...]]), 'attention_mask': tensor([[1, 1, 1, 1, 1...]])}

For now, only the PyTorch format is supported but support for TensorFlow and others will be added soon.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ The documentation is organized in five parts:
torch_tensorflow
filesystems
faiss_and_ea
dataset_streaming

.. toctree::
:maxdepth: 2
Expand Down
22 changes: 22 additions & 0 deletions docs/source/package_reference/main_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,28 @@ It also has dataset transform methods like map or filter, to process all the spl
prepare_for_task, align_labels_with_mapping


``IterableDataset``
~~~~~~~~~~~~~~~~~~~~~

The base class :class:`datasets.IterableDataset` implements an iterable Dataset backed by python generators.

.. autoclass:: datasets.IterableDataset
:members:
__iter__,
map, shuffle,
info, split, builder_name, citation, config_name, dataset_size,
description, download_checksums, download_size, features, homepage,
license, size_in_bytes, supervised_keys, version,


``IterableDatasetDict``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Dictionary with split names as keys ('train', 'test' for example), and :obj:`datasets.IterableDataset` objects as values.

.. autoclass:: datasets.IterableDatasetDict


``Features``
~~~~~~~~~~~~~~~~~~~~~

Expand Down
25 changes: 8 additions & 17 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,6 @@
DOCLINES = __doc__.split("\n")


# Pin some dependencies for old python versions
_deps = {
"fsspec": "fsspec"
if sys.version_info >= (3, 7)
else "fsspec<0.8.1", # fsspec>=0.8.1 requires py>=3.7 for async stuff
"s3fs": "s3fs"
if sys.version_info >= (3, 7)
else "s3fs==0.4.2", # later versions of s3fs have issues downloading directories recursively for py36
}


REQUIRED_PKGS = [
# We use numpy>=1.17 to have np.random.Generator (Dataset shuffling)
"numpy>=1.17",
Expand All @@ -93,7 +82,8 @@
# to get metadata of optional dependencies such as torch or tensorflow for Python versions that don't have it
"importlib_metadata;python_version<'3.8'",
# to save datasets locally or on any filesystem
_deps["fsspec"],
# minimum 2021.05.0 to have the AbstractArchiveFileSystem
"fsspec>=2021.05.0",
# To get datasets from the Datasets Hub on huggingface.co
"huggingface_hub<0.1.0",
# Utilities from PyPA to e.g., compare versions
Expand Down Expand Up @@ -122,7 +112,7 @@
"fsspec[s3]",
"moto[s3,server]==2.0.4",
"rarfile>=4.0",
_deps["s3fs"],
"s3fs",
"tensorflow>=2.3",
"torch",
"transformers",
Expand Down Expand Up @@ -182,11 +172,12 @@
"tensorflow_gpu": ["tensorflow-gpu>=2.2.0"],
"torch": ["torch"],
"s3": [
_deps["fsspec"],
"fsspec",
"boto3==1.16.43",
"botocore==1.19.52",
_deps["s3fs"],
"s3fs",
],
"streaming": ["aiohttp"],
"dev": TESTS_REQUIRE + QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
"quality": QUALITY_REQUIRE,
Expand All @@ -199,8 +190,8 @@
"sphinx-rtd-theme==0.4.3",
"sphinxext-opengraph==0.4.1",
"sphinx-copybutton",
_deps["fsspec"],
_deps["s3fs"],
"fsspec",
"s3fs",
],
}

Expand Down
4 changes: 3 additions & 1 deletion src/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from .arrow_reader import ArrowReader, ReadInstruction
from .arrow_writer import ArrowWriter
from .builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
from .dataset_dict import DatasetDict
from .combine import interleave_datasets
from .dataset_dict import DatasetDict, IterableDatasetDict
from .features import (
Array2D,
Array3D,
Expand All @@ -57,6 +58,7 @@
list_datasets,
list_metrics,
)
from .iterable_dataset import IterableDataset
from .keyhash import KeyHasher
from .load import import_main_class, load_dataset, load_from_disk, load_metric, prepare_module
from .metric import Metric
Expand Down
Loading