diff --git a/src/datasets/io/abc.py b/src/datasets/io/abc.py index 62cb696ab29..0e2fe42bb40 100644 --- a/src/datasets/io/abc.py +++ b/src/datasets/io/abc.py @@ -1,8 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, Union -from .. import DatasetDict, Features, NamedSplit -from ..arrow_dataset import Dataset +from .. import Dataset, DatasetDict, Features, IterableDataset, IterableDatasetDict, NamedSplit from ..utils.typing import NestedDataStructureLike, PathLike @@ -14,6 +13,7 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, **kwargs, ): self.path_or_paths = path_or_paths @@ -21,10 +21,11 @@ def __init__( self.features = features self.cache_dir = cache_dir self.keep_in_memory = keep_in_memory + self.streaming = streaming self.kwargs = kwargs @abstractmethod - def read(self) -> Union[Dataset, DatasetDict]: + def read(self) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: pass @@ -34,13 +35,15 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, **kwargs, ): self.features = features self.cache_dir = cache_dir self.keep_in_memory = keep_in_memory + self.streaming = streaming self.kwargs = kwargs @abstractmethod - def read(self) -> Dataset: + def read(self) -> Union[Dataset, IterableDataset]: pass diff --git a/src/datasets/io/csv.py b/src/datasets/io/csv.py index e59d09a2941..1c285d73144 100644 --- a/src/datasets/io/csv.py +++ b/src/datasets/io/csv.py @@ -18,10 +18,17 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, **kwargs, ): super().__init__( - path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs + path_or_paths, + split=split, + features=features, + cache_dir=cache_dir, + keep_in_memory=keep_in_memory, + streaming=streaming, + **kwargs, ) path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} self.builder = Csv( @@ -32,25 +39,28 @@ def __init__( ) def read(self): - download_config = None - download_mode = None - ignore_verifications = False - use_auth_token = None - base_path = None - - self.builder.download_and_prepare( - download_config=download_config, - download_mode=download_mode, - ignore_verifications=ignore_verifications, - # try_from_hf_gcs=try_from_hf_gcs, - base_path=base_path, - use_auth_token=use_auth_token, - ) - - # Build dataset for splits - dataset = self.builder.as_dataset( - split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory - ) + # Build iterable dataset + if self.streaming: + dataset = self.builder.as_streaming_dataset(split=self.split) + # Build regular (map-style) dataset + else: + download_config = None + download_mode = None + ignore_verifications = False + use_auth_token = None + base_path = None + + self.builder.download_and_prepare( + download_config=download_config, + download_mode=download_mode, + ignore_verifications=ignore_verifications, + # try_from_hf_gcs=try_from_hf_gcs, + base_path=base_path, + use_auth_token=use_auth_token, + ) + dataset = self.builder.as_dataset( + split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + ) return dataset diff --git a/src/datasets/io/generator.py b/src/datasets/io/generator.py index df1a6170c89..9d5d9ad67e6 100644 --- a/src/datasets/io/generator.py +++ b/src/datasets/io/generator.py @@ -12,10 +12,13 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, gen_kwargs: Optional[dict] = None, **kwargs, ): - super().__init__(features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs) + super().__init__( + features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, streaming=streaming, **kwargs + ) self.builder = Generator( cache_dir=cache_dir, features=features, @@ -25,23 +28,26 @@ def __init__( ) def read(self): - download_config = None - download_mode = None - ignore_verifications = False - use_auth_token = None - base_path = None - - self.builder.download_and_prepare( - download_config=download_config, - download_mode=download_mode, - ignore_verifications=ignore_verifications, - # try_from_hf_gcs=try_from_hf_gcs, - base_path=base_path, - use_auth_token=use_auth_token, - ) + # Build iterable dataset + if self.streaming: + dataset = self.builder.as_streaming_dataset(split="train") + # Build regular (map-style) dataset + else: + download_config = None + download_mode = None + ignore_verifications = False + use_auth_token = None + base_path = None - # Build dataset for splits - dataset = self.builder.as_dataset( - split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory - ) + self.builder.download_and_prepare( + download_config=download_config, + download_mode=download_mode, + ignore_verifications=ignore_verifications, + # try_from_hf_gcs=try_from_hf_gcs, + base_path=base_path, + use_auth_token=use_auth_token, + ) + dataset = self.builder.as_dataset( + split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + ) return dataset diff --git a/src/datasets/io/json.py b/src/datasets/io/json.py index e5196ec12e5..2888fadfd2d 100644 --- a/src/datasets/io/json.py +++ b/src/datasets/io/json.py @@ -20,11 +20,18 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, field: Optional[str] = None, **kwargs, ): super().__init__( - path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs + path_or_paths, + split=split, + features=features, + cache_dir=cache_dir, + keep_in_memory=keep_in_memory, + streaming=streaming, + **kwargs, ) self.field = field path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} @@ -37,26 +44,28 @@ def __init__( ) def read(self): - download_config = None - download_mode = None - ignore_verifications = True - try_from_hf_gcs = False - use_auth_token = None - base_path = None - - self.builder.download_and_prepare( - download_config=download_config, - download_mode=download_mode, - ignore_verifications=ignore_verifications, - try_from_hf_gcs=try_from_hf_gcs, - base_path=base_path, - use_auth_token=use_auth_token, - ) - - # Build dataset for splits - dataset = self.builder.as_dataset( - split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory - ) + # Build iterable dataset + if self.streaming: + dataset = self.builder.as_streaming_dataset(split=self.split) + # Build regular (map-style) dataset + else: + download_config = None + download_mode = None + ignore_verifications = False + use_auth_token = None + base_path = None + + self.builder.download_and_prepare( + download_config=download_config, + download_mode=download_mode, + ignore_verifications=ignore_verifications, + # try_from_hf_gcs=try_from_hf_gcs, + base_path=base_path, + use_auth_token=use_auth_token, + ) + dataset = self.builder.as_dataset( + split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + ) return dataset diff --git a/src/datasets/io/parquet.py b/src/datasets/io/parquet.py index 11ed2fd0337..e789a7ade63 100644 --- a/src/datasets/io/parquet.py +++ b/src/datasets/io/parquet.py @@ -20,10 +20,17 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, **kwargs, ): super().__init__( - path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs + path_or_paths, + split=split, + features=features, + cache_dir=cache_dir, + keep_in_memory=keep_in_memory, + streaming=streaming, + **kwargs, ) path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} hash = _PACKAGED_DATASETS_MODULES["parquet"][1] @@ -36,25 +43,28 @@ def __init__( ) def read(self): - download_config = None - download_mode = None - ignore_verifications = False - use_auth_token = None - base_path = None - - self.builder.download_and_prepare( - download_config=download_config, - download_mode=download_mode, - ignore_verifications=ignore_verifications, - # try_from_hf_gcs=try_from_hf_gcs, - base_path=base_path, - use_auth_token=use_auth_token, - ) + # Build iterable dataset + if self.streaming: + dataset = self.builder.as_streaming_dataset(split=self.split) + # Build regular (map-style) dataset + else: + download_config = None + download_mode = None + ignore_verifications = False + use_auth_token = None + base_path = None - # Build dataset for splits - dataset = self.builder.as_dataset( - split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory - ) + self.builder.download_and_prepare( + download_config=download_config, + download_mode=download_mode, + ignore_verifications=ignore_verifications, + # try_from_hf_gcs=try_from_hf_gcs, + base_path=base_path, + use_auth_token=use_auth_token, + ) + dataset = self.builder.as_dataset( + split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + ) return dataset diff --git a/src/datasets/io/text.py b/src/datasets/io/text.py index f4557f36e62..3612e70f219 100644 --- a/src/datasets/io/text.py +++ b/src/datasets/io/text.py @@ -14,10 +14,17 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, **kwargs, ): super().__init__( - path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs + path_or_paths, + split=split, + features=features, + cache_dir=cache_dir, + keep_in_memory=keep_in_memory, + streaming=streaming, + **kwargs, ) path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} self.builder = Text( @@ -28,23 +35,26 @@ def __init__( ) def read(self): - download_config = None - download_mode = None - ignore_verifications = False - use_auth_token = None - base_path = None + # Build iterable dataset + if self.streaming: + dataset = self.builder.as_streaming_dataset(split=self.split) + # Build regular (map-style) dataset + else: + download_config = None + download_mode = None + ignore_verifications = False + use_auth_token = None + base_path = None - self.builder.download_and_prepare( - download_config=download_config, - download_mode=download_mode, - ignore_verifications=ignore_verifications, - # try_from_hf_gcs=try_from_hf_gcs, - base_path=base_path, - use_auth_token=use_auth_token, - ) - - # Build dataset for splits - dataset = self.builder.as_dataset( - split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory - ) + self.builder.download_and_prepare( + download_config=download_config, + download_mode=download_mode, + ignore_verifications=ignore_verifications, + # try_from_hf_gcs=try_from_hf_gcs, + base_path=base_path, + use_auth_token=use_auth_token, + ) + dataset = self.builder.as_dataset( + split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + ) return dataset diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 390115d0a67..1e626187186 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -787,6 +787,41 @@ def __iter__(self): else: yield example + @staticmethod + def from_generator( + generator: Callable, + features: Optional[Features] = None, + gen_kwargs: Optional[dict] = None, + ): + """Create an Iterable Dataset from a generator. + + Args: + generator (:obj:`Callable`): A generator function that `yields` examples. + features (:class:`Features`, optional): Dataset features. + gen_kwargs(:obj:`dict`, optional): Keyword arguments to be passed to the `generator` callable. + + Returns: + :class:`IterableDataset` + + Example: + + ```py + >>> def gen(): + ... yield {"text": "Good", "label": 0} + ... yield {"text": "Bad", "label": 1} + ... + >>> ds = IterableDataset.from_generator(gen) + ``` + """ + from .io.generator import GeneratorDatasetInputStream + + return GeneratorDatasetInputStream( + generator=generator, + features=features, + gen_kwargs=gen_kwargs, + streaming=True, + ).read() + def with_format( self, type: Optional[str] = None, diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 1fc009a39e8..6c25e747630 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -576,6 +576,22 @@ def test_iterable_dataset_factory(): assert dataset._ex_iterable is ex_iterable +def test_iterable_dataset_from_generator(): + data = [ + {"col_1": "0", "col_2": 0, "col_3": 0.0}, + {"col_1": "1", "col_2": 1, "col_3": 1.0}, + {"col_1": "2", "col_2": 2, "col_3": 2.0}, + {"col_1": "3", "col_2": 3, "col_3": 3.0}, + ] + + def gen(): + yield from data + + dataset = IterableDataset.from_generator(gen) + assert isinstance(dataset, IterableDataset) + assert list(dataset) == data + + @require_torch def test_iterable_dataset_factory_torch_integration(): import torch