Skip to content

Commit f2c627e

Browse files
authored
codegen: explicitly handle Float16 intrinsics (#45249)
Fixes #44829, until llvm fixes the support for these intrinsics itself Also need to handle vectors, since the vectorizer may have introduced them. Also change our runtime emulation versions to f32 for consistency.
1 parent 2d40898 commit f2c627e

File tree

5 files changed

+294
-92
lines changed

5 files changed

+294
-92
lines changed

src/APInt-C.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) {
316316
void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
317317
double Val;
318318
if (numbits == 16)
319-
Val = __gnu_h2f_ieee(*(uint16_t*)pa);
319+
Val = julia__gnu_h2f_ieee(*(uint16_t*)pa);
320320
else if (numbits == 32)
321321
Val = *(float*)pa;
322322
else if (numbits == 64)
@@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
391391
val = a.roundToDouble(true);
392392
}
393393
if (onumbits == 16)
394-
*(uint16_t*)pr = __gnu_f2h_ieee(val);
394+
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
395395
else if (onumbits == 32)
396396
*(float*)pr = val;
397397
else if (onumbits == 64)
@@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
408408
val = a.roundToDouble(false);
409409
}
410410
if (onumbits == 16)
411-
*(uint16_t*)pr = __gnu_f2h_ieee(val);
411+
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
412412
else if (onumbits == 32)
413413
*(float*)pr = val;
414414
else if (onumbits == 64)

src/julia.expmap

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,6 @@
3737
environ;
3838
__progname;
3939

40-
/* compiler run-time intrinsics */
41-
__gnu_h2f_ieee;
42-
__extendhfsf2;
43-
__gnu_f2h_ieee;
44-
__truncdfhf2;
45-
4640
local:
4741
*;
4842
};

src/julia_internal.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,8 +1523,18 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
15231523
#define JL_GC_ASSERT_LIVE(x) (void)(x)
15241524
#endif
15251525

1526-
float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
1527-
uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
1526+
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
1527+
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
1528+
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
1529+
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT;
1530+
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT;
1531+
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT;
1532+
//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) JL_NOTSAFEPOINT;
1533+
//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) JL_NOTSAFEPOINT;
1534+
//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) JL_NOTSAFEPOINT;
1535+
//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) JL_NOTSAFEPOINT;
1536+
//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) JL_NOTSAFEPOINT;
1537+
//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) JL_NOTSAFEPOINT;
15281538

15291539
#ifdef __cplusplus
15301540
}

src/llvm-demote-float16.cpp

Lines changed: 242 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,194 @@ INST_STATISTIC(FCmp);
4545

