diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index e438b901694..7db4fff9483 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -3581,15 +3581,12 @@ def batch(self, batch_size: int, drop_last_batch: bool = False) -> "IterableData ``` """ - def batch_fn(unbatched): - return {k: [v] for k, v in unbatched.items()} - if self.features: features = Features({col: List(feature) for col, feature in self.features.items()}) else: features = None return self.map( - batch_fn, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch, features=features + _batch_fn, batched=True, batch_size=batch_size, drop_last_batch=drop_last_batch, features=features ) def to_dict(self, batch_size: Optional[int] = None, batched: bool = False) -> Union[dict, Iterator[dict]]: @@ -4659,3 +4656,7 @@ async def _apply_async(pool, func, x): return future.get() else: await asyncio.sleep(0) + + +def _batch_fn(unbatched): + return {k: [v] for k, v in unbatched.items()}