Skip to content

Commit 4a2bea8

Browse files
Similar optimizations and refactoring for Flowaware
1 parent 4ae2365 commit 4a2bea8

File tree

3 files changed

+79
-73
lines changed

3 files changed

+79
-73
lines changed

src/FlowAware.cpp

Lines changed: 78 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,16 @@ void IR2Vec_FA::collectWriteDefsMap(Module &M) {
9696
}
9797
}
9898

99-
Vector IR2Vec_FA::getValue(std::string key) {
100-
Vector vec(DIM, 0);
101-
if (vocabulary.find(key) == vocabulary.end()) {
102-
IR2VEC_DEBUG(errs() << "cannot find key in map : " << key << "\n");
103-
dataMissCounter++;
104-
} else
105-
vec = vocabulary[key];
106-
return vec;
99+
bool IR2Vec_FA::getValue(std::string key, IR2Vec::Vector &out) {
100+
if (auto it = vocabulary.find(std::string(key)); it != vocabulary.end()) {
101+
out = it->second;
102+
return true;
103+
}
104+
105+
out.assign(DIM, 0);
106+
dataMissCounter++;
107+
IR2VEC_DEBUG(errs() << "cannot find key in map : " << key << "\n");
108+
return false;
107109
}
108110

109111
// Function to update funcVecMap of function with vectors of it's callee list
@@ -169,9 +171,8 @@ void IR2Vec_FA::generateFlowAwareEncodings(std::ostream *o,
169171
res += std::to_string(cls) + "\t";
170172

171173
for (auto i : pgmVector) {
172-
if ((i <= 0.0001 && i > 0) || (i < 0 && i >= -0.0001)) {
173-
i = 0;
174-
}
174+
if (std::abs(i) <= 1e-4f)
175+
i = 0.0f;
175176
res += std::to_string(i) + "\t";
176177
}
177178
res += "\n";
@@ -537,11 +538,11 @@ Vector IR2Vec_FA::func2Vec(Function &F,
537538
}
538539
bbVecMap[b] = bbVector;
539540
IR2VEC_DEBUG(outs() << "-------------------------------------------\n");
540-
for (auto i : bbVector) {
541-
if ((i <= 0.0001 && i > 0) || (i < 0 && i >= -0.0001)) {
542-
i = 0;
543-
}
544-
}
541+
542+
std::for_each(bbVector.begin(), bbVector.end(), [](double &x) {
543+
if (std::abs(x) <= 1e-4f)
544+
x = 0.0f;
545+
});
545546

546547
std::transform(funcVector.begin(), funcVector.end(), bbVector.begin(),
547548
funcVector.begin(), std::plus<double>());
@@ -864,11 +865,11 @@ void IR2Vec_FA::getPartialVec(
864865
return;
865866
}
866867

867-
Vector instVector(DIM, 0);
868+
Vector instVector(DIM, 0), opcode_vec;
868869
StringRef opcodeName = I.getOpcodeName();
869-
auto vec = getValue(opcodeName.str());
870+
getValue(opcodeName.str(), opcode_vec);
870871
IR2VEC_DEBUG(I.print(outs()); outs() << "\n");
871-
std::transform(instVector.begin(), instVector.end(), vec.begin(),
872+
std::transform(instVector.begin(), instVector.end(), opcode_vec.begin(),
872873
instVector.begin(), std::plus<double>());
873874
partialInstValMap[&I] = instVector;
874875

@@ -878,38 +879,39 @@ void IR2Vec_FA::getPartialVec(
878879
i.first->print(outs());
879880
outs() << "\n";
880881
});
881-
auto type = I.getType();
882882

883+
auto type = I.getType();
884+
Vector type_vec;
883885
if (type->isVoidTy()) {
884-
vec = getValue("voidTy");
886+
getValue("voidTy", type_vec);
885887
} else if (type->isFloatingPointTy()) {
886-
vec = getValue("floatTy");
888+
getValue("floatTy", type_vec);
887889
} else if (type->isIntegerTy()) {
888-
vec = getValue("integerTy");
890+
getValue("integerTy", type_vec);
889891
} else if (type->isFunctionTy()) {
890-
vec = getValue("functionTy");
892+
getValue("functionTy", type_vec);
891893
} else if (type->isStructTy()) {
892-
vec = getValue("structTy");
894+
getValue("structTy", type_vec);
893895
} else if (type->isArrayTy()) {
894-
vec = getValue("arrayTy");
896+
getValue("arrayTy", type_vec);
895897
} else if (type->isPointerTy()) {
896-
vec = getValue("pointerTy");
898+
getValue("pointerTy", type_vec);
897899
} else if (type->isVectorTy()) {
898-
vec = getValue("vectorTy");
900+
getValue("vectorTy", type_vec);
899901
} else if (type->isEmptyTy()) {
900-
vec = getValue("emptyTy");
902+
getValue("emptyTy", type_vec);
901903
} else if (type->isLabelTy()) {
902-
vec = getValue("labelTy");
904+
getValue("labelTy", type_vec);
903905
} else if (type->isTokenTy()) {
904-
vec = getValue("tokenTy");
906+
getValue("tokenTy", type_vec);
905907
} else if (type->isMetadataTy()) {
906-
vec = getValue("metadataTy");
908+
getValue("metadataTy", type_vec);
907909
} else {
908-
vec = getValue("unknownTy");
910+
getValue("unknownTy", type_vec);
909911
}
910912

911-
scaleVector(vec, WT);
912-
std::transform(instVector.begin(), instVector.end(), vec.begin(),
913+
scaleVector(type_vec, WT);
914+
std::transform(instVector.begin(), instVector.end(), type_vec.begin(),
913915
instVector.begin(), std::plus<double>());
914916

915917
partialInstValMap[&I] = instVector;
@@ -940,7 +942,8 @@ void IR2Vec_FA::solveInsts(
940942
B.push_back(tmp);
941943
for (unsigned i = 0; i < inst->getNumOperands(); i++) {
942944
if (isa<Function>(inst->getOperand(i))) {
943-
auto f = getValue("function");
945+
Vector f;
946+
getValue("function", f);
944947
if (isa<CallInst>(inst)) {
945948
auto ci = dyn_cast<CallInst>(inst);
946949
Function *func = ci->getCalledFunction();
@@ -965,7 +968,8 @@ void IR2Vec_FA::solveInsts(
965968
B.push_back(vec);
966969
} else if (isa<Constant>(inst->getOperand(i)) &&
967970
!isa<PointerType>(inst->getOperand(i)->getType())) {
968-
auto c = getValue("constant");
971+
Vector c;
972+
getValue("constant", c);
969973
auto svtmp = c;
970974
scaleVector(svtmp, WA);
971975
std::vector<double> vtmp(svtmp.begin(), svtmp.end());
@@ -978,7 +982,8 @@ void IR2Vec_FA::solveInsts(
978982
IR2VEC_DEBUG(outs() << vec.back() << "\n");
979983
B.push_back(vec);
980984
} else if (isa<BasicBlock>(inst->getOperand(i))) {
981-
auto l = getValue("label");
985+
Vector l;
986+
getValue("label", l);
982987
auto svtmp = l;
983988
scaleVector(svtmp, WA);
984989
std::vector<double> vtmp(svtmp.begin(), svtmp.end());
@@ -1022,7 +1027,8 @@ void IR2Vec_FA::solveInsts(
10221027
}
10231028
}
10241029
} else if (isa<PointerType>(inst->getOperand(i)->getType())) {
1025-
auto l = getValue("pointer");
1030+
Vector l;
1031+
getValue("pointer", l);
10261032
auto svtmp = l;
10271033
scaleVector(svtmp, WA);
10281034
std::vector<double> vtmp(svtmp.begin(), svtmp.end());
@@ -1035,7 +1041,8 @@ void IR2Vec_FA::solveInsts(
10351041
IR2VEC_DEBUG(outs() << vec.back() << "\n");
10361042
B.push_back(vec);
10371043
} else {
1038-
auto l = getValue("variable");
1044+
Vector l;
1045+
getValue("variable", l);
10391046
auto svtmp = l;
10401047
scaleVector(svtmp, WA);
10411048
std::vector<double> vtmp(svtmp.begin(), svtmp.end());
@@ -1137,9 +1144,9 @@ void IR2Vec_FA::solveSingleComponent(
11371144
RDList.clear();
11381145

11391146
for (unsigned i = 0; i < I.getNumOperands() /*&& !isCyclic*/; i++) {
1140-
Vector vecOp(DIM, 0);
1147+
Vector vecOp;
11411148
if (isa<Function>(I.getOperand(i))) {
1142-
vecOp = getValue("function");
1149+
getValue("function", vecOp);
11431150
if (isa<CallInst>(I)) {
11441151
auto ci = dyn_cast<CallInst>(&I);
11451152
Function *func = ci->getCalledFunction();
@@ -1156,17 +1163,17 @@ void IR2Vec_FA::solveSingleComponent(
11561163
// non-numeric/alphabetic constants are also caught as pointer types
11571164
else if (isa<Constant>(I.getOperand(i)) &&
11581165
!isa<PointerType>(I.getOperand(i)->getType())) {
1159-
vecOp = getValue("constant");
1166+
getValue("constant", vecOp);
11601167
} else if (isa<BasicBlock>(I.getOperand(i))) {
1161-
vecOp = getValue("label");
1168+
getValue("label", vecOp);
11621169
} else {
11631170
if (isa<Instruction>(I.getOperand(i))) {
11641171
auto RD = getReachingDefs(&I, i);
11651172
RDList.insert(RDList.end(), RD.begin(), RD.end());
11661173
} else if (isa<PointerType>(I.getOperand(i)->getType())) {
1167-
vecOp = getValue("pointer");
1174+
getValue("pointer", vecOp);
11681175
} else
1169-
vecOp = getValue("variable");
1176+
getValue("variable", vecOp);
11701177
}
11711178

11721179
std::transform(VecArgs.begin(), VecArgs.end(), vecOp.begin(),
@@ -1237,11 +1244,11 @@ void IR2Vec_FA::inst2Vec(
12371244
return;
12381245
}
12391246

1240-
Vector instVector(DIM, 0);
1247+
Vector instVector(DIM, 0), opcode_vec;
12411248
StringRef opcodeName = I.getOpcodeName();
1242-
auto vec = getValue(opcodeName.str());
1249+
getValue(opcodeName.str(), opcode_vec);
12431250
IR2VEC_DEBUG(I.print(outs()); outs() << "\n");
1244-
std::transform(instVector.begin(), instVector.end(), vec.begin(),
1251+
std::transform(instVector.begin(), instVector.end(), opcode_vec.begin(),
12451252
instVector.begin(), std::plus<double>());
12461253
partialInstValMap[&I] = instVector;
12471254

@@ -1253,36 +1260,36 @@ void IR2Vec_FA::inst2Vec(
12531260
});
12541261

12551262
auto type = I.getType();
1256-
1263+
Vector type_vec;
12571264
if (type->isVoidTy()) {
1258-
vec = getValue("voidTy");
1265+
getValue("voidTy", type_vec);
12591266
} else if (type->isFloatingPointTy()) {
1260-
vec = getValue("floatTy");
1267+
getValue("floatTy", type_vec);
12611268
} else if (type->isIntegerTy()) {
1262-
vec = getValue("integerTy");
1269+
getValue("integerTy", type_vec);
12631270
} else if (type->isFunctionTy()) {
1264-
vec = getValue("functionTy");
1271+
getValue("functionTy", type_vec);
12651272
} else if (type->isStructTy()) {
1266-
vec = getValue("structTy");
1273+
getValue("structTy", type_vec);
12671274
} else if (type->isArrayTy()) {
1268-
vec = getValue("arrayTy");
1275+
getValue("arrayTy", type_vec);
12691276
} else if (type->isPointerTy()) {
1270-
vec = getValue("pointerTy");
1277+
getValue("pointerTy", type_vec);
12711278
} else if (type->isVectorTy()) {
1272-
vec = getValue("vectorTy");
1279+
getValue("vectorTy", type_vec);
12731280
} else if (type->isEmptyTy()) {
1274-
vec = getValue("emptyTy");
1281+
getValue("emptyTy", type_vec);
12751282
} else if (type->isLabelTy()) {
1276-
vec = getValue("labelTy");
1283+
getValue("labelTy", type_vec);
12771284
} else if (type->isTokenTy()) {
1278-
vec = getValue("tokenTy");
1285+
getValue("tokenTy", type_vec);
12791286
} else if (type->isMetadataTy()) {
1280-
vec = getValue("metadataTy");
1287+
getValue("metadataTy", type_vec);
12811288
} else {
1282-
vec = getValue("unknownTy");
1289+
getValue("unknownTy", type_vec);
12831290
}
1284-
scaleVector(vec, WT);
1285-
std::transform(instVector.begin(), instVector.end(), vec.begin(),
1291+
scaleVector(type_vec, WT);
1292+
std::transform(instVector.begin(), instVector.end(), type_vec.begin(),
12861293
instVector.begin(), std::plus<double>());
12871294
partialInstValMap[&I] = instVector;
12881295

@@ -1295,9 +1302,9 @@ void IR2Vec_FA::inst2Vec(
12951302
RDList.clear();
12961303

12971304
for (unsigned i = 0; i < I.getNumOperands() /*&& !isCyclic*/; i++) {
1298-
Vector vecOp(DIM, 0);
1305+
Vector vecOp;
12991306
if (isa<Function>(I.getOperand(i))) {
1300-
vecOp = getValue("function");
1307+
getValue("function", vecOp);
13011308
if (isa<CallInst>(I)) {
13021309
auto ci = dyn_cast<CallInst>(&I);
13031310
Function *func = ci->getCalledFunction();
@@ -1314,17 +1321,17 @@ void IR2Vec_FA::inst2Vec(
13141321
// non-numeric/alphabetic constants are also caught as pointer types
13151322
else if (isa<Constant>(I.getOperand(i)) &&
13161323
!isa<PointerType>(I.getOperand(i)->getType())) {
1317-
vecOp = getValue("constant");
1324+
getValue("constant", vecOp);
13181325
} else if (isa<BasicBlock>(I.getOperand(i))) {
1319-
vecOp = getValue("label");
1326+
getValue("label", vecOp);
13201327
} else {
13211328
if (isa<Instruction>(I.getOperand(i))) {
13221329
auto RD = getReachingDefs(&I, i);
13231330
RDList.insert(RDList.end(), RD.begin(), RD.end());
13241331
} else if (isa<PointerType>(I.getOperand(i)->getType()))
1325-
vecOp = getValue("pointer");
1332+
getValue("pointer", vecOp);
13261333
else
1327-
vecOp = getValue("variable");
1334+
getValue("variable", vecOp);
13281335
}
13291336

13301337
std::transform(VecArgs.begin(), VecArgs.end(), vecOp.begin(),

src/Symbolic.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ using namespace IR2Vec;
2828
using abi::__cxa_demangle;
2929

3030
bool IR2Vec_Symbolic::getValue(std::string key, IR2Vec::Vector &out) {
31-
Vector vec(DIM, 0);
3231
if (auto it = vocabulary.find(std::string(key)); it != vocabulary.end()) {
3332
out = it->second;
3433
return true;

src/include/FlowAware.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class IR2Vec_FA {
7272

7373
void getAllSCC();
7474

75-
IR2Vec::Vector getValue(std::string key);
75+
bool getValue(std::string key, IR2Vec::Vector &out);
7676
void collectWriteDefsMap(llvm::Module &M);
7777
void getTransitiveUse(
7878
const llvm::Instruction *root, const llvm::Instruction *def,

0 commit comments

Comments
 (0)