4646
namespace {
4747

48+
inline AttributeSet getFnAttrs(const AttributeList &Attrs)
49+
{
50+
#if JL_LLVM_VERSION >= 140000
51+
return Attrs.getFnAttrs();
52+
#else
53+
return Attrs.getFnAttributes();
54+
#endif
55+
}
56+
57+
inline AttributeSet getRetAttrs(const AttributeList &Attrs)
58+
{
59+
#if JL_LLVM_VERSION >= 140000
60+
return Attrs.getRetAttrs();
61+
#else
62+
return Attrs.getRetAttributes();
63+
#endif
64+
}
65+
66+
static Instruction *replaceIntrinsicWith(IntrinsicInst *call, Type *RetTy, ArrayRef<Value*> args)
67+
{
68+
Intrinsic::ID ID = call->getIntrinsicID();
69+
assert(ID);
70+
auto oldfType = call->getFunctionType();
71+
auto nargs = oldfType->getNumParams();
72+
assert(args.size() > nargs);
73+
SmallVector<Type*, 8> argTys(nargs);
74+
for (unsigned i = 0; i < nargs; i++)
75+
argTys[i] = args[i]->getType();
76+
auto newfType = FunctionType::get(RetTy, argTys, oldfType->isVarArg());
77+
78+
// Accumulate an array of overloaded types for the given intrinsic
79+
// and compute the new name mangling schema
80+
SmallVector<Type*, 4> overloadTys;
81+
{
82+
SmallVector<Intrinsic::IITDescriptor, 8> Table;
83+
getIntrinsicInfoTableEntries(ID, Table);
84+
ArrayRef<Intrinsic::IITDescriptor> TableRef = Table;
85+
auto res = Intrinsic::matchIntrinsicSignature(newfType, TableRef, overloadTys);
86+
assert(res == Intrinsic::MatchIntrinsicTypes_Match);
87+
(void)res;
88+
bool matchvararg = !Intrinsic::matchIntrinsicVarArg(newfType->isVarArg(), TableRef);
89+
assert(matchvararg);
90+
(void)matchvararg;
91+
}
92+
auto newF = Intrinsic::getDeclaration(call->getModule(), ID, overloadTys);
93+
assert(newF->getFunctionType() == newfType);
94+
newF->setCallingConv(call->getCallingConv());
95+
assert(args.back() == call->getCalledFunction());
96+
auto newCall = CallInst::Create(newF, args.drop_back(), "", call);
97+
newCall->setTailCallKind(call->getTailCallKind());
98+
auto old_attrs = call->getAttributes();
99+
newCall->setAttributes(AttributeList::get(call->getContext(), getFnAttrs(old_attrs),
100+
getRetAttrs(old_attrs), {})); // drop parameter attributes
101+
return newCall;
102+
}
103+
104+
105+
static Value* CreateFPCast(Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder)
106+
{
107+
Type *SrcTy = V->getType();
108+
Type *RetTy = DestTy;
109+
if (auto *VC = dyn_cast<Constant>(V)) {
110+
// The input IR often has things of the form
111+
// fcmp olt half %0, 0xH7C00
112+
// and we would like to avoid turning that constant into a call here
113+
// if we can simply constant fold it to the new type.
114+
VC = ConstantExpr::getCast(opcode, VC, DestTy, true);
115+
if (VC)
116+
return VC;
117+
}
118+
assert(SrcTy->isVectorTy() == DestTy->isVectorTy());
119+
if (SrcTy->isVectorTy()) {
120+
unsigned NumElems = cast<FixedVectorType>(SrcTy)->getNumElements();
121+
assert(cast<FixedVectorType>(DestTy)->getNumElements() == NumElems && "Mismatched cast");
122+
Value *NewV = UndefValue::get(DestTy);
123+
RetTy = RetTy->getScalarType();
124+
for (unsigned i = 0; i < NumElems; ++i) {
125+
Value *I = builder.getInt32(i);
126+
Value *Vi = builder.CreateExtractElement(V, I);
127+
Vi = CreateFPCast(opcode, Vi, RetTy, builder);
128+
NewV = builder.CreateInsertElement(NewV, Vi, I);
129+
}
130+
return NewV;
131+
}
132+
auto &M = *builder.GetInsertBlock()->getModule();
133+
auto &ctx = M.getContext();
134+
// Pick the Function to call in the Julia runtime
135+
StringRef Name;
136+
switch (opcode) {
137+
case Instruction::FPExt:
138+
// this is exact, so we only need one conversion
139+
assert(SrcTy->isHalfTy());
140+
Name = "julia__gnu_h2f_ieee";
141+
RetTy = Type::getFloatTy(ctx);
142+
break;
143+
case Instruction::FPTrunc:
144+
assert(DestTy->isHalfTy());
145+
if (SrcTy->isFloatTy())
146+
Name = "julia__gnu_f2h_ieee";
147+
else if (SrcTy->isDoubleTy())
148+
Name = "julia__truncdfhf2";
149+
break;
150+
// All F16 fit exactly in Int32 (-65504 to 65504)
151+
case Instruction::FPToSI: JL_FALLTHROUGH;
152+
case Instruction::FPToUI:
153+
assert(SrcTy->isHalfTy());
154+
Name = "julia__gnu_h2f_ieee";
155+
RetTy = Type::getFloatTy(ctx);
156+
break;
157+
case Instruction::SIToFP: JL_FALLTHROUGH;
158+
case Instruction::UIToFP:
159+
assert(DestTy->isHalfTy());
160+
Name = "julia__gnu_f2h_ieee";
161+
SrcTy = Type::getFloatTy(ctx);
162+
break;
163+
default:
164+
errs() << Instruction::getOpcodeName(opcode) << ' ';
165+
V->getType()->print(errs());
166+
errs() << " to ";
167+
DestTy->print(errs());
168+
errs() << " is an ";
169+
llvm_unreachable("invalid cast");
170+
}
171+
if (Name.empty()) {
172+
errs() << Instruction::getOpcodeName(opcode) << ' ';
173+
V->getType()->print(errs());
174+
errs() << " to ";
175+
DestTy->print(errs());
176+
errs() << " is an ";
177+
llvm_unreachable("illegal cast");
178+
}
179+
// Coerce the source to the required size and type
180+
auto T_int16 = Type::getInt16Ty(ctx);
181+
if (SrcTy->isHalfTy())
182+
SrcTy = T_int16;
183+
if (opcode == Instruction::SIToFP)
184+
V = builder.CreateSIToFP(V, SrcTy);
185+
else if (opcode == Instruction::UIToFP)
186+
V = builder.CreateUIToFP(V, SrcTy);
187+
else
188+
V = builder.CreateBitCast(V, SrcTy);
189+
// Call our intrinsic
190+
if (RetTy->isHalfTy())
191+
RetTy = T_int16;
192+
auto FT = FunctionType::get(RetTy, {SrcTy}, false);
193+
FunctionCallee F = M.getOrInsertFunction(Name, FT);
194+
Value *I = builder.CreateCall(F, {V});
195+
// Coerce the result to the expected type
196+
if (opcode == Instruction::FPToSI)
197+
I = builder.CreateFPToSI(I, DestTy);
198+
else if (opcode == Instruction::FPToUI)
199+
I = builder.CreateFPToUI(I, DestTy);
200+
else if (opcode == Instruction::FPExt)
201+
I = builder.CreateFPCast(I, DestTy);
202+
else
203+
I = builder.CreateBitCast(I, DestTy);
204+
return I;
205+
}
206+
48207
static bool demoteFloat16(Function &F)
49208
{
50209
auto &ctx = F.getContext();
51-
auto T_float16 = Type::getHalfTy(ctx);
52210
auto T_float32 = Type::getFloatTy(ctx);
53211

54212
SmallVector<Instruction *, 0> erase;
55213
for (auto &BB : F) {
56214
for (auto &I : BB) {
215+
// extend Float16 operands to Float32
216+
bool Float16 = I.getType()->getScalarType()->isHalfTy();
217+
for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) {
218+
Value *Op = I.getOperand(i);
219+
if (Op->getType()->getScalarType()->isHalfTy())
220+
Float16 = true;
221+
}
222+
if (!Float16)
223+
continue;
224+
225+
if (auto CI = dyn_cast<CastInst>(&I)) {
226+
if (CI->getOpcode() != Instruction::BitCast) { // aka !CI->isNoopCast(DL)
227+
++TotalChanged;
228+
IRBuilder<> builder(&I);
229+
Value *NewI = CreateFPCast(CI->getOpcode(), I.getOperand(0), I.getType(), builder);
230+
I.replaceAllUsesWith(NewI);
231+
erase.push_back(&I);
232+
}
233+
continue;
234+
}
235+
57236
switch (I.getOpcode()) {
58237
case Instruction::FNeg:
59238
case Instruction::FAdd:
@@ -64,6 +243,9 @@ static bool demoteFloat16(Function &F)
64243
case Instruction::FCmp:
65244
break;
66245
default:
246+
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I))
247+
if (intrinsic->getIntrinsicID())
248+
break;
67249
continue;
68250
}
69251

@@ -75,72 +257,78 @@ static bool demoteFloat16(Function &F)
75257
IRBuilder<> builder(&I);
76258

77259
// extend Float16 operands to Float32
78-
bool OperandsChanged = false;
260+
// XXX: Calls to llvm.fma.f16 may need to go to f64 to be correct?
79261
SmallVector<Value *, 2> Operands(I.getNumOperands());
80262
for (size_t i = 0; i < I.getNumOperands(); i++) {
81263
Value *Op = I.getOperand(i);
82-
if (Op->getType() == T_float16) {
264+
if (Op->getType()->getScalarType()->isHalfTy()) {
83265
++TotalExt;
84-
Op = builder.CreateFPExt(Op, T_float32);
85-
OperandsChanged = true;
266+
Op = CreateFPCast(Instruction::FPExt, Op, Op->getType()->getWithNewType(T_float32), builder);
86267
}
87268
Operands[i] = (Op);
88269
}
89270

90271
// recreate the instruction if any operands changed,
91272
// truncating the result back to Float16
92-
if (OperandsChanged) {
93-
Value *NewI;
94-
++TotalChanged;
95-
switch (I.getOpcode()) {
96-
case Instruction::FNeg:
97-
assert(Operands.size() == 1);
98-
++FNegChanged;
99-
NewI = builder.CreateFNeg(Operands[0]);
100-
break;
101-
case Instruction::FAdd:
102-
assert(Operands.size() == 2);
103-
++FAddChanged;
104-
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
105-
break;
106-
case Instruction::FSub:
107-
assert(Operands.size() == 2);
108-
++FSubChanged;
109-
NewI = builder.CreateFSub(Operands[0], Operands[1]);
110-
break;
111-
case Instruction::FMul:
112-
assert(Operands.size() == 2);
113-
++FMulChanged;
114-
NewI = builder.CreateFMul(Operands[0], Operands[1]);
115-
break;
116-
case Instruction::FDiv:
117-
assert(Operands.size() == 2);
118-
++FDivChanged;
119-
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
120-
break;
121-
case Instruction::FRem:
122-
assert(Operands.size() == 2);
123-
++FRemChanged;
124-
NewI = builder.CreateFRem(Operands[0], Operands[1]);
125-
break;
126-
case Instruction::FCmp:
127-
assert(Operands.size() == 2);
128-
++FCmpChanged;
129-
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
130-
Operands[0], Operands[1]);
273+
Value *NewI;
274+
++TotalChanged;
275+
switch (I.getOpcode()) {
276+
case Instruction::FNeg:
277+
assert(Operands.size() == 1);
278+
++FNegChanged;
279+
NewI = builder.CreateFNeg(Operands[0]);
280+
break;
281+
case Instruction::FAdd:
282+
assert(Operands.size() == 2);
283+
++FAddChanged;
284+
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
285+
break;
286+
case Instruction::FSub:
287+
assert(Operands.size() == 2);
288+
++FSubChanged;
289+
NewI = builder.CreateFSub(Operands[0], Operands[1]);
290+
break;
291+
case Instruction::FMul:
292+
assert(Operands.size() == 2);
293+
++FMulChanged;
294+
NewI = builder.CreateFMul(Operands[0], Operands[1]);
295+
break;
296+
case Instruction::FDiv:
297+
assert(Operands.size() == 2);
298+
++FDivChanged;
299+
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
300+
break;
301+
case Instruction::FRem:
302+
assert(Operands.size() == 2);
303+
++FRemChanged;
304+
NewI = builder.CreateFRem(Operands[0], Operands[1]);
305+
break;
306+
case Instruction::FCmp:
307+
assert(Operands.size() == 2);
308+
++FCmpChanged;
309+
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
310+
Operands[0], Operands[1]);
311+
break;
312+
default:
313+
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I)) {
314+
// XXX: this is not correct in general
315+
// some obvious failures include llvm.convert.to.fp16.*, llvm.vp.*to*, llvm.experimental.constrained.*to*, llvm.masked.*
316+
Type *RetTy = I.getType();
317+
if (RetTy->getScalarType()->isHalfTy())
318+
RetTy = RetTy->getWithNewType(T_float32);
319+
NewI = replaceIntrinsicWith(intrinsic, RetTy, Operands);
131320
break;
132-
default:
133-
abort();
134-
}
135-
cast<Instruction>(NewI)->copyMetadata(I);
136-
cast<Instruction>(NewI)->copyFastMathFlags(&I);
137-
if (NewI->getType() != I.getType()) {
138-
++TotalTrunc;
139-
NewI = builder.CreateFPTrunc(NewI, I.getType());
140321
}
141-
I.replaceAllUsesWith(NewI);
142-
erase.push_back(&I);
322+
abort();
323+
}
324+
cast<Instruction>(NewI)->copyMetadata(I);
325+
cast<Instruction>(NewI)->copyFastMathFlags(&I);
326+
if (NewI->getType() != I.getType()) {
327+
++TotalTrunc;
328+
NewI = CreateFPCast(Instruction::FPTrunc, NewI, I.getType(), builder);
143329
}
330+
I.replaceAllUsesWith(NewI);
331+
erase.push_back(&I);
144332
}
145333
}
146334

0 commit comments

Comments
 (0)