Skip to content

Commit f4ad7c4

Browse files

12 files changed

+468
-1
lines changed

llvm-spirv/lib/SPIRV/OCLToSPIRV.cpp

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,11 @@ class OCLToSPIRVBase : public InstVisitor<OCLToSPIRVBase> {
263263
void visitCallLdexp(CallInst *CI, StringRef MangledName,
264264
StringRef DemangledName);
265265

266+
/// For cl_intel_convert_bfloat16_as_ushort
267+
void visitCallConvertBFloat16AsUshort(CallInst *CI, StringRef DemangledName);
268+
/// For cl_intel_convert_as_bfloat16_float
269+
void visitCallConvertAsBFloat16Float(CallInst *CI, StringRef DemangledName);
270+
266271
void setOCLTypeToSPIRV(OCLTypeToSPIRVBase *OCLTypeToSPIRV) {
267272
OCLTypeToSPIRVPtr = OCLTypeToSPIRV;
268273
}
@@ -574,6 +579,24 @@ void OCLToSPIRVBase::visitCallInst(CallInst &CI) {
574579
visitCallLdexp(&CI, MangledName, DemangledName);
575580
return;
576581
}
582+
if (DemangledName == kOCLBuiltinName::ConvertBFloat16AsUShort ||
583+
DemangledName == kOCLBuiltinName::ConvertBFloat162AsUShort2 ||
584+
DemangledName == kOCLBuiltinName::ConvertBFloat163AsUShort3 ||
585+
DemangledName == kOCLBuiltinName::ConvertBFloat164AsUShort4 ||
586+
DemangledName == kOCLBuiltinName::ConvertBFloat168AsUShort8 ||
587+
DemangledName == kOCLBuiltinName::ConvertBFloat1616AsUShort16) {
588+
visitCallConvertBFloat16AsUshort(&CI, DemangledName);
589+
return;
590+
}
591+
if (DemangledName == kOCLBuiltinName::ConvertAsBFloat16Float ||
592+
DemangledName == kOCLBuiltinName::ConvertAsBFloat162Float2 ||
593+
DemangledName == kOCLBuiltinName::ConvertAsBFloat163Float3 ||
594+
DemangledName == kOCLBuiltinName::ConvertAsBFloat164Float4 ||
595+
DemangledName == kOCLBuiltinName::ConvertAsBFloat168Float8 ||
596+
DemangledName == kOCLBuiltinName::ConvertAsBFloat1616Float16) {
597+
visitCallConvertAsBFloat16Float(&CI, DemangledName);
598+
return;
599+
}
577600
visitCallBuiltinSimple(&CI, MangledName, DemangledName);
578601
}
579602

@@ -1916,6 +1939,103 @@ void OCLToSPIRVBase::visitCallLdexp(CallInst *CI, StringRef MangledName,
19161939
visitCallBuiltinSimple(CI, MangledName, DemangledName);
19171940
}
19181941

