Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions mypyc/codegen/emitclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def emit_line() -> None:
if cl.is_trait:
generate_new_for_trait(cl, new_name, emitter)

generate_methods_table(cl, methods_name, emitter)
generate_methods_table(cl, methods_name, setup_name if generate_full else None, emitter)
emit_line()

flags = ["Py_TPFLAGS_DEFAULT", "Py_TPFLAGS_HEAPTYPE", "Py_TPFLAGS_BASETYPE"]
Expand Down Expand Up @@ -960,8 +960,17 @@ def generate_finalize_for_class(
emitter.emit_line("}")


def generate_methods_table(cl: ClassIR, name: str, emitter: Emitter) -> None:
def generate_methods_table(
cl: ClassIR, name: str, setup_name: str | None, emitter: Emitter
) -> None:
emitter.emit_line(f"static PyMethodDef {name}[] = {{")
if setup_name:
# Store pointer to the setup function so it can be resolved dynamically
# in case of instance creation in __new__.
# CPy_SetupObject expects this method to be the first one in tp_methods.
emitter.emit_line(
f'{{"__internal_mypyc_setup", (PyCFunction){setup_name}, METH_O, NULL}},'
)
for fn in cl.methods.values():
if fn.decl.is_prop_setter or fn.decl.is_prop_getter or fn.internal:
continue
Expand Down
11 changes: 9 additions & 2 deletions mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
isinstance_dict,
)
from mypyc.primitives.float_ops import isinstance_float
from mypyc.primitives.generic_ops import generic_setattr
from mypyc.primitives.generic_ops import generic_setattr, setup_object
from mypyc.primitives.int_ops import isinstance_int
from mypyc.primitives.list_ops import isinstance_list, new_list_set_item_op
from mypyc.primitives.misc_ops import isinstance_bool
Expand Down Expand Up @@ -1103,7 +1103,14 @@ def translate_object_new(builder: IRBuilder, expr: CallExpr, callee: RefExpr) ->
method_args = fn.fitem.arg_names
if isinstance(typ_arg, NameExpr) and len(method_args) > 0 and method_args[0] == typ_arg.name:
subtype = builder.accept(expr.args[0])
return builder.add(Call(ir.setup, [subtype], expr.line))
subs = ir.subclasses()
if subs is not None and len(subs) == 0:
return builder.add(Call(ir.setup, [subtype], expr.line))
# Call a function that dynamically resolves the setup function of extension classes from the type object.
# This is necessary because the setup involves default attribute initialization and setting up
# the vtable which are specific to a given type and will not work if a subtype is created using
# the setup function of its base.
return builder.call_c(setup_object, [subtype], expr.line)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment here explaining why a dynamic method dispatch is needed here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a fast path if class is known to not have subclasses -- we would still be able to use the old direct call, right?


return None

Expand Down
2 changes: 2 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,8 @@ static inline int CPyObject_GenericSetAttr(PyObject *self, PyObject *name, PyObj
return _PyObject_GenericSetAttrWithDict(self, name, value, NULL);
}

PyObject *CPy_SetupObject(PyObject *type);

#if CPY_3_11_FEATURES
PyObject *CPy_GetName(PyObject *obj);
#endif
Expand Down
20 changes: 20 additions & 0 deletions mypyc/lib-rt/generic_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,23 @@ PyObject *CPyObject_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) {
Py_DECREF(slice);
return result;
}

typedef PyObject *(*SetupFunction)(PyObject *);

PyObject *CPy_SetupObject(PyObject *type) {
PyTypeObject *tp = (PyTypeObject *)type;
PyMethodDef *def = NULL;
for(; tp; tp = tp->tp_base) {
def = tp->tp_methods;
if (!def || !def->ml_name) {
continue;
}

if (!strcmp(def->ml_name, "__internal_mypyc_setup")) {
return ((SetupFunction)(void(*)(void))def->ml_meth)(type);
}
}

PyErr_SetString(PyExc_RuntimeError, "Internal mypyc error: Unable to find object setup function");
return NULL;
}
7 changes: 7 additions & 0 deletions mypyc/primitives/generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,10 @@
c_function_name="CPyObject_GenericSetAttr",
error_kind=ERR_NEG_INT,
)

setup_object = custom_op(
arg_types=[object_rprimitive],
return_type=object_rprimitive,
c_function_name="CPy_SetupObject",
error_kind=ERR_MAGIC,
)
28 changes: 28 additions & 0 deletions mypyc/test-data/irbuild-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,13 @@ class Test:
obj.val = val
return obj

class Test2:
def __new__(cls) -> Test2:
return super().__new__(cls)

class Sub(Test2):
pass

def fn() -> Test:
return Test.__new__(Test, 42)

Expand Down Expand Up @@ -1719,6 +1726,13 @@ L0:
obj = r0
obj.val = val; r1 = is_error
return obj
def Test2.__new__(cls):
cls, r0 :: object
r1 :: __main__.Test2
L0:
r0 = CPy_SetupObject(cls)
r1 = cast(__main__.Test2, r0)
return r1
def fn():
r0 :: object
r1 :: __main__.Test
Expand Down Expand Up @@ -1822,6 +1836,13 @@ class Test:
obj.val = val
return obj

