Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions lib/SPIRV/OCLToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ class OCLToSPIRVBase : public InstVisitor<OCLToSPIRVBase> {
void visitCallLdexp(CallInst *CI, StringRef MangledName,
StringRef DemangledName);

/// For cl_intel_convert_bfloat16_as_ushort
void visitCallConvertBFloat16AsUshort(CallInst *CI, StringRef DemangledName);
/// For cl_intel_convert_as_bfloat16_float
void visitCallConvertAsBFloat16Float(CallInst *CI, StringRef DemangledName);

void setOCLTypeToSPIRV(OCLTypeToSPIRVBase *OCLTypeToSPIRV) {
OCLTypeToSPIRVPtr = OCLTypeToSPIRV;
}
Expand Down Expand Up @@ -574,6 +579,24 @@ void OCLToSPIRVBase::visitCallInst(CallInst &CI) {
visitCallLdexp(&CI, MangledName, DemangledName);
return;
}
if (DemangledName == kOCLBuiltinName::ConvertBFloat16AsUShort ||
DemangledName == kOCLBuiltinName::ConvertBFloat162AsUShort2 ||
DemangledName == kOCLBuiltinName::ConvertBFloat163AsUShort3 ||
DemangledName == kOCLBuiltinName::ConvertBFloat164AsUShort4 ||
DemangledName == kOCLBuiltinName::ConvertBFloat168AsUShort8 ||
DemangledName == kOCLBuiltinName::ConvertBFloat1616AsUShort16) {
visitCallConvertBFloat16AsUshort(&CI, DemangledName);
return;
}
if (DemangledName == kOCLBuiltinName::ConvertAsBFloat16Float ||
DemangledName == kOCLBuiltinName::ConvertAsBFloat162Float2 ||
DemangledName == kOCLBuiltinName::ConvertAsBFloat163Float3 ||
DemangledName == kOCLBuiltinName::ConvertAsBFloat164Float4 ||
DemangledName == kOCLBuiltinName::ConvertAsBFloat168Float8 ||
DemangledName == kOCLBuiltinName::ConvertAsBFloat1616Float16) {
visitCallConvertAsBFloat16Float(&CI, DemangledName);
return;
}
visitCallBuiltinSimple(&CI, MangledName, DemangledName);
}

Expand Down Expand Up @@ -1916,6 +1939,103 @@ void OCLToSPIRVBase::visitCallLdexp(CallInst *CI, StringRef MangledName,
visitCallBuiltinSimple(CI, MangledName, DemangledName);
}

void OCLToSPIRVBase::visitCallConvertBFloat16AsUshort(CallInst *CI,
StringRef DemangledName) {
Type *RetTy = CI->getType();
Type *ArgTy = CI->getOperand(0)->getType();
if (DemangledName == kOCLBuiltinName::ConvertBFloat16AsUShort) {
if (!RetTy->isIntegerTy(16U) || !ArgTy->isFloatTy())
report_fatal_error(
"OpConvertBFloat16AsUShort must be of i16 and take float");
} else {
FixedVectorType *RetTyVec = cast<FixedVectorType>(RetTy);
FixedVectorType *ArgTyVec = cast<FixedVectorType>(ArgTy);
if (!RetTyVec || !RetTyVec->getElementType()->isIntegerTy(16U) ||
!ArgTyVec || !ArgTyVec->getElementType()->isFloatTy())
report_fatal_error("OpConvertBFloat16NAsUShortN must be of <N x i16> and "
"take <N x float>");
unsigned RetTyVecSize = RetTyVec->getNumElements();
unsigned ArgTyVecSize = ArgTyVec->getNumElements();
if (DemangledName == kOCLBuiltinName::ConvertBFloat162AsUShort2) {
if (RetTyVecSize != 2 || ArgTyVecSize != 2)
report_fatal_error("ConvertBFloat162AsUShort2 must be of <2 x i16> and "
"take <2 x float>");
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat163AsUShort3) {
if (RetTyVecSize != 3 || ArgTyVecSize != 3)
report_fatal_error("ConvertBFloat163AsUShort3 must be of <3 x i16> and "
"take <3 x float>");
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat164AsUShort4) {
if (RetTyVecSize != 4 || ArgTyVecSize != 4)
report_fatal_error("ConvertBFloat164AsUShort4 must be of <4 x i16> and "
"take <4 x float>");
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat168AsUShort8) {
if (RetTyVecSize != 8 || ArgTyVecSize != 8)
report_fatal_error("ConvertBFloat168AsUShort8 must be of <8 x i16> and "
"take <8 x float>");
} else if (DemangledName == kOCLBuiltinName::ConvertBFloat1616AsUShort16) {
if (RetTyVecSize != 16 || ArgTyVecSize != 16)
report_fatal_error("ConvertBFloat1616AsUShort16 must be of <16 x i16> "
"and take <16 x float>");
}
}

AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstSPIRV(
M, CI,
[=](CallInst *, std::vector<Value *> &Args) {
return getSPIRVFuncName(internal::OpConvertFToBF16INTEL);
},
&Attrs);
}

