diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 151d58ee35e..a78e97b5467 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -265,6 +265,8 @@ def __init__( inferred_features = Features.from_arrow_schema(arrow_table.schema) if self.info.features is None: self.info.features = inferred_features + else: # make sure the nested columns are in the right order + self.info.features = self.info.features.reorder_fields_as(inferred_features) # Infer fingerprint if None diff --git a/src/datasets/features.py b/src/datasets/features.py index 26c77cf56a3..bb18cd80d96 100644 --- a/src/datasets/features.py +++ b/src/datasets/features.py @@ -831,7 +831,7 @@ def get_nested_type(schema: FeatureType) -> pa.DataType: value_type = get_nested_type(schema.feature) # We allow to reverse list of dict => dict of list for compatiblity with tfds if isinstance(value_type, pa.StructType): - return pa.struct(dict(sorted((f.name, pa.list_(f.type, schema.length)) for f in value_type))) + return pa.struct({f.name: pa.list_(f.type, schema.length) for f in value_type}) return pa.list_(value_type, schema.length) # Other objects are callable which returns their data type (ClassLabel, Array2D, Translation, Arrow datatype creation methods) @@ -963,3 +963,59 @@ def encode_batch(self, batch): def copy(self) -> "Features": return copy.deepcopy(self) + + def reorder_fields_as(self, other: "Features") -> "Features": + """ + The order of the fields is important since it matters for the underlying arrow data. + This method is used to re-order your features to match the fields orders of other features. + + Re-ordering the fields allows to make the underlying arrow data type match. + + Example:: + + >>> from datasets import Features, Sequence, Value + >>> # let's say we have to features with a different order of nested fields (for a and b for example) + >>> f1 = Features({"root": Sequence({"a": Value("string"), "b": Value("string")})}) + >>> f2 = Features({"root": {"b": Sequence(Value("string")), "a": Sequence(Value("string"))}}) + >>> assert f1.type != f2.type + >>> # re-ordering keeps the base structure (here Sequence is defined at the root level), but make the fields order match + >>> f1.reorder_fields_as(f2) + {'root': Sequence(feature={'b': Value(dtype='string', id=None), 'a': Value(dtype='string', id=None)}, length=-1, id=None)} + >>> assert f1.reorder_fields_as(f2).type == f2.type + + """ + + def recursive_reorder(source, target, stack=""): + stack_position = " at " + stack[1:] if stack else "" + if isinstance(target, Sequence): + target = target.feature + if isinstance(target, dict): + target = {k: [v] for k, v in target.items()} + else: + target = [target] + if isinstance(source, Sequence): + source, id_, length = source.feature, source.id, source.length + if isinstance(source, dict): + source = {k: [v] for k, v in source.items()} + reordered = recursive_reorder(source, target, stack) + return Sequence({k: v[0] for k, v in reordered.items()}, id=id_, length=length) + else: + source = [source] + reordered = recursive_reorder(source, target, stack) + return Sequence(reordered[0], id=id_, length=length) + elif isinstance(source, dict): + if not isinstance(target, dict): + raise ValueError(f"Type mismatch: between {source} and {target}" + stack_position) + if sorted(source) != sorted(target): + raise ValueError(f"Keys mismatch: between {source} and {target}" + stack_position) + return {key: recursive_reorder(source[key], target[key], stack + f".{key}") for key in target} + elif isinstance(source, list): + if not isinstance(target, list): + raise ValueError(f"Type mismatch: between {source} and {target}" + stack_position) + if len(source) != len(target): + raise ValueError(f"Length mismatch: between {source} and {target}" + stack_position) + return [recursive_reorder(source[i], target[i], stack + f".") for i in range(len(target))] + else: + return source + + return Features(recursive_reorder(self, other)) diff --git a/tests/test_features.py b/tests/test_features.py index 216bae63b1a..9bb37b9b695 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -76,6 +76,126 @@ def test_feature_named_type(self): reloaded_features = Features.from_dict(asdict(ds_info)["features"]) assert features == reloaded_features + def test_reorder_fields_as(self): + features = Features( + { + "id": Value("string"), + "document": { + "title": Value("string"), + "url": Value("string"), + "html": Value("string"), + "tokens": Sequence({"token": Value("string"), "is_html": Value("bool")}), + }, + "question": { + "text": Value("string"), + "tokens": Sequence(Value("string")), + }, + "annotations": Sequence( + { + "id": Value("string"), + "long_answer": { + "start_token": Value("int64"), + "end_token": Value("int64"), + "start_byte": Value("int64"), + "end_byte": Value("int64"), + }, + "short_answers": Sequence( + { + "start_token": Value("int64"), + "end_token": Value("int64"), + "start_byte": Value("int64"), + "end_byte": Value("int64"), + "text": Value("string"), + } + ), + "yes_no_answer": ClassLabel(names=["NO", "YES"]), + } + ), + } + ) + + other = Features( # same but with [] instead of sequences, and with a shuffled fields order + { + "id": Value("string"), + "document": { + "tokens": Sequence({"token": Value("string"), "is_html": Value("bool")}), + "title": Value("string"), + "url": Value("string"), + "html": Value("string"), + }, + "question": { + "text": Value("string"), + "tokens": [Value("string")], + }, + "annotations": { + "yes_no_answer": [ClassLabel(names=["NO", "YES"])], + "id": [Value("string")], + "long_answer": [ + { + "end_byte": Value("int64"), + "start_token": Value("int64"), + "end_token": Value("int64"), + "start_byte": Value("int64"), + } + ], + "short_answers": [ + Sequence( + { + "text": Value("string"), + "start_token": Value("int64"), + "end_token": Value("int64"), + "start_byte": Value("int64"), + "end_byte": Value("int64"), + } + ) + ], + }, + } + ) + + expected = Features( + { + "id": Value("string"), + "document": { + "tokens": Sequence({"token": Value("string"), "is_html": Value("bool")}), + "title": Value("string"), + "url": Value("string"), + "html": Value("string"), + }, + "question": { + "text": Value("string"), + "tokens": Sequence(Value("string")), + }, + "annotations": Sequence( + { + "yes_no_answer": ClassLabel(names=["NO", "YES"]), + "id": Value("string"), + "long_answer": { + "end_byte": Value("int64"), + "start_token": Value("int64"), + "end_token": Value("int64"), + "start_byte": Value("int64"), + }, + "short_answers": Sequence( + { + "text": Value("string"), + "start_token": Value("int64"), + "end_token": Value("int64"), + "start_byte": Value("int64"), + "end_byte": Value("int64"), + } + ), + } + ), + } + ) + + reordered_features = features.reorder_fields_as(other) + self.assertDictEqual(reordered_features, expected) + self.assertEqual(reordered_features.type, other.type) + self.assertEqual(reordered_features.type, expected.type) + self.assertNotEqual(reordered_features.type, features.type) + def test_classlabel_init(tmp_path_factory): names = ["negative", "positive"] diff --git a/tests/test_table.py b/tests/test_table.py index 78f5807b339..4dd8c9b4425 100644 --- a/tests/test_table.py +++ b/tests/test_table.py @@ -220,9 +220,13 @@ def test_in_memory_table_from_buffer(in_memory_pa_table): def test_in_memory_table_from_pandas(in_memory_pa_table): df = in_memory_pa_table.to_pandas() with assert_arrow_memory_increases(): + # with no schema it might infer another order of the fields in the schema table = InMemoryTable.from_pandas(df) - assert table.table == in_memory_pa_table assert isinstance(table, InMemoryTable) + # by specifying schema we get the same order of features, and so the exact same table + table = InMemoryTable.from_pandas(df, schema=in_memory_pa_table.schema) + assert table.table == in_memory_pa_table + assert isinstance(table, InMemoryTable) def test_in_memory_table_from_arrays(in_memory_pa_table):