Skip to content

Commit 92aacfe

Browse files
authored
Fix NQ features loading: reorder fields of features to match nested fields order in arrow data (#2438)
* reorder fields of features to match nested fields order in arrow data * fix test * docstring * style
1 parent 43cbc50 commit 92aacfe

File tree

4 files changed

+184
-2
lines changed

4 files changed

+184
-2
lines changed

src/datasets/arrow_dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ def __init__(
265265
inferred_features = Features.from_arrow_schema(arrow_table.schema)
266266
if self.info.features is None:
267267
self.info.features = inferred_features
268+
else: # make sure the nested columns are in the right order
269+
self.info.features = self.info.features.reorder_fields_as(inferred_features)
268270

269271
# Infer fingerprint if None
270272

src/datasets/features.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,7 @@ def get_nested_type(schema: FeatureType) -> pa.DataType:
831831
value_type = get_nested_type(schema.feature)
832832
# We allow to reverse list of dict => dict of list for compatiblity with tfds
833833
if isinstance(value_type, pa.StructType):
834-
return pa.struct(dict(sorted((f.name, pa.list_(f.type, schema.length)) for f in value_type)))
834+
return pa.struct({f.name: pa.list_(f.type, schema.length) for f in value_type})
835835
return pa.list_(value_type, schema.length)
836836

837837
# Other objects are callable which returns their data type (ClassLabel, Array2D, Translation, Arrow datatype creation methods)
@@ -963,3 +963,59 @@ def encode_batch(self, batch):
963963

964964
def copy(self) -> "Features":
965965
return copy.deepcopy(self)
966+
967+
def reorder_fields_as(self, other: "Features") -> "Features":
968+
"""
969+
The order of the fields is important since it matters for the underlying arrow data.
970+
This method is used to re-order your features to match the fields orders of other features.
971+
972+
Re-ordering the fields allows to make the underlying arrow data type match.
973+
974+
Example::
975+
976+
>>> from datasets import Features, Sequence, Value
977+
>>> # let's say we have to features with a different order of nested fields (for a and b for example)
978+
>>> f1 = Features({"root": Sequence({"a": Value("string"), "b": Value("string")})})
979+
>>> f2 = Features({"root": {"b": Sequence(Value("string")), "a": Sequence(Value("string"))}})
980+
>>> assert f1.type != f2.type
981+
>>> # re-ordering keeps the base structure (here Sequence is defined at the root level), but make the fields order match
982+
>>> f1.reorder_fields_as(f2)
983+
{'root': Sequence(feature={'b': Value(dtype='string', id=None), 'a': Value(dtype='string', id=None)}, length=-1, id=None)}
984+
>>> assert f1.reorder_fields_as(f2).type == f2.type
985+
986+
"""
987+
988+
def recursive_reorder(source, target, stack=""):
989+
stack_position = " at " + stack[1:] if stack else ""
990+
if isinstance(target, Sequence):
991+
target = target.feature
992+
if isinstance(target, dict):
993+
target = {k: [v] for k, v in target.items()}
994+
else:
995+
target = [target]
996+
if isinstance(source, Sequence):
997+
source, id_, length = source.feature, source.id, source.length
998+
if isinstance(source, dict):
999+
source = {k: [v] for k, v in source.items()}
1000+
reordered = recursive_reorder(source, target, stack)
1001+
return Sequence({k: v[0] for k, v in reordered.items()}, id=id_, length=length)
1002+
else:
1003+
source = [source]
1004+
reordered = recursive_reorder(source, target, stack)
1005+
return Sequence(reordered[0], id=id_, length=length)
1006+
elif isinstance(source, dict):
1007+
if not isinstance(target, dict):
1008+
raise ValueError(f"Type mismatch: between {source} and {target}" + stack_position)
1009+
if sorted(source) != sorted(target):
1010+
raise ValueError(f"Keys mismatch: between {source} and {target}" + stack_position)
1011+
return {key: recursive_reorder(source[key], target[key], stack + f".{key}") for key in target}
1012+
elif isinstance(source, list):
1013+
if not isinstance(target, list):
1014+
raise ValueError(f"Type mismatch: between {source} and {target}" + stack_position)
1015+
if len(source) != len(target):
1016+
raise ValueError(f"Length mismatch: between {source} and {target}" + stack_position)
1017+
return [recursive_reorder(source[i], target[i], stack + f".<list>") for i in range(len(target))]
1018+
else:
1019+
return source
1020+
1021+
return Features(recursive_reorder(self, other))

tests/test_features.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,126 @@ def test_feature_named_type(self):
7676
reloaded_features = Features.from_dict(asdict(ds_info)["features"])
7777
assert features == reloaded_features
7878

79+
def test_reorder_fields_as(self):
80+
features = Features(
81+
{
82+
"id": Value("string"),
83+
"document": {
84+
"title": Value("string"),
85+
"url": Value("string"),
86+
"html": Value("string"),
87+
"tokens": Sequence({"token": Value("string"), "is_html": Value("bool")}),
88+
},
89+
"question": {
90+
"text": Value("string"),
91+
"tokens": Sequence(Value("string")),
92+
},
93+
"annotations": Sequence(
94+
{
95+
"id": Value("string"),
96+
"long_answer": {
97+
"start_token": Value("int64"),
98+
"end_token": Value("int64"),
99+
"start_byte": Value("int64"),
100+
"end_byte": Value("int64"),
101+
},
102+
"short_answers": Sequence(
103+
{
104+
"start_token": Value("int64"),
105+
"end_token": Value("int64"),
106+
"start_byte": Value("int64"),
107+
"end_byte": Value("int64"),
108+
"text": Value("string"),
109+
}
110+
),
111+
"yes_no_answer": ClassLabel(names=["NO", "YES"]),
112+
}
113+
),
114+
}
115+
)
116+
117+
other = Features( # same but with [] instead of sequences, and with a shuffled fields order
118+
{
119+
"id": Value("string"),
120+
"document": {
121+
"tokens": Sequence({"token": Value("string"), "is_html": Value("bool")}),
122+
"title": Value("string"),
123+
"url": Value("string"),
124+
"html": Value("string"),
125+
},
126+
"question": {
127+
"text": Value("string"),
128+
"tokens": [Value("string")],
129+
},
130+
"annotations": {
131+
"yes_no_answer": [ClassLabel(names=["NO", "YES"])],
132+
"id": [Value("string")],
133+
"long_answer": [
134+
{
135+
"end_byte": Value("int64"),
136+
"start_token": Value("int64"),
137+
"end_token": Value("int64"),
138+
"start_byte": Value("int64"),
139+
}
140+
],
141+
"short_answers": [
142+
Sequence(
143+
{
144+
"text": Value("string"),
145+
"start_token": Value("int64"),
146+
"end_token": Value("int64"),
147+
"start_byte": Value("int64"),
148+
"end_byte": Value("int64"),
149+
}
150+
)
151+
],
152+
},
153+
}
154+
)
155+
156+
expected = Features(
157+
{
158+
"id": Value("string"),
159+
"document": {
160+
"tokens": Sequence({"token": Value("string"), "is_html": Value("bool")}),
161+
"title": Value("string"),
162+
"url": Value("string"),
163+
"html": Value("string"),
164+
},
165+
"question": {
166+
"text": Value("string"),
167+
"tokens": Sequence(Value("string")),
168+
},
169+
"annotations": Sequence(
170+
{
171+
"yes_no_answer": ClassLabel(names=["NO", "YES"]),
172+
"id": Value("string"),
173+
"long_answer": {
174+
"end_byte": Value("int64"),
175+
"start_token": Value("int64"),
176+
"end_token": Value("int64"),
177+
"start_byte": Value("int64"),
178+
},
179+
"short_answers": Sequence(
180+
{
181+
"text": Value("string"),
182+
"start_token": Value("int64"),
183+
"end_token": Value("int64"),
184+
"start_byte": Value("int64"),
185+
"end_byte": Value("int64"),
186+
}
187+
),
188+
}
189+
),
190+
}
191+
)
192+
193+
reordered_features = features.reorder_fields_as(other)
194+
self.assertDictEqual(reordered_features, expected)
195+
self.assertEqual(reordered_features.type, other.type)
196+
self.assertEqual(reordered_features.type, expected.type)
197+
self.assertNotEqual(reordered_features.type, features.type)
198+
79199

80200
def test_classlabel_init(tmp_path_factory):
81201
names = ["negative", "positive"]

tests/test_table.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,13 @@ def test_in_memory_table_from_buffer(in_memory_pa_table):
220220
def test_in_memory_table_from_pandas(in_memory_pa_table):
221221
df = in_memory_pa_table.to_pandas()
222222
with assert_arrow_memory_increases():
223+
# with no schema it might infer another order of the fields in the schema
223224
table = InMemoryTable.from_pandas(df)
224-
assert table.table == in_memory_pa_table
225225
assert isinstance(table, InMemoryTable)
226+
# by specifying schema we get the same order of features, and so the exact same table
227+
table = InMemoryTable.from_pandas(df, schema=in_memory_pa_table.schema)
228+
assert table.table == in_memory_pa_table
229+
assert isinstance(table, InMemoryTable)
226230

227231

228232
def test_in_memory_table_from_arrays(in_memory_pa_table):

0 commit comments

Comments
 (0)