Skip to content

Conversation

@lhoestq
Copy link
Member

@lhoestq lhoestq commented May 18, 2021

Dataset Streaming

API

Current API is

from datasets import load_dataset

# Load an IterableDataset without downloading data
snli = load_dataset("snli", streaming=True)

# Access examples by streaming data
print(next(iter(snli["train"]))) 
# {'premise': 'A person on a horse jumps over a broken down airplane.',
#  'hypothesis': 'A person is training his horse for a competition.',
#  'label': 1}

I already implemented a few methods:

  • IterableDataset.map: apply transforms on-the-fly to the examples
  • IterableDataset.shuffle: shuffle the data a la TFDS, i.e. with a shuffling buffer
  • IterableDataset.with_format: set the format to "torch" to get a torch.utils.data.IterableDataset
  • merge_datasets: merge two iterable datasets by alternating one or the other (you can specify the probabilities)

I would love to have your opinion on the API design :)

Implementation details

Streaming

Data streaming is done using fsspec which has nice caching features.

To make dataset streaming work I extend the open function of dataset scripts to support opening remote files without downloading them entirely. It also works with remote compressed archives (currently only zip is supported):

# Get a file-like object by streaming data from a remote file
open("https://github.com/davidsbatista/NER-datasets/raw/master/CONLL2003/train.txt")

# Get a file-like object by streaming data from a remote compressed archive by using the hop separator "::"
open("zip://snli_1.0_train.txt::https://nlp.stanford.edu/projects/snli/snli_1.0.zip")

I also extend the os.path.join function to support navigation in remote compressed archives, since it has to deal with the "::" separator. This separator is used by fsspec.

Finally I also added a retry mechanism in case the connection fails during data streaming.

Transforms

An IterableDataset wraps an ExamplesIterable instance. There are different subclasses depending on the transforms we want to apply:

  • ExamplesIterable: the basic one
  • MappedExamplesIterable: an iterable with a map function applied on the fly
  • BufferShuffledExamplesIterable: an iterable with a shuffling buffer
  • CyclingMultiSourcesExamplesIterable: alternates between several ExamplesIterable
  • RandomlyCyclingMultiSourcesExamplesIterable: randomly alternates between several ExamplesIterable

DatasetBuilder

I use the same builders as usual. I just added a new method _get_examples_iterable_for_split to get an ExamplesIterable for a given split. Currently only the GeneratorBasedBuilder and the ArrowBasedBuilder implement it.

The BeamBasedBuilder doesn't implement it yet.
It means that datasets like wikipedia and natural_questions can't be loaded as IterableDataset for now.

Other details

I may have to do some changes in many dataset script to use download instead of download_and_extract when extraction is not needed. This will avoid errors for streaming.

EDIT: Actually I just check for the extension of the file to do extraction only if needed.

EDIT2: It's not possible to stream from .tar.gz files without downloading the file completely. For now I raise an error if one want to get a streaming dataset based on .tar.gz files.

TODO

usual stuff:

  • make streaming dependency "aiohttp" optional: pip install datasets[streaming]
  • tests
  • docs

@lhoestq lhoestq force-pushed the dataset-streaming branch from 4f8ee69 to f5cf3f3 Compare June 11, 2021 08:26
Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feature look mega cool and will solve a major pain point i've experienced with task templates for large datasets!

i left a few nits in the docs and a couple of questions

>>> print(next(iter(dataset)))
{'text': 'Mtendere Village was inspired by the vision of Chief Napoleon Dzombe, which he shared with John Blanchard during his first visit to Malawi. Chief Napoleon conveyed the desperate need for a program to intervene and care for the orphans and vulnerable children (OVC) in Malawi, and John committed to help...
Even though the dataset is 1.2 terabytes of data, you can start using it right away. Under the hood, it downloaded only the first examples of the dataset
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
Even though the dataset is 1.2 terabytes of data, you can start using it right away. Under the hood, it downloaded only the first examples of the dataset
Even though the dataset is 1.2 terabytes of data, you can start using it right away! Under the hood, it only downloaded the first example of the dataset.

also, does it download 1 or more than one example when we use next(iter(dataset))?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It downloads the first examples (buffering + caching) and yield the first one

This comment was marked as resolved.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect!

It is common to use several datasets to use a model. For example BERT was trained on a mix of Wikipedia and BookCorpus.
You can mix several iterable datasets together using :func:`datasets.interleave_datasets`.

By default, the resulting dataset alternates between the original datasets, but can also define sampling probabilities to sample randomly from the different datasets.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow, this is very cool!


.. note::

The dataset that is returned is a :class:`datasets.IterableDataset`, not the classic map-style :class:`datasets.Dataset`. To get examples from an iterable dataset, you have to iterate over it using a for loop for example. To get the very last example of the dataset, you first have to iterate on all the previous examples.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to create a classic Dataset from an IterableDataset?

one application that i have in mind is picking the first N examples of a huge dataset, collecting them in a standard Dataset and then doing all my exploration / preprocessing / task preparation etc on that dataset.

e.g. something like

from datasets import load_dataset 

dataset = load_dataset('oscar', "unshuffled_deduplicated_en", split='train', streaming=True)
# create a `Dataset` i can play with?
sample = dataset.select(range(100))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure definitely :)

I was thinking of adding something like this in a next PR.
Maybe IterableDataset.to_map_style_dataset() ?
To get only the first examples we can also add IterableDataset.take(n)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, both those features would be great!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the IterableDataset.take(n) as well. Could we also have a IterableDataset.sample(n) taking a random sample?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a random sample would be very neat as well. here we might want to use something like reservoir sampling to deal with unbounded streams: https://en.wikipedia.org/wiki/Reservoir_sampling

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As soon as we have .take(), you can do iterable_dataset.shuffle(buffer_size=buffer_size, seed=seed).take(n) to take random samples.
This could be simplified by adding a .sample() method indeed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about slicing support i.e iterable_dataset[100:200] to get an iterator or Dataset at a particular slice?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to avoid allowing users to get items using __getitem__ since it's not a map-style dataset.
So I agree it would be nice to get a slice of the data, but with a different API. Maybe something like

sliced_dataset = iterable_dataset.skip(100).take(100)

What do you think ?

This is pretty close to the tf.data.Dataset API, which is also an iterable dataset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, very cool @lhoestq

Copy link
Member Author

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback :) I took your comments into account

>>> print(next(iter(dataset)))
{'text': 'Mtendere Village was inspired by the vision of Chief Napoleon Dzombe, which he shared with John Blanchard during his first visit to Malawi. Chief Napoleon conveyed the desperate need for a program to intervene and care for the orphans and vulnerable children (OVC) in Malawi, and John committed to help...
Even though the dataset is 1.2 terabytes of data, you can start using it right away. Under the hood, it downloaded only the first examples of the dataset

This comment was marked as resolved.

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🚀


.. note::

The dataset that is returned is a :class:`datasets.IterableDataset`, not the classic map-style :class:`datasets.Dataset`. To get examples from an iterable dataset, you have to iterate over it using a for loop for example. To get the very last example of the dataset, you first have to iterate on all the previous examples.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a random sample would be very neat as well. here we might want to use something like reservoir sampling to deal with unbounded streams: https://en.wikipedia.org/wiki/Reservoir_sampling

Copy link
Member

@thomwolf thomwolf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really super cool!

@lhoestq lhoestq merged commit 3c49355 into master Jun 23, 2021
@lhoestq lhoestq deleted the dataset-streaming branch June 23, 2021 16:35
@lhoestq lhoestq mentioned this pull request Jul 2, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants