Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/datasets/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ def _dump_license(self, file):
@classmethod
def from_merge(cls, dataset_infos: List["DatasetInfo"]):
dataset_infos = [dset_info.copy() for dset_info in dataset_infos if dset_info is not None]

if len(dataset_infos) > 0 and all(dataset_infos[0] == dset_info for dset_info in dataset_infos):
# if all dataset_infos are equal we don't need to merge. Just return the first.
return dataset_infos[0]

description = "\n\n".join(unique_values(info.description for info in dataset_infos)).strip()
citation = "\n\n".join(unique_values(info.citation for info in dataset_infos)).strip()
homepage = "\n\n".join(unique_values(info.homepage for info in dataset_infos)).strip()
Expand Down
30 changes: 30 additions & 0 deletions tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,33 @@ def test_dataset_infos_dict_dump_and_reload(tmp_path, dataset_infos_dict: Datase

if dataset_infos_dict:
assert os.path.exists(os.path.join(tmp_path, "README.md"))


@pytest.mark.parametrize(
"dataset_info",
[
None,
DatasetInfo(),
DatasetInfo(
description="foo",
features=Features({"a": Value("int32")}),
builder_name="builder",
config_name="config",
version="1.0.0",
splits=[{"name": "train"}],
download_size=42,
dataset_name="dataset_name",
),
],
)
def test_from_merge_same_dataset_infos(dataset_info):
num_elements = 3
if dataset_info is not None:
dataset_info_list = [dataset_info.copy() for _ in range(num_elements)]
else:
dataset_info_list = [None] * num_elements
dataset_info_merged = DatasetInfo.from_merge(dataset_info_list)
if dataset_info is not None:
assert dataset_info == dataset_info_merged
else:
assert DatasetInfo() == dataset_info_merged