From 815ac932e612978b80fe29901504d1de4ce5ca28 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 22 May 2024 15:21:54 +0200 Subject: [PATCH 1/6] Make configs call super post_init in packaged modules --- src/datasets/packaged_modules/arrow/arrow.py | 3 +++ src/datasets/packaged_modules/audiofolder/audiofolder.py | 3 +++ src/datasets/packaged_modules/csv/csv.py | 1 + .../folder_based_builder/folder_based_builder.py | 3 +++ src/datasets/packaged_modules/generator/generator.py | 4 +++- src/datasets/packaged_modules/imagefolder/imagefolder.py | 3 +++ src/datasets/packaged_modules/json/json.py | 3 +++ src/datasets/packaged_modules/pandas/pandas.py | 3 +++ src/datasets/packaged_modules/parquet/parquet.py | 3 +++ src/datasets/packaged_modules/spark/spark.py | 3 +++ src/datasets/packaged_modules/sql/sql.py | 1 + src/datasets/packaged_modules/text/text.py | 1 + 12 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/datasets/packaged_modules/arrow/arrow.py b/src/datasets/packaged_modules/arrow/arrow.py index cd1ecbf12da..86a4cb66246 100644 --- a/src/datasets/packaged_modules/arrow/arrow.py +++ b/src/datasets/packaged_modules/arrow/arrow.py @@ -17,6 +17,9 @@ class ArrowConfig(datasets.BuilderConfig): features: Optional[datasets.Features] = None + def __post_init__(self): + super().__post_init__() + class Arrow(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = ArrowConfig diff --git a/src/datasets/packaged_modules/audiofolder/audiofolder.py b/src/datasets/packaged_modules/audiofolder/audiofolder.py index 51044143039..a86927934c1 100644 --- a/src/datasets/packaged_modules/audiofolder/audiofolder.py +++ b/src/datasets/packaged_modules/audiofolder/audiofolder.py @@ -15,6 +15,9 @@ class AudioFolderConfig(folder_based_builder.FolderBasedBuilderConfig): drop_labels: bool = None drop_metadata: bool = None + def __post_init__(self): + super().__post_init__() + class AudioFolder(folder_based_builder.FolderBasedBuilder): BASE_FEATURE = datasets.Audio diff --git a/src/datasets/packaged_modules/csv/csv.py b/src/datasets/packaged_modules/csv/csv.py index 181f52799b4..b7e8b6e220a 100644 --- a/src/datasets/packaged_modules/csv/csv.py +++ b/src/datasets/packaged_modules/csv/csv.py @@ -68,6 +68,7 @@ class CsvConfig(datasets.BuilderConfig): date_format: Optional[str] = None def __post_init__(self): + super().__post_init__() if self.delimiter is not None: self.sep = self.delimiter if self.column_names is not None: diff --git a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py index 24c32a746e8..7b71dc407ae 100644 --- a/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py +++ b/src/datasets/packaged_modules/folder_based_builder/folder_based_builder.py @@ -28,6 +28,9 @@ class FolderBasedBuilderConfig(datasets.BuilderConfig): drop_labels: bool = None drop_metadata: bool = None + def __post_init__(self): + super().__post_init__() + class FolderBasedBuilder(datasets.GeneratorBasedBuilder): """ diff --git a/src/datasets/packaged_modules/generator/generator.py b/src/datasets/packaged_modules/generator/generator.py index 1efa721b159..336942f2edc 100644 --- a/src/datasets/packaged_modules/generator/generator.py +++ b/src/datasets/packaged_modules/generator/generator.py @@ -11,7 +11,9 @@ class GeneratorConfig(datasets.BuilderConfig): features: Optional[datasets.Features] = None def __post_init__(self): - assert self.generator is not None, "generator must be specified" + super().__post_init__() + if self.generator is None: + raise ValueError("generator must be specified") if self.gen_kwargs is None: self.gen_kwargs = {} diff --git a/src/datasets/packaged_modules/imagefolder/imagefolder.py b/src/datasets/packaged_modules/imagefolder/imagefolder.py index bd2dd0d419a..16fbcd005d4 100644 --- a/src/datasets/packaged_modules/imagefolder/imagefolder.py +++ b/src/datasets/packaged_modules/imagefolder/imagefolder.py @@ -15,6 +15,9 @@ class ImageFolderConfig(folder_based_builder.FolderBasedBuilderConfig): drop_labels: bool = None drop_metadata: bool = None + def __post_init__(self): + super().__post_init__() + class ImageFolder(folder_based_builder.FolderBasedBuilder): BASE_FEATURE = datasets.Image diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index 8fa1e975d5a..2250186ce08 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -36,6 +36,9 @@ class JsonConfig(datasets.BuilderConfig): chunksize: int = 10 << 20 # 10MB newlines_in_values: Optional[bool] = None + def __post_init__(self): + super().__post_init__() + class Json(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = JsonConfig diff --git a/src/datasets/packaged_modules/pandas/pandas.py b/src/datasets/packaged_modules/pandas/pandas.py index c17f389945e..d1eb50d33c8 100644 --- a/src/datasets/packaged_modules/pandas/pandas.py +++ b/src/datasets/packaged_modules/pandas/pandas.py @@ -16,6 +16,9 @@ class PandasConfig(datasets.BuilderConfig): features: Optional[datasets.Features] = None + def __post_init__(self): + super().__post_init__() + class Pandas(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = PandasConfig diff --git a/src/datasets/packaged_modules/parquet/parquet.py b/src/datasets/packaged_modules/parquet/parquet.py index 399a2609f7e..e78638ecc6a 100644 --- a/src/datasets/packaged_modules/parquet/parquet.py +++ b/src/datasets/packaged_modules/parquet/parquet.py @@ -20,6 +20,9 @@ class ParquetConfig(datasets.BuilderConfig): columns: Optional[List[str]] = None features: Optional[datasets.Features] = None + def __post_init__(self): + super().__post_init__() + class Parquet(datasets.ArrowBasedBuilder): BUILDER_CONFIG_CLASS = ParquetConfig diff --git a/src/datasets/packaged_modules/spark/spark.py b/src/datasets/packaged_modules/spark/spark.py index fee5f7c4c61..f7c8a469ea8 100644 --- a/src/datasets/packaged_modules/spark/spark.py +++ b/src/datasets/packaged_modules/spark/spark.py @@ -30,6 +30,9 @@ class SparkConfig(datasets.BuilderConfig): features: Optional[datasets.Features] = None + def __post_init__(self): + super().__post_init__() + def _reorder_dataframe_by_partition(df: "pyspark.sql.DataFrame", new_partition_order: List[int]): df_combined = df.select("*").where(f"part_id = {new_partition_order[0]}") diff --git a/src/datasets/packaged_modules/sql/sql.py b/src/datasets/packaged_modules/sql/sql.py index b0791ba8859..152a8dc2089 100644 --- a/src/datasets/packaged_modules/sql/sql.py +++ b/src/datasets/packaged_modules/sql/sql.py @@ -35,6 +35,7 @@ class SqlConfig(datasets.BuilderConfig): features: Optional[datasets.Features] = None def __post_init__(self): + super().__post_init__() if self.sql is None: raise ValueError("sql must be specified") if self.con is None: diff --git a/src/datasets/packaged_modules/text/text.py b/src/datasets/packaged_modules/text/text.py index 47e07a0e4b3..e5754c3c06a 100644 --- a/src/datasets/packaged_modules/text/text.py +++ b/src/datasets/packaged_modules/text/text.py @@ -27,6 +27,7 @@ class TextConfig(datasets.BuilderConfig): sample_by: str = "line" def __post_init__(self, errors): + super().__post_init__() if errors != "deprecated": warnings.warn( "'errors' was deprecated in favor of 'encoding_errors' in version 2.14.0 and will be removed in 3.0.0.\n" From ac3fbf72229d4518235796fcfb1ff62ece9b835d Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 22 May 2024 17:01:36 +0200 Subject: [PATCH 2/6] Update hash in test --- tests/test_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_load.py b/tests/test_load.py index 4b2b9cbf58c..3ecebee69fb 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1749,7 +1749,7 @@ def test_resolve_trust_remote_code_future(trust_remote_code, expected): def test_reload_old_cache_from_2_15(tmp_path: Path): cache_dir = tmp_path / "test_reload_old_cache_from_2_15" builder_cache_dir = ( - cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata/v2-374bfde4f55442bc/0.0.0/7896925d64deea5d" + cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata/v2-374bfde4f55442bc/0.0.0/cf191ad706de653e" ) builder_cache_dir.mkdir(parents=True) arrow_path = builder_cache_dir / "audiofolder_two_configs_in_metadata-train.arrow" From ebb8413198e802cb871bc9a1fef293fadc53656f Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 5 Jun 2024 08:43:38 +0200 Subject: [PATCH 3/6] Add tests --- tests/packaged_modules/test_arrow.py | 16 ++++++++++++++++ tests/packaged_modules/test_audiofolder.py | 16 ++++++++++++++-- tests/packaged_modules/test_csv.py | 15 ++++++++++++++- .../test_folder_based_builder.py | 14 +++++++++++++- tests/packaged_modules/test_imagefolder.py | 16 ++++++++++++++-- tests/packaged_modules/test_json.py | 15 ++++++++++++++- tests/packaged_modules/test_pandas.py | 16 ++++++++++++++++ tests/packaged_modules/test_parquet.py | 16 ++++++++++++++++ tests/packaged_modules/test_spark.py | 15 +++++++++++++++ tests/packaged_modules/test_sql.py | 16 ++++++++++++++++ tests/packaged_modules/test_text.py | 15 ++++++++++++++- 11 files changed, 162 insertions(+), 8 deletions(-) create mode 100644 tests/packaged_modules/test_arrow.py create mode 100644 tests/packaged_modules/test_pandas.py create mode 100644 tests/packaged_modules/test_parquet.py create mode 100644 tests/packaged_modules/test_sql.py diff --git a/tests/packaged_modules/test_arrow.py b/tests/packaged_modules/test_arrow.py new file mode 100644 index 00000000000..6e355e49cad --- /dev/null +++ b/tests/packaged_modules/test_arrow.py @@ -0,0 +1,16 @@ +import pytest + +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.arrow.arrow import ArrowConfig + + +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = ArrowConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = ArrowConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_audiofolder.py b/tests/packaged_modules/test_audiofolder.py index 712e6aeac4f..418b35072da 100644 --- a/tests/packaged_modules/test_audiofolder.py +++ b/tests/packaged_modules/test_audiofolder.py @@ -7,9 +7,10 @@ import soundfile as sf from datasets import Audio, ClassLabel, Features, Value -from datasets.data_files import DataFilesDict, get_data_patterns +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesDict, DataFilesList, get_data_patterns from datasets.download.streaming_download_manager import StreamingDownloadManager -from datasets.packaged_modules.audiofolder.audiofolder import AudioFolder +from datasets.packaged_modules.audiofolder.audiofolder import AudioFolder, AudioFolderConfig from ..utils import require_sndfile @@ -230,6 +231,17 @@ def data_files_with_zip_archives(tmp_path, audio_file): return data_files_with_zip_archives +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = AudioFolderConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = AudioFolderConfig(name="name", data_files=data_files) + + @require_sndfile # check that labels are inferred correctly from dir names def test_generate_examples_with_labels(data_files_with_labels_no_metadata, cache_dir): diff --git a/tests/packaged_modules/test_csv.py b/tests/packaged_modules/test_csv.py index 6cfa5e4ca23..f824837ec86 100644 --- a/tests/packaged_modules/test_csv.py +++ b/tests/packaged_modules/test_csv.py @@ -5,7 +5,9 @@ import pytest from datasets import ClassLabel, Features, Image -from datasets.packaged_modules.csv.csv import Csv +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.csv.csv import Csv, CsvConfig from ..utils import require_pil @@ -86,6 +88,17 @@ def csv_file_with_int_list(tmp_path): return str(filename) +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = CsvConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = CsvConfig(name="name", data_files=data_files) + + def test_csv_generate_tables_raises_error_with_malformed_csv(csv_file, malformed_csv_file, caplog): csv = Csv() generator = csv._generate_tables([[csv_file, malformed_csv_file]]) diff --git a/tests/packaged_modules/test_folder_based_builder.py b/tests/packaged_modules/test_folder_based_builder.py index c6aad5ded09..2b0ef8f1ec6 100644 --- a/tests/packaged_modules/test_folder_based_builder.py +++ b/tests/packaged_modules/test_folder_based_builder.py @@ -5,7 +5,8 @@ import pytest from datasets import ClassLabel, DownloadManager, Features, Value -from datasets.data_files import DataFilesDict, get_data_patterns +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesDict, DataFilesList, get_data_patterns from datasets.download.streaming_download_manager import StreamingDownloadManager from datasets.packaged_modules.folder_based_builder.folder_based_builder import ( FolderBasedBuilder, @@ -265,6 +266,17 @@ def data_files_with_zip_archives(tmp_path, auto_text_file): return data_files_with_zip_archives +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = FolderBasedBuilderConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = FolderBasedBuilderConfig(name="name", data_files=data_files) + + def test_inferring_labels_from_data_dirs(data_files_with_labels_no_metadata, cache_dir): autofolder = DummyFolderBasedBuilder( data_files=data_files_with_labels_no_metadata, cache_dir=cache_dir, drop_labels=False diff --git a/tests/packaged_modules/test_imagefolder.py b/tests/packaged_modules/test_imagefolder.py index 3be9195d6aa..095c4909c11 100644 --- a/tests/packaged_modules/test_imagefolder.py +++ b/tests/packaged_modules/test_imagefolder.py @@ -5,9 +5,10 @@ import pytest from datasets import ClassLabel, Features, Image, Value -from datasets.data_files import DataFilesDict, get_data_patterns +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesDict, DataFilesList, get_data_patterns from datasets.download.streaming_download_manager import StreamingDownloadManager -from datasets.packaged_modules.imagefolder.imagefolder import ImageFolder +from datasets.packaged_modules.imagefolder.imagefolder import ImageFolder, ImageFolderConfig from ..utils import require_pil @@ -239,6 +240,17 @@ def data_files_with_zip_archives(tmp_path, image_file): return data_files_with_zip_archives +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = ImageFolderConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = ImageFolderConfig(name="name", data_files=data_files) + + @require_pil # check that labels are inferred correctly from dir names def test_generate_examples_with_labels(data_files_with_labels_no_metadata, cache_dir): diff --git a/tests/packaged_modules/test_json.py b/tests/packaged_modules/test_json.py index 9375fd443b4..e760a9208d6 100644 --- a/tests/packaged_modules/test_json.py +++ b/tests/packaged_modules/test_json.py @@ -4,7 +4,9 @@ import pytest from datasets import Features, Value -from datasets.packaged_modules.json.json import Json +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.json.json import Json, JsonConfig @pytest.fixture @@ -92,6 +94,17 @@ def json_file_with_list_of_dicts_field(tmp_path): return str(filename) +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = JsonConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = JsonConfig(name="name", data_files=data_files) + + @pytest.mark.parametrize( "file_fixture, config_kwargs", [ diff --git a/tests/packaged_modules/test_pandas.py b/tests/packaged_modules/test_pandas.py new file mode 100644 index 00000000000..d7049bfa304 --- /dev/null +++ b/tests/packaged_modules/test_pandas.py @@ -0,0 +1,16 @@ +import pytest + +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.pandas.pandas import PandasConfig + + +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = PandasConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = PandasConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_parquet.py b/tests/packaged_modules/test_parquet.py new file mode 100644 index 00000000000..995f60d1465 --- /dev/null +++ b/tests/packaged_modules/test_parquet.py @@ -0,0 +1,16 @@ +import pytest + +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.parquet.parquet import ParquetConfig + + +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = ParquetConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = ParquetConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_spark.py b/tests/packaged_modules/test_spark.py index cabc61c682f..f2f22cc1d0c 100644 --- a/tests/packaged_modules/test_spark.py +++ b/tests/packaged_modules/test_spark.py @@ -1,9 +1,13 @@ from unittest.mock import patch import pyspark +import pytest +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList from datasets.packaged_modules.spark.spark import ( Spark, + SparkConfig, SparkExamplesIterable, _generate_iterable_examples, ) @@ -23,6 +27,17 @@ def _get_expected_row_ids_and_row_dicts_for_partition_order(df, partition_order) return expected_row_ids_and_row_dicts +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = SparkConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = SparkConfig(name="name", data_files=data_files) + + @require_not_windows @require_dill_gt_0_3_2 def test_repartition_df_if_needed(): diff --git a/tests/packaged_modules/test_sql.py b/tests/packaged_modules/test_sql.py new file mode 100644 index 00000000000..ddc2fb0b205 --- /dev/null +++ b/tests/packaged_modules/test_sql.py @@ -0,0 +1,16 @@ +import pytest + +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.sql.sql import SqlConfig + + +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = SqlConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = SqlConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_text.py b/tests/packaged_modules/test_text.py index 0d1b3f3b5a4..6533245f1d1 100644 --- a/tests/packaged_modules/test_text.py +++ b/tests/packaged_modules/test_text.py @@ -4,7 +4,9 @@ import pytest from datasets import Features, Image -from datasets.packaged_modules.text.text import Text +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.text.text import Text, TextConfig from ..utils import require_pil @@ -39,6 +41,17 @@ def text_file_with_image(tmp_path, image_file): return str(filename) +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = TextConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = TextConfig(name="name", data_files=data_files) + + @pytest.mark.parametrize("keep_linebreaks", [True, False]) def test_text_linebreaks(text_file, keep_linebreaks): with open(text_file, encoding="utf-8") as f: From 58970f2782f915bbd13aa3352296adddf04484e3 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 5 Jun 2024 08:47:39 +0200 Subject: [PATCH 4/6] Add tests for BuilderConfig --- tests/test_builder.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/tests/test_builder.py b/tests/test_builder.py index 81966044fc3..57052f964e3 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -17,7 +17,15 @@ from datasets.arrow_dataset import Dataset from datasets.arrow_reader import DatasetNotOnHfGcsError from datasets.arrow_writer import ArrowWriter -from datasets.builder import ArrowBasedBuilder, BeamBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder +from datasets.builder import ( + ArrowBasedBuilder, + BeamBasedBuilder, + BuilderConfig, + DatasetBuilder, + GeneratorBasedBuilder, + InvalidConfigName, +) +from datasets.data_files import DataFilesList from datasets.dataset_dict import DatasetDict, IterableDatasetDict from datasets.download.download_manager import DownloadMode from datasets.features import Features, Value @@ -836,6 +844,17 @@ def test_cache_dir_for_configured_builder(self): self.assertNotEqual(builder.cache_dir, other_builder.cache_dir) +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = BuilderConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = BuilderConfig(name="name", data_files=data_files) + + def test_arrow_based_download_and_prepare(tmp_path): builder = DummyArrowBasedBuilder(cache_dir=tmp_path) builder.download_and_prepare() From 6ac8c6ea149accaac68f9e25f7c81a4f79afd132 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 5 Jun 2024 09:34:55 +0200 Subject: [PATCH 5/6] Fix syntax --- tests/packaged_modules/test_arrow.py | 2 +- tests/packaged_modules/test_audiofolder.py | 2 +- tests/packaged_modules/test_csv.py | 2 +- tests/packaged_modules/test_folder_based_builder.py | 2 +- tests/packaged_modules/test_imagefolder.py | 2 +- tests/packaged_modules/test_json.py | 2 +- tests/packaged_modules/test_pandas.py | 2 +- tests/packaged_modules/test_parquet.py | 2 +- tests/packaged_modules/test_spark.py | 2 +- tests/packaged_modules/test_sql.py | 2 +- tests/packaged_modules/test_text.py | 2 +- tests/test_builder.py | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/tests/packaged_modules/test_arrow.py b/tests/packaged_modules/test_arrow.py index 6e355e49cad..dda3720efe3 100644 --- a/tests/packaged_modules/test_arrow.py +++ b/tests/packaged_modules/test_arrow.py @@ -10,7 +10,7 @@ def test_config_raises_when_invalid_name() -> None: _ = ArrowConfig(name="name-with-*-invalid-character") -@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) def test_config_raises_when_invalid_data_files(data_files) -> None: with pytest.raises(ValueError, match="Expected a DataFilesDict"): _ = ArrowConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_audiofolder.py b/tests/packaged_modules/test_audiofolder.py index 418b35072da..3351fccf604 100644 --- a/tests/packaged_modules/test_audiofolder.py +++ b/tests/packaged_modules/test_audiofolder.py @@ -236,7 +236,7 @@ def test_config_raises_when_invalid_name() -> None: _ = AudioFolderConfig(name="name-with-*-invalid-character") -@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) def test_config_raises_when_invalid_data_files(data_files) -> None: with pytest.raises(ValueError, match="Expected a DataFilesDict"): _ = AudioFolderConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_csv.py b/tests/packaged_modules/test_csv.py index f824837ec86..e85dc1e3b09 100644 --- a/tests/packaged_modules/test_csv.py +++ b/tests/packaged_modules/test_csv.py @@ -93,7 +93,7 @@ def test_config_raises_when_invalid_name() -> None: _ = CsvConfig(name="name-with-*-invalid-character") -@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) def test_config_raises_when_invalid_data_files(data_files) -> None: with pytest.raises(ValueError, match="Expected a DataFilesDict"): _ = CsvConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_folder_based_builder.py b/tests/packaged_modules/test_folder_based_builder.py index 2b0ef8f1ec6..3623c4b1680 100644 --- a/tests/packaged_modules/test_folder_based_builder.py +++ b/tests/packaged_modules/test_folder_based_builder.py @@ -271,7 +271,7 @@ def test_config_raises_when_invalid_name() -> None: _ = FolderBasedBuilderConfig(name="name-with-*-invalid-character") -@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) def test_config_raises_when_invalid_data_files(data_files) -> None: with pytest.raises(ValueError, match="Expected a DataFilesDict"): _ = FolderBasedBuilderConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_imagefolder.py b/tests/packaged_modules/test_imagefolder.py index 095c4909c11..835d3a7db0c 100644 --- a/tests/packaged_modules/test_imagefolder.py +++ b/tests/packaged_modules/test_imagefolder.py @@ -245,7 +245,7 @@ def test_config_raises_when_invalid_name() -> None: _ = ImageFolderConfig(name="name-with-*-invalid-character") -@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) def test_config_raises_when_invalid_data_files(data_files) -> None: with pytest.raises(ValueError, match="Expected a DataFilesDict"): _ = ImageFolderConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_json.py b/tests/packaged_modules/test_json.py index 4b3a5c3e2d1..07bbba7adec 100644 --- a/tests/packaged_modules/test_json.py +++ b/tests/packaged_modules/test_json.py @@ -178,7 +178,7 @@ def test_config_raises_when_invalid_name() -> None: _ = JsonConfig(name="name-with-*-invalid-character") -@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) def test_config_raises_when_invalid_data_files(data_files) -> None: with pytest.raises(ValueError, match="Expected a DataFilesDict"): _ = JsonConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_pandas.py b/tests/packaged_modules/test_pandas.py index d7049bfa304..60b3bb22107 100644 --- a/tests/packaged_modules/test_pandas.py +++ b/tests/packaged_modules/test_pandas.py @@ -10,7 +10,7 @@ def test_config_raises_when_invalid_name() -> None: _ = PandasConfig(name="name-with-*-invalid-character") -@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) def test_config_raises_when_invalid_data_files(data_files) -> None: with pytest.raises(ValueError, match="Expected a DataFilesDict"): _ = PandasConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_parquet.py b/tests/packaged_modules/test_parquet.py index 995f60d1465..b5c1808d8f9 100644 --- a/tests/packaged_modules/test_parquet.py +++ b/tests/packaged_modules/test_parquet.py @@ -10,7 +10,7 @@ def test_config_raises_when_invalid_name() -> None: _ = ParquetConfig(name="name-with-*-invalid-character") -@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) def test_config_raises_when_invalid_data_files(data_files) -> None: with pytest.raises(ValueError, match="Expected a DataFilesDict"): _ = ParquetConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_spark.py b/tests/packaged_modules/test_spark.py index be5083883e4..c91bdd571ea 100644 --- a/tests/packaged_modules/test_spark.py +++ b/tests/packaged_modules/test_spark.py @@ -32,7 +32,7 @@ def test_config_raises_when_invalid_name() -> None: _ = SparkConfig(name="name-with-*-invalid-character") -@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) def test_config_raises_when_invalid_data_files(data_files) -> None: with pytest.raises(ValueError, match="Expected a DataFilesDict"): _ = SparkConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_sql.py b/tests/packaged_modules/test_sql.py index ddc2fb0b205..e745cb03d2e 100644 --- a/tests/packaged_modules/test_sql.py +++ b/tests/packaged_modules/test_sql.py @@ -10,7 +10,7 @@ def test_config_raises_when_invalid_name() -> None: _ = SqlConfig(name="name-with-*-invalid-character") -@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) def test_config_raises_when_invalid_data_files(data_files) -> None: with pytest.raises(ValueError, match="Expected a DataFilesDict"): _ = SqlConfig(name="name", data_files=data_files) diff --git a/tests/packaged_modules/test_text.py b/tests/packaged_modules/test_text.py index 6533245f1d1..a21b3e223d9 100644 --- a/tests/packaged_modules/test_text.py +++ b/tests/packaged_modules/test_text.py @@ -46,7 +46,7 @@ def test_config_raises_when_invalid_name() -> None: _ = TextConfig(name="name-with-*-invalid-character") -@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) def test_config_raises_when_invalid_data_files(data_files) -> None: with pytest.raises(ValueError, match="Expected a DataFilesDict"): _ = TextConfig(name="name", data_files=data_files) diff --git a/tests/test_builder.py b/tests/test_builder.py index 57052f964e3..6698a79cbf8 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -849,7 +849,7 @@ def test_config_raises_when_invalid_name() -> None: _ = BuilderConfig(name="name-with-*-invalid-character") -@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList["str_path"]]) +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) def test_config_raises_when_invalid_data_files(data_files) -> None: with pytest.raises(ValueError, match="Expected a DataFilesDict"): _ = BuilderConfig(name="name", data_files=data_files) From e6ecc34b36798ab6645269d88d797dec9d6dae39 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Wed, 5 Jun 2024 14:51:58 +0200 Subject: [PATCH 6/6] use old hash for 2.15 cache reload --- src/datasets/builder.py | 4 ++-- src/datasets/packaged_modules/__init__.py | 12 ++++++++++++ tests/test_load.py | 2 +- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 84e005f3914..37de71824b6 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -494,7 +494,7 @@ def _check_legacy_cache2(self, dataset_module: "DatasetModule") -> Optional[str] and not is_remote_url(self._cache_dir_root) and not (set(self.config_kwargs) - {"data_files", "data_dir"}) ): - from .packaged_modules import _PACKAGED_DATASETS_MODULES + from .packaged_modules import _PACKAGED_DATASETS_MODULES_2_15_HASHES from .utils._dill import Pickler def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str: @@ -516,7 +516,7 @@ def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> st namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None with patch.object(Pickler, "_legacy_no_dict_keys_sorting", True): config_id = self.config.name + "-" + Hasher.hash({"data_files": self.config.data_files}) - hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1] + hash = _PACKAGED_DATASETS_MODULES_2_15_HASHES.get(self.name, "missing") if ( dataset_module.builder_configs_parameters.metadata_configs and self.config.name in dataset_module.builder_configs_parameters.metadata_configs diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 984dc0f03a3..3513f9ae59e 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -43,6 +43,18 @@ def _hash_python_lines(lines: List[str]) -> str: "webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())), } +# get importable module names and hash for caching +_PACKAGED_DATASETS_MODULES_2_15_HASHES = { + "csv": "eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d", + "json": "8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96", + "pandas": "3ac4ffc4563c796122ef66899b9485a3f1a977553e2d2a8a318c72b8cc6f2202", + "parquet": "ca31c69184d9832faed373922c2acccec0b13a0bb5bbbe19371385c3ff26f1d1", + "arrow": "74f69db2c14c2860059d39860b1f400a03d11bf7fb5a8258ca38c501c878c137", + "text": "c4a140d10f020282918b5dd1b8a49f0104729c6177f60a6b49ec2a365ec69f34", + "imagefolder": "7b7ce5247a942be131d49ad4f3de5866083399a0f250901bd8dc202f8c5f7ce5", + "audiofolder": "d3c1655c66c8f72e4efb5c79e952975fa6e2ce538473a6890241ddbddee9071c", +} + # Used to infer the module to use based on the data files extensions _EXTENSION_TO_MODULE: Dict[str, Tuple[str, dict]] = { ".csv": ("csv", {}), diff --git a/tests/test_load.py b/tests/test_load.py index 337e2811792..c7c413ae10b 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1764,7 +1764,7 @@ def test_resolve_trust_remote_code_future(trust_remote_code, expected): def test_reload_old_cache_from_2_15(tmp_path: Path): cache_dir = tmp_path / "test_reload_old_cache_from_2_15" builder_cache_dir = ( - cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata/v2-374bfde4f55442bc/0.0.0/cf191ad706de653e" + cache_dir / "polinaeterna___audiofolder_two_configs_in_metadata/v2-374bfde4f55442bc/0.0.0/7896925d64deea5d" ) builder_cache_dir.mkdir(parents=True) arrow_path = builder_cache_dir / "audiofolder_two_configs_in_metadata-train.arrow"