|
| 1 | +import json |
| 2 | +import logging |
| 3 | +from dataclasses import dataclass |
| 4 | +from pathlib import Path |
| 5 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
| 6 | + |
| 7 | + |
| 8 | +# loading package files: https://stackoverflow.com/a/20885799 |
| 9 | +try: |
| 10 | + import importlib.resources as pkg_resources |
| 11 | +except ImportError: |
| 12 | + # Try backported to PY<37 `importlib_resources`. |
| 13 | + import importlib_resources as pkg_resources |
| 14 | + |
| 15 | +import yaml |
| 16 | + |
| 17 | +from . import resources |
| 18 | + |
| 19 | + |
| 20 | +BASE_REF_URL = "https://github.com/huggingface/datasets/tree/master/src/datasets/utils" |
| 21 | +this_url = f"{BASE_REF_URL}/{__file__}" |
| 22 | +logger = logging.getLogger(__name__) |
| 23 | + |
| 24 | + |
| 25 | +def load_json_resource(resource: str) -> Tuple[Any, str]: |
| 26 | + content = pkg_resources.read_text(resources, resource) |
| 27 | + return json.loads(content), f"{BASE_REF_URL}/resources/{resource}" |
| 28 | + |
| 29 | + |
| 30 | +# Source of languages.json: |
| 31 | +# https://datahub.io/core/language-codes/r/ietf-language-tags.csv |
| 32 | +# Language names were obtained with langcodes: https://github.com/LuminosoInsight/langcodes |
| 33 | +known_language_codes, known_language_codes_url = load_json_resource("languages.json") |
| 34 | +known_licenses, known_licenses_url = load_json_resource("licenses.json") |
| 35 | +known_task_ids, known_task_ids_url = load_json_resource("tasks.json") |
| 36 | +known_creators, known_creators_url = load_json_resource("creators.json") |
| 37 | +known_size_categories, known_size_categories_url = load_json_resource("size_categories.json") |
| 38 | +known_multilingualities, known_multilingualities_url = load_json_resource("multilingualities.json") |
| 39 | + |
| 40 | + |
| 41 | +def yaml_block_from_readme(path: Path) -> Optional[str]: |
| 42 | + with path.open() as readme_file: |
| 43 | + content = [line.strip() for line in readme_file] |
| 44 | + |
| 45 | + if content[0] == "---" and "---" in content[1:]: |
| 46 | + yamlblock = "\n".join(content[1 : content[1:].index("---") + 1]) |
| 47 | + return yamlblock |
| 48 | + |
| 49 | + return None |
| 50 | + |
| 51 | + |
| 52 | +def metadata_dict_from_readme(path: Path) -> Optional[Dict[str, List[str]]]: |
| 53 | + """"Loads a dataset's metadata from the dataset card (REAMDE.md), as a Python dict""" |
| 54 | + yaml_block = yaml_block_from_readme(path=path) |
| 55 | + if yaml_block is None: |
| 56 | + return None |
| 57 | + metada_dict = yaml.safe_load(yaml_block) or dict() |
| 58 | + return metada_dict |
| 59 | + |
| 60 | + |
| 61 | +ValidatorOutput = Tuple[List[str], Optional[str]] |
| 62 | + |
| 63 | + |
| 64 | +def tagset_validator(values: List[str], reference_values: List[str], name: str, url: str) -> ValidatorOutput: |
| 65 | + invalid_values = [v for v in values if v not in reference_values] |
| 66 | + if len(invalid_values) > 0: |
| 67 | + return [], f"{invalid_values} are not registered tags for '{name}', reference at {url}" |
| 68 | + return values, None |
| 69 | + |
| 70 | + |
| 71 | +def escape_validation_for_predicate( |
| 72 | + values: List[Any], predicate_fn: Callable[[Any], bool] |
| 73 | +) -> Tuple[List[Any], List[Any]]: |
| 74 | + trues, falses = list(), list() |
| 75 | + for v in values: |
| 76 | + if predicate_fn(v): |
| 77 | + trues.append(v) |
| 78 | + else: |
| 79 | + falses.append(v) |
| 80 | + if len(trues) > 0: |
| 81 | + logger.warning(f"The following values will escape validation: {trues}") |
| 82 | + return trues, falses |
| 83 | + |
| 84 | + |
| 85 | +def validate_metadata_type(metadata_dict: dict): |
| 86 | + basic_typing_errors = { |
| 87 | + name: value |
| 88 | + for name, value in metadata_dict.items() |
| 89 | + if not isinstance(value, list) or len(value) == 0 or not isinstance(value[0], str) |
| 90 | + } |
| 91 | + if len(basic_typing_errors) > 0: |
| 92 | + raise TypeError(f"Found fields that are not non-empty list of strings: {basic_typing_errors}") |
| 93 | + |
| 94 | + |
| 95 | +@dataclass |
| 96 | +class DatasetMetadata: |
| 97 | + annotations_creators: List[str] |
| 98 | + language_creators: List[str] |
| 99 | + languages: List[str] |
| 100 | + licenses: List[str] |
| 101 | + multilinguality: List[str] |
| 102 | + size_categories: List[str] |
| 103 | + source_datasets: List[str] |
| 104 | + task_categories: List[str] |
| 105 | + task_ids: List[str] |
| 106 | + |
| 107 | + def __post_init__(self): |
| 108 | + validate_metadata_type(metadata_dict=vars(self)) |
| 109 | + |
| 110 | + self.annotations_creators, annotations_creators_errors = self.validate_annotations_creators( |
| 111 | + self.annotations_creators |
| 112 | + ) |
| 113 | + self.language_creators, language_creators_errors = self.validate_language_creators(self.language_creators) |
| 114 | + self.languages, languages_errors = self.validate_language_codes(self.languages) |
| 115 | + self.licenses, licenses_errors = self.validate_licences(self.licenses) |
| 116 | + self.multilinguality, multilinguality_errors = self.validate_mulitlinguality(self.multilinguality) |
| 117 | + self.size_categories, size_categories_errors = self.validate_size_catgeories(self.size_categories) |
| 118 | + self.source_datasets, source_datasets_errors = self.validate_source_datasets(self.source_datasets) |
| 119 | + self.task_categories, task_categories_errors = self.validate_task_categories(self.task_categories) |
| 120 | + self.task_ids, task_ids_errors = self.validate_task_ids(self.task_ids) |
| 121 | + |
| 122 | + errors = { |
| 123 | + "annotations_creators": annotations_creators_errors, |
| 124 | + "language_creators": language_creators_errors, |
| 125 | + "licenses": licenses_errors, |
| 126 | + "multilinguality": multilinguality_errors, |
| 127 | + "size_categories": size_categories_errors, |
| 128 | + "source_datasets": source_datasets_errors, |
| 129 | + "task_categories": task_categories_errors, |
| 130 | + "task_ids": task_ids_errors, |
| 131 | + "languages": languages_errors, |
| 132 | + } |
| 133 | + |
| 134 | + exception_msg_dict = dict() |
| 135 | + for field, errs in errors.items(): |
| 136 | + if errs is not None: |
| 137 | + exception_msg_dict[field] = errs |
| 138 | + if len(exception_msg_dict) > 0: |
| 139 | + raise TypeError( |
| 140 | + "Could not validate the metada, found the following errors:\n" |
| 141 | + + "\n".join(f"* field '{fieldname}':\n\t{err}" for fieldname, err in exception_msg_dict.items()) |
| 142 | + ) |
| 143 | + |
| 144 | + @classmethod |
| 145 | + def from_readme(cls, path: Path) -> "DatasetMetadata": |
| 146 | + """Loads and validates the dataset metadat from its dataset card (README.md) |
| 147 | +
|
| 148 | + Args: |
| 149 | + path (:obj:`Path`): Path to the dataset card (its README.md file) |
| 150 | +
|
| 151 | + Returns: |
| 152 | + :class:`DatasetMetadata`: The dataset's metadata |
| 153 | +
|
| 154 | + Raises: |
| 155 | + :obj:`TypeError`: If the dataset card has no metadata (no YAML header) |
| 156 | + :obj:`TypeError`: If the dataset's metadata is invalid |
| 157 | + """ |
| 158 | + yaml_string = yaml_block_from_readme(path) |
| 159 | + if yaml_string is not None: |
| 160 | + return cls.from_yaml_string(yaml_string) |
| 161 | + else: |
| 162 | + raise TypeError(f"did not find a yaml block in '{path}'") |
| 163 | + |
| 164 | + @classmethod |
| 165 | + def from_yaml_string(cls, string: str) -> "DatasetMetadata": |
| 166 | + """Loads and validates the dataset metadat from a YAML string |
| 167 | +
|
| 168 | + Args: |
| 169 | + string (:obj:`str`): The YAML string |
| 170 | +
|
| 171 | + Returns: |
| 172 | + :class:`DatasetMetadata`: The dataset's metadata |
| 173 | +
|
| 174 | + Raises: |
| 175 | + :obj:`TypeError`: If the dataset's metadata is invalid |
| 176 | + """ |
| 177 | + metada_dict = yaml.safe_load(string) or dict() |
| 178 | + return cls(**metada_dict) |
| 179 | + |
| 180 | + @staticmethod |
| 181 | + def validate_annotations_creators(annotations_creators: List[str]) -> ValidatorOutput: |
| 182 | + return tagset_validator( |
| 183 | + annotations_creators, known_creators["annotations"], "annotations_creators", known_creators_url |
| 184 | + ) |
| 185 | + |
| 186 | + @staticmethod |
| 187 | + def validate_language_creators(language_creators: List[str]) -> ValidatorOutput: |
| 188 | + return tagset_validator(language_creators, known_creators["language"], "language_creators", known_creators_url) |
| 189 | + |
| 190 | + @staticmethod |
| 191 | + def validate_language_codes(languages: List[str]) -> ValidatorOutput: |
| 192 | + return tagset_validator( |
| 193 | + values=languages, |
| 194 | + reference_values=known_language_codes.keys(), |
| 195 | + name="languages", |
| 196 | + url=known_language_codes_url, |
| 197 | + ) |
| 198 | + |
| 199 | + @staticmethod |
| 200 | + def validate_licences(licenses: List[str]) -> ValidatorOutput: |
| 201 | + others, to_validate = escape_validation_for_predicate(licenses, lambda e: "-other-" in e) |
| 202 | + validated, error = tagset_validator(to_validate, list(known_licenses.keys()), "licenses", known_licenses_url) |
| 203 | + return [*validated, *others], error |
| 204 | + |
| 205 | + @staticmethod |
| 206 | + def validate_task_categories(task_categories: List[str]) -> ValidatorOutput: |
| 207 | + # TODO: we're currently ignoring all values starting with 'other' as our task taxonomy is bound to change |
| 208 | + # in the near future and we don't want to waste energy in tagging against a moving taxonomy. |
| 209 | + known_set = list(known_task_ids.keys()) |
| 210 | + others, to_validate = escape_validation_for_predicate(task_categories, lambda e: e.startswith("other")) |
| 211 | + validated, error = tagset_validator(to_validate, known_set, "task_categories", known_task_ids_url) |
| 212 | + return [*validated, *others], error |
| 213 | + |
| 214 | + @staticmethod |
| 215 | + def validate_task_ids(task_ids: List[str]) -> ValidatorOutput: |
| 216 | + # TODO: we're currently ignoring all values starting with 'other' as our task taxonomy is bound to change |
| 217 | + # in the near future and we don't want to waste energy in tagging against a moving taxonomy. |
| 218 | + known_set = [tid for _cat, d in known_task_ids.items() for tid in d["options"]] |
| 219 | + others, to_validate = escape_validation_for_predicate(task_ids, lambda e: "-other-" in e) |
| 220 | + validated, error = tagset_validator(to_validate, known_set, "task_ids", known_task_ids_url) |
| 221 | + return [*validated, *others], error |
| 222 | + |
| 223 | + @staticmethod |
| 224 | + def validate_mulitlinguality(multilinguality: List[str]) -> ValidatorOutput: |
| 225 | + others, to_validate = escape_validation_for_predicate(multilinguality, lambda e: e.startswith("other")) |
| 226 | + validated, error = tagset_validator( |
| 227 | + to_validate, list(known_multilingualities.keys()), "multilinguality", known_size_categories_url |
| 228 | + ) |
| 229 | + return [*validated, *others], error |
| 230 | + |
| 231 | + @staticmethod |
| 232 | + def validate_size_catgeories(size_cats: List[str]) -> ValidatorOutput: |
| 233 | + return tagset_validator(size_cats, known_size_categories, "size_categories", known_size_categories_url) |
| 234 | + |
| 235 | + @staticmethod |
| 236 | + def validate_source_datasets(sources: List[str]) -> ValidatorOutput: |
| 237 | + invalid_values = [] |
| 238 | + for src in sources: |
| 239 | + is_ok = src in ["original", "extended"] or src.startswith("extended|") |
| 240 | + if not is_ok: |
| 241 | + invalid_values.append(src) |
| 242 | + if len(invalid_values) > 0: |
| 243 | + return ( |
| 244 | + [], |
| 245 | + f"'source_datasets' has invalid values: {invalid_values}, refer to source code to understand {this_url}", |
| 246 | + ) |
| 247 | + |
| 248 | + return sources, None |
| 249 | + |
| 250 | + |
| 251 | +if __name__ == "__main__": |
| 252 | + from argparse import ArgumentParser |
| 253 | + |
| 254 | + ap = ArgumentParser(usage="Validate the yaml metadata block of a README.md file.") |
| 255 | + ap.add_argument("readme_filepath") |
| 256 | + args = ap.parse_args() |
| 257 | + |
| 258 | + readme_filepath = Path(args.readme_filepath) |
| 259 | + DatasetMetadata.from_readme(readme_filepath) |
0 commit comments