Skip to content

Commit 9b415a3

Browse files
committed
great fun exposing boost::unordered_map for CostStack :D
1 parent ac60bd1 commit 9b415a3

File tree

2 files changed

+48
-34
lines changed

2 files changed

+48
-34
lines changed

bindings/python/src/modelling/expose-cost-stack.cpp

Lines changed: 44 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "aligator/modelling/costs/sum-of-costs.hpp"
55

66
#include <eigenpy/std-pair.hpp>
7+
#include <eigenpy/std-map.hpp>
78
#include <eigenpy/variant.hpp>
89

910
namespace aligator {
@@ -23,40 +24,50 @@ void exposeCostStack() {
2324
eigenpy::StdPairConverter<CostItem>::registration();
2425
eigenpy::VariantConverter<CostKey>::registration();
2526

26-
bp::class_<CostStack, bp::bases<CostAbstract>>(
27-
"CostStack", "A weighted sum of other cost functions.", bp::no_init)
28-
.def(bp::init<xyz::polymorphic<Manifold>, const int,
29-
const std::vector<PolyCost> &, const std::vector<Scalar> &>(
30-
("self"_a, "space", "nu", "components"_a = bp::list(),
31-
"weights"_a = bp::list())))
32-
.def(bp::init<const PolyCost &>(("self"_a, "cost")))
33-
// .def_readwrite("components", &CostStack::components_,
34-
// "Components of this cost stack.")
35-
.def(
36-
"addCost",
37-
+[](CostStack &self, const PolyCost &cost, const Scalar weight) {
38-
// return
39-
self.addCost(cost, weight);
40-
},
41-
("self"_a, "cost", "weight"_a = 1.),
42-
bp::return_internal_reference<>())
43-
.def(
44-
"addCost",
45-
+[](CostStack &self, CostKey key, const PolyCost &cost,
46-
const Scalar weight) {
47-
// return
48-
self.addCost(key, cost, weight);
49-
},
50-
("self"_a, "key", "cost", "weight"_a = 1.),
51-
bp::return_internal_reference<>())
52-
.def("size", &CostStack::size, "Get the number of cost components.")
53-
.def(CopyableVisitor<CostStack>())
54-
.def(PolymorphicMultiBaseVisitor<CostAbstract>());
27+
{
28+
bp::scope scope =
29+
bp::class_<CostStack, bp::bases<CostAbstract>>(
30+
"CostStack", "A weighted sum of other cost functions.", bp::no_init)
31+
.def(bp::init<xyz::polymorphic<Manifold>, const int,
32+
const std::vector<PolyCost> &,
33+
const std::vector<Scalar> &>(
34+
("self"_a, "space", "nu", "components"_a = bp::list(),
35+
"weights"_a = bp::list())))
36+
.def(bp::init<const PolyCost &>(("self"_a, "cost")))
37+
.def_readwrite("components", &CostStack::components_,
38+
"Components of this cost stack.")
39+
.def(
40+
"addCost",
41+
+[](CostStack &self, const PolyCost &cost,
42+
const Scalar weight) {
43+
// return
44+
self.addCost(cost, weight);
45+
},
46+
("self"_a, "cost", "weight"_a = 1.),
47+
bp::return_internal_reference<>())
48+
.def(
49+
"addCost",
50+
+[](CostStack &self, CostKey key, const PolyCost &cost,
51+
const Scalar weight) {
52+
// return
53+
self.addCost(key, cost, weight);
54+
},
55+
("self"_a, "key", "cost", "weight"_a = 1.),
56+
bp::return_internal_reference<>())
57+
.def("size", &CostStack::size, "Get the number of cost components.")
58+
.def(CopyableVisitor<CostStack>())
59+
.def(PolymorphicMultiBaseVisitor<CostAbstract>());
60+
eigenpy::GenericMapVisitor<CostMap, true>::expose("CostMap");
61+
}
5562

56-
bp::register_ptr_to_python<shared_ptr<CostStackData>>();
57-
bp::class_<CostStackData, bp::bases<CostData>>(
58-
"CostStackData", "Data struct for CostStack.", bp::no_init)
59-
.def_readonly("sub_cost_data", &CostStackData::sub_cost_data);
63+
{
64+
bp::register_ptr_to_python<shared_ptr<CostStackData>>();
65+
bp::scope scope =
66+
bp::class_<CostStackData, bp::bases<CostData>>(
67+
"CostStackData", "Data struct for CostStack.", bp::no_init)
68+
.def_readonly("sub_cost_data", &CostStackData::sub_cost_data);
69+
eigenpy::GenericMapVisitor<CostStackData::DataMap, true>::expose("DataMap");
70+
}
6071
}
6172

6273
} // namespace python

tests/python/test_costs.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,14 @@ def test_stack_error():
145145
rcost = QuadraticCost(Q, R)
146146
cost_stack.addCost(rcost) # optional
147147

148+
cost_stack.components
149+
print(cost_stack.components.todict())
150+
148151
rc2 = QuadraticCost(np.eye(3), np.eye(nu))
149152
rc3 = QuadraticCost(np.eye(nx), np.eye(nu * 2))
150153

151154
cost_data = cost_stack.createData()
152-
print(cost_data.sub_cost_data.tolist())
155+
print(cost_data.sub_cost_data.todict())
153156

154157
with pytest.raises(Exception) as e_info:
155158
cost_stack.addCost(rc2)

0 commit comments

Comments
 (0)