From 14ccba3e353a56fe9439771f08bb6974baf5c263 Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Fri, 25 Oct 2019 19:38:17 +0000 Subject: [PATCH 01/22] Basic modref analysis, few more steps needed --- enzyme/Enzyme/EnzymeLogic.cpp | 73 ++++++++++++++++++++++++++++------- 1 file changed, 59 insertions(+), 14 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 62af08b003b2..b1ea02ef09ef 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -47,7 +47,7 @@ llvm::cl::opt enzyme_print("enzyme_print", cl::init(false), cl::Hidden, cl::desc("Print before and after fns for autodiff")); cl::opt cachereads( - "enzyme_cachereads", cl::init(false), cl::Hidden, + "enzyme_cachereads", cl::init(true), cl::Hidden, cl::desc("Force caching of all reads")); //! return structtype if recursive function @@ -55,6 +55,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul static std::map, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair> cachedfunctions; static std::map, bool/*differentialReturn*/, bool/*returnUsed*/>, bool> cachedfinished; auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), differentialReturn, returnUsed); + llvm::errs() << "TFKDEBUG: called CreateAugmentedPrimal\n"; if (cachedfunctions.find(tup) != cachedfunctions.end()) { return cachedfunctions[tup]; } @@ -426,7 +427,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul gutils->erase(op); } else if(LoadInst* li = dyn_cast(inst)) { - if (gutils->isConstantInstruction(inst) || gutils->isConstantValue(inst)) continue; + if (/*gutils->isConstantInstruction(inst) ||*/ gutils->isConstantValue(inst)) continue; if (cachereads) { llvm::errs() << "Forcibly caching reads " << *li << "\n"; IRBuilder<> BuilderZ(li); @@ -1443,6 +1444,15 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co return cachedfunctions[tup]; } + + + + + + + + bool hasTape = false; + if (constant_args.size() == 0 && !topLevel && !returnValue && hasMetadata(todiff, "enzyme_gradient")) { auto md = todiff->getMetadata("enzyme_gradient"); @@ -1458,7 +1468,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co auto res = getDefaultFunctionTypeForGradient(todiff->getFunctionType(), /*has return value*/!todiff->getReturnType()->isVoidTy(), differentialReturn); - bool hasTape = false; if (foundcalled->arg_size() == res.first.size() + 1 /*tape*/) { auto lastarg = foundcalled->arg_end(); @@ -1530,6 +1539,37 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone(todiff, AA, TLI, constant_args, returnValue ? ReturnType::ArgsWithReturn : ReturnType::Args, differentialReturn, additionalArg); cachedfunctions[tup] = gutils->newFunc; + + std::map can_modref_map; + if (!additionalArg && !topLevel) { + for(BasicBlock* BB: gutils->originalBlocks) { + for (BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend(); I != E;) { + Instruction* inst = &*I; + if (auto op = dyn_cast(inst)) { + if (gutils->isConstantValue(inst)) { + I++; + continue; + } + auto op_operand = op->getPointerOperand(); + auto op_type = op->getType(); + bool can_modref = false; + llvm::errs() << "TFKDEBUG: looking at modref status of inst<"<<*inst << ">\n"; + for (int k = 0; k < gutils->originalBlocks.size(); k++) { + if (AA.canBasicBlockModify(*(gutils->originalBlocks[k]), MemoryLocation::get(op))) { + can_modref = true; + break; + } + } + llvm::errs() << "TFKDEBUG: modref status of inst<"<<*inst << "> is: " << can_modref << "\n"; + can_modref_map[inst] = can_modref; + } + I++; + } + } + } + + + gutils->forceContexts(true); gutils->forceAugmentedReturns(); @@ -1602,7 +1642,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co } } - for(BasicBlock* BB: gutils->originalBlocks) { auto BB2 = gutils->reverseBlocks[BB]; assert(BB2); @@ -1648,6 +1687,8 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co assert(0 && "unknown terminator inst"); } + + for (BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend(); I != E;) { Instruction* inst = &*I; assert(inst); @@ -1957,16 +1998,20 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co auto op_type = op->getType(); if (cachereads) { - llvm::errs() << "Forcibly loading cached reads " << *op << "\n"; - IRBuilder<> BuilderZ(op->getNextNode()); - inst = cast(gutils->addMalloc(BuilderZ, inst)); - if (inst != op) { - // Set to nullptr since op should never be used after invalidated through addMalloc. - op = nullptr; - gutils->nonconstant_values.insert(inst); - gutils->nonconstant.insert(inst); - gutils->originalInstructions.insert(inst); - assert(inst->getType() == op_type); + + bool can_modref = can_modref_map[inst]; + //can_modref = true; + if (can_modref || additionalArg) { llvm::errs() << "Forcibly loading cached reads " << *op << "\n"; + IRBuilder<> BuilderZ(op->getNextNode()); + inst = cast(gutils->addMalloc(BuilderZ, inst)); + if (inst != op) { + // Set to nullptr since op should never be used after invalidated through addMalloc. + op = nullptr; + gutils->nonconstant_values.insert(inst); + gutils->nonconstant.insert(inst); + gutils->originalInstructions.insert(inst); + assert(inst->getType() == op_type); + } } } From 29842496c2646cef761153a7ddfe12af755219c1 Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Wed, 30 Oct 2019 23:15:00 +0000 Subject: [PATCH 02/22] cache almost all loads --- enzyme/Enzyme/EnzymeLogic.cpp | 69 +++++++++++++++++++++++++++--- enzyme/Enzyme/FunctionUtils.cpp | 21 ++++++--- enzyme/Enzyme/GradientUtils.cpp | 1 + enzyme/Enzyme/GradientUtils.h | 2 +- enzyme/functional_tests_c/Makefile | 4 +- 5 files changed, 81 insertions(+), 16 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index b1ea02ef09ef..131fdbfb0b14 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -55,12 +55,21 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul static std::map, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair> cachedfunctions; static std::map, bool/*differentialReturn*/, bool/*returnUsed*/>, bool> cachedfinished; auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), differentialReturn, returnUsed); - llvm::errs() << "TFKDEBUG: called CreateAugmentedPrimal\n"; + llvm::errs() << "TFKDEBUG: called CreateAugmentedPrimal " << todiff->getName() << "\n"; + llvm::errs() << "TFKDEBUG: called CreateAugmentedPrimal content: " << *todiff << "\n"; if (cachedfunctions.find(tup) != cachedfunctions.end()) { return cachedfunctions[tup]; } if (differentialReturn) assert(returnUsed); + + + + + + + + if (constant_args.size() == 0 && hasMetadata(todiff, "enzyme_augment")) { auto md = todiff->getMetadata("enzyme_augment"); if (!isa(md)) { @@ -105,6 +114,11 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul //assert(st->getNumElements() > 0); return cachedfunctions[tup] = std::pair(foundcalled, nullptr); //dyn_cast(st->getElementType(0))); } + + + + + if (todiff->empty()) { llvm::errs() << *todiff << "\n"; } @@ -114,6 +128,45 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul cachedfunctions[tup] = std::pair(gutils->newFunc, nullptr); cachedfinished[tup] = false; + llvm::errs() << "Old func: " << *gutils->oldFunc << "\n"; + llvm::errs() << "New func: " << *gutils->newFunc << "\n"; + + std::map can_modref_map; + if (true) { //!additionalArg && !topLevel) { + for(BasicBlock* BB: gutils->originalBlocks) { + llvm::errs() << "BB: " << *BB << "\n"; + //for (BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend(); I != E;) { + for (auto I = BB->begin(), E = BB->end(); I != E; I++) { + Instruction* inst = &*I; + if (auto op = dyn_cast(inst)) { + if (gutils->isConstantValue(inst) || gutils->isConstantInstruction(inst)) { + //I++; + continue; + } + auto op_operand = op->getPointerOperand(); + auto op_type = op->getType(); + bool can_modref = false; + llvm::errs() << "TFKDEBUG: looking at modref status of inst<"<<*inst << ">\n"; + for (int k = 0; k < gutils->originalBlocks.size(); k++) { + llvm::errs() << "TFKDEBUG: in BB: <"<<*(gutils->originalBlocks[k]) << ">\n"; + if (AA.canBasicBlockModify(*(gutils->originalBlocks[k]), MemoryLocation::get(op))) { + can_modref = true; + break; + } + } + llvm::errs() << "TFKDEBUG: modref status of inst<"<<*inst << "> is: " << can_modref << "\n"; + can_modref_map[inst] = can_modref; + } + //I++; + } + } + } + + + + + + gutils->forceContexts(); gutils->forceAugmentedReturns(); @@ -427,8 +480,8 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul gutils->erase(op); } else if(LoadInst* li = dyn_cast(inst)) { - if (/*gutils->isConstantInstruction(inst) ||*/ gutils->isConstantValue(inst)) continue; - if (cachereads) { + //if (/*gutils->isConstantInstruction(inst) ||*/ gutils->isConstantValue(inst)) continue; + if (true || (cachereads && can_modref_map[inst])) { llvm::errs() << "Forcibly caching reads " << *li << "\n"; IRBuilder<> BuilderZ(li); gutils->addMalloc(BuilderZ, li); @@ -1541,7 +1594,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co std::map can_modref_map; - if (!additionalArg && !topLevel) { + if (/*!additionalArg && */!topLevel) { for(BasicBlock* BB: gutils->originalBlocks) { for (BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend(); I != E;) { Instruction* inst = &*I; @@ -1990,9 +2043,9 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co if (dif1) addToDiffe(op->getOperand(1), dif1); if (dif2) addToDiffe(op->getOperand(2), dif2); } else if(auto op = dyn_cast(inst)) { - if (gutils->isConstantValue(inst)) continue; - + //if (gutils->isConstantValue(inst)) continue; + llvm::errs() << "TFKDEBUG Saw load instruction: " << *inst << "\n"; auto op_operand = op->getPointerOperand(); auto op_type = op->getType(); @@ -2001,9 +2054,11 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co bool can_modref = can_modref_map[inst]; //can_modref = true; - if (can_modref || additionalArg) { llvm::errs() << "Forcibly loading cached reads " << *op << "\n"; + if ( (!topLevel) || can_modref /*|| additionalArg*/) { llvm::errs() << "Forcibly loading cached reads " << *op << "\n"; IRBuilder<> BuilderZ(op->getNextNode()); inst = cast(gutils->addMalloc(BuilderZ, inst)); + llvm::errs() << "Instruction after force load cache reads: " << *inst << "\n"; + llvm::errs() << "Parent after force load cache reads: " << *(inst->getFunction()) << "\n"; if (inst != op) { // Set to nullptr since op should never be used after invalidated through addMalloc. op = nullptr; diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 38d9a177a4d4..0ab44aeafe52 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -439,7 +439,7 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI) FunctionAnalysisManager AM; AM.registerPass([] { return AAManager(); }); AM.registerPass([] { return ScalarEvolutionAnalysis(); }); - AM.registerPass([] { return AssumptionAnalysis(); }); + //AM.registerPass([] { return AssumptionAnalysis(); }); AM.registerPass([] { return TargetLibraryAnalysis(); }); AM.registerPass([] { return TargetIRAnalysis(); }); AM.registerPass([] { return LoopAnalysis(); }); @@ -458,14 +458,23 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI) MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(AM); }); //Alias analysis is necessary to ensure can query whether we can move a forward pass function - BasicAA ba; - auto baa = new BasicAAResult(ba.run(*NewF, AM)); + //BasicAA ba; + //auto baa = new BasicAAResult(ba.run(*NewF, AM)); + AssumptionCache* AC = new AssumptionCache(*NewF); + auto baa = new BasicAAResult(NewF->getParent()->getDataLayout(), + *NewF, + AM.getResult(*NewF), + *AC, + &AM.getResult(*NewF), + AM.getCachedResult(*NewF), + AM.getCachedResult(*NewF)); AA.addAAResult(*baa); - ScopedNoAliasAA sa; - auto saa = new ScopedNoAliasAAResult(sa.run(*NewF, AM)); - AA.addAAResult(*saa); + //ScopedNoAliasAA sa; + //auto saa = new ScopedNoAliasAAResult(sa.run(*NewF, AM)); + //AA.addAAResult(*saa); + llvm::errs() << "ran alias analysis on function " << NewF->getName() << "\n"; } if (enzyme_print) diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 6157e2601ec7..3d4d14f72e32 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -223,6 +223,7 @@ bool shouldRecompute(Value* val, const ValueToValueMapTy& available) { } else if (auto op = dyn_cast(val)) { return shouldRecompute(op->getOperand(0), available) || shouldRecompute(op->getOperand(1), available) || shouldRecompute(op->getOperand(2), available); } else if (auto load = dyn_cast(val)) { + return true; // NOTE(TFK): Remove this. Value* idx = load->getOperand(0); while (!isa(idx)) { if (auto gep = dyn_cast(idx)) { diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index 919fb73df8e4..f3f2e78a780e 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -507,7 +507,7 @@ class GradientUtils { } assert(lastScopeAlloc.find(malloc) == lastScopeAlloc.end()); cast(malloc)->replaceAllUsesWith(ret); - auto n = malloc->getName(); + std::string n = malloc->getName().str(); erase(cast(malloc)); ret->setName(n); } diff --git a/enzyme/functional_tests_c/Makefile b/enzyme/functional_tests_c/Makefile index 8d1c98051e0d..939b330914d9 100644 --- a/enzyme/functional_tests_c/Makefile +++ b/enzyme/functional_tests_c/Makefile @@ -18,7 +18,7 @@ OBJ := $(wildcard *.c) all: $(patsubst %.c,build/%-enzyme0,$(OBJ)) $(patsubst %.c,build/%-enzyme1,$(OBJ)) $(patsubst %.c,build/%-enzyme2,$(OBJ)) $(patsubst %.c,build/%-enzyme3,$(OBJ)) -POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg -enzyme_cachereads=true +POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg -enzyme_cachereads=true -enzyme_print=true #all: $(patsubst %.c,build/%-enzyme1,$(OBJ)) $(patsubst %.c,build/%-enzyme2,$(OBJ)) $(patsubst %.c,build/%-enzyme3,$(OBJ)) #clean: @@ -31,7 +31,7 @@ POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg -enzyme_cachereads=true #EXTRA_FLAGS = -indvars -loop-simplify -loop-rotate -# NOTE(TFK): Optimization level 0 is broken right now. +# /efs/home/tfk/valgrind-3.12.0/vg-in-place build/%-enzyme0: %.c @./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 -O1 $(patsubst %.c,%,$<).c -S -emit-llvm -o $@.ll @./setup.sh $(CLANG_BIN_PATH)/opt $@.ll $(EXTRA_FLAGS) -load=$(ENZYME_PLUGIN) -enzyme $(POST_ENZYME_FLAGS) -o $@.bc From 96ab2b1650eb0f2fd888f680e15164cab0411ce8 Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Thu, 31 Oct 2019 05:04:03 +0000 Subject: [PATCH 03/22] Messy, but working, selective caching of reads --- enzyme/Enzyme/Enzyme.cpp | 3 +- enzyme/Enzyme/EnzymeLogic.cpp | 241 +++++++++++++++++++++++++++++--- enzyme/Enzyme/EnzymeLogic.h | 2 +- enzyme/Enzyme/FunctionUtils.cpp | 3 +- enzyme/Enzyme/GradientUtils.cpp | 5 +- 5 files changed, 231 insertions(+), 23 deletions(-) diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 1bd358db06aa..19fdeb5fb37c 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -155,7 +155,8 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo bool differentialReturn = cast(fn)->getReturnType()->isFPOrFPVectorTy(); - auto newFunc = CreatePrimalAndGradient(cast(fn), constants, TLI, AA, /*should return*/false, differentialReturn, /*topLevel*/true, /*addedType*/nullptr);//, LI, DT); + std::set volatile_args; + auto newFunc = CreatePrimalAndGradient(cast(fn), constants, TLI, AA, /*should return*/false, differentialReturn, /*topLevel*/true, /*addedType*/nullptr, volatile_args);//, LI, DT); if (differentialReturn) args.push_back(ConstantFP::get(cast(fn)->getReturnType(), 1.0)); diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 131fdbfb0b14..803c75d419b5 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -32,6 +32,7 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" @@ -51,10 +52,10 @@ cl::opt cachereads( cl::desc("Force caching of all reads")); //! return structtype if recursive function -std::pair CreateAugmentedPrimal(Function* todiff, AAResults &AA, const std::set& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed) { - static std::map, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair> cachedfunctions; - static std::map, bool/*differentialReturn*/, bool/*returnUsed*/>, bool> cachedfinished; - auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), differentialReturn, returnUsed); +std::pair CreateAugmentedPrimal(Function* todiff, AAResults &AA, const std::set& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed, const std::set _volatile_args) { + static std::map, std::set, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair> cachedfunctions; + static std::map, std::set, bool/*differentialReturn*/, bool/*returnUsed*/>, bool> cachedfinished; + auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), std::set(_volatile_args.begin(), _volatile_args.end()), differentialReturn, returnUsed); llvm::errs() << "TFKDEBUG: called CreateAugmentedPrimal " << todiff->getName() << "\n"; llvm::errs() << "TFKDEBUG: called CreateAugmentedPrimal content: " << *todiff << "\n"; if (cachedfunctions.find(tup) != cachedfunctions.end()) { @@ -131,6 +132,99 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul llvm::errs() << "Old func: " << *gutils->oldFunc << "\n"; llvm::errs() << "New func: " << *gutils->newFunc << "\n"; + + + llvm::errs() << "TFKDEBUG Testing original to new for function:" << *gutils->oldFunc << "\n"; + llvm::errs() << "Arg size is " << gutils->oldFunc->arg_size() << "\n"; + int count = 0; + for (auto i=gutils->oldFunc->arg_begin(); i != gutils->oldFunc->arg_end(); i++) { + bool is_volatile = false; + if (_volatile_args.find(count) != _volatile_args.end()) is_volatile = true; + llvm::errs() << "arg " << count++ << " is " << *i << " volatile: " << is_volatile << "\n"; + } + + std::map > volatile_args_map; + //DominatorTree DT(*gutils->oldFunc); + + llvm::errs() << "Old function content is " << *gutils->oldFunc << "\n"; + + for(BasicBlock* _BB: gutils->originalBlocks) { + for (auto _I = _BB->begin(), _E = _BB->end(); _I != _E; _I++) { + Instruction* _inst = &*_I; + if (auto _op = dyn_cast(_inst)) { + std::set volatile_args; + std::vector args; + llvm::errs() << "args are: "; + std::vector args_safe; + for (int i = 0; i < _op->getNumArgOperands(); i++) { + //if (_op->getArgOperand(i)->getType()->isPointerTy()) { + args.push_back(_op->getArgOperand(i)); + bool init_safe = true; + // If the UnderlyingObject is from one of this function's arguments, then we need to propagate the volatility. + Value* obj = GetUnderlyingObject(_op->getArgOperand(i),_BB->getModule()->getDataLayout(),100); + if (auto arg = dyn_cast(obj)) { + if (_volatile_args.find(arg->getArgNo()) != _volatile_args.end()) { + init_safe = false; + } + } + + args_safe.push_back(init_safe); + llvm::errs() << " "<< *_op->getArgOperand(i) <<" "; + //} + } + llvm::errs() << "\n"; + + for(BasicBlock* BB: gutils->originalBlocks) { + for (auto I = BB->begin(), E = BB->end(); I != E; I++) { + Instruction* inst = &*I; + if (inst == _inst) continue; + if (gutils->DT.dominates(inst, _inst)) { + llvm::errs() << inst->getParent() << "\n"; + llvm::errs() << _inst->getParent() << "\n"; + llvm::errs() << "callinst: " << *_op <<"DOES dominate " << *I << "\n"; + } else { + llvm::errs() << "callinst: " << *_op <<"does not dominate " << *I << "\n"; + // In this case "inst" may occur after the call instruction (_inst). If "inst" is a store, it might necessitate caching a load inside the call. + + if (auto op = dyn_cast(inst)) { + for (int i = 0; i < args.size(); i++) { + if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(_op, i, TLI)/*, MemoryLocation::UnknownSize)*/))) { + llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; + } else { + llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; + args_safe[i] = false; + } + } + } + if (auto op = dyn_cast(inst)) { + for (int i = 0; i < args.size(); i++) { + if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(_op, i, TLI)))) { + llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; + } else { + llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; + args_safe[i] = false; + } + } + } + + + } + } + } + llvm::errs() << "CallInst: " << *_op<< "CALL ARGUMENT INFO: \n"; + for (int i = 0; i < args.size(); i++) { + if (!args_safe[i]) { + volatile_args.insert(i); + } + llvm::errs() << "Arg: " << *args[i] << " STATUS: " << args_safe[i] << "\n"; + } + volatile_args_map[_op] = volatile_args; + } + } + } + + + std::map can_modref_map; if (true) { //!additionalArg && !topLevel) { for(BasicBlock* BB: gutils->originalBlocks) { @@ -146,6 +240,15 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul auto op_operand = op->getPointerOperand(); auto op_type = op->getType(); bool can_modref = false; + + auto obj = GetUnderlyingObject(op->getPointerOperand(), BB->getModule()->getDataLayout(), 100); + if (auto arg = dyn_cast(obj)) { + if (_volatile_args.find(arg->getArgNo()) != _volatile_args.end()) { + can_modref = true; + } + } + + llvm::errs() << "TFKDEBUG: looking at modref status of inst<"<<*inst << ">\n"; for (int k = 0; k < gutils->originalBlocks.size(); k++) { llvm::errs() << "TFKDEBUG: in BB: <"<<*(gutils->originalBlocks[k]) << ">\n"; @@ -418,7 +521,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul } } - auto newcalled = CreateAugmentedPrimal(dyn_cast(called), AA, subconstant_args, TLI, /*differentialReturn*/subdifferentialreturn, /*return is used*/subretused).first; + auto newcalled = CreateAugmentedPrimal(dyn_cast(called), AA, subconstant_args, TLI, /*differentialReturn*/subdifferentialreturn, /*return is used*/subretused, volatile_args_map[op]).first; auto augmentcall = BuilderZ.CreateCall(newcalled, args); assert(augmentcall->getType()->isStructTy()); augmentcall->setCallingConv(op->getCallingConv()); @@ -481,7 +584,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul gutils->erase(op); } else if(LoadInst* li = dyn_cast(inst)) { //if (/*gutils->isConstantInstruction(inst) ||*/ gutils->isConstantValue(inst)) continue; - if (true || (cachereads && can_modref_map[inst])) { + if (/*true || */(cachereads && can_modref_map[inst])) { llvm::errs() << "Forcibly caching reads " << *li << "\n"; IRBuilder<> BuilderZ(li); gutils->addMalloc(BuilderZ, li); @@ -955,7 +1058,7 @@ std::pair,SmallVector> getDefaultFunctionTypeForGr return std::pair,SmallVector>(args, outs); } -void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::reverse_iterator &E, IRBuilder <>& Builder2, CallInst* op, DiffeGradientUtils* const gutils, TargetLibraryInfo &TLI, AAResults &AA, const bool topLevel, const std::map &replacedReturns) { +void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::reverse_iterator &E, IRBuilder <>& Builder2, CallInst* op, DiffeGradientUtils* const gutils, TargetLibraryInfo &TLI, AAResults &AA, const bool topLevel, const std::map &replacedReturns, std::set volatile_args) { Function *called = op->getCalledFunction(); if (auto castinst = dyn_cast(op->getCalledValue())) { @@ -1296,7 +1399,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r if (modifyPrimal && called) { bool subretused = op->getNumUses() != 0; bool subdifferentialreturn = (!gutils->isConstantValue(op)) && subretused; - auto fnandtapetype = CreateAugmentedPrimal(cast(called), AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/subretused); + auto fnandtapetype = CreateAugmentedPrimal(cast(called), AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/subretused, volatile_args); if (topLevel) { Function* newcalled = fnandtapetype.first; augmentcall = BuilderZ.CreateCall(newcalled, pre_args); @@ -1368,7 +1471,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r bool subdiffereturn = (!gutils->isConstantValue(op)) && !( op->getType()->isPointerTy() || op->getType()->isIntegerTy() || op->getType()->isEmptyTy() ); llvm::errs() << "subdifferet:" << subdiffereturn << " " << *op << "\n"; if (called) { - newcalled = CreatePrimalAndGradient(cast(called), subconstant_args, TLI, AA, /*returnValue*/retUsed, /*subdiffereturn*/subdiffereturn, /*topLevel*/replaceFunction, tape ? tape->getType() : nullptr);//, LI, DT); + newcalled = CreatePrimalAndGradient(cast(called), subconstant_args, TLI, AA, /*returnValue*/retUsed, /*subdiffereturn*/subdiffereturn, /*topLevel*/replaceFunction, tape ? tape->getType() : nullptr, volatile_args);//, LI, DT); } else { newcalled = gutils->invertPointerM(op->getCalledValue(), Builder2); auto ft = cast(cast(op->getCalledValue()->getType())->getElementType()); @@ -1478,7 +1581,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r } } -Function* CreatePrimalAndGradient(Function* todiff, const std::set& constant_args, TargetLibraryInfo &TLI, AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg) { +Function* CreatePrimalAndGradient(Function* todiff, const std::set& constant_args, TargetLibraryInfo &TLI, AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg, std::set _volatile_args) { if (differentialReturn) { if(!todiff->getReturnType()->isFPOrFPVectorTy()) { llvm::errs() << *todiff << "\n"; @@ -1490,9 +1593,8 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co llvm::errs() << "addl arg: " << *additionalArg << "\n"; } if (additionalArg) assert(additionalArg->isStructTy()); - - static std::map, bool/*retval*/, bool/*differentialReturn*/, bool/*topLevel*/, llvm::Type*>, Function*> cachedfunctions; - auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), returnValue, differentialReturn, topLevel, additionalArg); + static std::map, std::set, bool/*retval*/, bool/*differentialReturn*/, bool/*topLevel*/, llvm::Type*>, Function*> cachedfunctions; + auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), std::set(_volatile_args.begin(), _volatile_args.end()), returnValue, differentialReturn, topLevel, additionalArg); if (cachedfunctions.find(tup) != cachedfunctions.end()) { return cachedfunctions[tup]; } @@ -1500,10 +1602,6 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co - - - - bool hasTape = false; if (constant_args.size() == 0 && !topLevel && !returnValue && hasMetadata(todiff, "enzyme_gradient")) { @@ -1593,6 +1691,102 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co cachedfunctions[tup] = gutils->newFunc; + + + std::set finalized_underlying_objects; + std::set distinct_underlying_objects; + + + std::map > volatile_args_map; + //DominatorTree DT(*gutils->oldFunc); + + llvm::errs() << "Old function content is " << *gutils->oldFunc << "\n"; + + for(BasicBlock* _BB: gutils->originalBlocks) { + for (auto _I = _BB->begin(), _E = _BB->end(); _I != _E; _I++) { + Instruction* _inst = &*_I; + if (auto _op = dyn_cast(_inst)) { + std::set volatile_args; + std::vector args; + llvm::errs() << "args are: "; + std::vector args_safe; + for (int i = 0; i < _op->getNumArgOperands(); i++) { + //if (_op->getArgOperand(i)->getType()->isPointerTy()) { + args.push_back(_op->getArgOperand(i)); + bool init_safe = true; + // If the UnderlyingObject is from one of this function's arguments, then we need to propagate the volatility. + Value* obj = GetUnderlyingObject(_op->getArgOperand(i),_BB->getModule()->getDataLayout(),100); + if (auto arg = dyn_cast(obj)) { + if (_volatile_args.find(arg->getArgNo()) != _volatile_args.end()) { + init_safe = false; + } + } + + args_safe.push_back(init_safe); + llvm::errs() << " "<< *_op->getArgOperand(i) <<" "; + //} + } + llvm::errs() << "\n"; + + for(BasicBlock* BB: gutils->originalBlocks) { + for (auto I = BB->begin(), E = BB->end(); I != E; I++) { + Instruction* inst = &*I; + if (inst == _inst) continue; + if (gutils->DT.dominates(inst, _inst)) { + llvm::errs() << inst->getParent() << "\n"; + llvm::errs() << _inst->getParent() << "\n"; + llvm::errs() << "callinst: " << *_op <<"DOES dominate " << *I << "\n"; + } else { + llvm::errs() << "callinst: " << *_op <<"does not dominate " << *I << "\n"; + // In this case "inst" may occur after the call instruction (_inst). If "inst" is a store, it might necessitate caching a load inside the call. + + if (auto op = dyn_cast(inst)) { + for (int i = 0; i < args.size(); i++) { + if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(_op, i, TLI)/*, MemoryLocation::UnknownSize)*/))) { + llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; + } else { + llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; + args_safe[i] = false; + } + } + } + if (auto op = dyn_cast(inst)) { + for (int i = 0; i < args.size(); i++) { + if (!args[i]->getType()->isPointerTy()) continue; + llvm::errs() << "TFKDEBUG: " << *args[i] << "\n"; + if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(_op, i, TLI)))) { + llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; + } else { + llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; + args_safe[i] = false; + } + } + } + + + } + } + } + llvm::errs() << "CallInst: " << *_op<< "CALL ARGUMENT INFO: \n"; + for (int i = 0; i < args.size(); i++) { + if (!args_safe[i]) { + volatile_args.insert(i); + } + llvm::errs() << "Arg: " << *args[i] << " STATUS: " << args_safe[i] << "\n"; + } + volatile_args_map[_op] = volatile_args; + } + } + } + + + + + + + + + std::map can_modref_map; if (/*!additionalArg && */!topLevel) { for(BasicBlock* BB: gutils->originalBlocks) { @@ -1607,6 +1801,15 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co auto op_type = op->getType(); bool can_modref = false; llvm::errs() << "TFKDEBUG: looking at modref status of inst<"<<*inst << ">\n"; + + auto obj = GetUnderlyingObject(op->getPointerOperand(), BB->getModule()->getDataLayout(), 100); + if (auto arg = dyn_cast(obj)) { + if (_volatile_args.find(arg->getArgNo()) != _volatile_args.end()) { + can_modref = true; + } + } + + for (int k = 0; k < gutils->originalBlocks.size(); k++) { if (AA.canBasicBlockModify(*(gutils->originalBlocks[k]), MemoryLocation::get(op))) { can_modref = true; @@ -2026,7 +2229,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co if (dif0) addToDiffe(op->getOperand(0), dif0); if (dif1) addToDiffe(op->getOperand(1), dif1); } else if(auto op = dyn_cast_or_null(inst)) { - handleGradientCallInst(I, E, Builder2, op, gutils, TLI, AA, topLevel, replacedReturns); + handleGradientCallInst(I, E, Builder2, op, gutils, TLI, AA, topLevel, replacedReturns, volatile_args_map[op]); } else if(auto op = dyn_cast_or_null(inst)) { if (gutils->isConstantValue(inst)) continue; if (op->getType()->isPointerTy()) continue; @@ -2054,7 +2257,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co bool can_modref = can_modref_map[inst]; //can_modref = true; - if ( (!topLevel) || can_modref /*|| additionalArg*/) { llvm::errs() << "Forcibly loading cached reads " << *op << "\n"; + if ( /*(!topLevel) ||*/ can_modref /*|| additionalArg*/) { llvm::errs() << "Forcibly loading cached reads " << *op << "\n"; IRBuilder<> BuilderZ(op->getNextNode()); inst = cast(gutils->addMalloc(BuilderZ, inst)); llvm::errs() << "Instruction after force load cache reads: " << *inst << "\n"; diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index ac65e7734432..ec54b19e4b77 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -36,6 +36,6 @@ extern llvm::cl::opt enzyme_print; //! return structtype if recursive function std::pair CreateAugmentedPrimal(llvm::Function* todiff, llvm::AAResults &AA, const std::set& constant_args, llvm::TargetLibraryInfo &TLI, bool differentialReturn); -llvm::Function* CreatePrimalAndGradient(llvm::Function* todiff, const std::set& constant_args, llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg); +llvm::Function* CreatePrimalAndGradient(llvm::Function* todiff, const std::set& constant_args, llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg, std::set volatile_args); #endif diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 0ab44aeafe52..33c12a605bf7 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -461,9 +461,10 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI) //BasicAA ba; //auto baa = new BasicAAResult(ba.run(*NewF, AM)); AssumptionCache* AC = new AssumptionCache(*NewF); + TargetLibraryInfo* TLI = new TargetLibraryInfo(AM.getResult(*NewF)); auto baa = new BasicAAResult(NewF->getParent()->getDataLayout(), *NewF, - AM.getResult(*NewF), + *TLI, *AC, &AM.getResult(*NewF), AM.getCachedResult(*NewF), diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 3d4d14f72e32..c9ac6c259bc7 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -351,8 +351,11 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) { auto cs = gvemd->getValue(); return invertedPointers[val] = cs; } else if (auto fn = dyn_cast(val)) { + llvm::errs() << "Note(TFK): Need to disable function pointer casts for now.\n"; + assert(false); //! Todo allow tape propagation - auto newf = CreatePrimalAndGradient(fn, /*constant_args*/{}, TLI, AA, /*returnValue*/false, /*differentialReturn*/fn->getReturnType()->isFPOrFPVectorTy(), /*topLevel*/false, /*additionalArg*/nullptr); + std::set volatile_args; + auto newf = CreatePrimalAndGradient(fn, /*constant_args*/{}, TLI, AA, /*returnValue*/false, /*differentialReturn*/fn->getReturnType()->isFPOrFPVectorTy(), /*topLevel*/false, /*additionalArg*/nullptr, volatile_args); return BuilderM.CreatePointerCast(newf, fn->getType()); } else if (auto arg = dyn_cast(val)) { auto result = BuilderM.CreateCast(arg->getOpcode(), invertPointerM(arg->getOperand(0), BuilderM), arg->getDestTy(), arg->getName()+"'ipc"); From b118be2f966366088671a2f473377077141380ad Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Thu, 31 Oct 2019 05:19:59 +0000 Subject: [PATCH 04/22] add missing files and fix minor bugs --- enzyme/Enzyme/EnzymeLogic.cpp | 12 +++- enzyme/Enzyme/GradientUtils.cpp | 4 +- enzyme/functional_tests_c/insertsort_sum.c | 10 +++ enzyme/functional_tests_c/readwriteread.c | 67 +++++++++++++++++++ .../testfiles/readwriteread-enzyme0.test | 6 ++ .../testfiles/readwriteread-enzyme1.test | 6 ++ .../testfiles/readwriteread-enzyme2.test | 6 ++ .../testfiles/readwriteread-enzyme3.test | 6 ++ 8 files changed, 112 insertions(+), 5 deletions(-) create mode 100644 enzyme/functional_tests_c/readwriteread.c create mode 100644 enzyme/functional_tests_c/testfiles/readwriteread-enzyme0.test create mode 100644 enzyme/functional_tests_c/testfiles/readwriteread-enzyme1.test create mode 100644 enzyme/functional_tests_c/testfiles/readwriteread-enzyme2.test create mode 100644 enzyme/functional_tests_c/testfiles/readwriteread-enzyme3.test diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 803c75d419b5..d833719c87a3 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -48,7 +48,7 @@ llvm::cl::opt enzyme_print("enzyme_print", cl::init(false), cl::Hidden, cl::desc("Print before and after fns for autodiff")); cl::opt cachereads( - "enzyme_cachereads", cl::init(true), cl::Hidden, + "enzyme_cachereads", cl::init(false), cl::Hidden, cl::desc("Force caching of all reads")); //! return structtype if recursive function @@ -152,6 +152,9 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul for (auto _I = _BB->begin(), _E = _BB->end(); _I != _E; _I++) { Instruction* _inst = &*_I; if (auto _op = dyn_cast(_inst)) { + if(auto intrinsic = dyn_cast(_inst)) { + continue; + } std::set volatile_args; std::vector args; llvm::errs() << "args are: "; @@ -583,7 +586,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul gutils->erase(op); } else if(LoadInst* li = dyn_cast(inst)) { - //if (/*gutils->isConstantInstruction(inst) ||*/ gutils->isConstantValue(inst)) continue; + if (/*gutils->isConstantInstruction(inst) ||*/ gutils->isConstantValue(inst)) continue; if (/*true || */(cachereads && can_modref_map[inst])) { llvm::errs() << "Forcibly caching reads " << *li << "\n"; IRBuilder<> BuilderZ(li); @@ -1706,6 +1709,9 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co for (auto _I = _BB->begin(), _E = _BB->end(); _I != _E; _I++) { Instruction* _inst = &*_I; if (auto _op = dyn_cast(_inst)) { + if(auto intrinsic = dyn_cast(_inst)) { + continue; + } std::set volatile_args; std::vector args; llvm::errs() << "args are: "; @@ -2246,7 +2252,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co if (dif1) addToDiffe(op->getOperand(1), dif1); if (dif2) addToDiffe(op->getOperand(2), dif2); } else if(auto op = dyn_cast(inst)) { - //if (gutils->isConstantValue(inst)) continue; + if (gutils->isConstantValue(inst)) continue; llvm::errs() << "TFKDEBUG Saw load instruction: " << *inst << "\n"; diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index c9ac6c259bc7..2d2398ca6c48 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -351,8 +351,8 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) { auto cs = gvemd->getValue(); return invertedPointers[val] = cs; } else if (auto fn = dyn_cast(val)) { - llvm::errs() << "Note(TFK): Need to disable function pointer casts for now.\n"; - assert(false); + //llvm::errs() << "Note(TFK): Need to disable function pointer casts for now.\n"; + //assert(false); //! Todo allow tape propagation std::set volatile_args; auto newf = CreatePrimalAndGradient(fn, /*constant_args*/{}, TLI, AA, /*returnValue*/false, /*differentialReturn*/fn->getReturnType()->isFPOrFPVectorTy(), /*topLevel*/false, /*additionalArg*/nullptr, volatile_args); diff --git a/enzyme/functional_tests_c/insertsort_sum.c b/enzyme/functional_tests_c/insertsort_sum.c index 875bf620077c..e8ae9249bb8b 100644 --- a/enzyme/functional_tests_c/insertsort_sum.c +++ b/enzyme/functional_tests_c/insertsort_sum.c @@ -17,6 +17,10 @@ float* unsorted_array_init(int N) { } // sums the first half of a sorted array. +<<<<<<< HEAD +======= +//__attribute__((noinline)) +>>>>>>> add missing files and fix minor bugs void insertsort_sum (float* array, int N, float* ret) { float sum = 0; //qsort(array, N, sizeof(float), cmp); @@ -39,7 +43,13 @@ void insertsort_sum (float* array, int N, float* ret) { *ret = sum; } +<<<<<<< HEAD +======= +//void insertsort_sum (float* array, int N, float* ret) { +// insertsort_sum_subcall(array, N, ret); +//} +>>>>>>> add missing files and fix minor bugs int main(int argc, char** argv) { diff --git a/enzyme/functional_tests_c/readwriteread.c b/enzyme/functional_tests_c/readwriteread.c new file mode 100644 index 000000000000..06dfafd54381 --- /dev/null +++ b/enzyme/functional_tests_c/readwriteread.c @@ -0,0 +1,67 @@ +#include +#include +#include +#include +#define __builtin_autodiff __enzyme_autodiff +double __enzyme_autodiff(void*, ...); +int counter = 0; +double recurse_max_helper(float* a, float* b, int N) { + if (N <= 0) { + return *a + *b; + } + return recurse_max_helper(a,b,N-1) + recurse_max_helper(a,b,N-2); +} + + +double f_read(double* x) { + double product = (*x) * (*x); + return product; +} + + +void g_write(double* x, double product) { + *x = (*x) * product; +} + +double h_read(double* x) { + return *x; +} + + +double readwriteread_helper(double* x) { + double product = f_read(x); + g_write(x, product); + double ret = h_read(x); + return ret; +} + +void readwriteread(double*__restrict x, double*__restrict ret) { + *ret = readwriteread_helper(x); + //*ret = (*x) * (*x) * (*x); +} + + + +int main(int argc, char** argv) { + + + + double ret = 0; + double dret = 1.0; + double* x = (double*) malloc(sizeof(double)); + double* dx = (double*) malloc(sizeof(double)); + *x = 2.0; + *dx = 0.0; + + __builtin_autodiff(readwriteread, x, dx, &ret, &dret); + + + printf("dx is %f ret is %f\n", *dx, ret); + assert(*dx == 3*2.0*2.0); + //assert(db == 17711.0*2); + + + + //printf("hello! %f, res2 %f, da: %f, db: %f\n", ret, ret, da,db); + return 0; +} diff --git a/enzyme/functional_tests_c/testfiles/readwriteread-enzyme0.test b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme0.test new file mode 100644 index 000000000000..14a037d8426b --- /dev/null +++ b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme0.test @@ -0,0 +1,6 @@ +; RUN: cd %desired_wd +; RUN: make clean-readwriteread-enzyme0 ENZYME_PLUGIN=%loadEnzyme +; RUN: make build/readwriteread-enzyme0 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath +; RUN: build/readwriteread-enzyme0 +; RUN: make clean-readwriteread-enzyme0 ENZYME_PLUGIN=%loadEnzyme + diff --git a/enzyme/functional_tests_c/testfiles/readwriteread-enzyme1.test b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme1.test new file mode 100644 index 000000000000..9dc3174b8435 --- /dev/null +++ b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme1.test @@ -0,0 +1,6 @@ +; RUN: cd %desired_wd +; RUN: make clean-readwriteread-enzyme1 ENZYME_PLUGIN=%loadEnzyme +; RUN: make build/readwriteread-enzyme1 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath +; RUN: build/readwriteread-enzyme1 +; RUN: make clean-readwriteread-enzyme1 ENZYME_PLUGIN=%loadEnzyme + diff --git a/enzyme/functional_tests_c/testfiles/readwriteread-enzyme2.test b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme2.test new file mode 100644 index 000000000000..e03f5242726c --- /dev/null +++ b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme2.test @@ -0,0 +1,6 @@ +; RUN: cd %desired_wd +; RUN: make clean-readwriteread-enzyme2 ENZYME_PLUGIN=%loadEnzyme +; RUN: make build/readwriteread-enzyme2 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath +; RUN: build/readwriteread-enzyme2 +; RUN: make clean-readwriteread-enzyme2 ENZYME_PLUGIN=%loadEnzyme + diff --git a/enzyme/functional_tests_c/testfiles/readwriteread-enzyme3.test b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme3.test new file mode 100644 index 000000000000..40efc5f2c7e7 --- /dev/null +++ b/enzyme/functional_tests_c/testfiles/readwriteread-enzyme3.test @@ -0,0 +1,6 @@ +; RUN: cd %desired_wd +; RUN: make clean-readwriteread-enzyme3 ENZYME_PLUGIN=%loadEnzyme +; RUN: make build/readwriteread-enzyme3 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath +; RUN: build/readwriteread-enzyme3 +; RUN: make clean-readwriteread-enzyme3 ENZYME_PLUGIN=%loadEnzyme + From 2d020bd7d4a35cb2699f4b6952d52398bfbcedd8 Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Thu, 31 Oct 2019 17:37:10 +0000 Subject: [PATCH 05/22] all enzyme-check tests work except the badcall tests --- enzyme/Enzyme/EnzymeLogic.cpp | 88 +++++++++++++++++++++++++++++++-- enzyme/Enzyme/GradientUtils.cpp | 12 +++-- enzyme/Enzyme/GradientUtils.h | 3 ++ 3 files changed, 95 insertions(+), 8 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index d833719c87a3..57eb64e291b8 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -48,7 +48,7 @@ llvm::cl::opt enzyme_print("enzyme_print", cl::init(false), cl::Hidden, cl::desc("Print before and after fns for autodiff")); cl::opt cachereads( - "enzyme_cachereads", cl::init(false), cl::Hidden, + "enzyme_cachereads", cl::init(true), cl::Hidden, cl::desc("Force caching of all reads")); //! return structtype if recursive function @@ -155,6 +155,22 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul if(auto intrinsic = dyn_cast(_inst)) { continue; } + Function* called = _op->getCalledFunction(); + if (auto castinst = dyn_cast(_op->getCalledValue())) { + if (castinst->isCast()) { + if (auto fn = dyn_cast(castinst->getOperand(0))) { + if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { + called = fn; + } + } + } + } + + + //if (_op->getCalledFunction()->getName()=="free") { + if (isCertainMallocOrFree(called)) { + continue; + } std::set volatile_args; std::vector args; llvm::errs() << "args are: "; @@ -181,6 +197,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul for (auto I = BB->begin(), E = BB->end(); I != E; I++) { Instruction* inst = &*I; if (inst == _inst) continue; + if (gutils->DT.dominates(inst, _inst)) { llvm::errs() << inst->getParent() << "\n"; llvm::errs() << _inst->getParent() << "\n"; @@ -200,6 +217,27 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul } } if (auto op = dyn_cast(inst)) { + + + Function* called = op->getCalledFunction(); + if (auto castinst = dyn_cast(op->getCalledValue())) { + if (castinst->isCast()) { + if (auto fn = dyn_cast(castinst->getOperand(0))) { + if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { + called = fn; + } + } + } + } + + if (op->getCalledFunction()) { + llvm::errs() << "Called Function name is " << op->getCalledFunction()->getName() << "\n"; + } else { + llvm::errs() << "Called Function is null \n";// << op->getCalledFunction()->getName() << "\n"; + } + if (isCertainMallocOrFree(called)) { + continue; + } for (int i = 0; i < args.size(); i++) { if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(_op, i, TLI)))) { llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; @@ -229,6 +267,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul std::map can_modref_map; + gutils->can_modref_map = &can_modref_map; if (true) { //!additionalArg && !topLevel) { for(BasicBlock* BB: gutils->originalBlocks) { llvm::errs() << "BB: " << *BB << "\n"; @@ -586,7 +625,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul gutils->erase(op); } else if(LoadInst* li = dyn_cast(inst)) { - if (/*gutils->isConstantInstruction(inst) ||*/ gutils->isConstantValue(inst)) continue; + if (gutils->isConstantInstruction(inst) || gutils->isConstantValue(inst)) continue; if (/*true || */(cachereads && can_modref_map[inst])) { llvm::errs() << "Forcibly caching reads " << *li << "\n"; IRBuilder<> BuilderZ(li); @@ -1712,6 +1751,25 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co if(auto intrinsic = dyn_cast(_inst)) { continue; } + + Function* called = _op->getCalledFunction(); + if (auto castinst = dyn_cast(_op->getCalledValue())) { + if (castinst->isCast()) { + if (auto fn = dyn_cast(castinst->getOperand(0))) { + if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { + called = fn; + } + } + } + } + + + //if (_op->getCalledFunction()) { + //if (_op->getCalledFunction()->getName()=="free") { + if (isCertainMallocOrFree(called)) { + continue; + } + //} std::set volatile_args; std::vector args; llvm::errs() << "args are: "; @@ -1738,6 +1796,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co for (auto I = BB->begin(), E = BB->end(); I != E; I++) { Instruction* inst = &*I; if (inst == _inst) continue; + if (gutils->DT.dominates(inst, _inst)) { llvm::errs() << inst->getParent() << "\n"; llvm::errs() << _inst->getParent() << "\n"; @@ -1757,6 +1816,28 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co } } if (auto op = dyn_cast(inst)) { + + Function* called = op->getCalledFunction(); + if (auto castinst = dyn_cast(op->getCalledValue())) { + if (castinst->isCast()) { + if (auto fn = dyn_cast(castinst->getOperand(0))) { + if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { + called = fn; + } + } + } + } + + if (op->getCalledFunction()) { + llvm::errs() << "Called Function name is " << op->getCalledFunction()->getName() << "\n"; + } else { + llvm::errs() << "Called Function is null \n";// << op->getCalledFunction()->getName() << "\n"; + } + + if (isCertainMallocOrFree(called)) { + continue; + } + for (int i = 0; i < args.size(); i++) { if (!args[i]->getType()->isPointerTy()) continue; llvm::errs() << "TFKDEBUG: " << *args[i] << "\n"; @@ -1794,6 +1875,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co std::map can_modref_map; + gutils->can_modref_map = &can_modref_map; if (/*!additionalArg && */!topLevel) { for(BasicBlock* BB: gutils->originalBlocks) { for (BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend(); I != E;) { @@ -2252,7 +2334,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co if (dif1) addToDiffe(op->getOperand(1), dif1); if (dif2) addToDiffe(op->getOperand(2), dif2); } else if(auto op = dyn_cast(inst)) { - if (gutils->isConstantValue(inst)) continue; + if (gutils->isConstantValue(inst) || gutils->isConstantInstruction(inst)) continue; llvm::errs() << "TFKDEBUG Saw load instruction: " << *inst << "\n"; diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 2d2398ca6c48..db0df980bd24 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -223,7 +223,7 @@ bool shouldRecompute(Value* val, const ValueToValueMapTy& available) { } else if (auto op = dyn_cast(val)) { return shouldRecompute(op->getOperand(0), available) || shouldRecompute(op->getOperand(1), available) || shouldRecompute(op->getOperand(2), available); } else if (auto load = dyn_cast(val)) { - return true; // NOTE(TFK): Remove this. + //return true; // NOTE(TFK): Remove this. Value* idx = load->getOperand(0); while (!isa(idx)) { if (auto gep = dyn_cast(idx)) { @@ -828,10 +828,12 @@ Value* GradientUtils::lookupM(Value* val, IRBuilder<>& BuilderM) { } } - if (!shouldRecompute(inst, available)) { - auto op = unwrapM(inst, BuilderM, available, /*lookupIfAble*/true); - assert(op); - return op; + if (!(*(this->can_modref_map))[inst]) { + if (!shouldRecompute(inst, available)) { + auto op = unwrapM(inst, BuilderM, available, /*lookupIfAble*/true); + assert(op); + return op; + } } /* if (!inLoop) { diff --git a/enzyme/Enzyme/GradientUtils.h b/enzyme/Enzyme/GradientUtils.h index f3f2e78a780e..ffa9207c80b0 100644 --- a/enzyme/Enzyme/GradientUtils.h +++ b/enzyme/Enzyme/GradientUtils.h @@ -89,6 +89,9 @@ class GradientUtils { ValueToValueMapTy scopeFrees; ValueToValueMapTy originalToNewFn; + std::map* can_modref_map; + + Value* getNewFromOriginal(Value* originst) { assert(originst); auto f = originalToNewFn.find(originst); From 79aab0c9700c084b9ec3fd2e0781880f64ef60f9 Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Thu, 31 Oct 2019 20:31:03 +0000 Subject: [PATCH 06/22] modify the badcall tests so that they pass --- enzyme/test/Enzyme/badcall.ll | 30 ++++++++++++-------- enzyme/test/Enzyme/badcall2.ll | 25 +++++++++++------ enzyme/test/Enzyme/badcall3.ll | 27 +++++++++++------- enzyme/test/Enzyme/badcall4.ll | 17 ++++++------ enzyme/test/Enzyme/badcallused.ll | 44 +++++++++++++++++------------- enzyme/test/Enzyme/badcallused2.ll | 42 +++++++++++++++------------- 6 files changed, 108 insertions(+), 77 deletions(-) diff --git a/enzyme/test/Enzyme/badcall.ll b/enzyme/test/Enzyme/badcall.ll index 9672654917b2..15518f2ebd1d 100644 --- a/enzyme/test/Enzyme/badcall.ll +++ b/enzyme/test/Enzyme/badcall.ll @@ -42,11 +42,12 @@ attributes #1 = { noinline nounwind uwtable } ; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call { { {} } } @augmented_subf(double* %x, double* %"x'") -; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8 -; CHECK-NEXT: store double 0.000000e+00, double* %"x'" -; CHECK-NEXT: %1 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {} } undef) -; CHECK-NEXT: ret {} undef +; CHECK-NEXT: %0 = call { { {}, double } } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %1 = extractvalue { { {}, double } } %0, 0 +; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8 +; CHECK-NEXT: store double 0.000000e+00, double* %"x'" +; CHECK-NEXT: %2 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, double } %1) +; CHECK-NEXT: ret {} undef ; CHECK-NEXT: } ; CHECK: define internal {{(dso_local )?}}{ {} } @augmented_metasubf(double* nocapture %x, double* %"x'") @@ -56,16 +57,21 @@ attributes #1 = { noinline nounwind uwtable } ; CHECK-NEXT: ret { {} } undef ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{ { {} } } @augmented_subf(double* nocapture %x, double* %"x'") +; CHECK: define internal {{(dso_local )?}}{ { {}, double } } @augmented_subf(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = load double, double* %x, align 8 -; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00 -; CHECK-NEXT: store double %mul, double* %x, align 8 -; CHECK-NEXT: %1 = call { {} } @augmented_metasubf(double* %x, double* %"x'") -; CHECK-NEXT: ret { { {} } } undef +; CHECK-NEXT: %0 = alloca { { {}, double } } +; CHECK-NEXT: %1 = getelementptr { { {}, double } }, { { {}, double } }* %0, i32 0, i32 0 +; CHECK-NEXT: %2 = load double, double* %x, align 8 +; CHECK-NEXT: %3 = getelementptr { {}, double }, { {}, double }* %1, i32 0, i32 1 +; CHECK-NEXT: store double %2, double* %3 +; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00 +; CHECK-NEXT: store double %mul, double* %x, align 8 +; CHECK-NEXT: %4 = call { {} } @augmented_metasubf(double* %x, double* %"x'") +; CHECK-NEXT: %5 = load { { {}, double } }, { { {}, double } }* %0 +; CHECK-NEXT: ret { { {}, double } } %5 ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {} } %tapeArg) +; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, double } %tapeArg) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef) ; CHECK-NEXT: %1 = load double, double* %"x'" diff --git a/enzyme/test/Enzyme/badcall2.ll b/enzyme/test/Enzyme/badcall2.ll index 10a46708f25f..0f47f7f1435e 100644 --- a/enzyme/test/Enzyme/badcall2.ll +++ b/enzyme/test/Enzyme/badcall2.ll @@ -50,10 +50,11 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call { { {}, {} } } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %0 = call { { {}, {}, double } } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %1 = extractvalue { { {}, {}, double } } %0, 0 ; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8 ; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 -; CHECK-NEXT: %1 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {} } undef) +; CHECK-NEXT: %2 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {}, double } %1) ; CHECK-NEXT: ret {} undef ; CHECK-NEXT: } @@ -71,17 +72,23 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK-NEXT: ret { {} } undef ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{ { {}, {} } } @augmented_subf(double* nocapture %x, double* %"x'") +; CHECK: define internal {{(dso_local )?}}{ { {}, {}, double } } @augmented_subf(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = load double, double* %x, align 8 -; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00 +; CHECK-NEXT: %0 = alloca { { {}, {}, double } } +; CHECK-NEXT: %1 = getelementptr { { {}, {}, double } }, { { {}, {}, double } }* %0, i32 0, i32 0 +; CHECK-NEXT: %2 = load double, double* %x, align 8 +; CHECK-NEXT: %3 = getelementptr { {}, {}, double }, { {}, {}, double }* %1, i32 0, i32 2 +; CHECK-NEXT: store double %2, double* %3 +; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00 ; CHECK-NEXT: store double %mul, double* %x, align 8 -; CHECK-NEXT: %1 = call { {} } @augmented_metasubf(double* %x, double* %"x'") -; CHECK-NEXT: %2 = call { {} } @augmented_othermetasubf(double* %x, double* %"x'") -; CHECK-NEXT: ret { { {}, {} } } undef +; CHECK-NEXT: %4 = call { {} } @augmented_metasubf(double* %x, double* %"x'") +; CHECK-NEXT: %5 = call { {} } @augmented_othermetasubf(double* %x, double* %"x'") +; CHECK-NEXT: %6 = load { { {}, {}, double } }, { { {}, {}, double } }* %0 +; CHECK-NEXT: ret { { {}, {}, double } } %6 ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, {} } %tapeArg) + +; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, {}, double } %tapeArg) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = call {} @diffeothermetasubf(double* %x, double* %"x'", {} undef) ; CHECK-NEXT: %1 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef) diff --git a/enzyme/test/Enzyme/badcall3.ll b/enzyme/test/Enzyme/badcall3.ll index 86fb9083359b..0d0b936da2fe 100644 --- a/enzyme/test/Enzyme/badcall3.ll +++ b/enzyme/test/Enzyme/badcall3.ll @@ -50,10 +50,11 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call { { {}, {} } } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %0 = call { { {}, {}, double } } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %1 = extractvalue { { {}, {}, double } } %0, 0 ; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8 ; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 -; CHECK-NEXT: %1 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {} } undef) +; CHECK-NEXT: %2 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {}, double } %1) ; CHECK-NEXT: ret {} undef ; CHECK-NEXT: } @@ -71,17 +72,23 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK-NEXT: ret { {} } undef ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{ { {}, {} } } @augmented_subf(double* nocapture %x, double* %"x'") +; CHECK: define internal {{(dso_local )?}}{ { {}, {}, double } } @augmented_subf(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = load double, double* %x, align 8 -; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00 -; CHECK-NEXT: store double %mul, double* %x, align 8 -; CHECK-NEXT: %1 = call { {} } @augmented_metasubf(double* %x, double* %"x'") -; CHECK-NEXT: %2 = call { {} } @augmented_othermetasubf(double* %x, double* %"x'") -; CHECK-NEXT: ret { { {}, {} } } undef +; CHECK-NEXT: %0 = alloca { { {}, {}, double } } +; CHECK-NEXT: %1 = getelementptr { { {}, {}, double } }, { { {}, {}, double } }* %0, i32 0, i32 0 +; CHECK-NEXT: %2 = load double, double* %x, align 8 +; CHECK-NEXT: %3 = getelementptr { {}, {}, double }, { {}, {}, double }* %1, i32 0, i32 2 +; CHECK-NEXT: store double %2, double* %3 +; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00 +; CHECK-NEXT: store double %mul, double* %x, align 8 +; CHECK-NEXT: %4 = call { {} } @augmented_metasubf(double* %x, double* %"x'") +; CHECK-NEXT: %5 = call { {} } @augmented_othermetasubf(double* %x, double* %"x'") +; CHECK-NEXT: %6 = load { { {}, {}, double } }, { { {}, {}, double } }* %0 +; CHECK-NEXT: ret { { {}, {}, double } } %6 ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, {} } %tapeArg) + +; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, {}, double } %tapeArg) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = call {} @diffeothermetasubf(double* %x, double* %"x'", {} undef) ; CHECK-NEXT: %1 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef) diff --git a/enzyme/test/Enzyme/badcall4.ll b/enzyme/test/Enzyme/badcall4.ll index b7183c501717..b099fac3c2e9 100644 --- a/enzyme/test/Enzyme/badcall4.ll +++ b/enzyme/test/Enzyme/badcall4.ll @@ -51,11 +51,11 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call { { {}, i1, {}, i1 } } @augmented_subf(double* %x, double* %"x'") -; CHECK-NEXT: %1 = extractvalue { { {}, i1, {}, i1 } } %0, 0 +; CHECK-NEXT: %0 = call { { {}, i1, {}, i1, double } } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %1 = extractvalue { { {}, i1, {}, i1, double } } %0, 0 ; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8 ; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 -; CHECK-NEXT: %2 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, i1, {}, i1 } %1) +; CHECK-NEXT: %2 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, i1, {}, i1, double } %1) ; CHECK-NEXT: ret {} undef ; CHECK-NEXT: } @@ -63,7 +63,7 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK: define internal {{(dso_local )?}}{ {}, i1 } @augmented_metasubf(double* nocapture %x, double* %"x'") -; CHECK: define internal {{(dso_local )?}}{ { {}, i1, {}, i1 } } @augmented_subf(double* nocapture %x, double* %"x'") +; CHECK: define internal {{(dso_local )?}}{ { {}, i1, {}, i1, double } } @augmented_subf(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = load double, double* %x, align 8 ; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00 @@ -72,12 +72,13 @@ declare dso_local double @__enzyme_autodiff(i8*, double*, double*) local_unnamed ; CHECK-NEXT: %2 = extractvalue { {}, i1 } %1, 1 ; CHECK-NEXT: %3 = call { {}, i1 } @augmented_othermetasubf(double* %x, double* %"x'") ; CHECK-NEXT: %4 = extractvalue { {}, i1 } %3, 1 -; CHECK-NEXT: %[[iv1:.+]] = insertvalue { { {}, i1, {}, i1 } } undef, i1 %4, 0, 1 -; CHECK-NEXT: %[[iv2:.+]] = insertvalue { { {}, i1, {}, i1 } } %[[iv1]], i1 %2, 0, 3 -; CHECK-NEXT: ret { { {}, i1, {}, i1 } } %[[iv2]] +; CHECK-NEXT: %.fca.0.1.insert = insertvalue { { {}, i1, {}, i1, double } } undef, i1 %4, 0, 1 +; CHECK-NEXT: %.fca.0.3.insert = insertvalue { { {}, i1, {}, i1, double } } %.fca.0.1.insert, i1 %2, 0, 3 +; CHECK-NEXT: %.fca.0.4.insert = insertvalue { { {}, i1, {}, i1, double } } %.fca.0.3.insert, double %0, 0, 4 +; CHECK-NEXT: ret { { {}, i1, {}, i1, double } } %.fca.0.4.insert ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, i1, {}, i1 } %tapeArg) +; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, i1, {}, i1, double } %tapeArg) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = call {} @diffeothermetasubf(double* %x, double* %"x'", {} undef) ; CHECK-NEXT: %1 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef) diff --git a/enzyme/test/Enzyme/badcallused.ll b/enzyme/test/Enzyme/badcallused.ll index 51f8b2b915ed..e39062a1751e 100644 --- a/enzyme/test/Enzyme/badcallused.ll +++ b/enzyme/test/Enzyme/badcallused.ll @@ -43,12 +43,13 @@ attributes #1 = { noinline nounwind uwtable } ; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call { { {} }, i1, i1 } @augmented_subf(double* %x, double* %"x'") -; CHECK-NEXT: %1 = extractvalue { { {} }, i1, i1 } %0, 1 -; CHECK-NEXT: %sel = select i1 %1, double 2.000000e+00, double 3.000000e+00 +; CHECK-NEXT: %0 = call { { {}, double }, i1, i1 } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %1 = extractvalue { { {}, double }, i1, i1 } %0, 0 +; CHECK-NEXT: %2 = extractvalue { { {}, double }, i1, i1 } %0, 1 +; CHECK-NEXT: %sel = select i1 %2, double 2.000000e+00, double 3.000000e+00 ; CHECK-NEXT: store double %sel, double* %x, align 8 -; CHECK-NEXT: store double 0.000000e+00, double* %"x'" -; CHECK-NEXT: %[[dsubf:.+]] = call {} @diffesubf(double* nonnull %x, double* %"x'", { {} } undef) +; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8 +; CHECK-NEXT: %3 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, double } %1) ; CHECK-NEXT: ret {} undef ; CHECK-NEXT: } @@ -65,24 +66,29 @@ attributes #1 = { noinline nounwind uwtable } ; CHECK-NEXT: ret { {}, i1, i1 } %3 ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{ { {} }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'") +; CHECK: define internal {{(dso_local )?}}{ { {}, double }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = alloca { { {} }, i1, i1 } -; CHECK-NEXT: %1 = load double, double* %x, align 8 -; CHECK-NEXT: %mul = fmul fast double %1, 2.000000e+00 +; CHECK-NEXT: %0 = alloca { { {}, double }, i1, i1 } +; CHECK-NEXT: %1 = getelementptr { { {}, double }, i1, i1 } +; CHECK-NEXT: %2 = load double, double* %x, align 8 +; CHECK-NEXT: %3 = getelementptr { {}, double }, { {}, double }* %1, i32 0, i32 1 +; CHECK-NEXT: store double %2, double* %3 +; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00 ; CHECK-NEXT: store double %mul, double* %x, align 8 -; CHECK-NEXT: %2 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'") -; CHECK-NEXT: %3 = extractvalue { {}, i1, i1 } %2, 1 -; CHECK-NEXT: %antiptr_call = extractvalue { {}, i1, i1 } %2, 2 -; CHECK-NEXT: %4 = getelementptr { { {} }, i1, i1 }, { { {} }, i1, i1 }* %0, i32 0, i32 1 -; CHECK-NEXT: store i1 %3, i1* %4 -; CHECK-NEXT: %5 = getelementptr { { {} }, i1, i1 }, { { {} }, i1, i1 }* %0, i32 0, i32 2 -; CHECK-NEXT: store i1 %antiptr_call, i1* %5 -; CHECK-NEXT: %[[toret:.+]] = load { { {} }, i1, i1 }, { { {} }, i1, i1 }* %0 -; CHECK-NEXT: ret { { {} }, i1, i1 } %[[toret]] +; CHECK-NEXT: %4 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'") +; CHECK-NEXT: %5 = extractvalue { {}, i1, i1 } %4, 1 +; CHECK-NEXT: %antiptr_call = extractvalue { {}, i1, i1 } %4, 2 + + +; CHECK-NEXT: %6 = getelementptr { { {}, double }, i1, i1 }, { { {}, double }, i1, i1 }* %0, i32 0, i32 1 +; CHECK-NEXT: store i1 %5, i1* %6 +; CHECK-NEXT: %7 = getelementptr { { {}, double }, i1, i1 }, { { {}, double }, i1, i1 }* %0, i32 0, i32 2 +; CHECK-NEXT: store i1 %antiptr_call, i1* %7 +; CHECK-NEXT: %[[toret:.+]] = load { { {}, double }, i1, i1 }, { { {}, double }, i1, i1 }* %0 +; CHECK-NEXT: ret { { {}, double }, i1, i1 } %[[toret]] ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {} } %tapeArg) +; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, double } %tapeArg) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef) ; CHECK-NEXT: %1 = load double, double* %"x'" diff --git a/enzyme/test/Enzyme/badcallused2.ll b/enzyme/test/Enzyme/badcallused2.ll index 92069b003948..0513dde7ad9f 100644 --- a/enzyme/test/Enzyme/badcallused2.ll +++ b/enzyme/test/Enzyme/badcallused2.ll @@ -53,12 +53,13 @@ attributes #1 = { noinline nounwind uwtable } ; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call { { {}, {} }, i1, i1 } @augmented_subf(double* %x, double* %"x'") -; CHECK-NEXT: %1 = extractvalue { { {}, {} }, i1, i1 } %0, 1 -; CHECK-NEXT: %sel = select i1 %1, double 2.000000e+00, double 3.000000e+00 +; CHECK-NEXT: %0 = call { { {}, {}, double }, i1, i1 } @augmented_subf(double* %x, double* %"x'") +; CHECK-NEXT: %1 = extractvalue { { {}, {}, double }, i1, i1 } %0, 0 +; CHECK-NEXT: %2 = extractvalue { { {}, {}, double }, i1, i1 } %0, 1 +; CHECK-NEXT: %sel = select i1 %2, double 2.000000e+00, double 3.000000e+00 ; CHECK-NEXT: store double %sel, double* %x, align 8 ; CHECK-NEXT: store double 0.000000e+00, double* %"x'" -; CHECK-NEXT: %[[dsubf:.+]] = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {} } undef) +; CHECK-NEXT: %[[dsubf:.+]] = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, {}, double } %1) ; CHECK-NEXT: ret {} undef ; CHECK-NEXT: } @@ -82,25 +83,28 @@ attributes #1 = { noinline nounwind uwtable } ; CHECK-NEXT: ret { {} } undef ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{ { {}, {} }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'") +; CHECK: define internal {{(dso_local )?}}{ { {}, {}, double }, i1, i1 } @augmented_subf(double* nocapture %x, double* %"x'") ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = alloca { { {}, {} }, i1, i1 } -; CHECK-NEXT: %1 = load double, double* %x, align 8 -; CHECK-NEXT: %mul = fmul fast double %1, 2.000000e+00 +; CHECK-NEXT: %0 = alloca { { {}, {}, double }, i1, i1 } +; CHECK-NEXT: %1 = getelementptr { { {}, {}, double }, i1, i1 }, { { {}, {}, double }, i1, i1 }* %0, i32 0, i32 0 +; CHECK-NEXT: %2 = load double, double* %x, align 8 +; CHECK-NEXT: %3 = getelementptr { {}, {}, double }, { {}, {}, double }* %1, i32 0, i32 2 +; CHECK-NEXT: store double %2, double* %3 +; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00 ; CHECK-NEXT: store double %mul, double* %x, align 8 -; CHECK-NEXT: %2 = call { {} } @augmented_omegasubf(double* %x, double* %"x'") -; CHECK-NEXT: %3 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'") -; CHECK-NEXT: %4 = extractvalue { {}, i1, i1 } %3, 1 -; CHECK-NEXT: %antiptr_call2 = extractvalue { {}, i1, i1 } %3, 2 -; CHECK-NEXT: %5 = getelementptr { { {}, {} }, i1, i1 }, { { {}, {} }, i1, i1 }* %0, i32 0, i32 1 -; CHECK-NEXT: store i1 %4, i1* %5 -; CHECK-NEXT: %6 = getelementptr { { {}, {} }, i1, i1 }, { { {}, {} }, i1, i1 }* %0, i32 0, i32 2 -; CHECK-NEXT: store i1 %antiptr_call2, i1* %6 -; CHECK-NEXT: %[[toret:.+]] = load { { {}, {} }, i1, i1 }, { { {}, {} }, i1, i1 }* %0 -; CHECK-NEXT: ret { { {}, {} }, i1, i1 } %[[toret]] +; CHECK-NEXT: %4 = call { {} } @augmented_omegasubf(double* %x, double* %"x'") +; CHECK-NEXT: %5 = call { {}, i1, i1 } @augmented_metasubf(double* %x, double* %"x'") +; CHECK-NEXT: %6 = extractvalue { {}, i1, i1 } %5, 1 +; CHECK-NEXT: %antiptr_call2 = extractvalue { {}, i1, i1 } %5, 2 +; CHECK-NEXT: %7 = getelementptr { { {}, {}, double }, i1, i1 }, { { {}, {}, double }, i1, i1 }* %0, i32 0, i32 1 +; CHECK-NEXT: store i1 %6, i1* %7 +; CHECK-NEXT: %8 = getelementptr { { {}, {}, double }, i1, i1 }, { { {}, {}, double }, i1, i1 }* %0, i32 0, i32 2 +; CHECK-NEXT: store i1 %antiptr_call2, i1* %8 +; CHECK-NEXT: %[[toret:.+]] = load { { {}, {}, double }, i1, i1 }, { { {}, {}, double }, i1, i1 }* %0 +; CHECK-NEXT: ret { { {}, {}, double }, i1, i1 } %[[toret]] ; CHECK-NEXT: } -; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, {} } %tapeArg) +; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, {}, double } %tapeArg) ; CHECK-NEXT: entry: ; CHECK-NEXT: %0 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef) ; CHECK-NEXT: %1 = call {} @diffeomegasubf(double* %x, double* %"x'", {} undef) From 5d6669d3352f39ca5cf3bdaed85fa320d4e00d4c Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Thu, 31 Oct 2019 22:05:36 +0000 Subject: [PATCH 07/22] cleanup --- enzyme/Enzyme/EnzymeLogic.cpp | 524 +++++++++++-------------------- enzyme/test/Enzyme/insertsort.ll | 1 - 2 files changed, 187 insertions(+), 338 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 57eb64e291b8..5cf1fcaa7089 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -51,6 +51,183 @@ cl::opt cachereads( "enzyme_cachereads", cl::init(true), cl::Hidden, cl::desc("Force caching of all reads")); + +std::map compute_volatile_load_map(GradientUtils* gutils, AAResults& AA, + std::set volatile_args) { + std::map can_modref_map; + // NOTE(TFK): Want to construct a test case where this causes an issue. + for(BasicBlock* BB: gutils->originalBlocks) { + for (auto I = BB->begin(), E = BB->end(); I != E; I++) { + Instruction* inst = &*I; + if (auto op = dyn_cast(inst)) { + if (gutils->isConstantValue(inst) || gutils->isConstantInstruction(inst)) { + continue; + } + auto op_operand = op->getPointerOperand(); + auto op_type = op->getType(); + bool can_modref = false; + + auto obj = GetUnderlyingObject(op->getPointerOperand(), BB->getModule()->getDataLayout(), 100); + if (auto arg = dyn_cast(obj)) { + if (volatile_args.find(arg->getArgNo()) != volatile_args.end()) { + can_modref = true; + } + } + + for (int k = 0; k < gutils->originalBlocks.size(); k++) { + if (AA.canBasicBlockModify(*(gutils->originalBlocks[k]), MemoryLocation::get(op))) { + can_modref = true; + break; + } + } + can_modref_map[inst] = can_modref; + } + } + } + return can_modref_map; +} + + +std::set compute_volatile_args_for_one_callsite(Instruction* callsite_inst, DominatorTree &DT, + TargetLibraryInfo &TLI, AAResults& AA, GradientUtils* gutils, std::set parent_volatile_args) { + CallInst* callsite_op = dyn_cast(callsite_inst); + assert(callsite_op != nullptr); + + std::set volatile_args; + std::vector args; + std::vector args_safe; + + // First, we need to propagate the volatile status from the parent function to the callee. + // because memory location x modified after parent returns => x modified after callee returns. + for (int i = 0; i < callsite_op->getNumArgOperands(); i++) { + args.push_back(callsite_op->getArgOperand(i)); + bool init_safe = true; + + // If the UnderlyingObject is from one of this function's arguments, then we need to propagate the volatility. + Value* obj = GetUnderlyingObject(callsite_op->getArgOperand(i), + callsite_inst->getParent()->getModule()->getDataLayout(), + 100); + // If underlying object is an Argument, check parent volatility status. + if (auto arg = dyn_cast(obj)) { + if (parent_volatile_args.find(arg->getArgNo()) != parent_volatile_args.end()) { + init_safe = false; + } + } + // TODO(TFK): Also need to check whether underlying object is traced to load / non-allocating-call instruction. + args_safe.push_back(init_safe); + } + + // Second, we check for memory modifications that can occur in the continuation of the + // callee inside the parent function. + for(BasicBlock* BB: gutils->originalBlocks) { + for (auto I = BB->begin(), E = BB->end(); I != E; I++) { + Instruction* inst = &*I; + if (inst == callsite_inst) continue; + + // If the "inst" does not dominate "callsite_inst" then we cannot prove that + // "inst" happens before "callsite_inst". If "inst" modifies an argument of the call, + // then that call needs to consider the argument volatile. + if (!gutils->DT.dominates(inst, callsite_inst)) { + // Consider Store Instructions. + if (auto op = dyn_cast(inst)) { + for (int i = 0; i < args.size(); i++) { + // If the modification flag is set, then this instruction may modify the $i$th argument of the call. + if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) { + //llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; + } else { + //llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; + args_safe[i] = false; + } + } + } + + // Consider Call Instructions. + if (auto op = dyn_cast(inst)) { + // Ignore memory allocation functions. + Function* called = op->getCalledFunction(); + if (auto castinst = dyn_cast(op->getCalledValue())) { + if (castinst->isCast()) { + if (auto fn = dyn_cast(castinst->getOperand(0))) { + if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { + called = fn; + } + } + } + } + if (isCertainMallocOrFree(called)) { + continue; + } + + // For all the arguments, perform same check as for Stores, but ignore non-pointer arguments. + for (int i = 0; i < args.size(); i++) { + if (!args[i]->getType()->isPointerTy()) continue; // Ignore non-pointer arguments. + if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) { + //llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; + } else { + //llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; + args_safe[i] = false; + } + } + } + } + } + } + + //llvm::errs() << "CallInst: " << *callsite_op<< "CALL ARGUMENT INFO: \n"; + for (int i = 0; i < args.size(); i++) { + if (!args_safe[i]) { + volatile_args.insert(i); + } + //llvm::errs() << "Arg: " << *args[i] << " STATUS: " << args_safe[i] << "\n"; + } + return volatile_args; +} + +// Given a function and the arguments passed to it by its caller that are volatile (_volatile_args) compute +// the set of volatile arguments for each callsite inside the function. A pointer argument is volatile at +// a callsite if the memory pointed to might be modified after that callsite. +std::map > compute_volatile_args_for_callsites( + Function* F, DominatorTree &DT, TargetLibraryInfo &TLI, AAResults& AA, GradientUtils* gutils, + std::set const volatile_args) { + std::map > volatile_args_map; + for(BasicBlock* BB: gutils->originalBlocks) { + for (auto I = BB->begin(), E = BB->end(); I != E; I++) { + Instruction* inst = &*I; + if (auto op = dyn_cast(inst)) { + + // We do not need volatile args for intrinsic functions. So skip such callsites. + if(auto intrinsic = dyn_cast(inst)) { + continue; + } + + // We do not need volatile args for memory allocation functions. So skip such callsites. + Function* called = op->getCalledFunction(); + if (auto castinst = dyn_cast(op->getCalledValue())) { + if (castinst->isCast()) { + if (auto fn = dyn_cast(castinst->getOperand(0))) { + if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { + called = fn; + } + } + } + } + if (isCertainMallocOrFree(called)) { + continue; + } + + // For all other calls, we compute the volatile args for this callsite. + volatile_args_map[op] = compute_volatile_args_for_one_callsite(inst, + DT, TLI, AA, gutils, volatile_args); + } + } + } + return volatile_args_map; +} + + + + + //! return structtype if recursive function std::pair CreateAugmentedPrimal(Function* todiff, AAResults &AA, const std::set& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed, const std::set _volatile_args) { static std::map, std::set, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair> cachedfunctions; @@ -143,174 +320,14 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul llvm::errs() << "arg " << count++ << " is " << *i << " volatile: " << is_volatile << "\n"; } - std::map > volatile_args_map; - //DominatorTree DT(*gutils->oldFunc); + std::map > volatile_args_map = + compute_volatile_args_for_callsites(gutils->oldFunc, gutils->DT, TLI, AA, gutils, _volatile_args); llvm::errs() << "Old function content is " << *gutils->oldFunc << "\n"; - for(BasicBlock* _BB: gutils->originalBlocks) { - for (auto _I = _BB->begin(), _E = _BB->end(); _I != _E; _I++) { - Instruction* _inst = &*_I; - if (auto _op = dyn_cast(_inst)) { - if(auto intrinsic = dyn_cast(_inst)) { - continue; - } - Function* called = _op->getCalledFunction(); - if (auto castinst = dyn_cast(_op->getCalledValue())) { - if (castinst->isCast()) { - if (auto fn = dyn_cast(castinst->getOperand(0))) { - if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { - called = fn; - } - } - } - } - - //if (_op->getCalledFunction()->getName()=="free") { - if (isCertainMallocOrFree(called)) { - continue; - } - std::set volatile_args; - std::vector args; - llvm::errs() << "args are: "; - std::vector args_safe; - for (int i = 0; i < _op->getNumArgOperands(); i++) { - //if (_op->getArgOperand(i)->getType()->isPointerTy()) { - args.push_back(_op->getArgOperand(i)); - bool init_safe = true; - // If the UnderlyingObject is from one of this function's arguments, then we need to propagate the volatility. - Value* obj = GetUnderlyingObject(_op->getArgOperand(i),_BB->getModule()->getDataLayout(),100); - if (auto arg = dyn_cast(obj)) { - if (_volatile_args.find(arg->getArgNo()) != _volatile_args.end()) { - init_safe = false; - } - } - - args_safe.push_back(init_safe); - llvm::errs() << " "<< *_op->getArgOperand(i) <<" "; - //} - } - llvm::errs() << "\n"; - - for(BasicBlock* BB: gutils->originalBlocks) { - for (auto I = BB->begin(), E = BB->end(); I != E; I++) { - Instruction* inst = &*I; - if (inst == _inst) continue; - - if (gutils->DT.dominates(inst, _inst)) { - llvm::errs() << inst->getParent() << "\n"; - llvm::errs() << _inst->getParent() << "\n"; - llvm::errs() << "callinst: " << *_op <<"DOES dominate " << *I << "\n"; - } else { - llvm::errs() << "callinst: " << *_op <<"does not dominate " << *I << "\n"; - // In this case "inst" may occur after the call instruction (_inst). If "inst" is a store, it might necessitate caching a load inside the call. - - if (auto op = dyn_cast(inst)) { - for (int i = 0; i < args.size(); i++) { - if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(_op, i, TLI)/*, MemoryLocation::UnknownSize)*/))) { - llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; - } else { - llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; - args_safe[i] = false; - } - } - } - if (auto op = dyn_cast(inst)) { - - - Function* called = op->getCalledFunction(); - if (auto castinst = dyn_cast(op->getCalledValue())) { - if (castinst->isCast()) { - if (auto fn = dyn_cast(castinst->getOperand(0))) { - if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { - called = fn; - } - } - } - } - - if (op->getCalledFunction()) { - llvm::errs() << "Called Function name is " << op->getCalledFunction()->getName() << "\n"; - } else { - llvm::errs() << "Called Function is null \n";// << op->getCalledFunction()->getName() << "\n"; - } - if (isCertainMallocOrFree(called)) { - continue; - } - for (int i = 0; i < args.size(); i++) { - if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(_op, i, TLI)))) { - llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; - } else { - llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; - args_safe[i] = false; - } - } - } - - - } - } - } - llvm::errs() << "CallInst: " << *_op<< "CALL ARGUMENT INFO: \n"; - for (int i = 0; i < args.size(); i++) { - if (!args_safe[i]) { - volatile_args.insert(i); - } - llvm::errs() << "Arg: " << *args[i] << " STATUS: " << args_safe[i] << "\n"; - } - volatile_args_map[_op] = volatile_args; - } - } - } - - - - std::map can_modref_map; + std::map can_modref_map = compute_volatile_load_map(gutils, AA, _volatile_args); gutils->can_modref_map = &can_modref_map; - if (true) { //!additionalArg && !topLevel) { - for(BasicBlock* BB: gutils->originalBlocks) { - llvm::errs() << "BB: " << *BB << "\n"; - //for (BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend(); I != E;) { - for (auto I = BB->begin(), E = BB->end(); I != E; I++) { - Instruction* inst = &*I; - if (auto op = dyn_cast(inst)) { - if (gutils->isConstantValue(inst) || gutils->isConstantInstruction(inst)) { - //I++; - continue; - } - auto op_operand = op->getPointerOperand(); - auto op_type = op->getType(); - bool can_modref = false; - - auto obj = GetUnderlyingObject(op->getPointerOperand(), BB->getModule()->getDataLayout(), 100); - if (auto arg = dyn_cast(obj)) { - if (_volatile_args.find(arg->getArgNo()) != _volatile_args.end()) { - can_modref = true; - } - } - - - llvm::errs() << "TFKDEBUG: looking at modref status of inst<"<<*inst << ">\n"; - for (int k = 0; k < gutils->originalBlocks.size(); k++) { - llvm::errs() << "TFKDEBUG: in BB: <"<<*(gutils->originalBlocks[k]) << ">\n"; - if (AA.canBasicBlockModify(*(gutils->originalBlocks[k]), MemoryLocation::get(op))) { - can_modref = true; - break; - } - } - llvm::errs() << "TFKDEBUG: modref status of inst<"<<*inst << "> is: " << can_modref << "\n"; - can_modref_map[inst] = can_modref; - } - //I++; - } - } - } - - - - - gutils->forceContexts(); gutils->forceAugmentedReturns(); @@ -1733,186 +1750,19 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co cachedfunctions[tup] = gutils->newFunc; - - - std::set finalized_underlying_objects; - std::set distinct_underlying_objects; - - - std::map > volatile_args_map; - //DominatorTree DT(*gutils->oldFunc); + std::map > volatile_args_map = + compute_volatile_args_for_callsites(gutils->oldFunc, gutils->DT, TLI, AA, gutils, _volatile_args); llvm::errs() << "Old function content is " << *gutils->oldFunc << "\n"; - for(BasicBlock* _BB: gutils->originalBlocks) { - for (auto _I = _BB->begin(), _E = _BB->end(); _I != _E; _I++) { - Instruction* _inst = &*_I; - if (auto _op = dyn_cast(_inst)) { - if(auto intrinsic = dyn_cast(_inst)) { - continue; - } - - Function* called = _op->getCalledFunction(); - if (auto castinst = dyn_cast(_op->getCalledValue())) { - if (castinst->isCast()) { - if (auto fn = dyn_cast(castinst->getOperand(0))) { - if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { - called = fn; - } - } - } - } - - - //if (_op->getCalledFunction()) { - //if (_op->getCalledFunction()->getName()=="free") { - if (isCertainMallocOrFree(called)) { - continue; - } - //} - std::set volatile_args; - std::vector args; - llvm::errs() << "args are: "; - std::vector args_safe; - for (int i = 0; i < _op->getNumArgOperands(); i++) { - //if (_op->getArgOperand(i)->getType()->isPointerTy()) { - args.push_back(_op->getArgOperand(i)); - bool init_safe = true; - // If the UnderlyingObject is from one of this function's arguments, then we need to propagate the volatility. - Value* obj = GetUnderlyingObject(_op->getArgOperand(i),_BB->getModule()->getDataLayout(),100); - if (auto arg = dyn_cast(obj)) { - if (_volatile_args.find(arg->getArgNo()) != _volatile_args.end()) { - init_safe = false; - } - } - - args_safe.push_back(init_safe); - llvm::errs() << " "<< *_op->getArgOperand(i) <<" "; - //} - } - llvm::errs() << "\n"; - - for(BasicBlock* BB: gutils->originalBlocks) { - for (auto I = BB->begin(), E = BB->end(); I != E; I++) { - Instruction* inst = &*I; - if (inst == _inst) continue; - - if (gutils->DT.dominates(inst, _inst)) { - llvm::errs() << inst->getParent() << "\n"; - llvm::errs() << _inst->getParent() << "\n"; - llvm::errs() << "callinst: " << *_op <<"DOES dominate " << *I << "\n"; - } else { - llvm::errs() << "callinst: " << *_op <<"does not dominate " << *I << "\n"; - // In this case "inst" may occur after the call instruction (_inst). If "inst" is a store, it might necessitate caching a load inside the call. - - if (auto op = dyn_cast(inst)) { - for (int i = 0; i < args.size(); i++) { - if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(_op, i, TLI)/*, MemoryLocation::UnknownSize)*/))) { - llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; - } else { - llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; - args_safe[i] = false; - } - } - } - if (auto op = dyn_cast(inst)) { - - Function* called = op->getCalledFunction(); - if (auto castinst = dyn_cast(op->getCalledValue())) { - if (castinst->isCast()) { - if (auto fn = dyn_cast(castinst->getOperand(0))) { - if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { - called = fn; - } - } - } - } - - if (op->getCalledFunction()) { - llvm::errs() << "Called Function name is " << op->getCalledFunction()->getName() << "\n"; - } else { - llvm::errs() << "Called Function is null \n";// << op->getCalledFunction()->getName() << "\n"; - } - - if (isCertainMallocOrFree(called)) { - continue; - } - - for (int i = 0; i < args.size(); i++) { - if (!args[i]->getType()->isPointerTy()) continue; - llvm::errs() << "TFKDEBUG: " << *args[i] << "\n"; - if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(_op, i, TLI)))) { - llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; - } else { - llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; - args_safe[i] = false; - } - } - } - - - } - } - } - llvm::errs() << "CallInst: " << *_op<< "CALL ARGUMENT INFO: \n"; - for (int i = 0; i < args.size(); i++) { - if (!args_safe[i]) { - volatile_args.insert(i); - } - llvm::errs() << "Arg: " << *args[i] << " STATUS: " << args_safe[i] << "\n"; - } - volatile_args_map[_op] = volatile_args; - } - } - } - - - - - - - - std::map can_modref_map; - gutils->can_modref_map = &can_modref_map; - if (/*!additionalArg && */!topLevel) { - for(BasicBlock* BB: gutils->originalBlocks) { - for (BasicBlock::reverse_iterator I = BB->rbegin(), E = BB->rend(); I != E;) { - Instruction* inst = &*I; - if (auto op = dyn_cast(inst)) { - if (gutils->isConstantValue(inst)) { - I++; - continue; - } - auto op_operand = op->getPointerOperand(); - auto op_type = op->getType(); - bool can_modref = false; - llvm::errs() << "TFKDEBUG: looking at modref status of inst<"<<*inst << ">\n"; - - auto obj = GetUnderlyingObject(op->getPointerOperand(), BB->getModule()->getDataLayout(), 100); - if (auto arg = dyn_cast(obj)) { - if (_volatile_args.find(arg->getArgNo()) != _volatile_args.end()) { - can_modref = true; - } - } - - - for (int k = 0; k < gutils->originalBlocks.size(); k++) { - if (AA.canBasicBlockModify(*(gutils->originalBlocks[k]), MemoryLocation::get(op))) { - can_modref = true; - break; - } - } - llvm::errs() << "TFKDEBUG: modref status of inst<"<<*inst << "> is: " << can_modref << "\n"; - can_modref_map[inst] = can_modref; - } - I++; - } - } + // NOTE(TFK): Sanity check this decision. + // Is it always possibly to recompute the result of loads at top level? + if (!topLevel) { + can_modref_map = compute_volatile_load_map(gutils, AA, _volatile_args); } - - + gutils->can_modref_map = &can_modref_map; gutils->forceContexts(true); gutils->forceAugmentedReturns(); diff --git a/enzyme/test/Enzyme/insertsort.ll b/enzyme/test/Enzyme/insertsort.ll index 7eeef4c35ee3..20fa8d67897c 100644 --- a/enzyme/test/Enzyme/insertsort.ll +++ b/enzyme/test/Enzyme/insertsort.ll @@ -1,5 +1,4 @@ ; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -instcombine -correlated-propagation -adce -instcombine -simplifycfg -early-cse -simplifycfg -S | FileCheck %s -; XFAIL: * ; Function Attrs: noinline norecurse nounwind uwtable define dso_local void @insertion_sort_inner(float* nocapture %array, i32 %i) local_unnamed_addr #0 { From 15a533fa71d8a7af4af6310ce083d6a353f6f5a5 Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Thu, 31 Oct 2019 22:49:09 +0000 Subject: [PATCH 08/22] cleanup --- enzyme/Enzyme/EnzymeLogic.cpp | 44 +-------------------------------- enzyme/Enzyme/FunctionUtils.cpp | 1 - enzyme/Enzyme/GradientUtils.cpp | 3 --- 3 files changed, 1 insertion(+), 47 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 5cf1fcaa7089..11ad5f157bed 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -224,30 +224,16 @@ std::map > compute_volatile_args_for_callsites( return volatile_args_map; } - - - - //! return structtype if recursive function std::pair CreateAugmentedPrimal(Function* todiff, AAResults &AA, const std::set& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed, const std::set _volatile_args) { static std::map, std::set, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair> cachedfunctions; static std::map, std::set, bool/*differentialReturn*/, bool/*returnUsed*/>, bool> cachedfinished; auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), std::set(_volatile_args.begin(), _volatile_args.end()), differentialReturn, returnUsed); - llvm::errs() << "TFKDEBUG: called CreateAugmentedPrimal " << todiff->getName() << "\n"; - llvm::errs() << "TFKDEBUG: called CreateAugmentedPrimal content: " << *todiff << "\n"; if (cachedfunctions.find(tup) != cachedfunctions.end()) { return cachedfunctions[tup]; } if (differentialReturn) assert(returnUsed); - - - - - - - - if (constant_args.size() == 0 && hasMetadata(todiff, "enzyme_augment")) { auto md = todiff->getMetadata("enzyme_augment"); if (!isa(md)) { @@ -306,26 +292,9 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul cachedfunctions[tup] = std::pair(gutils->newFunc, nullptr); cachedfinished[tup] = false; - llvm::errs() << "Old func: " << *gutils->oldFunc << "\n"; - llvm::errs() << "New func: " << *gutils->newFunc << "\n"; - - - - llvm::errs() << "TFKDEBUG Testing original to new for function:" << *gutils->oldFunc << "\n"; - llvm::errs() << "Arg size is " << gutils->oldFunc->arg_size() << "\n"; - int count = 0; - for (auto i=gutils->oldFunc->arg_begin(); i != gutils->oldFunc->arg_end(); i++) { - bool is_volatile = false; - if (_volatile_args.find(count) != _volatile_args.end()) is_volatile = true; - llvm::errs() << "arg " << count++ << " is " << *i << " volatile: " << is_volatile << "\n"; - } - std::map > volatile_args_map = compute_volatile_args_for_callsites(gutils->oldFunc, gutils->DT, TLI, AA, gutils, _volatile_args); - llvm::errs() << "Old function content is " << *gutils->oldFunc << "\n"; - - std::map can_modref_map = compute_volatile_load_map(gutils, AA, _volatile_args); gutils->can_modref_map = &can_modref_map; @@ -1749,13 +1718,9 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone(todiff, AA, TLI, constant_args, returnValue ? ReturnType::ArgsWithReturn : ReturnType::Args, differentialReturn, additionalArg); cachedfunctions[tup] = gutils->newFunc; - std::map > volatile_args_map = compute_volatile_args_for_callsites(gutils->oldFunc, gutils->DT, TLI, AA, gutils, _volatile_args); - llvm::errs() << "Old function content is " << *gutils->oldFunc << "\n"; - - std::map can_modref_map; // NOTE(TFK): Sanity check this decision. // Is it always possibly to recompute the result of loads at top level? @@ -2186,20 +2151,13 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co } else if(auto op = dyn_cast(inst)) { if (gutils->isConstantValue(inst) || gutils->isConstantInstruction(inst)) continue; - llvm::errs() << "TFKDEBUG Saw load instruction: " << *inst << "\n"; - auto op_operand = op->getPointerOperand(); auto op_type = op->getType(); if (cachereads) { - - bool can_modref = can_modref_map[inst]; - //can_modref = true; - if ( /*(!topLevel) ||*/ can_modref /*|| additionalArg*/) { llvm::errs() << "Forcibly loading cached reads " << *op << "\n"; + if (can_modref_map[inst]) { IRBuilder<> BuilderZ(op->getNextNode()); inst = cast(gutils->addMalloc(BuilderZ, inst)); - llvm::errs() << "Instruction after force load cache reads: " << *inst << "\n"; - llvm::errs() << "Parent after force load cache reads: " << *(inst->getFunction()) << "\n"; if (inst != op) { // Set to nullptr since op should never be used after invalidated through addMalloc. op = nullptr; diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index 33c12a605bf7..c8dd75d0b4ee 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -475,7 +475,6 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI) //auto saa = new ScopedNoAliasAAResult(sa.run(*NewF, AM)); //AA.addAAResult(*saa); - llvm::errs() << "ran alias analysis on function " << NewF->getName() << "\n"; } if (enzyme_print) diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index db0df980bd24..0b8bb9a570f1 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -223,7 +223,6 @@ bool shouldRecompute(Value* val, const ValueToValueMapTy& available) { } else if (auto op = dyn_cast(val)) { return shouldRecompute(op->getOperand(0), available) || shouldRecompute(op->getOperand(1), available) || shouldRecompute(op->getOperand(2), available); } else if (auto load = dyn_cast(val)) { - //return true; // NOTE(TFK): Remove this. Value* idx = load->getOperand(0); while (!isa(idx)) { if (auto gep = dyn_cast(idx)) { @@ -351,8 +350,6 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) { auto cs = gvemd->getValue(); return invertedPointers[val] = cs; } else if (auto fn = dyn_cast(val)) { - //llvm::errs() << "Note(TFK): Need to disable function pointer casts for now.\n"; - //assert(false); //! Todo allow tape propagation std::set volatile_args; auto newf = CreatePrimalAndGradient(fn, /*constant_args*/{}, TLI, AA, /*returnValue*/false, /*differentialReturn*/fn->getReturnType()->isFPOrFPVectorTy(), /*topLevel*/false, /*additionalArg*/nullptr, volatile_args); From 98710febe5d24b980444ca8ae7c8bac4778be48b Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Thu, 31 Oct 2019 22:55:00 +0000 Subject: [PATCH 09/22] fix compiler warnings --- enzyme/Enzyme/EnzymeLogic.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 11ad5f157bed..8af5398207db 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -63,8 +63,6 @@ std::map compute_volatile_load_map(GradientUtils* gutils, AA if (gutils->isConstantValue(inst) || gutils->isConstantInstruction(inst)) { continue; } - auto op_operand = op->getPointerOperand(); - auto op_type = op->getType(); bool can_modref = false; auto obj = GetUnderlyingObject(op->getPointerOperand(), BB->getModule()->getDataLayout(), 100); @@ -74,7 +72,7 @@ std::map compute_volatile_load_map(GradientUtils* gutils, AA } } - for (int k = 0; k < gutils->originalBlocks.size(); k++) { + for (unsigned k = 0; k < gutils->originalBlocks.size(); k++) { if (AA.canBasicBlockModify(*(gutils->originalBlocks[k]), MemoryLocation::get(op))) { can_modref = true; break; @@ -99,7 +97,7 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ // First, we need to propagate the volatile status from the parent function to the callee. // because memory location x modified after parent returns => x modified after callee returns. - for (int i = 0; i < callsite_op->getNumArgOperands(); i++) { + for (unsigned i = 0; i < callsite_op->getNumArgOperands(); i++) { args.push_back(callsite_op->getArgOperand(i)); bool init_safe = true; @@ -130,7 +128,7 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ if (!gutils->DT.dominates(inst, callsite_inst)) { // Consider Store Instructions. if (auto op = dyn_cast(inst)) { - for (int i = 0; i < args.size(); i++) { + for (unsigned i = 0; i < args.size(); i++) { // If the modification flag is set, then this instruction may modify the $i$th argument of the call. if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) { //llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; @@ -159,7 +157,7 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ } // For all the arguments, perform same check as for Stores, but ignore non-pointer arguments. - for (int i = 0; i < args.size(); i++) { + for (unsigned i = 0; i < args.size(); i++) { if (!args[i]->getType()->isPointerTy()) continue; // Ignore non-pointer arguments. if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) { //llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; @@ -174,7 +172,7 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ } //llvm::errs() << "CallInst: " << *callsite_op<< "CALL ARGUMENT INFO: \n"; - for (int i = 0; i < args.size(); i++) { + for (unsigned i = 0; i < args.size(); i++) { if (!args_safe[i]) { volatile_args.insert(i); } @@ -196,7 +194,7 @@ std::map > compute_volatile_args_for_callsites( if (auto op = dyn_cast(inst)) { // We do not need volatile args for intrinsic functions. So skip such callsites. - if(auto intrinsic = dyn_cast(inst)) { + if(isa(inst)) { continue; } From f54dda29ed75a788c91600bc953178aa34a268bb Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Thu, 31 Oct 2019 23:01:43 +0000 Subject: [PATCH 10/22] make insertsort.ll expected fail again --- enzyme/test/Enzyme/insertsort.ll | 1 + 1 file changed, 1 insertion(+) diff --git a/enzyme/test/Enzyme/insertsort.ll b/enzyme/test/Enzyme/insertsort.ll index 20fa8d67897c..7eeef4c35ee3 100644 --- a/enzyme/test/Enzyme/insertsort.ll +++ b/enzyme/test/Enzyme/insertsort.ll @@ -1,4 +1,5 @@ ; RUN: opt < %s %loadEnzyme -enzyme -enzyme_preopt=false -inline -mem2reg -instcombine -correlated-propagation -adce -instcombine -simplifycfg -early-cse -simplifycfg -S | FileCheck %s +; XFAIL: * ; Function Attrs: noinline norecurse nounwind uwtable define dso_local void @insertion_sort_inner(float* nocapture %array, i32 %i) local_unnamed_addr #0 { From 10f3eab2c759d8aa12b5bea6bc7f8be476e611ca Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Fri, 1 Nov 2019 02:03:54 +0000 Subject: [PATCH 11/22] put in the more strict/correct logic for ordering instructions in single function when checking for modrefs --- enzyme/Enzyme/EnzymeLogic.cpp | 25 ++++++++++++++++------ enzyme/functional_tests_c/insertsort_sum.c | 3 +++ enzyme/functional_tests_c/readwriteread.c | 21 ------------------ 3 files changed, 21 insertions(+), 28 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 8af5398207db..18c31286ad21 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -51,7 +51,6 @@ cl::opt cachereads( "enzyme_cachereads", cl::init(true), cl::Hidden, cl::desc("Force caching of all reads")); - std::map compute_volatile_load_map(GradientUtils* gutils, AAResults& AA, std::set volatile_args) { std::map can_modref_map; @@ -63,21 +62,33 @@ std::map compute_volatile_load_map(GradientUtils* gutils, AA if (gutils->isConstantValue(inst) || gutils->isConstantInstruction(inst)) { continue; } - bool can_modref = false; + bool can_modref = false; auto obj = GetUnderlyingObject(op->getPointerOperand(), BB->getModule()->getDataLayout(), 100); if (auto arg = dyn_cast(obj)) { if (volatile_args.find(arg->getArgNo()) != volatile_args.end()) { can_modref = true; } } - - for (unsigned k = 0; k < gutils->originalBlocks.size(); k++) { - if (AA.canBasicBlockModify(*(gutils->originalBlocks[k]), MemoryLocation::get(op))) { - can_modref = true; - break; + for (BasicBlock* BB2 : gutils->originalBlocks) { + for (auto I2 = BB2->begin(), E2 = BB2->end(); I2 != E2; I2++) { + Instruction* inst2 = &*I2; + if (inst == inst2) continue; + if (!gutils->DT.dominates(inst2, inst)) { + if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(op)))) { + can_modref = true; + break; + } + } } } + // NOTE(TFK): I need a testcase where this logic below fails to test correctness of logic above. + //for (unsigned k = 0; k < gutils->originalBlocks.size(); k++) { + // if (AA.canBasicBlockModify(*(gutils->originalBlocks[k]), MemoryLocation::get(op))) { + // can_modref = true; + // break; + // } + //} can_modref_map[inst] = can_modref; } } diff --git a/enzyme/functional_tests_c/insertsort_sum.c b/enzyme/functional_tests_c/insertsort_sum.c index e8ae9249bb8b..0bd3bfd4f4a5 100644 --- a/enzyme/functional_tests_c/insertsort_sum.c +++ b/enzyme/functional_tests_c/insertsort_sum.c @@ -43,6 +43,7 @@ void insertsort_sum (float* array, int N, float* ret) { *ret = sum; } +<<<<<<< HEAD <<<<<<< HEAD ======= @@ -52,6 +53,8 @@ void insertsort_sum (float* array, int N, float* ret) { >>>>>>> add missing files and fix minor bugs +======= +>>>>>>> put in the more strict/correct logic for ordering instructions in single function when checking for modrefs int main(int argc, char** argv) { diff --git a/enzyme/functional_tests_c/readwriteread.c b/enzyme/functional_tests_c/readwriteread.c index 06dfafd54381..355c632190a2 100644 --- a/enzyme/functional_tests_c/readwriteread.c +++ b/enzyme/functional_tests_c/readwriteread.c @@ -4,21 +4,12 @@ #include #define __builtin_autodiff __enzyme_autodiff double __enzyme_autodiff(void*, ...); -int counter = 0; -double recurse_max_helper(float* a, float* b, int N) { - if (N <= 0) { - return *a + *b; - } - return recurse_max_helper(a,b,N-1) + recurse_max_helper(a,b,N-2); -} - double f_read(double* x) { double product = (*x) * (*x); return product; } - void g_write(double* x, double product) { *x = (*x) * product; } @@ -27,7 +18,6 @@ double h_read(double* x) { return *x; } - double readwriteread_helper(double* x) { double product = f_read(x); g_write(x, product); @@ -37,15 +27,9 @@ double readwriteread_helper(double* x) { void readwriteread(double*__restrict x, double*__restrict ret) { *ret = readwriteread_helper(x); - //*ret = (*x) * (*x) * (*x); } - - int main(int argc, char** argv) { - - - double ret = 0; double dret = 1.0; double* x = (double*) malloc(sizeof(double)); @@ -58,10 +42,5 @@ int main(int argc, char** argv) { printf("dx is %f ret is %f\n", *dx, ret); assert(*dx == 3*2.0*2.0); - //assert(db == 17711.0*2); - - - - //printf("hello! %f, res2 %f, da: %f, db: %f\n", ret, ret, da,db); return 0; } From 7528b5d5a18ab95ea27e257dfa6fc899e7f9b7fc Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Fri, 1 Nov 2019 03:02:25 +0000 Subject: [PATCH 12/22] Check whether a loaded value is needed --- only at the top level for now. --- enzyme/Enzyme/EnzymeLogic.cpp | 57 ++++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 18c31286ad21..e203ba1860e3 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -233,6 +233,55 @@ std::map > compute_volatile_args_for_callsites( return volatile_args_map; } +// Determine if a load is needed in the reverse pass. We only use this logic in the top level function right now. +bool is_load_needed_in_reverse(GradientUtils* gutils, AAResults& AA, Instruction* inst) { + + std::vector uses_list; + std::set uses_set; + uses_list.push_back(inst); + uses_set.insert(inst); + + while (true) { + bool new_user_added = false; + for (unsigned i = 0; i < uses_list.size(); i++) { + for (auto use = uses_list[i]->use_begin(); use != uses_list[i]->use_end(); use++) { + Value* v = (*use); + if (uses_set.find(v) == uses_set.end()) { + uses_set.insert(v); + uses_list.push_back(v); + new_user_added = true; + } + } + } + if (!new_user_added) break; + } + + for (unsigned i = 0; i < uses_list.size(); i++) { + if (uses_list[i] == dyn_cast(inst)) continue; + if (auto op = dyn_cast(uses_list[i])) { + if (op->getOpcode() == Instruction::FAdd || op->getOpcode() == Instruction::FSub) { + continue; + } else { + llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; + return true; + } + } + + if (auto op = dyn_cast(uses_list[i])) { + llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; + return true; + } + + if (auto op = dyn_cast(uses_list[i])) { + llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; + return true; + } + return true; + } + return false; +} + + //! return structtype if recursive function std::pair CreateAugmentedPrimal(Function* todiff, AAResults &AA, const std::set& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed, const std::set _volatile_args) { static std::map, std::set, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair> cachedfunctions; @@ -1733,8 +1782,14 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co std::map can_modref_map; // NOTE(TFK): Sanity check this decision. // Is it always possibly to recompute the result of loads at top level? - if (!topLevel) { can_modref_map = compute_volatile_load_map(gutils, AA, _volatile_args); + if (topLevel) { + for (auto iter = can_modref_map.begin(); iter != can_modref_map.end(); iter++) { + if (iter->second) { + bool is_needed = is_load_needed_in_reverse(gutils, AA, iter->first); + can_modref_map[iter->first] = is_needed; + } + } } gutils->can_modref_map = &can_modref_map; From 2ac9d86ffb213e532bf3c9e169533c18ff2cdc4b Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Fri, 1 Nov 2019 04:34:04 +0000 Subject: [PATCH 13/22] bugfix. still unsure if the logic used at topLevel for detecting when we can avoid caching loads is correct though --- enzyme/Enzyme/EnzymeLogic.cpp | 28 +++++++++++++--------- enzyme/functional_tests_c/insertsort_sum.c | 8 +++++-- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index e203ba1860e3..d78f439867e2 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -77,6 +77,7 @@ std::map compute_volatile_load_map(GradientUtils* gutils, AA if (!gutils->DT.dominates(inst2, inst)) { if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(op)))) { can_modref = true; + //llvm::errs() << *inst << " needs to be cached due to: " << *inst2 << "\n"; break; } } @@ -244,8 +245,9 @@ bool is_load_needed_in_reverse(GradientUtils* gutils, AAResults& AA, Instruction while (true) { bool new_user_added = false; for (unsigned i = 0; i < uses_list.size(); i++) { - for (auto use = uses_list[i]->use_begin(); use != uses_list[i]->use_end(); use++) { + for (auto use = uses_list[i]->user_begin(), end = uses_list[i]->user_end(); use != end; ++use) { Value* v = (*use); + //llvm::errs() << "Use list: " << *v << "\n"; if (uses_set.find(v) == uses_set.end()) { uses_set.insert(v); uses_list.push_back(v); @@ -255,27 +257,31 @@ bool is_load_needed_in_reverse(GradientUtils* gutils, AAResults& AA, Instruction } if (!new_user_added) break; } - + //llvm::errs() << "Analysis for load " << *inst << " which has nuses: " << inst->getNumUses() << "\n"; for (unsigned i = 0; i < uses_list.size(); i++) { + //llvm::errs() << "Considering use " << *uses_list[i] << "\n"; if (uses_list[i] == dyn_cast(inst)) continue; + + if (isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || + isa(uses_list[i]) || isa(uses_list[i])){ + continue; + } + if (auto op = dyn_cast(uses_list[i])) { if (op->getOpcode() == Instruction::FAdd || op->getOpcode() == Instruction::FSub) { continue; } else { - llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; + //llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; return true; } } - if (auto op = dyn_cast(uses_list[i])) { - llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; - return true; - } + //if (auto op = dyn_cast(uses_list[i])) { + // //llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; + // return true; + //} - if (auto op = dyn_cast(uses_list[i])) { - llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; - return true; - } + //llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *uses_list[i] << "\n"; return true; } return false; diff --git a/enzyme/functional_tests_c/insertsort_sum.c b/enzyme/functional_tests_c/insertsort_sum.c index 0bd3bfd4f4a5..e68ef37721d0 100644 --- a/enzyme/functional_tests_c/insertsort_sum.c +++ b/enzyme/functional_tests_c/insertsort_sum.c @@ -20,8 +20,12 @@ float* unsorted_array_init(int N) { <<<<<<< HEAD ======= //__attribute__((noinline)) +<<<<<<< HEAD >>>>>>> add missing files and fix minor bugs void insertsort_sum (float* array, int N, float* ret) { +======= +void insertsort_sum (float*__restrict array, int N, float*__restrict ret) { +>>>>>>> bugfix. still unsure if the logic used at topLevel for detecting when we can avoid caching loads is correct though float sum = 0; //qsort(array, N, sizeof(float), cmp); @@ -35,11 +39,11 @@ void insertsort_sum (float* array, int N, float* ret) { } } - for (int i = 0; i < N/2; i++) { - printf("Val: %f\n", array[i]); + //printf("Val: %f\n", array[i]); sum += array[i]; } + *ret = sum; } From 95ef7599831eae07e4807fe37c6e9c19d73bb310 Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Tue, 5 Nov 2019 06:33:11 +0000 Subject: [PATCH 14/22] intermediate commit --- enzyme/Enzyme/ActiveVariable.cpp | 3 +- enzyme/Enzyme/EnzymeLogic.cpp | 84 +++++++++++++++++++++----------- enzyme/Enzyme/FunctionUtils.cpp | 11 +++-- enzyme/Enzyme/GradientUtils.cpp | 1 + 4 files changed, 67 insertions(+), 32 deletions(-) diff --git a/enzyme/Enzyme/ActiveVariable.cpp b/enzyme/Enzyme/ActiveVariable.cpp index 1bafa198c96f..dae69a858cdc 100644 --- a/enzyme/Enzyme/ActiveVariable.cpp +++ b/enzyme/Enzyme/ActiveVariable.cpp @@ -152,7 +152,8 @@ bool isIntASecretFloat(Value* val) { if (!pointerUse && floatingUse) return true; llvm::errs() << *inst->getParent()->getParent() << "\n"; llvm::errs() << " val:" << *val << " pointer:" << pointerUse << " floating:" << floatingUse << "\n"; - assert(0 && "ambiguous unsure if constant or not"); + //assert(0 && "ambiguous unsure if constant or not"); + return false; // NOTE(TFK): Return false instead of asset. } llvm::errs() << *val << "\n"; diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index d78f439867e2..94f2e84fc98c 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -77,7 +77,7 @@ std::map compute_volatile_load_map(GradientUtils* gutils, AA if (!gutils->DT.dominates(inst2, inst)) { if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(op)))) { can_modref = true; - //llvm::errs() << *inst << " needs to be cached due to: " << *inst2 << "\n"; + llvm::errs() << *inst << " needs to be cached due to: " << *inst2 << "\n"; break; } } @@ -132,20 +132,21 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ for(BasicBlock* BB: gutils->originalBlocks) { for (auto I = BB->begin(), E = BB->end(); I != E; I++) { Instruction* inst = &*I; - if (inst == callsite_inst) continue; + //if (inst == callsite_inst) continue; // If the "inst" does not dominate "callsite_inst" then we cannot prove that // "inst" happens before "callsite_inst". If "inst" modifies an argument of the call, // then that call needs to consider the argument volatile. if (!gutils->DT.dominates(inst, callsite_inst)) { + llvm::errs() << "Instruction " << *inst << " DOES NOT dominates " << *callsite_inst << "\n"; // Consider Store Instructions. if (auto op = dyn_cast(inst)) { for (unsigned i = 0; i < args.size(); i++) { // If the modification flag is set, then this instruction may modify the $i$th argument of the call. if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) { - //llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; + llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; } else { - //llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; + llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; args_safe[i] = false; } } @@ -153,6 +154,7 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ // Consider Call Instructions. if (auto op = dyn_cast(inst)) { + llvm::errs() << "OP is call inst: " << *op << "\n"; // Ignore memory allocation functions. Function* called = op->getCalledFunction(); if (auto castinst = dyn_cast(op->getCalledValue())) { @@ -165,6 +167,7 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ } } if (isCertainMallocOrFree(called)) { + llvm::errs() << "OP is certain malloc or free: " << *op << "\n"; continue; } @@ -172,23 +175,25 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ for (unsigned i = 0; i < args.size(); i++) { if (!args[i]->getType()->isPointerTy()) continue; // Ignore non-pointer arguments. if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) { - //llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; + llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; } else { - //llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; + llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; args_safe[i] = false; } } } + } else { + llvm::errs() << "Instruction " << *inst << " DOES dominates " << *callsite_inst << "\n"; } } } - //llvm::errs() << "CallInst: " << *callsite_op<< "CALL ARGUMENT INFO: \n"; + llvm::errs() << "CallInst: " << *callsite_op<< "CALL ARGUMENT INFO: \n"; for (unsigned i = 0; i < args.size(); i++) { if (!args_safe[i]) { volatile_args.insert(i); } - //llvm::errs() << "Arg: " << *args[i] << " STATUS: " << args_safe[i] << "\n"; + llvm::errs() << "Arg: " << *args[i] << " STATUS: " << args_safe[i] << "\n"; } return volatile_args; } @@ -289,10 +294,11 @@ bool is_load_needed_in_reverse(GradientUtils* gutils, AAResults& AA, Instruction //! return structtype if recursive function -std::pair CreateAugmentedPrimal(Function* todiff, AAResults &AA, const std::set& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed, const std::set _volatile_args) { +std::pair CreateAugmentedPrimal(Function* todiff, AAResults &_AA, const std::set& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed, const std::set _volatile_args) { static std::map, std::set, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair> cachedfunctions; static std::map, std::set, bool/*differentialReturn*/, bool/*returnUsed*/>, bool> cachedfinished; auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), std::set(_volatile_args.begin(), _volatile_args.end()), differentialReturn, returnUsed); + llvm::errs() << "Create Augmented Primal " << todiff->getName() << "\n"; if (cachedfunctions.find(tup) != cachedfunctions.end()) { return cachedfunctions[tup]; } @@ -351,7 +357,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul llvm::errs() << *todiff << "\n"; } assert(!todiff->empty()); - + AAResults AA(TLI); GradientUtils *gutils = GradientUtils::CreateFromClone(todiff, AA, TLI, constant_args, /*returnValue*/returnUsed ? ReturnType::TapeAndReturns : ReturnType::Tape, /*differentialReturn*/differentialReturn); cachedfunctions[tup] = std::pair(gutils->newFunc, nullptr); cachedfinished[tup] = false; @@ -362,6 +368,14 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul std::map can_modref_map = compute_volatile_load_map(gutils, AA, _volatile_args); gutils->can_modref_map = &can_modref_map; + for (auto iter = can_modref_map.begin(); iter != can_modref_map.end(); iter++) { + if (iter->second) { + bool is_needed = is_load_needed_in_reverse(gutils, AA, iter->first); + //can_modref_map[iter->first] = is_needed; + } + } + + gutils->forceContexts(); gutils->forceAugmentedReturns(); @@ -613,7 +627,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul } } - auto newcalled = CreateAugmentedPrimal(dyn_cast(called), AA, subconstant_args, TLI, /*differentialReturn*/subdifferentialreturn, /*return is used*/subretused, volatile_args_map[op]).first; + auto newcalled = CreateAugmentedPrimal(dyn_cast(called), _AA, subconstant_args, TLI, /*differentialReturn*/subdifferentialreturn, /*return is used*/subretused, volatile_args_map[op]).first; auto augmentcall = BuilderZ.CreateCall(newcalled, args); assert(augmentcall->getType()->isStructTy()); augmentcall->setCallingConv(op->getCallingConv()); @@ -644,20 +658,20 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul gutils->addMalloc(BuilderZ, rv); } - if ((op->getType()->isPointerTy() || op->getType()->isIntegerTy()) && subdifferentialreturn) { + if ((op->getType()->isPointerTy() || op->getType()->isIntegerTy()) && gutils->invertedPointers.count(op) != 0) { auto placeholder = cast(gutils->invertedPointers[op]); if (I != E && placeholder == &*I) I++; gutils->invertedPointers.erase(op); - - assert(cast(augmentcall->getType())->getNumElements() == 3); - auto antiptr = cast(BuilderZ.CreateExtractValue(augmentcall, {2}, "antiptr_" + op->getName() )); - gutils->invertedPointers[rv] = antiptr; - placeholder->replaceAllUsesWith(antiptr); - - if (shouldCache) { - gutils->addMalloc(BuilderZ, antiptr); + if (subdifferentialreturn) { + assert(cast(augmentcall->getType())->getNumElements() == 3); + auto antiptr = cast(BuilderZ.CreateExtractValue(augmentcall, {2}, "antiptr_" + op->getName() )); + gutils->invertedPointers[rv] = antiptr; + placeholder->replaceAllUsesWith(antiptr); + + if (shouldCache) { + gutils->addMalloc(BuilderZ, antiptr); + } } - gutils->erase(placeholder); } else { if (cast(augmentcall->getType())->getNumElements() != 2) { @@ -671,6 +685,14 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul } gutils->replaceAWithB(op,rv); + } else { + if ((op->getType()->isPointerTy() || op->getType()->isIntegerTy()) && gutils->invertedPointers.count(op) != 0) { + auto placeholder = cast(gutils->invertedPointers[op]); + if (I != E && placeholder == &*I) I++; + gutils->invertedPointers.erase(op); + gutils->erase(placeholder); + } + } gutils->erase(op); @@ -1150,7 +1172,8 @@ std::pair,SmallVector> getDefaultFunctionTypeForGr return std::pair,SmallVector>(args, outs); } -void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::reverse_iterator &E, IRBuilder <>& Builder2, CallInst* op, DiffeGradientUtils* const gutils, TargetLibraryInfo &TLI, AAResults &AA, const bool topLevel, const std::map &replacedReturns, std::set volatile_args) { +void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::reverse_iterator &E, IRBuilder <>& Builder2, CallInst* op, DiffeGradientUtils* const gutils, TargetLibraryInfo &TLI, AAResults &AA, AAResults & _AA, const bool topLevel, const std::map &replacedReturns, std::set volatile_args) { + llvm::errs() << "HandleGradientCall " << *op << "\n"; Function *called = op->getCalledFunction(); if (auto castinst = dyn_cast(op->getCalledValue())) { @@ -1360,7 +1383,9 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r ModRefInfo mri = ModRefInfo::NoModRef; if (iter->mayReadOrWriteMemory()) { - mri = AA.getModRefInfo(&*iter, origop); + llvm::errs() << "Iter is at " << *iter << "\n"; + llvm::errs() << "origop is at " << *origop << "\n"; + mri = _AA.getModRefInfo(&*iter, origop); } if (mri == ModRefInfo::NoModRef && !usesInst) { @@ -1491,7 +1516,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r if (modifyPrimal && called) { bool subretused = op->getNumUses() != 0; bool subdifferentialreturn = (!gutils->isConstantValue(op)) && subretused; - auto fnandtapetype = CreateAugmentedPrimal(cast(called), AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/subretused, volatile_args); + auto fnandtapetype = CreateAugmentedPrimal(cast(called), _AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/subretused, volatile_args); if (topLevel) { Function* newcalled = fnandtapetype.first; augmentcall = BuilderZ.CreateCall(newcalled, pre_args); @@ -1563,7 +1588,8 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r bool subdiffereturn = (!gutils->isConstantValue(op)) && !( op->getType()->isPointerTy() || op->getType()->isIntegerTy() || op->getType()->isEmptyTy() ); llvm::errs() << "subdifferet:" << subdiffereturn << " " << *op << "\n"; if (called) { - newcalled = CreatePrimalAndGradient(cast(called), subconstant_args, TLI, AA, /*returnValue*/retUsed, /*subdiffereturn*/subdiffereturn, /*topLevel*/replaceFunction, tape ? tape->getType() : nullptr, volatile_args);//, LI, DT); + llvm::errs() << "Before create primal and gradient instruction is " << *op << "\n"; + newcalled = CreatePrimalAndGradient(cast(called), subconstant_args, TLI, _AA, /*returnValue*/retUsed, /*subdiffereturn*/subdiffereturn, /*topLevel*/replaceFunction, tape ? tape->getType() : nullptr, volatile_args);//, LI, DT); } else { newcalled = gutils->invertPointerM(op->getCalledValue(), Builder2); auto ft = cast(cast(op->getCalledValue()->getType())->getElementType()); @@ -1673,7 +1699,8 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r } } -Function* CreatePrimalAndGradient(Function* todiff, const std::set& constant_args, TargetLibraryInfo &TLI, AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg, std::set _volatile_args) { +Function* CreatePrimalAndGradient(Function* todiff, const std::set& constant_args, TargetLibraryInfo &TLI, AAResults &_AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg, std::set _volatile_args) { + llvm::errs() << "Create Primal And Gradient " << todiff->getName() << "\n"; if (differentialReturn) { if(!todiff->getReturnType()->isFPOrFPVectorTy()) { llvm::errs() << *todiff << "\n"; @@ -1778,7 +1805,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co auto M = todiff->getParent(); auto& Context = M->getContext(); - + AAResults AA(TLI); DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone(todiff, AA, TLI, constant_args, returnValue ? ReturnType::ArgsWithReturn : ReturnType::Args, differentialReturn, additionalArg); cachedfunctions[tup] = gutils->newFunc; @@ -1966,6 +1993,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co break; } default: + continue; // NOTE(TFK) added this. assert(op); llvm::errs() << *gutils->newFunc << "\n"; llvm::errs() << "cannot handle unknown binary operator: " << *op << "\n"; @@ -2202,7 +2230,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co if (dif0) addToDiffe(op->getOperand(0), dif0); if (dif1) addToDiffe(op->getOperand(1), dif1); } else if(auto op = dyn_cast_or_null(inst)) { - handleGradientCallInst(I, E, Builder2, op, gutils, TLI, AA, topLevel, replacedReturns, volatile_args_map[op]); + handleGradientCallInst(I, E, Builder2, op, gutils, TLI, _AA, _AA, topLevel, replacedReturns, volatile_args_map[op]); } else if(auto op = dyn_cast_or_null(inst)) { if (gutils->isConstantValue(inst)) continue; if (op->getType()->isPointerTy()) continue; diff --git a/enzyme/Enzyme/FunctionUtils.cpp b/enzyme/Enzyme/FunctionUtils.cpp index c8dd75d0b4ee..cf233f0f3247 100644 --- a/enzyme/Enzyme/FunctionUtils.cpp +++ b/enzyme/Enzyme/FunctionUtils.cpp @@ -164,8 +164,13 @@ PHINode* canonicalizeIVs(fake::SCEVExpander &e, Type *Ty, Loop *L, DominatorTree Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI) { static std::map cache; - if (cache.find(F) != cache.end()) return cache[F]; - + static std::map cache_AA; + llvm::errs() << "Before cache lookup for " << F->getName() << "\n"; + if (cache.find(F) != cache.end()) { + AA.addAAResult(*(cache_AA[F])); + return cache[F]; + } + llvm::errs() << "Did not do cache lookup for " << F->getName() << "\n"; Function *NewF = Function::Create(F->getFunctionType(), F->getLinkage(), "preprocess_" + F->getName(), F->getParent()); ValueToValueMapTy VMap; @@ -469,8 +474,8 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI) &AM.getResult(*NewF), AM.getCachedResult(*NewF), AM.getCachedResult(*NewF)); + cache_AA[F] = baa; AA.addAAResult(*baa); - //ScopedNoAliasAA sa; //auto saa = new ScopedNoAliasAAResult(sa.run(*NewF, AM)); //AA.addAAResult(*saa); diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 0b8bb9a570f1..23ff5015022c 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -306,6 +306,7 @@ DiffeGradientUtils* DiffeGradientUtils::CreateFromClone(Function *todiff, AAResu } Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) { + if (isa(val)) return val; if (isa(val)) { return val; } else if (isa(val)) { From a166347669fac1d1a8211030dd02a9c91dd70b9a Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Tue, 5 Nov 2019 18:40:28 +0000 Subject: [PATCH 15/22] separate global and local aa results for speed on large programs; add logic for more conservative detection of uncacheable loads --- enzyme/Enzyme/EnzymeLogic.cpp | 116 ++++++++++++++---- enzyme/Enzyme/GradientUtils.cpp | 2 +- .../functional_tests_c/insertsort_sum_alt.c | 4 +- 3 files changed, 92 insertions(+), 30 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 94f2e84fc98c..9259e647e3ee 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -51,25 +51,69 @@ cl::opt cachereads( "enzyme_cachereads", cl::init(true), cl::Hidden, cl::desc("Force caching of all reads")); -std::map compute_volatile_load_map(GradientUtils* gutils, AAResults& AA, + + + + + +// Computes a map of LoadInst -> boolean for a function indicating whether that load is "volatile". +// A load is considered "volatile" if the data at the loaded memory location can be modified after +// the load instruction. +std::map compute_volatile_load_map(GradientUtils* gutils, AAResults& AA, TargetLibraryInfo& TLI, std::set volatile_args) { std::map can_modref_map; - // NOTE(TFK): Want to construct a test case where this causes an issue. for(BasicBlock* BB: gutils->originalBlocks) { for (auto I = BB->begin(), E = BB->end(); I != E; I++) { Instruction* inst = &*I; + // For each load instruction, determine if it is volatile. if (auto op = dyn_cast(inst)) { + // NOTE(TFK): The reasoning behind skipping ConstantValues and ConstantInstructions needs to be fleshed out. if (gutils->isConstantValue(inst) || gutils->isConstantInstruction(inst)) { continue; } bool can_modref = false; + // Find the underlying object for the pointer operand of the load instruction. auto obj = GetUnderlyingObject(op->getPointerOperand(), BB->getModule()->getDataLayout(), 100); + // If the pointer operand is from an argument to the function, we need to check if the argument + // received from the caller is volatile. if (auto arg = dyn_cast(obj)) { if (volatile_args.find(arg->getArgNo()) != volatile_args.end()) { can_modref = true; } + } else { + // NOTE(TFK): In the case where the underlying object for the pointer operand is from a Load or Call we need + // to check if we need to cache. Likely, we need to play it safe in this case and cache. + // NOTE(TFK): The logic below is an attempt at a conservative handling of the case mentioned above, but it + // needs to be verified. + + // Pointer operands originating from call instructions that are not malloc/free are conservatively considered volatile. + if (auto obj_op = dyn_cast(obj)) { + Function* called = obj_op->getCalledFunction(); + if (auto castinst = dyn_cast(obj_op->getCalledValue())) { + if (castinst->isCast()) { + if (auto fn = dyn_cast(castinst->getOperand(0))) { + if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { + called = fn; + } + } + } + } + if (isCertainMallocOrFree(called)) { + llvm::errs() << "OP is certain malloc or free: " << *op << "\n"; + } else { + llvm::errs() << "OP is a non malloc/free call so we need to cache " << *op << "\n"; + can_modref = true; + } + } else if (auto obj_op = dyn_cast(obj)) { + // If obj is from a load instruction conservatively consider it volatile. + can_modref = true; + } else { + // In absence of more information, assume that the underlying object for pointer operand is volatile in caller. + can_modref = true; + } } + for (BasicBlock* BB2 : gutils->originalBlocks) { for (auto I2 = BB2->begin(), E2 = BB2->end(); I2 != E2; I2++) { Instruction* inst2 = &*I2; @@ -77,19 +121,12 @@ std::map compute_volatile_load_map(GradientUtils* gutils, AA if (!gutils->DT.dominates(inst2, inst)) { if (llvm::isModSet(AA.getModRefInfo(inst2, MemoryLocation::get(op)))) { can_modref = true; - llvm::errs() << *inst << " needs to be cached due to: " << *inst2 << "\n"; + //llvm::errs() << *inst << " needs to be cached due to: " << *inst2 << "\n"; break; } } } } - // NOTE(TFK): I need a testcase where this logic below fails to test correctness of logic above. - //for (unsigned k = 0; k < gutils->originalBlocks.size(); k++) { - // if (AA.canBasicBlockModify(*(gutils->originalBlocks[k]), MemoryLocation::get(op))) { - // can_modref = true; - // break; - // } - //} can_modref_map[inst] = can_modref; } } @@ -97,7 +134,6 @@ std::map compute_volatile_load_map(GradientUtils* gutils, AA return can_modref_map; } - std::set compute_volatile_args_for_one_callsite(Instruction* callsite_inst, DominatorTree &DT, TargetLibraryInfo &TLI, AAResults& AA, GradientUtils* gutils, std::set parent_volatile_args) { CallInst* callsite_op = dyn_cast(callsite_inst); @@ -122,6 +158,32 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ if (parent_volatile_args.find(arg->getArgNo()) != parent_volatile_args.end()) { init_safe = false; } + } else { + // Pointer operands originating from call instructions that are not malloc/free are conservatively considered volatile. + if (auto obj_op = dyn_cast(obj)) { + Function* called = obj_op->getCalledFunction(); + if (auto castinst = dyn_cast(obj_op->getCalledValue())) { + if (castinst->isCast()) { + if (auto fn = dyn_cast(castinst->getOperand(0))) { + if (isAllocationFunction(*fn, TLI) || isDeallocationFunction(*fn, TLI)) { + called = fn; + } + } + } + } + if (isCertainMallocOrFree(called)) { + //llvm::errs() << "OP is certain malloc or free: " << *op << "\n"; + } else { + //llvm::errs() << "OP is a non malloc/free call so we need to cache " << *op << "\n"; + init_safe = false; + } + } else if (auto obj_op = dyn_cast(obj)) { + // If obj is from a load instruction conservatively consider it volatile. + init_safe = false; + } else { + // In absence of more information, assume that the underlying object for pointer operand is volatile in caller. + init_safe = false; + } } // TODO(TFK): Also need to check whether underlying object is traced to load / non-allocating-call instruction. args_safe.push_back(init_safe); @@ -262,13 +324,13 @@ bool is_load_needed_in_reverse(GradientUtils* gutils, AAResults& AA, Instruction } if (!new_user_added) break; } - //llvm::errs() << "Analysis for load " << *inst << " which has nuses: " << inst->getNumUses() << "\n"; + llvm::errs() << "Analysis for load " << *inst << " which has nuses: " << inst->getNumUses() << "\n"; for (unsigned i = 0; i < uses_list.size(); i++) { - //llvm::errs() << "Considering use " << *uses_list[i] << "\n"; + llvm::errs() << "Considering use " << *uses_list[i] << "\n"; if (uses_list[i] == dyn_cast(inst)) continue; - if (isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || - isa(uses_list[i]) || isa(uses_list[i])){ + if (isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || + isa(uses_list[i]) /*|| isa(uses_list[i])*/){ continue; } @@ -276,18 +338,18 @@ bool is_load_needed_in_reverse(GradientUtils* gutils, AAResults& AA, Instruction if (op->getOpcode() == Instruction::FAdd || op->getOpcode() == Instruction::FSub) { continue; } else { - //llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; + llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; return true; } } //if (auto op = dyn_cast(uses_list[i])) { - // //llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; + // llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; // return true; //} //llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *uses_list[i] << "\n"; - return true; + //return true; } return false; } @@ -365,15 +427,15 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul std::map > volatile_args_map = compute_volatile_args_for_callsites(gutils->oldFunc, gutils->DT, TLI, AA, gutils, _volatile_args); - std::map can_modref_map = compute_volatile_load_map(gutils, AA, _volatile_args); + std::map can_modref_map = compute_volatile_load_map(gutils, AA, TLI, _volatile_args); gutils->can_modref_map = &can_modref_map; - for (auto iter = can_modref_map.begin(); iter != can_modref_map.end(); iter++) { - if (iter->second) { - bool is_needed = is_load_needed_in_reverse(gutils, AA, iter->first); - //can_modref_map[iter->first] = is_needed; - } - } + //for (auto iter = can_modref_map.begin(); iter != can_modref_map.end(); iter++) { + // if (iter->second) { + // bool is_needed = is_load_needed_in_reverse(gutils, AA, iter->first); + // can_modref_map[iter->first] = is_needed; + // } + //} gutils->forceContexts(); @@ -1815,7 +1877,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co std::map can_modref_map; // NOTE(TFK): Sanity check this decision. // Is it always possibly to recompute the result of loads at top level? - can_modref_map = compute_volatile_load_map(gutils, AA, _volatile_args); + can_modref_map = compute_volatile_load_map(gutils, AA, TLI, _volatile_args); if (topLevel) { for (auto iter = can_modref_map.begin(); iter != can_modref_map.end(); iter++) { if (iter->second) { @@ -1993,7 +2055,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co break; } default: - continue; // NOTE(TFK) added this. + //continue; // NOTE(TFK) added this. assert(op); llvm::errs() << *gutils->newFunc << "\n"; llvm::errs() << "cannot handle unknown binary operator: " << *op << "\n"; diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 23ff5015022c..bfc74aceea34 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -306,7 +306,7 @@ DiffeGradientUtils* DiffeGradientUtils::CreateFromClone(Function *todiff, AAResu } Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) { - if (isa(val)) return val; + //if (isa(val)) return val; if (isa(val)) { return val; } else if (isa(val)) { diff --git a/enzyme/functional_tests_c/insertsort_sum_alt.c b/enzyme/functional_tests_c/insertsort_sum_alt.c index 10cee35434ba..944804b6b271 100644 --- a/enzyme/functional_tests_c/insertsort_sum_alt.c +++ b/enzyme/functional_tests_c/insertsort_sum_alt.c @@ -35,7 +35,7 @@ void insertion_sort_inner(float* array, int i) { } // sums the first half of a sorted array. -void insertsort_sum (float* array, int N, float* ret) { +void insertsort_sum (float*__restrict array, int N, float*__restrict ret) { float sum = 0; //qsort(array, N, sizeof(float), cmp); @@ -45,7 +45,7 @@ void insertsort_sum (float* array, int N, float* ret) { for (int i = 0; i < N/2; i++) { - printf("Val: %f\n", array[i]); + //printf("Val: %f\n", array[i]); sum += array[i]; } *ret = sum; From 9484ab9f70ee8fbfcecc3900a6ad56676081505f Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Tue, 5 Nov 2019 20:54:56 +0000 Subject: [PATCH 16/22] cleanup --- enzyme/Enzyme/EnzymeLogic.cpp | 190 ++++++++++++++++------------- enzyme/Enzyme/GradientUtils.cpp | 4 +- enzyme/functional_tests_c/Makefile | 2 +- 3 files changed, 106 insertions(+), 90 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 9259e647e3ee..cd3a08324c69 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -47,25 +47,26 @@ using namespace llvm; llvm::cl::opt enzyme_print("enzyme_print", cl::init(false), cl::Hidden, cl::desc("Print before and after fns for autodiff")); -cl::opt cachereads( - "enzyme_cachereads", cl::init(true), cl::Hidden, +cl::opt cache_reads_always( + "enzyme_always_cache_reads", cl::init(false), cl::Hidden, cl::desc("Force caching of all reads")); +cl::opt cache_reads_never( + "enzyme_never_cache_reads", cl::init(false), cl::Hidden, + cl::desc("Force caching of all reads")); - - -// Computes a map of LoadInst -> boolean for a function indicating whether that load is "volatile". -// A load is considered "volatile" if the data at the loaded memory location can be modified after +// Computes a map of LoadInst -> boolean for a function indicating whether that load is "uncacheable". +// A load is considered "uncacheable" if the data at the loaded memory location can be modified after // the load instruction. -std::map compute_volatile_load_map(GradientUtils* gutils, AAResults& AA, TargetLibraryInfo& TLI, - std::set volatile_args) { +std::map compute_uncacheable_load_map(GradientUtils* gutils, AAResults& AA, TargetLibraryInfo& TLI, + const std::set uncacheable_args) { std::map can_modref_map; for(BasicBlock* BB: gutils->originalBlocks) { for (auto I = BB->begin(), E = BB->end(); I != E; I++) { Instruction* inst = &*I; - // For each load instruction, determine if it is volatile. + // For each load instruction, determine if it is uncacheable. if (auto op = dyn_cast(inst)) { // NOTE(TFK): The reasoning behind skipping ConstantValues and ConstantInstructions needs to be fleshed out. if (gutils->isConstantValue(inst) || gutils->isConstantInstruction(inst)) { @@ -76,9 +77,9 @@ std::map compute_volatile_load_map(GradientUtils* gutils, AA // Find the underlying object for the pointer operand of the load instruction. auto obj = GetUnderlyingObject(op->getPointerOperand(), BB->getModule()->getDataLayout(), 100); // If the pointer operand is from an argument to the function, we need to check if the argument - // received from the caller is volatile. + // received from the caller is uncacheable. if (auto arg = dyn_cast(obj)) { - if (volatile_args.find(arg->getArgNo()) != volatile_args.end()) { + if (uncacheable_args.find(arg->getArgNo()) != uncacheable_args.end()) { can_modref = true; } } else { @@ -87,7 +88,7 @@ std::map compute_volatile_load_map(GradientUtils* gutils, AA // NOTE(TFK): The logic below is an attempt at a conservative handling of the case mentioned above, but it // needs to be verified. - // Pointer operands originating from call instructions that are not malloc/free are conservatively considered volatile. + // Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable. if (auto obj_op = dyn_cast(obj)) { Function* called = obj_op->getCalledFunction(); if (auto castinst = dyn_cast(obj_op->getCalledValue())) { @@ -100,16 +101,16 @@ std::map compute_volatile_load_map(GradientUtils* gutils, AA } } if (isCertainMallocOrFree(called)) { - llvm::errs() << "OP is certain malloc or free: " << *op << "\n"; + //llvm::errs() << "OP is certain malloc or free: " << *op << "\n"; } else { - llvm::errs() << "OP is a non malloc/free call so we need to cache " << *op << "\n"; + //llvm::errs() << "OP is a non malloc/free call so we need to cache " << *op << "\n"; can_modref = true; } - } else if (auto obj_op = dyn_cast(obj)) { - // If obj is from a load instruction conservatively consider it volatile. + } else if (isa(obj)) { + // If obj is from a load instruction conservatively consider it uncacheable. can_modref = true; } else { - // In absence of more information, assume that the underlying object for pointer operand is volatile in caller. + // In absence of more information, assume that the underlying object for pointer operand is uncacheable in caller. can_modref = true; } } @@ -134,16 +135,16 @@ std::map compute_volatile_load_map(GradientUtils* gutils, AA return can_modref_map; } -std::set compute_volatile_args_for_one_callsite(Instruction* callsite_inst, DominatorTree &DT, - TargetLibraryInfo &TLI, AAResults& AA, GradientUtils* gutils, std::set parent_volatile_args) { +std::set compute_uncacheable_args_for_one_callsite(Instruction* callsite_inst, DominatorTree &DT, + TargetLibraryInfo &TLI, AAResults& AA, GradientUtils* gutils, const std::set parent_uncacheable_args) { CallInst* callsite_op = dyn_cast(callsite_inst); assert(callsite_op != nullptr); - std::set volatile_args; + std::set uncacheable_args; std::vector args; std::vector args_safe; - // First, we need to propagate the volatile status from the parent function to the callee. + // First, we need to propagate the uncacheable status from the parent function to the callee. // because memory location x modified after parent returns => x modified after callee returns. for (unsigned i = 0; i < callsite_op->getNumArgOperands(); i++) { args.push_back(callsite_op->getArgOperand(i)); @@ -155,11 +156,11 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ 100); // If underlying object is an Argument, check parent volatility status. if (auto arg = dyn_cast(obj)) { - if (parent_volatile_args.find(arg->getArgNo()) != parent_volatile_args.end()) { + if (parent_uncacheable_args.find(arg->getArgNo()) != parent_uncacheable_args.end()) { init_safe = false; } } else { - // Pointer operands originating from call instructions that are not malloc/free are conservatively considered volatile. + // Pointer operands originating from call instructions that are not malloc/free are conservatively considered uncacheable. if (auto obj_op = dyn_cast(obj)) { Function* called = obj_op->getCalledFunction(); if (auto castinst = dyn_cast(obj_op->getCalledValue())) { @@ -177,11 +178,11 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ //llvm::errs() << "OP is a non malloc/free call so we need to cache " << *op << "\n"; init_safe = false; } - } else if (auto obj_op = dyn_cast(obj)) { - // If obj is from a load instruction conservatively consider it volatile. + } else if (isa(obj)) { + // If obj is from a load instruction conservatively consider it uncacheable. init_safe = false; } else { - // In absence of more information, assume that the underlying object for pointer operand is volatile in caller. + // In absence of more information, assume that the underlying object for pointer operand is uncacheable in caller. init_safe = false; } } @@ -198,17 +199,17 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ // If the "inst" does not dominate "callsite_inst" then we cannot prove that // "inst" happens before "callsite_inst". If "inst" modifies an argument of the call, - // then that call needs to consider the argument volatile. + // then that call needs to consider the argument uncacheable. if (!gutils->DT.dominates(inst, callsite_inst)) { - llvm::errs() << "Instruction " << *inst << " DOES NOT dominates " << *callsite_inst << "\n"; + //llvm::errs() << "Instruction " << *inst << " DOES NOT dominates " << *callsite_inst << "\n"; // Consider Store Instructions. if (auto op = dyn_cast(inst)) { for (unsigned i = 0; i < args.size(); i++) { // If the modification flag is set, then this instruction may modify the $i$th argument of the call. if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) { - llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; + //llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; } else { - llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; + //llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; args_safe[i] = false; } } @@ -216,7 +217,7 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ // Consider Call Instructions. if (auto op = dyn_cast(inst)) { - llvm::errs() << "OP is call inst: " << *op << "\n"; + //llvm::errs() << "OP is call inst: " << *op << "\n"; // Ignore memory allocation functions. Function* called = op->getCalledFunction(); if (auto castinst = dyn_cast(op->getCalledValue())) { @@ -229,7 +230,7 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ } } if (isCertainMallocOrFree(called)) { - llvm::errs() << "OP is certain malloc or free: " << *op << "\n"; + //llvm::errs() << "OP is certain malloc or free: " << *op << "\n"; continue; } @@ -237,47 +238,47 @@ std::set compute_volatile_args_for_one_callsite(Instruction* callsite_ for (unsigned i = 0; i < args.size(); i++) { if (!args[i]->getType()->isPointerTy()) continue; // Ignore non-pointer arguments. if (!llvm::isModSet(AA.getModRefInfo(op, MemoryLocation::getForArgument(callsite_op, i, TLI)))) { - llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; + //llvm::errs() << "Instruction " << *op << " is NoModRef with call argument " << *args[i] << "\n"; } else { - llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; + //llvm::errs() << "Instruction " << *op << " is maybe ModRef with call argument " << *args[i] << "\n"; args_safe[i] = false; } } } } else { - llvm::errs() << "Instruction " << *inst << " DOES dominates " << *callsite_inst << "\n"; + //llvm::errs() << "Instruction " << *inst << " DOES dominates " << *callsite_inst << "\n"; } } } - llvm::errs() << "CallInst: " << *callsite_op<< "CALL ARGUMENT INFO: \n"; + //llvm::errs() << "CallInst: " << *callsite_op<< "CALL ARGUMENT INFO: \n"; for (unsigned i = 0; i < args.size(); i++) { if (!args_safe[i]) { - volatile_args.insert(i); + uncacheable_args.insert(i); } - llvm::errs() << "Arg: " << *args[i] << " STATUS: " << args_safe[i] << "\n"; + //llvm::errs() << "Arg: " << *args[i] << " STATUS: " << args_safe[i] << "\n"; } - return volatile_args; + return uncacheable_args; } -// Given a function and the arguments passed to it by its caller that are volatile (_volatile_args) compute -// the set of volatile arguments for each callsite inside the function. A pointer argument is volatile at +// Given a function and the arguments passed to it by its caller that are uncacheable (_uncacheable_args) compute +// the set of uncacheable arguments for each callsite inside the function. A pointer argument is uncacheable at // a callsite if the memory pointed to might be modified after that callsite. -std::map > compute_volatile_args_for_callsites( +std::map > compute_uncacheable_args_for_callsites( Function* F, DominatorTree &DT, TargetLibraryInfo &TLI, AAResults& AA, GradientUtils* gutils, - std::set const volatile_args) { - std::map > volatile_args_map; + const std::set uncacheable_args) { + std::map > uncacheable_args_map; for(BasicBlock* BB: gutils->originalBlocks) { for (auto I = BB->begin(), E = BB->end(); I != E; I++) { Instruction* inst = &*I; if (auto op = dyn_cast(inst)) { - // We do not need volatile args for intrinsic functions. So skip such callsites. + // We do not need uncacheable args for intrinsic functions. So skip such callsites. if(isa(inst)) { continue; } - // We do not need volatile args for memory allocation functions. So skip such callsites. + // We do not need uncacheable args for memory allocation functions. So skip such callsites. Function* called = op->getCalledFunction(); if (auto castinst = dyn_cast(op->getCalledValue())) { if (castinst->isCast()) { @@ -292,13 +293,13 @@ std::map > compute_volatile_args_for_callsites( continue; } - // For all other calls, we compute the volatile args for this callsite. - volatile_args_map[op] = compute_volatile_args_for_one_callsite(inst, - DT, TLI, AA, gutils, volatile_args); + // For all other calls, we compute the uncacheable args for this callsite. + uncacheable_args_map[op] = compute_uncacheable_args_for_one_callsite(inst, + DT, TLI, AA, gutils, uncacheable_args); } } } - return volatile_args_map; + return uncacheable_args_map; } // Determine if a load is needed in the reverse pass. We only use this logic in the top level function right now. @@ -324,9 +325,9 @@ bool is_load_needed_in_reverse(GradientUtils* gutils, AAResults& AA, Instruction } if (!new_user_added) break; } - llvm::errs() << "Analysis for load " << *inst << " which has nuses: " << inst->getNumUses() << "\n"; + //llvm::errs() << "Analysis for load " << *inst << " which has nuses: " << inst->getNumUses() << "\n"; for (unsigned i = 0; i < uses_list.size(); i++) { - llvm::errs() << "Considering use " << *uses_list[i] << "\n"; + //llvm::errs() << "Considering use " << *uses_list[i] << "\n"; if (uses_list[i] == dyn_cast(inst)) continue; if (isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || isa(uses_list[i]) || @@ -338,7 +339,7 @@ bool is_load_needed_in_reverse(GradientUtils* gutils, AAResults& AA, Instruction if (op->getOpcode() == Instruction::FAdd || op->getOpcode() == Instruction::FSub) { continue; } else { - llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; + //llvm::errs() << "Need value of " << *inst << "\n" << "\t Due to " << *op << "\n"; return true; } } @@ -356,11 +357,10 @@ bool is_load_needed_in_reverse(GradientUtils* gutils, AAResults& AA, Instruction //! return structtype if recursive function -std::pair CreateAugmentedPrimal(Function* todiff, AAResults &_AA, const std::set& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed, const std::set _volatile_args) { - static std::map, std::set, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair> cachedfunctions; - static std::map, std::set, bool/*differentialReturn*/, bool/*returnUsed*/>, bool> cachedfinished; - auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), std::set(_volatile_args.begin(), _volatile_args.end()), differentialReturn, returnUsed); - llvm::errs() << "Create Augmented Primal " << todiff->getName() << "\n"; +std::pair CreateAugmentedPrimal(Function* todiff, AAResults &_AA, const std::set& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed, const std::set _uncacheable_args) { + static std::map/*constant_args*/, std::set/*uncacheable_args*/, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair> cachedfunctions; + static std::map/*constant_args*/, std::set/*uncacheable_args*/, bool/*differentialReturn*/, bool/*returnUsed*/>, bool> cachedfinished; + auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), std::set(_uncacheable_args.begin(), _uncacheable_args.end()), differentialReturn, returnUsed); if (cachedfunctions.find(tup) != cachedfunctions.end()) { return cachedfunctions[tup]; } @@ -424,12 +424,22 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul cachedfunctions[tup] = std::pair(gutils->newFunc, nullptr); cachedfinished[tup] = false; - std::map > volatile_args_map = - compute_volatile_args_for_callsites(gutils->oldFunc, gutils->DT, TLI, AA, gutils, _volatile_args); + std::map > uncacheable_args_map = + compute_uncacheable_args_for_callsites(gutils->oldFunc, gutils->DT, TLI, AA, gutils, _uncacheable_args); - std::map can_modref_map = compute_volatile_load_map(gutils, AA, TLI, _volatile_args); + std::map can_modref_map = compute_uncacheable_load_map(gutils, AA, TLI, _uncacheable_args); gutils->can_modref_map = &can_modref_map; + // Allow forcing cache reads to be on or off using flags. + assert(!(cache_reads_always && cache_reads_never) && "Both cache_reads_always and cache_reads_never are true. This doesn't make sense."); + if (cache_reads_always || cache_reads_never) { + bool is_needed = cache_reads_always ? true : false; + for (auto iter = can_modref_map.begin(); iter != can_modref_map.end(); iter++) { + can_modref_map[iter->first] = is_needed; + } + } + + //for (auto iter = can_modref_map.begin(); iter != can_modref_map.end(); iter++) { // if (iter->second) { // bool is_needed = is_load_needed_in_reverse(gutils, AA, iter->first); @@ -689,7 +699,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul } } - auto newcalled = CreateAugmentedPrimal(dyn_cast(called), _AA, subconstant_args, TLI, /*differentialReturn*/subdifferentialreturn, /*return is used*/subretused, volatile_args_map[op]).first; + auto newcalled = CreateAugmentedPrimal(dyn_cast(called), _AA, subconstant_args, TLI, /*differentialReturn*/subdifferentialreturn, /*return is used*/subretused, uncacheable_args_map[op]).first; auto augmentcall = BuilderZ.CreateCall(newcalled, args); assert(augmentcall->getType()->isStructTy()); augmentcall->setCallingConv(op->getCallingConv()); @@ -760,7 +770,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul gutils->erase(op); } else if(LoadInst* li = dyn_cast(inst)) { if (gutils->isConstantInstruction(inst) || gutils->isConstantValue(inst)) continue; - if (/*true || */(cachereads && can_modref_map[inst])) { + if (can_modref_map[inst]) { llvm::errs() << "Forcibly caching reads " << *li << "\n"; IRBuilder<> BuilderZ(li); gutils->addMalloc(BuilderZ, li); @@ -1234,7 +1244,7 @@ std::pair,SmallVector> getDefaultFunctionTypeForGr return std::pair,SmallVector>(args, outs); } -void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::reverse_iterator &E, IRBuilder <>& Builder2, CallInst* op, DiffeGradientUtils* const gutils, TargetLibraryInfo &TLI, AAResults &AA, AAResults & _AA, const bool topLevel, const std::map &replacedReturns, std::set volatile_args) { +void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::reverse_iterator &E, IRBuilder <>& Builder2, CallInst* op, DiffeGradientUtils* const gutils, TargetLibraryInfo &TLI, AAResults &AA, AAResults & _AA, const bool topLevel, const std::map &replacedReturns, std::set uncacheable_args) { llvm::errs() << "HandleGradientCall " << *op << "\n"; Function *called = op->getCalledFunction(); @@ -1578,7 +1588,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r if (modifyPrimal && called) { bool subretused = op->getNumUses() != 0; bool subdifferentialreturn = (!gutils->isConstantValue(op)) && subretused; - auto fnandtapetype = CreateAugmentedPrimal(cast(called), _AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/subretused, volatile_args); + auto fnandtapetype = CreateAugmentedPrimal(cast(called), _AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/subretused, uncacheable_args); if (topLevel) { Function* newcalled = fnandtapetype.first; augmentcall = BuilderZ.CreateCall(newcalled, pre_args); @@ -1650,8 +1660,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r bool subdiffereturn = (!gutils->isConstantValue(op)) && !( op->getType()->isPointerTy() || op->getType()->isIntegerTy() || op->getType()->isEmptyTy() ); llvm::errs() << "subdifferet:" << subdiffereturn << " " << *op << "\n"; if (called) { - llvm::errs() << "Before create primal and gradient instruction is " << *op << "\n"; - newcalled = CreatePrimalAndGradient(cast(called), subconstant_args, TLI, _AA, /*returnValue*/retUsed, /*subdiffereturn*/subdiffereturn, /*topLevel*/replaceFunction, tape ? tape->getType() : nullptr, volatile_args);//, LI, DT); + newcalled = CreatePrimalAndGradient(cast(called), subconstant_args, TLI, _AA, /*returnValue*/retUsed, /*subdiffereturn*/subdiffereturn, /*topLevel*/replaceFunction, tape ? tape->getType() : nullptr, uncacheable_args);//, LI, DT); } else { newcalled = gutils->invertPointerM(op->getCalledValue(), Builder2); auto ft = cast(cast(op->getCalledValue()->getType())->getElementType()); @@ -1761,8 +1770,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r } } -Function* CreatePrimalAndGradient(Function* todiff, const std::set& constant_args, TargetLibraryInfo &TLI, AAResults &_AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg, std::set _volatile_args) { - llvm::errs() << "Create Primal And Gradient " << todiff->getName() << "\n"; +Function* CreatePrimalAndGradient(Function* todiff, const std::set& constant_args, TargetLibraryInfo &TLI, AAResults &_AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg, std::set _uncacheable_args) { if (differentialReturn) { if(!todiff->getReturnType()->isFPOrFPVectorTy()) { llvm::errs() << *todiff << "\n"; @@ -1774,8 +1782,8 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co llvm::errs() << "addl arg: " << *additionalArg << "\n"; } if (additionalArg) assert(additionalArg->isStructTy()); - static std::map, std::set, bool/*retval*/, bool/*differentialReturn*/, bool/*topLevel*/, llvm::Type*>, Function*> cachedfunctions; - auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), std::set(_volatile_args.begin(), _volatile_args.end()), returnValue, differentialReturn, topLevel, additionalArg); + static std::map/*constant_args*/, std::set/*uncacheable_args*/, bool/*retval*/, bool/*differentialReturn*/, bool/*topLevel*/, llvm::Type*>, Function*> cachedfunctions; + auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), std::set(_uncacheable_args.begin(), _uncacheable_args.end()), returnValue, differentialReturn, topLevel, additionalArg); if (cachedfunctions.find(tup) != cachedfunctions.end()) { return cachedfunctions[tup]; } @@ -1871,13 +1879,13 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co DiffeGradientUtils *gutils = DiffeGradientUtils::CreateFromClone(todiff, AA, TLI, constant_args, returnValue ? ReturnType::ArgsWithReturn : ReturnType::Args, differentialReturn, additionalArg); cachedfunctions[tup] = gutils->newFunc; - std::map > volatile_args_map = - compute_volatile_args_for_callsites(gutils->oldFunc, gutils->DT, TLI, AA, gutils, _volatile_args); + std::map > uncacheable_args_map = + compute_uncacheable_args_for_callsites(gutils->oldFunc, gutils->DT, TLI, AA, gutils, _uncacheable_args); std::map can_modref_map; // NOTE(TFK): Sanity check this decision. // Is it always possibly to recompute the result of loads at top level? - can_modref_map = compute_volatile_load_map(gutils, AA, TLI, _volatile_args); + can_modref_map = compute_uncacheable_load_map(gutils, AA, TLI, _uncacheable_args); if (topLevel) { for (auto iter = can_modref_map.begin(); iter != can_modref_map.end(); iter++) { if (iter->second) { @@ -1886,6 +1894,16 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co } } } + + // Allow forcing cache reads to be on or off using flags. + assert(!(cache_reads_always && cache_reads_never) && "Both cache_reads_always and cache_reads_never are true. This doesn't make sense."); + if (cache_reads_always || cache_reads_never) { + bool is_needed = cache_reads_always ? true : false; + for (auto iter = can_modref_map.begin(); iter != can_modref_map.end(); iter++) { + can_modref_map[iter->first] = is_needed; + } + } + gutils->can_modref_map = &can_modref_map; gutils->forceContexts(true); @@ -2292,7 +2310,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co if (dif0) addToDiffe(op->getOperand(0), dif0); if (dif1) addToDiffe(op->getOperand(1), dif1); } else if(auto op = dyn_cast_or_null(inst)) { - handleGradientCallInst(I, E, Builder2, op, gutils, TLI, _AA, _AA, topLevel, replacedReturns, volatile_args_map[op]); + handleGradientCallInst(I, E, Builder2, op, gutils, TLI, _AA, _AA, topLevel, replacedReturns, uncacheable_args_map[op]); } else if(auto op = dyn_cast_or_null(inst)) { if (gutils->isConstantValue(inst)) continue; if (op->getType()->isPointerTy()) continue; @@ -2314,18 +2332,16 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co auto op_operand = op->getPointerOperand(); auto op_type = op->getType(); - if (cachereads) { - if (can_modref_map[inst]) { - IRBuilder<> BuilderZ(op->getNextNode()); - inst = cast(gutils->addMalloc(BuilderZ, inst)); - if (inst != op) { - // Set to nullptr since op should never be used after invalidated through addMalloc. - op = nullptr; - gutils->nonconstant_values.insert(inst); - gutils->nonconstant.insert(inst); - gutils->originalInstructions.insert(inst); - assert(inst->getType() == op_type); - } + if (can_modref_map[inst]) { + IRBuilder<> BuilderZ(op->getNextNode()); + inst = cast(gutils->addMalloc(BuilderZ, inst)); + if (inst != op) { + // Set to nullptr since op should never be used after invalidated through addMalloc. + op = nullptr; + gutils->nonconstant_values.insert(inst); + gutils->nonconstant.insert(inst); + gutils->originalInstructions.insert(inst); + assert(inst->getType() == op_type); } } diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index bfc74aceea34..3f1db3950eae 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -352,8 +352,8 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) { return invertedPointers[val] = cs; } else if (auto fn = dyn_cast(val)) { //! Todo allow tape propagation - std::set volatile_args; - auto newf = CreatePrimalAndGradient(fn, /*constant_args*/{}, TLI, AA, /*returnValue*/false, /*differentialReturn*/fn->getReturnType()->isFPOrFPVectorTy(), /*topLevel*/false, /*additionalArg*/nullptr, volatile_args); + std::set uncacheable_args; + auto newf = CreatePrimalAndGradient(fn, /*constant_args*/{}, TLI, AA, /*returnValue*/false, /*differentialReturn*/fn->getReturnType()->isFPOrFPVectorTy(), /*topLevel*/false, /*additionalArg*/nullptr, uncacheable_args); return BuilderM.CreatePointerCast(newf, fn->getType()); } else if (auto arg = dyn_cast(val)) { auto result = BuilderM.CreateCast(arg->getOpcode(), invertPointerM(arg->getOperand(0), BuilderM), arg->getDestTy(), arg->getName()+"'ipc"); diff --git a/enzyme/functional_tests_c/Makefile b/enzyme/functional_tests_c/Makefile index 939b330914d9..aa40e0c670df 100644 --- a/enzyme/functional_tests_c/Makefile +++ b/enzyme/functional_tests_c/Makefile @@ -18,7 +18,7 @@ OBJ := $(wildcard *.c) all: $(patsubst %.c,build/%-enzyme0,$(OBJ)) $(patsubst %.c,build/%-enzyme1,$(OBJ)) $(patsubst %.c,build/%-enzyme2,$(OBJ)) $(patsubst %.c,build/%-enzyme3,$(OBJ)) -POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg -enzyme_cachereads=true -enzyme_print=true +POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg -enzyme_print=true #all: $(patsubst %.c,build/%-enzyme1,$(OBJ)) $(patsubst %.c,build/%-enzyme2,$(OBJ)) $(patsubst %.c,build/%-enzyme3,$(OBJ)) #clean: From 532cf59fbbd50bb68d35c6f8374cc730ca16bbd2 Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Tue, 5 Nov 2019 21:09:00 +0000 Subject: [PATCH 17/22] after rebase --- enzyme/functional_tests_c/insertsort_sum.c | 50 ++-------------------- 1 file changed, 3 insertions(+), 47 deletions(-) diff --git a/enzyme/functional_tests_c/insertsort_sum.c b/enzyme/functional_tests_c/insertsort_sum.c index e68ef37721d0..c5e7cd33d3a0 100644 --- a/enzyme/functional_tests_c/insertsort_sum.c +++ b/enzyme/functional_tests_c/insertsort_sum.c @@ -16,18 +16,8 @@ float* unsorted_array_init(int N) { return arr; } -// sums the first half of a sorted array. -<<<<<<< HEAD -======= -//__attribute__((noinline)) -<<<<<<< HEAD ->>>>>>> add missing files and fix minor bugs -void insertsort_sum (float* array, int N, float* ret) { -======= void insertsort_sum (float*__restrict array, int N, float*__restrict ret) { ->>>>>>> bugfix. still unsure if the logic used at topLevel for detecting when we can avoid caching loads is correct though float sum = 0; - //qsort(array, N, sizeof(float), cmp); for (int i = 1; i < N; i++) { int j = i; @@ -47,31 +37,8 @@ void insertsort_sum (float*__restrict array, int N, float*__restrict ret) { *ret = sum; } -<<<<<<< HEAD -<<<<<<< HEAD -======= -//void insertsort_sum (float* array, int N, float* ret) { -// insertsort_sum_subcall(array, N, ret); -//} ->>>>>>> add missing files and fix minor bugs - - -======= ->>>>>>> put in the more strict/correct logic for ordering instructions in single function when checking for modrefs int main(int argc, char** argv) { - - - - float a = 2.0; - float b = 3.0; - - - - float da = 0; - float db = 0; - - float ret = 0; float dret = 1.0; @@ -88,18 +55,15 @@ int main(int argc, char** argv) { printf("%d:%f\n", i, array[i]); } - //insertsort_sum(array, N, &ret); + __builtin_autodiff(insertsort_sum, array, d_array, N, &ret, &dret); + + printf("The total sum is %f\n", ret); printf("Array after sorting:\n"); for (int i = 0; i < N; i++) { printf("%d:%f\n", i, array[i]); } - - printf("The total sum is %f\n", ret); - - __builtin_autodiff(insertsort_sum, array, d_array, N, &ret, &dret); - for (int i = 0; i < N; i++) { printf("Diffe for index %d is %f\n", i, d_array[i]); if (i%2 == 0) { @@ -108,13 +72,5 @@ int main(int argc, char** argv) { assert(d_array[i] == 1.0); } } - - //__builtin_autodiff(compute_loops, &a, &da, &b, &db, &ret, &dret); - - - //assert(da == 100*1.0f); - //assert(db == 100*1.0f); - - //printf("hello! %f, res2 %f, da: %f, db: %f\n", ret, ret, da,db); return 0; } From 1352eb4934af63edc0bf97368385a242827c365f Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Tue, 5 Nov 2019 21:11:06 +0000 Subject: [PATCH 18/22] remove enzyme print from functional_c_tests --- enzyme/functional_tests_c/Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enzyme/functional_tests_c/Makefile b/enzyme/functional_tests_c/Makefile index aa40e0c670df..310affbe3342 100644 --- a/enzyme/functional_tests_c/Makefile +++ b/enzyme/functional_tests_c/Makefile @@ -18,7 +18,7 @@ OBJ := $(wildcard *.c) all: $(patsubst %.c,build/%-enzyme0,$(OBJ)) $(patsubst %.c,build/%-enzyme1,$(OBJ)) $(patsubst %.c,build/%-enzyme2,$(OBJ)) $(patsubst %.c,build/%-enzyme3,$(OBJ)) -POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg -enzyme_print=true +POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg #all: $(patsubst %.c,build/%-enzyme1,$(OBJ)) $(patsubst %.c,build/%-enzyme2,$(OBJ)) $(patsubst %.c,build/%-enzyme3,$(OBJ)) #clean: From 94fddb911ab85c47cf5e1cc9ad97969e4f2bf702 Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Tue, 5 Nov 2019 21:20:42 +0000 Subject: [PATCH 19/22] remove mistaken commit --- enzyme/Enzyme/ActiveVariable.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/enzyme/Enzyme/ActiveVariable.cpp b/enzyme/Enzyme/ActiveVariable.cpp index dae69a858cdc..1bafa198c96f 100644 --- a/enzyme/Enzyme/ActiveVariable.cpp +++ b/enzyme/Enzyme/ActiveVariable.cpp @@ -152,8 +152,7 @@ bool isIntASecretFloat(Value* val) { if (!pointerUse && floatingUse) return true; llvm::errs() << *inst->getParent()->getParent() << "\n"; llvm::errs() << " val:" << *val << " pointer:" << pointerUse << " floating:" << floatingUse << "\n"; - //assert(0 && "ambiguous unsure if constant or not"); - return false; // NOTE(TFK): Return false instead of asset. + assert(0 && "ambiguous unsure if constant or not"); } llvm::errs() << *val << "\n"; From ac1857b876e5202497da6b0e5695a38200753cc7 Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Thu, 7 Nov 2019 19:08:46 +0000 Subject: [PATCH 20/22] few changes --- enzyme/Enzyme/EnzymeLogic.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index cd3a08324c69..e4e38d3a89e0 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -49,11 +49,11 @@ llvm::cl::opt enzyme_print("enzyme_print", cl::init(false), cl::Hidden, cl::opt cache_reads_always( "enzyme_always_cache_reads", cl::init(false), cl::Hidden, - cl::desc("Force caching of all reads")); + cl::desc("Force always caching of all reads")); cl::opt cache_reads_never( "enzyme_never_cache_reads", cl::init(false), cl::Hidden, - cl::desc("Force caching of all reads")); + cl::desc("Force never caching of all reads")); @@ -69,9 +69,9 @@ std::map compute_uncacheable_load_map(GradientUtils* gutils, // For each load instruction, determine if it is uncacheable. if (auto op = dyn_cast(inst)) { // NOTE(TFK): The reasoning behind skipping ConstantValues and ConstantInstructions needs to be fleshed out. - if (gutils->isConstantValue(inst) || gutils->isConstantInstruction(inst)) { - continue; - } + //if (gutils->isConstantValue(inst) || gutils->isConstantInstruction(inst)) { + // continue; + //} bool can_modref = false; // Find the underlying object for the pointer operand of the load instruction. From 9591d4153800a8c52984e26f743b7e51ed1117c5 Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Thu, 7 Nov 2019 21:22:39 +0000 Subject: [PATCH 21/22] few minor changes/fixes --- enzyme/Enzyme/EnzymeLogic.cpp | 6 +++--- enzyme/Enzyme/GradientUtils.cpp | 1 - enzyme/functional_tests_c/setup.sh | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index e4e38d3a89e0..36dd16c0177a 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -195,12 +195,12 @@ std::set compute_uncacheable_args_for_one_callsite(Instruction* callsi for(BasicBlock* BB: gutils->originalBlocks) { for (auto I = BB->begin(), E = BB->end(); I != E; I++) { Instruction* inst = &*I; - //if (inst == callsite_inst) continue; - + // If the "inst" does not dominate "callsite_inst" then we cannot prove that // "inst" happens before "callsite_inst". If "inst" modifies an argument of the call, // then that call needs to consider the argument uncacheable. - if (!gutils->DT.dominates(inst, callsite_inst)) { + // To correctly handle case where inst == callsite_inst, we need to look at next instruction after callsite_inst. + if (!gutils->DT.dominates(inst, callsite_inst->getNextNonDebugInstruction())) { //llvm::errs() << "Instruction " << *inst << " DOES NOT dominates " << *callsite_inst << "\n"; // Consider Store Instructions. if (auto op = dyn_cast(inst)) { diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 3f1db3950eae..49d5bb73ee08 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -306,7 +306,6 @@ DiffeGradientUtils* DiffeGradientUtils::CreateFromClone(Function *todiff, AAResu } Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) { - //if (isa(val)) return val; if (isa(val)) { return val; } else if (isa(val)) { diff --git a/enzyme/functional_tests_c/setup.sh b/enzyme/functional_tests_c/setup.sh index 98be63a09e0d..c5c86df2fbb9 100755 --- a/enzyme/functional_tests_c/setup.sh +++ b/enzyme/functional_tests_c/setup.sh @@ -1,8 +1,8 @@ #!/bin/bash # NOTE(TFK): Uncomment for local testing. -export CLANG_BIN_PATH=./../../build-dbg/bin -export ENZYME_PLUGIN=./../mkdebug/Enzyme/LLVMEnzyme-7.so +export CLANG_BIN_PATH=./../../llvm/build/bin/ +export ENZYME_PLUGIN=./../build/Enzyme/LLVMEnzyme-7.so mkdir -p build $@ From a6845db7e81798e1d99939f0febe9c67608b58d4 Mon Sep 17 00:00:00 2001 From: Tim Kaler Date: Thu, 7 Nov 2019 21:33:11 +0000 Subject: [PATCH 22/22] rename AA/_AA arguments --- enzyme/Enzyme/EnzymeLogic.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 36dd16c0177a..e06991589792 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -357,7 +357,7 @@ bool is_load_needed_in_reverse(GradientUtils* gutils, AAResults& AA, Instruction //! return structtype if recursive function -std::pair CreateAugmentedPrimal(Function* todiff, AAResults &_AA, const std::set& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed, const std::set _uncacheable_args) { +std::pair CreateAugmentedPrimal(Function* todiff, AAResults &global_AA, const std::set& constant_args, TargetLibraryInfo &TLI, bool differentialReturn, bool returnUsed, const std::set _uncacheable_args) { static std::map/*constant_args*/, std::set/*uncacheable_args*/, bool/*differentialReturn*/, bool/*returnUsed*/>, std::pair> cachedfunctions; static std::map/*constant_args*/, std::set/*uncacheable_args*/, bool/*differentialReturn*/, bool/*returnUsed*/>, bool> cachedfinished; auto tup = std::make_tuple(todiff, std::set(constant_args.begin(), constant_args.end()), std::set(_uncacheable_args.begin(), _uncacheable_args.end()), differentialReturn, returnUsed); @@ -699,7 +699,7 @@ std::pair CreateAugmentedPrimal(Function* todiff, AAResul } } - auto newcalled = CreateAugmentedPrimal(dyn_cast(called), _AA, subconstant_args, TLI, /*differentialReturn*/subdifferentialreturn, /*return is used*/subretused, uncacheable_args_map[op]).first; + auto newcalled = CreateAugmentedPrimal(dyn_cast(called), global_AA, subconstant_args, TLI, /*differentialReturn*/subdifferentialreturn, /*return is used*/subretused, uncacheable_args_map[op]).first; auto augmentcall = BuilderZ.CreateCall(newcalled, args); assert(augmentcall->getType()->isStructTy()); augmentcall->setCallingConv(op->getCallingConv()); @@ -1244,7 +1244,7 @@ std::pair,SmallVector> getDefaultFunctionTypeForGr return std::pair,SmallVector>(args, outs); } -void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::reverse_iterator &E, IRBuilder <>& Builder2, CallInst* op, DiffeGradientUtils* const gutils, TargetLibraryInfo &TLI, AAResults &AA, AAResults & _AA, const bool topLevel, const std::map &replacedReturns, std::set uncacheable_args) { +void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::reverse_iterator &E, IRBuilder <>& Builder2, CallInst* op, DiffeGradientUtils* const gutils, TargetLibraryInfo &TLI, AAResults &AA, AAResults & global_AA, const bool topLevel, const std::map &replacedReturns, std::set uncacheable_args) { llvm::errs() << "HandleGradientCall " << *op << "\n"; Function *called = op->getCalledFunction(); @@ -1457,7 +1457,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r if (iter->mayReadOrWriteMemory()) { llvm::errs() << "Iter is at " << *iter << "\n"; llvm::errs() << "origop is at " << *origop << "\n"; - mri = _AA.getModRefInfo(&*iter, origop); + mri = AA.getModRefInfo(&*iter, origop); } if (mri == ModRefInfo::NoModRef && !usesInst) { @@ -1588,7 +1588,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r if (modifyPrimal && called) { bool subretused = op->getNumUses() != 0; bool subdifferentialreturn = (!gutils->isConstantValue(op)) && subretused; - auto fnandtapetype = CreateAugmentedPrimal(cast(called), _AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/subretused, uncacheable_args); + auto fnandtapetype = CreateAugmentedPrimal(cast(called), global_AA, subconstant_args, TLI, /*differentialReturns*/subdifferentialreturn, /*return is used*/subretused, uncacheable_args); if (topLevel) { Function* newcalled = fnandtapetype.first; augmentcall = BuilderZ.CreateCall(newcalled, pre_args); @@ -1660,7 +1660,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r bool subdiffereturn = (!gutils->isConstantValue(op)) && !( op->getType()->isPointerTy() || op->getType()->isIntegerTy() || op->getType()->isEmptyTy() ); llvm::errs() << "subdifferet:" << subdiffereturn << " " << *op << "\n"; if (called) { - newcalled = CreatePrimalAndGradient(cast(called), subconstant_args, TLI, _AA, /*returnValue*/retUsed, /*subdiffereturn*/subdiffereturn, /*topLevel*/replaceFunction, tape ? tape->getType() : nullptr, uncacheable_args);//, LI, DT); + newcalled = CreatePrimalAndGradient(cast(called), subconstant_args, TLI, global_AA, /*returnValue*/retUsed, /*subdiffereturn*/subdiffereturn, /*topLevel*/replaceFunction, tape ? tape->getType() : nullptr, uncacheable_args);//, LI, DT); } else { newcalled = gutils->invertPointerM(op->getCalledValue(), Builder2); auto ft = cast(cast(op->getCalledValue()->getType())->getElementType()); @@ -1770,7 +1770,7 @@ void handleGradientCallInst(BasicBlock::reverse_iterator &I, const BasicBlock::r } } -Function* CreatePrimalAndGradient(Function* todiff, const std::set& constant_args, TargetLibraryInfo &TLI, AAResults &_AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg, std::set _uncacheable_args) { +Function* CreatePrimalAndGradient(Function* todiff, const std::set& constant_args, TargetLibraryInfo &TLI, AAResults &global_AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg, std::set _uncacheable_args) { if (differentialReturn) { if(!todiff->getReturnType()->isFPOrFPVectorTy()) { llvm::errs() << *todiff << "\n"; @@ -2310,7 +2310,7 @@ Function* CreatePrimalAndGradient(Function* todiff, const std::set& co if (dif0) addToDiffe(op->getOperand(0), dif0); if (dif1) addToDiffe(op->getOperand(1), dif1); } else if(auto op = dyn_cast_or_null(inst)) { - handleGradientCallInst(I, E, Builder2, op, gutils, TLI, _AA, _AA, topLevel, replacedReturns, uncacheable_args_map[op]); + handleGradientCallInst(I, E, Builder2, op, gutils, TLI, global_AA, global_AA, topLevel, replacedReturns, uncacheable_args_map[op]); } else if(auto op = dyn_cast_or_null(inst)) { if (gutils->isConstantValue(inst)) continue; if (op->getType()->isPointerTy()) continue;