Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/datasets/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,13 +1674,17 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExample
"""Shuffle the underlying iterable, then repeat."""
return RepeatExamplesIterable(self.ex_iterable.shuffle_data_sources(generator), num_times=self.num_times)

def shard_data_sources(self, worker_id: int, num_workers: int) -> "RepeatExamplesIterable":
def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "RepeatExamplesIterable":
"""Shard, then repeat shards."""
return RepeatExamplesIterable(
self.ex_iterable.shard_data_sources(worker_id, num_workers),
self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous),
num_times=self.num_times,
)

@property
def num_shards(self) -> int:
return self.ex_iterable.num_shards

@property
def n_shards(self) -> int:
return self.ex_iterable.n_shards
Expand Down