diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index c2c7d8ff17e..5aedd42c4b2 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -39,7 +39,7 @@ from .. import config from ..naming import camelcase_to_snakecase, snakecase_to_camelcase from ..table import array_cast -from ..utils import logging +from ..utils import experimental, logging from ..utils.py_utils import asdict, first_non_null_value, zip_dict from .audio import Audio from .image import Image, encode_pil_image @@ -1342,6 +1342,37 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni return obj +_FEATURE_TYPES: Dict[str, FeatureType] = { + Value.__name__: Value, + ClassLabel.__name__: ClassLabel, + Translation.__name__: Translation, + TranslationVariableLanguages.__name__: TranslationVariableLanguages, + Sequence.__name__: Sequence, + Array2D.__name__: Array2D, + Array3D.__name__: Array3D, + Array4D.__name__: Array4D, + Array5D.__name__: Array5D, + Audio.__name__: Audio, + Image.__name__: Image, +} + + +@experimental +def register_feature( + feature_cls: type, + feature_type: str, +): + """ + Register a Feature object using a name and class. + This function must be used on a Feature class. + """ + if feature_type in _FEATURE_TYPES: + logger.warning( + f"Overwriting feature type '{feature_type}' ({_FEATURE_TYPES[feature_type].__name__} -> {feature_cls.__name__})" + ) + _FEATURE_TYPES[feature_type] = feature_cls + + def generate_from_dict(obj: Any): """Regenerate the nested feature object from a deserialized dict. We use the '_type' fields to get the dataclass name to load. @@ -1360,7 +1391,11 @@ def generate_from_dict(obj: Any): if "_type" not in obj or isinstance(obj["_type"], dict): return {key: generate_from_dict(value) for key, value in obj.items()} obj = dict(obj) - class_type = globals()[obj.pop("_type")] + _type = obj.pop("_type") + class_type = _FEATURE_TYPES.get(_type, None) or globals().get(_type, None) + + if class_type is None: + raise ValueError(f"Feature type '{_type}' not found. Available feature types: {list(_FEATURE_TYPES.keys())}") if class_type == Sequence: return Sequence(feature=generate_from_dict(obj["feature"]), length=obj.get("length", -1))