Skip to content
Merged
10 changes: 7 additions & 3 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
update_fingerprint,
)
from .formatting import format_table, get_format_type_from_alias, get_formatter, query_table
from .info import DATASET_INFO_FILENAME, DatasetInfo
from .info import DatasetInfo
from .search import IndexableMixin
from .splits import NamedSplit
from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables, list_table_cache_files
Expand Down Expand Up @@ -616,7 +616,9 @@ def save_to_disk(self, dataset_path: str, fs=None):
Path(dataset_path, config.DATASET_STATE_JSON_FILENAME).as_posix(), "w", encoding="utf-8"
) as state_file:
json.dump(state, state_file, indent=2, sort_keys=True)
with fs.open(Path(dataset_path, DATASET_INFO_FILENAME).as_posix(), "w", encoding="utf-8") as dataset_info_file:
with fs.open(
Path(dataset_path, config.DATASET_INFO_FILENAME).as_posix(), "w", encoding="utf-8"
) as dataset_info_file:
json.dump(dataset_info, dataset_info_file, indent=2, sort_keys=True)
logger.info("Dataset saved in {}".format(dataset_path))

Expand Down Expand Up @@ -653,7 +655,9 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] =
Path(dataset_path, config.DATASET_STATE_JSON_FILENAME).as_posix(), "r", encoding="utf-8"
) as state_file:
state = json.load(state_file)
with open(Path(dataset_path, DATASET_INFO_FILENAME).as_posix(), "r", encoding="utf-8") as dataset_info_file:
with open(
Path(dataset_path, config.DATASET_INFO_FILENAME).as_posix(), "r", encoding="utf-8"
) as dataset_info_file:
dataset_info = DatasetInfo.from_dict(json.load(dataset_info_file))