1942+
void OCLToSPIRVBase::visitCallConvertBFloat16AsUshort(CallInst *CI,
1943+
StringRef DemangledName) {
1944+
Type *RetTy = CI->getType();
1945+
Type *ArgTy = CI->getOperand(0)->getType();
1946+
if (DemangledName == kOCLBuiltinName::ConvertBFloat16AsUShort) {
1947+
if (!RetTy->isIntegerTy(16U) || !ArgTy->isFloatTy())
1948+
report_fatal_error(
1949+
"OpConvertBFloat16AsUShort must be of i16 and take float");
1950+
} else {
1951+
FixedVectorType *RetTyVec = cast<FixedVectorType>(RetTy);
1952+
FixedVectorType *ArgTyVec = cast<FixedVectorType>(ArgTy);
1953+
if (!RetTyVec || !RetTyVec->getElementType()->isIntegerTy(16U) ||
1954+
!ArgTyVec || !ArgTyVec->getElementType()->isFloatTy())
1955+
report_fatal_error("OpConvertBFloat16NAsUShortN must be of <N x i16> and "
1956+
"take <N x float>");
1957+
unsigned RetTyVecSize = RetTyVec->getNumElements();
1958+
unsigned ArgTyVecSize = ArgTyVec->getNumElements();
1959+
if (DemangledName == kOCLBuiltinName::ConvertBFloat162AsUShort2) {
1960+
if (RetTyVecSize != 2 || ArgTyVecSize != 2)
1961+
report_fatal_error("ConvertBFloat162AsUShort2 must be of <2 x i16> and "
1962+
"take <2 x float>");
1963+
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat163AsUShort3) {
1964+
if (RetTyVecSize != 3 || ArgTyVecSize != 3)
1965+
report_fatal_error("ConvertBFloat163AsUShort3 must be of <3 x i16> and "
1966+
"take <3 x float>");
1967+
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat164AsUShort4) {
1968+
if (RetTyVecSize != 4 || ArgTyVecSize != 4)
1969+
report_fatal_error("ConvertBFloat164AsUShort4 must be of <4 x i16> and "
1970+
"take <4 x float>");
1971+
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat168AsUShort8) {
1972+
if (RetTyVecSize != 8 || ArgTyVecSize != 8)
1973+
report_fatal_error("ConvertBFloat168AsUShort8 must be of <8 x i16> and "
1974+
"take <8 x float>");
1975+
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat1616AsUShort16) {
1976+
if (RetTyVecSize != 16 || ArgTyVecSize != 16)
1977+
report_fatal_error("ConvertBFloat1616AsUShort16 must be of <16 x i16> "
1978+
"and take <16 x float>");
1979+
}
1980+
}
1981+
1982+
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
1983+
mutateCallInstSPIRV(
1984+
M, CI,
1985+
[=](CallInst *, std::vector<Value *> &Args) {
1986+
return getSPIRVFuncName(internal::OpConvertFToBF16INTEL);
1987+
},
1988+
&Attrs);
1989+
}
1990+
1991+
void OCLToSPIRVBase::visitCallConvertAsBFloat16Float(CallInst *CI,
1992+
StringRef DemangledName) {
1993+
Type *RetTy = CI->getType();
1994+
Type *ArgTy = CI->getOperand(0)->getType();
1995+
if (DemangledName == kOCLBuiltinName::ConvertAsBFloat16Float) {
1996+
if (!RetTy->isFloatTy() || !ArgTy->isIntegerTy(16U))
1997+
report_fatal_error(
1998+
"OpConvertAsBFloat16Float must be of float and take i16");
1999+
} else {
2000+
FixedVectorType *RetTyVec = cast<FixedVectorType>(RetTy);
2001+
FixedVectorType *ArgTyVec = cast<FixedVectorType>(ArgTy);
2002+
if (!RetTyVec || !RetTyVec->getElementType()->isFloatTy() || !ArgTyVec ||
2003+
!ArgTyVec->getElementType()->isIntegerTy(16U))
2004+
report_fatal_error("OpConvertAsBFloat16NFloatN must be of <N x float> "
2005+
"and take <N x i16>");
2006+
unsigned RetTyVecSize = RetTyVec->getNumElements();
2007+
unsigned ArgTyVecSize = ArgTyVec->getNumElements();
2008+
if (DemangledName == kOCLBuiltinName::ConvertAsBFloat162Float2) {
2009+
if (RetTyVecSize != 2 || ArgTyVecSize != 2)
2010+
report_fatal_error("ConvertAsBFloat162Float2 must be of <2 x float> "
2011+
"and take <2 x i16>");
2012+
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat163Float3) {
2013+
if (RetTyVecSize != 3 || ArgTyVecSize != 3)
2014+
report_fatal_error("ConvertAsBFloat163Float3 must be of <3 x float> "
2015+
"and take <3 x i16>");
2016+
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat164Float4) {
2017+
if (RetTyVecSize != 4 || ArgTyVecSize != 4)
2018+
report_fatal_error("ConvertAsBFloat164Float4 must be of <4 x float> "
2019+
"and take <4 x i16>");
2020+
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat168Float8) {
2021+
if (RetTyVecSize != 8 || ArgTyVecSize != 8)
2022+
report_fatal_error("ConvertAsBFloat168Float8 must be of <8 x float> "
2023+
"and take <8 x i16>");
2024+
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat1616Float16) {
2025+
if (RetTyVecSize != 16 || ArgTyVecSize != 16)
2026+
report_fatal_error("ConvertAsBFloat1616Float16 must be of <16 x float> "
2027+
"and take <16 x i16>");
2028+
}
2029+
}
2030+
2031+
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
2032+
mutateCallInstSPIRV(
2033+
M, CI,
2034+
[=](CallInst *, std::vector<Value *> &Args) {
2035+
return getSPIRVFuncName(internal::OpConvertBF16ToFINTEL);
2036+
},
2037+
&Attrs);
2038+
}
19192039
} // namespace SPIRV
19202040