void OCLToSPIRVBase::visitCallConvertAsBFloat16Float(CallInst *CI,
StringRef DemangledName) {
Type *RetTy = CI->getType();
Type *ArgTy = CI->getOperand(0)->getType();
if (DemangledName == kOCLBuiltinName::ConvertAsBFloat16Float) {
if (!RetTy->isFloatTy() || !ArgTy->isIntegerTy(16U))
report_fatal_error(
"OpConvertAsBFloat16Float must be of float and take i16");
} else {
FixedVectorType *RetTyVec = cast<FixedVectorType>(RetTy);
FixedVectorType *ArgTyVec = cast<FixedVectorType>(ArgTy);
if (!RetTyVec || !RetTyVec->getElementType()->isFloatTy() || !ArgTyVec ||
!ArgTyVec->getElementType()->isIntegerTy(16U))
report_fatal_error("OpConvertAsBFloat16NFloatN must be of <N x float> "
"and take <N x i16>");
unsigned RetTyVecSize = RetTyVec->getNumElements();
unsigned ArgTyVecSize = ArgTyVec->getNumElements();
if (DemangledName == kOCLBuiltinName::ConvertAsBFloat162Float2) {
if (RetTyVecSize != 2 || ArgTyVecSize != 2)
report_fatal_error("ConvertAsBFloat162Float2 must be of <2 x float> "
"and take <2 x i16>");
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat163Float3) {
if (RetTyVecSize != 3 || ArgTyVecSize != 3)
report_fatal_error("ConvertAsBFloat163Float3 must be of <3 x float> "
"and take <3 x i16>");
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat164Float4) {
if (RetTyVecSize != 4 || ArgTyVecSize != 4)
report_fatal_error("ConvertAsBFloat164Float4 must be of <4 x float> "
"and take <4 x i16>");
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat168Float8) {
if (RetTyVecSize != 8 || ArgTyVecSize != 8)
report_fatal_error("ConvertAsBFloat168Float8 must be of <8 x float> "
"and take <8 x i16>");
} else if (DemangledName == kOCLBuiltinName::ConvertAsBFloat1616Float16) {
if (RetTyVecSize != 16 || ArgTyVecSize != 16)
report_fatal_error("ConvertAsBFloat1616Float16 must be of <16 x float> "
"and take <16 x i16>");
}
}

AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstSPIRV(
M, CI,
[=](CallInst *, std::vector<Value *> &Args) {
return getSPIRVFuncName(internal::OpConvertBF16ToFINTEL);
},
&Attrs);
}
} // namespace SPIRV

