Skip to content

Commit 55f53a2

Browse files
mariosaskolhoestq
andauthored
Add Dataset.from_generator (#4957)
* Add `Dataset.from_generator` * Add tests * Docs * Doc typo * Add Returns to docstring * Docstring for some params * Remove docs changes to test CI * Add from_generator to package reference * Return doc * Fix docstring * Update src/datasets/arrow_dataset.py Co-authored-by: Quentin Lhoest <[email protected]> * Use for loop * Call close on writer instances Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 5b23f58 commit 55f53a2

File tree

9 files changed

+211
-1
lines changed

9 files changed

+211
-1
lines changed

docs/source/loading.mdx

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,21 @@ Load a list of Python dictionaries with [`~Dataset.from_list`]:
220220
>>> dataset = Dataset.from_list(my_list)
221221
```
222222

223+
### Python generator
224+
225+
Create a dataset from a Python generator with [`~Dataset.from_generator`]:
226+
227+
```py
228+
>>> from datasets import Dataset
229+
>>> def my_gen():
230+
... for i in range(1, 4):
231+
... yield {"a": i}
232+
...
233+
>>> dataset = Dataset.from_generator(my_dict)
234+
```
235+
236+
This approach supports loading data larger than available memory.
237+
223238
### Pandas DataFrame
224239

225240
Load Pandas DataFrames with [`~Dataset.from_pandas`]:

docs/source/package_reference/main_classes.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table.
1616
- from_buffer
1717
- from_pandas
1818
- from_dict
19+
- from_generator
1920
- data
2021
- cache_files
2122
- num_columns

src/datasets/arrow_dataset.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,46 @@ def from_csv(
928928
path_or_paths, split=split, features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs
929929
).read()
930930

931+
@staticmethod
932+
def from_generator(
933+
generator: Callable,
934+
features: Optional[Features] = None,
935+
cache_dir: str = None,
936+
keep_in_memory: bool = False,
937+
gen_kwargs: Optional[dict] = None,
938+
):
939+
"""Create a Dataset from a generator.
940+
941+
Args:
942+
generator (:obj:`Callable`): A generator function that `yields` examples.
943+
features (:class:`Features`, optional): Dataset features.
944+
cache_dir (:obj:`str`, optional, default ``"~/.cache/huggingface/datasets"``): Directory to cache data.
945+
keep_in_memory (:obj:`bool`, default ``False``): Whether to copy the data in-memory.
946+
gen_kwargs(:obj:`dict`, optional): Keyword arguments to be passed to the `generator` callable.
947+
948+
Returns:
949+
:class:`Dataset`
950+
951+
Example:
952+
953+
```py
954+
>>> def gen():
955+
... yield {"text": "Good", "label": 0}
956+
... yield {"text": "Bad", "label": 1}
957+
...
958+
>>> ds = Dataset.from_generator(gen)
959+
```
960+
"""
961+
from .io.generator import GeneratorDatasetInputStream
962+
963+
return GeneratorDatasetInputStream(
964+
generator=generator,
965+
features=features,
966+
cache_dir=cache_dir,
967+
keep_in_memory=keep_in_memory,
968+
gen_kwargs=gen_kwargs,
969+
).read()
970+
931971
@staticmethod
932972
def from_json(
933973
path_or_paths: Union[PathLike, List[PathLike]],

src/datasets/builder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1371,6 +1371,7 @@ def _prepare_split(
13711371
):
13721372
if max_shard_size is not None and writer._num_bytes > max_shard_size:
13731373
num_examples, num_bytes = writer.finalize()
1374+
writer.close()
13741375
total_num_examples += num_examples
13751376
total_num_bytes += num_bytes
13761377
shard_id += 1
@@ -1382,11 +1383,12 @@ def _prepare_split(
13821383
check_duplicates=check_duplicate_keys,
13831384
storage_options=self._fs.storage_options,
13841385
)
1385-
example = self.info.features.encode_example(record)
1386+
example = self.info.features.encode_example(record) if self.info.features is not None else record
13861387
writer.write(example, key)
13871388
finally:
13881389
num_shards = shard_id + 1
13891390
num_examples, num_bytes = writer.finalize()
1391+
writer.close()
13901392
total_num_examples += num_examples
13911393
total_num_bytes += num_bytes
13921394

@@ -1492,6 +1494,7 @@ def _prepare_split(
14921494
):
14931495
if max_shard_size is not None and writer._num_bytes > max_shard_size:
14941496
num_examples, num_bytes = writer.finalize()
1497+
writer.close()
14951498
total_num_examples += num_examples
14961499
total_num_bytes += num_bytes
14971500
shard_id += 1
@@ -1504,6 +1507,7 @@ def _prepare_split(
15041507
finally:
15051508
num_shards = shard_id + 1
15061509
num_examples, num_bytes = writer.finalize()
1510+
writer.close()
15071511
total_num_examples += num_examples
15081512
total_num_bytes += num_bytes
15091513

src/datasets/io/abc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,21 @@ def __init__(
2626
@abstractmethod
2727
def read(self) -> Union[Dataset, DatasetDict]:
2828
pass
29+
30+
31+
class AbstractDatasetInputStream(ABC):
32+
def __init__(
33+
self,
34+
features: Optional[Features] = None,
35+
cache_dir: str = None,
36+
keep_in_memory: bool = False,
37+
**kwargs,
38+
):
39+
self.features = features
40+
self.cache_dir = cache_dir
41+
self.keep_in_memory = keep_in_memory
42+
self.kwargs = kwargs
43+
44+
@abstractmethod
45+
def read(self) -> Dataset:
46+
pass

src/datasets/io/generator.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Callable, Optional
2+
3+
from .. import Features
4+
from ..packaged_modules.generator.generator import Generator
5+
from .abc import AbstractDatasetInputStream
6+
7+
8+
class GeneratorDatasetInputStream(AbstractDatasetInputStream):
9+
def __init__(
10+
self,
11+
generator: Callable,
12+
features: Optional[Features] = None,
13+
cache_dir: str = None,
14+
keep_in_memory: bool = False,
15+
gen_kwargs: Optional[dict] = None,
16+
**kwargs,
17+
):
18+
super().__init__(features=features, cache_dir=cache_dir, keep_in_memory=keep_in_memory, **kwargs)
19+
self.builder = Generator(
20+
cache_dir=cache_dir,
21+
features=features,
22+
generator=generator,
23+
gen_kwargs=gen_kwargs,
24+
**kwargs,
25+
)
26+
27+
def read(self):
28+
download_config = None
29+
download_mode = None
30+
ignore_verifications = False
31+
use_auth_token = None
32+
base_path = None
33+
34+
self.builder.download_and_prepare(
35+
download_config=download_config,
36+
download_mode=download_mode,
37+
ignore_verifications=ignore_verifications,
38+
# try_from_hf_gcs=try_from_hf_gcs,
39+
base_path=base_path,
40+
use_auth_token=use_auth_token,
41+
)
42+
43+
# Build dataset for splits
44+
dataset = self.builder.as_dataset(
45+
split="train", ignore_verifications=ignore_verifications, in_memory=self.keep_in_memory
46+
)
47+
return dataset

src/datasets/packaged_modules/generator/__init__.py

Whitespace-only changes.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from dataclasses import dataclass
2+
from typing import Callable, Optional
3+
4+
import datasets
5+
6+
7+
@dataclass
8+
class GeneratorConfig(datasets.BuilderConfig):
9+
generator: Optional[Callable] = None
10+
gen_kwargs: Optional[dict] = None
11+
features: Optional[datasets.Features] = None
12+
13+
def __post_init__(self):
14+
assert self.generator is not None, "generator must be specified"
15+
16+
if self.gen_kwargs is None:
17+
self.gen_kwargs = {}
18+
19+
20+
class Generator(datasets.GeneratorBasedBuilder):
21+
BUILDER_CONFIG_CLASS = GeneratorConfig
22+
23+
def _info(self):
24+
return datasets.DatasetInfo(features=self.config.features)
25+
26+
def _split_generators(self, dl_manager):
27+
return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={})]
28+
29+
def _generate_examples(self):
30+
for idx, ex in enumerate(self.config.generator(**self.config.gen_kwargs)):
31+
yield idx, ex

tests/test_arrow_dataset.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3178,6 +3178,60 @@ def test_dataset_from_text_path_type(path_type, text_path, tmp_path):
31783178
_check_text_dataset(dataset, expected_features)
31793179

31803180

3181+
@pytest.fixture
3182+
def data_generator():
3183+
def _gen():
3184+
data = [
3185+
{"col_1": "0", "col_2": 0, "col_3": 0.0},
3186+
{"col_1": "1", "col_2": 1, "col_3": 1.0},
3187+
{"col_1": "2", "col_2": 2, "col_3": 2.0},
3188+
{"col_1": "3", "col_2": 3, "col_3": 3.0},
3189+
]
3190+
for item in data:
3191+
yield item
3192+
3193+
return _gen
3194+
3195+
3196+
def _check_generator_dataset(dataset, expected_features):
3197+
assert isinstance(dataset, Dataset)
3198+
assert dataset.num_rows == 4
3199+
assert dataset.num_columns == 3
3200+
assert dataset.column_names == ["col_1", "col_2", "col_3"]
3201+
for feature, expected_dtype in expected_features.items():
3202+
assert dataset.features[feature].dtype == expected_dtype
3203+
3204+
3205+
@pytest.mark.parametrize("keep_in_memory", [False, True])
3206+
def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, tmp_path):
3207+
cache_dir = tmp_path / "cache"
3208+
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3209+
with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase():
3210+
dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory)
3211+
_check_generator_dataset(dataset, expected_features)
3212+
3213+
3214+
@pytest.mark.parametrize(
3215+
"features",
3216+
[
3217+
None,
3218+
{"col_1": "string", "col_2": "int64", "col_3": "float64"},
3219+
{"col_1": "string", "col_2": "string", "col_3": "string"},
3220+
{"col_1": "int32", "col_2": "int32", "col_3": "int32"},
3221+
{"col_1": "float32", "col_2": "float32", "col_3": "float32"},
3222+
],
3223+
)
3224+
def test_dataset_from_generator_features(features, data_generator, tmp_path):
3225+
cache_dir = tmp_path / "cache"
3226+
default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
3227+
expected_features = features.copy() if features else default_expected_features
3228+
features = (
3229+
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
3230+
)
3231+
dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir)
3232+
_check_generator_dataset(dataset, expected_features)
3233+
3234+
31813235
def test_dataset_to_json(dataset, tmp_path):
31823236
file_path = tmp_path / "test_path.jsonl"
31833237
bytes_written = dataset.to_json(path_or_buf=file_path)

0 commit comments

Comments
 (0)