-
Notifications
You must be signed in to change notification settings - Fork 3k
Closed
Description
Describe the bug
I'm trying to split IterableDataset by split_dataset_by_node.
But when doing split on a already split dataset, the resulting rank is greater than world_size.
Steps to reproduce the bug
Here is the minimal code for reproduction:
>>> from datasets import load_dataset
>>> from datasets.distributed import split_dataset_by_node
>>> dataset = load_dataset('fla-hub/slimpajama-test', split='train', streaming=True)
>>> dataset = split_dataset_by_node(dataset, 1, 32)
>>> dataset._distributed
DistributedConfig(rank=1, world_size=32)
>>> dataset = split_dataset_by_node(dataset, 1, 15)
>>> dataset._distributed
DistributedConfig(rank=481, world_size=480)As you can see, the second rank 481 > 480, which is problematic.
Expected behavior
I think this error comes from this line @lhoestq
datasets/src/datasets/iterable_dataset.py
Lines 2943 to 2944 in a6ccf94
| world_size = world_size * dataset._distributed.world_size | |
| rank = world_size * dataset._distributed.rank + rank |
We may need to obtain the rank first. Then the above code gives
>>> dataset._distributed
DistributedConfig(rank=16, world_size=480)Environment info
datasets==2.20.0
Metadata
Metadata
Assignees
Labels
No labels