19212041
INITIALIZE_PASS_BEGIN(OCLToSPIRVLegacy, "ocl-to-spv",

llvm-spirv/lib/SPIRV/OCLUtil.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,26 @@ const static char SubgroupBlockWriteINTELPrefix[] =
305305
const static char SubgroupImageMediaBlockINTELPrefix[] =
306306
"intel_sub_group_media_block";
307307
const static char LDEXP[] = "ldexp";
308+
#define _SPIRV_OP(x) \
309+
const static char ConvertBFloat16##x##AsUShort##x[] = \
310+
"intel_convert_bfloat16" #x "_as_ushort" #x;
311+
_SPIRV_OP()
312+
_SPIRV_OP(2)
313+
_SPIRV_OP(3)
314+
_SPIRV_OP(4)
315+
_SPIRV_OP(8)
316+
_SPIRV_OP(16)
317+
#undef _SPIRV_OP
318+
#define _SPIRV_OP(x) \
319+
const static char ConvertAsBFloat16##x##Float##x[] = \
320+
"intel_convert_as_bfloat16" #x "_float" #x;
321+
_SPIRV_OP()
322+
_SPIRV_OP(2)
323+
_SPIRV_OP(3)
324+
_SPIRV_OP(4)
325+
_SPIRV_OP(8)
326+
_SPIRV_OP(16)
327+
#undef _SPIRV_OP
308328
} // namespace kOCLBuiltinName
309329

310330
/// Offset for OpenCL image channel order enumeration values.

llvm-spirv/lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,11 @@ void SPIRVToOCLBase::visitCallInst(CallInst &CI) {
205205
visitCallSPIRVRelational(&CI, OC);
206206
return;
207207
}
208+
if (OC == internal::OpConvertFToBF16INTEL ||
209+
OC == internal::OpConvertBF16ToFINTEL) {
210+
visitCallSPIRVBFloat16Conversions(&CI, OC);
211+
return;
212+
}
208213
if (OCLSPIRVBuiltinMap::rfind(OC))
209214
visitCallSPIRVBuiltin(&CI, OC);
210215
}
@@ -986,6 +991,32 @@ void SPIRVToOCLBase::visitCallSPIRVGenericPtrMemSemantics(CallInst *CI) {
986991
&Attrs);
987992
}
988993

