Skip to content

Commit 1668db9

Browse files
committed
1 parent dd7177a commit 1668db9

12 files changed

+469
-1
lines changed

lib/SPIRV/OCL20ToSPIRV.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,11 @@ class OCL20ToSPIRV : public ModulePass, public InstVisitor<OCL20ToSPIRV> {
282282
void visitCallLdexp(CallInst *CI, StringRef MangledName,
283283
StringRef DemangledName);
284284

285+
/// For cl_intel_convert_bfloat16_as_ushort
286+
void visitCallConvertBFloat16AsUshort(CallInst *CI, StringRef DemangledName);
287+
/// For cl_intel_convert_as_bfloat16_float
288+
void visitCallConvertAsBFloat16Float(CallInst *CI, StringRef DemangledName);
289+
285290
static char ID;
286291

287292
private:
@@ -560,6 +565,24 @@ void OCL20ToSPIRV::visitCallInst(CallInst &CI) {
560565
visitCallLdexp(&CI, MangledName, DemangledName);
561566
return;
562567
}
568+
if (DemangledName == kOCLBuiltinName::ConvertBFloat16AsUShort ||
569+
DemangledName == kOCLBuiltinName::ConvertBFloat162AsUShort2 ||
570+
DemangledName == kOCLBuiltinName::ConvertBFloat163AsUShort3 ||
571+
DemangledName == kOCLBuiltinName::ConvertBFloat164AsUShort4 ||
572+
DemangledName == kOCLBuiltinName::ConvertBFloat168AsUShort8 ||
573+
DemangledName == kOCLBuiltinName::ConvertBFloat1616AsUShort16) {
574+
visitCallConvertBFloat16AsUshort(&CI, DemangledName);
575+
return;
576+
}
577+
if (DemangledName == kOCLBuiltinName::ConvertAsBFloat16Float ||
578+
DemangledName == kOCLBuiltinName::ConvertAsBFloat162Float2 ||
579+
DemangledName == kOCLBuiltinName::ConvertAsBFloat163Float3 ||
580+
DemangledName == kOCLBuiltinName::ConvertAsBFloat164Float4 ||
581+
DemangledName == kOCLBuiltinName::ConvertAsBFloat168Float8 ||
582+
DemangledName == kOCLBuiltinName::ConvertAsBFloat1616Float16) {
583+
visitCallConvertAsBFloat16Float(&CI, DemangledName);
584+
return;
585+
}
563586
visitCallBuiltinSimple(&CI, MangledName, DemangledName);
564587
}
565588

@@ -1925,6 +1948,104 @@ void OCL20ToSPIRV::visitCallLdexp(CallInst *CI, StringRef MangledName,
19251948
visitCallBuiltinSimple(CI, MangledName, DemangledName);
19261949
}
19271950

