diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 0836a04ebe9..43301d23041 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1120,6 +1120,7 @@ def from_generator( gen_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, split: NamedSplit = Split.TRAIN, + fingerprint: Optional[str] = None, **kwargs, ): """Create a Dataset from a generator. @@ -1146,6 +1147,12 @@ def from_generator( Split name to be assigned to the dataset. + fingerprint (`str`, *optional*): + Fingerprint that will be used to generate dataset ID. + By default `fingerprint` is generated by hashing the generator function and all the args which can be slow + if it uses large objects like AI models. + + **kwargs (additional keyword arguments): Keyword arguments to be passed to :[`GeneratorConfig`]. @@ -1183,6 +1190,7 @@ def from_generator( gen_kwargs=gen_kwargs, num_proc=num_proc, split=split, + fingerprint=fingerprint, **kwargs, ).read() diff --git a/src/datasets/builder.py b/src/datasets/builder.py index e63960dcabf..b88aa0bf8f9 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -313,6 +313,7 @@ def __init__( data_dir: Optional[str] = None, storage_options: Optional[dict] = None, writer_batch_size: Optional[int] = None, + config_id: Optional[str] = None, **config_kwargs, ): # DatasetBuilder name @@ -343,6 +344,7 @@ def __init__( self.config, self.config_id = self._create_builder_config( config_name=config_name, custom_features=features, + config_id=config_id, **config_kwargs, ) @@ -502,7 +504,7 @@ def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> st return legacy_relative_data_dir def _create_builder_config( - self, config_name=None, custom_features=None, **config_kwargs + self, config_name=None, custom_features=None, config_id=None, **config_kwargs ) -> tuple[BuilderConfig, str]: """Create and validate BuilderConfig object as well as a unique config id for this config. Raises ValueError if there are multiple builder configs and config_name and DEFAULT_CONFIG_NAME are None. @@ -570,10 +572,11 @@ def _create_builder_config( ) # compute the config id that is going to be used for caching - config_id = builder_config.create_config_id( - config_kwargs, - custom_features=custom_features, - ) + if config_id is None: + config_id = builder_config.create_config_id( + config_kwargs, + custom_features=custom_features, + ) is_custom = (config_id not in self.builder_configs) and config_id != "default" if is_custom: logger.info(f"Using custom data configuration {config_id}") diff --git a/src/datasets/io/generator.py b/src/datasets/io/generator.py index b10609cac23..6c1eaee9b0f 100644 --- a/src/datasets/io/generator.py +++ b/src/datasets/io/generator.py @@ -16,6 +16,7 @@ def __init__( gen_kwargs: Optional[dict] = None, num_proc: Optional[int] = None, split: NamedSplit = Split.TRAIN, + fingerprint: Optional[str] = None, **kwargs, ): super().__init__( @@ -32,8 +33,10 @@ def __init__( generator=generator, gen_kwargs=gen_kwargs, split=split, + config_id="default-fingerprint=" + fingerprint if fingerprint else None, **kwargs, ) + self.fingerprint = fingerprint def read(self): # Build iterable dataset @@ -56,4 +59,6 @@ def read(self): dataset = self.builder.as_dataset( split=self.builder.config.split, verification_mode=verification_mode, in_memory=self.keep_in_memory ) + if self.fingerprint: + dataset._fingerprint = self.fingerprint return dataset diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 4661e8c6dd7..8e76952d6ca 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -4114,6 +4114,16 @@ def test_dataset_from_generator_split(split, data_generator, tmp_path): _check_generator_dataset(dataset, expected_features, expected_split) +@pytest.mark.parametrize("fingerprint", [None, "test-dataset"]) +def test_dataset_from_generator_fingerprint(fingerprint, data_generator, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, fingerprint=fingerprint) + _check_generator_dataset(dataset, expected_features, NamedSplit("train")) + if fingerprint: + assert dataset._fingerprint == fingerprint + + @require_not_windows @require_dill_gt_0_3_2 @require_pyspark