diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 5773d26e045..6ebdb48741d 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1776,10 +1776,22 @@ def to_yaml_inner(obj: Union[dict, list]) -> dict: return {"struct": [{"name": name, **to_yaml_inner(_feature)} for name, _feature in obj.items()]} elif isinstance(obj, list): return simplify({"list": simplify(to_yaml_inner(obj[0]))}) + elif isinstance(obj, tuple): + return to_yaml_inner(list(obj)) else: raise TypeError(f"Expected a dict or a list but got {type(obj)}: {obj}") - return to_yaml_inner(yaml_data)["struct"] + def to_yaml_types(obj: dict) -> dict: + if isinstance(obj, dict): + return {k: to_yaml_types(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [to_yaml_types(v) for v in obj] + elif isinstance(obj, tuple): + return to_yaml_types(list(obj)) + else: + return obj + + return to_yaml_types(to_yaml_inner(yaml_data)["struct"]) @classmethod def _from_yaml_list(cls, yaml_data: list) -> "Features": @@ -1837,7 +1849,7 @@ def from_yaml_inner(obj: Union[dict, list]) -> Union[dict, list]: Value(obj["dtype"]) return {**obj, "_type": "Value"} except ValueError: - # for audio and image that are Audio and Image types, not Value + # e.g. Audio, Image, ArrayXD return {"_type": snakecase_to_camelcase(obj["dtype"])} else: return from_yaml_inner(obj["dtype"])