diff --git a/src/datasets/load.py b/src/datasets/load.py index f9bbad0f04c..c28335ecb7c 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -49,6 +49,7 @@ hf_github_url, hf_hub_url, init_hf_modules, + relative_to_absolute_path, url_or_path_join, url_or_path_parent, ) @@ -246,7 +247,7 @@ def prepare_module( script_version (Optional ``Union[str, datasets.Version]``): If specified, the module will be loaded from the datasets repository at this version. By default: - - it is set to the local version fo the lib. + - it is set to the local version of the lib. - it will also try to load it from the master branch if it's not available at the local version fo the lib. Specifying a version that is different from your local version of the lib might cause compatibility issues. download_config (Optional ``datasets.DownloadConfig``: specific download configuration parameters. @@ -307,15 +308,18 @@ def prepare_module( # - if os.path.join(path, name) is a file or a remote url # - if path is a file or a remote url # - otherwise we assume path/name is a path to our S3 bucket - combined_path = os.path.join(path, name) + combined_path = path if path.endswith(name) else os.path.join(path, name) + if os.path.isfile(combined_path): file_path = combined_path - local_path = file_path + local_path = combined_path elif os.path.isfile(path): file_path = path local_path = path else: # Try github (canonical datasets/metrics) and then S3 (users datasets/metrics) + + combined_path_abs = relative_to_absolute_path(combined_path) try: head_hf_s3(path, filename=name, dataset=dataset, max_retries=download_config.max_retries) script_version = str(script_version) if script_version is not None else None @@ -326,7 +330,7 @@ def prepare_module( except FileNotFoundError: if script_version is not None: raise FileNotFoundError( - "Couldn't find remote file with version {} at {}. Please provide a valid version and a valid {} name".format( + "Couldn't find remote file with version {} at {}. Please provide a valid version and a valid {} name.".format( script_version, file_path, "dataset" if dataset else "metric" ) ) @@ -338,14 +342,14 @@ def prepare_module( logger.warning( "Couldn't find file locally at {}, or remotely at {}.\n" "The file was picked from the master branch on github instead at {}.".format( - combined_path, github_file_path, file_path + combined_path_abs, github_file_path, file_path ) ) except FileNotFoundError: raise FileNotFoundError( "Couldn't find file locally at {}, or remotely at {}.\n" "The file is also not present on the master branch on github.".format( - combined_path, github_file_path + combined_path_abs, github_file_path ) ) elif path.count("/") == 1: # users datasets/metrics: s3 path (hub for datasets and s3 for metrics) @@ -357,14 +361,14 @@ def prepare_module( local_path = cached_path(file_path, download_config=download_config) except FileNotFoundError: raise FileNotFoundError( - "Couldn't find file locally at {}, or remotely at {}. Please provide a valid {} name".format( - combined_path, file_path, "dataset" if dataset else "metric" + "Couldn't find file locally at {}, or remotely at {}. Please provide a valid {} name.".format( + combined_path_abs, file_path, "dataset" if dataset else "metric" ) ) else: raise FileNotFoundError( - "Couldn't find file locally at {}. Please provide a valid {} name".format( - combined_path, "dataset" if dataset else "metric" + "Couldn't find file locally at {}. Please provide a valid {} name.".format( + combined_path_abs, "dataset" if dataset else "metric" ) ) except Exception as e: # noqa: all the attempts failed, before raising the error we should check if the module already exists. @@ -382,7 +386,7 @@ def _get_modification_time(module_hash): logger.warning( f"Using the latest cached version of the module from {os.path.join(main_folder_path, hash)} " f"(last modified on {time.ctime(_get_modification_time(hash))}) since it " - f"couldn't be found locally at {combined_path} or remotely ({type(e).__name__})." + f"couldn't be found locally at {combined_path_abs}, or remotely ({type(e).__name__})." ) if return_resolved_file_path: with open(os.path.join(main_folder_path, hash, short_name + ".json")) as cache_metadata: @@ -421,7 +425,7 @@ def _get_modification_time(module_hash): f"Error in {module_type} script at {file_path}, importing relative {import_name} module " f"but {import_name} is the name of the {module_type} script. " f"Please change relative import {import_name} to another name and add a '# From: URL_OR_PATH' " - f"comment pointing to the original realtive import file path." + f"comment pointing to the original relative import file path." ) if import_type == "internal": url_or_filename = url_or_path_join(base_path, import_path + ".py") diff --git a/src/datasets/utils/__init__.py b/src/datasets/utils/__init__.py index 5f3165cb489..79e79ae3834 100644 --- a/src/datasets/utils/__init__.py +++ b/src/datasets/utils/__init__.py @@ -19,7 +19,7 @@ from . import logging from .download_manager import DownloadManager, GenerateMode -from .file_utils import DownloadConfig, cached_path, hf_bucket_url, is_remote_url, temp_seed +from .file_utils import DownloadConfig, cached_path, hf_bucket_url, is_remote_url, relative_to_absolute_path, temp_seed from .mock_download_manager import MockDownloadManager from .py_utils import ( NonMutableDict, diff --git a/src/datasets/utils/file_utils.py b/src/datasets/utils/file_utils.py index fa967c0feae..6035bc69ff5 100644 --- a/src/datasets/utils/file_utils.py +++ b/src/datasets/utils/file_utils.py @@ -18,7 +18,7 @@ from functools import partial from hashlib import sha256 from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, TypeVar, Union from urllib.parse import urlparse import numpy as np @@ -37,6 +37,8 @@ INCOMPLETE_SUFFIX = ".incomplete" +T = TypeVar("T", str, Path) + def init_hf_modules(hf_modules_cache: Optional[Union[Path, str]] = None) -> str: """ @@ -127,6 +129,12 @@ def is_relative_path(url_or_filename: str) -> bool: return urlparse(url_or_filename).scheme == "" and not os.path.isabs(url_or_filename) +def relative_to_absolute_path(path: T) -> T: + """Convert relative path to absolute path.""" + abs_path_str = os.path.abspath(os.path.expanduser(os.path.expandvars(str(path)))) + return Path(abs_path_str) if isinstance(path, Path) else abs_path_str + + def hf_bucket_url(identifier: str, filename: str, use_cdn=False, dataset=True) -> str: if dataset: endpoint = config.CLOUDFRONT_DATASETS_DISTRIB_PREFIX if use_cdn else config.S3_DATASETS_BUCKET_PREFIX diff --git a/src/datasets/utils/filelock.py b/src/datasets/utils/filelock.py index fe4452dcb78..d440348e7c2 100644 --- a/src/datasets/utils/filelock.py +++ b/src/datasets/utils/filelock.py @@ -353,8 +353,10 @@ class WindowsFileLock(BaseFileLock): """ def __init__(self, lock_file, timeout=-1, max_filename_length=255): + from .file_utils import relative_to_absolute_path + super().__init__(lock_file, timeout=timeout, max_filename_length=max_filename_length) - self._lock_file = "\\\\?\\" + os.path.abspath(os.path.expanduser(os.path.expandvars(self._lock_file))) + self._lock_file = "\\\\?\\" + relative_to_absolute_path(self.lock_file) def _acquire(self): open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC diff --git a/tests/test_load.py b/tests/test_load.py index 68c6c87f27c..cc690dcf800 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1,5 +1,6 @@ import importlib import os +import re import shutil import tempfile import time @@ -224,7 +225,10 @@ def test_load_dataset_local(dataset_loading_script_dir, data_dir, keep_in_memory assert "Using the latest cached version of the module" in caplog.text with pytest.raises(FileNotFoundError) as exc_info: datasets.load_dataset("_dummy") - assert "at " + os.path.join("_dummy", "_dummy.py") in str(exc_info.value) + m_combined_path = re.search(fr"\S*{re.escape(os.path.join('_dummy', '_dummy.py'))}\b", str(exc_info.value)) + assert m_combined_path is not None and os.path.isabs(m_combined_path.group()) + m_path = re.search(r"\S*_dummy\b", str(exc_info.value)) + assert m_path is not None and os.path.isabs(m_path.group()) @require_streaming