Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 184 additions & 74 deletions src/datasets/utils/metadata.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
import logging
from collections import Counter
from dataclasses import dataclass, fields
from dataclasses import asdict, dataclass, fields
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union


# loading package files: https://stackoverflow.com/a/20885799
Expand Down Expand Up @@ -66,7 +66,7 @@ def yaml_block_from_readme(path: Path) -> Optional[str]:


def metadata_dict_from_readme(path: Path) -> Optional[Dict[str, List[str]]]:
""" "Loads a dataset's metadata from the dataset card (REAMDE.md), as a Python dict"""
"""Loads a dataset's metadata from the dataset card (REAMDE.md), as a Python dict"""
yaml_block = yaml_block_from_readme(path=path)
if yaml_block is None:
return None
Expand All @@ -77,61 +77,126 @@ def metadata_dict_from_readme(path: Path) -> Optional[Dict[str, List[str]]]:
ValidatorOutput = Tuple[List[str], Optional[str]]


def tagset_validator(values: List[str], reference_values: List[str], name: str, url: str) -> ValidatorOutput:
invalid_values = [v for v in values if v not in reference_values]
def tagset_validator(
items: Union[List[str], Dict[str, List[str]]],
reference_values: List[str],
name: str,
url: str,
escape_validation_predicate_fn: Optional[Callable[[Any], bool]] = None,
) -> ValidatorOutput:
if isinstance(items, list):
if escape_validation_predicate_fn is not None:
invalid_values = [
v for v in items if v not in reference_values and escape_validation_predicate_fn(v) is False
]
else:
invalid_values = [v for v in items if v not in reference_values]

else:
invalid_values = []
if escape_validation_predicate_fn is not None:
for config_name, values in items.items():
invalid_values += [
v for v in values if v not in reference_values and escape_validation_predicate_fn(v) is False
]
else:
for config_name, values in items.items():
invalid_values += [v for v in values if v not in reference_values]

if len(invalid_values) > 0:
return [], f"{invalid_values} are not registered tags for '{name}', reference at {url}"
return values, None
return items, None


def escape_validation_for_predicate(
values: List[Any], predicate_fn: Callable[[Any], bool]
) -> Tuple[List[Any], List[Any]]:
trues, falses = list(), list()
for v in values:
if predicate_fn(v):
trues.append(v)
def validate_type(value: Any, expected_type: Type):
error_string = ""
NoneType = type(None)
if expected_type == NoneType:
if not isinstance(value, NoneType):
return f"Expected `{NoneType}`. Found value: `{value}` of type `{type(value)}`.\n"
else:
falses.append(v)
if len(trues) > 0:
logger.warning(f"The following values will escape validation: {trues}")
return trues, falses
return error_string
if expected_type == str:
if not isinstance(value, str):
return f"Expected `{str}`. Found value: `{value}` of type: `{type(value)}`.\n"

elif isinstance(value, str) and len(value) == 0:
return (
f"Expected `{str}` with length > 0. Found value: `{value}` of type: `{type(value)}` with length: 0.\n"
)
else:
return error_string
# Add more `elif` statements if primitive type checking is needed
else:
expected_type_origin = expected_type.__origin__
expected_type_args = expected_type.__args__

if expected_type_origin == Union:
for type_arg in expected_type_args:
temp_error_string = validate_type(value, type_arg)
if temp_error_string == "": # at least one type is successfully validated
return temp_error_string
else:
if error_string == "":
error_string = "(" + temp_error_string + ")"
else:
error_string += "\nOR\n" + "(" + temp_error_string + ")"

else:
# Assuming `List`/`Dict`/`Tuple`
if not isinstance(value, expected_type_origin) or len(value) == 0:
return f"Expected `{expected_type_origin}` with length > 0. Found value of type: `{type(value)}`, with length: {len(value)}.\n"

if expected_type_origin == Dict:
key_type, value_type = expected_type_args
key_error_string = ""
value_error_string = ""
for k, v in value.items():
key_error_string += validate_type(k, key_type)
value_error_string += validate_type(v, value_type)
if key_error_string != "" or value_error_string != "":
return f"Typing errors with keys:\n {key_error_string} and values:\n {value_error_string}"

else: # `List`/`Tuple`
value_type = expected_type_args[0]
value_error_string = ""
for v in value:
value_error_string += validate_type(v, value_type)
if value_error_string != "":
return f"Typing errors with values:\n {value_error_string}"

return error_string


def validate_metadata_type(metadata_dict: dict):
fields_types = {field.name: field.type for field in fields(DatasetMetadata)}
list_typing_errors = {
name: value
for name, value in metadata_dict.items()
if fields_types.get(name, List[str]) == List[str]
and (not isinstance(value, list) or len(value) == 0 or not isinstance(value[0], str))
}
if len(list_typing_errors) > 0:
raise TypeError(f"Found fields that are not non-empty list of strings: {list_typing_errors}")

