From 3c317edf3acd49a7eeb1aa1a7c47c390ae1f2ad0 Mon Sep 17 00:00:00 2001 From: Hamid Vakilzadeh Date: Fri, 30 Sep 2022 16:55:41 -0500 Subject: [PATCH 1/4] added from_generator method to IterableDataset class. --- src/datasets/io/generator.py | 13 +++++++---- src/datasets/iterable_dataset.py | 40 ++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 4 deletions(-) diff --git a/src/datasets/io/generator.py b/src/datasets/io/generator.py index df1a6170c89..2872b334fe2 100644 --- a/src/datasets/io/generator.py +++ b/src/datasets/io/generator.py @@ -24,7 +24,7 @@ def __init__( **kwargs, ) - def read(self): + def read(self, streaming: bool = False): download_config = None download_mode = None ignore_verifications = False @@ -40,8 +40,13 @@ def read(self): use_auth_token=use_auth_token, ) + # Build iterable dataset + if streaming: + dataset = self.builder.as_streaming_dataset(split="train") + # Build dataset for splits - dataset = self.builder.as_dataset( - split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory - ) + else: + dataset = self.builder.as_dataset( + split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + ) return dataset diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 390115d0a67..d86257e4aa7 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1339,6 +1339,46 @@ def _resolve_features(self): token_per_repo_id=self._token_per_repo_id, ) + @staticmethod + def from_generator( + generator: Callable, + features: Optional[Features] = None, + cache_dir: str = None, + keep_in_memory: bool = False, + gen_kwargs: Optional[dict] = None, + ): + """Create an Iterable Dataset from a generator. + + Args: + generator (:obj:`Callable`): A generator function that `yields` examples. + features (:class:`Features`, optional): Dataset features. + cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data. + keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory. + gen_kwargs(:obj:`dict`, optional): Keyword arguments to be passed to the `generator` callable. + + Returns: + :class:`IterableDataset` + + Example: + + ```py + >>> def gen(): + ... yield {"text": "Good", "label": 0} + ... yield {"text": "Bad", "label": 1} + ... + >>> ds = IterableDataset.from_generator(gen) + ``` + """ + from .io.generator import GeneratorDatasetInputStream + + return GeneratorDatasetInputStream( + generator=generator, + features=features, + cache_dir=cache_dir, + keep_in_memory=keep_in_memory, + gen_kwargs=gen_kwargs, + ).read(streaming=True) + def iterable_dataset( ex_iterable: Iterable, From 0473a0a0bfd949a83188a40377718603dd4c686a Mon Sep 17 00:00:00 2001 From: mariosasko Date: Tue, 4 Oct 2022 13:24:04 +0200 Subject: [PATCH 2/4] Move streaming param to __init__ --- src/datasets/io/abc.py | 10 ++++- src/datasets/io/csv.py | 50 ++++++++++++--------- src/datasets/io/generator.py | 41 ++++++++--------- src/datasets/io/json.py | 51 +++++++++++++--------- src/datasets/io/parquet.py | 48 ++++++++++++-------- src/datasets/io/text.py | 48 ++++++++++++-------- src/datasets/iterable_dataset.py | 75 +++++++++++++++----------------- 7 files changed, 182 insertions(+), 141 deletions(-) diff --git a/src/datasets/io/abc.py b/src/datasets/io/abc.py index 62cb696ab29..7498a72092b 100644 --- a/src/datasets/io/abc.py +++ b/src/datasets/io/abc.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import Optional, Union +from datasets.iterable_dataset import IterableDataset + from .. import DatasetDict, Features, NamedSplit from ..arrow_dataset import Dataset from ..utils.typing import NestedDataStructureLike, PathLike @@ -14,6 +16,7 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, **kwargs, ): self.path_or_paths = path_or_paths @@ -21,10 +24,11 @@ def __init__( self.features = features self.cache_dir = cache_dir self.keep_in_memory = keep_in_memory + self.streaming = streaming self.kwargs = kwargs @abstractmethod - def read(self) -> Union[Dataset, DatasetDict]: + def read(self) -> Union[Dataset, DatasetDict, IterableDataset, IterableDataset]: pass @@ -34,13 +38,15 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, **kwargs, ): self.features = features self.cache_dir = cache_dir self.keep_in_memory = keep_in_memory + self.streaming = streaming self.kwargs = kwargs @abstractmethod - def read(self) -> Dataset: + def read(self) -> Union[Dataset, IterableDataset]: pass diff --git a/src/datasets/io/csv.py b/src/datasets/io/csv.py index e59d09a2941..1c285d73144 100644 --- a/src/datasets/io/csv.py +++ b/src/datasets/io/csv.py @@ -18,10 +18,17 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, **kwargs, ): super().__init__( - path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs + path_or_paths, + split=split, + features=features, + cache_dir=cache_dir, + keep_in_memory=keep_in_memory, + streaming=streaming, + **kwargs, ) path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} self.builder = Csv( @@ -32,25 +39,28 @@ def __init__( ) def read(self): - download_config = None - download_mode = None - ignore_verifications = False - use_auth_token = None - base_path = None - - self.builder.download_and_prepare( - download_config=download_config, - download_mode=download_mode, - ignore_verifications=ignore_verifications, - # try_from_hf_gcs=try_from_hf_gcs, - base_path=base_path, - use_auth_token=use_auth_token, - ) - - # Build dataset for splits - dataset = self.builder.as_dataset( - split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory - ) + # Build iterable dataset + if self.streaming: + dataset = self.builder.as_streaming_dataset(split=self.split) + # Build regular (map-style) dataset + else: + download_config = None + download_mode = None + ignore_verifications = False + use_auth_token = None + base_path = None + + self.builder.download_and_prepare( + download_config=download_config, + download_mode=download_mode, + ignore_verifications=ignore_verifications, + # try_from_hf_gcs=try_from_hf_gcs, + base_path=base_path, + use_auth_token=use_auth_token, + ) + dataset = self.builder.as_dataset( + split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + ) return dataset diff --git a/src/datasets/io/generator.py b/src/datasets/io/generator.py index 2872b334fe2..9d5d9ad67e6 100644 --- a/src/datasets/io/generator.py +++ b/src/datasets/io/generator.py @@ -12,10 +12,13 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, gen_kwargs: Optional[dict] = None, **kwargs, ): - super().__init__(features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs) + super().__init__( + features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, streaming=streaming, **kwargs + ) self.builder = Generator( cache_dir=cache_dir, features=features, @@ -24,28 +27,26 @@ def __init__( **kwargs, ) - def read(self, streaming: bool = False): - download_config = None - download_mode = None - ignore_verifications = False - use_auth_token = None - base_path = None - - self.builder.download_and_prepare( - download_config=download_config, - download_mode=download_mode, - ignore_verifications=ignore_verifications, - # try_from_hf_gcs=try_from_hf_gcs, - base_path=base_path, - use_auth_token=use_auth_token, - ) - + def read(self): # Build iterable dataset - if streaming: + if self.streaming: dataset = self.builder.as_streaming_dataset(split="train") - - # Build dataset for splits + # Build regular (map-style) dataset else: + download_config = None + download_mode = None + ignore_verifications = False + use_auth_token = None + base_path = None + + self.builder.download_and_prepare( + download_config=download_config, + download_mode=download_mode, + ignore_verifications=ignore_verifications, + # try_from_hf_gcs=try_from_hf_gcs, + base_path=base_path, + use_auth_token=use_auth_token, + ) dataset = self.builder.as_dataset( split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory ) diff --git a/src/datasets/io/json.py b/src/datasets/io/json.py index e5196ec12e5..2888fadfd2d 100644 --- a/src/datasets/io/json.py +++ b/src/datasets/io/json.py @@ -20,11 +20,18 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, field: Optional[str] = None, **kwargs, ): super().__init__( - path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs + path_or_paths, + split=split, + features=features, + cache_dir=cache_dir, + keep_in_memory=keep_in_memory, + streaming=streaming, + **kwargs, ) self.field = field path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} @@ -37,26 +44,28 @@ def __init__( ) def read(self): - download_config = None - download_mode = None - ignore_verifications = True - try_from_hf_gcs = False - use_auth_token = None - base_path = None - - self.builder.download_and_prepare( - download_config=download_config, - download_mode=download_mode, - ignore_verifications=ignore_verifications, - try_from_hf_gcs=try_from_hf_gcs, - base_path=base_path, - use_auth_token=use_auth_token, - ) - - # Build dataset for splits - dataset = self.builder.as_dataset( - split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory - ) + # Build iterable dataset + if self.streaming: + dataset = self.builder.as_streaming_dataset(split=self.split) + # Build regular (map-style) dataset + else: + download_config = None + download_mode = None + ignore_verifications = False + use_auth_token = None + base_path = None + + self.builder.download_and_prepare( + download_config=download_config, + download_mode=download_mode, + ignore_verifications=ignore_verifications, + # try_from_hf_gcs=try_from_hf_gcs, + base_path=base_path, + use_auth_token=use_auth_token, + ) + dataset = self.builder.as_dataset( + split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + ) return dataset diff --git a/src/datasets/io/parquet.py b/src/datasets/io/parquet.py index 11ed2fd0337..e789a7ade63 100644 --- a/src/datasets/io/parquet.py +++ b/src/datasets/io/parquet.py @@ -20,10 +20,17 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, **kwargs, ): super().__init__( - path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs + path_or_paths, + split=split, + features=features, + cache_dir=cache_dir, + keep_in_memory=keep_in_memory, + streaming=streaming, + **kwargs, ) path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} hash = _PACKAGED_DATASETS_MODULES["parquet"][1] @@ -36,25 +43,28 @@ def __init__( ) def read(self): - download_config = None - download_mode = None - ignore_verifications = False - use_auth_token = None - base_path = None - - self.builder.download_and_prepare( - download_config=download_config, - download_mode=download_mode, - ignore_verifications=ignore_verifications, - # try_from_hf_gcs=try_from_hf_gcs, - base_path=base_path, - use_auth_token=use_auth_token, - ) + # Build iterable dataset + if self.streaming: + dataset = self.builder.as_streaming_dataset(split=self.split) + # Build regular (map-style) dataset + else: + download_config = None + download_mode = None + ignore_verifications = False + use_auth_token = None + base_path = None - # Build dataset for splits - dataset = self.builder.as_dataset( - split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory - ) + self.builder.download_and_prepare( + download_config=download_config, + download_mode=download_mode, + ignore_verifications=ignore_verifications, + # try_from_hf_gcs=try_from_hf_gcs, + base_path=base_path, + use_auth_token=use_auth_token, + ) + dataset = self.builder.as_dataset( + split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + ) return dataset diff --git a/src/datasets/io/text.py b/src/datasets/io/text.py index f4557f36e62..3612e70f219 100644 --- a/src/datasets/io/text.py +++ b/src/datasets/io/text.py @@ -14,10 +14,17 @@ def __init__( features: Optional[Features] = None, cache_dir: str = None, keep_in_memory: bool = False, + streaming: bool = False, **kwargs, ): super().__init__( - path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs + path_or_paths, + split=split, + features=features, + cache_dir=cache_dir, + keep_in_memory=keep_in_memory, + streaming=streaming, + **kwargs, ) path_or_paths = path_or_paths if isinstance(path_or_paths, dict) else {self.split: path_or_paths} self.builder = Text( @@ -28,23 +35,26 @@ def __init__( ) def read(self): - download_config = None - download_mode = None - ignore_verifications = False - use_auth_token = None - base_path = None + # Build iterable dataset + if self.streaming: + dataset = self.builder.as_streaming_dataset(split=self.split) + # Build regular (map-style) dataset + else: + download_config = None + download_mode = None + ignore_verifications = False + use_auth_token = None + base_path = None - self.builder.download_and_prepare( - download_config=download_config, - download_mode=download_mode, - ignore_verifications=ignore_verifications, - # try_from_hf_gcs=try_from_hf_gcs, - base_path=base_path, - use_auth_token=use_auth_token, - ) - - # Build dataset for splits - dataset = self.builder.as_dataset( - split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory - ) + self.builder.download_and_prepare( + download_config=download_config, + download_mode=download_mode, + ignore_verifications=ignore_verifications, + # try_from_hf_gcs=try_from_hf_gcs, + base_path=base_path, + use_auth_token=use_auth_token, + ) + dataset = self.builder.as_dataset( + split=self.split, ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory + ) return dataset diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index d86257e4aa7..1e626187186 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -787,6 +787,41 @@ def __iter__(self): else: yield example + @staticmethod + def from_generator( + generator: Callable, + features: Optional[Features] = None, + gen_kwargs: Optional[dict] = None, + ): + """Create an Iterable Dataset from a generator. + + Args: + generator (:obj:`Callable`): A generator function that `yields` examples. + features (:class:`Features`, optional): Dataset features. + gen_kwargs(:obj:`dict`, optional): Keyword arguments to be passed to the `generator` callable. + + Returns: + :class:`IterableDataset` + + Example: + + ```py + >>> def gen(): + ... yield {"text": "Good", "label": 0} + ... yield {"text": "Bad", "label": 1} + ... + >>> ds = IterableDataset.from_generator(gen) + ``` + """ + from .io.generator import GeneratorDatasetInputStream + + return GeneratorDatasetInputStream( + generator=generator, + features=features, + gen_kwargs=gen_kwargs, + streaming=True, + ).read() + def with_format( self, type: Optional[str] = None, @@ -1339,46 +1374,6 @@ def _resolve_features(self): token_per_repo_id=self._token_per_repo_id, ) - @staticmethod - def from_generator( - generator: Callable, - features: Optional[Features] = None, - cache_dir: str = None, - keep_in_memory: bool = False, - gen_kwargs: Optional[dict] = None, - ): - """Create an Iterable Dataset from a generator. - - Args: - generator (:obj:`Callable`): A generator function that `yields` examples. - features (:class:`Features`, optional): Dataset features. - cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data. - keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory. - gen_kwargs(:obj:`dict`, optional): Keyword arguments to be passed to the `generator` callable. - - Returns: - :class:`IterableDataset` - - Example: - - ```py - >>> def gen(): - ... yield {"text": "Good", "label": 0} - ... yield {"text": "Bad", "label": 1} - ... - >>> ds = IterableDataset.from_generator(gen) - ``` - """ - from .io.generator import GeneratorDatasetInputStream - - return GeneratorDatasetInputStream( - generator=generator, - features=features, - cache_dir=cache_dir, - keep_in_memory=keep_in_memory, - gen_kwargs=gen_kwargs, - ).read(streaming=True) - def iterable_dataset( ex_iterable: Iterable, From 84d79ced23f047d99c8f0179b4ac17ed72ea4271 Mon Sep 17 00:00:00 2001 From: mariosasko Date: Tue, 4 Oct 2022 13:24:24 +0200 Subject: [PATCH 3/4] Test --- tests/test_iterable_dataset.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index 1fc009a39e8..6c25e747630 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -576,6 +576,22 @@ def test_iterable_dataset_factory(): assert dataset._ex_iterable is ex_iterable +def test_iterable_dataset_from_generator(): + data = [ + {"col_1": "0", "col_2": 0, "col_3": 0.0}, + {"col_1": "1", "col_2": 1, "col_3": 1.0}, + {"col_1": "2", "col_2": 2, "col_3": 2.0}, + {"col_1": "3", "col_2": 3, "col_3": 3.0}, + ] + + def gen(): + yield from data + + dataset = IterableDataset.from_generator(gen) + assert isinstance(dataset, IterableDataset) + assert list(dataset) == data + + @require_torch def test_iterable_dataset_factory_torch_integration(): import torch From ee3224b855fe8c83349c935f32c06cfc44d9813c Mon Sep 17 00:00:00 2001 From: mariosasko Date: Wed, 5 Oct 2022 13:50:43 +0200 Subject: [PATCH 4/4] Type hint fix --- src/datasets/io/abc.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/datasets/io/abc.py b/src/datasets/io/abc.py index 7498a72092b..0e2fe42bb40 100644 --- a/src/datasets/io/abc.py +++ b/src/datasets/io/abc.py @@ -1,10 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional, Union -from datasets.iterable_dataset import IterableDataset - -from .. import DatasetDict, Features, NamedSplit -from ..arrow_dataset import Dataset +from .. import Dataset, DatasetDict, Features, IterableDataset, IterableDatasetDict, NamedSplit from ..utils.typing import NestedDataStructureLike, PathLike @@ -28,7 +25,7 @@ def __init__( self.kwargs = kwargs @abstractmethod - def read(self) -> Union[Dataset, DatasetDict, IterableDataset, IterableDataset]: + def read(self) -> Union[Dataset, DatasetDict, IterableDataset, IterableDatasetDict]: pass