diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 54c05d25761..f49d0f3b8fc 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1674,16 +1674,16 @@ def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExample """Shuffle the underlying iterable, then repeat.""" return RepeatExamplesIterable(self.ex_iterable.shuffle_data_sources(generator), num_times=self.num_times) - def shard_data_sources(self, worker_id: int, num_workers: int) -> "RepeatExamplesIterable": + def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "RepeatExamplesIterable": """Shard, then repeat shards.""" return RepeatExamplesIterable( - self.ex_iterable.shard_data_sources(worker_id, num_workers), + self.ex_iterable.shard_data_sources(num_shards, index, contiguous=contiguous), num_times=self.num_times, ) @property - def n_shards(self) -> int: - return self.ex_iterable.n_shards + def num_shards(self) -> int: + return self.ex_iterable.num_shards class TakeExamplesIterable(_BaseExamplesIterable):