Skip to content

Commit 637246b

Browse files
yzhangcslhoestq
andauthored
Fix incorrect rank value in data splitting (#6994)
* Fix incorrect rank value in data splitting (#6990) * Add tests for splitting distributed datasets * make style --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 1e1d313 commit 637246b

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

src/datasets/iterable_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3013,8 +3013,8 @@ def _split_by_node_iterable_dataset(dataset: IterableDataset, rank: int, world_s
30133013
[`IterableDataset`]: The iterable dataset to be used on the node at rank `rank`.
30143014
"""
30153015
if dataset._distributed:
3016-
world_size = world_size * dataset._distributed.world_size
30173016
rank = world_size * dataset._distributed.rank + rank
3017+
world_size = world_size * dataset._distributed.world_size
30183018
distributed = DistributedConfig(rank=rank, world_size=world_size)
30193019
return IterableDataset(
30203020
ex_iterable=dataset._ex_iterable,

tests/test_distributed.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,26 @@ def gen(shards):
5555
assert len({tuple(x.values()) for ds in datasets_per_rank for x in ds}) == full_size
5656

5757

58+
def test_split_dataset_by_node_iterable_distributed():
59+
def gen():
60+
return ({"i": i} for i in range(100))
61+
62+
world_size = 3
63+
num_workers = 3
64+
full_ds = IterableDataset.from_generator(gen)
65+
full_size = len(list(full_ds))
66+
datasets_per_rank = [
67+
split_dataset_by_node(full_ds, rank=rank, world_size=world_size) for rank in range(world_size)
68+
]
69+
datasets_per_rank_per_worker = [
70+
split_dataset_by_node(ds, rank=worker, world_size=num_workers)
71+
for ds in datasets_per_rank
72+
for worker in range(num_workers)
73+
]
74+
assert sum(len(list(ds)) for ds in datasets_per_rank_per_worker) == full_size
75+
assert len({tuple(x.values()) for ds in datasets_per_rank_per_worker for x in ds}) == full_size
76+
77+
5878
def test_distributed_shuffle_iterable():
5979
def gen():
6080
return ({"i": i} for i in range(17))

0 commit comments

Comments
 (0)