From d33c55e6ba5880b565f431a7f6a049c112c0ec4a Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 9 Jun 2021 19:06:13 +0200 Subject: [PATCH 1/3] Test dataset from JSON with ClassLabel --- tests/conftest.py | 16 ++++++++++++++++ tests/test_arrow_dataset.py | 13 +++++++++++++ 2 files changed, 29 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 4f110012a69..3a38fa853bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -130,6 +130,13 @@ def xml_file(tmp_path_factory): "col_3": [0.0, 1.0, 2.0, 3.0], } +DATA_STR = [ + {"col_1": "s0", "col_2": 0, "col_3": 0.0}, + {"col_1": "s1", "col_2": 1, "col_3": 1.0}, + {"col_1": "s2", "col_2": 2, "col_3": 2.0}, + {"col_1": "s3", "col_2": 3, "col_3": 3.0}, +] + @pytest.fixture(scope="session") def dataset_dict(): @@ -182,6 +189,15 @@ def jsonl_path(tmp_path_factory): return path +@pytest.fixture(scope="session") +def jsonl_str_path(tmp_path_factory): + path = str(tmp_path_factory.mktemp("data") / "dataset-str.jsonl") + with open(path, "w") as f: + for item in DATA_STR: + f.write(json.dumps(item)) + return path + + @pytest.fixture(scope="session") def text_path(tmp_path_factory): data = ["0", "1", "2", "3"] diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 9b36a5c2da2..c8cd0a04660 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2564,6 +2564,19 @@ def test_dataset_from_json_features(features, jsonl_path, tmp_path): _check_json_dataset(dataset, expected_features) +def test_dataset_from_json_with_class_label_feature(jsonl_str_path, tmp_path): + features = Features( + {"col_1": ClassLabel(names=["s0", "s1", "s2", "s3"]), "col_2": Value("int64"), "col_3": Value("float64")} + ) + cache_dir = tmp_path / "cache" + dataset = Dataset.from_json(jsonl_str_path, features=features, cache_dir=cache_dir) + # import pdb + # + # pdb.set_trace() + # dataset = dataset.map(features.encode_example, features=features) + assert dataset.features["col_1"].dtype == "int64" + + @pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"]) def test_dataset_from_json_split(split, jsonl_path, tmp_path): cache_dir = tmp_path / "cache" From 8f4ea507009baab673a94a409be0b34dbca14ede Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 9 Jun 2021 19:07:18 +0200 Subject: [PATCH 2/3] Implement encoding if ClassLabel --- src/datasets/features.py | 2 +- src/datasets/packaged_modules/json/json.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/datasets/features.py b/src/datasets/features.py index bb18cd80d96..265c19f49eb 100644 --- a/src/datasets/features.py +++ b/src/datasets/features.py @@ -604,7 +604,7 @@ def str2int(self, values: Union[str, Iterable]): if self._str2int: # strip key if not in dict if value not in self._str2int: - value = value.strip() + value = str(value).strip() output.append(self._str2int[str(value)]) else: # No names provided, try to integerize diff --git a/src/datasets/packaged_modules/json/json.py b/src/datasets/packaged_modules/json/json.py index f5916f72f81..fef61ddbb70 100644 --- a/src/datasets/packaged_modules/json/json.py +++ b/src/datasets/packaged_modules/json/json.py @@ -92,6 +92,12 @@ def _generate_tables(self, files): f"Select the correct one and provide it as `field='XXX'` to the dataset loading method. " ) if self.config.schema: + # Encode column if ClassLabel + for i, col in enumerate(self.config.features.keys()): + if isinstance(self.config.features[col], datasets.ClassLabel): + pa_table = pa_table.set_column( + i, self.config.schema.field(col), [self.config.features[col].str2int(pa_table[col])] + ) # Cast allows str <-> int/float, while parse_option explicit_schema does NOT pa_table = pa_table.cast(self.config.schema) yield i, pa_table From 4a5c4e0527496ada60e5e26f45bbfacd387a81f3 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Sat, 12 Jun 2021 15:49:36 +0200 Subject: [PATCH 3/3] Cleanup --- tests/test_arrow_dataset.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index c8cd0a04660..45cb0360001 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -2570,10 +2570,6 @@ def test_dataset_from_json_with_class_label_feature(jsonl_str_path, tmp_path): ) cache_dir = tmp_path / "cache" dataset = Dataset.from_json(jsonl_str_path, features=features, cache_dir=cache_dir) - # import pdb - # - # pdb.set_trace() - # dataset = dataset.map(features.encode_example, features=features) assert dataset.features["col_1"].dtype == "int64"