Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/datasets/utils/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union


try: # Python >= 3.8
from typing import get_args
except ImportError:

def get_args(tp):
return tp.__args__


# loading package files: https://stackoverflow.com/a/20885799
try:
import importlib.resources as pkg_resources
Expand Down Expand Up @@ -132,7 +140,7 @@ def validate_type(value: Any, expected_type: Type):
# Add more `elif` statements if primitive type checking is needed
else:
expected_type_name = str(expected_type).split(".", 1)[-1].split("[")[0] # typing.List[str] -> List
expected_type_args = expected_type.__args__
expected_type_args = get_args(expected_type)

if expected_type_name == "Union":
for type_arg in expected_type_args:
Expand Down Expand Up @@ -161,7 +169,7 @@ def validate_type(value: Any, expected_type: Type):
if expected_type_name == "Dict":
if not isinstance(value, dict):
return f"Expected `{expected_type}` with length > 0. Found value of type: `{type(value)}`, with length: {len(value)}.\n"
if expected_type_args != Dict.__args__: # if we specified types for keys and values
if expected_type_args != get_args(Dict): # if we specified types for keys and values
key_type, value_type = expected_type_args
key_error_string = ""
value_error_string = ""
Expand All @@ -174,7 +182,7 @@ def validate_type(value: Any, expected_type: Type):
else: # `List`/`Tuple`
if not isinstance(value, (list, tuple)):
return f"Expected `{expected_type}` with length > 0. Found value of type: `{type(value)}`, with length: {len(value)}.\n"
if expected_type_args != List.__args__: # if we specified types for the items in the list
if expected_type_args != get_args(List): # if we specified types for the items in the list
value_type = expected_type_args[0]
value_error_string = ""
for v in value:
Expand Down