@@ -96,14 +96,16 @@ void IR2Vec_FA::collectWriteDefsMap(Module &M) {
9696 }
9797}
9898
99- Vector IR2Vec_FA::getValue (std::string key) {
100- Vector vec (DIM, 0 );
101- if (vocabulary.find (key) == vocabulary.end ()) {
102- IR2VEC_DEBUG (errs () << " cannot find key in map : " << key << " \n " );
103- dataMissCounter++;
104- } else
105- vec = vocabulary[key];
106- return vec;
99+ bool IR2Vec_FA::getValue (std::string key, IR2Vec::Vector &out) {
100+ if (auto it = vocabulary.find (std::string (key)); it != vocabulary.end ()) {
101+ out = it->second ;
102+ return true ;
103+ }
104+
105+ out.assign (DIM, 0 );
106+ dataMissCounter++;
107+ IR2VEC_DEBUG (errs () << " cannot find key in map : " << key << " \n " );
108+ return false ;
107109}
108110
109111// Function to update funcVecMap of function with vectors of it's callee list
@@ -169,9 +171,8 @@ void IR2Vec_FA::generateFlowAwareEncodings(std::ostream *o,
169171 res += std::to_string (cls) + " \t " ;
170172
171173 for (auto i : pgmVector) {
172- if ((i <= 0.0001 && i > 0 ) || (i < 0 && i >= -0.0001 )) {
173- i = 0 ;
174- }
174+ if (std::abs (i) <= 1e-4f )
175+ i = 0 .0f ;
175176 res += std::to_string (i) + " \t " ;
176177 }
177178 res += " \n " ;
@@ -537,11 +538,11 @@ Vector IR2Vec_FA::func2Vec(Function &F,
537538 }
538539 bbVecMap[b] = bbVector;
539540 IR2VEC_DEBUG (outs () << " -------------------------------------------\n " );
540- for ( auto i : bbVector) {
541- if ((i <= 0.0001 && i > 0 ) || (i < 0 && i >= - 0.0001 ) ) {
542- i = 0 ;
543- }
544- }
541+
542+ std::for_each (bbVector. begin (), bbVector. end (), []( double &x ) {
543+ if ( std::abs (x) <= 1e- 4f )
544+ x = 0 . 0f ;
545+ });
545546
546547 std::transform (funcVector.begin (), funcVector.end (), bbVector.begin (),
547548 funcVector.begin (), std::plus<double >());
@@ -864,11 +865,11 @@ void IR2Vec_FA::getPartialVec(
864865 return ;
865866 }
866867
867- Vector instVector (DIM, 0 );
868+ Vector instVector (DIM, 0 ), opcode_vec ;
868869 StringRef opcodeName = I.getOpcodeName ();
869- auto vec = getValue (opcodeName.str ());
870+ getValue (opcodeName.str (), opcode_vec );
870871 IR2VEC_DEBUG (I.print (outs ()); outs () << " \n " );
871- std::transform (instVector.begin (), instVector.end (), vec .begin (),
872+ std::transform (instVector.begin (), instVector.end (), opcode_vec .begin (),
872873 instVector.begin (), std::plus<double >());
873874 partialInstValMap[&I] = instVector;
874875
@@ -878,38 +879,39 @@ void IR2Vec_FA::getPartialVec(
878879 i.first ->print (outs ());
879880 outs () << " \n " ;
880881 });
881- auto type = I.getType ();
882882
883+ auto type = I.getType ();
884+ Vector type_vec;
883885 if (type->isVoidTy ()) {
884- vec = getValue (" voidTy" );
886+ getValue (" voidTy" , type_vec );
885887 } else if (type->isFloatingPointTy ()) {
886- vec = getValue (" floatTy" );
888+ getValue (" floatTy" , type_vec );
887889 } else if (type->isIntegerTy ()) {
888- vec = getValue (" integerTy" );
890+ getValue (" integerTy" , type_vec );
889891 } else if (type->isFunctionTy ()) {
890- vec = getValue (" functionTy" );
892+ getValue (" functionTy" , type_vec );
891893 } else if (type->isStructTy ()) {
892- vec = getValue (" structTy" );
894+ getValue (" structTy" , type_vec );
893895 } else if (type->isArrayTy ()) {
894- vec = getValue (" arrayTy" );
896+ getValue (" arrayTy" , type_vec );
895897 } else if (type->isPointerTy ()) {
896- vec = getValue (" pointerTy" );
898+ getValue (" pointerTy" , type_vec );
897899 } else if (type->isVectorTy ()) {
898- vec = getValue (" vectorTy" );
900+ getValue (" vectorTy" , type_vec );
899901 } else if (type->isEmptyTy ()) {
900- vec = getValue (" emptyTy" );
902+ getValue (" emptyTy" , type_vec );
901903 } else if (type->isLabelTy ()) {
902- vec = getValue (" labelTy" );
904+ getValue (" labelTy" , type_vec );
903905 } else if (type->isTokenTy ()) {
904- vec = getValue (" tokenTy" );
906+ getValue (" tokenTy" , type_vec );
905907 } else if (type->isMetadataTy ()) {
906- vec = getValue (" metadataTy" );
908+ getValue (" metadataTy" , type_vec );
907909 } else {
908- vec = getValue (" unknownTy" );
910+ getValue (" unknownTy" , type_vec );
909911 }
910912
911- scaleVector (vec , WT);
912- std::transform (instVector.begin (), instVector.end (), vec .begin (),
913+ scaleVector (type_vec , WT);
914+ std::transform (instVector.begin (), instVector.end (), type_vec .begin (),
913915 instVector.begin (), std::plus<double >());
914916
915917 partialInstValMap[&I] = instVector;
@@ -940,7 +942,8 @@ void IR2Vec_FA::solveInsts(
940942 B.push_back (tmp);
941943 for (unsigned i = 0 ; i < inst->getNumOperands (); i++) {
942944 if (isa<Function>(inst->getOperand (i))) {
943- auto f = getValue (" function" );
945+ Vector f;
946+ getValue (" function" , f);
944947 if (isa<CallInst>(inst)) {
945948 auto ci = dyn_cast<CallInst>(inst);
946949 Function *func = ci->getCalledFunction ();
@@ -965,7 +968,8 @@ void IR2Vec_FA::solveInsts(
965968 B.push_back (vec);
966969 } else if (isa<Constant>(inst->getOperand (i)) &&
967970 !isa<PointerType>(inst->getOperand (i)->getType ())) {
968- auto c = getValue (" constant" );
971+ Vector c;
972+ getValue (" constant" , c);
969973 auto svtmp = c;
970974 scaleVector (svtmp, WA);
971975 std::vector<double > vtmp (svtmp.begin (), svtmp.end ());
@@ -978,7 +982,8 @@ void IR2Vec_FA::solveInsts(
978982 IR2VEC_DEBUG (outs () << vec.back () << " \n " );
979983 B.push_back (vec);
980984 } else if (isa<BasicBlock>(inst->getOperand (i))) {
981- auto l = getValue (" label" );
985+ Vector l;
986+ getValue (" label" , l);
982987 auto svtmp = l;
983988 scaleVector (svtmp, WA);
984989 std::vector<double > vtmp (svtmp.begin (), svtmp.end ());
@@ -1022,7 +1027,8 @@ void IR2Vec_FA::solveInsts(
10221027 }
10231028 }
10241029 } else if (isa<PointerType>(inst->getOperand (i)->getType ())) {
1025- auto l = getValue (" pointer" );
1030+ Vector l;
1031+ getValue (" pointer" , l);
10261032 auto svtmp = l;
10271033 scaleVector (svtmp, WA);
10281034 std::vector<double > vtmp (svtmp.begin (), svtmp.end ());
@@ -1035,7 +1041,8 @@ void IR2Vec_FA::solveInsts(
10351041 IR2VEC_DEBUG (outs () << vec.back () << " \n " );
10361042 B.push_back (vec);
10371043 } else {
1038- auto l = getValue (" variable" );
1044+ Vector l;
1045+ getValue (" variable" , l);
10391046 auto svtmp = l;
10401047 scaleVector (svtmp, WA);
10411048 std::vector<double > vtmp (svtmp.begin (), svtmp.end ());
@@ -1137,9 +1144,9 @@ void IR2Vec_FA::solveSingleComponent(
11371144 RDList.clear ();
11381145
11391146 for (unsigned i = 0 ; i < I.getNumOperands () /* && !isCyclic*/ ; i++) {
1140- Vector vecOp (DIM, 0 ) ;
1147+ Vector vecOp;
11411148 if (isa<Function>(I.getOperand (i))) {
1142- vecOp = getValue (" function" );
1149+ getValue (" function" , vecOp );
11431150 if (isa<CallInst>(I)) {
11441151 auto ci = dyn_cast<CallInst>(&I);
11451152 Function *func = ci->getCalledFunction ();
@@ -1156,17 +1163,17 @@ void IR2Vec_FA::solveSingleComponent(
11561163 // non-numeric/alphabetic constants are also caught as pointer types
11571164 else if (isa<Constant>(I.getOperand (i)) &&
11581165 !isa<PointerType>(I.getOperand (i)->getType ())) {
1159- vecOp = getValue (" constant" );
1166+ getValue (" constant" , vecOp );
11601167 } else if (isa<BasicBlock>(I.getOperand (i))) {
1161- vecOp = getValue (" label" );
1168+ getValue (" label" , vecOp );
11621169 } else {
11631170 if (isa<Instruction>(I.getOperand (i))) {
11641171 auto RD = getReachingDefs (&I, i);
11651172 RDList.insert (RDList.end (), RD.begin (), RD.end ());
11661173 } else if (isa<PointerType>(I.getOperand (i)->getType ())) {
1167- vecOp = getValue (" pointer" );
1174+ getValue (" pointer" , vecOp );
11681175 } else
1169- vecOp = getValue (" variable" );
1176+ getValue (" variable" , vecOp );
11701177 }
11711178
11721179 std::transform (VecArgs.begin (), VecArgs.end (), vecOp.begin (),
@@ -1237,11 +1244,11 @@ void IR2Vec_FA::inst2Vec(
12371244 return ;
12381245 }
12391246
1240- Vector instVector (DIM, 0 );
1247+ Vector instVector (DIM, 0 ), opcode_vec ;
12411248 StringRef opcodeName = I.getOpcodeName ();
1242- auto vec = getValue (opcodeName.str ());
1249+ getValue (opcodeName.str (), opcode_vec );
12431250 IR2VEC_DEBUG (I.print (outs ()); outs () << " \n " );
1244- std::transform (instVector.begin (), instVector.end (), vec .begin (),
1251+ std::transform (instVector.begin (), instVector.end (), opcode_vec .begin (),
12451252 instVector.begin (), std::plus<double >());
12461253 partialInstValMap[&I] = instVector;
12471254
@@ -1253,36 +1260,36 @@ void IR2Vec_FA::inst2Vec(
12531260 });
12541261
12551262 auto type = I.getType ();
1256-
1263+ Vector type_vec;
12571264 if (type->isVoidTy ()) {
1258- vec = getValue (" voidTy" );
1265+ getValue (" voidTy" , type_vec );
12591266 } else if (type->isFloatingPointTy ()) {
1260- vec = getValue (" floatTy" );
1267+ getValue (" floatTy" , type_vec );
12611268 } else if (type->isIntegerTy ()) {
1262- vec = getValue (" integerTy" );
1269+ getValue (" integerTy" , type_vec );
12631270 } else if (type->isFunctionTy ()) {
1264- vec = getValue (" functionTy" );
1271+ getValue (" functionTy" , type_vec );
12651272 } else if (type->isStructTy ()) {
1266- vec = getValue (" structTy" );
1273+ getValue (" structTy" , type_vec );
12671274 } else if (type->isArrayTy ()) {
1268- vec = getValue (" arrayTy" );
1275+ getValue (" arrayTy" , type_vec );
12691276 } else if (type->isPointerTy ()) {
1270- vec = getValue (" pointerTy" );
1277+ getValue (" pointerTy" , type_vec );
12711278 } else if (type->isVectorTy ()) {
1272- vec = getValue (" vectorTy" );
1279+ getValue (" vectorTy" , type_vec );
12731280 } else if (type->isEmptyTy ()) {
1274- vec = getValue (" emptyTy" );
1281+ getValue (" emptyTy" , type_vec );
12751282 } else if (type->isLabelTy ()) {
1276- vec = getValue (" labelTy" );
1283+ getValue (" labelTy" , type_vec );
12771284 } else if (type->isTokenTy ()) {
1278- vec = getValue (" tokenTy" );
1285+ getValue (" tokenTy" , type_vec );
12791286 } else if (type->isMetadataTy ()) {
1280- vec = getValue (" metadataTy" );
1287+ getValue (" metadataTy" , type_vec );
12811288 } else {
1282- vec = getValue (" unknownTy" );
1289+ getValue (" unknownTy" , type_vec );
12831290 }
1284- scaleVector (vec , WT);
1285- std::transform (instVector.begin (), instVector.end (), vec .begin (),
1291+ scaleVector (type_vec , WT);
1292+ std::transform (instVector.begin (), instVector.end (), type_vec .begin (),
12861293 instVector.begin (), std::plus<double >());
12871294 partialInstValMap[&I] = instVector;
12881295
@@ -1295,9 +1302,9 @@ void IR2Vec_FA::inst2Vec(
12951302 RDList.clear ();
12961303
12971304 for (unsigned i = 0 ; i < I.getNumOperands () /* && !isCyclic*/ ; i++) {
1298- Vector vecOp (DIM, 0 ) ;
1305+ Vector vecOp;
12991306 if (isa<Function>(I.getOperand (i))) {
1300- vecOp = getValue (" function" );
1307+ getValue (" function" , vecOp );
13011308 if (isa<CallInst>(I)) {
13021309 auto ci = dyn_cast<CallInst>(&I);
13031310 Function *func = ci->getCalledFunction ();
@@ -1314,17 +1321,17 @@ void IR2Vec_FA::inst2Vec(
13141321 // non-numeric/alphabetic constants are also caught as pointer types
13151322 else if (isa<Constant>(I.getOperand (i)) &&
13161323 !isa<PointerType>(I.getOperand (i)->getType ())) {
1317- vecOp = getValue (" constant" );
1324+ getValue (" constant" , vecOp );
13181325 } else if (isa<BasicBlock>(I.getOperand (i))) {
1319- vecOp = getValue (" label" );
1326+ getValue (" label" , vecOp );
13201327 } else {
13211328 if (isa<Instruction>(I.getOperand (i))) {
13221329 auto RD = getReachingDefs (&I, i);
13231330 RDList.insert (RDList.end (), RD.begin (), RD.end ());
13241331 } else if (isa<PointerType>(I.getOperand (i)->getType ()))
1325- vecOp = getValue (" pointer" );
1332+ getValue (" pointer" , vecOp );
13261333 else
1327- vecOp = getValue (" variable" );
1334+ getValue (" variable" , vecOp );
13281335 }
13291336
13301337 std::transform (VecArgs.begin (), VecArgs.end (), vecOp.begin (),
0 commit comments