Skip to content

Commit 1b94fbb

Browse files
[mypyc] Fix vtable pointer with inherited dunder new (#20302)
Fixes an issue where a subclass would have its vtable pointer set to the base class' vtable when there is a `__new__` method defined in the base class. This resulted in the subclass constructor calling the setup function of the base class because mypyc transforms `object.__new__` into the setup function. The fix is to store the pointers to the setup functions in `tp_methods` of type objects and look them up dynamically when instantiating new objects. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 13369cb commit 1b94fbb

File tree

7 files changed

+132
-4
lines changed

7 files changed

+132
-4
lines changed

mypyc/codegen/emitclass.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def emit_line() -> None:
359359
if cl.is_trait:
360360
generate_new_for_trait(cl, new_name, emitter)
361361

362-
generate_methods_table(cl, methods_name, emitter)
362+
generate_methods_table(cl, methods_name, setup_name if generate_full else None, emitter)
363363
emit_line()
364364

365365
flags = ["Py_TPFLAGS_DEFAULT", "Py_TPFLAGS_HEAPTYPE", "Py_TPFLAGS_BASETYPE"]
@@ -960,8 +960,17 @@ def generate_finalize_for_class(
960960
emitter.emit_line("}")
961961

962962

963-
def generate_methods_table(cl: ClassIR, name: str, emitter: Emitter) -> None:
963+
def generate_methods_table(
964+
cl: ClassIR, name: str, setup_name: str | None, emitter: Emitter
965+
) -> None:
964966
emitter.emit_line(f"static PyMethodDef {name}[] = {{")
967+
if setup_name:
968+
# Store pointer to the setup function so it can be resolved dynamically
969+
# in case of instance creation in __new__.
970+
# CPy_SetupObject expects this method to be the first one in tp_methods.
971+
emitter.emit_line(
972+
f'{{"__internal_mypyc_setup", (PyCFunction){setup_name}, METH_O, NULL}},'
973+
)
965974
for fn in cl.methods.values():
966975
if fn.decl.is_prop_setter or fn.decl.is_prop_getter or fn.internal:
967976
continue

mypyc/irbuild/specialize.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
isinstance_dict,
100100
)
101101
from mypyc.primitives.float_ops import isinstance_float
102-
from mypyc.primitives.generic_ops import generic_setattr
102+
from mypyc.primitives.generic_ops import generic_setattr, setup_object
103103
from mypyc.primitives.int_ops import isinstance_int
104104
from mypyc.primitives.list_ops import isinstance_list, new_list_set_item_op
105105
from mypyc.primitives.misc_ops import isinstance_bool
@@ -1103,7 +1103,14 @@ def translate_object_new(builder: IRBuilder, expr: CallExpr, callee: RefExpr) ->
11031103
method_args = fn.fitem.arg_names
11041104
if isinstance(typ_arg, NameExpr) and len(method_args) > 0 and method_args[0] == typ_arg.name:
11051105
subtype = builder.accept(expr.args[0])
1106-
return builder.add(Call(ir.setup, [subtype], expr.line))
1106+
subs = ir.subclasses()
1107+
if subs is not None and len(subs) == 0:
1108+
return builder.add(Call(ir.setup, [subtype], expr.line))
1109+
# Call a function that dynamically resolves the setup function of extension classes from the type object.
1110+
# This is necessary because the setup involves default attribute initialization and setting up
1111+
# the vtable which are specific to a given type and will not work if a subtype is created using
1112+
# the setup function of its base.
1113+
return builder.call_c(setup_object, [subtype], expr.line)
11071114

11081115
return None
11091116

mypyc/lib-rt/CPy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,8 @@ static inline int CPyObject_GenericSetAttr(PyObject *self, PyObject *name, PyObj
958958
return _PyObject_GenericSetAttrWithDict(self, name, value, NULL);
959959
}
960960

961+
PyObject *CPy_SetupObject(PyObject *type);
962+
961963
#if CPY_3_11_FEATURES
962964
PyObject *CPy_GetName(PyObject *obj);
963965
#endif

mypyc/lib-rt/generic_ops.c

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,23 @@ PyObject *CPyObject_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end) {
6262
Py_DECREF(slice);
6363
return result;
6464
}
65+
66+
typedef PyObject *(*SetupFunction)(PyObject *);
67+
68+
PyObject *CPy_SetupObject(PyObject *type) {
69+
PyTypeObject *tp = (PyTypeObject *)type;
70+
PyMethodDef *def = NULL;
71+
for(; tp; tp = tp->tp_base) {
72+
def = tp->tp_methods;
73+
if (!def || !def->ml_name) {
74+
continue;
75+
}
76+
77+
if (!strcmp(def->ml_name, "__internal_mypyc_setup")) {
78+
return ((SetupFunction)(void(*)(void))def->ml_meth)(type);
79+
}
80+
}
81+
82+
PyErr_SetString(PyExc_RuntimeError, "Internal mypyc error: Unable to find object setup function");
83+
return NULL;
84+
}

