Skip to content

Commit 9d8bf36

Browse files
authored
Fix save_to_disk nested features order in dataset_info.json (#2422)
* fix save_to_disk nested features order in dataet_info.json * add test
1 parent 0d34edd commit 9d8bf36

File tree

2 files changed

+39
-20
lines changed

2 files changed

+39
-20
lines changed

src/datasets/arrow_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,9 @@ def save_to_disk(self, dataset_path: str, fs=None):
640640
with fs.open(
641641
Path(dataset_path, config.DATASET_INFO_FILENAME).as_posix(), "w", encoding="utf-8"
642642
) as dataset_info_file:
643-
json.dump(dataset_info, dataset_info_file, indent=2, sort_keys=True)
643+
# Sort only the first level of keys, or we might shuffle fields of nested features if we use sort_keys=True
644+
sorted_keys_dataset_info = {key: dataset_info[key] for key in sorted(dataset_info)}
645+
json.dump(sorted_keys_dataset_info, dataset_info_file, indent=2)
644646
logger.info("Dataset saved in {}".format(dataset_path))
645647

646648
@staticmethod

tests/test_arrow_dataset.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -83,25 +83,29 @@ def inject_fixtures(self, caplog):
8383
self._caplog = caplog
8484

8585
def _create_dummy_dataset(
86-
self, in_memory: bool, tmp_dir: str, multiple_columns=False, array_features=False
86+
self, in_memory: bool, tmp_dir: str, multiple_columns=False, array_features=False, nested_features=False
8787
) -> Dataset:
88+
assert int(multiple_columns) + int(array_features) + int(nested_features) < 2
8889
if multiple_columns:
89-
if array_features:
90-
data = {
91-
"col_1": [[[True, False], [False, True]]] * 4, # 2D
92-
"col_2": [[[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]]] * 4, # 3D array
93-
"col_3": [[3, 2, 1, 0]] * 4, # Sequence
90+
data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"], "col_3": [False, True, False, True]}
91+
dset = Dataset.from_dict(data)
92+
elif array_features:
93+
data = {
94+
"col_1": [[[True, False], [False, True]]] * 4, # 2D
95+
"col_2": [[[["a", "b"], ["c", "d"]], [["e", "f"], ["g", "h"]]]] * 4, # 3D array
96+
"col_3": [[3, 2, 1, 0]] * 4, # Sequence
97+
}
98+
features = Features(
99+
{
100+
"col_1": Array2D(shape=(2, 2), dtype="bool"),
101+
"col_2": Array3D(shape=(2, 2, 2), dtype="string"),
102+
"col_3": Sequence(feature=Value("int64")),
94103
}
95-
features = Features(
96-
{
97-
"col_1": Array2D(shape=(2, 2), dtype="bool"),
98-
"col_2": Array3D(shape=(2, 2, 2), dtype="string"),
99-
"col_3": Sequence(feature=Value("int64")),
100-
}
101-
)
102-
else:
103-
data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"], "col_3": [False, True, False, True]}
104-
features = None
104+
)
105+
dset = Dataset.from_dict(data, features=features)
106+
elif nested_features:
107+
data = {"nested": [{"a": i, "x": i * 10, "c": i * 100} for i in range(1, 11)]}
108+
features = Features({"nested": {"a": Value("int64"), "x": Value("int64"), "c": Value("int64")}})
105109
dset = Dataset.from_dict(data, features=features)
106110
else:
107111
dset = Dataset.from_dict({"filename": ["my_name-train" + "_" + str(x) for x in np.arange(30).tolist()]})
@@ -139,7 +143,7 @@ def test_dummy_dataset(self, in_memory):
139143
self.assertEqual(dset["col_1"][0], 3)
140144

141145
with tempfile.TemporaryDirectory() as tmp_dir:
142-
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True, array_features=True) as dset:
146+
with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
143147
self.assertDictEqual(
144148
dset.features,
145149
Features(
@@ -249,6 +253,19 @@ def test_dummy_dataset_serialize(self, in_memory):
249253
self.assertEqual(dset[0]["filename"], "my_name-train_0")
250254
self.assertEqual(dset["filename"][0], "my_name-train_0")
251255

256+
with self._create_dummy_dataset(in_memory, tmp_dir, nested_features=True) as dset:
257+
with assert_arrow_memory_doesnt_increase():
258+
dset.save_to_disk(dataset_path)
259+
260+
with Dataset.load_from_disk(dataset_path) as dset:
261+
self.assertEqual(len(dset), 10)
262+
self.assertDictEqual(
263+
dset.features,
264+
Features({"nested": {"a": Value("int64"), "x": Value("int64"), "c": Value("int64")}}),
265+
)
266+
self.assertDictEqual(dset[0]["nested"], {"a": 1, "c": 100, "x": 10})
267+
self.assertDictEqual(dset["nested"][0], {"a": 1, "c": 100, "x": 10})
268+
252269
def test_dummy_dataset_load_from_disk(self, in_memory):
253270
with tempfile.TemporaryDirectory() as tmp_dir:
254271

@@ -453,7 +470,7 @@ def test_class_encode_column(self, in_memory):
453470
assert_arrow_metadata_are_synced_with_dataset_features(casted_dset)
454471

455472
# Test raises if feature is an array / sequence
456-
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True, array_features=True) as dset:
473+
with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
457474
for column in dset.column_names:
458475
with self.assertRaises(ValueError):
459476
dset.class_encode_column(column)
@@ -1597,7 +1614,7 @@ def test_to_csv(self, in_memory):
15971614
self.assertListEqual(list(csv_dset.columns), list(dset.column_names))
15981615

15991616
# With array features
1600-
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True, array_features=True) as dset:
1617+
with self._create_dummy_dataset(in_memory, tmp_dir, array_features=True) as dset:
16011618
file_path = os.path.join(tmp_dir, "test_path.csv")
16021619
bytes_written = dset.to_csv(path_or_buf=file_path)
16031620

0 commit comments

Comments
 (0)