Skip to content

Conversation

@lhoestq
Copy link
Member

@lhoestq lhoestq commented Jun 9, 2022

Currently .shard() and .select() always create an indices mapping. However if the requested data are contiguous, it's much more optimized to simply slice the Arrow table instead of building an indices mapping. In particular:

  • the shard/select operation will be much faster
  • reading speed will be much faster in the resulting dataset, since it won't have to do a lookup step in the indices mapping

Since .shard() is also used for .map() with num_proc>1, it will also significantly improve the reading speed of multiprocessed .map() operations

Here is an example of speed-up:

>>> import io
>>> import numpy as np
>>> from datasets import Dataset
>>> ds = Dataset.from_dict({"a": np.random.rand(10_000_000)})
>>> shard = ds.shard(num_shards=4, index=0, contiguous=True)  # this calls `.select(range(2_500_000))`
>>> buf = io.BytesIO()
>>> %time dd.to_json(buf)
Creating json from Arrow format: 100%|██████████████████| 100/100 [00:00<00:00, 376.17ba/s]
CPU times: user 258 ms, sys: 9.06 ms, total: 267 ms
Wall time: 266 ms

while previously it was

Creating json from Arrow format: 100%|███████████████████| 100/100 [00:03<00:00, 29.41ba/s]
CPU times: user 3.33 s, sys: 69.1 ms, total: 3.39 s
Wall time: 3.4 s

In this simple case the speed-up is x10, but @sayakpaul experienced a x100 speed-up on its data when exporting to JSON.

Implementation details

I mostly improved .select(): it now checks if the input corresponds to a contiguous chunk of data and then it slices the main Arrow table (or the indices mapping table if it exists). To check if the input indices are contiguous it checks two possibilities:

  • if the indices is of type range, it checks that start >= 0 and step = 1
  • otherwise in the general case, it iterates over the indices. If all the indices are contiguous then we're good, otherwise we have to build an indices mapping.

Having to iterate over the indices doesn't cause performance issues IMO because:

  • either they are contiguous and in this case the cost of iterating over the indices is much less than the cost of creating an indices mapping
  • or they are not contiguous, and then iterating generally stops quickly when it first encounters the first indice that is not contiguous.

@lhoestq lhoestq requested a review from albertvillanova June 9, 2022 13:45
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 9, 2022

The documentation is not available anymore as the PR was closed or merged.

@sayakpaul
Copy link
Member

I thought of just mentioning the benefits I got. Here's the code that @lhoestq provided:

import os
from datasets import load_dataset
from tqdm.auto import tqdm

ds = load_dataset("squad", split="train")
os.makedirs("tmp")

num_shards = 5
for index in tqdm(range(num_shards)):
    size = len(ds) // num_shards
    shard = Dataset(ds.data.slice(size * index, size), fingerprint=f"{ds._fingerprint}_{index}")
    shard.to_json(f"tmp/data_{index}.jsonl")

It is 1.64s. Previously the code was:

num_shards = 5
for index in tqdm(range(num_shards)):
    shard = ds.shard(num_shards=num_shards, index=index, contiguous=True)
    shard.to_json(f"tmp/data_{index}.jsonl")
    # upload_to_gcs(f"tmp/data_{index}.jsonl")

It was 2min31s.

I ran it on my humble MacBook Pro:

image

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good performance gain!!!

Just some comment/question below.

Comment on lines 3043 to 3044
except StopIteration:
return self._select_contiguous(0, 0, new_fingerprint=new_fingerprint)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naive question: which use case is this for?

Copy link
Member Author

@lhoestq lhoestq Jun 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's in case indices is an empty iterable, let me add a comment

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See additional comment on performance.

Comment on lines 3041 to 3047
try:
start = next(iter(indices))
except StopIteration:
return self._select_contiguous(0, 0, new_fingerprint=new_fingerprint)
if start >= 0:
counter_from_start = itertools.count(start=start)
if all(i == j for i, j in zip(indices, counter_from_start)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note this implementation has an overhead for np.array and pd.Series, compared to regular Python list:

In [34]: def check_with_counter(indices):
    ...:     start = next(iter(indices))
    ...:     counter_from_start = itertools.count(start=start)
    ...:     if all(i == j for i, j in zip(indices, counter_from_start)):
    ...:         return True
    ...:     else:
    ...:         return False

In [81]: lis = list(range(10_000_000))

In [82]: %timeit check_with_counter(lis)
657 ms ± 5.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [83]: arr = np.array(range(10_000_000))

In [84]: %timeit check_with_counter(arr)
2.43 s ± 37.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [85]: ser = pd.Series(list(range(10_000_000)))

In [86]: %timeit check_with_counter(ser)
1.35 s ± 18.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the info ! In the docstring of select we should maybe encourage users to pass a range instead of a list/array/series in such cases anyway

@lhoestq
Copy link
Member Author

lhoestq commented Jun 14, 2022

I addressed your comments @albertvillanova , let me know what you think :)

Copy link
Member

@albertvillanova albertvillanova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good job.

@lhoestq lhoestq merged commit 5994036 into master Jun 14, 2022
@lhoestq lhoestq deleted the optimize-contiguous-shard-and-select branch June 14, 2022 15:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants