From 731ca73d961671ba34fa607988333f4640087555 Mon Sep 17 00:00:00 2001 From: Patrick Smyth Date: Sun, 10 Mar 2024 12:07:19 -0500 Subject: [PATCH 1/6] Add a registry instead of calling globals for fetching feature types --- src/datasets/features/features.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 6ebdb48741d..1547a492a2d 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -13,7 +13,8 @@ # limitations under the License. # Lint as: python3 -""" This class handle features definition in datasets and some utilities to display table type.""" +"""This class handle features definition in datasets and some utilities to display table type.""" + import copy import json import re @@ -1340,6 +1341,24 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni return obj +_FEATURE_TYPES: Dict[Optional[str], FeatureType] = {k: v for k, v in globals().items() if isinstance(v, FeatureType)} + + +def _register_feature( + feature_cls: type, + feature_type: Optional[str], +): + """ + Register a Feature object using a name and class. + This function must be used on a Feature class. + """ + if feature_type in _FEATURE_TYPES: + logger.warning( + f"Overwriting feature type '{feature_type}' ({_FEATURE_TYPES[feature_type].__name__} -> {feature_cls.__name__})" + ) + _FEATURE_TYPES[feature_type] = feature_cls + + def generate_from_dict(obj: Any): """Regenerate the nested feature object from a deserialized dict. We use the '_type' fields to get the dataclass name to load. @@ -1358,7 +1377,11 @@ def generate_from_dict(obj: Any): if "_type" not in obj or isinstance(obj["_type"], dict): return {key: generate_from_dict(value) for key, value in obj.items()} obj = dict(obj) - class_type = globals()[obj.pop("_type")] + _type = obj.pop("_type") + class_type = _FEATURE_TYPES.get(_type, None) + + if class_type is None: + raise ValueError(f"Feature type '{_type}' not found. Available feature types: {list(_FEATURE_TYPES.keys())}") if class_type == Sequence: return Sequence(feature=generate_from_dict(obj["feature"]), length=obj.get("length", -1)) From dd982cb39b7efa2e6f3690d05eb8d5b90b16994a Mon Sep 17 00:00:00 2001 From: Patrick Smyth Date: Sun, 10 Mar 2024 12:16:12 -0500 Subject: [PATCH 2/6] Refactor feature registration function --- src/datasets/features/features.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 1547a492a2d..740b162bdc0 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1344,7 +1344,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni _FEATURE_TYPES: Dict[Optional[str], FeatureType] = {k: v for k, v in globals().items() if isinstance(v, FeatureType)} -def _register_feature( +def register_feature( feature_cls: type, feature_type: Optional[str], ): From 8fd0c5bc483b0b37939d0ed195fcd69472a34208 Mon Sep 17 00:00:00 2001 From: Patrick Smyth Date: Sun, 10 Mar 2024 13:22:22 -0500 Subject: [PATCH 3/6] Add feature type registry lookup first before global lookup --- src/datasets/features/features.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 740b162bdc0..841284ec27d 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -1341,7 +1341,7 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni return obj -_FEATURE_TYPES: Dict[Optional[str], FeatureType] = {k: v for k, v in globals().items() if isinstance(v, FeatureType)} +_FEATURE_TYPES: Dict[Optional[str], FeatureType] = {} def register_feature( @@ -1378,7 +1378,7 @@ def generate_from_dict(obj: Any): return {key: generate_from_dict(value) for key, value in obj.items()} obj = dict(obj) _type = obj.pop("_type") - class_type = _FEATURE_TYPES.get(_type, None) + class_type = _FEATURE_TYPES.get(_type, None) or globals().get(_type, None) if class_type is None: raise ValueError(f"Feature type '{_type}' not found. Available feature types: {list(_FEATURE_TYPES.keys())}") From 3f6fd374baf4aad93971094e67b75b7d7c476a8e Mon Sep 17 00:00:00 2001 From: Patrick Smyth Date: Mon, 11 Mar 2024 12:16:25 -0500 Subject: [PATCH 4/6] Add experimental feature registration --- src/datasets/features/features.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index dbc4b2038f1..9ac4f894f1e 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -39,7 +39,7 @@ from .. import config from ..naming import camelcase_to_snakecase, snakecase_to_camelcase from ..table import array_cast -from ..utils import logging +from ..utils import logging, experimental from ..utils.py_utils import asdict, first_non_null_value, zip_dict from .audio import Audio from .image import Image, encode_pil_image @@ -1342,12 +1342,25 @@ def decode_nested_example(schema, obj, token_per_repo_id: Optional[Dict[str, Uni return obj -_FEATURE_TYPES: Dict[Optional[str], FeatureType] = {} +_FEATURE_TYPES: Dict[str, FeatureType] = { + Value.__name__: Value, + ClassLabel.__name__: ClassLabel, + Translation.__name__: Translation, + TranslationVariableLanguages.__name__: TranslationVariableLanguages, + Sequence.__name__: Sequence, + Array2D.__name__: Array2D, + Array3D.__name__: Array3D, + Array4D.__name__: Array4D, + Array5D.__name__: Array5D, + Audio.__name__: Audio, + Image.__name__: Image, +} +@experimental def register_feature( feature_cls: type, - feature_type: Optional[str], + feature_type: str, ): """ Register a Feature object using a name and class. From 67705da3c2ac317e255032c9a9e4a9186dcbf641 Mon Sep 17 00:00:00 2001 From: Patrick Smyth Date: Mon, 11 Mar 2024 12:18:53 -0500 Subject: [PATCH 5/6] Reorder imports in features.py --- src/datasets/features/features.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 9ac4f894f1e..5aedd42c4b2 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -39,7 +39,7 @@ from .. import config from ..naming import camelcase_to_snakecase, snakecase_to_camelcase from ..table import array_cast -from ..utils import logging, experimental +from ..utils import experimental, logging from ..utils.py_utils import asdict, first_non_null_value, zip_dict from .audio import Audio from .image import Image, encode_pil_image From 99cad86c992982b4ed57f4ea18c1dc3e4e9a3513 Mon Sep 17 00:00:00 2001 From: Patrick Smyth Date: Mon, 11 Mar 2024 12:18:53 -0500 Subject: [PATCH 6/6] Reorder imports in features.py --- src/datasets/features/features.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/datasets/features/features.py b/src/datasets/features/features.py index 9ac4f894f1e..5aedd42c4b2 100644 --- a/src/datasets/features/features.py +++ b/src/datasets/features/features.py @@ -39,7 +39,7 @@ from .. import config from ..naming import camelcase_to_snakecase, snakecase_to_camelcase from ..table import array_cast -from ..utils import logging, experimental +from ..utils import experimental, logging from ..utils.py_utils import asdict, first_non_null_value, zip_dict from .audio import Audio from .image import Image, encode_pil_image