diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 36b744a024a..bc57174c90e 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1514,6 +1514,7 @@ def save_to_disk( num_shards: Optional[int] = None, num_proc: Optional[int] = None, storage_options: Optional[dict] = None, + flatten_indices: bool = True, ): """ Saves a dataset to a dataset directory, or in a filesystem using any implementation of `fsspec.spec.AbstractFileSystem`. @@ -1611,10 +1612,13 @@ def save_to_disk( total=len(self), desc=f"Saving the dataset ({shards_done}/{num_shards} shards)", ) + # False avoids rebuilding the dataset and can significantly speed up save_to_disk for those cases. + dataset_for_sharding = self.flatten_indices() if (self._indices is not None and flatten_indices) else self + kwargs_per_job = ( { "job_id": shard_idx, - "shard": self.shard(num_shards=num_shards, index=shard_idx, contiguous=True), + "shard": dataset_for_sharding.shard(num_shards=num_shards, index=shard_idx, contiguous=True), "fpath": posixpath.join(dataset_path, f"data-{shard_idx:05d}-of-{num_shards:05d}.arrow"), "storage_options": storage_options, } diff --git a/src/datasets/dataset_dict.py b/src/datasets/dataset_dict.py index 995103d26e0..5b93abfda42 100644 --- a/src/datasets/dataset_dict.py +++ b/src/datasets/dataset_dict.py @@ -1298,6 +1298,7 @@ def save_to_disk( num_shards: Optional[dict[str, int]] = None, num_proc: Optional[int] = None, storage_options: Optional[dict] = None, + flatten_indices: bool = True, ): """ Saves a dataset dict to a filesystem using `fsspec.spec.AbstractFileSystem`. @@ -1363,6 +1364,7 @@ def save_to_disk( max_shard_size=max_shard_size, num_proc=num_proc, storage_options=storage_options, + flatten_indices=flatten_indices, ) @staticmethod