@@ -425,6 +425,16 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::UndefAttr undefAttr,
425425 loc, converter->convertType (undefAttr.getType ()));
426426}
427427
428+ // / PoisonAttr visitor.
429+ static mlir::Value
430+ lowerCirAttrAsValue (mlir::Operation *parentOp, cir::PoisonAttr poisonAttr,
431+ mlir::ConversionPatternRewriter &rewriter,
432+ const mlir::TypeConverter *converter) {
433+ auto loc = parentOp->getLoc ();
434+ return rewriter.create <mlir::LLVM::PoisonOp>(
435+ loc, converter->convertType (poisonAttr.getType ()));
436+ }
437+
428438// / ConstStruct visitor.
429439static mlir::Value
430440lowerCirAttrAsValue (mlir::Operation *parentOp, cir::ConstStructAttr constStruct,
@@ -644,6 +654,8 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
644654 return lowerCirAttrAsValue (parentOp, zeroAttr, rewriter, converter);
645655 if (const auto undefAttr = mlir::dyn_cast<cir::UndefAttr>(attr))
646656 return lowerCirAttrAsValue (parentOp, undefAttr, rewriter, converter);
657+ if (const auto poisonAttr = mlir::dyn_cast<cir::PoisonAttr>(attr))
658+ return lowerCirAttrAsValue (parentOp, poisonAttr, rewriter, converter);
647659 if (const auto globalAttr = mlir::dyn_cast<cir::GlobalViewAttr>(attr))
648660 return lowerCirAttrAsValue (parentOp, globalAttr, rewriter, converter);
649661 if (const auto vtableAttr = mlir::dyn_cast<cir::VTableAttr>(attr))
@@ -1555,6 +1567,14 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
15551567 mlir::ConversionPatternRewriter &rewriter) const {
15561568 mlir::Attribute attr = op.getValue ();
15571569
1570+ // Regardless of the type, we should lower the constant of poison value
1571+ // into PoisonOp.
1572+ if (mlir::isa<cir::PoisonAttr>(attr)) {
1573+ rewriter.replaceOp (
1574+ op, lowerCirAttrAsValue (op, attr, rewriter, getTypeConverter ()));
1575+ return mlir::success ();
1576+ }
1577+
15581578 if (mlir::isa<mlir::IntegerType>(op.getType ())) {
15591579 // Verified cir.const operations cannot actually be of these types, but the
15601580 // lowering pass may generate temporary cir.const operations with these
@@ -1695,6 +1715,7 @@ mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
16951715 mlir::Value result = rewriter.create <mlir::LLVM::PoisonOp>(loc, llvmTy);
16961716 assert (vecTy.getSize () == op.getElements ().size () &&
16971717 " cir.vec.create op count doesn't match vector type elements count" );
1718+
16981719 for (uint64_t i = 0 ; i < vecTy.getSize (); ++i) {
16991720 mlir::Value indexValue =
17001721 rewriter.create <mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type (), i);
@@ -1745,15 +1766,21 @@ mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
17451766 assert (vecTy && " result type of cir.vec.splat op is not VectorType" );
17461767 auto llvmTy = typeConverter->convertType (vecTy);
17471768 auto loc = op.getLoc ();
1748- mlir::Value undef = rewriter.create <mlir::LLVM::PoisonOp>(loc, llvmTy);
1769+ mlir::Value poison = rewriter.create <mlir::LLVM::PoisonOp>(loc, llvmTy);
17491770 mlir::Value indexValue =
17501771 rewriter.create <mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type (), 0 );
17511772 mlir::Value elementValue = adaptor.getValue ();
1773+ if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp ())) {
1774+ // If the splat value is poison, then we can just use poison value
1775+ // for the entire vector.
1776+ rewriter.replaceOp (op, poison);
1777+ return mlir::success ();
1778+ }
17521779 mlir::Value oneElement = rewriter.create <mlir::LLVM::InsertElementOp>(
1753- loc, undef , elementValue, indexValue);
1780+ loc, poison , elementValue, indexValue);
17541781 SmallVector<int32_t > zeroValues (vecTy.getSize (), 0 );
17551782 mlir::Value shuffled = rewriter.create <mlir::LLVM::ShuffleVectorOp>(
1756- loc, oneElement, undef , zeroValues);
1783+ loc, oneElement, poison , zeroValues);
17571784 rewriter.replaceOp (op, shuffled);
17581785 return mlir::success ();
17591786}
0 commit comments