other_typing_errors = {
name: value
for name, value in metadata_dict.items()
if fields_types.get(name, List[str]) != List[str] and isinstance(value, list)
}
if len(other_typing_errors) > 0:
raise TypeError(f"Found fields that are lists instead of single strings: {other_typing_errors}")
field_types = {field.name: field.type for field in fields(DatasetMetadata)}

typing_errors = {}
for field_name, field_value in metadata_dict.items():
field_type_error = validate_type(
metadata_dict[field_name], field_types.get(field_name, Union[List[str], Dict[str, List[str]]])
)
if field_type_error != "":
typing_errors[field_name] = field_type_error
if len(typing_errors) > 0:
raise TypeError(f"The following typing errors are found: {typing_errors}")


@dataclass
class DatasetMetadata:
annotations_creators: List[str]
language_creators: List[str]
languages: List[str]
licenses: List[str]
multilinguality: List[str]
size_categories: List[str]
source_datasets: List[str]
task_categories: List[str]
task_ids: List[str]
annotations_creators: Union[List[str], Dict[str, List[str]]]
language_creators: Union[List[str], Dict[str, List[str]]]
languages: Union[List[str], Dict[str, List[str]]]
licenses: Union[List[str], Dict[str, List[str]]]
multilinguality: Union[List[str], Dict[str, List[str]]]
pretty_names: Union[str, Dict[str, str]]
size_categories: Union[List[str], Dict[str, List[str]]]
source_datasets: Union[List[str], Dict[str, List[str]]]
task_categories: Union[List[str], Dict[str, List[str]]]
task_ids: Union[List[str], Dict[str, List[str]]]
paperswithcode_id: Optional[str] = None

def __post_init__(self):
def validate(self):
validate_metadata_type(metadata_dict=vars(self))