mypyc/primitives/generic_ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,10 @@
417417
c_function_name="CPyObject_GenericSetAttr",
418418
error_kind=ERR_NEG_INT,
419419
)
420+
421+
setup_object = custom_op(
422+
arg_types=[object_rprimitive],
423+
return_type=object_rprimitive,
424+
c_function_name="CPy_SetupObject",
425+
error_kind=ERR_MAGIC,
426+
)

mypyc/test-data/irbuild-classes.test

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1685,6 +1685,13 @@ class Test:
16851685
obj.val = val
16861686
return obj
16871687

1688+
class Test2:
1689+
def __new__(cls) -> Test2:
1690+
return super().__new__(cls)
1691+
1692+
class Sub(Test2):
1693+
pass
1694+
16881695
def fn() -> Test:
16891696
return Test.__new__(Test, 42)
16901697

@@ -1719,6 +1726,13 @@ L0:
17191726
obj = r0
17201727
obj.val = val; r1 = is_error
17211728
return obj
1729+
def Test2.__new__(cls):
1730+
cls, r0 :: object
1731+
r1 :: __main__.Test2
1732+
L0:
1733+
r0 = CPy_SetupObject(cls)
1734+
r1 = cast(__main__.Test2, r0)
1735+
return r1
17221736
def fn():
17231737
r0 :: object
17241738
r1 :: __main__.Test
@@ -1822,6 +1836,13 @@ class Test:
18221836
obj.val = val
18231837
return obj
18241838

1839+
class Test2:
1840+
def __new__(cls) -> Test2:
1841+
return object.__new__(cls)
1842+
1843+
class Sub(Test2):
1844+
pass
1845+
18251846
def fn() -> Test:
18261847
return Test.__new__(Test, 42)
18271848

@@ -1874,6 +1895,13 @@ L0:
18741895
obj = r0
18751896
obj.val = val; r1 = is_error
18761897
return obj
1898+
def Test2.__new__(cls):
1899+
cls, r0 :: object
1900+
r1 :: __main__.Test2
1901+
L0:
1902+
r0 = CPy_SetupObject(cls)
1903+
r1 = cast(__main__.Test2, r0)
1904+
return r1
18771905
def fn():
18781906
r0 :: object
18791907
r1 :: __main__.Test

mypyc/test-data/run-classes.test

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3859,6 +3859,7 @@ Add(1, 0)=1
38593859
[case testInheritedDunderNew]
38603860
from __future__ import annotations
38613861
from mypy_extensions import mypyc_attr
3862+
from testutil import assertRaises
38623863
from typing_extensions import Self
38633864

38643865
from m import interpreted_subclass
@@ -3875,19 +3876,32 @@ class Base:
38753876
def __init__(self, val: int) -> None:
38763877
self.init_val = val
38773878

3879+
def method(self) -> int:
3880+
raise NotImplementedError
3881+
38783882
class Sub(Base):
3883+
38793884
def __new__(cls, val: int) -> Self:
38803885
return super().__new__(cls, val + 1)
38813886

38823887
def __init__(self, val: int) -> None:
38833888
super().__init__(val)
38843889
self.init_val = self.init_val * 2
38853890

3891+
def method(self) -> int:
3892+
return 0
3893+
38863894
class SubWithoutNew(Base):
3895+
sub_only_str = ""
3896+
sub_only_int: int
3897+
38873898
def __init__(self, val: int) -> None:
38883899
super().__init__(val)
38893900
self.init_val = self.init_val * 2
38903901

3902+
def method(self) -> int:
3903+
return 1
3904+
38913905
class BaseWithoutInterpretedSubclasses:
38923906
val: int
38933907

@@ -3899,6 +3913,9 @@ class BaseWithoutInterpretedSubclasses:
38993913
def __init__(self, val: int) -> None:
39003914
self.init_val = val
39013915