1951+
void OCL20ToSPIRV::visitCallConvertBFloat16AsUshort(CallInst *CI,
1952+
StringRef DemangledName) {
1953+
Type *RetTy = CI->getType();
1954+
Type *ArgTy = CI->getOperand(0)->getType();
1955+
if (DemangledName == kOCLBuiltinName::ConvertBFloat16AsUShort) {
1956+
if (!RetTy->isIntegerTy(16U) || !ArgTy->isFloatTy())
1957+
report_fatal_error(
1958+
"OpConvertBFloat16AsUShort must be of i16 and take float");
1959+
} else {
1960+
VectorType *RetTyVec = cast<VectorType>(RetTy);
1961+
VectorType *ArgTyVec = cast<VectorType>(ArgTy);
1962+
if (!RetTyVec || !RetTyVec->getElementType()->isIntegerTy(16U) ||
1963+
!ArgTyVec || !ArgTyVec->getElementType()->isFloatTy())
1964+
report_fatal_error("OpConvertBFloat16NAsUShortN must be of <N x i16> and "
1965+
"take <N x float>");
1966+
unsigned RetTyVecSize = RetTyVec->getNumElements();
1967+
unsigned ArgTyVecSize = ArgTyVec->getNumElements();
1968+
if (DemangledName == kOCLBuiltinName::ConvertBFloat162AsUShort2) {
1969+
if (RetTyVecSize != 2 || ArgTyVecSize != 2)
1970+
report_fatal_error("ConvertBFloat162AsUShort2 must be of <2 x i16> and "
1971+
"take <2 x float>");
1972+
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat163AsUShort3) {
1973+
if (RetTyVecSize != 3 || ArgTyVecSize != 3)
1974+
report_fatal_error("ConvertBFloat163AsUShort3 must be of <3 x i16> and "
1975+
"take <3 x float>");
1976+
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat164AsUShort4) {
1977+
if (RetTyVecSize != 4 || ArgTyVecSize != 4)
1978+
report_fatal_error("ConvertBFloat164AsUShort4 must be of <4 x i16> and "
1979+
"take <4 x float>");
1980+
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat168AsUShort8) {
1981+
if (RetTyVecSize != 8 || ArgTyVecSize != 8)
1982+
report_fatal_error("ConvertBFloat168AsUShort8 must be of <8 x i16> and "
1983+
"take <8 x float>");
1984+
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat1616AsUShort16) {
1985+
if (RetTyVecSize != 16 || ArgTyVecSize != 16)
1986+
report_fatal_error("ConvertBFloat1616AsUShort16 must be of <16 x i16> "
1987+
"and take <16 x float>");
1988+
}
1989+
}
1990+
1991+
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
1992+
mutateCallInstSPIRV(
1993+
M, CI,
1994+
[=](CallInst *, std::vector<Value *> &Args) {
1995+
return getSPIRVFuncName(internal::OpConvertFToBF16INTEL);
1996+
},
1997+
&Attrs);
1998+
}
1999+
2000+
void OCL20ToSPIRV::visitCallConvertAsBFloat16Float(CallInst *CI,
2001+
StringRef DemangledName) {
2002+
Type *RetTy = CI->getType();
2003+
Type *ArgTy = CI->getOperand(0)->getType();
2004+
if (DemangledName == kOCLBuiltinName::ConvertAsBFloat16Float) {
2005+
if (!RetTy->isFloatTy() || !ArgTy->isIntegerTy(16U))
2006+
report_fatal_error(
2007+
"OpConvertAsBFloat16Float must be of float and take i16");
2008+
} else {
2009+
VectorType *RetTyVec = cast<VectorType>(RetTy);
2010+
VectorType *ArgTyVec = cast<VectorType>(ArgTy);
2011+
if (!RetTyVec || !RetTyVec->getElementType()->isFloatTy() || !ArgTyVec ||
2012+
!ArgTyVec->getElementType()->isIntegerTy(16U))
2013+
report_fatal_error("OpConvertAsBFloat16NFloatN must be of <N x float> "
2014+
"and take <N x i16>");
2015+
unsigned RetTyVecSize = RetTyVec->getNumElements();
2016+
unsigned ArgTyVecSize = ArgTyVec->getNumElements();
2017+
if (DemangledName == kOCLBuiltinName::ConvertAsBFloat162Float2) {
2018+
if (RetTyVecSize != 2 || ArgTyVecSize != 2)
2019+
report_fatal_error("ConvertAsBFloat162Float2 must be of <2 x float> "
2020+
"and take <2 x i16>");
2021+
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat163Float3) {
2022+
if (RetTyVecSize != 3 || ArgTyVecSize != 3)
2023+
report_fatal_error("ConvertAsBFloat163Float3 must be of <3 x float> "
2024+
"and take <3 x i16>");
2025+
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat164Float4) {
2026+
if (RetTyVecSize != 4 || ArgTyVecSize != 4)
2027+
report_fatal_error("ConvertAsBFloat164Float4 must be of <4 x float> "
2028+
"and take <4 x i16>");
2029+
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat168Float8) {
2030+
if (RetTyVecSize != 8 || ArgTyVecSize != 8)
2031+
report_fatal_error("ConvertAsBFloat168Float8 must be of <8 x float> "
2032+
"and take <8 x i16>");
2033+
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat1616Float16) {
2034+
if (RetTyVecSize != 16 || ArgTyVecSize != 16)
2035+
report_fatal_error("ConvertAsBFloat1616Float16 must be of <16 x float> "
2036+
"and take <16 x i16>");
2037+
}
2038+
}
2039+
2040+
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
2041+
mutateCallInstSPIRV(
2042+
M, CI,
2043+
[=](CallInst *, std::vector<Value *> &Args) {
2044+
return getSPIRVFuncName(internal::OpConvertBF16ToFINTEL);
2045+
},
2046+
&Attrs);
2047+
}
2048+
19282049
} // namespace SPIRV
19292050

