diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 472a09c9a6e..032b4d6b3c2 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1361,6 +1361,11 @@ def from_sql( **kwargs, ).read() + def __setstate__(self, state): + self.__dict__.update(state) + maybe_register_dataset_for_temp_dir_deletion(self) + return self + def __del__(self): if hasattr(self, "_data"): del self._data diff --git a/src/datasets/config.py b/src/datasets/config.py index 851eb7d899f..d01f4e6f735 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -231,6 +231,9 @@ MAX_DATASET_CONFIG_ID_READABLE_LENGTH = 255 +# Temporary cache directory prefix +TEMP_CACHE_DIR_PREFIX = "hf_datasets-" + # Streaming STREAMING_READ_MAX_RETRIES = 20 STREAMING_READ_RETRY_INTERVAL = 5 diff --git a/src/datasets/fingerprint.py b/src/datasets/fingerprint.py index 7d2f2bdb98a..d7b3220ad96 100644 --- a/src/datasets/fingerprint.py +++ b/src/datasets/fingerprint.py @@ -11,6 +11,7 @@ import numpy as np import xxhash +from . import config from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH from .utils._dill import dumps from .utils.deprecation_utils import deprecated @@ -38,28 +39,30 @@ ################# _CACHING_ENABLED = True -_TEMP_DIR_FOR_TEMP_CACHE_FILES: Optional["_TempDirWithCustomCleanup"] = None +_TEMP_DIR_FOR_TEMP_CACHE_FILES: Optional["_TempCacheDir"] = None _DATASETS_WITH_TABLE_IN_TEMP_DIR: Optional[weakref.WeakSet] = None -class _TempDirWithCustomCleanup: +class _TempCacheDir: """ - A temporary directory with a custom cleanup function. - We need a custom temporary directory cleanup in order to delete the dataset objects that have - cache files in the temporary directory before deleting the dorectory itself. + A temporary directory for storing cached Arrow files with a cleanup that frees references to the Arrow files + before deleting the directory itself to avoid permission errors on Windows. """ - def __init__(self, cleanup_func=None, *cleanup_func_args, **cleanup_func_kwargs): - self.name = tempfile.mkdtemp() + def __init__(self): + self.name = tempfile.mkdtemp(prefix=config.TEMP_CACHE_DIR_PREFIX) self._finalizer = weakref.finalize(self, self._cleanup) - self._cleanup_func = cleanup_func - self._cleanup_func_args = cleanup_func_args - self._cleanup_func_kwargs = cleanup_func_kwargs def _cleanup(self): - self._cleanup_func(*self._cleanup_func_args, **self._cleanup_func_kwargs) + for dset in get_datasets_with_cache_file_in_temp_dir(): + dset.__del__() if os.path.exists(self.name): - shutil.rmtree(self.name) + try: + shutil.rmtree(self.name) + except Exception as e: + raise OSError( + f"An error occured while trying to delete temporary cache directory {self.name}. Please delete it manually." + ) from e def cleanup(self): if self._finalizer.detach(): @@ -180,13 +183,7 @@ def get_temporary_cache_files_directory() -> str: """Return a directory that is deleted when session closes.""" global _TEMP_DIR_FOR_TEMP_CACHE_FILES if _TEMP_DIR_FOR_TEMP_CACHE_FILES is None: - # Avoids a PermissionError on Windows caused by the datasets referencing - # the files from the cache directory on clean-up - def cleanup_func(): - for dset in get_datasets_with_cache_file_in_temp_dir(): - dset.__del__() - - _TEMP_DIR_FOR_TEMP_CACHE_FILES = _TempDirWithCustomCleanup(cleanup_func=cleanup_func) + _TEMP_DIR_FOR_TEMP_CACHE_FILES = _TempCacheDir() return _TEMP_DIR_FOR_TEMP_CACHE_FILES.name diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 4d4ecc9802b..0b9b4e5a0a7 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -1392,8 +1392,14 @@ def test_map_caching(self, in_memory): self.assertEqual(len(dset_test2.cache_files), 1) self.assertNotIn("Loading cached processed dataset", self._caplog.text) # make sure the arrow files are going to be removed - self.assertIn("tmp", dset_test1.cache_files[0]["filename"]) - self.assertIn("tmp", dset_test2.cache_files[0]["filename"]) + self.assertIn( + Path(tempfile.gettempdir()), + Path(dset_test1.cache_files[0]["filename"]).parents, + ) + self.assertIn( + Path(tempfile.gettempdir()), + Path(dset_test2.cache_files[0]["filename"]).parents, + ) finally: datasets.enable_caching() @@ -3985,11 +3991,11 @@ def test_build_local_temp_path(uri_or_path): extracted_path = strip_protocol(uri_or_path) local_temp_path = Dataset._build_local_temp_path(extracted_path).as_posix() extracted_path_without_anchor = Path(extracted_path).relative_to(Path(extracted_path).anchor).as_posix() - path_relative_to_tmp_dir = local_temp_path.split("tmp")[-1].split("/", 1)[1] + # Check that the local temp path is relative to the system temp dir + path_relative_to_tmp_dir = Path(local_temp_path).relative_to(Path(tempfile.gettempdir())).as_posix() assert ( - "tmp" in local_temp_path - and "hdfs" not in path_relative_to_tmp_dir + "hdfs" not in path_relative_to_tmp_dir and "s3" not in path_relative_to_tmp_dir and not local_temp_path.startswith(extracted_path_without_anchor) and local_temp_path.endswith(extracted_path_without_anchor)