Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions src/magicgui/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@
import time
from functools import wraps
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Iterable, overload
from typing import (
TYPE_CHECKING,
Callable,
Iterable,
get_args,
get_origin,
overload,
)

from docstring_parser import DocstringParam, parse

Expand Down Expand Up @@ -145,10 +152,69 @@ def user_cache_dir(
return path


def _safe_isinstance_tuple(obj: object, superclass: object) -> bool:
"""
Extracted from `safe_issubclass` to handle checking of generic tuple types.

It covers following cases:

1. obj is tuple with ellipsis and superclass is tuple with ellipsis
2. obj is tuple with ellipsis and superclass is Iterable with specific type
3. obj is tuple with same type elements and superclass is tuple with ellipsis

for other cases it fallback to simple compare types
"""
obj_args = get_args(obj)
superclass_args = get_args(superclass)
superclass_origin = get_origin(superclass)

if safe_issubclass(superclass_origin, tuple):
if len(superclass_args) == 2 and superclass_args[1] is Ellipsis:
if len(obj_args) == 2 and obj_args[1] is Ellipsis:
# case 3
return safe_issubclass(obj_args[0], superclass_args[0])
# case 2
return all(safe_issubclass(o, superclass_args[0]) for o in obj_args)
# fallback to simple compare
return (
len(obj_args) == len(superclass_args) and
all(safe_issubclass(o, s) for o, s in zip(obj_args, superclass_args))
)

if len(obj_args) == 2 and obj_args[1] is Ellipsis:
return safe_issubclass(obj_args[0], superclass_args[0])
return all(safe_issubclass(o, superclass_args[0]) for o in obj_args)


def safe_issubclass(obj: object, superclass: object) -> bool:
"""Safely check if obj is a subclass of superclass."""
if isinstance(superclass, tuple):
return any(safe_issubclass(obj, s) for s in superclass)
obj_origin = get_origin(obj)
superclass_origin = get_origin(superclass)
superclass_args = get_args(superclass)
try:
return issubclass(obj, superclass) # type: ignore
if obj_origin is None:
if superclass_origin is None:
return issubclass(obj, superclass) # type: ignore
if not superclass_args:
return issubclass(obj, superclass_origin) # type: ignore
# if obj is not generic type, but superclass is with
# we can't say anything about it
return False
if obj_origin is not None and superclass_origin is None:
return issubclass(obj_origin, superclass) # type: ignore
if not issubclass(obj_origin, superclass_origin): # type: ignore
return False
obj_args = get_args(obj)
if obj_origin is tuple and obj_args:
return _safe_isinstance_tuple(obj, superclass)

return (
issubclass(obj_origin, superclass_origin) and # type: ignore
(obj_args == superclass_args or not superclass_args)
)

except Exception:
return False

Expand Down
20 changes: 7 additions & 13 deletions src/magicgui/type_map/_type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,15 @@ class MissingWidget(RuntimeError):
datetime.datetime: widgets.DateTimeEdit,
range: widgets.RangeEdit,
slice: widgets.SliceEdit,
list: widgets.ListEdit,
Sequence[pathlib.Path]: widgets.FileEdit,
tuple: widgets.TupleEdit,
Sequence: widgets.ListEdit,
os.PathLike: widgets.FileEdit,
}

_ADDITIONAL_KWARGS: dict[type, dict[str, Any]] = {
Sequence[pathlib.Path]: {"mode": "rm"}
}

def match_type(type_: Any, default: Any | None = None) -> WidgetTuple | None:
"""Check simple type mappings."""
Expand All @@ -86,10 +90,10 @@ def match_type(type_: Any, default: Any | None = None) -> WidgetTuple | None:
return widgets.ProgressBar, {"bind": lambda widget: widget, "visible": True}

if type_ in _SIMPLE_TYPES:
return _SIMPLE_TYPES[type_], {}
return _SIMPLE_TYPES[type_], _ADDITIONAL_KWARGS.get(type_, {})
for key in _SIMPLE_TYPES.keys():
if safe_issubclass(type_, key):
return _SIMPLE_TYPES[key], {}
return _SIMPLE_TYPES[key], _ADDITIONAL_KWARGS.get(key, {})

if type_ in (types.FunctionType,):
return widgets.FunctionGui, {"function": default}
Expand All @@ -99,16 +103,6 @@ def match_type(type_: Any, default: Any | None = None) -> WidgetTuple | None:
if choices is not None: # it's a Literal type
return widgets.ComboBox, {"choices": choices, "nullable": nullable}

# sequence of paths
if safe_issubclass(origin, Sequence):
args = get_args(type_)
if len(args) == 1 and safe_issubclass(args[0], pathlib.Path):
return widgets.FileEdit, {"mode": "rm"}
elif safe_issubclass(origin, list):
return widgets.ListEdit, {}
elif safe_issubclass(origin, tuple):
return widgets.TupleEdit, {}

if safe_issubclass(origin, Set):
for arg in get_args(type_):
if get_origin(arg) is Literal:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Sequence, Union
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -150,6 +150,7 @@ def test_type_registered():
with type_registered(Path, widget_type=widgets.LineEdit):
assert isinstance(widgets.create_widget(annotation=Path), widgets.LineEdit)
assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit)
assert isinstance(widgets.create_widget(annotation=Sequence[Path]), widgets.FileEdit)


