@@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
707707 UnresolvedMaterializationRewrite (
708708 ConversionPatternRewriterImpl &rewriterImpl,
709709 UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr ,
710- MaterializationKind kind = MaterializationKind::Target,
711- Type origOutputType = nullptr )
710+ MaterializationKind kind = MaterializationKind::Target)
712711 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
713- converterAndKind (converter, kind), origOutputType(origOutputType) {}
712+ converterAndKind (converter, kind) {}
714713
715714 static bool classof (const IRRewrite *rewrite) {
716715 return rewrite->getKind () == Kind::UnresolvedMaterialization;
@@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
734733 return converterAndKind.getInt ();
735734 }
736735
737- // / Return the original illegal output type of the input values.
738- Type getOrigOutputType () const { return origOutputType; }
739-
740736private:
741737 // / The corresponding type converter to use when resolving this
742738 // / materialization, and the kind of this materialization.
743739 llvm::PointerIntPair<const TypeConverter *, 1 , MaterializationKind>
744740 converterAndKind;
745-
746- // / The original output type. This is only used for argument conversions.
747- Type origOutputType;
748741};
749742} // namespace
750743
@@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
860853 Block *insertBlock,
861854 Block::iterator insertPt, Location loc,
862855 ValueRange inputs, Type outputType,
863- Type origOutputType,
864856 const TypeConverter *converter);
865857
866858 Value buildUnresolvedArgumentMaterialization (Block *block, Location loc,
867859 ValueRange inputs,
868- Type origOutputType,
869860 Type outputType,
870861 const TypeConverter *converter);
871862
@@ -1388,20 +1379,28 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13881379 if (replArgs.size () == 1 &&
13891380 (!converter || replArgs[0 ].getType () == origArg.getType ())) {
13901381 newArg = replArgs.front ();
1382+ mapping.map (origArg, newArg);
13911383 } else {
1392- Type origOutputType = origArg.getType ();
1393-
1394- // Legalize the argument output type.
1395- Type outputType = origOutputType;
1396- if (Type legalOutputType = converter->convertType (outputType))
1397- outputType = legalOutputType;
1398-
1399- newArg = buildUnresolvedArgumentMaterialization (
1400- newBlock, origArg.getLoc (), replArgs, origOutputType, outputType,
1401- converter);
1384+ // Build argument materialization: new block arguments -> old block
1385+ // argument type.
1386+ Value argMat = buildUnresolvedArgumentMaterialization (
1387+ newBlock, origArg.getLoc (), replArgs, origArg.getType (), converter);
1388+ mapping.map (origArg, argMat);
1389+
1390+ // Build target materialization: old block argument type -> legal type.
1391+ // Note: This function returns an "empty" type if no valid conversion to
1392+ // a legal type exists. In that case, we continue the conversion with the
1393+ // original block argument type.
1394+ Type legalOutputType = converter->convertType (origArg.getType ());
1395+ if (legalOutputType && legalOutputType != origArg.getType ()) {
1396+ newArg = buildUnresolvedTargetMaterialization (
1397+ origArg.getLoc (), argMat, legalOutputType, converter);
1398+ mapping.map (argMat, newArg);
1399+ } else {
1400+ newArg = argMat;
1401+ }
14021402 }
14031403
1404- mapping.map (origArg, newArg);
14051404 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
14061405 argInfo[i] = ConvertedArgInfo (inputMap->inputNo , inputMap->size , newArg);
14071406 }
@@ -1424,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14241423// / of input operands.
14251424Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization (
14261425 MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1427- Location loc, ValueRange inputs, Type outputType, Type origOutputType,
1426+ Location loc, ValueRange inputs, Type outputType,
14281427 const TypeConverter *converter) {
14291428 // Avoid materializing an unnecessary cast.
14301429 if (inputs.size () == 1 && inputs.front ().getType () == outputType)
@@ -1435,16 +1434,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14351434 OpBuilder builder (insertBlock, insertPt);
14361435 auto convertOp =
14371436 builder.create <UnrealizedConversionCastOp>(loc, outputType, inputs);
1438- appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1439- origOutputType);
1437+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
14401438 return convertOp.getResult (0 );
14411439}
14421440Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization (
1443- Block *block, Location loc, ValueRange inputs, Type origOutputType ,
1444- Type outputType, const TypeConverter *converter) {
1441+ Block *block, Location loc, ValueRange inputs, Type outputType ,
1442+ const TypeConverter *converter) {
14451443 return buildUnresolvedMaterialization (MaterializationKind::Argument, block,
14461444 block->begin (), loc, inputs, outputType,
1447- origOutputType, converter);
1445+ converter);
14481446}
14491447Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization (
14501448 Location loc, Value input, Type outputType,
@@ -1456,7 +1454,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
14561454
14571455 return buildUnresolvedMaterialization (MaterializationKind::Target,
14581456 insertBlock, insertPt, loc, input,
1459- outputType, outputType, converter);
1457+ outputType, converter);
14601458}
14611459
14621460// ===----------------------------------------------------------------------===//
@@ -2672,19 +2670,28 @@ static void computeNecessaryMaterializations(
26722670 ConversionPatternRewriterImpl &rewriterImpl,
26732671 DenseMap<Value, SmallVector<Value>> &inverseMapping,
26742672 SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
2673+ // Helper function to check if the given value or a not yet materialized
2674+ // replacement of the given value is live.
2675+ // Note: `inverseMapping` maps from replaced values to original values.
26752676 auto isLive = [&](Value value) {
26762677 auto findFn = [&](Operation *user) {
26772678 auto matIt = materializationOps.find (user);
26782679 if (matIt != materializationOps.end ())
26792680 return !necessaryMaterializations.count (matIt->second );
26802681 return rewriterImpl.isOpIgnored (user);
26812682 };
2682- // This value may be replacing another value that has a live user.
2683- for (Value inv : inverseMapping.lookup (value))
2684- if (llvm::find_if_not (inv.getUsers (), findFn) != inv.user_end ())
2683+ // A worklist is needed because a value may have gone through a chain of
2684+ // replacements and each of the replaced values may have live users.
2685+ SmallVector<Value> worklist;
2686+ worklist.push_back (value);
2687+ while (!worklist.empty ()) {
2688+ Value next = worklist.pop_back_val ();
2689+ if (llvm::find_if_not (next.getUsers (), findFn) != next.user_end ())
26852690 return true ;
2686- // Or have live users itself.
2687- return llvm::find_if_not (value.getUsers (), findFn) != value.user_end ();
2691+ // This value may be replacing another value that has a live user.
2692+ llvm::append_range (worklist, inverseMapping.lookup (next));
2693+ }
2694+ return false ;
26882695 };
26892696
26902697 llvm::unique_function<Value (Value, Value, Type)> lookupRemappedValue =
@@ -2844,18 +2851,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
28442851 switch (mat.getMaterializationKind ()) {
28452852 case MaterializationKind::Argument:
28462853 // Try to materialize an argument conversion.
2847- // FIXME: The current argument materialization hook expects the original
2848- // output type, even though it doesn't use that as the actual output type
2849- // of the generated IR. The output type is just used as an indicator of
2850- // the type of materialization to do. This behavior is really awkward in
2851- // that it diverges from the behavior of the other hooks, and can be
2852- // easily misunderstood. We should clean up the argument hooks to better
2853- // represent the desired invariants we actually care about.
28542854 newMaterialization = converter->materializeArgumentConversion (
2855- rewriter, op->getLoc (), mat. getOrigOutputType () , inputOperands);
2855+ rewriter, op->getLoc (), outputType , inputOperands);
28562856 if (newMaterialization)
28572857 break ;
2858-
28592858 // If an argument materialization failed, fallback to trying a target
28602859 // materialization.
28612860 [[fallthrough]];
@@ -2865,6 +2864,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
28652864 break ;
28662865 }
28672866 if (newMaterialization) {
2867+ assert (newMaterialization.getType () == outputType &&
2868+ " materialization callback produced value of incorrect type" );
28682869 replaceMaterialization (rewriterImpl, opResult, newMaterialization,
28692870 inverseMapping);
28702871 return success ();
0 commit comments