Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/datasets/features/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -1342,6 +1342,24 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni
return obj


_FEATURE_TYPES: Dict[Optional[str], FeatureType] = {}


def register_feature(
feature_cls: type,
feature_type: Optional[str],
):
"""
Register a Feature object using a name and class.
This function must be used on a Feature class.
"""
if feature_type in _FEATURE_TYPES:
logger.warning(
f"Overwriting feature type '{feature_type}' ({_FEATURE_TYPES[feature_type].__name__} -> {feature_cls.__name__})"
)
_FEATURE_TYPES[feature_type] = feature_cls


def generate_from_dict(obj: Any):
"""Regenerate the nested feature object from a deserialized dict.
We use the '_type' fields to get the dataclass name to load.
Expand All @@ -1360,7 +1378,11 @@ def generate_from_dict(obj: Any):
if "_type" not in obj or isinstance(obj["_type"], dict):
return {key: generate_from_dict(value) for key, value in obj.items()}
obj = dict(obj)
class_type = globals()[obj.pop("_type")]
_type = obj.pop("_type")
class_type = _FEATURE_TYPES.get(_type, None) or globals().get(_type, None)

if class_type is None:
raise ValueError(f"Feature type '{_type}' not found. Available feature types: {list(_FEATURE_TYPES.keys())}")

if class_type == Sequence:
return Sequence(feature=generate_from_dict(obj["feature"]), length=obj.get("length", -1))
Expand Down