Skip to content

Problematic rank after calling split_dataset_by_node twice #6990

@yzhangcs

Description

@yzhangcs

Describe the bug

I'm trying to split IterableDataset by split_dataset_by_node.
But when doing split on a already split dataset, the resulting rank is greater than world_size.

Steps to reproduce the bug

Here is the minimal code for reproduction:

>>> from datasets import load_dataset
>>> from datasets.distributed import split_dataset_by_node
>>> dataset = load_dataset('fla-hub/slimpajama-test', split='train', streaming=True)      
>>> dataset = split_dataset_by_node(dataset, 1, 32)
>>> dataset._distributed
DistributedConfig(rank=1, world_size=32)
>>> dataset = split_dataset_by_node(dataset, 1, 15)
>>> dataset._distributed                           
DistributedConfig(rank=481, world_size=480)

As you can see, the second rank 481 > 480, which is problematic.

Expected behavior

I think this error comes from this line @lhoestq

world_size = world_size * dataset._distributed.world_size
rank = world_size * dataset._distributed.rank + rank

We may need to obtain the rank first. Then the above code gives

>>> dataset._distributed                           
DistributedConfig(rank=16, world_size=480)

Environment info

datasets==2.20.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions