Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ def __init__(
inferred_features = Features.from_arrow_schema(arrow_table.schema)
if self.info.features is None:
self.info.features = inferred_features
else: # make sure the nested columns are in the right order
self.info.features = self.info.features.reorder_fields_as(inferred_features)

# Infer fingerprint if None

Expand Down
58 changes: 57 additions & 1 deletion src/datasets/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ def get_nested_type(schema: FeatureType) -> pa.DataType:
value_type = get_nested_type(schema.feature)
# We allow to reverse list of dict => dict of list for compatiblity with tfds
if isinstance(value_type, pa.StructType):
return pa.struct(dict(sorted((f.name, pa.list_(f.type, schema.length)) for f in value_type)))
return pa.struct({f.name: pa.list_(f.type, schema.length) for f in value_type})
return pa.list_(value_type, schema.length)

# Other objects are callable which returns their data type (ClassLabel, Array2D, Translation, Arrow datatype creation methods)
Expand Down Expand Up @@ -963,3 +963,59 @@ def encode_batch(self, batch):

def copy(self) -> "Features":
return copy.deepcopy(self)

def reorder_fields_as(self, other: "Features") -> "Features":
"""
The order of the fields is important since it matters for the underlying arrow data.
This method is used to re-order your features to match the fields orders of other features.

Re-ordering the fields allows to make the underlying arrow data type match.

Example::

>>> from datasets import Features, Sequence, Value
>>> # let's say we have to features with a different order of nested fields (for a and b for example)
>>> f1 = Features({"root": Sequence({"a": Value("string"), "b": Value("string")})})
>>> f2 = Features({"root": {"b": Sequence(Value("string")), "a": Sequence(Value("string"))}})
>>> assert f1.type != f2.type
>>> # re-ordering keeps the base structure (here Sequence is defined at the root level), but make the fields order match
>>> f1.reorder_fields_as(f2)
{'root': Sequence(feature={'b': Value(dtype='string', id=None), 'a': Value(dtype='string', id=None)}, length=-1, id=None)}
>>> assert f1.reorder_fields_as(f2).type == f2.type

"""

def recursive_reorder(source, target, stack=""):
stack_position = " at " + stack[1:] if stack else ""
if isinstance(target, Sequence):
target = target.feature
if isinstance(target, dict):
target = {k: [v] for k, v in target.items()}
else:
target = [target]
if isinstance(source, Sequence):
source, id_, length = source.feature, source.id, source.length
if isinstance(source, dict):
source = {k: [v] for k, v in source.items()}
reordered = recursive_reorder(source, target, stack)
return Sequence({k: v[0] for k, v in reordered.items()}, id=id_, length=length)
else:
source = [source]
reordered = recursive_reorder(source, target, stack)
return Sequence(reordered[0], id=id_, length=length)
elif isinstance(source, dict):
if not isinstance(target, dict):
raise ValueError(f"Type mismatch: between {source} and {target}" + stack_position)
if sorted(source) != sorted(target):
raise ValueError(f"Keys mismatch: between {source} and {target}" + stack_position)
return {key: recursive_reorder(source[key], target[key], stack + f".{key}") for key in target}
elif isinstance(source, list):
if not isinstance(target, list):
raise ValueError(f"Type mismatch: between {source} and {target}" + stack_position)
if len(source) != len(target):
raise ValueError(f"Length mismatch: between {source} and {target}" + stack_position)
return [recursive_reorder(source[i], target[i], stack + f".<list>") for i in range(len(target))]
else:
return source

return Features(recursive_reorder(self, other))
120 changes: 120 additions & 0 deletions tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,126 @@ def test_feature_named_type(self):
reloaded_features = Features.from_dict(asdict(ds_info)["features"])
assert features == reloaded_features

def test_reorder_fields_as(self):
features = Features(
{
"id": Value("string"),
"document": {
"title": Value("string"),
"url": Value("string"),
"html": Value("string"),
"tokens": Sequence({"token": Value("string"), "is_html": Value("bool")}),
},
"question": {
"text": Value("string"),
"tokens": Sequence(Value("string")),
},
"annotations": Sequence(
{
"id": Value("string"),
"long_answer": {
"start_token": Value("int64"),
"end_token": Value("int64"),
"start_byte": Value("int64"),
"end_byte": Value("int64"),
},
"short_answers": Sequence(
{
"start_token": Value("int64"),
"end_token": Value("int64"),
"start_byte": Value("int64"),
"end_byte": Value("int64"),
"text": Value("string"),
}
),
"yes_no_answer": ClassLabel(names=["NO", "YES"]),
}
),
}
)

other = Features( # same but with [] instead of sequences, and with a shuffled fields order
{
"id": Value("string"),
"document": {
"tokens": Sequence({"token": Value("string"), "is_html": Value("bool")}),
"title": Value("string"),
"url": Value("string"),
"html": Value("string"),
},
"question": {
"text": Value("string"),
"tokens": [Value("string")],
},
"annotations": {
"yes_no_answer": [ClassLabel(names=["NO", "YES"])],
"id": [Value("string")],
"long_answer": [
{
"end_byte": Value("int64"),
"start_token": Value("int64"),
"end_token": Value("int64"),
"start_byte": Value("int64"),
}
],
"short_answers": [
Sequence(
{
"text": Value("string"),
"start_token": Value("int64"),
"end_token": Value("int64"),
"start_byte": Value("int64"),
"end_byte": Value("int64"),
}
)
],
},
}
)

expected = Features(
{
"id": Value("string"),
"document": {
"tokens": Sequence({"token": Value("string"), "is_html": Value("bool")}),
"title": Value("string"),
"url": Value("string"),
"html": Value("string"),
},
"question": {
"text": Value("string"),
"tokens": Sequence(Value("string")),
},
"annotations": Sequence(
{
"yes_no_answer": ClassLabel(names=["NO", "YES"]),
"id": Value("string"),
"long_answer": {
"end_byte": Value("int64"),
"start_token": Value("int64"),
"end_token": Value("int64"),
"start_byte": Value("int64"),
},
"short_answers": Sequence(
{
"text": Value("string"),
"start_token": Value("int64"),
"end_token": Value("int64"),
"start_byte": Value("int64"),
"end_byte": Value("int64"),
}
),
}
),
}
)

reordered_features = features.reorder_fields_as(other)
self.assertDictEqual(reordered_features, expected)
self.assertEqual(reordered_features.type, other.type)
self.assertEqual(reordered_features.type, expected.type)
self.assertNotEqual(reordered_features.type, features.type)


def test_classlabel_init(tmp_path_factory):
names = ["negative", "positive"]
Expand Down
6 changes: 5 additions & 1 deletion tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,13 @@ def test_in_memory_table_from_buffer(in_memory_pa_table):
def test_in_memory_table_from_pandas(in_memory_pa_table):
df = in_memory_pa_table.to_pandas()
with assert_arrow_memory_increases():
# with no schema it might infer another order of the fields in the schema
table = InMemoryTable.from_pandas(df)
assert table.table == in_memory_pa_table
assert isinstance(table, InMemoryTable)
# by specifying schema we get the same order of features, and so the exact same table
table = InMemoryTable.from_pandas(df, schema=in_memory_pa_table.schema)
assert table.table == in_memory_pa_table
assert isinstance(table, InMemoryTable)


def test_in_memory_table_from_arrays(in_memory_pa_table):
Expand Down