INITIALIZE_PASS_BEGIN(OCLToSPIRVLegacy, "ocl-to-spv",
Expand Down
20 changes: 20 additions & 0 deletions lib/SPIRV/OCLUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,26 @@ const static char SubgroupBlockWriteINTELPrefix[] =
const static char SubgroupImageMediaBlockINTELPrefix[] =
"intel_sub_group_media_block";
const static char LDEXP[] = "ldexp";
#define _SPIRV_OP(x) \
const static char ConvertBFloat16##x##AsUShort##x[] = \
"intel_convert_bfloat16" #x "_as_ushort" #x;
_SPIRV_OP()
_SPIRV_OP(2)
_SPIRV_OP(3)
_SPIRV_OP(4)
_SPIRV_OP(8)
_SPIRV_OP(16)
#undef _SPIRV_OP
#define _SPIRV_OP(x) \
const static char ConvertAsBFloat16##x##Float##x[] = \
"intel_convert_as_bfloat16" #x "_float" #x;
_SPIRV_OP()
_SPIRV_OP(2)
_SPIRV_OP(3)
_SPIRV_OP(4)
_SPIRV_OP(8)
_SPIRV_OP(16)
#undef _SPIRV_OP
} // namespace kOCLBuiltinName

/// Offset for OpenCL image channel order enumeration values.
Expand Down
31 changes: 31 additions & 0 deletions lib/SPIRV/SPIRVToOCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ void SPIRVToOCLBase::visitCallInst(CallInst &CI) {
visitCallSPIRVRelational(&CI, OC);
return;
}
if (OC == internal::OpConvertFToBF16INTEL ||
OC == internal::OpConvertBF16ToFINTEL) {
visitCallSPIRVBFloat16Conversions(&CI, OC);
return;
}
if (OCLSPIRVBuiltinMap::rfind(OC))
visitCallSPIRVBuiltin(&CI, OC);
}
Expand Down Expand Up @@ -970,6 +975,32 @@ void SPIRVToOCLBase::visitCallSPIRVGenericPtrMemSemantics(CallInst *CI) {
&Attrs);
}

void SPIRVToOCLBase::visitCallSPIRVBFloat16Conversions(CallInst *CI, Op OC) {
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstOCL(
M, CI,
[=](CallInst *, std::vector<Value *> &Args) {
Type *ArgTy = CI->getOperand(0)->getType();
std::string N =
ArgTy->isVectorTy()
? std::to_string(cast<FixedVectorType>(ArgTy)->getNumElements())
: "";
std::string Name;
switch (static_cast<uint32_t>(OC)) {
case internal::OpConvertFToBF16INTEL:
Name = "intel_convert_bfloat16" + N + "_as_ushort" + N;
break;
case internal::OpConvertBF16ToFINTEL:
Name = "intel_convert_as_bfloat16" + N + "_float" + N;
break;
default:
break; // do nothing
}
return Name;
},
&Attrs);
}