dataset_size = estimate_dataset_size(
Expand Down
27 changes: 7 additions & 20 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,7 @@
from .arrow_writer import ArrowWriter, BeamWriter
from .dataset_dict import DatasetDict
from .fingerprint import Hasher
from .info import (
DATASET_INFO_FILENAME,
DATASET_INFOS_DICT_FILE_NAME,
LICENSE_FILENAME,
DatasetInfo,
DatasetInfosDict,
PostProcessedInfo,
)
from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo
from .naming import camelcase_to_snakecase, filename_prefix_for_split
from .splits import Split, SplitDict, SplitGenerator
from .utils.download_manager import DownloadManager, GenerateMode
Expand All @@ -55,12 +48,6 @@

logger = get_logger(__name__)

FORCE_REDOWNLOAD = GenerateMode.FORCE_REDOWNLOAD
REUSE_CACHE_IF_EXISTS = GenerateMode.REUSE_CACHE_IF_EXISTS
REUSE_DATASET_IF_EXISTS = GenerateMode.REUSE_DATASET_IF_EXISTS

MAX_DIRECTORY_NAME_LENGTH = 255


class InvalidConfigName(ValueError):
pass
Expand Down Expand Up @@ -175,7 +162,7 @@ def create_config_id(self, config_kwargs: dict, custom_features: Optional[Featur

if suffix:
config_id = self.name + "-" + suffix
if len(config_id) > MAX_DIRECTORY_NAME_LENGTH:
if len(config_id) > config.MAX_DATASET_CONFIG_ID_READABLE_LENGTH:
config_id = self.name + "-" + Hasher.hash(suffix)
return config_id
else:
Expand Down Expand Up @@ -297,7 +284,7 @@ def manual_download_instructions(self) -> Optional[str]:
@classmethod
def get_all_exported_dataset_infos(cls) -> dict:
"""Empty dict if doesn't exist"""
dset_infos_file_path = os.path.join(cls.get_imported_module_dir(), DATASET_INFOS_DICT_FILE_NAME)
dset_infos_file_path = os.path.join(cls.get_imported_module_dir(), config.DATASETDICT_INFOS_FILENAME)
if os.path.exists(dset_infos_file_path):
return DatasetInfosDict.from_directory(cls.get_imported_module_dir())
return {}
Expand Down Expand Up @@ -496,7 +483,7 @@ def download_and_prepare(
if download_config is None:
download_config = DownloadConfig(
cache_dir=os.path.join(self._cache_dir_root, "downloads"),
force_download=bool(download_mode == FORCE_REDOWNLOAD),
force_download=bool(download_mode == GenerateMode.FORCE_REDOWNLOAD),
use_etag=False,
use_auth_token=use_auth_token,
) # We don't use etag for data files to speed up the process
Expand All @@ -515,7 +502,7 @@ def download_and_prepare(
lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock")
with FileLock(lock_path):
data_exists = os.path.exists(self._cache_dir)
if data_exists and download_mode == REUSE_DATASET_IF_EXISTS:
if data_exists and download_mode == GenerateMode.REUSE_DATASET_IF_EXISTS:
logger.warning("Reusing dataset %s (%s)", self.name, self._cache_dir)
# We need to update the info in case some splits were added in the meantime
# for example when calling load_dataset from multiple workers.
Expand Down Expand Up @@ -1174,9 +1161,9 @@ def _save_info(self):
import apache_beam as beam

fs = beam.io.filesystems.FileSystems
with fs.create(os.path.join(self._cache_dir, DATASET_INFO_FILENAME)) as f:
with fs.create(os.path.join(self._cache_dir, config.DATASET_INFO_FILENAME)) as f:
self.info._dump_info(f)
with fs.create(os.path.join(self._cache_dir, LICENSE_FILENAME)) as f:
with fs.create(os.path.join(self._cache_dir, config.LICENSE_FILENAME)) as f:
self.info._dump_license(f)

def _prepare_split(self, split_generator, pipeline):
Expand Down
12 changes: 7 additions & 5 deletions src/datasets/commands/run_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from typing import List

from datasets import config
from datasets.builder import FORCE_REDOWNLOAD, REUSE_CACHE_IF_EXISTS, DatasetBuilder, DownloadConfig
from datasets.builder import DatasetBuilder
from datasets.commands import BaseTransformersCLICommand
from datasets.info import DATASET_INFOS_DICT_FILE_NAME
from datasets.load import import_main_class, prepare_module
from datasets.utils.download_manager import DownloadConfig, GenerateMode


def run_beam_command_factory(args):
Expand Down Expand Up @@ -113,7 +113,9 @@ def run(self):

for builder in builders:
builder.download_and_prepare(
download_mode=REUSE_CACHE_IF_EXISTS if not self._force_redownload else FORCE_REDOWNLOAD,
download_mode=GenerateMode.REUSE_CACHE_IF_EXISTS
if not self._force_redownload
else GenerateMode.FORCE_REDOWNLOAD,
download_config=DownloadConfig(cache_dir=os.path.join(config.HF_DATASETS_CACHE, "downloads")),
save_infos=self._save_infos,
ignore_verifications=self._ignore_verifications,
Expand All @@ -126,7 +128,7 @@ def run(self):
# Let's move it to the original directory of the dataset script, to allow the user to
# upload them on S3 at the same time afterwards.
if self._save_infos:
dataset_infos_path = os.path.join(builder_cls.get_imported_module_dir(), DATASET_INFOS_DICT_FILE_NAME)
dataset_infos_path = os.path.join(builder_cls.get_imported_module_dir(), config.DATASETDICT_INFOS_FILENAME)

name = Path(path).name + ".py"

Expand All @@ -140,6 +142,6 @@ def run(self):
exit(1)

# Move datasetinfo back to the user
user_dataset_infos_path = os.path.join(dataset_dir, DATASET_INFOS_DICT_FILE_NAME)
user_dataset_infos_path = os.path.join(dataset_dir, config.DATASETDICT_INFOS_FILENAME)
copyfile(dataset_infos_path, user_dataset_infos_path)
print("Dataset Infos file saved at {}".format(user_dataset_infos_path))
15 changes: 10 additions & 5 deletions src/datasets/commands/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from shutil import copyfile, rmtree
from typing import Generator

from datasets.builder import FORCE_REDOWNLOAD, REUSE_CACHE_IF_EXISTS, DatasetBuilder
import datasets.config
from datasets.builder import DatasetBuilder
from datasets.commands import BaseTransformersCLICommand
from datasets.info import DATASET_INFOS_DICT_FILE_NAME
from datasets.load import import_main_class, prepare_module
from datasets.utils.download_manager import GenerateMode
from datasets.utils.filelock import logger as fl_logger
from datasets.utils.logging import ERROR, get_logger

Expand Down Expand Up @@ -136,7 +137,9 @@ def get_builders() -> Generator[DatasetBuilder, None, None]:
for j, builder in enumerate(get_builders()):
print(f"Testing builder '{builder.config.name}' ({j + 1}/{n_builders})")
builder.download_and_prepare(
download_mode=REUSE_CACHE_IF_EXISTS if not self._force_redownload else FORCE_REDOWNLOAD,
download_mode=GenerateMode.REUSE_CACHE_IF_EXISTS
if not self._force_redownload
else GenerateMode.FORCE_REDOWNLOAD,
ignore_verifications=self._ignore_verifications,
try_from_hf_gcs=False,
)
Expand All @@ -148,7 +151,9 @@ def get_builders() -> Generator[DatasetBuilder, None, None]:
# Let's move it to the original directory of the dataset script, to allow the user to
# upload them on S3 at the same time afterwards.
if self._save_infos:
dataset_infos_path = os.path.join(builder_cls.get_imported_module_dir(), DATASET_INFOS_DICT_FILE_NAME)
dataset_infos_path = os.path.join(
builder_cls.get_imported_module_dir(), datasets.config.DATASETDICT_INFOS_FILENAME
)
name = Path(path).name + ".py"
combined_path = os.path.join(path, name)
if os.path.isfile(path):
Expand All @@ -161,7 +166,7 @@ def get_builders() -> Generator[DatasetBuilder, None, None]:

# Move dataset_info back to the user
if dataset_dir is not None:
user_dataset_infos_path = os.path.join(dataset_dir, DATASET_INFOS_DICT_FILE_NAME)
user_dataset_infos_path = os.path.join(dataset_dir, datasets.config.DATASETDICT_INFOS_FILENAME)
copyfile(dataset_infos_path, user_dataset_infos_path)
print("Dataset Infos file saved at {}".format(user_dataset_infos_path))

Expand Down
35 changes: 18 additions & 17 deletions src/datasets/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,25 +61,20 @@
TF_AVAILABLE = importlib.util.find_spec("tensorflow") is not None
if TF_AVAILABLE:
# For the metadata, we have to look for both tensorflow and tensorflow-cpu
try:
TF_VERSION = importlib_metadata.version("tensorflow")
except importlib_metadata.PackageNotFoundError:
for package in [
"tensorflow",
"tensorflow-cpu",
"tensorflow-gpu",
"tf-nightly",
"tf-nightly-cpu",
"tf-nightly-gpu",
]:
try:
TF_VERSION = importlib_metadata.version("tensorflow-cpu")
TF_VERSION = importlib_metadata.version(package)
except importlib_metadata.PackageNotFoundError:
try:
TF_VERSION = importlib_metadata.version("tensorflow-gpu")
except importlib_metadata.PackageNotFoundError:
try:
TF_VERSION = importlib_metadata.version("tf-nightly")
except importlib_metadata.PackageNotFoundError:
try:
TF_VERSION = importlib_metadata.version("tf-nightly-cpu")
except importlib_metadata.PackageNotFoundError:
try:
TF_VERSION = importlib_metadata.version("tf-nightly-gpu")
except importlib_metadata.PackageNotFoundError:
pass
continue
else:
break
if TF_AVAILABLE:
if version.parse(TF_VERSION) < version.parse("2"):
logger.info(f"TensorFlow found but with version {TF_VERSION}. `datasets` requires version 2 minimum.")
Expand Down Expand Up @@ -155,3 +150,9 @@
DATASET_ARROW_FILENAME = "dataset.arrow"
DATASET_INDICES_FILENAME = "indices.arrow"
DATASET_STATE_JSON_FILENAME = "state.json"
DATASET_INFO_FILENAME = "dataset_info.json"
DATASETDICT_INFOS_FILENAME = "dataset_infos.json"
LICENSE_FILENAME = "LICENSE"
METRIC_INFO_FILENAME = "metric_info.json"

MAX_DATASET_CONFIG_ID_READABLE_LENGTH = 255
25 changes: 10 additions & 15 deletions src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from dataclasses import asdict, dataclass, field
from typing import List, Optional, Union

from . import config
from .features import Features, Value
from .splits import SplitDict
from .utils import Version
Expand All @@ -44,12 +45,6 @@

logger = get_logger(__name__)

# Name of the file to output the DatasetInfo p rotobuf object.
DATASET_INFO_FILENAME = "dataset_info.json"
DATASET_INFOS_DICT_FILE_NAME = "dataset_infos.json"
LICENSE_FILENAME = "LICENSE"
METRIC_INFO_FILENAME = "metric_info.json"


@dataclass
class SupervisedKeysData:
Expand Down Expand Up @@ -156,17 +151,17 @@ def __post_init__(self):
self.supervised_keys = SupervisedKeysData(**self.supervised_keys)

def _license_path(self, dataset_info_dir):
return os.path.join(dataset_info_dir, LICENSE_FILENAME)
return os.path.join(dataset_info_dir, config.LICENSE_FILENAME)

def write_to_directory(self, dataset_info_dir):
"""Write `DatasetInfo` as JSON to `dataset_info_dir`.

Also save the license separately in LICENCE.
"""
with open(os.path.join(dataset_info_dir, DATASET_INFO_FILENAME), "wb") as f:
with open(os.path.join(dataset_info_dir, config.DATASET_INFO_FILENAME), "wb") as f:
self._dump_info(f)

with open(os.path.join(dataset_info_dir, LICENSE_FILENAME), "wb") as f:
with open(os.path.join(dataset_info_dir, config.LICENSE_FILENAME), "wb") as f:
self._dump_license(f)

def _dump_info(self, file):
Expand Down Expand Up @@ -220,7 +215,7 @@ def from_directory(cls, dataset_info_dir: str) -> "DatasetInfo":
if not dataset_info_dir:
raise ValueError("Calling DatasetInfo.from_directory() with undefined dataset_info_dir.")

with open(os.path.join(dataset_info_dir, DATASET_INFO_FILENAME), "r", encoding="utf-8") as f:
with open(os.path.join(dataset_info_dir, config.DATASET_INFO_FILENAME), "r", encoding="utf-8") as f:
dataset_info_dict = json.load(f)
return cls.from_dict(dataset_info_dict)

Expand All @@ -246,7 +241,7 @@ def copy(self) -> "DatasetInfo":
class DatasetInfosDict(dict):
def write_to_directory(self, dataset_infos_dir, overwrite=False):
total_dataset_infos = {}
dataset_infos_path = os.path.join(dataset_infos_dir, DATASET_INFOS_DICT_FILE_NAME)
dataset_infos_path = os.path.join(dataset_infos_dir, config.DATASETDICT_INFOS_FILENAME)
if os.path.exists(dataset_infos_path) and not overwrite:
logger.info("Dataset Infos already exists in {}. Completing it with new infos.".format(dataset_infos_dir))
total_dataset_infos = self.from_directory(dataset_infos_dir)
Expand All @@ -259,7 +254,7 @@ def write_to_directory(self, dataset_infos_dir, overwrite=False):
@classmethod
def from_directory(cls, dataset_infos_dir):
logger.info("Loading Dataset Infos from {}".format(dataset_infos_dir))
with open(os.path.join(dataset_infos_dir, DATASET_INFOS_DICT_FILE_NAME), "r", encoding="utf-8") as f:
with open(os.path.join(dataset_infos_dir, config.DATASETDICT_INFOS_FILENAME), "r", encoding="utf-8") as f:
dataset_infos_dict = {
config_name: DatasetInfo.from_dict(dataset_info_dict)
for config_name, dataset_info_dict in json.load(f).items()
Expand Down Expand Up @@ -308,10 +303,10 @@ def write_to_directory(self, metric_info_dir):
"""Write `MetricInfo` as JSON to `metric_info_dir`.
Also save the license separately in LICENCE.
"""
with open(os.path.join(metric_info_dir, METRIC_INFO_FILENAME), "w", encoding="utf-8") as f:
with open(os.path.join(metric_info_dir, config.METRIC_INFO_FILENAME), "w", encoding="utf-8") as f:
json.dump(asdict(self), f)

with open(os.path.join(metric_info_dir, LICENSE_FILENAME), "w", encoding="utf-8") as f:
with open(os.path.join(metric_info_dir, config.LICENSE_FILENAME), "w", encoding="utf-8") as f:
f.write(self.license)

@classmethod
Expand All @@ -326,7 +321,7 @@ def from_directory(cls, metric_info_dir) -> "MetricInfo":
if not metric_info_dir:
raise ValueError("Calling MetricInfo.from_directory() with undefined metric_info_dir.")

with open(os.path.join(metric_info_dir, METRIC_INFO_FILENAME), "r", encoding="utf-8") as f:
with open(os.path.join(metric_info_dir, config.METRIC_INFO_FILENAME), "r", encoding="utf-8") as f:
metric_info_dict = json.load(f)
return cls.from_dict(metric_info_dict)

Expand Down
5 changes: 2 additions & 3 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from .dataset_dict import DatasetDict
from .features import Features
from .filesystems import extract_path_from_uri, is_remote_filesystem
from .info import DATASET_INFOS_DICT_FILE_NAME
from .metric import Metric
from .packaged_modules import _PACKAGED_DATASETS_MODULES, hash_python_lines
from .splits import Split
Expand Down Expand Up @@ -395,7 +394,7 @@ def _get_modification_time(module_hash):
# 2. copy from the local file system inside the modules cache to import it

base_path = url_or_path_parent(file_path) # remove the filename
dataset_infos = url_or_path_join(base_path, DATASET_INFOS_DICT_FILE_NAME)
dataset_infos = url_or_path_join(base_path, config.DATASETDICT_INFOS_FILENAME)

# Download the dataset infos file if available
try:
Expand Down Expand Up @@ -462,7 +461,7 @@ def _get_modification_time(module_hash):
hash_folder_path = force_local_path

local_file_path = os.path.join(hash_folder_path, name)
dataset_infos_path = os.path.join(hash_folder_path, DATASET_INFOS_DICT_FILE_NAME)
dataset_infos_path = os.path.join(hash_folder_path, config.DATASETDICT_INFOS_FILENAME)

# Prevent parallel disk operations
lock_path = local_path + ".lock"
Expand Down
3 changes: 2 additions & 1 deletion tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
from absl.testing import parameterized

import datasets.arrow_dataset
from datasets import NamedSplit, concatenate_datasets, load_from_disk, temp_seed
from datasets import concatenate_datasets, load_from_disk, temp_seed
from datasets.arrow_dataset import Dataset, transmit_format, update_metadata_with_features
from datasets.dataset_dict import DatasetDict
from datasets.features import Array2D, Array3D, ClassLabel, Features, Sequence, Value
from datasets.info import DatasetInfo
from datasets.splits import NamedSplit
from datasets.table import ConcatenationTable, InMemoryTable, MemoryMappedTable
from datasets.utils.logging import WARNING

Expand Down
Loading