Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,21 @@ def UndefAttr : CIR_Attr<"Undef", "undef", [TypedAttrInterface]> {
let assemblyFormat = [{}];
}

//===----------------------------------------------------------------------===//
// PoisonAttr
//===----------------------------------------------------------------------===//

def PoisonAttr : CIR_Attr<"Poison", "poison", [TypedAttrInterface]> {
let summary = "Represent an poison constant";
let description = [{
The PoisonAttr represents an poison constant, corresponding to LLVM's notion
of poison.
}];

let parameters = (ins AttributeSelfTypeParameter<"">:$type);
let assemblyFormat = [{}];
}

//===----------------------------------------------------------------------===//
// ConstArrayAttr
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 37 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenBuiltinAArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2265,6 +2265,32 @@ static mlir::Value emitNeonRShiftImm(CIRGenFunction &cgf, mlir::Value shiftVec,
false /* right shift */);
}

/// Vectorize value, usually for argument of a neon SISD intrinsic call.
static void vecExtendIntValue(CIRGenFunction &cgf, cir::VectorType argVTy,
mlir::Value &arg, mlir::Location loc) {
CIRGenBuilderTy &builder = cgf.getBuilder();
cir::IntType eltTy = mlir::dyn_cast<cir::IntType>(argVTy.getEltType());
assert(mlir::isa<cir::IntType>(arg.getType()) && eltTy);
// The constant argument to an _n_ intrinsic always has Int32Ty, so truncate
// it before inserting.
arg = builder.createIntCast(arg, eltTy);
mlir::Value zero = builder.getConstInt(loc, cgf.SizeTy, 0);
mlir::Value poison = builder.create<cir::ConstantOp>(
loc, eltTy, builder.getAttr<cir::PoisonAttr>(eltTy));
arg = builder.create<cir::VecInsertOp>(
loc, builder.create<cir::VecSplatOp>(loc, argVTy, poison), arg, zero);
}

/// Reduce vector type value to scalar, usually for result of a
/// neon SISD intrinsic call
static mlir::Value vecReduceIntValue(CIRGenFunction &cgf, mlir::Value val,
mlir::Location loc) {
CIRGenBuilderTy &builder = cgf.getBuilder();
assert(mlir::isa<cir::VectorType>(val.getType()));
return builder.create<cir::VecExtractOp>(
loc, val, builder.getConstInt(loc, cgf.SizeTy, 0));
}

mlir::Value emitNeonCall(CIRGenBuilderTy &builder,
llvm::SmallVector<mlir::Type> argTypes,
llvm::SmallVectorImpl<mlir::Value> &args,
Expand Down Expand Up @@ -2853,8 +2879,17 @@ static mlir::Value emitCommonNeonSISDBuiltinExpr(
llvm_unreachable(" neon_vqmovnh_s16 NYI ");
case NEON::BI__builtin_neon_vqmovnh_u16:
llvm_unreachable(" neon_vqmovnh_u16 NYI ");
case NEON::BI__builtin_neon_vqmovns_s32:
llvm_unreachable(" neon_vqmovns_s32 NYI ");
case NEON::BI__builtin_neon_vqmovns_s32: {
mlir::Location loc = cgf.getLoc(expr->getExprLoc());
cir::VectorType argVecTy =
cir::VectorType::get(&(cgf.getMLIRContext()), cgf.SInt32Ty, 4);
cir::VectorType resVecTy =
cir::VectorType::get(&(cgf.getMLIRContext()), cgf.SInt16Ty, 4);
vecExtendIntValue(cgf, argVecTy, ops[0], loc);
mlir::Value result = emitNeonCall(builder, {argVecTy}, ops,
"aarch64.neon.sqxtn", resVecTy, loc);
return vecReduceIntValue(cgf, result, loc);
}
case NEON::BI__builtin_neon_vqmovns_u32:
llvm_unreachable(" neon_vqmovns_u32 NYI ");
case NEON::BI__builtin_neon_vqmovund_s64:
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,12 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
return op->emitOpError("undef expects non-void type");
}

if (isa<cir::PoisonAttr>(attrType)) {
if (!::mlir::isa<cir::VoidType>(opType))
return success();
return op->emitOpError("poison expects non-void type");
}

if (mlir::isa<cir::BoolAttr>(attrType)) {
if (!mlir::isa<cir::BoolType>(opType))
return op->emitOpError("result type (")
Expand Down
33 changes: 30 additions & 3 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,16 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::UndefAttr undefAttr,
loc, converter->convertType(undefAttr.getType()));
}

/// PoisonAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::PoisonAttr poisonAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
auto loc = parentOp->getLoc();
return rewriter.create<mlir::LLVM::PoisonOp>(
loc, converter->convertType(poisonAttr.getType()));
}

