-
Notifications
You must be signed in to change notification settings - Fork 3k
Optimize contiguous shard and select #4466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
|
I thought of just mentioning the benefits I got. Here's the code that @lhoestq provided: import os
from datasets import load_dataset
from tqdm.auto import tqdm
ds = load_dataset("squad", split="train")
os.makedirs("tmp")
num_shards = 5
for index in tqdm(range(num_shards)):
size = len(ds) // num_shards
shard = Dataset(ds.data.slice(size * index, size), fingerprint=f"{ds._fingerprint}_{index}")
shard.to_json(f"tmp/data_{index}.jsonl")It is 1.64s. Previously the code was: num_shards = 5
for index in tqdm(range(num_shards)):
shard = ds.shard(num_shards=num_shards, index=index, contiguous=True)
shard.to_json(f"tmp/data_{index}.jsonl")
# upload_to_gcs(f"tmp/data_{index}.jsonl")It was 2min31s. I ran it on my humble MacBook Pro: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, good performance gain!!!
Just some comment/question below.
| except StopIteration: | ||
| return self._select_contiguous(0, 0, new_fingerprint=new_fingerprint) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Naive question: which use case is this for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's in case indices is an empty iterable, let me add a comment
albertvillanova
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See additional comment on performance.
| try: | ||
| start = next(iter(indices)) | ||
| except StopIteration: | ||
| return self._select_contiguous(0, 0, new_fingerprint=new_fingerprint) | ||
| if start >= 0: | ||
| counter_from_start = itertools.count(start=start) | ||
| if all(i == j for i, j in zip(indices, counter_from_start)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also note this implementation has an overhead for np.array and pd.Series, compared to regular Python list:
In [34]: def check_with_counter(indices):
...: start = next(iter(indices))
...: counter_from_start = itertools.count(start=start)
...: if all(i == j for i, j in zip(indices, counter_from_start)):
...: return True
...: else:
...: return False
In [81]: lis = list(range(10_000_000))
In [82]: %timeit check_with_counter(lis)
657 ms ± 5.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [83]: arr = np.array(range(10_000_000))
In [84]: %timeit check_with_counter(arr)
2.43 s ± 37.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [85]: ser = pd.Series(list(range(10_000_000)))
In [86]: %timeit check_with_counter(ser)
1.35 s ± 18.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the info ! In the docstring of select we should maybe encourage users to pass a range instead of a list/array/series in such cases anyway
|
I addressed your comments @albertvillanova , let me know what you think :) |
albertvillanova
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, good job.

Currently
.shard()and.select()always create an indices mapping. However if the requested data are contiguous, it's much more optimized to simply slice the Arrow table instead of building an indices mapping. In particular:Since
.shard()is also used for.map()withnum_proc>1, it will also significantly improve the reading speed of multiprocessed.map()operationsHere is an example of speed-up:
while previously it was
In this simple case the speed-up is x10, but @sayakpaul experienced a x100 speed-up on its data when exporting to JSON.
Implementation details
I mostly improved
.select(): it now checks if the input corresponds to a contiguous chunk of data and then it slices the main Arrow table (or the indices mapping table if it exists). To check if the input indices are contiguous it checks two possibilities:range, it checks that start >= 0 and step = 1Having to iterate over the indices doesn't cause performance issues IMO because: