diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 32c568ec11f..cef58f560ed 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -53,7 +53,7 @@ update_fingerprint, ) from .formatting import format_table, get_format_type_from_alias, get_formatter, query_table -from .info import DATASET_INFO_FILENAME, DatasetInfo +from .info import DatasetInfo from .search import IndexableMixin from .splits import NamedSplit from .table import InMemoryTable, MemoryMappedTable, Table, concat_tables, list_table_cache_files @@ -616,7 +616,9 @@ def save_to_disk(self, dataset_path: str, fs=None): Path(dataset_path, config.DATASET_STATE_JSON_FILENAME).as_posix(), "w", encoding="utf-8" ) as state_file: json.dump(state, state_file, indent=2, sort_keys=True) - with fs.open(Path(dataset_path, DATASET_INFO_FILENAME).as_posix(), "w", encoding="utf-8") as dataset_info_file: + with fs.open( + Path(dataset_path, config.DATASET_INFO_FILENAME).as_posix(), "w", encoding="utf-8" + ) as dataset_info_file: json.dump(dataset_info, dataset_info_file, indent=2, sort_keys=True) logger.info("Dataset saved in {}".format(dataset_path)) @@ -653,7 +655,9 @@ def load_from_disk(dataset_path: str, fs=None, keep_in_memory: Optional[bool] = Path(dataset_path, config.DATASET_STATE_JSON_FILENAME).as_posix(), "r", encoding="utf-8" ) as state_file: state = json.load(state_file) - with open(Path(dataset_path, DATASET_INFO_FILENAME).as_posix(), "r", encoding="utf-8") as dataset_info_file: + with open( + Path(dataset_path, config.DATASET_INFO_FILENAME).as_posix(), "r", encoding="utf-8" + ) as dataset_info_file: dataset_info = DatasetInfo.from_dict(json.load(dataset_info_file)) dataset_size = estimate_dataset_size( diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 37f3a5dcba5..e509dbe8a2a 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -36,14 +36,7 @@ from .arrow_writer import ArrowWriter, BeamWriter from .dataset_dict import DatasetDict from .fingerprint import Hasher -from .info import ( - DATASET_INFO_FILENAME, - DATASET_INFOS_DICT_FILE_NAME, - LICENSE_FILENAME, - DatasetInfo, - DatasetInfosDict, - PostProcessedInfo, -) +from .info import DatasetInfo, DatasetInfosDict, PostProcessedInfo from .naming import camelcase_to_snakecase, filename_prefix_for_split from .splits import Split, SplitDict, SplitGenerator from .utils.download_manager import DownloadManager, GenerateMode @@ -55,12 +48,6 @@ logger = get_logger(__name__) -FORCE_REDOWNLOAD = GenerateMode.FORCE_REDOWNLOAD -REUSE_CACHE_IF_EXISTS = GenerateMode.REUSE_CACHE_IF_EXISTS -REUSE_DATASET_IF_EXISTS = GenerateMode.REUSE_DATASET_IF_EXISTS - -MAX_DIRECTORY_NAME_LENGTH = 255 - class InvalidConfigName(ValueError): pass @@ -175,7 +162,7 @@ def create_config_id(self, config_kwargs: dict, custom_features: Optional[Featur if suffix: config_id = self.name + "-" + suffix - if len(config_id) > MAX_DIRECTORY_NAME_LENGTH: + if len(config_id) > config.MAX_DATASET_CONFIG_ID_READABLE_LENGTH: config_id = self.name + "-" + Hasher.hash(suffix) return config_id else: @@ -297,7 +284,7 @@ def manual_download_instructions(self) -> Optional[str]: @classmethod def get_all_exported_dataset_infos(cls) -> dict: """Empty dict if doesn't exist""" - dset_infos_file_path = os.path.join(cls.get_imported_module_dir(), DATASET_INFOS_DICT_FILE_NAME) + dset_infos_file_path = os.path.join(cls.get_imported_module_dir(), config.DATASETDICT_INFOS_FILENAME) if os.path.exists(dset_infos_file_path): return DatasetInfosDict.from_directory(cls.get_imported_module_dir()) return {} @@ -496,7 +483,7 @@ def download_and_prepare( if download_config is None: download_config = DownloadConfig( cache_dir=os.path.join(self._cache_dir_root, "downloads"), - force_download=bool(download_mode == FORCE_REDOWNLOAD), + force_download=bool(download_mode == GenerateMode.FORCE_REDOWNLOAD), use_etag=False, use_auth_token=use_auth_token, ) # We don't use etag for data files to speed up the process @@ -515,7 +502,7 @@ def download_and_prepare( lock_path = os.path.join(self._cache_dir_root, self._cache_dir.replace(os.sep, "_") + ".lock") with FileLock(lock_path): data_exists = os.path.exists(self._cache_dir) - if data_exists and download_mode == REUSE_DATASET_IF_EXISTS: + if data_exists and download_mode == GenerateMode.REUSE_DATASET_IF_EXISTS: logger.warning("Reusing dataset %s (%s)", self.name, self._cache_dir) # We need to update the info in case some splits were added in the meantime # for example when calling load_dataset from multiple workers. @@ -1174,9 +1161,9 @@ def _save_info(self): import apache_beam as beam fs = beam.io.filesystems.FileSystems - with fs.create(os.path.join(self._cache_dir, DATASET_INFO_FILENAME)) as f: + with fs.create(os.path.join(self._cache_dir, config.DATASET_INFO_FILENAME)) as f: self.info._dump_info(f) - with fs.create(os.path.join(self._cache_dir, LICENSE_FILENAME)) as f: + with fs.create(os.path.join(self._cache_dir, config.LICENSE_FILENAME)) as f: self.info._dump_license(f) def _prepare_split(self, split_generator, pipeline): diff --git a/src/datasets/commands/run_beam.py b/src/datasets/commands/run_beam.py index 5314d6f5f0c..7b671722efa 100644 --- a/src/datasets/commands/run_beam.py +++ b/src/datasets/commands/run_beam.py @@ -5,10 +5,10 @@ from typing import List from datasets import config -from datasets.builder import FORCE_REDOWNLOAD, REUSE_CACHE_IF_EXISTS, DatasetBuilder, DownloadConfig +from datasets.builder import DatasetBuilder from datasets.commands import BaseTransformersCLICommand -from datasets.info import DATASET_INFOS_DICT_FILE_NAME from datasets.load import import_main_class, prepare_module +from datasets.utils.download_manager import DownloadConfig, GenerateMode def run_beam_command_factory(args): @@ -113,7 +113,9 @@ def run(self): for builder in builders: builder.download_and_prepare( - download_mode=REUSE_CACHE_IF_EXISTS if not self._force_redownload else FORCE_REDOWNLOAD, + download_mode=GenerateMode.REUSE_CACHE_IF_EXISTS + if not self._force_redownload + else GenerateMode.FORCE_REDOWNLOAD, download_config=DownloadConfig(cache_dir=os.path.join(config.HF_DATASETS_CACHE, "downloads")), save_infos=self._save_infos, ignore_verifications=self._ignore_verifications, @@ -126,7 +128,7 @@ def run(self): # Let's move it to the original directory of the dataset script, to allow the user to # upload them on S3 at the same time afterwards. if self._save_infos: - dataset_infos_path = os.path.join(builder_cls.get_imported_module_dir(), DATASET_INFOS_DICT_FILE_NAME) + dataset_infos_path = os.path.join(builder_cls.get_imported_module_dir(), config.DATASETDICT_INFOS_FILENAME) name = Path(path).name + ".py" @@ -140,6 +142,6 @@ def run(self): exit(1) # Move datasetinfo back to the user - user_dataset_infos_path = os.path.join(dataset_dir, DATASET_INFOS_DICT_FILE_NAME) + user_dataset_infos_path = os.path.join(dataset_dir, config.DATASETDICT_INFOS_FILENAME) copyfile(dataset_infos_path, user_dataset_infos_path) print("Dataset Infos file saved at {}".format(user_dataset_infos_path)) diff --git a/src/datasets/commands/test.py b/src/datasets/commands/test.py index 217744e8687..84aac9a6d6c 100644 --- a/src/datasets/commands/test.py +++ b/src/datasets/commands/test.py @@ -4,10 +4,11 @@ from shutil import copyfile, rmtree from typing import Generator -from datasets.builder import FORCE_REDOWNLOAD, REUSE_CACHE_IF_EXISTS, DatasetBuilder +import datasets.config +from datasets.builder import DatasetBuilder from datasets.commands import BaseTransformersCLICommand -from datasets.info import DATASET_INFOS_DICT_FILE_NAME from datasets.load import import_main_class, prepare_module +from datasets.utils.download_manager import GenerateMode from datasets.utils.filelock import logger as fl_logger from datasets.utils.logging import ERROR, get_logger @@ -136,7 +137,9 @@ def get_builders() -> Generator[DatasetBuilder, None, None]: for j, builder in enumerate(get_builders()): print(f"Testing builder '{builder.config.name}' ({j + 1}/{n_builders})") builder.download_and_prepare( - download_mode=REUSE_CACHE_IF_EXISTS if not self._force_redownload else FORCE_REDOWNLOAD, + download_mode=GenerateMode.REUSE_CACHE_IF_EXISTS + if not self._force_redownload + else GenerateMode.FORCE_REDOWNLOAD, ignore_verifications=self._ignore_verifications, try_from_hf_gcs=False, ) @@ -148,7 +151,9 @@ def get_builders() -> Generator[DatasetBuilder, None, None]: # Let's move it to the original directory of the dataset script, to allow the user to # upload them on S3 at the same time afterwards. if self._save_infos: - dataset_infos_path = os.path.join(builder_cls.get_imported_module_dir(), DATASET_INFOS_DICT_FILE_NAME) + dataset_infos_path = os.path.join( + builder_cls.get_imported_module_dir(), datasets.config.DATASETDICT_INFOS_FILENAME + ) name = Path(path).name + ".py" combined_path = os.path.join(path, name) if os.path.isfile(path): @@ -161,7 +166,7 @@ def get_builders() -> Generator[DatasetBuilder, None, None]: # Move dataset_info back to the user if dataset_dir is not None: - user_dataset_infos_path = os.path.join(dataset_dir, DATASET_INFOS_DICT_FILE_NAME) + user_dataset_infos_path = os.path.join(dataset_dir, datasets.config.DATASETDICT_INFOS_FILENAME) copyfile(dataset_infos_path, user_dataset_infos_path) print("Dataset Infos file saved at {}".format(user_dataset_infos_path)) diff --git a/src/datasets/config.py b/src/datasets/config.py index 4239248fa77..5ee05da6118 100644 --- a/src/datasets/config.py +++ b/src/datasets/config.py @@ -61,25 +61,20 @@ TF_AVAILABLE = importlib.util.find_spec("tensorflow") is not None if TF_AVAILABLE: # For the metadata, we have to look for both tensorflow and tensorflow-cpu - try: - TF_VERSION = importlib_metadata.version("tensorflow") - except importlib_metadata.PackageNotFoundError: + for package in [ + "tensorflow", + "tensorflow-cpu", + "tensorflow-gpu", + "tf-nightly", + "tf-nightly-cpu", + "tf-nightly-gpu", + ]: try: - TF_VERSION = importlib_metadata.version("tensorflow-cpu") + TF_VERSION = importlib_metadata.version(package) except importlib_metadata.PackageNotFoundError: - try: - TF_VERSION = importlib_metadata.version("tensorflow-gpu") - except importlib_metadata.PackageNotFoundError: - try: - TF_VERSION = importlib_metadata.version("tf-nightly") - except importlib_metadata.PackageNotFoundError: - try: - TF_VERSION = importlib_metadata.version("tf-nightly-cpu") - except importlib_metadata.PackageNotFoundError: - try: - TF_VERSION = importlib_metadata.version("tf-nightly-gpu") - except importlib_metadata.PackageNotFoundError: - pass + continue + else: + break if TF_AVAILABLE: if version.parse(TF_VERSION) < version.parse("2"): logger.info(f"TensorFlow found but with version {TF_VERSION}. `datasets` requires version 2 minimum.") @@ -155,3 +150,9 @@ DATASET_ARROW_FILENAME = "dataset.arrow" DATASET_INDICES_FILENAME = "indices.arrow" DATASET_STATE_JSON_FILENAME = "state.json" +DATASET_INFO_FILENAME = "dataset_info.json" +DATASETDICT_INFOS_FILENAME = "dataset_infos.json" +LICENSE_FILENAME = "LICENSE" +METRIC_INFO_FILENAME = "metric_info.json" + +MAX_DATASET_CONFIG_ID_READABLE_LENGTH = 255 diff --git a/src/datasets/info.py b/src/datasets/info.py index 4d518e8b41e..77775191092 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -36,6 +36,7 @@ from dataclasses import asdict, dataclass, field from typing import List, Optional, Union +from . import config from .features import Features, Value from .splits import SplitDict from .utils import Version @@ -44,12 +45,6 @@ logger = get_logger(__name__) -# Name of the file to output the DatasetInfo p rotobuf object. -DATASET_INFO_FILENAME = "dataset_info.json" -DATASET_INFOS_DICT_FILE_NAME = "dataset_infos.json" -LICENSE_FILENAME = "LICENSE" -METRIC_INFO_FILENAME = "metric_info.json" - @dataclass class SupervisedKeysData: @@ -156,17 +151,17 @@ def __post_init__(self): self.supervised_keys = SupervisedKeysData(**self.supervised_keys) def _license_path(self, dataset_info_dir): - return os.path.join(dataset_info_dir, LICENSE_FILENAME) + return os.path.join(dataset_info_dir, config.LICENSE_FILENAME) def write_to_directory(self, dataset_info_dir): """Write `DatasetInfo` as JSON to `dataset_info_dir`. Also save the license separately in LICENCE. """ - with open(os.path.join(dataset_info_dir, DATASET_INFO_FILENAME), "wb") as f: + with open(os.path.join(dataset_info_dir, config.DATASET_INFO_FILENAME), "wb") as f: self._dump_info(f) - with open(os.path.join(dataset_info_dir, LICENSE_FILENAME), "wb") as f: + with open(os.path.join(dataset_info_dir, config.LICENSE_FILENAME), "wb") as f: self._dump_license(f) def _dump_info(self, file): @@ -220,7 +215,7 @@ def from_directory(cls, dataset_info_dir: str) -> "DatasetInfo": if not dataset_info_dir: raise ValueError("Calling DatasetInfo.from_directory() with undefined dataset_info_dir.") - with open(os.path.join(dataset_info_dir, DATASET_INFO_FILENAME), "r", encoding="utf-8") as f: + with open(os.path.join(dataset_info_dir, config.DATASET_INFO_FILENAME), "r", encoding="utf-8") as f: dataset_info_dict = json.load(f) return cls.from_dict(dataset_info_dict) @@ -246,7 +241,7 @@ def copy(self) -> "DatasetInfo": class DatasetInfosDict(dict): def write_to_directory(self, dataset_infos_dir, overwrite=False): total_dataset_infos = {} - dataset_infos_path = os.path.join(dataset_infos_dir, DATASET_INFOS_DICT_FILE_NAME) + dataset_infos_path = os.path.join(dataset_infos_dir, config.DATASETDICT_INFOS_FILENAME) if os.path.exists(dataset_infos_path) and not overwrite: logger.info("Dataset Infos already exists in {}. Completing it with new infos.".format(dataset_infos_dir)) total_dataset_infos = self.from_directory(dataset_infos_dir) @@ -259,7 +254,7 @@ def write_to_directory(self, dataset_infos_dir, overwrite=False): @classmethod def from_directory(cls, dataset_infos_dir): logger.info("Loading Dataset Infos from {}".format(dataset_infos_dir)) - with open(os.path.join(dataset_infos_dir, DATASET_INFOS_DICT_FILE_NAME), "r", encoding="utf-8") as f: + with open(os.path.join(dataset_infos_dir, config.DATASETDICT_INFOS_FILENAME), "r", encoding="utf-8") as f: dataset_infos_dict = { config_name: DatasetInfo.from_dict(dataset_info_dict) for config_name, dataset_info_dict in json.load(f).items() @@ -308,10 +303,10 @@ def write_to_directory(self, metric_info_dir): """Write `MetricInfo` as JSON to `metric_info_dir`. Also save the license separately in LICENCE. """ - with open(os.path.join(metric_info_dir, METRIC_INFO_FILENAME), "w", encoding="utf-8") as f: + with open(os.path.join(metric_info_dir, config.METRIC_INFO_FILENAME), "w", encoding="utf-8") as f: json.dump(asdict(self), f) - with open(os.path.join(metric_info_dir, LICENSE_FILENAME), "w", encoding="utf-8") as f: + with open(os.path.join(metric_info_dir, config.LICENSE_FILENAME), "w", encoding="utf-8") as f: f.write(self.license) @classmethod @@ -326,7 +321,7 @@ def from_directory(cls, metric_info_dir) -> "MetricInfo": if not metric_info_dir: raise ValueError("Calling MetricInfo.from_directory() with undefined metric_info_dir.") - with open(os.path.join(metric_info_dir, METRIC_INFO_FILENAME), "r", encoding="utf-8") as f: + with open(os.path.join(metric_info_dir, config.METRIC_INFO_FILENAME), "r", encoding="utf-8") as f: metric_info_dict = json.load(f) return cls.from_dict(metric_info_dict) diff --git a/src/datasets/load.py b/src/datasets/load.py index cb1b2580296..40b4dfcf754 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -35,7 +35,6 @@ from .dataset_dict import DatasetDict from .features import Features from .filesystems import extract_path_from_uri, is_remote_filesystem -from .info import DATASET_INFOS_DICT_FILE_NAME from .metric import Metric from .packaged_modules import _PACKAGED_DATASETS_MODULES, hash_python_lines from .splits import Split @@ -395,7 +394,7 @@ def _get_modification_time(module_hash): # 2. copy from the local file system inside the modules cache to import it base_path = url_or_path_parent(file_path) # remove the filename - dataset_infos = url_or_path_join(base_path, DATASET_INFOS_DICT_FILE_NAME) + dataset_infos = url_or_path_join(base_path, config.DATASETDICT_INFOS_FILENAME) # Download the dataset infos file if available try: @@ -462,7 +461,7 @@ def _get_modification_time(module_hash): hash_folder_path = force_local_path local_file_path = os.path.join(hash_folder_path, name) - dataset_infos_path = os.path.join(hash_folder_path, DATASET_INFOS_DICT_FILE_NAME) + dataset_infos_path = os.path.join(hash_folder_path, config.DATASETDICT_INFOS_FILENAME) # Prevent parallel disk operations lock_path = local_path + ".lock" diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 4c07bd35f42..245136e48b7 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -13,11 +13,12 @@ from absl.testing import parameterized import datasets.arrow_dataset -from datasets import NamedSplit, concatenate_datasets, load_from_disk, temp_seed +from datasets import concatenate_datasets, load_from_disk, temp_seed from datasets.arrow_dataset import Dataset, transmit_format, update_metadata_with_features from datasets.dataset_dict import DatasetDict from datasets.features import Array2D, Array3D, ClassLabel, Features, Sequence, Value from datasets.info import DatasetInfo +from datasets.splits import NamedSplit from datasets.table import ConcatenationTable, InMemoryTable, MemoryMappedTable from datasets.utils.logging import WARNING diff --git a/tests/test_builder.py b/tests/test_builder.py index 5c16a5af478..47b0163b4e4 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -9,17 +9,12 @@ from datasets.arrow_dataset import Dataset from datasets.arrow_writer import ArrowWriter -from datasets.builder import ( - FORCE_REDOWNLOAD, - REUSE_DATASET_IF_EXISTS, - BuilderConfig, - DatasetBuilder, - GeneratorBasedBuilder, -) +from datasets.builder import BuilderConfig, DatasetBuilder, GeneratorBasedBuilder from datasets.dataset_dict import DatasetDict from datasets.features import Features, Value from datasets.info import DatasetInfo, PostProcessedInfo from datasets.splits import Split, SplitDict, SplitGenerator, SplitInfo +from datasets.utils.download_manager import GenerateMode from .utils import assert_arrow_memory_doesnt_increase, assert_arrow_memory_increases, require_faiss @@ -123,7 +118,7 @@ def _split_generators(self, dl_manager): def _run_concurrent_download_and_prepare(tmp_dir): dummy_builder = DummyBuilder(cache_dir=tmp_dir, name="dummy") - dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=REUSE_DATASET_IF_EXISTS) + dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=GenerateMode.REUSE_DATASET_IF_EXISTS) return dummy_builder @@ -131,7 +126,7 @@ class BuilderTest(TestCase): def test_download_and_prepare(self): with tempfile.TemporaryDirectory() as tmp_dir: dummy_builder = DummyBuilder(cache_dir=tmp_dir, name="dummy") - dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=FORCE_REDOWNLOAD) + dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=GenerateMode.FORCE_REDOWNLOAD) self.assertTrue( os.path.exists(os.path.join(tmp_dir, "dummy_builder", "dummy", "0.0.0", "dummy_builder-train.arrow")) ) @@ -170,13 +165,13 @@ def test_download_and_prepare_with_base_path(self): dummy_builder = DummyBuilderWithDownload(cache_dir=tmp_dir, name="dummy", rel_path=rel_path) with self.assertRaises(FileNotFoundError): dummy_builder.download_and_prepare( - try_from_hf_gcs=False, download_mode=FORCE_REDOWNLOAD, base_path=tmp_dir + try_from_hf_gcs=False, download_mode=GenerateMode.FORCE_REDOWNLOAD, base_path=tmp_dir ) # test absolute path is missing dummy_builder = DummyBuilderWithDownload(cache_dir=tmp_dir, name="dummy", abs_path=abs_path) with self.assertRaises(FileNotFoundError): dummy_builder.download_and_prepare( - try_from_hf_gcs=False, download_mode=FORCE_REDOWNLOAD, base_path=tmp_dir + try_from_hf_gcs=False, download_mode=GenerateMode.FORCE_REDOWNLOAD, base_path=tmp_dir ) # test that they are both properly loaded when they exist open(os.path.join(tmp_dir, rel_path), "w") @@ -185,7 +180,7 @@ def test_download_and_prepare_with_base_path(self): cache_dir=tmp_dir, name="dummy", rel_path=rel_path, abs_path=abs_path ) dummy_builder.download_and_prepare( - try_from_hf_gcs=False, download_mode=FORCE_REDOWNLOAD, base_path=tmp_dir + try_from_hf_gcs=False, download_mode=GenerateMode.FORCE_REDOWNLOAD, base_path=tmp_dir ) self.assertTrue( os.path.exists( @@ -417,7 +412,7 @@ def _post_processing_resources(self, split): ) dummy_builder._post_process = types.MethodType(_post_process, dummy_builder) dummy_builder._post_processing_resources = types.MethodType(_post_processing_resources, dummy_builder) - dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=FORCE_REDOWNLOAD) + dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=GenerateMode.FORCE_REDOWNLOAD) self.assertTrue( os.path.exists(os.path.join(tmp_dir, "dummy_builder", "dummy", "0.0.0", "dummy_builder-train.arrow")) ) @@ -437,7 +432,7 @@ def _post_process(self, dataset, resources_paths): with tempfile.TemporaryDirectory() as tmp_dir: dummy_builder = DummyBuilder(cache_dir=tmp_dir, name="dummy") dummy_builder._post_process = types.MethodType(_post_process, dummy_builder) - dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=FORCE_REDOWNLOAD) + dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=GenerateMode.FORCE_REDOWNLOAD) self.assertTrue( os.path.exists(os.path.join(tmp_dir, "dummy_builder", "dummy", "0.0.0", "dummy_builder-train.arrow")) ) @@ -466,7 +461,7 @@ def _post_processing_resources(self, split): dummy_builder = DummyBuilder(cache_dir=tmp_dir, name="dummy") dummy_builder._post_process = types.MethodType(_post_process, dummy_builder) dummy_builder._post_processing_resources = types.MethodType(_post_processing_resources, dummy_builder) - dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=FORCE_REDOWNLOAD) + dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=GenerateMode.FORCE_REDOWNLOAD) self.assertTrue( os.path.exists(os.path.join(tmp_dir, "dummy_builder", "dummy", "0.0.0", "dummy_builder-train.arrow")) ) @@ -485,14 +480,17 @@ def _prepare_split(self, split_generator, **kwargs): dummy_builder = DummyBuilder(cache_dir=tmp_dir, name="dummy") dummy_builder._prepare_split = types.MethodType(_prepare_split, dummy_builder) self.assertRaises( - ValueError, dummy_builder.download_and_prepare, try_from_hf_gcs=False, download_mode=FORCE_REDOWNLOAD + ValueError, + dummy_builder.download_and_prepare, + try_from_hf_gcs=False, + download_mode=GenerateMode.FORCE_REDOWNLOAD, ) self.assertRaises(AssertionError, dummy_builder.as_dataset) def test_generator_based_download_and_prepare(self): with tempfile.TemporaryDirectory() as tmp_dir: dummy_builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, name="dummy") - dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=FORCE_REDOWNLOAD) + dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=GenerateMode.FORCE_REDOWNLOAD) self.assertTrue( os.path.exists( os.path.join( @@ -700,7 +698,7 @@ def test_generator_based_builder_as_dataset(in_memory, tmp_path): cache_dir.mkdir() cache_dir = str(cache_dir) dummy_builder = DummyGeneratorBasedBuilder(cache_dir=cache_dir, name="dummy") - dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=FORCE_REDOWNLOAD) + dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=GenerateMode.FORCE_REDOWNLOAD) with assert_arrow_memory_increases() if in_memory else assert_arrow_memory_doesnt_increase(): dataset = dummy_builder.as_dataset("train", in_memory=in_memory) assert dataset.data.to_pydict() == {"text": ["foo"] * 100} @@ -715,6 +713,6 @@ def test_custom_writer_batch_size(tmp_path, writer_batch_size, default_writer_ba DummyGeneratorBasedBuilder.DEFAULT_WRITER_BATCH_SIZE = default_writer_batch_size dummy_builder = DummyGeneratorBasedBuilder(cache_dir=cache_dir, name="dummy", writer_batch_size=writer_batch_size) assert dummy_builder._writer_batch_size == (writer_batch_size or default_writer_batch_size) - dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=FORCE_REDOWNLOAD) + dummy_builder.download_and_prepare(try_from_hf_gcs=False, download_mode=GenerateMode.FORCE_REDOWNLOAD) dataset = dummy_builder.as_dataset("train") assert len(dataset.data[0].chunks) == expected_chunks diff --git a/tests/test_dataset_common.py b/tests/test_dataset_common.py index e0ff52cc618..a2f0a943f89 100644 --- a/tests/test_dataset_common.py +++ b/tests/test_dataset_common.py @@ -24,25 +24,15 @@ from absl.testing import parameterized -from datasets import ( - BuilderConfig, - DatasetBuilder, - DownloadConfig, - Features, - GenerateMode, - MockDownloadManager, - Value, - cached_path, - hf_api, - import_main_class, - load_dataset, - prepare_module, -) -from datasets.features import ClassLabel +from datasets import cached_path, hf_api, import_main_class, load_dataset, prepare_module +from datasets.builder import BuilderConfig, DatasetBuilder +from datasets.features import ClassLabel, Features, Value from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES from datasets.search import _has_faiss -from datasets.utils.file_utils import is_remote_url +from datasets.utils.download_manager import GenerateMode +from datasets.utils.file_utils import DownloadConfig, is_remote_url from datasets.utils.logging import get_logger +from datasets.utils.mock_download_manager import MockDownloadManager from .utils import OfflineSimulationMode, for_all_test_methods, local, offline, packaged, remote, slow diff --git a/tests/test_dataset_dict.py b/tests/test_dataset_dict.py index a55eb8a5cb9..be2c1ef1f58 100644 --- a/tests/test_dataset_dict.py +++ b/tests/test_dataset_dict.py @@ -6,9 +6,10 @@ import pandas as pd import pytest -from datasets import Features, Sequence, Value, load_from_disk +from datasets import load_from_disk from datasets.arrow_dataset import Dataset from datasets.dataset_dict import DatasetDict +from datasets.features import Features, Sequence, Value from .conftest import s3_test_bucket_name from .utils import ( diff --git a/tests/test_download_manager.py b/tests/test_download_manager.py index a9eff208652..c0a5a6ea900 100644 --- a/tests/test_download_manager.py +++ b/tests/test_download_manager.py @@ -4,8 +4,8 @@ import pytest -from datasets.utils.download_manager import DownloadConfig, DownloadManager -from datasets.utils.file_utils import hash_url_to_filename +from datasets.utils.download_manager import DownloadManager +from datasets.utils.file_utils import DownloadConfig, hash_url_to_filename URL = "http://www.mocksite.com/file1.txt" diff --git a/tests/test_dummy_data_autogenerate.py b/tests/test_dummy_data_autogenerate.py index 8de67ab01ff..174784d8e1c 100644 --- a/tests/test_dummy_data_autogenerate.py +++ b/tests/test_dummy_data_autogenerate.py @@ -3,9 +3,12 @@ from tempfile import TemporaryDirectory from unittest import TestCase -from datasets.builder import DatasetInfo, DownloadConfig, GeneratorBasedBuilder, Split, SplitGenerator +from datasets.builder import GeneratorBasedBuilder from datasets.commands.dummy_data import DummyDataGeneratorDownloadManager, MockDownloadManager from datasets.features import Features, Value +from datasets.info import DatasetInfo +from datasets.splits import Split, SplitGenerator +from datasets.utils.download_manager import DownloadConfig from datasets.utils.version import Version diff --git a/tests/test_features.py b/tests/test_features.py index 73c106480ae..216bae63b1a 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -7,7 +7,6 @@ import pyarrow as pa import pytest -from datasets import DatasetInfo from datasets.arrow_dataset import Dataset from datasets.features import ( ClassLabel, @@ -19,6 +18,7 @@ cast_to_python_objects, string_to_arrow, ) +from datasets.info import DatasetInfo from .utils import require_tf, require_torch diff --git a/tests/test_hf_gcp.py b/tests/test_hf_gcp.py index 445f47d97f7..096c31267d3 100644 --- a/tests/test_hf_gcp.py +++ b/tests/test_hf_gcp.py @@ -4,9 +4,9 @@ from absl.testing import parameterized +from datasets import config from datasets.arrow_reader import HF_GCP_BASE_URL from datasets.builder import DatasetBuilder -from datasets.info import DATASET_INFO_FILENAME from datasets.load import import_main_class, prepare_module from datasets.utils import cached_path @@ -65,7 +65,7 @@ def test_dataset_info_available(self, dataset, config_name): ) dataset_info_url = os.path.join( - HF_GCP_BASE_URL, builder_instance._relative_data_dir(with_hash=False), DATASET_INFO_FILENAME + HF_GCP_BASE_URL, builder_instance._relative_data_dir(with_hash=False), config.DATASET_INFO_FILENAME ).replace(os.sep, "/") datset_info_path = cached_path(dataset_info_url, cache_dir=tmp_dir) self.assertTrue(os.path.exists(datset_info_path))