class Test2:
def __new__(cls) -> Test2:
return object.__new__(cls)

class Sub(Test2):
pass

def fn() -> Test:
return Test.__new__(Test, 42)

Expand Down Expand Up @@ -1874,6 +1895,13 @@ L0:
obj = r0
obj.val = val; r1 = is_error
return obj
def Test2.__new__(cls):
cls, r0 :: object
r1 :: __main__.Test2
L0:
r0 = CPy_SetupObject(cls)
r1 = cast(__main__.Test2, r0)
return r1
def fn():
r0 :: object
r1 :: __main__.Test
Expand Down
55 changes: 55 additions & 0 deletions mypyc/test-data/run-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -3859,6 +3859,7 @@ Add(1, 0)=1
[case testInheritedDunderNew]
from __future__ import annotations
from mypy_extensions import mypyc_attr
from testutil import assertRaises
from typing_extensions import Self

from m import interpreted_subclass
Expand All @@ -3875,19 +3876,32 @@ class Base:
def __init__(self, val: int) -> None:
self.init_val = val

def method(self) -> int:
raise NotImplementedError

class Sub(Base):

def __new__(cls, val: int) -> Self:
return super().__new__(cls, val + 1)

def __init__(self, val: int) -> None:
super().__init__(val)
self.init_val = self.init_val * 2

def method(self) -> int:
return 0

class SubWithoutNew(Base):
sub_only_str = ""
sub_only_int: int

def __init__(self, val: int) -> None:
super().__init__(val)
self.init_val = self.init_val * 2

def method(self) -> int:
return 1

class BaseWithoutInterpretedSubclasses:
val: int

Expand All @@ -3899,6 +3913,9 @@ class BaseWithoutInterpretedSubclasses:
def __init__(self, val: int) -> None:
self.init_val = val

def method(self) -> int:
raise NotImplementedError

class SubNoInterpreted(BaseWithoutInterpretedSubclasses):
def __new__(cls, val: int) -> Self:
return super().__new__(cls, val + 1)
Expand All @@ -3907,55 +3924,77 @@ class SubNoInterpreted(BaseWithoutInterpretedSubclasses):
super().__init__(val)
self.init_val = self.init_val * 2

def method(self) -> int:
return 0

class SubNoInterpretedWithoutNew(BaseWithoutInterpretedSubclasses):
def __init__(self, val: int) -> None:
super().__init__(val)
self.init_val = self.init_val * 2

def method(self) -> int:
return 1

def test_inherited_dunder_new() -> None:
b = Base(42)
assert type(b) == Base
assert b.val == 43
assert b.init_val == 42
with assertRaises(NotImplementedError):
b.method()

s = Sub(42)
assert type(s) == Sub
assert s.val == 44
assert s.init_val == 84
assert s.method() == 0

s2 = SubWithoutNew(42)
assert type(s2) == SubWithoutNew
assert s2.val == 43
assert s2.init_val == 84
assert s2.method() == 1
assert s2.sub_only_str == ""
with assertRaises(AttributeError):
s2.sub_only_int
s2.sub_only_int = 11
assert s2.sub_only_int == 11

def test_inherited_dunder_new_without_interpreted_subclasses() -> None:
b = BaseWithoutInterpretedSubclasses(42)
assert type(b) == BaseWithoutInterpretedSubclasses
assert b.val == 43
assert b.init_val == 42
with assertRaises(NotImplementedError):
b.method()

s = SubNoInterpreted(42)
assert type(s) == SubNoInterpreted
assert s.val == 44
assert s.init_val == 84
assert s.method() == 0

s2 = SubNoInterpretedWithoutNew(42)
assert type(s2) == SubNoInterpretedWithoutNew
assert s2.val == 43
assert s2.init_val == 84
assert s2.method() == 1

def test_interpreted_subclass() -> None:
interpreted_subclass(Base)

[file m.py]
from __future__ import annotations
from testutil import assertRaises
from typing_extensions import Self

def interpreted_subclass(base) -> None:
b = base(42)
assert type(b) == base
assert b.val == 43
assert b.init_val == 42
with assertRaises(NotImplementedError):
b.method()

class InterpretedSub(base):
def __new__(cls, val: int) -> Self:
Expand All @@ -3965,20 +4004,36 @@ def interpreted_subclass(base) -> None:
super().__init__(val)
self.init_val : int = self.init_val * 2

def method(self) -> int:
return 3

s = InterpretedSub(42)
assert type(s) == InterpretedSub
assert s.val == 44
assert s.init_val == 84
assert s.method() == 3

class InterpretedSubWithoutNew(base):
sub_only_str = ""
sub_only_int: int

def __init__(self, val: int) -> None:
super().__init__(val)
self.init_val : int = self.init_val * 2

def method(self) -> int:
return 4

s2 = InterpretedSubWithoutNew(42)
assert type(s2) == InterpretedSubWithoutNew
assert s2.val == 43
assert s2.init_val == 84
assert s2.method() == 4
assert s2.sub_only_str == ""
with assertRaises(AttributeError):
s2.sub_only_int
s2.sub_only_int = 11
assert s2.sub_only_int == 11

[typing fixtures/typing-full.pyi]

Expand Down