def test_type_registered_callbacks():
Expand Down
80 changes: 80 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import sys
import typing
from collections.abc import Mapping, Sequence

import pytest

from magicgui._util import safe_issubclass


class TestSafeIsSubclass:
def test_basic(self):
assert safe_issubclass(int, int)
assert safe_issubclass(int, object)

def test_generic_base(self):
assert safe_issubclass(typing.List[int], list)
assert safe_issubclass(typing.List[int], typing.List)

def test_multiple_generic_base(self):
assert safe_issubclass(typing.List[int], (typing.List, typing.Dict))

def test_no_exception(self):
assert not safe_issubclass(int, 1)

def test_typing_inheritance(self):
assert safe_issubclass(typing.List, list)
assert safe_issubclass(list, typing.List)
assert safe_issubclass(typing.Tuple, tuple)
assert safe_issubclass(tuple, typing.Tuple)
assert safe_issubclass(typing.Dict, dict)
assert safe_issubclass(dict, typing.Dict)

def test_inheritance_generic_list(self):
assert safe_issubclass(list, typing.Sequence)
assert safe_issubclass(typing.List, typing.Sequence)
assert safe_issubclass(typing.List[int], typing.Sequence[int])
assert safe_issubclass(typing.List[int], typing.Sequence)
assert safe_issubclass(typing.List[int], Sequence)

def test_no_inheritance_generic_super(self):
assert not safe_issubclass(list, typing.List[int])

def test_inheritance_generic_mapping(self):
assert safe_issubclass(dict, typing.Mapping)
assert safe_issubclass(typing.Dict, typing.Mapping)
assert safe_issubclass(typing.Dict[int, str], typing.Mapping[int, str])
assert safe_issubclass(typing.Dict[int, str], typing.Mapping)
assert safe_issubclass(typing.Dict[int, str], Mapping)

@pytest.mark.skipif(sys.version_info < (3, 9), reason="PEP-585 is supported in 3.9+")
def test_typing_builtins_list(self):
assert safe_issubclass(list[int], list)
assert safe_issubclass(list[int], Sequence)
assert safe_issubclass(list[int], typing.Sequence)
assert safe_issubclass(list[int], typing.Sequence[int])
assert safe_issubclass(list[int], typing.List[int])
assert safe_issubclass(typing.List[int], list)
assert safe_issubclass(typing.List[int], list[int])

@pytest.mark.skipif(sys.version_info < (3, 9), reason="PEP-585 is supported in 3.9+")
def test_typing_builtins_dict(self):
assert safe_issubclass(dict[int, str], dict)
assert safe_issubclass(dict[int, str], Mapping)
assert safe_issubclass(dict[int, str], typing.Mapping)
assert safe_issubclass(dict[int, str], typing.Mapping[int, str])
assert safe_issubclass(dict[int, str], typing.Dict[int, str])
assert safe_issubclass(typing.Dict[int, str], dict)
assert safe_issubclass(typing.Dict[int, str], dict[int, str])

def test_tuple_check(self):
assert safe_issubclass(typing.Tuple[int, str], tuple)
assert safe_issubclass(typing.Tuple[int], typing.Sequence[int])
assert safe_issubclass(typing.Tuple[int, int], typing.Sequence[int])
assert safe_issubclass(typing.Tuple[int, ...], typing.Sequence[int])
assert safe_issubclass(typing.Tuple[int, ...], typing.Iterable[int])
assert not safe_issubclass(typing.Tuple[int, ...], typing.Dict[int, typing.Any])
assert safe_issubclass(typing.Tuple[int, ...], typing.Tuple[int, ...])
assert safe_issubclass(typing.Tuple[int, int], typing.Tuple[int, ...])
assert not safe_issubclass(typing.Tuple[int, int], typing.Tuple[int, str])
assert not safe_issubclass(typing.Tuple[int, int], typing.Tuple[int, int, int])
7 changes: 4 additions & 3 deletions tests/test_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
from enum import Enum
from pathlib import Path
from typing import Optional, Tuple
from typing import List, Optional, Tuple
from unittest.mock import MagicMock, patch

import pytest
Expand Down Expand Up @@ -847,8 +847,6 @@ def test_pushbutton_icon(backend: str):

def test_list_edit():
"""Test ListEdit."""
from typing import List

mock = MagicMock()

list_edit = widgets.ListEdit(value=[1, 2, 3])
Expand Down Expand Up @@ -900,6 +898,8 @@ def test_list_edit():
assert mock.call_count == 7
mock.assert_called_with([2, 1])


def test_list_edit_only_values():
@magicgui
def f1(x=[2, 4, 6]): # noqa: B006
pass
Expand All @@ -908,6 +908,7 @@ def f1(x=[2, 4, 6]): # noqa: B006
assert f1.x._args_type is int
assert f1.x.value == [2, 4, 6]

def test_list_edit_annotations():
@magicgui
def f2(x: List[int]):
pass
Expand Down