Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
hf_github_url,
hf_hub_url,
init_hf_modules,
relative_to_absolute_path,
url_or_path_join,
url_or_path_parent,
)
Expand Down Expand Up @@ -246,7 +247,7 @@ def prepare_module(
script_version (Optional ``Union[str, datasets.Version]``):
If specified, the module will be loaded from the datasets repository at this version.
By default:
- it is set to the local version fo the lib.
- it is set to the local version of the lib.
- it will also try to load it from the master branch if it's not available at the local version fo the lib.
Specifying a version that is different from your local version of the lib might cause compatibility issues.
download_config (Optional ``datasets.DownloadConfig``: specific download configuration parameters.
Expand Down Expand Up @@ -307,15 +308,18 @@ def prepare_module(
# - if os.path.join(path, name) is a file or a remote url
# - if path is a file or a remote url
# - otherwise we assume path/name is a path to our S3 bucket
combined_path = os.path.join(path, name)
combined_path = path if path.endswith(name) else os.path.join(path, name)

if os.path.isfile(combined_path):
file_path = combined_path
local_path = file_path
local_path = combined_path
elif os.path.isfile(path):
file_path = path
local_path = path
else:
# Try github (canonical datasets/metrics) and then S3 (users datasets/metrics)

combined_path_abs = relative_to_absolute_path(combined_path)
try:
head_hf_s3(path, filename=name, dataset=dataset, max_retries=download_config.max_retries)
script_version = str(script_version) if script_version is not None else None
Expand All @@ -326,7 +330,7 @@ def prepare_module(
except FileNotFoundError:
if script_version is not None:
raise FileNotFoundError(
"Couldn't find remote file with version {} at {}. Please provide a valid version and a valid {} name".format(
"Couldn't find remote file with version {} at {}. Please provide a valid version and a valid {} name.".format(
Copy link
Contributor

@stas00 stas00 Jul 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lhoestq, do you guys plan to keep the same style as transformers? If so, the latter fully switched to f"" strings from format.

This could be a good https://github.com/huggingface/datasets/contribute Issue if you choose to do so.

If not, please ignore my comment.

Copy link
Member

@lhoestq lhoestq Jul 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we prefer f-strings than using format, and when it's possible we try to follow the same style as transformers

The changes can be done in another PR :)

script_version, file_path, "dataset" if dataset else "metric"
)
)
Expand All @@ -338,14 +342,14 @@ def prepare_module(
logger.warning(
"Couldn't find file locally at {}, or remotely at {}.\n"
"The file was picked from the master branch on github instead at {}.".format(
combined_path, github_file_path, file_path
combined_path_abs, github_file_path, file_path
)
)
except FileNotFoundError:
raise FileNotFoundError(
"Couldn't find file locally at {}, or remotely at {}.\n"
"The file is also not present on the master branch on github.".format(
combined_path, github_file_path
combined_path_abs, github_file_path
)
)
elif path.count("/") == 1: # users datasets/metrics: s3 path (hub for datasets and s3 for metrics)
Expand All @@ -357,14 +361,14 @@ def prepare_module(
local_path = cached_path(file_path, download_config=download_config)
except FileNotFoundError:
raise FileNotFoundError(
"Couldn't find file locally at {}, or remotely at {}. Please provide a valid {} name".format(
combined_path, file_path, "dataset" if dataset else "metric"
"Couldn't find file locally at {}, or remotely at {}. Please provide a valid {} name.".format(
combined_path_abs, file_path, "dataset" if dataset else "metric"
)
)
else:
raise FileNotFoundError(
"Couldn't find file locally at {}. Please provide a valid {} name".format(
combined_path, "dataset" if dataset else "metric"
"Couldn't find file locally at {}. Please provide a valid {} name.".format(
combined_path_abs, "dataset" if dataset else "metric"
)
)
except Exception as e: # noqa: all the attempts failed, before raising the error we should check if the module already exists.
Expand All @@ -382,7 +386,7 @@ def _get_modification_time(module_hash):
logger.warning(
f"Using the latest cached version of the module from {os.path.join(main_folder_path, hash)} "
f"(last modified on {time.ctime(_get_modification_time(hash))}) since it "
f"couldn't be found locally at {combined_path} or remotely ({type(e).__name__})."
f"couldn't be found locally at {combined_path_abs}, or remotely ({type(e).__name__})."
)
if return_resolved_file_path:
with open(os.path.join(main_folder_path, hash, short_name + ".json")) as cache_metadata:
Expand Down Expand Up @@ -421,7 +425,7 @@ def _get_modification_time(module_hash):
f"Error in {module_type} script at {file_path}, importing relative {import_name} module "
f"but {import_name} is the name of the {module_type} script. "
f"Please change relative import {import_name} to another name and add a '# From: URL_OR_PATH' "
f"comment pointing to the original realtive import file path."
f"comment pointing to the original relative import file path."
)
if import_type == "internal":
url_or_filename = url_or_path_join(base_path, import_path + ".py")
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from . import logging
from .download_manager import DownloadManager, GenerateMode
from .file_utils import DownloadConfig, cached_path, hf_bucket_url, is_remote_url, temp_seed
from .file_utils import DownloadConfig, cached_path, hf_bucket_url, is_remote_url, relative_to_absolute_path, temp_seed
from .mock_download_manager import MockDownloadManager
from .py_utils import (
NonMutableDict,
Expand Down
10 changes: 9 additions & 1 deletion src/datasets/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from functools import partial
from hashlib import sha256
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, TypeVar, Union
from urllib.parse import urlparse

import numpy as np
Expand All @@ -37,6 +37,8 @@

INCOMPLETE_SUFFIX = ".incomplete"

T = TypeVar("T", str, Path)


def init_hf_modules(hf_modules_cache: Optional[Union[Path, str]] = None) -> str:
"""
Expand Down Expand Up @@ -127,6 +129,12 @@ def is_relative_path(url_or_filename: str) -> bool:
return urlparse(url_or_filename).scheme == "" and not os.path.isabs(url_or_filename)


def relative_to_absolute_path(path: T) -> T:
"""Convert relative path to absolute path."""
abs_path_str = os.path.abspath(os.path.expanduser(os.path.expandvars(str(path))))
return Path(abs_path_str) if isinstance(path, Path) else abs_path_str


def hf_bucket_url(identifier: str, filename: str, use_cdn=False, dataset=True) -> str:
if dataset:
endpoint = config.CLOUDFRONT_DATASETS_DISTRIB_PREFIX if use_cdn else config.S3_DATASETS_BUCKET_PREFIX
Expand Down
4 changes: 3 additions & 1 deletion src/datasets/utils/filelock.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,10 @@ class WindowsFileLock(BaseFileLock):
"""

def __init__(self, lock_file, timeout=-1, max_filename_length=255):
from .file_utils import relative_to_absolute_path

super().__init__(lock_file, timeout=timeout, max_filename_length=max_filename_length)
self._lock_file = "\\\\?\\" + os.path.abspath(os.path.expanduser(os.path.expandvars(self._lock_file)))
self._lock_file = "\\\\?\\" + relative_to_absolute_path(self.lock_file)

def _acquire(self):
open_mode = os.O_RDWR | os.O_CREAT | os.O_TRUNC
Expand Down
6 changes: 5 additions & 1 deletion tests/test_load.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
import os
import re
import shutil
import tempfile
import time
Expand Down Expand Up @@ -224,7 +225,10 @@ def test_load_dataset_local(dataset_loading_script_dir, data_dir, keep_in_memory
assert "Using the latest cached version of the module" in caplog.text
with pytest.raises(FileNotFoundError) as exc_info:
datasets.load_dataset("_dummy")
assert "at " + os.path.join("_dummy", "_dummy.py") in str(exc_info.value)
m_combined_path = re.search(fr"\S*{re.escape(os.path.join('_dummy', '_dummy.py'))}\b", str(exc_info.value))
assert m_combined_path is not None and os.path.isabs(m_combined_path.group())
m_path = re.search(r"\S*_dummy\b", str(exc_info.value))
assert m_path is not None and os.path.isabs(m_path.group())
Comment on lines +230 to +231
Copy link
Collaborator Author

@mariosasko mariosasko Jul 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lhoestq Actually, this check doesn't do anything (m_path returns a substring of m_combined_path without .py file extension). We can replace this check with a check which verifies that the error message returns a remote URL.

m_paths = re.findall(r"\S*_dummy/_dummy.py\b", str(exc_info.value))  # on Linux this will match an URL as well as a local_path due to different os.sep, so take the last element (an URL always comes last in the list)
assert len(m_paths) > 0 and is_remote_url(m_paths[-1])  # is_remote_url comes from datasets.utils.file_utils



@require_streaming
Expand Down