Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions include/pybind11/stl.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,14 @@ struct set_caster {
public:
bool load(handle src, bool convert) {
if (!PyObjectTypeIsConvertibleToStdSet(src.ptr())) {
return false;
if (!convert) {
return false;
}
if (!(isinstance(src, module_::import("collections.abc").attr("Set"))
&& hasattr(src, "__contains__") && hasattr(src, "__iter__")
&& hasattr(src, "__len__"))) {
return false;
}
}
if (isinstance<anyset>(src)) {
value.clear();
Expand Down Expand Up @@ -203,7 +210,9 @@ struct set_caster {
return s.release();
}

PYBIND11_TYPE_CASTER(type, const_name("set[") + key_conv::name + const_name("]"));
PYBIND11_TYPE_CASTER(type,
io_name("collections.abc.Set", "set") + const_name("[") + key_conv::name
+ const_name("]"));
};

template <typename Type, typename Key, typename Value>
Expand Down Expand Up @@ -235,7 +244,14 @@ struct map_caster {
public:
bool load(handle src, bool convert) {
if (!PyObjectTypeIsConvertibleToStdMap(src.ptr())) {
return false;
if (!convert) {
return false;
}
if (!(isinstance(src, module_::import("collections.abc").attr("Mapping"))
&& hasattr(src, "__getitem__") && hasattr(src, "__iter__")
&& hasattr(src, "__len__"))) {
return false;
}
}
if (isinstance<dict>(src)) {
return convert_elements(reinterpret_borrow<dict>(src), convert);
Expand Down Expand Up @@ -274,7 +290,8 @@ struct map_caster {
}

PYBIND11_TYPE_CASTER(Type,
const_name("dict[") + key_conv::name + const_name(", ") + value_conv::name
io_name("collections.abc.Mapping", "dict") + const_name("[")
+ key_conv::name + const_name(", ") + value_conv::name
+ const_name("]"));
};

Expand Down Expand Up @@ -340,7 +357,9 @@ struct list_caster {
return l.release();
}

PYBIND11_TYPE_CASTER(Type, const_name("list[") + value_conv::name + const_name("]"));
PYBIND11_TYPE_CASTER(Type,
io_name("collections.abc.Sequence", "list") + const_name("[")
+ value_conv::name + const_name("]"));
};

template <typename Type, typename Alloc>
Expand Down Expand Up @@ -474,10 +493,12 @@ struct array_caster {
using cast_op_type = movable_cast_op_type<T_>;

static constexpr auto name
= const_name<Resizable>(const_name(""), const_name("Annotated[")) + const_name("list[")
+ value_conv::name + const_name("]")
+ const_name<Resizable>(
const_name(""), const_name(", FixedSize(") + const_name<Size>() + const_name(")]"));
= const_name<Resizable>(const_name(""), const_name("typing.Annotated["))
+ io_name("collections.abc.Sequence", "list") + const_name("[") + value_conv::name
+ const_name("]")
+ const_name<Resizable>(const_name(""),
const_name(", \"FixedSize(") + const_name<Size>()
+ const_name(")\"]"));
};

template <typename Type, size_t Size>
Expand Down
2 changes: 1 addition & 1 deletion tests/test_kwargs_and_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_function_signatures(doc):
assert doc(m.kw_func3) == "kw_func3(data: str = 'Hello world!') -> None"
assert (
doc(m.kw_func4)
== "kw_func4(myList: list[typing.SupportsInt] = [13, 17]) -> str"
== "kw_func4(myList: collections.abc.Sequence[typing.SupportsInt] = [13, 17]) -> str"
)
assert (
doc(m.kw_func_udl)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@ def test_arg_return_type_hints(doc):
# std::vector<T>
assert (
doc(m.half_of_number_vector)
== "half_of_number_vector(arg0: list[Union[float, int]]) -> list[float]"
== "half_of_number_vector(arg0: collections.abc.Sequence[Union[float, int]]) -> list[float]"
)
# Tuple<T, T>
assert (
Expand Down
3 changes: 3 additions & 0 deletions tests/test_stl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,4 +648,7 @@ TEST_SUBMODULE(stl, m) {
}
return zum;
});
m.def("roundtrip_std_vector_int", [](const std::vector<int> &v) { return v; });
m.def("roundtrip_std_map_str_int", [](const std::map<std::string, int> &m) { return m; });
m.def("roundtrip_std_set_int", [](const std::set<int> &s) { return s; });
}
153 changes: 144 additions & 9 deletions tests/test_stl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def test_vector(doc):
assert m.load_bool_vector((True, False))

assert doc(m.cast_vector) == "cast_vector() -> list[int]"
assert doc(m.load_vector) == "load_vector(arg0: list[typing.SupportsInt]) -> bool"
assert (
doc(m.load_vector)
== "load_vector(arg0: collections.abc.Sequence[typing.SupportsInt]) -> bool"
)

# Test regression caused by 936: pointers to stl containers weren't castable
assert m.cast_ptr_vector() == ["lvalue", "lvalue"]
Expand All @@ -42,10 +45,13 @@ def test_array(doc):
assert m.load_array(lst)
assert m.load_array(tuple(lst))

assert doc(m.cast_array) == "cast_array() -> Annotated[list[int], FixedSize(2)]"
assert (
doc(m.cast_array)
== 'cast_array() -> typing.Annotated[list[int], "FixedSize(2)"]'
)
assert (
doc(m.load_array)
== "load_array(arg0: Annotated[list[typing.SupportsInt], FixedSize(2)]) -> bool"
== 'load_array(arg0: typing.Annotated[collections.abc.Sequence[typing.SupportsInt], "FixedSize(2)"]) -> bool'
)


Expand All @@ -65,7 +71,8 @@ def test_valarray(doc):

assert doc(m.cast_valarray) == "cast_valarray() -> list[int]"
assert (
doc(m.load_valarray) == "load_valarray(arg0: list[typing.SupportsInt]) -> bool"
doc(m.load_valarray)
== "load_valarray(arg0: collections.abc.Sequence[typing.SupportsInt]) -> bool"
)


Expand All @@ -79,7 +86,9 @@ def test_map(doc):
assert m.load_map(d)

assert doc(m.cast_map) == "cast_map() -> dict[str, str]"
assert doc(m.load_map) == "load_map(arg0: dict[str, str]) -> bool"
assert (
doc(m.load_map) == "load_map(arg0: collections.abc.Mapping[str, str]) -> bool"
)


def test_set(doc):
Expand All @@ -91,7 +100,7 @@ def test_set(doc):
assert m.load_set(frozenset(s))

assert doc(m.cast_set) == "cast_set() -> set[str]"
assert doc(m.load_set) == "load_set(arg0: set[str]) -> bool"
assert doc(m.load_set) == "load_set(arg0: collections.abc.Set[str]) -> bool"


def test_recursive_casting():
Expand Down Expand Up @@ -273,7 +282,7 @@ def __fspath__(self):
assert m.parent_paths(["foo/bar", "foo/baz"]) == [Path("foo"), Path("foo")]
assert (
doc(m.parent_paths)
== "parent_paths(arg0: list[Union[os.PathLike, str, bytes]]) -> list[pathlib.Path]"
== "parent_paths(arg0: collections.abc.Sequence[Union[os.PathLike, str, bytes]]) -> list[pathlib.Path]"
)
# py::typing::List
assert m.parent_paths_list(["foo/bar", "foo/baz"]) == [Path("foo"), Path("foo")]
Expand Down Expand Up @@ -364,7 +373,7 @@ def test_stl_pass_by_pointer(msg):
msg(excinfo.value)
== """
stl_pass_by_pointer(): incompatible function arguments. The following argument types are supported:
1. (v: list[typing.SupportsInt] = None) -> list[int]
1. (v: collections.abc.Sequence[typing.SupportsInt] = None) -> list[int]

Invoked with:
"""
Expand All @@ -376,7 +385,7 @@ def test_stl_pass_by_pointer(msg):
msg(excinfo.value)
== """
stl_pass_by_pointer(): incompatible function arguments. The following argument types are supported:
1. (v: list[typing.SupportsInt] = None) -> list[int]
1. (v: collections.abc.Sequence[typing.SupportsInt] = None) -> list[int]

Invoked with: None
"""
Expand Down Expand Up @@ -567,3 +576,129 @@ def gen_invalid():
with pytest.raises(expected_exception):
m.pass_std_map_int(FakePyMappingGenObj(gen_obj))
assert not tuple(gen_obj)


def test_sequence_caster_protocol(doc):
from collections.abc import Sequence

class SequenceLike(Sequence):
def __init__(self, *args):
self.data = tuple(args)

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index]

class FakeSequenceLike:
def __init__(self, *args):
self.data = tuple(args)

def __len__(self):
return len(self.data)

def __getitem__(self, index):
return self.data[index]

assert (
doc(m.roundtrip_std_vector_int)
== "roundtrip_std_vector_int(arg0: collections.abc.Sequence[typing.SupportsInt]) -> list[int]"
)
assert m.roundtrip_std_vector_int([1, 2, 3]) == [1, 2, 3]
assert m.roundtrip_std_vector_int((1, 2, 3)) == [1, 2, 3]
assert m.roundtrip_std_vector_int(SequenceLike(1, 2, 3)) == [1, 2, 3]
assert m.roundtrip_std_vector_int(FakeSequenceLike(1, 2, 3)) == [1, 2, 3]
assert m.roundtrip_std_vector_int([]) == []
assert m.roundtrip_std_vector_int(()) == []
assert m.roundtrip_std_vector_int(FakeSequenceLike()) == []


def test_mapping_caster_protocol(doc):
from collections.abc import Mapping

class MappingLike(Mapping):
def __init__(self, **kwargs):
self.data = dict(kwargs)

def __len__(self):
return len(self.data)

def __getitem__(self, key):
return self.data[key]

def __iter__(self):
yield from self.data

class FakeMappingLike:
def __init__(self, **kwargs):
self.data = dict(kwargs)

def __len__(self):
return len(self.data)

def __getitem__(self, key):
return self.data[key]

def __iter__(self):
yield from self.data

assert (
doc(m.roundtrip_std_map_str_int)
== "roundtrip_std_map_str_int(arg0: collections.abc.Mapping[str, typing.SupportsInt]) -> dict[str, int]"
)
assert m.roundtrip_std_map_str_int({"a": 1, "b": 2, "c": 3}) == {
"a": 1,
"b": 2,
"c": 3,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to keep this more compact via

    a1b2c3 = {"a": 1, "b": 2, "c": 3}

and then reuse three times.

Copy link
Contributor Author

@timohl timohl Apr 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented in 6fc2ab3

I just made this change for the mapping test.
In the sequence test, there is now:

pybind11/tests/test_stl.py

Lines 619 to 623 in 383dcb5

assert m.roundtrip_std_vector_int_noconvert(FormalSequenceLike(1, 2, 3)) == [
1,
2,
3,
]

Should I change this to use a variable as well (like list123)?

assert m.roundtrip_std_map_str_int(MappingLike(a=1, b=2, c=3)) == {
"a": 1,
"b": 2,
"c": 3,
}
assert m.roundtrip_std_map_str_int({}) == {}
assert m.roundtrip_std_map_str_int(MappingLike()) == {}
with pytest.raises(TypeError):
m.roundtrip_std_map_str_int(FakeMappingLike(a=1, b=2, c=3))


def test_set_caster_protocol(doc):
from collections.abc import Set

class SetLike(Set):
def __init__(self, *args):
self.data = set(args)

def __len__(self):
return len(self.data)

def __contains__(self, item):
return item in self.data

def __iter__(self):
yield from self.data

class FakeSetLike:
def __init__(self, *args):
self.data = set(args)

def __len__(self):
return len(self.data)

def __contains__(self, item):
return item in self.data

def __iter__(self):
yield from self.data

assert (
doc(m.roundtrip_std_set_int)
== "roundtrip_std_set_int(arg0: collections.abc.Set[typing.SupportsInt]) -> set[int]"
)
assert m.roundtrip_std_set_int({1, 2, 3}) == {1, 2, 3}
assert m.roundtrip_std_set_int(SetLike(1, 2, 3)) == {1, 2, 3}
assert m.roundtrip_std_set_int(set()) == set()
assert m.roundtrip_std_set_int(SetLike()) == set()
with pytest.raises(TypeError):
m.roundtrip_std_set_int(FakeSetLike(1, 2, 3))
Loading