Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions Source/Mockolate.SourceGenerators/Entities/Class.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@ namespace Mockolate.SourceGenerators.Entities;

internal record Class
{
private readonly IAssemblySymbol _sourceAssembly;

public Class(ITypeSymbol type,
IAssemblySymbol sourceAssembly,
List<Method>? alreadyDefinedMethods = null,
List<Property>? alreadyDefinedProperties = null,
List<Event>? alreadyDefinedEvents = null,
List<Method>? exceptMethods = null,
List<Property>? exceptProperties = null,
List<Event>? exceptEvents = null)
{
_sourceAssembly = sourceAssembly;
Namespace = type.ContainingNamespace.ToString();
DisplayString = type.ToDisplayString();
ClassName = GetTypeName(type);
Expand All @@ -34,48 +38,71 @@ public Class(ITypeSymbol type,
// Exclude getter/setter methods
.Where(x => x.AssociatedSymbol is null && !x.IsSealed)
.Where(x => IsInterface || x.IsVirtual || x.IsAbstract)
.Where(x => ShouldIncludeMember(x))
.Select(x => new Method(x, alreadyDefinedMethods))
.Distinct(), exceptMethods, Method.ContainingTypeIndependentEqualityComparer);
Methods = new EquatableArray<Method>(methods.ToArray());

List<Property> properties = ToListExcept(type.GetMembers().OfType<IPropertySymbol>()
.Where(x => !x.IsSealed)
.Where(x => IsInterface || x.IsVirtual || x.IsAbstract)
.Where(x => ShouldIncludeMember(x))
.Select(x => new Property(x, alreadyDefinedProperties))
.Distinct(), exceptProperties, Property.ContainingTypeIndependentEqualityComparer);
Properties = new EquatableArray<Property>(properties.ToArray());

List<Event> events = ToListExcept(type.GetMembers().OfType<IEventSymbol>()
.Where(x => !x.IsSealed)
.Where(x => IsInterface || x.IsVirtual || x.IsAbstract)
.Where(x => ShouldIncludeMember(x))
.Select(x => (x, x.Type as INamedTypeSymbol))
.Where(x => x.Item2?.DelegateInvokeMethod is not null)
.Select(x => new Event(x.x, x.Item2!.DelegateInvokeMethod!, alreadyDefinedEvents))
.Distinct(), exceptEvents, Event.ContainingTypeIndependentEqualityComparer);
Events = new EquatableArray<Event>(events.ToArray());

exceptProperties ??= type.GetMembers().OfType<IPropertySymbol>()
exceptProperties ??= new List<Property>();
exceptProperties.AddRange(type.GetMembers().OfType<IPropertySymbol>()
.Where(x => x.IsSealed)
.Select(x => new Property(x, null))
.Distinct()
.ToList();
exceptMethods ??= type.GetMembers().OfType<IMethodSymbol>()
.Distinct());

exceptMethods ??= new List<Method>();
exceptMethods.AddRange(type.GetMembers().OfType<IMethodSymbol>()
.Where(x => x.IsSealed)
.Select(x => new Method(x, null))
.Distinct()
.ToList();
exceptEvents ??= type.GetMembers().OfType<IEventSymbol>()
.Distinct());

exceptEvents ??= new List<Event>();
exceptEvents.AddRange(type.GetMembers().OfType<IEventSymbol>()
.Where(x => x.IsSealed)
.Select(x => (x, x.Type as INamedTypeSymbol))
.Where(x => x.Item2?.DelegateInvokeMethod is not null)
.Select(x => new Event(x.x, x.Item2!.DelegateInvokeMethod!, null))
.Distinct()
.ToList();
.Distinct());

InheritedTypes = new EquatableArray<Class>(
GetInheritedTypes(type).Select(t
=> new Class(t, methods, properties, events, exceptMethods, exceptProperties, exceptEvents))
=> new Class(t, sourceAssembly, methods, properties, events, exceptMethods, exceptProperties,
exceptEvents))
.ToArray());

bool ShouldIncludeMember(ISymbol member)
{
if (IsInterface || member.IsAbstract)
{
return true;
}

if ((member.DeclaredAccessibility == Accessibility.Internal ||
member.DeclaredAccessibility == Accessibility.ProtectedOrInternal) &&
!SymbolEqualityComparer.Default.Equals(member.ContainingAssembly, _sourceAssembly))
{
return false;
}

return true;
}
}

