Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 45 additions & 19 deletions src/coreclr/tools/aot/ILCompiler.Compiler/Compiler/ILScanner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ private sealed class ScannedDevirtualizationManager : DevirtualizationManager
private HashSet<TypeDesc> _constructedTypes = new HashSet<TypeDesc>();
private HashSet<TypeDesc> _canonConstructedTypes = new HashSet<TypeDesc>();
private HashSet<TypeDesc> _unsealedTypes = new HashSet<TypeDesc>();
private Dictionary<TypeDesc, HashSet<TypeDesc>> _interfaceImplementators = new();
private HashSet<TypeDesc> _disqualifiedInterfaces = new();
private Dictionary<TypeDesc, HashSet<TypeDesc>> _implementators = new();
private HashSet<TypeDesc> _disqualifiedTypes = new();

public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<DependencyNodeCore<NodeFactory>> markedNodes)
{
Expand All @@ -437,8 +437,8 @@ public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<Depend
{
// If the interface is implemented through IDynamicInterfaceCastable, there might be
// no real upper bound on the number of actual classes implementing it.
if (CanAssumeWholeProgramViewOnInterfaceUse(factory, type, baseInterface))
_disqualifiedInterfaces.Add(baseInterface);
if (CanAssumeWholeProgramViewOnTypeUse(factory, type, baseInterface))
_disqualifiedTypes.Add(baseInterface);
}
}
}
Expand All @@ -457,14 +457,23 @@ public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<Depend

if (type is not MetadataType { IsAbstract: true })
{
// Record all interfaces this class implements to _interfaceImplementators
// Record all interfaces this class implements to _implementators
foreach (DefType baseInterface in type.RuntimeInterfaces)
{
if (CanAssumeWholeProgramViewOnInterfaceUse(factory, type, baseInterface))
if (CanAssumeWholeProgramViewOnTypeUse(factory, type, baseInterface))
{
RecordImplementation(baseInterface, type);
}
}

// Record all base types of this class
for (DefType @base = type.BaseType; @base != null; @base = @base.BaseType)
{
if (CanAssumeWholeProgramViewOnTypeUse(factory, type, @base))
{
RecordImplementation(@base, type);
}
}
}

if (type.IsCanonicalSubtype(CanonicalFormKind.Any))
Expand All @@ -474,7 +483,13 @@ public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<Depend
// due to MakeGenericType.
foreach (DefType baseInterface in type.RuntimeInterfaces)
{
_disqualifiedInterfaces.Add(baseInterface);
_disqualifiedTypes.Add(baseInterface);
}

// Same for base classes
for (DefType @base = type.BaseType; @base != null; @base = @base.BaseType)
{
_disqualifiedTypes.Add(@base);
}
}
else if (type.IsArray || type.GetTypeDefinition() == factory.ArrayOfTEnumeratorType)
Expand All @@ -490,7 +505,7 @@ public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<Depend
{
// Limit to the generic ones - ICollection<T>, etc.
if (baseInterface.HasInstantiation)
_disqualifiedInterfaces.Add(baseInterface);
_disqualifiedTypes.Add(baseInterface);
}
}
}
Expand All @@ -513,22 +528,23 @@ public ScannedDevirtualizationManager(NodeFactory factory, ImmutableArray<Depend
}
}

private static bool CanAssumeWholeProgramViewOnInterfaceUse(NodeFactory factory, TypeDesc implementingType, DefType interfaceType)
private static bool CanAssumeWholeProgramViewOnTypeUse(NodeFactory factory, TypeDesc implementingType, DefType baseType)
{
if (!interfaceType.HasInstantiation)
if (!baseType.HasInstantiation)
{
return true;
}

// If there are variance considerations, bail
if (VariantInterfaceMethodUseNode.IsVariantInterfaceImplementation(factory, implementingType, interfaceType))
if (baseType.IsInterface
&& VariantInterfaceMethodUseNode.IsVariantInterfaceImplementation(factory, implementingType, baseType))
{
return false;
}

if (interfaceType.IsCanonicalSubtype(CanonicalFormKind.Any)
|| interfaceType.ConvertToCanonForm(CanonicalFormKind.Specific) != interfaceType
|| interfaceType.Context.SupportsUniversalCanon)
if (baseType.IsCanonicalSubtype(CanonicalFormKind.Any)
|| baseType.ConvertToCanonForm(CanonicalFormKind.Specific) != baseType
|| baseType.Context.SupportsUniversalCanon)
{
// If the interface has a canonical form, we might not have a full view of all implementers.
// E.g. if we have:
Expand All @@ -549,10 +565,10 @@ private void RecordImplementation(TypeDesc type, TypeDesc implType)
Debug.Assert(!implType.IsInterface);

HashSet<TypeDesc> implList;
if (!_interfaceImplementators.TryGetValue(type, out implList))
if (!_implementators.TryGetValue(type, out implList))
{
implList = new();
_interfaceImplementators[type] = implList;
_implementators[type] = implList;
}
implList.Add(implType);
}
Expand Down Expand Up @@ -604,13 +620,23 @@ protected override MethodDesc ResolveVirtualMethod(MethodDesc declMethod, DefTyp

public override TypeDesc[] GetImplementingClasses(TypeDesc type)
{
if (_disqualifiedInterfaces.Contains(type))
if (_disqualifiedTypes.Contains(type))
return null;

if (type.IsInterface && _interfaceImplementators.TryGetValue(type, out HashSet<TypeDesc> implementations))
if (_implementators.TryGetValue(type, out HashSet<TypeDesc> implementations))
{
var types = new TypeDesc[implementations.Count];
TypeDesc[] types;
int index = 0;
if (!type.IsInterface && type is not MetadataType { IsAbstract: true })
{
types = new TypeDesc[implementations.Count + 1];
types[index++] = type;
}
else
{
types = new TypeDesc[implementations.Count];
}

foreach (TypeDesc implementation in implementations)
{
types[index++] = implementation;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2196,12 +2196,6 @@ private int getExactClasses(CORINFO_CLASS_STRUCT_* baseType, int maxExactClasses
return 1;
}

if (!type.IsInterface)
{
// TODO: handle classes
return 0;
}

TypeDesc[] implClasses = _compilation.GetImplementingClasses(type);
if (implClasses == null || implClasses.Length > maxExactClasses)
{
Expand Down