diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py
index 0836a04ebe9..43301d23041 100644
--- a/src/datasets/arrow_dataset.py
+++ b/src/datasets/arrow_dataset.py
@@ -1120,6 +1120,7 @@ def from_generator(
gen_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
split: NamedSplit = Split.TRAIN,
+ fingerprint: Optional[str] = None,
**kwargs,
):
"""Create a Dataset from a generator.
@@ -1146,6 +1147,12 @@ def from_generator(
Split name to be assigned to the dataset.
+ fingerprint (`str`, *optional*):
+ Fingerprint that will be used to generate dataset ID.
+ By default `fingerprint` is generated by hashing the generator function and all the args which can be slow
+ if it uses large objects like AI models.
+
+
**kwargs (additional keyword arguments):
Keyword arguments to be passed to :[`GeneratorConfig`].
@@ -1183,6 +1190,7 @@ def from_generator(
gen_kwargs=gen_kwargs,
num_proc=num_proc,
split=split,
+ fingerprint=fingerprint,
**kwargs,
).read()
diff --git a/src/datasets/builder.py b/src/datasets/builder.py
index e63960dcabf..b88aa0bf8f9 100644
--- a/src/datasets/builder.py
+++ b/src/datasets/builder.py
@@ -313,6 +313,7 @@ def __init__(
data_dir: Optional[str] = None,
storage_options: Optional[dict] = None,
writer_batch_size: Optional[int] = None,
+ config_id: Optional[str] = None,
**config_kwargs,
):
# DatasetBuilder name
@@ -343,6 +344,7 @@ def __init__(
self.config, self.config_id = self._create_builder_config(
config_name=config_name,
custom_features=features,
+ config_id=config_id,
**config_kwargs,
)
@@ -502,7 +504,7 @@ def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> st
return legacy_relative_data_dir
def _create_builder_config(
- self, config_name=None, custom_features=None, **config_kwargs
+ self, config_name=None, custom_features=None, config_id=None, **config_kwargs
) -> tuple[BuilderConfig, str]:
"""Create and validate BuilderConfig object as well as a unique config id for this config.
Raises ValueError if there are multiple builder configs and config_name and DEFAULT_CONFIG_NAME are None.
@@ -570,10 +572,11 @@ def _create_builder_config(
)
# compute the config id that is going to be used for caching
- config_id = builder_config.create_config_id(
- config_kwargs,
- custom_features=custom_features,
- )
+ if config_id is None:
+ config_id = builder_config.create_config_id(
+ config_kwargs,
+ custom_features=custom_features,
+ )
is_custom = (config_id not in self.builder_configs) and config_id != "default"
if is_custom:
logger.info(f"Using custom data configuration {config_id}")
diff --git a/src/datasets/io/generator.py b/src/datasets/io/generator.py
index b10609cac23..6c1eaee9b0f 100644
--- a/src/datasets/io/generator.py
+++ b/src/datasets/io/generator.py
@@ -16,6 +16,7 @@ def __init__(
gen_kwargs: Optional[dict] = None,
num_proc: Optional[int] = None,
split: NamedSplit = Split.TRAIN,
+ fingerprint: Optional[str] = None,
**kwargs,
):
super().__init__(
@@ -32,8 +33,10 @@ def __init__(
generator=generator,
gen_kwargs=gen_kwargs,
split=split,
+ config_id="default-fingerprint=" + fingerprint if fingerprint else None,
**kwargs,
)
+ self.fingerprint = fingerprint
def read(self):
# Build iterable dataset
@@ -56,4 +59,6 @@ def read(self):
dataset = self.builder.as_dataset(
split=self.builder.config.split, verification_mode=verification_mode, in_memory=self.keep_in_memory
)
+ if self.fingerprint:
+ dataset._fingerprint = self.fingerprint
return dataset
diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py
index 4661e8c6dd7..8e76952d6ca 100644
--- a/tests/test_arrow_dataset.py
+++ b/tests/test_arrow_dataset.py
@@ -4114,6 +4114,16 @@ def test_dataset_from_generator_split(split, data_generator, tmp_path):
_check_generator_dataset(dataset, expected_features, expected_split)
+@pytest.mark.parametrize("fingerprint", [None, "test-dataset"])
+def test_dataset_from_generator_fingerprint(fingerprint, data_generator, tmp_path):
+ cache_dir = tmp_path / "cache"
+ expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
+ dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, fingerprint=fingerprint)
+ _check_generator_dataset(dataset, expected_features, NamedSplit("train"))
+ if fingerprint:
+ assert dataset._fingerprint == fingerprint
+
+
@require_not_windows
@require_dill_gt_0_3_2
@require_pyspark