diff --git a/src/datasets/info.py b/src/datasets/info.py index bab49e5deae..3348da1c631 100644 --- a/src/datasets/info.py +++ b/src/datasets/info.py @@ -269,6 +269,11 @@ def _dump_license(self, file): @classmethod def from_merge(cls, dataset_infos: List["DatasetInfo"]): dataset_infos = [dset_info.copy() for dset_info in dataset_infos if dset_info is not None] + + if len(dataset_infos) > 0 and all(dataset_infos[0] == dset_info for dset_info in dataset_infos): + # if all dataset_infos are equal we don't need to merge. Just return the first. + return dataset_infos[0] + description = "\n\n".join(unique_values(info.description for info in dataset_infos)).strip() citation = "\n\n".join(unique_values(info.citation for info in dataset_infos)).strip() homepage = "\n\n".join(unique_values(info.homepage for info in dataset_infos)).strip() diff --git a/tests/test_info.py b/tests/test_info.py index f82c98fb161..e128011c136 100644 --- a/tests/test_info.py +++ b/tests/test_info.py @@ -134,3 +134,33 @@ def test_dataset_infos_dict_dump_and_reload(tmp_path, dataset_infos_dict: Datase if dataset_infos_dict: assert os.path.exists(os.path.join(tmp_path, "README.md")) + + +@pytest.mark.parametrize( + "dataset_info", + [ + None, + DatasetInfo(), + DatasetInfo( + description="foo", + features=Features({"a": Value("int32")}), + builder_name="builder", + config_name="config", + version="1.0.0", + splits=[{"name": "train"}], + download_size=42, + dataset_name="dataset_name", + ), + ], +) +def test_from_merge_same_dataset_infos(dataset_info): + num_elements = 3 + if dataset_info is not None: + dataset_info_list = [dataset_info.copy() for _ in range(num_elements)] + else: + dataset_info_list = [None] * num_elements + dataset_info_merged = DatasetInfo.from_merge(dataset_info_list) + if dataset_info is not None: + assert dataset_info == dataset_info_merged + else: + assert DatasetInfo() == dataset_info_merged