diff --git a/src/magicgui/_util.py b/src/magicgui/_util.py index 85c8dde93..4bfa65a9a 100644 --- a/src/magicgui/_util.py +++ b/src/magicgui/_util.py @@ -6,7 +6,14 @@ import time from functools import wraps from pathlib import Path -from typing import TYPE_CHECKING, Callable, Iterable, overload +from typing import ( + TYPE_CHECKING, + Callable, + Iterable, + get_args, + get_origin, + overload, +) from docstring_parser import DocstringParam, parse @@ -145,10 +152,69 @@ def user_cache_dir( return path +def _safe_isinstance_tuple(obj: object, superclass: object) -> bool: + """ + Extracted from `safe_issubclass` to handle checking of generic tuple types. + + It covers following cases: + + 1. obj is tuple with ellipsis and superclass is tuple with ellipsis + 2. obj is tuple with ellipsis and superclass is Iterable with specific type + 3. obj is tuple with same type elements and superclass is tuple with ellipsis + + for other cases it fallback to simple compare types + """ + obj_args = get_args(obj) + superclass_args = get_args(superclass) + superclass_origin = get_origin(superclass) + + if safe_issubclass(superclass_origin, tuple): + if len(superclass_args) == 2 and superclass_args[1] is Ellipsis: + if len(obj_args) == 2 and obj_args[1] is Ellipsis: + # case 3 + return safe_issubclass(obj_args[0], superclass_args[0]) + # case 2 + return all(safe_issubclass(o, superclass_args[0]) for o in obj_args) + # fallback to simple compare + return ( + len(obj_args) == len(superclass_args) and + all(safe_issubclass(o, s) for o, s in zip(obj_args, superclass_args)) + ) + + if len(obj_args) == 2 and obj_args[1] is Ellipsis: + return safe_issubclass(obj_args[0], superclass_args[0]) + return all(safe_issubclass(o, superclass_args[0]) for o in obj_args) + + def safe_issubclass(obj: object, superclass: object) -> bool: """Safely check if obj is a subclass of superclass.""" + if isinstance(superclass, tuple): + return any(safe_issubclass(obj, s) for s in superclass) + obj_origin = get_origin(obj) + superclass_origin = get_origin(superclass) + superclass_args = get_args(superclass) try: - return issubclass(obj, superclass) # type: ignore + if obj_origin is None: + if superclass_origin is None: + return issubclass(obj, superclass) # type: ignore + if not superclass_args: + return issubclass(obj, superclass_origin) # type: ignore + # if obj is not generic type, but superclass is with + # we can't say anything about it + return False + if obj_origin is not None and superclass_origin is None: + return issubclass(obj_origin, superclass) # type: ignore + if not issubclass(obj_origin, superclass_origin): # type: ignore + return False + obj_args = get_args(obj) + if obj_origin is tuple and obj_args: + return _safe_isinstance_tuple(obj, superclass) + + return ( + issubclass(obj_origin, superclass_origin) and # type: ignore + (obj_args == superclass_args or not superclass_args) + ) + except Exception: return False diff --git a/src/magicgui/type_map/_type_map.py b/src/magicgui/type_map/_type_map.py index 308a13660..c4882839d 100644 --- a/src/magicgui/type_map/_type_map.py +++ b/src/magicgui/type_map/_type_map.py @@ -71,11 +71,15 @@ class MissingWidget(RuntimeError): datetime.datetime: widgets.DateTimeEdit, range: widgets.RangeEdit, slice: widgets.SliceEdit, - list: widgets.ListEdit, + Sequence[pathlib.Path]: widgets.FileEdit, tuple: widgets.TupleEdit, + Sequence: widgets.ListEdit, os.PathLike: widgets.FileEdit, } +_ADDITIONAL_KWARGS: dict[type, dict[str, Any]] = { + Sequence[pathlib.Path]: {"mode": "rm"} +} def match_type(type_: Any, default: Any | None = None) -> WidgetTuple | None: """Check simple type mappings.""" @@ -86,10 +90,10 @@ def match_type(type_: Any, default: Any | None = None) -> WidgetTuple | None: return widgets.ProgressBar, {"bind": lambda widget: widget, "visible": True} if type_ in _SIMPLE_TYPES: - return _SIMPLE_TYPES[type_], {} + return _SIMPLE_TYPES[type_], _ADDITIONAL_KWARGS.get(type_, {}) for key in _SIMPLE_TYPES.keys(): if safe_issubclass(type_, key): - return _SIMPLE_TYPES[key], {} + return _SIMPLE_TYPES[key], _ADDITIONAL_KWARGS.get(key, {}) if type_ in (types.FunctionType,): return widgets.FunctionGui, {"function": default} @@ -99,16 +103,6 @@ def match_type(type_: Any, default: Any | None = None) -> WidgetTuple | None: if choices is not None: # it's a Literal type return widgets.ComboBox, {"choices": choices, "nullable": nullable} - # sequence of paths - if safe_issubclass(origin, Sequence): - args = get_args(type_) - if len(args) == 1 and safe_issubclass(args[0], pathlib.Path): - return widgets.FileEdit, {"mode": "rm"} - elif safe_issubclass(origin, list): - return widgets.ListEdit, {} - elif safe_issubclass(origin, tuple): - return widgets.TupleEdit, {} - if safe_issubclass(origin, Set): for arg in get_args(type_): if get_origin(arg) is Literal: diff --git a/tests/test_types.py b/tests/test_types.py index 242f7c1a1..7c538d226 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,6 +1,6 @@ from enum import Enum from pathlib import Path -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Sequence, Union from unittest.mock import Mock import pytest @@ -150,6 +150,7 @@ def test_type_registered(): with type_registered(Path, widget_type=widgets.LineEdit): assert isinstance(widgets.create_widget(annotation=Path), widgets.LineEdit) assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit) + assert isinstance(widgets.create_widget(annotation=Sequence[Path]), widgets.FileEdit) def test_type_registered_callbacks(): diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 000000000..de00bf8f8 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,80 @@ +import sys +import typing +from collections.abc import Mapping, Sequence + +import pytest + +from magicgui._util import safe_issubclass + + +class TestSafeIsSubclass: + def test_basic(self): + assert safe_issubclass(int, int) + assert safe_issubclass(int, object) + + def test_generic_base(self): + assert safe_issubclass(typing.List[int], list) + assert safe_issubclass(typing.List[int], typing.List) + + def test_multiple_generic_base(self): + assert safe_issubclass(typing.List[int], (typing.List, typing.Dict)) + + def test_no_exception(self): + assert not safe_issubclass(int, 1) + + def test_typing_inheritance(self): + assert safe_issubclass(typing.List, list) + assert safe_issubclass(list, typing.List) + assert safe_issubclass(typing.Tuple, tuple) + assert safe_issubclass(tuple, typing.Tuple) + assert safe_issubclass(typing.Dict, dict) + assert safe_issubclass(dict, typing.Dict) + + def test_inheritance_generic_list(self): + assert safe_issubclass(list, typing.Sequence) + assert safe_issubclass(typing.List, typing.Sequence) + assert safe_issubclass(typing.List[int], typing.Sequence[int]) + assert safe_issubclass(typing.List[int], typing.Sequence) + assert safe_issubclass(typing.List[int], Sequence) + + def test_no_inheritance_generic_super(self): + assert not safe_issubclass(list, typing.List[int]) + + def test_inheritance_generic_mapping(self): + assert safe_issubclass(dict, typing.Mapping) + assert safe_issubclass(typing.Dict, typing.Mapping) + assert safe_issubclass(typing.Dict[int, str], typing.Mapping[int, str]) + assert safe_issubclass(typing.Dict[int, str], typing.Mapping) + assert safe_issubclass(typing.Dict[int, str], Mapping) + + @pytest.mark.skipif(sys.version_info < (3, 9), reason="PEP-585 is supported in 3.9+") + def test_typing_builtins_list(self): + assert safe_issubclass(list[int], list) + assert safe_issubclass(list[int], Sequence) + assert safe_issubclass(list[int], typing.Sequence) + assert safe_issubclass(list[int], typing.Sequence[int]) + assert safe_issubclass(list[int], typing.List[int]) + assert safe_issubclass(typing.List[int], list) + assert safe_issubclass(typing.List[int], list[int]) + + @pytest.mark.skipif(sys.version_info < (3, 9), reason="PEP-585 is supported in 3.9+") + def test_typing_builtins_dict(self): + assert safe_issubclass(dict[int, str], dict) + assert safe_issubclass(dict[int, str], Mapping) + assert safe_issubclass(dict[int, str], typing.Mapping) + assert safe_issubclass(dict[int, str], typing.Mapping[int, str]) + assert safe_issubclass(dict[int, str], typing.Dict[int, str]) + assert safe_issubclass(typing.Dict[int, str], dict) + assert safe_issubclass(typing.Dict[int, str], dict[int, str]) + + def test_tuple_check(self): + assert safe_issubclass(typing.Tuple[int, str], tuple) + assert safe_issubclass(typing.Tuple[int], typing.Sequence[int]) + assert safe_issubclass(typing.Tuple[int, int], typing.Sequence[int]) + assert safe_issubclass(typing.Tuple[int, ...], typing.Sequence[int]) + assert safe_issubclass(typing.Tuple[int, ...], typing.Iterable[int]) + assert not safe_issubclass(typing.Tuple[int, ...], typing.Dict[int, typing.Any]) + assert safe_issubclass(typing.Tuple[int, ...], typing.Tuple[int, ...]) + assert safe_issubclass(typing.Tuple[int, int], typing.Tuple[int, ...]) + assert not safe_issubclass(typing.Tuple[int, int], typing.Tuple[int, str]) + assert not safe_issubclass(typing.Tuple[int, int], typing.Tuple[int, int, int]) diff --git a/tests/test_widgets.py b/tests/test_widgets.py index 53ab9c392..0917daff7 100644 --- a/tests/test_widgets.py +++ b/tests/test_widgets.py @@ -2,7 +2,7 @@ import inspect from enum import Enum from pathlib import Path -from typing import Optional, Tuple +from typing import List, Optional, Tuple from unittest.mock import MagicMock, patch import pytest @@ -847,8 +847,6 @@ def test_pushbutton_icon(backend: str): def test_list_edit(): """Test ListEdit.""" - from typing import List - mock = MagicMock() list_edit = widgets.ListEdit(value=[1, 2, 3]) @@ -900,6 +898,8 @@ def test_list_edit(): assert mock.call_count == 7 mock.assert_called_with([2, 1]) + +def test_list_edit_only_values(): @magicgui def f1(x=[2, 4, 6]): # noqa: B006 pass @@ -908,6 +908,7 @@ def f1(x=[2, 4, 6]): # noqa: B006 assert f1.x._args_type is int assert f1.x.value == [2, 4, 6] +def test_list_edit_annotations(): @magicgui def f2(x: List[int]): pass