19302051
INITIALIZE_PASS_BEGIN(OCL20ToSPIRV, "cl20tospv", "Transform OCL 2.0 to SPIR-V",

lib/SPIRV/OCLUtil.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,26 @@ const static char SubgroupBlockWriteINTELPrefix[] =
239239
const static char SubgroupImageMediaBlockINTELPrefix[] =
240240
"intel_sub_group_media_block";
241241
const static char LDEXP[] = "ldexp";
242+
#define _SPIRV_OP(x) \
243+
const static char ConvertBFloat16##x##AsUShort##x[] = \
244+
"intel_convert_bfloat16" #x "_as_ushort" #x;
245+
_SPIRV_OP()
246+
_SPIRV_OP(2)
247+
_SPIRV_OP(3)
248+
_SPIRV_OP(4)
249+
_SPIRV_OP(8)
250+
_SPIRV_OP(16)
251+
#undef _SPIRV_OP
252+
#define _SPIRV_OP(x) \
253+
const static char ConvertAsBFloat16##x##Float##x[] = \
254+
"intel_convert_as_bfloat16" #x "_float" #x;
255+
_SPIRV_OP()
256+
_SPIRV_OP(2)
257+
_SPIRV_OP(3)
258+
_SPIRV_OP(4)
259+
_SPIRV_OP(8)
260+
_SPIRV_OP(16)
261+
#undef _SPIRV_OP
242262
} // namespace kOCLBuiltinName
243263

244264
/// Offset for OpenCL image channel order enumeration values.

lib/SPIRV/SPIRVToOCL.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,11 @@ void SPIRVToOCL::visitCallInst(CallInst &CI) {
194194
visitCallSPIRVRelational(&CI, OC);
195195
return;
196196
}
197+
if (OC == internal::OpConvertFToBF16INTEL ||
198+
OC == internal::OpConvertBF16ToFINTEL) {
199+
visitCallSPIRVBFloat16Conversions(&CI, OC);
200+
return;
201+
}
197202
if (OCLSPIRVBuiltinMap::rfind(OC))
198203
visitCallSPIRVBuiltin(&CI, OC);
199204
}
@@ -967,6 +972,32 @@ void SPIRVToOCL::visitCallSPIRVGenericPtrMemSemantics(CallInst *CI) {
967972
&Attrs);
968973
}
969974