public EquatableArray<Method> Methods { get; }
Expand Down
62 changes: 57 additions & 5 deletions Source/Mockolate.SourceGenerators/Entities/Method.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,63 @@ public bool Equals(Method? x, Method? y)
private sealed class ContainingTypeIndependentMethodEqualityComparer : IEqualityComparer<Method>
{
public bool Equals(Method? x, Method? y)
=> (x is null && y is null) ||
(x is not null && y is not null &&
x.Name.Equals(y.Name) &&
x.Parameters.Count == y.Parameters.Count &&
x.Parameters.SequenceEqual(y.Parameters));
{
if (x is null && y is null)
{
return true;
}

if (x is null || y is null)
{
return false;
}

if (!x.Name.Equals(y.Name) || x.Parameters.Count != y.Parameters.Count)
{
return false;
}

// Compare parameters ignoring nullability annotations
MethodParameter[]? xParams = x.Parameters.AsArray();
MethodParameter[]? yParams = y.Parameters.AsArray();

if (xParams is null || yParams is null)
{
return xParams is null && yParams is null;
}

for (int i = 0; i < xParams.Length; i++)
{
MethodParameter xParam = xParams[i];
MethodParameter yParam = yParams[i];

if (xParam.RefKind != yParam.RefKind)
{
return false;
}

// Normalize type names by removing nullable annotation
string xTypeName = xParam.Type.Fullname;
string yTypeName = yParam.Type.Fullname;

if (xTypeName.EndsWith("?"))
{
xTypeName = xTypeName.Substring(0, xTypeName.Length - 1);
}

if (yTypeName.EndsWith("?"))
{
yTypeName = yTypeName.Substring(0, yTypeName.Length - 1);
}

if (!xTypeName.Equals(yTypeName))
{
return false;
}
}

return true;
}

public int GetHashCode(Method obj) => obj.Name.GetHashCode();
}
Expand Down
4 changes: 2 additions & 2 deletions Source/Mockolate.SourceGenerators/Entities/MockClass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ namespace Mockolate.SourceGenerators.Entities;

internal record MockClass : Class
{
public MockClass(ITypeSymbol[] types) : base(types[0])
public MockClass(ITypeSymbol[] types, IAssemblySymbol sourceAssembly) : base(types[0], sourceAssembly)
{
AdditionalImplementations = new EquatableArray<Class>(
types.Skip(1).Select(x => new Class(x)).ToArray());
types.Skip(1).Select(x => new Class(x, sourceAssembly)).ToArray());

if (!IsInterface && types[0] is INamedTypeSymbol namedTypeSymbol)
{
Expand Down
21 changes: 11 additions & 10 deletions Source/Mockolate.SourceGenerators/MockGeneratorHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ internal static IEnumerable<MockClass> ExtractMockOrMockFactoryCreateSyntaxOrDef
MemberAccessExpressionSyntax memberAccessExpressionSyntax =
(MemberAccessExpressionSyntax)invocationSyntax.Expression;
GenericNameSyntax genericNameSyntax = (GenericNameSyntax)memberAccessExpressionSyntax.Name;
IAssemblySymbol sourceAssembly = semanticModel.Compilation.Assembly;
if (semanticModel.GetSymbolInfo(syntaxNode).IsCreateInvocationOnMockOrMockFactory())
{
ITypeSymbol[] genericTypes = genericNameSyntax.TypeArgumentList.Arguments
Expand All @@ -50,17 +51,17 @@ internal static IEnumerable<MockClass> ExtractMockOrMockFactoryCreateSyntaxOrDef
// Ignore types from the global namespace, as they are not generated correctly.
genericTypes.All(x => !x.ContainingNamespace.IsGlobalNamespace))
{
yield return new MockClass(genericTypes);
yield return new MockClass(genericTypes, sourceAssembly);

foreach (MockClass? additionalMockClass in DiscoverMockableTypes(genericTypes))
foreach (MockClass? additionalMockClass in DiscoverMockableTypes(genericTypes, sourceAssembly))
{
yield return additionalMockClass;
}
}
}
}

