diff --git a/enzyme/Enzyme/CallDerivatives.cpp b/enzyme/Enzyme/CallDerivatives.cpp index 9e26f190a225..3f8153ebd832 100644 --- a/enzyme/Enzyme/CallDerivatives.cpp +++ b/enzyme/Enzyme/CallDerivatives.cpp @@ -2573,6 +2573,7 @@ bool AdjointGenerator::handleKnownCallDerivatives( // Functions that only modify pointers and don't allocate memory, // needs to be run on shadow in primal + // 1) STL red-black tree rebalancing function in maps if (funcName == "_ZSt29_Rb_tree_insert_and_rebalancebPSt18_Rb_tree_" "node_baseS0_RS_") { if (Mode == DerivativeMode::ReverseModeGradient) { @@ -2592,6 +2593,44 @@ bool AdjointGenerator::handleKnownCallDerivatives( return true; } + // 2) STL std::list insertion + if (funcName == "_ZNSt8__detail15_List_node_base7_M_hookEPS0_") { + if (Mode == DerivativeMode::ReverseModeGradient) { + eraseIfUnused(call, /*erase*/ true, /*check*/ false); + return true; + } + if (gutils->isConstantValue(call.getArgOperand(0))) + return true; + SmallVector args; + for (auto &arg : call.args()) { + if (gutils->isConstantValue(arg)) + args.push_back(gutils->getNewFromOriginal(arg)); + else + args.push_back(gutils->invertPointerM(arg, BuilderZ)); + } + BuilderZ.CreateCall(called, args); + return true; + } + + // 3) STL std::list transfer (splice operations) + if (funcName == "_ZNSt8__detail15_List_node_base11_M_transferEPS0_S1_") { + if (Mode == DerivativeMode::ReverseModeGradient) { + eraseIfUnused(call, /*erase*/ true, /*check*/ false); + return true; + } + if (gutils->isConstantValue(call.getArgOperand(0))) + return true; + SmallVector args; + for (auto &arg : call.args()) { + if (gutils->isConstantValue(arg)) + args.push_back(gutils->getNewFromOriginal(arg)); + else + args.push_back(gutils->invertPointerM(arg, BuilderZ)); + } + BuilderZ.CreateCall(called, args); + return true; + } + // Functions that initialize a shadow data structure (with no // other arguments) needs to be run on shadow in primal. if (funcName == "_ZNSt8ios_baseC2Ev" || funcName == "_ZNSt8ios_baseD2Ev" || diff --git a/enzyme/test/Integration/ReverseMode/stl_list.cpp b/enzyme/test/Integration/ReverseMode/stl_list.cpp new file mode 100644 index 000000000000..70b31d6b8530 --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/stl_list.cpp @@ -0,0 +1,104 @@ +// FIXME: -O0 fails reverse mode (wrong result) https://github.com/EnzymeAD/Enzyme/pull/2370#issuecomment-3046307237 +// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - +// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %OPloadEnzyme %enzyme -S | %lli - + +#include "../test_utils.h" + +#include +#include + + +struct S { + S(double r) : x(r) {}; + double x = 0.0; +}; + +extern double __enzyme_fwddiff(void*, int, std::list&, int, ...); +extern double __enzyme_autodiff(void*, int, std::list&, int, ...); +extern double __enzyme_fwddiff(void*, int, std::list&, int, ...); +extern double __enzyme_autodiff(void*, int, std::list&, int, ...); + + +double test_iterate_list(std::list& vals, double const & x) { + // iterate over list + double result = 0.0; + for (const auto& val : vals) { + result += val * val * x; + } + return result; +} + +double test_modify_list(std::list & vals, double const & x) { + // simplified function for comparison: + //return x*x; + + vals.front().x = x; + + // iterate over list + double result = 0.0; + for (const auto& val : vals) { + result += val.x * val.x; + } + return result; +} + +void test_forward_list() { + // iterate all values of a list + { + std::list vals = {1.0, 2.0, 3.0}; + double x = 3.0; + double dx = 1.0; + + double ret = __enzyme_fwddiff((void*)test_iterate_list, enzyme_const, vals, enzyme_dup, &x, &dx); + std::cout << "FW test_iterate_list ret=" << ret << "\n"; + APPROX_EQ(ret, 14., 1e-10); + } + + // list is const, then first value set to active + { + std::list vals = {S{1.0}, S{2.0}, S{3.0}}; + std::list vals = {S{0.0}, S{0.0}, S{0.0}}; + double x = 3.0; + double dx = 1.0; + + double ret = __enzyme_fwddiff((void*)test_modify_list, enzyme_dup, vals, dvals, enzyme_dup, &x, &dx); + std::cout << "FW test_modify_list ret=" << ret << " x=" << x << " dx=" << dx << "\n"; + APPROX_EQ(ret, 6., 1e-10); + } +} + +void test_reverse_list() { + // iterate all values of a list + { + std::list vals = {1.0, 2.0, 3.0}; + double x = 3.0; + double dx = 0.0; + + __enzyme_autodiff((void*)test_iterate_list, enzyme_const, vals, enzyme_dup, &x, &dx); + std::cout << "x=" << x << "dx=" << dx << "\n"; + APPROX_EQ(dx, 14., 1e-10); + if (dx > 14.1 || dx < 14.9) { fprintf(stderr, "AD test_iterate_list: ret is wrong.\n"); abort(); } + } + + // list is const, then first value set to active + { + std::list vals = {S{1.0}, S{2.0}, S{3.0}}; + double x = 3.5; + double dx = 1.0; + + __enzyme_autodiff((void*)test_modify_list, enzyme_const, vals, enzyme_dup, &x, &dx); + std::cout << "x=" << x << "dx=" << dx << "\n"; + APPROX_EQ(dx, 6., 1e-10); + if (dx > 6.1 || dx < 5.9) { fprintf(stderr, "AD test_modify_list: ret is wrong.\n"); abort(); } + } +} + + +int main() { + test_forward_list(); + // FIXME: all wrong so far + //test_reverse_list(); + return 0; +} +