diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 5aedd42c4b2..893c9f9a1b0 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1937,7 +1937,7 @@ def encode_column(self, column, column_name: str): `list[Any]` """ column = cast_to_python_objects(column) - return [encode_nested_example(self[column_name], obj) for obj in column] + return [encode_nested_example(self[column_name], obj, level=1) for obj in column] def encode_batch(self, batch): """ @@ -1955,7 +1955,7 @@ def encode_batch(self, batch): raise ValueError(f"Column mismatch between batch {set(batch)} and features {set(self)}") for key, column in batch.items(): column = cast_to_python_objects(column) - encoded_batch[key] = [encode_nested_example(self[key], obj) for obj in column] + encoded_batch[key] = [encode_nested_example(self[key], obj, level=1) for obj in column] return encoded_batch def decode_example(self, example: dict, token_per_repo_id: Optional[Dict[str, Union[str, bool, None]]] = None): diff --git a/tests/features/test_features.py b/tests/features/test_features.py index 8c14f157817..b1ba7e9745f 100644 --- a/tests/features/test_features.py +++ b/tests/features/test_features.py @@ -405,6 +405,16 @@ def test_encode_batch_with_example_with_empty_first_elem(): assert encoded_batch == {"x": [[[0], [1]], [[], [1]]]} +def test_encode_column_dict_with_none(): + features = Features( + { + "x": {"a": ClassLabel(names=["a", "b"]), "b": Value("int32")}, + } + ) + encoded_column = features.encode_column([{"a": "a", "b": 1}, None], "x") + assert encoded_column == [{"a": 0, "b": 1}, None] + + @pytest.mark.parametrize( "feature", [