975+
void SPIRVToOCL::visitCallSPIRVBFloat16Conversions(CallInst *CI, Op OC) {
976+
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
977+
mutateCallInstOCL(
978+
M, CI,
979+
[=](CallInst *, std::vector<Value *> &Args) {
980+
Type *ArgTy = CI->getOperand(0)->getType();
981+
std::string N =
982+
ArgTy->isVectorTy()
983+
? std::to_string(cast<VectorType>(ArgTy)->getNumElements())
984+
: "";
985+
std::string Name;
986+
switch (static_cast<uint32_t>(OC)) {
987+
case internal::OpConvertFToBF16INTEL:
988+
Name = "intel_convert_bfloat16" + N + "_as_ushort" + N;
989+
break;
990+
case internal::OpConvertBF16ToFINTEL:
991+
Name = "intel_convert_as_bfloat16" + N + "_float" + N;
992+
break;
993+
default:
994+
break; // do nothing
995+
}
996+
return Name;
997+
},
998+
&Attrs);
999+
}
1000+
9701001
void SPIRVToOCL::visitCallSPIRVBuiltin(CallInst *CI, Op OC) {
9711002
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
9721003
mutateCallInstOCL(

lib/SPIRV/SPIRVToOCL.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,13 @@ class SPIRVToOCL : public ModulePass, public InstVisitor<SPIRVToOCL> {
157157
/// %1 = shl i31 %0, 8
158158
void visitCallSPIRVGenericPtrMemSemantics(CallInst *CI);
159159

160+
/// Transform __spirv_ConvertFToBF16INTELDv(N)_f to:
161+
/// intel_convert_bfloat16(N)_as_ushort(N)Dv(N)_f;
162+
/// and transform __spirv_ConvertBF16ToFINTELDv(N)_s to:
163+
/// intel_convert_as_bfloat16(N)_float(N)Dv(N)_t;
164+
/// where N is vector size
165+
void visitCallSPIRVBFloat16Conversions(CallInst *CI, Op OC);
166+
160167
/// Transform __spirv_* builtins to OCL 2.0 builtins.
161168
/// No change with arguments.
162169
void visitCallSPIRVBuiltin(CallInst *CI, Op OC);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
3+
4+
; CHECK-ERROR: OpConvertAsBFloat16Float must be of float and take i16
5+
6+
; ModuleID = 'kernel.cl'
7+
source_filename = "kernel.cl"
8+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir"
10+
11+
; Function Attrs: convergent noinline norecurse nounwind optnone
12+
define dso_local spir_kernel void @f() {
13+
entry:
14+
%call = call spir_func double @_Z31intel_convert_as_bfloat16_floatt(i32 zeroext 0)
15+
ret void
16+
}
17+
18+
; Function Attrs: convergent
19+
declare spir_func double @_Z31intel_convert_as_bfloat16_floatt(i32 zeroext)
20+
21+
!opencl.ocl.version = !{!0}
22+
23+
!0 = !{i32 2, i32 0}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
3+
4+
; CHECK-ERROR: OpConvertAsBFloat16NFloatN must be of <N x float> and take <N x i16>
5+
6+
; ModuleID = 'kernel.cl'
7+
source_filename = "kernel.cl"
8+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir"
10+
11+
; Function Attrs: convergent noinline norecurse nounwind optnone
12+
define dso_local spir_kernel void @f() {
13+
entry:
14+
%call = call spir_func <2 x double> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<2 x i32> zeroinitializer)
15+
ret void
16+
}
17+
18+
; ; Function Attrs: convergent
19+
declare spir_func <2 x double> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<2 x i32>)
20+
21+
!opencl.ocl.version = !{!0}
22+
23+
!0 = !{i32 2, i32 0}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
3+
4+
; CHECK-ERROR: ConvertAsBFloat162Float2 must be of <2 x float> and take <2 x i16>
5+
6+
; ModuleID = 'kernel.cl'
7+
source_filename = "kernel.cl"
8+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir"
10+
11+
; Function Attrs: convergent noinline norecurse nounwind optnone
12+
define dso_local spir_kernel void @f() {
13+
entry:
14+
%call = call spir_func <8 x float> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<4 x i16> zeroinitializer)
15+
ret void
16+
}
17+
18+
; Function Attrs: convergent
19+
declare spir_func <8 x float> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<4 x i16>)
20+
21+
!opencl.ocl.version = !{!0}
22+
23+
!0 = !{i32 2, i32 0}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
3+
4+
; CHECK-ERROR: OpConvertBFloat16AsUShort must be of i16 and take float
5+
6+
; ModuleID = 'kernel.cl'
7+
source_filename = "kernel.cl"
8+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir"
10+
11+
; Function Attrs: convergent noinline norecurse nounwind optnone
12+
define dso_local spir_kernel void @f() {
13+
entry:
14+
%call = call spir_func zeroext i16 @_Z32intel_convert_bfloat16_as_ushortf(double 0.000000e+00)
15+
ret void
16+
}
17+
18+
; Function Attrs: convergent
19+
declare spir_func zeroext i16 @_Z32intel_convert_bfloat16_as_ushortf(double)
20+
21+
!opencl.ocl.version = !{!0}
22+
23+
!0 = !{i32 2, i32 0}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
3+
4+
; CHECK-ERROR: OpConvertBFloat16NAsUShortN must be of <N x i16> and take <N x float>
5+
6+
; ModuleID = 'kernel.cl'
7+
source_filename = "kernel.cl"
8+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir"
10+
11+
; Function Attrs: convergent noinline norecurse nounwind optnone
12+
define dso_local spir_kernel void @f() {
13+
entry:
14+
%call = call spir_func <2 x i32> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<2 x double> zeroinitializer)
15+
ret void
16+
}
17+
18+
; Function Attrs: convergent
19+
declare spir_func <2 x i32> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<2 x double>)
20+
21+
!opencl.ocl.version = !{!0}
22+
23+
!0 = !{i32 2, i32 0}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: not llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR
3+
4+
; CHECK-ERROR: ConvertBFloat162AsUShort2 must be of <2 x i16> and take <2 x float>
5+
6+
; ModuleID = 'kernel.cl'
7+
source_filename = "kernel.cl"
8+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
9+
target triple = "spir"
10+
11+
; Function Attrs: convergent noinline norecurse nounwind optnone
12+
define dso_local spir_kernel void @f() {
13+
entry:
14+
%call = call spir_func <8 x i16> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<4 x float> zeroinitializer)
15+
ret void
16+
}
17+
18+
; Function Attrs: convergent
19+
declare spir_func <8 x i16> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<4 x float>)
20+
21+
!opencl.ocl.version = !{!0}
22+
23+
!0 = !{i32 2, i32 0}

0 commit comments

Comments
 (0)