Skip to content

Commit 7aee1e7

Browse files
changing MPPI's SG filter to 9-point formulation (prev. 5) (ros-navigation#3444)
* changing filter to 9 * fix tests
1 parent 8d4f6f4 commit 7aee1e7

File tree

4 files changed

+112
-23
lines changed

4 files changed

+112
-23
lines changed

nav2_mppi_controller/include/nav2_mppi_controller/optimizer.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class Optimizer
248248

249249
models::State state_;
250250
models::ControlSequence control_sequence_;
251-
std::array<mppi::models::Control, 2> control_history_;
251+
std::array<mppi::models::Control, 4> control_history_;
252252
models::Trajectories generated_trajectories_;
253253
models::Path path_;
254254
xt::xtensor<float, 1> costs_;

nav2_mppi_controller/include/nav2_mppi_controller/tools/utils.hpp

Lines changed: 96 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -436,17 +436,17 @@ inline double posePointAngle(const geometry_msgs::msg::Pose & pose, double point
436436
*/
437437
inline void savitskyGolayFilter(
438438
models::ControlSequence & control_sequence,
439-
std::array<mppi::models::Control, 2> & control_history,
439+
std::array<mppi::models::Control, 4> & control_history,
440440
const models::OptimizerSettings & settings)
441441
{
442-
// Savitzky-Golay Quadratic, 5-point Coefficients
443-
xt::xarray<float> filter = {-3.0, 12.0, 17.0, 12.0, -3.0};
444-
filter /= 35.0;
442+
// Savitzky-Golay Quadratic, 9-point Coefficients
443+
xt::xarray<float> filter = {-21.0, 14.0, 39.0, 54.0, 59.0, 54.0, 39.0, 14.0, -21.0};
444+
filter /= 231.0;
445445

446-
const unsigned int num_sequences = control_sequence.vx.shape(0);
446+
const unsigned int num_sequences = control_sequence.vx.shape(0) - 1;
447447

448448
// Too short to smooth meaningfully
449-
if (num_sequences < 10) {
449+
if (num_sequences < 20) {
450450
return;
451451
}
452452

@@ -455,64 +455,145 @@ inline void savitskyGolayFilter(
455455
};
456456

457457
auto applyFilterOverAxis =
458-
[&](xt::xtensor<float, 1> & sequence, const float hist_0, const float hist_1) -> void
458+
[&](xt::xtensor<float, 1> & sequence,
459+
const float hist_0, const float hist_1, const float hist_2, const float hist_3) -> void
459460
{
460461
unsigned int idx = 0;
461462
sequence(idx) = applyFilter(
462463
{
463464
hist_0,
464465
hist_1,
466+
hist_2,
467+
hist_3,
465468
sequence(idx),
466469
sequence(idx + 1),
467-
sequence(idx + 2)});
470+
sequence(idx + 2),
471+
sequence(idx + 3),
472+
sequence(idx + 4)});
468473

469474
idx++;
470475
sequence(idx) = applyFilter(
471476
{
472477
hist_1,
478+
hist_2,
479+
hist_3,
473480
sequence(idx - 1),
474481
sequence(idx),
475482
sequence(idx + 1),
476-
sequence(idx + 2)});
483+
sequence(idx + 2),
484+
sequence(idx + 3),
485+
sequence(idx + 4)});
486+
487+
idx++;
488+
sequence(idx) = applyFilter(
489+
{
490+
hist_2,
491+
hist_3,
492+
sequence(idx - 2),
493+
sequence(idx - 1),
494+
sequence(idx),
495+
sequence(idx + 1),
496+
sequence(idx + 2),
497+
sequence(idx + 3),
498+
sequence(idx + 4)});
477499

478-
for (idx = 2; idx != num_sequences - 3; idx++) {
500+
idx++;
501+
sequence(idx) = applyFilter(
502+
{
503+
hist_3,
504+
sequence(idx - 3),
505+
sequence(idx - 2),
506+
sequence(idx - 1),
507+
sequence(idx),
508+
sequence(idx + 1),
509+
sequence(idx + 2),
510+
sequence(idx + 3),
511+
sequence(idx + 4)});
512+
513+
for (idx = 4; idx != num_sequences - 4; idx++) {
479514
sequence(idx) = applyFilter(
480515
{
516+
sequence(idx - 4),
517+
sequence(idx - 3),
481518
sequence(idx - 2),
482519
sequence(idx - 1),
483520
sequence(idx),
484521
sequence(idx + 1),
485-
sequence(idx + 2)});
522+
sequence(idx + 2),
523+
sequence(idx + 3),
524+
sequence(idx + 4)});
486525
}
487526

