Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 42 additions & 7 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,42 @@ static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
return nullptr;
}

// If the function has pointer arguments, we are forced to re-create this
// function type from the very beginning, changing PointerType by
// TypedPointerType for each pointer argument. Otherwise, the same `Type*`
// potentially corresponds to different SPIR-V function type, effectively
// invalidating logic behind global registry and duplicates tracker.
static FunctionType *
fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
FunctionType *FTy, const SPIRVType *SRetTy,
const SmallVector<SPIRVType *, 4> &SArgTys) {
if (F.getParent()->getNamedMetadata("spv.cloned_funcs"))
return FTy;

bool hasArgPtrs = false;
for (auto &Arg : F.args()) {
// check if it's an instance of a non-typed PointerType
if (Arg.getType()->isPointerTy()) {
hasArgPtrs = true;
break;
}
}
if (!hasArgPtrs) {
Type *RetTy = FTy->getReturnType();
// check if it's an instance of a non-typed PointerType
if (!RetTy->isPointerTy())
return FTy;
}

// re-create function type, using TypedPointerType instead of PointerType to
// properly trace argument types
const Type *RetTy = GR->getTypeForSPIRVType(SRetTy);
SmallVector<Type *, 4> ArgTys;
for (auto SArgTy : SArgTys)
ArgTys.push_back(const_cast<Type *>(GR->getTypeForSPIRVType(SArgTy)));
return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
}

// This code restores function args/retvalue types for composite cases
// because the final types should still be aggregate whereas they're i32
// during the translation to cope with aggregate flattening etc.
Expand Down Expand Up @@ -162,7 +198,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,

// If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
// be legally reassigned later).
if (!OriginalArgType->isPointerTy())
if (!isPointerTy(OriginalArgType))
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);

// In case OriginalArgType is of pointer type, there are three possibilities:
Expand All @@ -179,8 +215,7 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
SPIRVType *ElementType = GR->getOrCreateSPIRVType(ByValRefType, MIRBuilder);
return GR->getOrCreateSPIRVPointerType(
ElementType, MIRBuilder,
addressSpaceToStorageClass(Arg->getType()->getPointerAddressSpace(),
ST));
addressSpaceToStorageClass(getPointerAddressSpace(Arg->getType()), ST));
}

for (auto User : Arg->users()) {
Expand Down Expand Up @@ -240,7 +275,6 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());

