Skip to content

Commit 247e3cf

Browse files
committed
allow dataset_infos to be struct or list in YAML
1 parent 5e583ef commit 247e3cf

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

src/datasets/info.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,11 @@ def write_to_directory(self, dataset_infos_dir, overwrite=False, pretty_print=Fa
340340
dataset_metadata = DatasetMetadata.from_readme(Path(dataset_readme_path))
341341
else:
342342
dataset_metadata = {}
343-
dataset_metadata["dataset_infos"] = [dset_info._to_yaml_dict() for dset_info in total_dataset_infos.values()]
344-
dataset_metadata.to_readme(Path(dataset_readme_path))
343+
if total_dataset_infos:
344+
dataset_metadata["dataset_infos"] = [dset_info._to_yaml_dict() for dset_info in total_dataset_infos.values()]
345+
if len(dataset_metadata["dataset_infos"]) == 1:
346+
dataset_metadata["dataset_infos"] = dataset_metadata["dataset_infos"][0]
347+
dataset_metadata.to_readme(Path(dataset_readme_path))
345348

346349
@classmethod
347350
def from_directory(cls, dataset_infos_dir):
@@ -357,15 +360,20 @@ def from_directory(cls, dataset_infos_dir):
357360
)
358361
if os.path.exists(os.path.join(dataset_infos_dir, "README.md")):
359362
dataset_metadata = DatasetMetadata.from_readme(Path(dataset_infos_dir) / "README.md")
360-
if isinstance(dataset_metadata.get("dataset_infos"), list) and dataset_metadata["dataset_infos"]:
361-
dataset_infos_dict.update(
362-
{
363-
dataset_info_yaml_dict.get("config_name", "default"): DatasetInfo._from_yaml_dict(
364-
dataset_info_yaml_dict
365-
)
366-
for dataset_info_yaml_dict in dataset_metadata["dataset_infos"]
367-
}
368-
)
363+
if isinstance(dataset_metadata.get("dataset_infos"), (list, dict)) and dataset_metadata["dataset_infos"]:
364+
if isinstance(dataset_metadata["dataset_infos"], list):
365+
dataset_infos_dict.update(
366+
{
367+
dataset_info_yaml_dict.get("config_name", "default"): DatasetInfo._from_yaml_dict(
368+
dataset_info_yaml_dict
369+
)
370+
for dataset_info_yaml_dict in dataset_metadata["dataset_infos"]
371+
}
372+
)
373+
else:
374+
dataset_infos_dict[
375+
dataset_metadata["dataset_infos"].get("config_name", "default")
376+
] = dataset_metadata["dataset_infos"]
369377
return cls(**dataset_infos_dict)
370378

371379

0 commit comments

Comments
 (0)