488527
idx++;
489528
sequence(idx) = applyFilter(
490529
{
530+
sequence(idx - 4),
531+
sequence(idx - 3),
491532
sequence(idx - 2),
492533
sequence(idx - 1),
493534
sequence(idx),
494535
sequence(idx + 1),
536+
sequence(idx + 2),
537+
sequence(idx + 3),
538+
sequence(idx + 3)});
539+
540+
idx++;
541+
sequence(idx) = applyFilter(
542+
{
543+
sequence(idx - 4),
544+
sequence(idx - 3),
545+
sequence(idx - 2),
546+
sequence(idx - 1),
547+
sequence(idx),
548+
sequence(idx + 1),
549+
sequence(idx + 2),
550+
sequence(idx + 2),
551+
sequence(idx + 2)});
552+
553+
idx++;
554+
sequence(idx) = applyFilter(
555+
{
556+
sequence(idx - 4),
557+
sequence(idx - 3),
558+
sequence(idx - 2),
559+
sequence(idx - 1),
560+
sequence(idx),
561+
sequence(idx + 1),
562+
sequence(idx + 1),
563+
sequence(idx + 1),
495564
sequence(idx + 1)});
496565

497566
idx++;
498567
sequence(idx) = applyFilter(
499568
{
569+
sequence(idx - 4),
570+
sequence(idx - 3),
500571
sequence(idx - 2),
501572
sequence(idx - 1),
502573
sequence(idx),
503574
sequence(idx),
575+
sequence(idx),
576+
sequence(idx),
504577
sequence(idx)});
505578
};
506579

507580
// Filter trajectories
508-
applyFilterOverAxis(control_sequence.vx, control_history[0].vx, control_history[1].vx);
509-
applyFilterOverAxis(control_sequence.vy, control_history[0].vy, control_history[1].vy);
510-
applyFilterOverAxis(control_sequence.wz, control_history[0].wz, control_history[1].wz);
581+
applyFilterOverAxis(
582+
control_sequence.vx, control_history[0].vx,
583+
control_history[1].vx, control_history[2].vx, control_history[3].vx);
584+
applyFilterOverAxis(
585+
control_sequence.vy, control_history[0].vy,
586+
control_history[1].vy, control_history[2].vy, control_history[3].vy);
587+
applyFilterOverAxis(
588+
control_sequence.wz, control_history[0].wz,
589+
control_history[1].wz, control_history[2].wz, control_history[3].wz);
511590

512591
// Update control history
513592
unsigned int offset = settings.shift_control_sequence ? 1 : 0;
514593
control_history[0] = control_history[1];
515-
control_history[1] = {
594+
control_history[1] = control_history[2];
595+
control_history[2] = control_history[3];
596+
control_history[3] = {
516597
control_sequence.vx(offset),
517598
control_sequence.vy(offset),
518599
control_sequence.wz(offset)};

nav2_mppi_controller/src/optimizer.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ void Optimizer::reset()
119119
control_sequence_.reset(settings_.time_steps);
120120
control_history_[0] = {0.0, 0.0, 0.0};
121121
control_history_[1] = {0.0, 0.0, 0.0};
122+
control_history_[2] = {0.0, 0.0, 0.0};
123+
control_history_[3] = {0.0, 0.0, 0.0};
122124

123125
costs_ = xt::zeros<float>({settings_.batch_size});
124126
generated_trajectories_.reset(settings_.batch_size, settings_.time_steps);

nav2_mppi_controller/test/utils_test.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,13 @@ TEST(UtilsTests, SmootherTest)
317317
noisey_sequence.wz += noises;
318318
sequence_init = noisey_sequence;
319319

320-
std::array<mppi::models::Control, 2> history, history_init;
320+
std::array<mppi::models::Control, 4> history, history_init;
321+
history[3].vx = 0.1;
322+
history[3].vy = 0.0;
323+
history[3].wz = 0.3;
324+
history[2].vx = 0.1;
325+
history[2].vy = 0.0;
326+
history[2].wz = 0.3;
321327
history[1].vx = 0.1;
322328
history[1].vy = 0.0;
323329
history[1].wz = 0.3;
@@ -332,14 +338,14 @@ TEST(UtilsTests, SmootherTest)
332338
savitskyGolayFilter(noisey_sequence, history, settings);
333339

334340
// Check history is propogated backward
335-
EXPECT_NEAR(history_init[1].vx, history[0].vx, 0.02);
336-
EXPECT_NEAR(history_init[1].vy, history[0].vy, 0.02);
337-
EXPECT_NEAR(history_init[1].wz, history[0].wz, 0.02);
341+
EXPECT_NEAR(history_init[3].vx, history[2].vx, 0.02);
342+
EXPECT_NEAR(history_init[3].vy, history[2].vy, 0.02);
343+
EXPECT_NEAR(history_init[3].wz, history[2].wz, 0.02);
338344

339345
// Check history element is updated for first command
340-
EXPECT_NEAR(history[1].vx, 0.2, 0.05);
341-
EXPECT_NEAR(history[1].vy, 0.0, 0.02);
342-
EXPECT_NEAR(history[1].wz, 0.23, 0.02);
346+
EXPECT_NEAR(history[3].vx, 0.2, 0.05);
347+
EXPECT_NEAR(history[3].vy, 0.0, 0.035);
348+
EXPECT_NEAR(history[3].wz, 0.23, 0.02);
343349

344350
// Check that path is smoother
345351
float smoothed_val, original_val;

0 commit comments

Comments
 (0)