@@ -45,15 +45,194 @@ INST_STATISTIC(FCmp);
4545
4646namespace {
4747
48+ inline AttributeSet getFnAttrs (const AttributeList &Attrs)
49+ {
50+ #if JL_LLVM_VERSION >= 140000
51+ return Attrs.getFnAttrs ();
52+ #else
53+ return Attrs.getFnAttributes ();
54+ #endif
55+ }
56+
57+ inline AttributeSet getRetAttrs (const AttributeList &Attrs)
58+ {
59+ #if JL_LLVM_VERSION >= 140000
60+ return Attrs.getRetAttrs ();
61+ #else
62+ return Attrs.getRetAttributes ();
63+ #endif
64+ }
65+
66+ static Instruction *replaceIntrinsicWith (IntrinsicInst *call, Type *RetTy, ArrayRef<Value*> args)
67+ {
68+ Intrinsic::ID ID = call->getIntrinsicID ();
69+ assert (ID);
70+ auto oldfType = call->getFunctionType ();
71+ auto nargs = oldfType->getNumParams ();
72+ assert (args.size () > nargs);
73+ SmallVector<Type*, 8 > argTys (nargs);
74+ for (unsigned i = 0 ; i < nargs; i++)
75+ argTys[i] = args[i]->getType ();
76+ auto newfType = FunctionType::get (RetTy, argTys, oldfType->isVarArg ());
77+
78+ // Accumulate an array of overloaded types for the given intrinsic
79+ // and compute the new name mangling schema
80+ SmallVector<Type*, 4 > overloadTys;
81+ {
82+ SmallVector<Intrinsic::IITDescriptor, 8 > Table;
83+ getIntrinsicInfoTableEntries (ID, Table);
84+ ArrayRef<Intrinsic::IITDescriptor> TableRef = Table;
85+ auto res = Intrinsic::matchIntrinsicSignature (newfType, TableRef, overloadTys);
86+ assert (res == Intrinsic::MatchIntrinsicTypes_Match);
87+ (void )res;
88+ bool matchvararg = !Intrinsic::matchIntrinsicVarArg (newfType->isVarArg (), TableRef);
89+ assert (matchvararg);
90+ (void )matchvararg;
91+ }
92+ auto newF = Intrinsic::getDeclaration (call->getModule (), ID, overloadTys);
93+ assert (newF->getFunctionType () == newfType);
94+ newF->setCallingConv (call->getCallingConv ());
95+ assert (args.back () == call->getCalledFunction ());
96+ auto newCall = CallInst::Create (newF, args.drop_back (), " " , call);
97+ newCall->setTailCallKind (call->getTailCallKind ());
98+ auto old_attrs = call->getAttributes ();
99+ newCall->setAttributes (AttributeList::get (call->getContext (), getFnAttrs (old_attrs),
100+ getRetAttrs (old_attrs), {})); // drop parameter attributes
101+ return newCall;
102+ }
103+
104+
105+ static Value* CreateFPCast (Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder)
106+ {
107+ Type *SrcTy = V->getType ();
108+ Type *RetTy = DestTy;
109+ if (auto *VC = dyn_cast<Constant>(V)) {
110+ // The input IR often has things of the form
111+ // fcmp olt half %0, 0xH7C00
112+ // and we would like to avoid turning that constant into a call here
113+ // if we can simply constant fold it to the new type.
114+ VC = ConstantExpr::getCast (opcode, VC, DestTy, true );
115+ if (VC)
116+ return VC;
117+ }
118+ assert (SrcTy->isVectorTy () == DestTy->isVectorTy ());
119+ if (SrcTy->isVectorTy ()) {
120+ unsigned NumElems = cast<FixedVectorType>(SrcTy)->getNumElements ();
121+ assert (cast<FixedVectorType>(DestTy)->getNumElements () == NumElems && " Mismatched cast" );
122+ Value *NewV = UndefValue::get (DestTy);
123+ RetTy = RetTy->getScalarType ();
124+ for (unsigned i = 0 ; i < NumElems; ++i) {
125+ Value *I = builder.getInt32 (i);
126+ Value *Vi = builder.CreateExtractElement (V, I);
127+ Vi = CreateFPCast (opcode, Vi, RetTy, builder);
128+ NewV = builder.CreateInsertElement (NewV, Vi, I);
129+ }
130+ return NewV;
131+ }
132+ auto &M = *builder.GetInsertBlock ()->getModule ();
133+ auto &ctx = M.getContext ();
134+ // Pick the Function to call in the Julia runtime
135+ StringRef Name;
136+ switch (opcode) {
137+ case Instruction::FPExt:
138+ // this is exact, so we only need one conversion
139+ assert (SrcTy->isHalfTy ());
140+ Name = " julia__gnu_h2f_ieee" ;
141+ RetTy = Type::getFloatTy (ctx);
142+ break ;
143+ case Instruction::FPTrunc:
144+ assert (DestTy->isHalfTy ());
145+ if (SrcTy->isFloatTy ())
146+ Name = " julia__gnu_f2h_ieee" ;
147+ else if (SrcTy->isDoubleTy ())
148+ Name = " julia__truncdfhf2" ;
149+ break ;
150+ // All F16 fit exactly in Int32 (-65504 to 65504)
151+ case Instruction::FPToSI: JL_FALLTHROUGH;
152+ case Instruction::FPToUI:
153+ assert (SrcTy->isHalfTy ());
154+ Name = " julia__gnu_h2f_ieee" ;
155+ RetTy = Type::getFloatTy (ctx);
156+ break ;
157+ case Instruction::SIToFP: JL_FALLTHROUGH;
158+ case Instruction::UIToFP:
159+ assert (DestTy->isHalfTy ());
160+ Name = " julia__gnu_f2h_ieee" ;
161+ SrcTy = Type::getFloatTy (ctx);
162+ break ;
163+ default :
164+ errs () << Instruction::getOpcodeName (opcode) << ' ' ;
165+ V->getType ()->print (errs ());
166+ errs () << " to " ;
167+ DestTy->print (errs ());
168+ errs () << " is an " ;
169+ llvm_unreachable (" invalid cast" );
170+ }
171+ if (Name.empty ()) {
172+ errs () << Instruction::getOpcodeName (opcode) << ' ' ;
173+ V->getType ()->print (errs ());
174+ errs () << " to " ;
175+ DestTy->print (errs ());
176+ errs () << " is an " ;
177+ llvm_unreachable (" illegal cast" );
178+ }
179+ // Coerce the source to the required size and type
180+ auto T_int16 = Type::getInt16Ty (ctx);
181+ if (SrcTy->isHalfTy ())
182+ SrcTy = T_int16;
183+ if (opcode == Instruction::SIToFP)
184+ V = builder.CreateSIToFP (V, SrcTy);
185+ else if (opcode == Instruction::UIToFP)
186+ V = builder.CreateUIToFP (V, SrcTy);
187+ else
188+ V = builder.CreateBitCast (V, SrcTy);
189+ // Call our intrinsic
190+ if (RetTy->isHalfTy ())
191+ RetTy = T_int16;
192+ auto FT = FunctionType::get (RetTy, {SrcTy}, false );
193+ FunctionCallee F = M.getOrInsertFunction (Name, FT);
194+ Value *I = builder.CreateCall (F, {V});
195+ // Coerce the result to the expected type
196+ if (opcode == Instruction::FPToSI)
197+ I = builder.CreateFPToSI (I, DestTy);
198+ else if (opcode == Instruction::FPToUI)
199+ I = builder.CreateFPToUI (I, DestTy);
200+ else if (opcode == Instruction::FPExt)
201+ I = builder.CreateFPCast (I, DestTy);
202+ else
203+ I = builder.CreateBitCast (I, DestTy);
204+ return I;
205+ }
206+
48207static bool demoteFloat16 (Function &F)
49208{
50209 auto &ctx = F.getContext ();
51- auto T_float16 = Type::getHalfTy (ctx);
52210 auto T_float32 = Type::getFloatTy (ctx);
53211
54212 SmallVector<Instruction *, 0 > erase;
55213 for (auto &BB : F) {
56214 for (auto &I : BB) {
215+ // extend Float16 operands to Float32
216+ bool Float16 = I.getType ()->getScalarType ()->isHalfTy ();
217+ for (size_t i = 0 ; !Float16 && i < I.getNumOperands (); i++) {
218+ Value *Op = I.getOperand (i);
219+ if (Op->getType ()->getScalarType ()->isHalfTy ())
220+ Float16 = true ;
221+ }
222+ if (!Float16)
223+ continue ;
224+
225+ if (auto CI = dyn_cast<CastInst>(&I)) {
226+ if (CI->getOpcode () != Instruction::BitCast) { // aka !CI->isNoopCast(DL)
227+ ++TotalChanged;
228+ IRBuilder<> builder (&I);
229+ Value *NewI = CreateFPCast (CI->getOpcode (), I.getOperand (0 ), I.getType (), builder);
230+ I.replaceAllUsesWith (NewI);
231+ erase.push_back (&I);
232+ }
233+ continue ;
234+ }
235+
57236 switch (I.getOpcode ()) {
58237 case Instruction::FNeg:
59238 case Instruction::FAdd:
@@ -64,6 +243,9 @@ static bool demoteFloat16(Function &F)
64243 case Instruction::FCmp:
65244 break ;
66245 default :
246+ if (auto intrinsic = dyn_cast<IntrinsicInst>(&I))
247+ if (intrinsic->getIntrinsicID ())
248+ break ;
67249 continue ;
68250 }
69251
@@ -75,72 +257,78 @@ static bool demoteFloat16(Function &F)
75257 IRBuilder<> builder (&I);
76258
77259 // extend Float16 operands to Float32
78- bool OperandsChanged = false ;
260+ // XXX: Calls to llvm.fma.f16 may need to go to f64 to be correct?
79261 SmallVector<Value *, 2 > Operands (I.getNumOperands ());
80262 for (size_t i = 0 ; i < I.getNumOperands (); i++) {
81263 Value *Op = I.getOperand (i);
82- if (Op->getType () == T_float16 ) {
264+ if (Op->getType ()-> getScalarType ()-> isHalfTy () ) {
83265 ++TotalExt;
84- Op = builder.CreateFPExt (Op, T_float32);
85- OperandsChanged = true ;
266+ Op = CreateFPCast (Instruction::FPExt, Op, Op->getType ()->getWithNewType (T_float32), builder);
86267 }
87268 Operands[i] = (Op);
88269 }
89270
90271 // recreate the instruction if any operands changed,
91272 // truncating the result back to Float16
92- if (OperandsChanged) {
93- Value *NewI;
94- ++TotalChanged;
95- switch (I.getOpcode ()) {
96- case Instruction::FNeg:
97- assert (Operands.size () == 1 );
98- ++FNegChanged;
99- NewI = builder.CreateFNeg (Operands[0 ]);
100- break ;
101- case Instruction::FAdd:
102- assert (Operands.size () == 2 );
103- ++FAddChanged;
104- NewI = builder.CreateFAdd (Operands[0 ], Operands[1 ]);
105- break ;
106- case Instruction::FSub:
107- assert (Operands.size () == 2 );
108- ++FSubChanged;
109- NewI = builder.CreateFSub (Operands[0 ], Operands[1 ]);
110- break ;
111- case Instruction::FMul:
112- assert (Operands.size () == 2 );
113- ++FMulChanged;
114- NewI = builder.CreateFMul (Operands[0 ], Operands[1 ]);
115- break ;
116- case Instruction::FDiv:
117- assert (Operands.size () == 2 );
118- ++FDivChanged;
119- NewI = builder.CreateFDiv (Operands[0 ], Operands[1 ]);
120- break ;
121- case Instruction::FRem:
122- assert (Operands.size () == 2 );
123- ++FRemChanged;
124- NewI = builder.CreateFRem (Operands[0 ], Operands[1 ]);
125- break ;
126- case Instruction::FCmp:
127- assert (Operands.size () == 2 );
128- ++FCmpChanged;
129- NewI = builder.CreateFCmp (cast<FCmpInst>(&I)->getPredicate (),
130- Operands[0 ], Operands[1 ]);
273+ Value *NewI;
274+ ++TotalChanged;
275+ switch (I.getOpcode ()) {
276+ case Instruction::FNeg:
277+ assert (Operands.size () == 1 );
278+ ++FNegChanged;
279+ NewI = builder.CreateFNeg (Operands[0 ]);
280+ break ;
281+ case Instruction::FAdd:
282+ assert (Operands.size () == 2 );
283+ ++FAddChanged;
284+ NewI = builder.CreateFAdd (Operands[0 ], Operands[1 ]);
285+ break ;
286+ case Instruction::FSub:
287+ assert (Operands.size () == 2 );
288+ ++FSubChanged;
289+ NewI = builder.CreateFSub (Operands[0 ], Operands[1 ]);
290+ break ;
291+ case Instruction::FMul:
292+ assert (Operands.size () == 2 );
293+ ++FMulChanged;
294+ NewI = builder.CreateFMul (Operands[0 ], Operands[1 ]);
295+ break ;
296+ case Instruction::FDiv:
297+ assert (Operands.size () == 2 );
298+ ++FDivChanged;
299+ NewI = builder.CreateFDiv (Operands[0 ], Operands[1 ]);
300+ break ;
301+ case Instruction::FRem:
302+ assert (Operands.size () == 2 );
303+ ++FRemChanged;
304+ NewI = builder.CreateFRem (Operands[0 ], Operands[1 ]);
305+ break ;
306+ case Instruction::FCmp:
307+ assert (Operands.size () == 2 );
308+ ++FCmpChanged;
309+ NewI = builder.CreateFCmp (cast<FCmpInst>(&I)->getPredicate (),
310+ Operands[0 ], Operands[1 ]);
311+ break ;
312+ default :
313+ if (auto intrinsic = dyn_cast<IntrinsicInst>(&I)) {
314+ // XXX: this is not correct in general
315+ // some obvious failures include llvm.convert.to.fp16.*, llvm.vp.*to*, llvm.experimental.constrained.*to*, llvm.masked.*
316+ Type *RetTy = I.getType ();
317+ if (RetTy->getScalarType ()->isHalfTy ())
318+ RetTy = RetTy->getWithNewType (T_float32);
319+ NewI = replaceIntrinsicWith (intrinsic, RetTy, Operands);
131320 break ;
132- default :
133- abort ();
134- }
135- cast<Instruction>(NewI)->copyMetadata (I);
136- cast<Instruction>(NewI)->copyFastMathFlags (&I);
137- if (NewI->getType () != I.getType ()) {
138- ++TotalTrunc;
139- NewI = builder.CreateFPTrunc (NewI, I.getType ());
140321 }
141- I.replaceAllUsesWith (NewI);
142- erase.push_back (&I);
322+ abort ();
323+ }
324+ cast<Instruction>(NewI)->copyMetadata (I);
325+ cast<Instruction>(NewI)->copyFastMathFlags (&I);
326+ if (NewI->getType () != I.getType ()) {
327+ ++TotalTrunc;
328+ NewI = CreateFPCast (Instruction::FPTrunc, NewI, I.getType (), builder);
143329 }
330+ I.replaceAllUsesWith (NewI);
331+ erase.push_back (&I);
144332 }
145333 }
146334
0 commit comments