From 3524459d4a1372e2c9ac319d7afd2fd26544a9d9 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Mon, 1 Jul 2024 09:59:04 +0200 Subject: [PATCH 1/8] add split argument to Generator, from_generator, AbstractDatasetInputStream, GeneratorDatasetInputStream --- src/datasets/arrow_dataset.py | 4 ++++ src/datasets/io/abc.py | 2 ++ src/datasets/io/generator.py | 9 ++++++--- src/datasets/io/sql.py | 2 +- src/datasets/iterable_dataset.py | 9 ++++----- .../packaged_modules/generator/generator.py | 12 +++++++++++- tests/test_arrow_dataset.py | 15 ++++++++++----- 7 files changed, 38 insertions(+), 15 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 779091a75af..0e5bf79166f 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1068,6 +1068,7 @@ def from_generator( keep_in_memory: bool = False, gen_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, + split: Optional[NamedSplit] = None, **kwargs, ): """Create a Dataset from a generator. @@ -1088,6 +1089,8 @@ def from_generator( Number of processes when downloading and generating the dataset locally. This is helpful if the dataset is made of multiple files. Multiprocessing is disabled by default. If `num_proc` is greater than one, then all list values in `gen_kwargs` must be the same length. These values will be split between calls to the generator. The number of shards will be the minimum of the shortest list in `gen_kwargs` and `num_proc`. + split (`str`, defaults to `"train"`): + Split name to be assigned to the dataset. **kwargs (additional keyword arguments): @@ -1126,6 +1129,7 @@ def from_generator( keep_in_memory=keep_in_memory, gen_kwargs=gen_kwargs, num_proc=num_proc, + split=split, **kwargs, ).read() diff --git a/src/datasets/io/abc.py b/src/datasets/io/abc.py index a1913cc20e3..9e0290d3c6f 100644 --- a/src/datasets/io/abc.py +++ b/src/datasets/io/abc.py @@ -39,6 +39,7 @@ def __init__( keep_in_memory: bool = False, streaming: bool = False, num_proc: Optional[int] = None, + split: Optional[NamedSplit] = None, **kwargs, ): self.features = features @@ -47,6 +48,7 @@ def __init__( self.streaming = streaming self.num_proc = num_proc self.kwargs = kwargs + self.split = split if split else "train" @abstractmethod def read(self) -> Union[Dataset, IterableDataset]: diff --git a/src/datasets/io/generator.py b/src/datasets/io/generator.py index 2566d5fcdcc..c2b47f7be98 100644 --- a/src/datasets/io/generator.py +++ b/src/datasets/io/generator.py @@ -1,6 +1,6 @@ from typing import Callable, Optional -from .. import Features +from .. import Features, NamedSplit from ..packaged_modules.generator.generator import Generator from .abc import AbstractDatasetInputStream @@ -15,6 +15,7 @@ def __init__( streaming: bool = False, gen_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, + split: Optional[NamedSplit] = None, **kwargs, ): super().__init__( @@ -23,6 +24,7 @@ def __init__( keep_in_memory=keep_in_memory, streaming=streaming, num_proc=num_proc, + split=split, **kwargs, ) self.builder = Generator( @@ -30,13 +32,14 @@ def __init__( features=features, generator=generator, gen_kwargs=gen_kwargs, + split=split, **kwargs, ) def read(self): # Build iterable dataset if self.streaming: - dataset = self.builder.as_streaming_dataset(split="train") + dataset = self.builder.as_streaming_dataset(split=self.split) # Build regular (map-style) dataset else: download_config = None @@ -52,6 +55,6 @@ def read(self): num_proc=self.num_proc, ) dataset = self.builder.as_dataset( - split="train", verification_mode=verification_mode, in_memory=self.keep_in_memory + split=self.split, verification_mode=verification_mode, in_memory=self.keep_in_memory ) return dataset diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py index 2331e3e6407..1157f0b7643 100644 --- a/src/datasets/io/sql.py +++ b/src/datasets/io/sql.py @@ -48,7 +48,7 @@ def read(self): # Build dataset for splits dataset = self.builder.as_dataset( - split="train", verification_mode=verification_mode, in_memory=self.keep_in_memory + split=self.split, verification_mode=verification_mode, in_memory=self.keep_in_memory ) return dataset diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 3d0b3ce1cf3..f79d0bd9b83 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -2062,6 +2062,7 @@ def from_generator( generator: Callable, features: Optional[Features] = None, gen_kwargs: Optional[dict] = None, + split: Optional[NamedSplit] = None, ) -> "IterableDataset": """Create an Iterable Dataset from a generator. @@ -2074,7 +2075,8 @@ def from_generator( Keyword arguments to be passed to the `generator` callable. You can define a sharded iterable dataset by passing the list of shards in `gen_kwargs`. This can be used to improve shuffling and when iterating over the dataset with multiple workers. - + split(`str`, default="train"): + Split name to be assigned to the dataset. Returns: `IterableDataset` @@ -2105,10 +2107,7 @@ def from_generator( from .io.generator import GeneratorDatasetInputStream return GeneratorDatasetInputStream( - generator=generator, - features=features, - gen_kwargs=gen_kwargs, - streaming=True, + generator=generator, features=features, gen_kwargs=gen_kwargs, streaming=True, split=split ).read() @staticmethod diff --git a/src/datasets/packaged_modules/generator/generator.py b/src/datasets/packaged_modules/generator/generator.py index 336942f2edc..dec44b5f023 100644 --- a/src/datasets/packaged_modules/generator/generator.py +++ b/src/datasets/packaged_modules/generator/generator.py @@ -22,11 +22,21 @@ def __post_init__(self): class Generator(datasets.GeneratorBasedBuilder): BUILDER_CONFIG_CLASS = GeneratorConfig + def __init__( + self, + split: Optional[datasets.NamedSplit] = None, + **kwargs, + ): + self.split = split if split is not None else datasets.Split.TRAIN + return super().__init__( + **kwargs, + ) + def _info(self): return datasets.DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): - return [datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs=self.config.gen_kwargs)] + return [datasets.SplitGenerator(name=self.split, gen_kwargs=self.config.gen_kwargs)] def _generate_examples(self, **gen_kwargs): for idx, ex in enumerate(self.config.generator(**gen_kwargs)): diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index efa7b7ae4c8..3d0b9db186f 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3867,10 +3867,11 @@ def _gen(): return _gen -def _check_generator_dataset(dataset, expected_features): +def _check_generator_dataset(dataset, expected_features, split): assert isinstance(dataset, Dataset) assert dataset.num_rows == 4 assert dataset.num_columns == 3 + assert dataset.split == split assert dataset.column_names == ["col_1", "col_2", "col_3"] for feature, expected_dtype in expected_features.items(): assert dataset.features[feature].dtype == expected_dtype @@ -3880,9 +3881,12 @@ def _check_generator_dataset(dataset, expected_features): def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, tmp_path): cache_dir = tmp_path / "cache" expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + split = NamedSplit("validation") with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): - dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory) - _check_generator_dataset(dataset, expected_features) + dataset = Dataset.from_generator( + data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory, split=split + ) + _check_generator_dataset(dataset, expected_features, split) @pytest.mark.parametrize( @@ -3898,12 +3902,13 @@ def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, t def test_dataset_from_generator_features(features, data_generator, tmp_path): cache_dir = tmp_path / "cache" default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + split = NamedSplit("validation") expected_features = features.copy() if features else default_expected_features features = ( Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None ) - dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir) - _check_generator_dataset(dataset, expected_features) + dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir, split=split) + _check_generator_dataset(dataset, expected_features, split) @require_not_windows From eef7c96a5272d5939d6c50b8e985085b41acd736 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Wed, 10 Jul 2024 07:48:17 +0200 Subject: [PATCH 2/8] split generator review feedbacks --- src/datasets/arrow_dataset.py | 8 +++--- src/datasets/io/abc.py | 2 -- src/datasets/io/generator.py | 7 +++--- src/datasets/io/sql.py | 2 +- src/datasets/iterable_dataset.py | 6 ++--- .../packaged_modules/generator/generator.py | 5 ++-- tests/test_arrow_dataset.py | 25 +++++++++++++------ 7 files changed, 31 insertions(+), 24 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 0e5bf79166f..ca380a9c3d7 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1068,7 +1068,7 @@ def from_generator( keep_in_memory: bool = False, gen_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, - split: Optional[NamedSplit] = None, + split: NamedSplit = Split.TRAIN, **kwargs, ): """Create a Dataset from a generator. @@ -1089,10 +1089,12 @@ def from_generator( Number of processes when downloading and generating the dataset locally. This is helpful if the dataset is made of multiple files. Multiprocessing is disabled by default. If `num_proc` is greater than one, then all list values in `gen_kwargs` must be the same length. These values will be split between calls to the generator. The number of shards will be the minimum of the shortest list in `gen_kwargs` and `num_proc`. - split (`str`, defaults to `"train"`): - Split name to be assigned to the dataset. + split ([`NamedSplit`], defaults to `Split.TRAIN`): + Split name to be assigned to the dataset. + + **kwargs (additional keyword arguments): Keyword arguments to be passed to :[`GeneratorConfig`]. diff --git a/src/datasets/io/abc.py b/src/datasets/io/abc.py index 9e0290d3c6f..a1913cc20e3 100644 --- a/src/datasets/io/abc.py +++ b/src/datasets/io/abc.py @@ -39,7 +39,6 @@ def __init__( keep_in_memory: bool = False, streaming: bool = False, num_proc: Optional[int] = None, - split: Optional[NamedSplit] = None, **kwargs, ): self.features = features @@ -48,7 +47,6 @@ def __init__( self.streaming = streaming self.num_proc = num_proc self.kwargs = kwargs - self.split = split if split else "train" @abstractmethod def read(self) -> Union[Dataset, IterableDataset]: diff --git a/src/datasets/io/generator.py b/src/datasets/io/generator.py index c2b47f7be98..28009c74c48 100644 --- a/src/datasets/io/generator.py +++ b/src/datasets/io/generator.py @@ -15,7 +15,7 @@ def __init__( streaming: bool = False, gen_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, - split: Optional[NamedSplit] = None, + split: NamedSplit = Split.TRAIN, **kwargs, ): super().__init__( @@ -24,7 +24,6 @@ def __init__( keep_in_memory=keep_in_memory, streaming=streaming, num_proc=num_proc, - split=split, **kwargs, ) self.builder = Generator( @@ -39,7 +38,7 @@ def __init__( def read(self): # Build iterable dataset if self.streaming: - dataset = self.builder.as_streaming_dataset(split=self.split) + dataset = self.builder.as_streaming_dataset(split=self.builder.config.split) # Build regular (map-style) dataset else: download_config = None @@ -55,6 +54,6 @@ def read(self): num_proc=self.num_proc, ) dataset = self.builder.as_dataset( - split=self.split, verification_mode=verification_mode, in_memory=self.keep_in_memory + split=self.builder.config.split, verification_mode=verification_mode, in_memory=self.keep_in_memory ) return dataset diff --git a/src/datasets/io/sql.py b/src/datasets/io/sql.py index 1157f0b7643..2331e3e6407 100644 --- a/src/datasets/io/sql.py +++ b/src/datasets/io/sql.py @@ -48,7 +48,7 @@ def read(self): # Build dataset for splits dataset = self.builder.as_dataset( - split=self.split, verification_mode=verification_mode, in_memory=self.keep_in_memory + split="train", verification_mode=verification_mode, in_memory=self.keep_in_memory ) return dataset diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index f79d0bd9b83..29114689ea4 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -2062,7 +2062,7 @@ def from_generator( generator: Callable, features: Optional[Features] = None, gen_kwargs: Optional[dict] = None, - split: Optional[NamedSplit] = None, + split: NamedSplit = Split.TRAIN, ) -> "IterableDataset": """Create an Iterable Dataset from a generator. @@ -2075,7 +2075,7 @@ def from_generator( Keyword arguments to be passed to the `generator` callable. You can define a sharded iterable dataset by passing the list of shards in `gen_kwargs`. This can be used to improve shuffling and when iterating over the dataset with multiple workers. - split(`str`, default="train"): + split(`NamedSplit`, default=Split.TRAIN): Split name to be assigned to the dataset. Returns: `IterableDataset` @@ -2867,7 +2867,7 @@ def _resolve_features(self): def _concatenate_iterable_datasets( dsets: List[IterableDataset], info: Optional[DatasetInfo] = None, - split: Optional[NamedSplit] = None, + split: NamedSplit = None, axis: int = 0, ) -> IterableDataset: """ diff --git a/src/datasets/packaged_modules/generator/generator.py b/src/datasets/packaged_modules/generator/generator.py index dec44b5f023..4ae972bd1c2 100644 --- a/src/datasets/packaged_modules/generator/generator.py +++ b/src/datasets/packaged_modules/generator/generator.py @@ -9,6 +9,7 @@ class GeneratorConfig(datasets.BuilderConfig): generator: Optional[Callable] = None gen_kwargs: Optional[dict] = None features: Optional[datasets.Features] = None + split: datasets.NamedSplit = datasets.Split.TRAIN def __post_init__(self): super().__post_init__() @@ -24,10 +25,8 @@ class Generator(datasets.GeneratorBasedBuilder): def __init__( self, - split: Optional[datasets.NamedSplit] = None, **kwargs, ): - self.split = split if split is not None else datasets.Split.TRAIN return super().__init__( **kwargs, ) @@ -36,7 +35,7 @@ def _info(self): return datasets.DatasetInfo(features=self.config.features) def _split_generators(self, dl_manager): - return [datasets.SplitGenerator(name=self.split, gen_kwargs=self.config.gen_kwargs)] + return [datasets.SplitGenerator(name=self.config.split, gen_kwargs=self.config.gen_kwargs)] def _generate_examples(self, **gen_kwargs): for idx, ex in enumerate(self.config.generator(**gen_kwargs)): diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 3d0b9db186f..329d264df75 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3881,12 +3881,9 @@ def _check_generator_dataset(dataset, expected_features, split): def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, tmp_path): cache_dir = tmp_path / "cache" expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} - split = NamedSplit("validation") with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): - dataset = Dataset.from_generator( - data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory, split=split - ) - _check_generator_dataset(dataset, expected_features, split) + dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory) + _check_generator_dataset(dataset, expected_features, NamedSplit("train")) @pytest.mark.parametrize( @@ -3902,13 +3899,25 @@ def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, t def test_dataset_from_generator_features(features, data_generator, tmp_path): cache_dir = tmp_path / "cache" default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} - split = NamedSplit("validation") expected_features = features.copy() if features else default_expected_features features = ( Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None ) - dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir, split=split) - _check_generator_dataset(dataset, expected_features, split) + dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir) + _check_generator_dataset(dataset, expected_features, NamedSplit("train")) + + +@pytest.mark.parametrize( + "split", + [None, NamedSplit("train"), "train", NamedSplit("foo"), "foo"], +) +def test_dataset_from_generator_split(split, data_generator, tmp_path): + cache_dir = tmp_path / "cache" + default_expected_split = "train" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + expected_split = split if split else default_expected_split + dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, split=split) + _check_generator_dataset(dataset, expected_features, expected_split) @require_not_windows From bdd96629cc28098866f2d1e374f2496f82a21432 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Wed, 10 Jul 2024 07:58:54 +0200 Subject: [PATCH 3/8] import Split --- src/datasets/io/generator.py | 2 +- src/datasets/iterable_dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/io/generator.py b/src/datasets/io/generator.py index 28009c74c48..b10609cac23 100644 --- a/src/datasets/io/generator.py +++ b/src/datasets/io/generator.py @@ -1,6 +1,6 @@ from typing import Callable, Optional -from .. import Features, NamedSplit +from .. import Features, NamedSplit, Split from ..packaged_modules.generator.generator import Generator from .abc import AbstractDatasetInputStream diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index 29114689ea4..e31d6dc0eee 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -19,7 +19,7 @@ from .features.features import FeatureType, _align_features, _check_if_features_can_be_aligned, cast_to_python_objects from .formatting import PythonFormatter, TensorFormatter, get_format_type_from_alias, get_formatter from .info import DatasetInfo -from .splits import NamedSplit +from .splits import NamedSplit, Split from .table import cast_table_to_features, read_schema_from_file, table_cast from .utils.logging import get_logger from .utils.py_utils import Literal From 5512e3fbac2116c4f4ad875573b7545839b7e559 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Wed, 10 Jul 2024 08:08:06 +0200 Subject: [PATCH 4/8] tag added version in iterable_dataset, rollback change in _concatenate_iterable_datasets --- src/datasets/iterable_dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index e31d6dc0eee..d1be6d0c2ce 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -2077,6 +2077,8 @@ def from_generator( This can be used to improve shuffling and when iterating over the dataset with multiple workers. split(`NamedSplit`, default=Split.TRAIN): Split name to be assigned to the dataset. + + Returns: `IterableDataset` @@ -2867,7 +2869,7 @@ def _resolve_features(self): def _concatenate_iterable_datasets( dsets: List[IterableDataset], info: Optional[DatasetInfo] = None, - split: NamedSplit = None, + split: Optional[NamedSplit] = None, axis: int = 0, ) -> IterableDataset: """ From 6f1c18bf0a639d8d17dda50ccdd9b2d8bc2f6a66 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Wed, 10 Jul 2024 08:11:38 +0200 Subject: [PATCH 5/8] rm useless Generator __init__ --- src/datasets/packaged_modules/generator/generator.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/datasets/packaged_modules/generator/generator.py b/src/datasets/packaged_modules/generator/generator.py index 4ae972bd1c2..8a42ba05aa6 100644 --- a/src/datasets/packaged_modules/generator/generator.py +++ b/src/datasets/packaged_modules/generator/generator.py @@ -23,14 +23,6 @@ def __post_init__(self): class Generator(datasets.GeneratorBasedBuilder): BUILDER_CONFIG_CLASS = GeneratorConfig - def __init__( - self, - **kwargs, - ): - return super().__init__( - **kwargs, - ) - def _info(self): return datasets.DatasetInfo(features=self.config.features) From d74a86249a0b6df989113acfb1d2e8276d78727e Mon Sep 17 00:00:00 2001 From: Colle Date: Wed, 10 Jul 2024 09:02:51 +0200 Subject: [PATCH 6/8] docstring formatting Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> --- src/datasets/iterable_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index d1be6d0c2ce..de7c11d341c 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -2075,7 +2075,7 @@ def from_generator( Keyword arguments to be passed to the `generator` callable. You can define a sharded iterable dataset by passing the list of shards in `gen_kwargs`. This can be used to improve shuffling and when iterating over the dataset with multiple workers. - split(`NamedSplit`, default=Split.TRAIN): + split ([`NamedSplit`], defaults to `Split.TRAIN`) Split name to be assigned to the dataset. From 7e50f23815df06eb79e0321c33d9459d3880a898 Mon Sep 17 00:00:00 2001 From: Colle Date: Wed, 10 Jul 2024 10:12:43 +0200 Subject: [PATCH 7/8] format docstring Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> --- src/datasets/iterable_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index de7c11d341c..019d38e73b4 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -2075,7 +2075,7 @@ def from_generator( Keyword arguments to be passed to the `generator` callable. You can define a sharded iterable dataset by passing the list of shards in `gen_kwargs`. This can be used to improve shuffling and when iterating over the dataset with multiple workers. - split ([`NamedSplit`], defaults to `Split.TRAIN`) + split ([`NamedSplit`], defaults to `Split.TRAIN`): Split name to be assigned to the dataset. From 96b9e372c3584ff01829a1619bf3f6077919b9f4 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Thu, 11 Jul 2024 09:22:12 +0200 Subject: [PATCH 8/8] fix test_dataset_from_generator_split[None] --- tests/test_arrow_dataset.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 329d264df75..06c40608b61 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3916,7 +3916,10 @@ def test_dataset_from_generator_split(split, data_generator, tmp_path): default_expected_split = "train" expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} expected_split = split if split else default_expected_split - dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, split=split) + if split: + dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, split=split) + else: + dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir) _check_generator_dataset(dataset, expected_features, expected_split)