994+
void SPIRVToOCLBase::visitCallSPIRVBFloat16Conversions(CallInst *CI, Op OC) {
995+
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
996+
mutateCallInstOCL(
997+
M, CI,
998+
[=](CallInst *, std::vector<Value *> &Args) {
999+
Type *ArgTy = CI->getOperand(0)->getType();
1000+
std::string N =
1001+
ArgTy->isVectorTy()
1002+
? std::to_string(cast<FixedVectorType>(ArgTy)->getNumElements())
1003+
: "";
1004+
std::string Name;
1005+
switch (static_cast<uint32_t>(OC)) {
1006+
case internal::OpConvertFToBF16INTEL:
1007+
Name = "intel_convert_bfloat16" + N + "_as_ushort" + N;
1008+
break;
1009+
case internal::OpConvertBF16ToFINTEL:
1010+
Name = "intel_convert_as_bfloat16" + N + "_float" + N;
1011+
break;
1012+
default:
1013+
break; // do nothing
1014+
}
1015+
return Name;
1016+
},
1017+
&Attrs);
1018+
}
1019+
9891020
void SPIRVToOCLBase::visitCallSPIRVBuiltin(CallInst *CI, Op OC) {
9901021
assert(CI->getCalledFunction() && "Unexpected indirect call");
9911022
AttributeList Attrs = CI->getCalledFunction()->getAttributes();

llvm-spirv/lib/SPIRV/SPIRVToOCL.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,13 @@ class SPIRVToOCLBase : public InstVisitor<SPIRVToOCLBase> {
161161
/// %1 = shl i31 %0, 8
162162
void visitCallSPIRVGenericPtrMemSemantics(CallInst *CI);
163163

164+
/// Transform __spirv_ConvertFToBF16INTELDv(N)_f to:
165+
/// intel_convert_bfloat16(N)_as_ushort(N)Dv(N)_f;
166+
/// and transform __spirv_ConvertBF16ToFINTELDv(N)_s to:
167+
/// intel_convert_as_bfloat16(N)_float(N)Dv(N)_t;
168+
/// where N is vector size
169+
void visitCallSPIRVBFloat16Conversions(CallInst *CI, Op OC);
170+
164171
/// Transform __spirv_* builtins to OCL 2.0 builtins.
165172
/// No change with arguments.
166173
void visitCallSPIRVBuiltin(CallInst *CI, Op OC);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
3+
4+
; CHECK-ERROR: OpConvertAsBFloat16Float must be of float and take i16
5+
6+
; ModuleID = 'kernel.cl'
7+
source_filename = "kernel.cl"
8+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir"
10+
11+
; Function Attrs: convergent noinline norecurse nounwind optnone
12+
define dso_local spir_kernel void @f() {
13+
entry:
14+
%call = call spir_func double @_Z31intel_convert_as_bfloat16_floatt(i32 zeroext 0)
15+
ret void
16+
}
17+
18+
; Function Attrs: convergent
19+
declare spir_func double @_Z31intel_convert_as_bfloat16_floatt(i32 zeroext)
20+
21+
!opencl.ocl.version = !{!0}
22+
23+
!0 = !{i32 2, i32 0}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
3+
4+
; CHECK-ERROR: OpConvertAsBFloat16NFloatN must be of <N x float> and take <N x i16>
5+
6+
; ModuleID = 'kernel.cl'
7+
source_filename = "kernel.cl"
8+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir"
10+
11+
; Function Attrs: convergent noinline norecurse nounwind optnone
12+
define dso_local spir_kernel void @f() {
13+
entry:
14+
%call = call spir_func <2 x double> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<2 x i32> zeroinitializer)
15+
ret void
16+
}
17+
18+
; ; Function Attrs: convergent
19+
declare spir_func <2 x double> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<2 x i32>)
20+
21+
!opencl.ocl.version = !{!0}
22+
23+
!0 = !{i32 2, i32 0}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
3+
4+
; CHECK-ERROR: ConvertAsBFloat162Float2 must be of <2 x float> and take <2 x i16>
5+
6+
; ModuleID = 'kernel.cl'
7+
source_filename = "kernel.cl"
8+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir"
10+
11+
; Function Attrs: convergent noinline norecurse nounwind optnone
12+
define dso_local spir_kernel void @f() {
13+
entry:
14+
%call = call spir_func <8 x float> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<4 x i16> zeroinitializer)
15+
ret void
16+
}
17+
18+
; Function Attrs: convergent
19+
declare spir_func <8 x float> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<4 x i16>)
20+
21+
!opencl.ocl.version = !{!0}
22+
23+
!0 = !{i32 2, i32 0}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
3+
4+
; CHECK-ERROR: OpConvertBFloat16AsUShort must be of i16 and take float
5+
6+
; ModuleID = 'kernel.cl'
7+
source_filename = "kernel.cl"
8+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir"
10+
11+
; Function Attrs: convergent noinline norecurse nounwind optnone
12+
define dso_local spir_kernel void @f() {
13+
entry:
14+
%call = call spir_func zeroext i16 @_Z32intel_convert_bfloat16_as_ushortf(double 0.000000e+00)
15+
ret void
16+
}
17+
18+
; Function Attrs: convergent
19+
declare spir_func zeroext i16 @_Z32intel_convert_bfloat16_as_ushortf(double)
20+
21+
!opencl.ocl.version = !{!0}
22+
23+
!0 = !{i32 2, i32 0}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
3+
4+
; CHECK-ERROR: OpConvertBFloat16NAsUShortN must be of <N x i16> and take <N x float>
5+
6+
; ModuleID = 'kernel.cl'
7+
source_filename = "kernel.cl"
8+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir"
10+
11+
; Function Attrs: convergent noinline norecurse nounwind optnone
12+
define dso_local spir_kernel void @f() {
13+
entry:
14+
%call = call spir_func <2 x i32> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<2 x double> zeroinitializer)
15+
ret void
16+
}
17+
18+
; Function Attrs: convergent
19+
declare spir_func <2 x i32> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<2 x double>)
20+
21+
!opencl.ocl.version = !{!0}
22+
23+
!0 = !{i32 2, i32 0}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
3+
4+
; CHECK-ERROR: ConvertBFloat162AsUShort2 must be of <2 x i16> and take <2 x float>
5+
6+
; ModuleID = 'kernel.cl'
7+
source_filename = "kernel.cl"
8+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir"
10+
11+
; Function Attrs: convergent noinline norecurse nounwind optnone
12+
define dso_local spir_kernel void @f() {
13+
entry:
14+
%call = call spir_func <8 x i16> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<4 x float> zeroinitializer)
15+
ret void
16+
}
17+
18+
; Function Attrs: convergent
19+
declare spir_func <8 x i16> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<4 x float>)
20+
21+
!opencl.ocl.version = !{!0}
22+
23+
!0 = !{i32 2, i32 0}

0 commit comments

Comments
 (0)