3916+
def method(self) -> int:
3917+
raise NotImplementedError
3918+
39023919
class SubNoInterpreted(BaseWithoutInterpretedSubclasses):
39033920
def __new__(cls, val: int) -> Self:
39043921
return super().__new__(cls, val + 1)
@@ -3907,55 +3924,77 @@ class SubNoInterpreted(BaseWithoutInterpretedSubclasses):
39073924
super().__init__(val)
39083925
self.init_val = self.init_val * 2
39093926

3927+
def method(self) -> int:
3928+
return 0
3929+
39103930
class SubNoInterpretedWithoutNew(BaseWithoutInterpretedSubclasses):
39113931
def __init__(self, val: int) -> None:
39123932
super().__init__(val)
39133933
self.init_val = self.init_val * 2
39143934

3935+
def method(self) -> int:
3936+
return 1
3937+
39153938
def test_inherited_dunder_new() -> None:
39163939
b = Base(42)
39173940
assert type(b) == Base
39183941
assert b.val == 43
39193942
assert b.init_val == 42
3943+
with assertRaises(NotImplementedError):
3944+
b.method()
39203945

39213946
s = Sub(42)
39223947
assert type(s) == Sub
39233948
assert s.val == 44
39243949
assert s.init_val == 84
3950+
assert s.method() == 0
39253951

39263952
s2 = SubWithoutNew(42)
39273953
assert type(s2) == SubWithoutNew
39283954
assert s2.val == 43
39293955
assert s2.init_val == 84
3956+
assert s2.method() == 1
3957+
assert s2.sub_only_str == ""
3958+
with assertRaises(AttributeError):
3959+
s2.sub_only_int
3960+
s2.sub_only_int = 11
3961+
assert s2.sub_only_int == 11
39303962

39313963
def test_inherited_dunder_new_without_interpreted_subclasses() -> None:
39323964
b = BaseWithoutInterpretedSubclasses(42)
39333965
assert type(b) == BaseWithoutInterpretedSubclasses
39343966
assert b.val == 43
39353967
assert b.init_val == 42
3968+
with assertRaises(NotImplementedError):
3969+
b.method()
39363970

39373971
s = SubNoInterpreted(42)
39383972
assert type(s) == SubNoInterpreted
39393973
assert s.val == 44
39403974
assert s.init_val == 84
3975+
assert s.method() == 0
39413976

39423977
s2 = SubNoInterpretedWithoutNew(42)
39433978
assert type(s2) == SubNoInterpretedWithoutNew
39443979
assert s2.val == 43
39453980
assert s2.init_val == 84
3981+
assert s2.method() == 1
39463982

39473983
def test_interpreted_subclass() -> None:
39483984
interpreted_subclass(Base)
39493985

39503986
[file m.py]
39513987
from __future__ import annotations
3988+
from testutil import assertRaises
39523989
from typing_extensions import Self
39533990

39543991
def interpreted_subclass(base) -> None:
39553992
b = base(42)
39563993
assert type(b) == base
39573994
assert b.val == 43
39583995
assert b.init_val == 42
3996+
with assertRaises(NotImplementedError):
3997+
b.method()
39593998

39603999
class InterpretedSub(base):
39614000
def __new__(cls, val: int) -> Self:
@@ -3965,20 +4004,36 @@ def interpreted_subclass(base) -> None:
39654004
super().__init__(val)
39664005
self.init_val : int = self.init_val * 2
39674006

4007+
def method(self) -> int:
4008+
return 3
4009+
39684010
s = InterpretedSub(42)
39694011
assert type(s) == InterpretedSub
39704012
assert s.val == 44
39714013
assert s.init_val == 84
4014+
assert s.method() == 3
39724015

39734016
class InterpretedSubWithoutNew(base):
4017+
sub_only_str = ""
4018+
sub_only_int: int
4019+
39744020
def __init__(self, val: int) -> None:
39754021
super().__init__(val)
39764022
self.init_val : int = self.init_val * 2
39774023

4024+
def method(self) -> int:
4025+
return 4
4026+
39784027
s2 = InterpretedSubWithoutNew(42)
39794028
assert type(s2) == InterpretedSubWithoutNew
39804029
assert s2.val == 43
39814030
assert s2.init_val == 84
4031+
assert s2.method() == 4
4032+
assert s2.sub_only_str == ""
4033+
with assertRaises(AttributeError):
4034+
s2.sub_only_int
4035+
s2.sub_only_int = 11
4036+
assert s2.sub_only_int == 11
39824037

39834038
[typing fixtures/typing-full.pyi]
39844039

0 commit comments

Comments
 (0)