self.annotations_creators, annotations_creators_errors = self.validate_annotations_creators(
Expand Down Expand Up @@ -168,7 +233,7 @@ def __post_init__(self):
exception_msg_dict[field] = errs
if len(exception_msg_dict) > 0:
raise TypeError(
"Could not validate the metada, found the following errors:\n"
"Could not validate the metadata, found the following errors:\n"
+ "\n".join(f"* field '{fieldname}':\n\t{err}" for fieldname, err in exception_msg_dict.items())
)

Expand All @@ -190,7 +255,7 @@ def from_readme(cls, path: Path) -> "DatasetMetadata":
if yaml_string is not None:
return cls.from_yaml_string(yaml_string)
else:
raise TypeError(f"did not find a yaml block in '{path}'")
raise TypeError(f"Unable to find a yaml block in '{path}'")

@classmethod
def from_yaml_string(cls, string: str) -> "DatasetMetadata":
Expand All @@ -206,24 +271,20 @@ def from_yaml_string(cls, string: str) -> "DatasetMetadata":
:obj:`TypeError`: If the dataset's metadata is invalid
"""
metada_dict = yaml.load(string, Loader=NoDuplicateSafeLoader) or dict()
# flatten the metadata of each config
for key in metada_dict:
if isinstance(metada_dict[key], dict):
metada_dict[key] = list(set(sum(metada_dict[key].values(), [])))
return cls(**metada_dict)

@staticmethod
def validate_annotations_creators(annotations_creators: List[str]) -> ValidatorOutput:
def validate_annotations_creators(annotations_creators: Union[List[str], Dict[str, List[str]]]) -> ValidatorOutput:
return tagset_validator(
annotations_creators, known_creators["annotations"], "annotations_creators", known_creators_url
)

@staticmethod
def validate_language_creators(language_creators: List[str]) -> ValidatorOutput:
def validate_language_creators(language_creators: Union[List[str], Dict[str, List[str]]]) -> ValidatorOutput:
return tagset_validator(language_creators, known_creators["language"], "language_creators", known_creators_url)

@staticmethod
def validate_language_codes(languages: List[str]) -> ValidatorOutput:
def validate_language_codes(languages: Union[List[str], Dict[str, List[str]]]) -> ValidatorOutput:
return tagset_validator(
values=languages,
reference_values=known_language_codes.keys(),
Expand All @@ -232,47 +293,53 @@ def validate_language_codes(languages: List[str]) -> ValidatorOutput:
)

@staticmethod
def validate_licences(licenses: List[str]) -> ValidatorOutput:
others, to_validate = escape_validation_for_predicate(
licenses, lambda e: "-other-" in e or e.startswith("other-")
def validate_licences(licenses: Union[List[str], Dict[str, List[str]]]) -> ValidatorOutput:
validated, error = tagset_validator(
licenses,
list(known_licenses.keys()),
"licenses",
known_licenses_url,
lambda e: "-other-" in e or e.startswith("other-"),
)
validated, error = tagset_validator(to_validate, list(known_licenses.keys()), "licenses", known_licenses_url)
return [*validated, *others], error
return validated, error

@staticmethod
def validate_task_categories(task_categories: List[str]) -> ValidatorOutput:
def validate_task_categories(task_categories: Union[List[str], Dict[str, List[str]]]) -> ValidatorOutput:
# TODO: we're currently ignoring all values starting with 'other' as our task taxonomy is bound to change
# in the near future and we don't want to waste energy in tagging against a moving taxonomy.
known_set = list(known_task_ids.keys())
others, to_validate = escape_validation_for_predicate(task_categories, lambda e: e.startswith("other-"))
validated, error = tagset_validator(to_validate, known_set, "task_categories", known_task_ids_url)
return [*validated, *others], error
validated, error = tagset_validator(
task_categories, known_set, "task_categories", known_task_ids_url, lambda e: e.startswith("other-")
)
return validated, error

@staticmethod
def validate_task_ids(task_ids: List[str]) -> ValidatorOutput:
def validate_task_ids(task_ids: Union[List[str], Dict[str, List[str]]]) -> ValidatorOutput:
# TODO: we're currently ignoring all values starting with 'other' as our task taxonomy is bound to change
# in the near future and we don't want to waste energy in tagging against a moving taxonomy.
known_set = [tid for _cat, d in known_task_ids.items() for tid in d["options"]]
others, to_validate = escape_validation_for_predicate(
task_ids, lambda e: "-other-" in e or e.startswith("other-")
validated, error = tagset_validator(
task_ids, known_set, "task_ids", known_task_ids_url, lambda e: "-other-" in e or e.startswith("other-")
)
validated, error = tagset_validator(to_validate, known_set, "task_ids", known_task_ids_url)
return [*validated, *others], error
return validated, error

@staticmethod
def validate_mulitlinguality(multilinguality: List[str]) -> ValidatorOutput:
others, to_validate = escape_validation_for_predicate(multilinguality, lambda e: e.startswith("other-"))
def validate_mulitlinguality(multilinguality: Union[List[str], Dict[str, List[str]]]) -> ValidatorOutput:
validated, error = tagset_validator(
to_validate, list(known_multilingualities.keys()), "multilinguality", known_size_categories_url
multilinguality,
list(known_multilingualities.keys()),
"multilinguality",
known_size_categories_url,
lambda e: e.startswith("other-"),
)
return [*validated, *others], error
return validated, error

@staticmethod
def validate_size_catgeories(size_cats: List[str]) -> ValidatorOutput:
def validate_size_catgeories(size_cats: Union[List[str], Dict[str, List[str]]]) -> ValidatorOutput:
return tagset_validator(size_cats, known_size_categories, "size_categories", known_size_categories_url)

@staticmethod
def validate_source_datasets(sources: List[str]) -> ValidatorOutput:
def validate_source_datasets(sources: Union[List[str], Dict[str, List[str]]]) -> ValidatorOutput:
invalid_values = []
for src in sources:
is_ok = src in ["original", "extended"] or src.startswith("extended|")
Expand All @@ -299,6 +366,48 @@ def validate_paperswithcode_id_errors(paperswithcode_id: Optional[str]) -> Valid
else:
return paperswithcode_id, None

@staticmethod
def validate_pretty_names(pretty_names: Union[str, Dict[str, str]]):
if isinstance(pretty_names, str):
if len(pretty_names) == 0:
return None, f"The pretty name must have a length greater than 0 but got an empty string."
else:
error_string = ""
for key, value in pretty_names.items():
if len(value) == 0:
error_string += f"The pretty name must have a length greater than 0 but got an empty string for config: {key}.\n"

if error_string == "":
return None, error_string
else:
return pretty_names, None

def get_metadata_by_config_name(self, name: str) -> "DatasetMetadata":
metadata_dict = asdict(self)
config_name_hit = []
has_multi_configs = []
result_dict = {}
for tag_key, tag_value in metadata_dict.items():
if isinstance(tag_value, str) or isinstance(tag_value, list):
result_dict[tag_key] = tag_value
elif isinstance(tag_value, dict):
has_multi_configs.append(tag_key)
for config_name, value in tag_value.items():
if config_name == name:
result_dict[tag_key] = value
config_name_hit.append(tag_key)

if len(has_multi_configs) > 0 and has_multi_configs != config_name_hit:
raise TypeError(
f"The following tags have multiple configs: {has_multi_configs} but the config `{name}` was found only in: {config_name_hit}."
)
if config_name_hit == 0:
logger.warning(
"No matching config names found in the metadata, using the common values to create metadata."
)

return DatasetMetadata(**result_dict)


if __name__ == "__main__":
from argparse import ArgumentParser
Expand All @@ -308,4 +417,5 @@ def validate_paperswithcode_id_errors(paperswithcode_id: Optional[str]) -> Valid
args = ap.parse_args()

readme_filepath = Path(args.readme_filepath)
DatasetMetadata.from_readme(readme_filepath)
dataset_metadata = DatasetMetadata.from_readme(readme_filepath)
dataset_metadata.validate()
Loading