// Assign types and names to all args, and store their types for later.
FunctionType *FTy = getOriginalFunctionType(F);
SmallVector<SPIRVType *, 4> ArgTypeVRegs;
if (VRegs.size() > 0) {
unsigned i = 0;
Expand All @@ -255,7 +289,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,

if (Arg.hasName())
buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
if (Arg.getType()->isPointerTy()) {
if (isPointerTy(Arg.getType())) {
auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
if (DerefBytes != 0)
buildOpDecorate(VRegs[i][0], MIRBuilder,
Expand Down Expand Up @@ -322,7 +356,9 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
if (F.isDeclaration())
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
FunctionType *FTy = getOriginalFunctionType(F);
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
FTy, RetTy, ArgTypeVRegs, MIRBuilder);
uint32_t FuncControl = getFunctionControl(F);
Expand Down Expand Up @@ -429,7 +465,6 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
return false;
MachineFunction &MF = MIRBuilder.getMF();
GR->setCurrentFunc(MF);
FunctionType *FTy = nullptr;
const Function *CF = nullptr;
std::string DemangledName;
const Type *OrigRetTy = Info.OrigRet.Ty;
Expand All @@ -444,7 +479,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
// TODO: support constexpr casts and indirect calls.
if (CF == nullptr)
return false;
if ((FTy = getOriginalFunctionType(*CF)) != nullptr)
if (FunctionType *FTy = getOriginalFunctionType(*CF))
OrigRetTy = FTy->getReturnType();
}

Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ Instruction *SPIRVEmitIntrinsics::visitBitCastInst(BitCastInst &I) {
// varying element types. In case of IR coming from older versions of LLVM
// such bitcasts do not provide sufficient information, should be just skipped
// here, and handled in insertPtrCastOrAssignTypeInstr.
if (I.getType()->isPointerTy()) {
if (isPointerTy(I.getType())) {
I.replaceAllUsesWith(Source);
I.eraseFromParent();
return nullptr;
Expand Down Expand Up @@ -356,7 +356,7 @@ void SPIRVEmitIntrinsics::replacePointerOperandWithPtrCast(
ValueAsMetadata::getConstant(ExpectedElementTypeConst);
MDTuple *TyMD = MDNode::get(F->getContext(), CM);
MetadataAsValue *VMD = MetadataAsValue::get(F->getContext(), TyMD);
unsigned AddressSpace = Pointer->getType()->getPointerAddressSpace();
unsigned AddressSpace = getPointerAddressSpace(Pointer->getType());
bool FirstPtrCastOrAssignPtrType = true;

// Do not emit new spv_ptrcast if equivalent one already exists or when
Expand Down Expand Up @@ -419,7 +419,7 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
// Handle basic instructions:
StoreInst *SI = dyn_cast<StoreInst>(I);
if (SI && F->getCallingConv() == CallingConv::SPIR_KERNEL &&
SI->getValueOperand()->getType()->isPointerTy() &&
isPointerTy(SI->getValueOperand()->getType()) &&
isa<Argument>(SI->getValueOperand())) {
return replacePointerOperandWithPtrCast(
I, SI->getValueOperand(), IntegerType::getInt8Ty(F->getContext()), 0,
Expand Down Expand Up @@ -639,14 +639,14 @@ void SPIRVEmitIntrinsics::processGlobalValue(GlobalVariable &GV,
void SPIRVEmitIntrinsics::insertAssignPtrTypeIntrs(Instruction *I,
IRBuilder<> &B) {
reportFatalOnTokenType(I);
if (!I->getType()->isPointerTy() || !requireAssignType(I) ||
if (!isPointerTy(I->getType()) || !requireAssignType(I) ||
isa<BitCastInst>(I))
return;

setInsertPointSkippingPhis(B, I->getNextNode());

Constant *EltTyConst;
unsigned AddressSpace = I->getType()->getPointerAddressSpace();
unsigned AddressSpace = getPointerAddressSpace(I->getType());
if (auto *AI = dyn_cast<AllocaInst>(I))
EltTyConst = UndefValue::get(AI->getAllocatedType());
else if (auto *GEP = dyn_cast<GetElementPtrInst>(I))
Expand All @@ -662,7 +662,7 @@ void SPIRVEmitIntrinsics::insertAssignTypeIntrs(Instruction *I,
IRBuilder<> &B) {
reportFatalOnTokenType(I);
Type *Ty = I->getType();
if (!Ty->isVoidTy() && !Ty->isPointerTy() && requireAssignType(I)) {
if (!Ty->isVoidTy() && !isPointerTy(Ty) && requireAssignType(I)) {
setInsertPointSkippingPhis(B, I->getNextNode());
Type *TypeToAssign = Ty;
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
Expand Down
25 changes: 15 additions & 10 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
if (TypesInProcessing.count(Ty) && !Ty->isPointerTy())
if (TypesInProcessing.count(Ty) && !isPointerTy(Ty))
return nullptr;
TypesInProcessing.insert(Ty);
SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
Expand All @@ -762,11 +762,11 @@ SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
// will be added later. For special types it is already added to DT.
if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
!isSpecialOpaqueType(Ty)) {
if (!Ty->isPointerTy())
if (!isPointerTy(Ty))
DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
else
DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
Ty->getPointerAddressSpace(), &MIRBuilder.getMF(),
getPointerAddressSpace(Ty), &MIRBuilder.getMF(),
getSPIRVTypeID(SpirvType));
}

Expand All @@ -787,12 +787,12 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
const Type *Ty, MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
Register Reg;
if (!Ty->isPointerTy())
if (!isPointerTy(Ty))
Reg = DT.find(Ty, &MIRBuilder.getMF());
else
Reg =
DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
Ty->getPointerAddressSpace(), &MIRBuilder.getMF());
getPointerAddressSpace(Ty), &MIRBuilder.getMF());

if (Reg.isValid() && !isSpecialOpaqueType(Ty))
return getSPIRVTypeForVReg(Reg);
Expand Down Expand Up @@ -836,11 +836,16 @@ bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,

unsigned
SPIRVGlobalRegistry::getScalarOrVectorComponentCount(Register VReg) const {
if (SPIRVType *Type = getSPIRVTypeForVReg(VReg))
return Type->getOpcode() == SPIRV::OpTypeVector
? static_cast<unsigned>(Type->getOperand(2).getImm())
: 1;
return 0;
return getScalarOrVectorComponentCount(getSPIRVTypeForVReg(VReg));
}

unsigned
SPIRVGlobalRegistry::getScalarOrVectorComponentCount(SPIRVType *Type) const {
if (!Type)
return 0;
return Type->getOpcode() == SPIRV::OpTypeVector
? static_cast<unsigned>(Type->getOperand(2).getImm())
: 1;
}

unsigned
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,10 @@ class SPIRVGlobalRegistry {
// opcode (e.g. OpTypeBool, or OpTypeVector %x 4, where %x is OpTypeBool).
bool isScalarOrVectorOfType(Register VReg, unsigned TypeOpcode) const;

// Return number of elements in a vector if the given VReg is associated with
// Return number of elements in a vector if the argument is associated with
// a vector type. Return 1 for a scalar type, and 0 for a missing type.
unsigned getScalarOrVectorComponentCount(Register VReg) const;
unsigned getScalarOrVectorComponentCount(SPIRVType *Type) const;

// For vectors or scalars of booleans, integers and floats, return the scalar
// type's bitwidth. Otherwise calls llvm_unreachable().
Expand Down
41 changes: 41 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ class SPIRVInstructionSelector : public InstructionSelector {

bool selectConstVector(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;
bool selectSplatVector(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectCmp(Register ResVReg, const SPIRVType *ResType,
unsigned comparisonOpcode, MachineInstr &I) const;
Expand Down Expand Up @@ -313,6 +315,8 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,

case TargetOpcode::G_BUILD_VECTOR:
return selectConstVector(ResVReg, ResType, I);
case TargetOpcode::G_SPLAT_VECTOR:
return selectSplatVector(ResVReg, ResType, I);

case TargetOpcode::G_SHUFFLE_VECTOR: {
MachineBasicBlock &BB = *I.getParent();
Expand Down Expand Up @@ -1185,6 +1189,43 @@ bool SPIRVInstructionSelector::selectConstVector(Register ResVReg,
return MIB.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
if (ResType->getOpcode() != SPIRV::OpTypeVector)
report_fatal_error("Cannot select G_SPLAT_VECTOR with a non-vector result");
unsigned N = GR.getScalarOrVectorComponentCount(ResType);
unsigned OpIdx = I.getNumExplicitDefs();
if (!I.getOperand(OpIdx).isReg())
report_fatal_error("Unexpected argument in G_SPLAT_VECTOR");

// check if we may construct a constant vector
Register OpReg = I.getOperand(OpIdx).getReg();
bool IsConst = false;
if (SPIRVType *OpDef = MRI->getVRegDef(OpReg)) {
if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE &&
OpDef->getOperand(1).isReg()) {
if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg()))
OpDef = RefDef;
}
IsConst = OpDef->getOpcode() == TargetOpcode::G_CONSTANT ||
OpDef->getOpcode() == TargetOpcode::G_FCONSTANT;
}

if (!IsConst && N < 2)
report_fatal_error(
"There must be at least two constituent operands in a vector");

auto MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(),
TII.get(IsConst ? SPIRV::OpConstantComposite
: SPIRV::OpCompositeConstruct))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType));
for (unsigned i = 0; i < N; ++i)
MIB.addUse(OpReg);
return MIB.constrainAllUses(TII, TRI, RBI);
}

bool SPIRVInstructionSelector::selectCmp(Register ResVReg,
const SPIRVType *ResType,
unsigned CmpOpc,
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();

// TODO: add proper rules for vectors legalization.
getActionDefinitionsBuilder({G_BUILD_VECTOR, G_SHUFFLE_VECTOR}).alwaysLegal();
getActionDefinitionsBuilder(
{G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
.alwaysLegal();

// Vector Reduction Operations
getActionDefinitionsBuilder(
Expand Down
17 changes: 17 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "MCTargetDesc/SPIRVBaseInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/TypedPointerType.h"
#include <string>

namespace llvm {
Expand Down Expand Up @@ -100,5 +101,21 @@ bool isEntryPoint(const Function &F);

// Parse basic scalar type name, substring TypeName, and return LLVM type.
Type *parseBasicTypeName(StringRef TypeName, LLVMContext &Ctx);

// True if this is an instance of PointerType or TypedPointerType.
inline bool isPointerTy(const Type *T) {
return T->getTypeID() == Type::PointerTyID ||
T->getTypeID() == Type::TypedPointerTyID;
}

// Get the address space of this pointer or pointer vector type for instances of
// PointerType or TypedPointerType.
inline unsigned getPointerAddressSpace(const Type *T) {
Type *SubT = T->getScalarType();
return SubT->getTypeID() == Type::PointerTyID
? cast<PointerType>(SubT)->getAddressSpace()
: cast<TypedPointerType>(SubT)->getAddressSpace();
}

} // namespace llvm
#endif // LLVM_LIB_TARGET_SPIRV_SPIRVUTILS_H
1 change: 1 addition & 0 deletions llvm/test/CodeGen/SPIRV/capability-kernel.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: OpCapability Addresses

Expand Down
29 changes: 29 additions & 0 deletions llvm/test/CodeGen/SPIRV/pointers/typeof-ptr-int.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
; This test is to check that two functions have different SPIR-V type
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for tackling this one! This was a big (repeating) issue appearing when validating OpenCL CTS kernels. Surprisingly, the same function type definitions did not cause any runtime problems at least in OpenCL CTS.

; definitions, even though their LLVM function types are identical.

; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: OpName %[[Fun32:.*]] "tp_arg_i32"
; CHECK-DAG: OpName %[[Fun64:.*]] "tp_arg_i64"
; CHECK-DAG: %[[TyI32:.*]] = OpTypeInt 32 0
; CHECK-DAG: %[[TyVoid:.*]] = OpTypeVoid
; CHECK-DAG: %[[TyPtr32:.*]] = OpTypePointer Function %[[TyI32]]
; CHECK-DAG: %[[TyFun32:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtr32]]
; CHECK-DAG: %[[TyI64:.*]] = OpTypeInt 64 0
; CHECK-DAG: %[[TyPtr64:.*]] = OpTypePointer Function %[[TyI64]]
; CHECK-DAG: %[[TyFun64:.*]] = OpTypeFunction %[[TyVoid]] %[[TyPtr64]]
; CHECK-DAG: %[[Fun32]] = OpFunction %[[TyVoid]] None %[[TyFun32]]
; CHECK-DAG: %[[Fun64]] = OpFunction %[[TyVoid]] None %[[TyFun64]]

define spir_kernel void @tp_arg_i32(ptr %ptr) {
entry:
store i32 1, ptr %ptr
ret void
}

define spir_kernel void @tp_arg_i64(ptr %ptr) {
entry:
store i64 1, ptr %ptr
ret void
}
1 change: 1 addition & 0 deletions llvm/test/CodeGen/SPIRV/relationals.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}

declare dso_local spir_func <4 x i8> @_Z13__spirv_IsNanIDv4_aDv4_fET_T0_(<4 x float>)
declare dso_local spir_func <4 x i8> @_Z13__spirv_IsInfIDv4_aDv4_fET_T0_(<4 x float>)
Expand Down
1 change: 1 addition & 0 deletions llvm/test/CodeGen/SPIRV/simple.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

;; Support of doubles is required.
; CHECK: OpCapability Float64
Expand Down
1 change: 1 addition & 0 deletions llvm/test/CodeGen/SPIRV/transcoding/isequal.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-SPIRV-NOT: OpSConvert

Expand Down
1 change: 1 addition & 0 deletions llvm/test/CodeGen/SPIRV/transcoding/relationals_double.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}

;; This test checks following SYCL relational builtins with double and double2
;; types:
Expand Down
1 change: 1 addition & 0 deletions llvm/test/CodeGen/SPIRV/transcoding/relationals_float.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}

;; This test checks following SYCL relational builtins with float and float2
;; types:
Expand Down
Loading