@@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
707707 UnresolvedMaterializationRewrite (
708708 ConversionPatternRewriterImpl &rewriterImpl,
709709 UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr ,
710- MaterializationKind kind = MaterializationKind::Target,
711- Type origOutputType = nullptr )
710+ MaterializationKind kind = MaterializationKind::Target)
712711 : OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
713- converterAndKind (converter, kind), origOutputType(origOutputType) {}
712+ converterAndKind (converter, kind) {}
714713
715714 static bool classof (const IRRewrite *rewrite) {
716715 return rewrite->getKind () == Kind::UnresolvedMaterialization;
@@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
734733 return converterAndKind.getInt ();
735734 }
736735
737- // / Return the original illegal output type of the input values.
738- Type getOrigOutputType () const { return origOutputType; }
739-
740736private:
741737 // / The corresponding type converter to use when resolving this
742738 // / materialization, and the kind of this materialization.
743739 llvm::PointerIntPair<const TypeConverter *, 1 , MaterializationKind>
744740 converterAndKind;
745-
746- // / The original output type. This is only used for argument conversions.
747- Type origOutputType;
748741};
749742} // namespace
750743
@@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
860853 Block *insertBlock,
861854 Block::iterator insertPt, Location loc,
862855 ValueRange inputs, Type outputType,
863- Type origOutputType,
864856 const TypeConverter *converter);
865857
866858 Value buildUnresolvedArgumentMaterialization (Block *block, Location loc,
867859 ValueRange inputs,
868- Type origOutputType,
869860 Type outputType,
870861 const TypeConverter *converter);
871862
@@ -1388,20 +1379,27 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13881379 if (replArgs.size () == 1 &&
13891380 (!converter || replArgs[0 ].getType () == origArg.getType ())) {
13901381 newArg = replArgs.front ();
1382+ mapping.map (origArg, newArg);
13911383 } else {
1392- Type origOutputType = origArg.getType ();
1393-
1394- // Legalize the argument output type.
1395- Type outputType = origOutputType;
1396- if (Type legalOutputType = converter->convertType (outputType))
1397- outputType = legalOutputType;
1398-
1399- newArg = buildUnresolvedArgumentMaterialization (
1400- newBlock, origArg.getLoc (), replArgs, origOutputType, outputType,
1401- converter);
1384+ // Build argument materialization: new block arguments -> old block
1385+ // argument type.
1386+ Value argMat = buildUnresolvedArgumentMaterialization (
1387+ newBlock, origArg.getLoc (), replArgs, origArg.getType (), converter);
1388+ mapping.map (origArg, argMat);
1389+
1390+ // Build target materialization: old block argument type -> legal type.
1391+ // Note: This function returns an "empty" type if no valid conversion to
1392+ // a legal type exists. In that case, we continue the conversion with the
1393+ // original block argument type.
1394+ if (Type legalOutputType = converter->convertType (origArg.getType ())) {
1395+ newArg = buildUnresolvedTargetMaterialization (
1396+ origArg.getLoc (), argMat, legalOutputType, converter);
1397+ mapping.map (argMat, newArg);
1398+ } else {
1399+ newArg = argMat;
1400+ }
14021401 }
14031402
1404- mapping.map (origArg, newArg);
14051403 appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
14061404 argInfo[i] = ConvertedArgInfo (inputMap->inputNo , inputMap->size , newArg);
14071405 }
@@ -1424,7 +1422,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14241422// / of input operands.
14251423Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization (
14261424 MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1427- Location loc, ValueRange inputs, Type outputType, Type origOutputType,
1425+ Location loc, ValueRange inputs, Type outputType,
14281426 const TypeConverter *converter) {
14291427 // Avoid materializing an unnecessary cast.
14301428 if (inputs.size () == 1 && inputs.front ().getType () == outputType)
@@ -1435,16 +1433,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14351433 OpBuilder builder (insertBlock, insertPt);
14361434 auto convertOp =
14371435 builder.create <UnrealizedConversionCastOp>(loc, outputType, inputs);
1438- appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1439- origOutputType);
1436+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
14401437 return convertOp.getResult (0 );
14411438}
14421439Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization (
1443- Block *block, Location loc, ValueRange inputs, Type origOutputType ,
1444- Type outputType, const TypeConverter *converter) {
1440+ Block *block, Location loc, ValueRange inputs, Type outputType ,
1441+ const TypeConverter *converter) {
14451442 return buildUnresolvedMaterialization (MaterializationKind::Argument, block,
14461443 block->begin (), loc, inputs, outputType,
1447- origOutputType, converter);
1444+ converter);
14481445}
14491446Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization (
14501447 Location loc, Value input, Type outputType,
@@ -1456,7 +1453,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
14561453
14571454 return buildUnresolvedMaterialization (MaterializationKind::Target,
14581455 insertBlock, insertPt, loc, input,
1459- outputType, outputType, converter);
1456+ outputType, converter);
14601457}
14611458
14621459// ===----------------------------------------------------------------------===//
@@ -2672,19 +2669,28 @@ static void computeNecessaryMaterializations(
26722669 ConversionPatternRewriterImpl &rewriterImpl,
26732670 DenseMap<Value, SmallVector<Value>> &inverseMapping,
26742671 SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
2672+ // Helper function to check if the given value or a not yet materialized
2673+ // replacement of the given value is live.
2674+ // Note: `inverseMapping` maps from replaced values to original values.
26752675 auto isLive = [&](Value value) {
26762676 auto findFn = [&](Operation *user) {
26772677 auto matIt = materializationOps.find (user);
26782678 if (matIt != materializationOps.end ())
26792679 return !necessaryMaterializations.count (matIt->second );
26802680 return rewriterImpl.isOpIgnored (user);
26812681 };
2682- // This value may be replacing another value that has a live user.
2683- for (Value inv : inverseMapping.lookup (value))
2684- if (llvm::find_if_not (inv.getUsers (), findFn) != inv.user_end ())
2682+ // A worklist is needed because a value may have gone through a chain of
2683+ // replacements and each of the replaced values may have live users.
2684+ SmallVector<Value> worklist;
2685+ worklist.push_back (value);
2686+ while (!worklist.empty ()) {
2687+ Value next = worklist.pop_back_val ();
2688+ if (llvm::find_if_not (next.getUsers (), findFn) != next.user_end ())
26852689 return true ;
2686- // Or have live users itself.
2687- return llvm::find_if_not (value.getUsers (), findFn) != value.user_end ();
2690+ // This value may be replacing another value that has a live user.
2691+ llvm::append_range (worklist, inverseMapping.lookup (next));
2692+ }
2693+ return false ;
26882694 };
26892695
26902696 llvm::unique_function<Value (Value, Value, Type)> lookupRemappedValue =
@@ -2844,18 +2850,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
28442850 switch (mat.getMaterializationKind ()) {
28452851 case MaterializationKind::Argument:
28462852 // Try to materialize an argument conversion.
2847- // FIXME: The current argument materialization hook expects the original
2848- // output type, even though it doesn't use that as the actual output type
2849- // of the generated IR. The output type is just used as an indicator of
2850- // the type of materialization to do. This behavior is really awkward in
2851- // that it diverges from the behavior of the other hooks, and can be
2852- // easily misunderstood. We should clean up the argument hooks to better
2853- // represent the desired invariants we actually care about.
28542853 newMaterialization = converter->materializeArgumentConversion (
2855- rewriter, op->getLoc (), mat. getOrigOutputType () , inputOperands);
2854+ rewriter, op->getLoc (), outputType , inputOperands);
28562855 if (newMaterialization)
28572856 break ;
2858-
28592857 // If an argument materialization failed, fallback to trying a target
28602858 // materialization.
28612859 [[fallthrough]];
@@ -2865,6 +2863,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
28652863 break ;
28662864 }
28672865 if (newMaterialization) {
2866+ assert (newMaterialization.getType () == outputType &&
2867+ " materialization callback produced value of incorrect type" );
28682868 replaceMaterialization (rewriterImpl, opResult, newMaterialization,
28692869 inverseMapping);
28702870 return success ();
0 commit comments