private static IEnumerable<MockClass> DiscoverMockableTypes(IEnumerable<ITypeSymbol> initialTypes)
private static IEnumerable<MockClass> DiscoverMockableTypes(IEnumerable<ITypeSymbol> initialTypes, IAssemblySymbol sourceAssembly)
{
Queue<ITypeSymbol> typesToProcess = new(initialTypes);
HashSet<ITypeSymbol> processedTypes = new(SymbolEqualityComparer.Default);
Expand All @@ -70,28 +71,28 @@ private static IEnumerable<MockClass> DiscoverMockableTypes(IEnumerable<ITypeSym
ITypeSymbol currentType = typesToProcess.Dequeue();

foreach (ITypeSymbol propertyType in currentType.GetMembers()
.OfType<IPropertySymbol>()
.Select(p => p.Type))
.OfType<IPropertySymbol>()
.Select(p => p.Type))
{
if (propertyType.TypeKind == TypeKind.Interface &&
IsMockable(propertyType) &&
processedTypes.Add(propertyType))
{
yield return new MockClass([propertyType,]);
yield return new MockClass([propertyType,], sourceAssembly);
typesToProcess.Enqueue(propertyType);
}
}

foreach (ITypeSymbol methodType in currentType.GetMembers()
.OfType<IMethodSymbol>()
.Where(m => !m.ReturnsVoid)
.Select(m => m.ReturnType))
.OfType<IMethodSymbol>()
.Where(m => !m.ReturnsVoid)
.Select(m => m.ReturnType))
{
if (methodType.TypeKind == TypeKind.Interface &&
IsMockable(methodType) &&
processedTypes.Add(methodType))
{
yield return new MockClass([methodType,]);
yield return new MockClass([methodType,], sourceAssembly);
typesToProcess.Enqueue(methodType);
}
}
Expand Down
123 changes: 123 additions & 0 deletions Tests/Mockolate.SourceGenerators.Tests/Sources/ForMockTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -326,4 +326,127 @@ await That(result.Sources).ContainsKey("MockForIMyServiceListMyData.g.cs").Whose
.DoesNotContain("using MyCode.Services;").And
.DoesNotContain("using MyCode.Models;");
}

[Fact]
public async Task ShouldHandleComplexInheritanceWithSealedAndInternalMembers()
{
GeneratorResult result = Generator
.Run("""
using Mockolate;

namespace MyCode;

public class Program
{
public static void Main(string[] args)
{
_ = Mock.Create<MyDerivedClass>();
}
}

public class MyDerivedClass : MyMiddleClass
{
}

public class MyMiddleClass : MyBaseClass
{
public sealed override void SealedMethod() { }
protected internal override void ProtectedInternalMethod() { }
}

public class MyBaseClass
{
public virtual void SealedMethod() { }
public virtual void NormalMethod() { }
protected internal virtual void ProtectedInternalMethod() { }
internal virtual void InternalMethod() { }
protected virtual void ProtectedMethod() { }
}
""");

await That(result.Sources).ContainsKey("MockForMyDerivedClass.g.cs").WhoseValue
.DoesNotContain("override void SealedMethod").And
.Contains("ProtectedInternalMethod").And
.Contains("InternalMethod").And
.Contains("override void NormalMethod").And
.Contains("override void ProtectedMethod");
}

[Fact]
public async Task ShouldNotIncludeSealedOverrideSpecialMethods()
{
GeneratorResult result = Generator
.Run("""
using Mockolate;

namespace MyCode;

public class Program
{
public static void Main(string[] args)
{
_ = Mock.Create<MyDerivedClass>();
}
}

public class MyDerivedClass : MyMiddleClass
{
}

public class MyMiddleClass : MyBaseClass
{
public sealed override bool Equals(object? obj) => base.Equals(obj);
public sealed override int GetHashCode() => base.GetHashCode();
public sealed override string? ToString() => base.ToString();
}

public class MyBaseClass
{
public virtual void SomeMethod() { }
}
""");

await That(result.Sources).ContainsKey("MockForMyDerivedClass.g.cs").WhoseValue
.DoesNotContain("override bool Equals").And
.DoesNotContain("override int GetHashCode").And
.DoesNotContain("override string ToString");
}

[Fact]
public async Task ShouldNotIncludeSealedOverrideSpecialMethodsWithNonNullableParameters()
{
GeneratorResult result = Generator
.Run("""
using Mockolate;

namespace MyCode;

public class Program
{
public static void Main(string[] args)
{
_ = Mock.Create<MyDerivedClass>();
}
}

public class MyDerivedClass : MyMiddleClass
{
}

public class MyMiddleClass : MyBaseClass
{
public sealed override bool Equals(object obj) => base.Equals(obj);
}

public class MyBaseClass
{
public virtual void SomeMethod() { }
}
""");

// Even though MyMiddleClass.Equals has non-nullable parameter (object),
// it should still match and filter out object.Equals with nullable parameter (object?)
await That(result.Sources).ContainsKey("MockForMyDerivedClass.g.cs").WhoseValue
.DoesNotContain("override bool Equals");
}
}
Loading