void SPIRVToOCLBase::visitCallSPIRVBuiltin(CallInst *CI, Op OC) {
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstOCL(
Expand Down
7 changes: 7 additions & 0 deletions lib/SPIRV/SPIRVToOCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,13 @@ class SPIRVToOCLBase : public InstVisitor<SPIRVToOCLBase> {
/// %1 = shl i31 %0, 8
void visitCallSPIRVGenericPtrMemSemantics(CallInst *CI);

/// Transform __spirv_ConvertFToBF16INTELDv(N)_f to:
/// intel_convert_bfloat16(N)_as_ushort(N)Dv(N)_f;
/// and transform __spirv_ConvertBF16ToFINTELDv(N)_s to:
/// intel_convert_as_bfloat16(N)_float(N)Dv(N)_t;
/// where N is vector size
void visitCallSPIRVBFloat16Conversions(CallInst *CI, Op OC);

/// Transform __spirv_* builtins to OCL 2.0 builtins.
/// No change with arguments.
void visitCallSPIRVBuiltin(CallInst *CI, Op OC);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
; RUN: llvm-as %s -o %t.bc
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: OpConvertAsBFloat16Float must be of float and take i16

; ModuleID = 'kernel.cl'
source_filename = "kernel.cl"
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent noinline norecurse nounwind optnone
define dso_local spir_kernel void @f() {
entry:
%call = call spir_func double @_Z31intel_convert_as_bfloat16_floatt(i32 zeroext 0)
ret void
}

; Function Attrs: convergent
declare spir_func double @_Z31intel_convert_as_bfloat16_floatt(i32 zeroext)

!opencl.ocl.version = !{!0}

!0 = !{i32 2, i32 0}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
; RUN: llvm-as %s -o %t.bc
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: OpConvertAsBFloat16NFloatN must be of <N x float> and take <N x i16>

; ModuleID = 'kernel.cl'
source_filename = "kernel.cl"
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent noinline norecurse nounwind optnone
define dso_local spir_kernel void @f() {
entry:
%call = call spir_func <2 x double> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<2 x i32> zeroinitializer)
ret void
}

; ; Function Attrs: convergent
declare spir_func <2 x double> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<2 x i32>)

!opencl.ocl.version = !{!0}

!0 = !{i32 2, i32 0}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
; RUN: llvm-as %s -o %t.bc
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: ConvertAsBFloat162Float2 must be of <2 x float> and take <2 x i16>

; ModuleID = 'kernel.cl'
source_filename = "kernel.cl"
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent noinline norecurse nounwind optnone
define dso_local spir_kernel void @f() {
entry:
%call = call spir_func <8 x float> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<4 x i16> zeroinitializer)
ret void
}

; Function Attrs: convergent
declare spir_func <8 x float> @_Z33intel_convert_as_bfloat162_float2Dv2_t(<4 x i16>)

!opencl.ocl.version = !{!0}

!0 = !{i32 2, i32 0}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
; RUN: llvm-as %s -o %t.bc
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: OpConvertBFloat16AsUShort must be of i16 and take float

; ModuleID = 'kernel.cl'
source_filename = "kernel.cl"
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent noinline norecurse nounwind optnone
define dso_local spir_kernel void @f() {
entry:
%call = call spir_func zeroext i16 @_Z32intel_convert_bfloat16_as_ushortf(double 0.000000e+00)
ret void
}

; Function Attrs: convergent
declare spir_func zeroext i16 @_Z32intel_convert_bfloat16_as_ushortf(double)

!opencl.ocl.version = !{!0}

!0 = !{i32 2, i32 0}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
; RUN: llvm-as %s -o %t.bc
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: OpConvertBFloat16NAsUShortN must be of <N x i16> and take <N x float>

; ModuleID = 'kernel.cl'
source_filename = "kernel.cl"
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent noinline norecurse nounwind optnone
define dso_local spir_kernel void @f() {
entry:
%call = call spir_func <2 x i32> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<2 x double> zeroinitializer)
ret void
}

; Function Attrs: convergent
declare spir_func <2 x i32> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<2 x double>)

!opencl.ocl.version = !{!0}

!0 = !{i32 2, i32 0}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
; RUN: llvm-as %s -o %t.bc
; RUN: not --crash llvm-spirv %t.bc -o %t.spv 2>&1 | FileCheck %s --check-prefix=CHECK-ERROR

; CHECK-ERROR: ConvertBFloat162AsUShort2 must be of <2 x i16> and take <2 x float>

; ModuleID = 'kernel.cl'
source_filename = "kernel.cl"
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir"

; Function Attrs: convergent noinline norecurse nounwind optnone
define dso_local spir_kernel void @f() {
entry:
%call = call spir_func <8 x i16> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<4 x float> zeroinitializer)
ret void
}

; Function Attrs: convergent
declare spir_func <8 x i16> @_Z34intel_convert_bfloat162_as_ushort2Dv2_f(<4 x float>)

!opencl.ocl.version = !{!0}

!0 = !{i32 2, i32 0}
Loading