/// ConstStruct visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstStructAttr constStruct,
Expand Down Expand Up @@ -644,6 +654,8 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
return lowerCirAttrAsValue(parentOp, zeroAttr, rewriter, converter);
if (const auto undefAttr = mlir::dyn_cast<cir::UndefAttr>(attr))
return lowerCirAttrAsValue(parentOp, undefAttr, rewriter, converter);
if (const auto poisonAttr = mlir::dyn_cast<cir::PoisonAttr>(attr))
return lowerCirAttrAsValue(parentOp, poisonAttr, rewriter, converter);
if (const auto globalAttr = mlir::dyn_cast<cir::GlobalViewAttr>(attr))
return lowerCirAttrAsValue(parentOp, globalAttr, rewriter, converter);
if (const auto vtableAttr = mlir::dyn_cast<cir::VTableAttr>(attr))
Expand Down Expand Up @@ -1555,6 +1567,14 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
mlir::ConversionPatternRewriter &rewriter) const {
mlir::Attribute attr = op.getValue();

// Regardless of the type, we should lower the constant of poison value
// into PoisonOp.
if (mlir::isa<cir::PoisonAttr>(attr)) {
rewriter.replaceOp(
op, lowerCirAttrAsValue(op, attr, rewriter, getTypeConverter()));
return mlir::success();
}

if (mlir::isa<mlir::IntegerType>(op.getType())) {
// Verified cir.const operations cannot actually be of these types, but the
// lowering pass may generate temporary cir.const operations with these
Expand Down Expand Up @@ -1695,6 +1715,7 @@ mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
mlir::Value result = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
assert(vecTy.getSize() == op.getElements().size() &&
"cir.vec.create op count doesn't match vector type elements count");

for (uint64_t i = 0; i < vecTy.getSize(); ++i) {
mlir::Value indexValue =
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
Expand Down Expand Up @@ -1745,15 +1766,21 @@ mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
assert(vecTy && "result type of cir.vec.splat op is not VectorType");
auto llvmTy = typeConverter->convertType(vecTy);
auto loc = op.getLoc();
mlir::Value undef = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
mlir::Value indexValue =
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
mlir::Value elementValue = adaptor.getValue();
if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
// If the splat value is poison, then we can just use poison value
// for the entire vector.
rewriter.replaceOp(op, poison);
return mlir::success();
}
mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
loc, undef, elementValue, indexValue);
loc, poison, elementValue, indexValue);
SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
mlir::Value shuffled = rewriter.create<mlir::LLVM::ShuffleVectorOp>(
loc, oneElement, undef, zeroValues);
loc, oneElement, poison, zeroValues);
rewriter.replaceOp(op, shuffled);
return mlir::success();
}
Expand Down
27 changes: 19 additions & 8 deletions clang/test/CIR/CodeGen/AArch64/neon.c
Original file line number Diff line number Diff line change
Expand Up @@ -14611,14 +14611,25 @@ void test_vst1q_s64(int64_t *a, int64x2_t b) {
// return (int8_t)vqmovnh_s16(a);
// }

// NYI-LABEL: @test_vqmovns_s32(
// NYI: [[TMP0:%.*]] = insertelement <4 x i32> poison, i32 %a, i64 0
// NYI: [[VQMOVNS_S32_I:%.*]] = call <4 x i16> @llvm.aarch64.neon.sqxtn.v4i16(<4 x i32> [[TMP0]])
// NYI: [[TMP1:%.*]] = extractelement <4 x i16> [[VQMOVNS_S32_I]], i64 0
// NYI: ret i16 [[TMP1]]
// int16_t test_vqmovns_s32(int32_t a) {
// return (int16_t)vqmovns_s32(a);
// }
int16_t test_vqmovns_s32(int32_t a) {
return (int16_t)vqmovns_s32(a);

// CIR-LABEL: vqmovns_s32
// CIR: [[A:%.*]] = cir.load {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: [[VQMOVNS_S32_ZERO1:%.*]] = cir.const #cir.int<0> : !u64i
// CIR: [[POISON:%.*]] = cir.const #cir.poison : !s32i
// CIR: [[POISON_VEC:%.*]] = cir.vec.splat [[POISON]] : !s32i, !cir.vector<!s32i x 4>
// CIR: [[TMP0:%.*]] = cir.vec.insert [[A]], [[POISON_VEC]][[[VQMOVNS_S32_ZERO1]] : !u64i] : !cir.vector<!s32i x 4>
// CIR: [[VQMOVNS_S32_I:%.*]] = cir.llvm.intrinsic "aarch64.neon.sqxtn" [[TMP0]] : (!cir.vector<!s32i x 4>) -> !cir.vector<!s16i x 4>
// CIR: [[VQMOVNS_S32_ZERO2:%.*]] = cir.const #cir.int<0> : !u64i
// CIR: [[TMP1:%.*]] = cir.vec.extract [[VQMOVNS_S32_I]][[[VQMOVNS_S32_ZERO2]] : !u64i] : !cir.vector<!s16i x 4>

// LLVM: {{.*}}@test_vqmovns_s32(i32{{.*}}[[a:%.*]])
// LLVM: [[TMP0:%.*]] = insertelement <4 x i32> poison, i32 [[a]], i64 0
// LLVM: [[VQMOVNS_S32_I:%.*]] = call <4 x i16> @llvm.aarch64.neon.sqxtn.v4i16(<4 x i32> [[TMP0]])
// LLVM: [[TMP1:%.*]] = extractelement <4 x i16> [[VQMOVNS_S32_I]], i64 0
// LLVM: ret i16 [[TMP1]]
}

// NYI-LABEL: @test_vqmovnd_s64(
// NYI: [[VQMOVND_S64_I:%.*]] = call i32 @llvm.aarch64.neon.scalar.sqxtn.i32.i64(i64 %a)
Expand Down
6 changes: 6 additions & 0 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,12 @@ module {

// -----

module {
cir.global external @v = #cir.poison : !cir.void // expected-error {{poison expects non-void type}}
}

// -----

!s32i = !cir.int<s, 32>
cir.func @vec_op_size() {
%0 = cir.const #cir.int<1> : !s32i
Expand Down
2 changes: 2 additions & 0 deletions clang/test/CIR/Lowering/const.cir
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ module {
// CHECK: llvm.mlir.zero : !llvm.array<3 x i32>
%5 = cir.const #cir.undef : !cir.array<!s32i x 3>
// CHECK: llvm.mlir.undef : !llvm.array<3 x i32>
%6 = cir.const #cir.poison : !s32i
// CHECK: llvm.mlir.poison : i32
cir.return
}

Expand Down
Loading