diff --git a/CHANGELOG.md b/CHANGELOG.md index d4bc0322..5b09065e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,10 @@ - Add `typing_extensions.Buffer`, a marker class for buffer types, as proposed by PEP 688. Equivalent to `collections.abc.Buffer` in Python 3.12. Patch by Jelle Zijlstra. +- Backport [CPython PR 26067](https://github.com/python/cpython/pull/26067) + (originally by Yurii Karabas), ensuring that `isinstance()` calls on + protocols raise `TypeError` when the protocol is not decorated with + `@runtime_checkable`. Patch by Alex Waygood. # Release 4.5.0 (February 14, 2023) diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 9092eae9..b8e48af8 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -1421,6 +1421,22 @@ class E(C, BP): pass self.assertNotIsInstance(D(), E) self.assertNotIsInstance(E(), D) + @skipUnless( + hasattr(typing, "Protocol"), + "Test is only relevant if typing.Protocol exists" + ) + def test_runtimecheckable_on_typing_dot_Protocol(self): + @runtime_checkable + class Foo(typing.Protocol): + x: int + + class Bar: + def __init__(self): + self.x = 42 + + self.assertIsInstance(Bar(), Foo) + self.assertNotIsInstance(object(), Foo) + def test_no_instantiation(self): class P(Protocol): pass with self.assertRaises(TypeError): @@ -1829,11 +1845,7 @@ def meth(self): self.assertTrue(P._is_protocol) self.assertTrue(PR._is_protocol) self.assertTrue(PG._is_protocol) - if hasattr(typing, 'Protocol'): - self.assertFalse(P._is_runtime_protocol) - else: - with self.assertRaises(AttributeError): - self.assertFalse(P._is_runtime_protocol) + self.assertFalse(P._is_runtime_protocol) self.assertTrue(PR._is_runtime_protocol) self.assertTrue(PG[int]._is_protocol) self.assertEqual(typing_extensions._get_protocol_attrs(P), {'meth'}) @@ -1929,6 +1941,13 @@ class CustomProtocol(TestCase, Protocol): class CustomContextManager(typing.ContextManager, Protocol): pass + def test_non_runtime_protocol_isinstance_check(self): + class P(Protocol): + x: int + + with self.assertRaisesRegex(TypeError, "@runtime_checkable"): + isinstance(1, P) + def test_no_init_same_for_different_protocol_implementations(self): class CustomProtocolWithoutInitA(Protocol): pass @@ -3314,7 +3333,7 @@ def test_typing_extensions_defers_when_possible(self): 'is_typeddict', } if sys.version_info < (3, 10): - exclude |= {'get_args', 'get_origin'} + exclude |= {'get_args', 'get_origin', 'Protocol', 'runtime_checkable'} if sys.version_info < (3, 11): exclude |= {'final', 'NamedTuple', 'Any'} for item in typing_extensions.__all__: diff --git a/src/typing_extensions.py b/src/typing_extensions.py index 9ef87b97..6527cdb6 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -398,6 +398,25 @@ def clear_overloads(): } +_EXCLUDED_ATTRS = { + "__abstractmethods__", "__annotations__", "__weakref__", "_is_protocol", + "_is_runtime_protocol", "__dict__", "__slots__", "__parameters__", + "__orig_bases__", "__module__", "_MutableMapping__marker", "__doc__", + "__subclasshook__", "__orig_class__", "__init__", "__new__", +} + +if sys.version_info < (3, 8): + _EXCLUDED_ATTRS |= { + "_gorg", "__next_in_mro__", "__extra__", "__tree_hash__", "__args__", + "__origin__" + } + +if sys.version_info >= (3, 9): + _EXCLUDED_ATTRS.add("__class_getitem__") + +_EXCLUDED_ATTRS = frozenset(_EXCLUDED_ATTRS) + + def _get_protocol_attrs(cls): attrs = set() for base in cls.__mro__[:-1]: # without object @@ -405,14 +424,7 @@ def _get_protocol_attrs(cls): continue annotations = getattr(base, '__annotations__', {}) for attr in list(base.__dict__.keys()) + list(annotations.keys()): - if (not attr.startswith('_abc_') and attr not in ( - '__abstractmethods__', '__annotations__', '__weakref__', - '_is_protocol', '_is_runtime_protocol', '__dict__', - '__args__', '__slots__', - '__next_in_mro__', '__parameters__', '__origin__', - '__orig_bases__', '__extra__', '__tree_hash__', - '__doc__', '__subclasshook__', '__init__', '__new__', - '__module__', '_MutableMapping__marker', '_gorg')): + if (not attr.startswith('_abc_') and attr not in _EXCLUDED_ATTRS): attrs.add(attr) return attrs @@ -468,11 +480,18 @@ def _caller(depth=2): return None -# 3.8+ -if hasattr(typing, 'Protocol'): +# A bug in runtime-checkable protocols was fixed in 3.10+, +# but we backport it to all versions +if sys.version_info >= (3, 10): Protocol = typing.Protocol -# 3.7 + runtime_checkable = typing.runtime_checkable else: + def _allow_reckless_class_checks(depth=4): + """Allow instance and class checks for special stdlib modules. + The abc and functools modules indiscriminately call isinstance() and + issubclass() on the whole MRO of a user class, which may contain protocols. + """ + return _caller(depth) in {'abc', 'functools', None} def _no_init(self, *args, **kwargs): if type(self)._is_protocol: @@ -484,11 +503,19 @@ class _ProtocolMeta(abc.ABCMeta): def __instancecheck__(cls, instance): # We need this method for situations where attributes are # assigned in __init__. - if ((not getattr(cls, '_is_protocol', False) or + is_protocol_cls = getattr(cls, "_is_protocol", False) + if ( + is_protocol_cls and + not getattr(cls, '_is_runtime_protocol', False) and + not _allow_reckless_class_checks(depth=2) + ): + raise TypeError("Instance and class checks can only be used with" + " @runtime_checkable protocols") + if ((not is_protocol_cls or _is_callable_members_only(cls)) and issubclass(instance.__class__, cls)): return True - if cls._is_protocol: + if is_protocol_cls: if all(hasattr(instance, attr) and (not callable(getattr(cls, attr, None)) or getattr(instance, attr) is not None) @@ -530,6 +557,7 @@ def meth(self) -> T: """ __slots__ = () _is_protocol = True + _is_runtime_protocol = False def __new__(cls, *args, **kwds): if cls is Protocol: @@ -581,12 +609,12 @@ def _proto_hook(other): if not cls.__dict__.get('_is_protocol', None): return NotImplemented if not getattr(cls, '_is_runtime_protocol', False): - if _caller(depth=3) in {'abc', 'functools'}: + if _allow_reckless_class_checks(): return NotImplemented raise TypeError("Instance and class checks can only be used with" " @runtime protocols") if not _is_callable_members_only(cls): - if _caller(depth=3) in {'abc', 'functools'}: + if _allow_reckless_class_checks(): return NotImplemented raise TypeError("Protocols with non-method members" " don't support issubclass()") @@ -625,12 +653,6 @@ def _proto_hook(other): f' protocols, got {repr(base)}') cls.__init__ = _no_init - -# 3.8+ -if hasattr(typing, 'runtime_checkable'): - runtime_checkable = typing.runtime_checkable -# 3.7 -else: def runtime_checkable(cls): """Mark a protocol class as a runtime protocol, so that it can be used with isinstance() and issubclass(). Raise TypeError @@ -639,7 +661,10 @@ def runtime_checkable(cls): This allows a simple-minded structural check very similar to the one-offs in collections.abc such as Hashable. """ - if not isinstance(cls, _ProtocolMeta) or not cls._is_protocol: + if not ( + (isinstance(cls, _ProtocolMeta) or issubclass(cls, typing.Generic)) + and getattr(cls, "_is_protocol", False) + ): raise TypeError('@runtime_checkable can be only applied to protocol classes,' f' got {cls!r}') cls._is_runtime_protocol = True