Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3077,7 +3077,7 @@ def add_item(self, item: dict, new_fingerprint: str):
Returns:
:class:`Dataset`
"""
item_table = InMemoryTable.from_pydict({k: [v] for k, v in item.items()})
item_table = InMemoryTable.from_pydict({k: [item[k]] for k in self.features.keys() if k in item})
# Cast item
schema = pa.schema(self.features.type)
item_table = item_table.cast(schema)
Expand Down
7 changes: 6 additions & 1 deletion src/datasets/arrow_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,12 @@ def write_examples_on_file(self):
return

# Since current_examples contains (example, key) tuples
cols = sorted(self.current_examples[0][0].keys())
cols = (
[col for col in self._schema.names if col in self.current_examples[0][0]]
+ [col for col in self.current_examples[0][0].keys() if col not in self._schema.names]
if self._schema
else self.current_examples[0][0].keys()
)

schema = None if self.pa_writer is None and self.update_features else self._schema
try_schema = self._schema if self.pa_writer is None and self.update_features else None
Expand Down
6 changes: 3 additions & 3 deletions src/datasets/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,8 @@ def get_nested_type(schema: FeatureType) -> pa.DataType:
# Nested structures: we allow dict, list/tuples, sequences
if isinstance(schema, Features):
return pa.struct(
{key: get_nested_type(schema[key]) for key in sorted(schema)}
) # sort to make the order of columns deterministic
{key: get_nested_type(schema[key]) for key in schema}
) # Features is subclass of dict, and dict order is deterministic since Python 3.6
elif isinstance(schema, dict):
return pa.struct(
{key: get_nested_type(schema[key]) for key in schema}
Expand All @@ -829,7 +829,7 @@ def get_nested_type(schema: FeatureType) -> pa.DataType:
return pa.list_(value_type)
elif isinstance(schema, Sequence):
value_type = get_nested_type(schema.feature)
# We allow to reverse list of dict => dict of list for compatiblity with tfds
# We allow to reverse list of dict => dict of list for compatibility with tfds
if isinstance(value_type, pa.StructType):
return pa.struct({f.name: pa.list_(f.type, schema.length) for f in value_type})
return pa.list_(value_type, schema.length)
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def xml_file(tmp_path_factory):
"col_3": [0.0, 1.0, 2.0, 3.0],
}

DATA_312 = [
{"col_3": 0.0, "col_1": "0", "col_2": 0},
{"col_3": 1.0, "col_1": "1", "col_2": 1},
]


@pytest.fixture(scope="session")
def dataset_dict():
Expand Down Expand Up @@ -182,6 +187,15 @@ def jsonl_path(tmp_path_factory):
return path


@pytest.fixture(scope="session")
def jsonl_312_path(tmp_path_factory):
path = str(tmp_path_factory.mktemp("data") / "dataset_312.jsonl")
with open(path, "w") as f:
for item in DATA_312:
f.write(json.dumps(item))
return path


@pytest.fixture(scope="session")
def text_path(tmp_path_factory):
data = ["0", "1", "2", "3"]
Expand Down
24 changes: 23 additions & 1 deletion tests/io/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def test_dataset_from_json_keep_in_memory(keep_in_memory, jsonl_path, tmp_path):
)
def test_dataset_from_json_features(features, jsonl_path, tmp_path):
cache_dir = tmp_path / "cache"
# CSV file loses col_1 string dtype information: default now is "int64" instead of "string"
default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
expected_features = features.copy() if features else default_expected_features
features = (
Expand All @@ -49,6 +48,29 @@ def test_dataset_from_json_features(features, jsonl_path, tmp_path):
_check_json_dataset(dataset, expected_features)


@pytest.mark.parametrize(
"features",
[
None,
{"col_3": "float64", "col_1": "string", "col_2": "int64"},
],
)
def test_dataset_from_json_with_unsorted_column_names(features, jsonl_312_path, tmp_path):
cache_dir = tmp_path / "cache"
default_expected_features = {"col_3": "float64", "col_1": "string", "col_2": "int64"}
expected_features = features.copy() if features else default_expected_features
features = (
Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None
)
dataset = JsonDatasetReader(jsonl_312_path, features=features, cache_dir=cache_dir).read()
assert isinstance(dataset, Dataset)
assert dataset.num_rows == 2
assert dataset.num_columns == 3
assert dataset.column_names == ["col_3", "col_1", "col_2"]
for feature, expected_dtype in expected_features.items():
assert dataset.features[feature].dtype == expected_dtype


@pytest.mark.parametrize("split", [None, NamedSplit("train"), "train", "test"])
def test_dataset_from_json_split(split, jsonl_path, tmp_path):
cache_dir = tmp_path / "cache"
Expand Down
6 changes: 5 additions & 1 deletion tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2393,7 +2393,11 @@ def test_dataset_add_column(column, expected_dtype, in_memory, transform, datase
original_dataset: Dataset = getattr(original_dataset, transform_name)(*args, **kwargs)
dataset = original_dataset.add_column(column_name, column)
assert dataset.data.shape == (4, 4)
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64", column_name: expected_dtype}
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
# Sort expected features as in the original dataset
expected_features = {feature: expected_features[feature] for feature in original_dataset.features}
# Add new column feature
expected_features[column_name] = expected_dtype
assert dataset.data.column_names == list(expected_features.keys())
for feature, expected_dtype in expected_features.items():
assert dataset.features[feature].dtype == expected_dtype
Expand Down