From 4c7a688f07328d791e0aff42f4f51025b01f7edc Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue, 15 Jun 2021 15:58:08 +0200 Subject: [PATCH 1/2] Test loading JSON with mismatched features --- tests/conftest.py | 14 ++++++++++++++ tests/io/test_json.py | 17 +++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 4f110012a69..a4b2e4e6669 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -130,6 +130,11 @@ def xml_file(tmp_path_factory): "col_3": [0.0, 1.0, 2.0, 3.0], } +DATA_312 = [ + {"col_3": 0.0, "col_1": "0", "col_2": 0}, + {"col_3": 1.0, "col_1": "1", "col_2": 1}, +] + @pytest.fixture(scope="session") def dataset_dict(): @@ -182,6 +187,15 @@ def jsonl_path(tmp_path_factory): return path +@pytest.fixture(scope="session") +def jsonl_312_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset_312.jsonl") + with open(path, "w") as f: + for item in DATA_312: + f.write(json.dumps(item)) + return path + + @pytest.fixture(scope="session") def text_path(tmp_path_factory): data = ["0", "1", "2", "3"] diff --git a/tests/io/test_json.py b/tests/io/test_json.py index 024c98da025..f74d36bfa48 100644 --- a/tests/io/test_json.py +++ b/tests/io/test_json.py @@ -49,6 +49,23 @@ def test_dataset_from_json_features(features, jsonl_path, tmp_path): _check_json_dataset(dataset, expected_features) +def test_dataset_from_json_with_mismatched_features(jsonl_312_path, tmp_path): + # jsonl_312_path features are {"col_3": "float64", "col_1": "string", "col_2": "int64"} + features = {"col_2": "int64", "col_3": "float64", "col_1": "string"} + expected_features = features.copy() + features = ( + Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None + ) + cache_dir = tmp_path / "cache" + dataset = JsonDatasetReader(jsonl_312_path, features=features, cache_dir=cache_dir).read() + assert isinstance(dataset, Dataset) + assert dataset.num_rows == 2 + assert dataset.num_columns == 3 + assert dataset.column_names == ["col_2", "col_3", "col_1"] + for feature, expected_dtype in expected_features.items(): + assert dataset.features[feature].dtype == expected_dtype + + @pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"]) def test_dataset_from_json_split(split, jsonl_path, tmp_path): cache_dir = tmp_path / "cache" From 8b781ba92cf12b660e7a9871f2e8397a5ae9827d Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue, 15 Jun 2021 16:01:06 +0200 Subject: [PATCH 2/2] Rearrange JSON field names to match passed features schema field names --- src/datasets/packaged_modules/json/json.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index f5916f72f81..7c0a74b65e1 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -91,7 +91,10 @@ def _generate_tables(self, files): f"This JSON file contain the following fields: {str(list(dataset.keys()))}. " f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. " ) - if self.config.schema: + if self.config.features: # Cast allows str <-> int/float, while parse_option explicit_schema does NOT - pa_table = pa_table.cast(self.config.schema) + # Before casting, rearrange JSON field names to match passed features schema field names order + pa_table = pa.Table.from_arrays( + [pa_table[name] for name in self.config.features], schema=self.config.schema + ) yield i, pa_table