Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions src/Analyzers/MSTest.Analyzers/UseProperAssertMethodsAnalyzer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ private enum CountCheckStatus
HasCount,
}

private enum LinqPredicateCheckStatus
{
Unknown,
Any,
Count,
WhereAny,
WhereCount,
}

internal const string ProperAssertMethodNameKey = nameof(ProperAssertMethodNameKey);

/// <summary>
Expand Down Expand Up @@ -519,6 +528,74 @@ private static ComparisonCheckStatus RecognizeComparisonCheck(
return ComparisonCheckStatus.Unknown;
}

private static LinqPredicateCheckStatus RecognizeLinqPredicateCheck(
IOperation operation,
out SyntaxNode? collectionExpression,
out SyntaxNode? predicateExpression,
out IOperation? countOperation)
{
collectionExpression = null;
predicateExpression = null;
countOperation = null;

// Check for enumerable.Any(predicate)
// Extension methods appear as: Instance=null, Arguments[0]=collection, Arguments[1]=predicate
if (operation is IInvocationOperation anyInvocation &&
anyInvocation.TargetMethod.Name == "Any" &&
anyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
anyInvocation.Arguments.Length == 2)
{
collectionExpression = anyInvocation.Arguments[0].Value.Syntax;
predicateExpression = anyInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.Any;
}

// Check for enumerable.Count(predicate)
if (operation is IInvocationOperation countInvocation &&
countInvocation.TargetMethod.Name == "Count" &&
countInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
countInvocation.Arguments.Length == 2)
{
collectionExpression = countInvocation.Arguments[0].Value.Syntax;
predicateExpression = countInvocation.Arguments[1].Value.Syntax;
countOperation = operation;
return LinqPredicateCheckStatus.Count;
}

// Check for enumerable.Where(predicate).Any()
if (operation is IInvocationOperation whereAnyInvocation &&
whereAnyInvocation.TargetMethod.Name == "Any" &&
whereAnyInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereAnyInvocation.Arguments.Length == 1 &&
whereAnyInvocation.Arguments[0].Value is IInvocationOperation whereInvocation &&
whereInvocation.TargetMethod.Name == "Where" &&
whereInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation.Arguments.Length == 2)
{
collectionExpression = whereInvocation.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation.Arguments[1].Value.Syntax;
return LinqPredicateCheckStatus.WhereAny;
}

// Check for enumerable.Where(predicate).Count()
if (operation is IInvocationOperation whereCountInvocation &&
whereCountInvocation.TargetMethod.Name == "Count" &&
whereCountInvocation.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereCountInvocation.Arguments.Length == 1 &&
whereCountInvocation.Arguments[0].Value is IInvocationOperation whereInvocation2 &&
whereInvocation2.TargetMethod.Name == "Where" &&
whereInvocation2.TargetMethod.ContainingType?.ToDisplayString() == "System.Linq.Enumerable" &&
whereInvocation2.Arguments.Length == 2)
{
collectionExpression = whereInvocation2.Arguments[0].Value.Syntax;
predicateExpression = whereInvocation2.Arguments[1].Value.Syntax;
countOperation = operation;
return LinqPredicateCheckStatus.WhereCount;
}

return LinqPredicateCheckStatus.Unknown;
}

private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext context, IOperation conditionArgument, bool isTrueInvocation, INamedTypeSymbol objectTypeSymbol)
{
RoslynDebug.Assert(context.Operation is IInvocationOperation, "Expected IInvocationOperation.");
Expand Down Expand Up @@ -555,6 +632,36 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co
return;
}

// Check for LINQ predicate patterns that suggest Contains/DoesNotContain
LinqPredicateCheckStatus linqStatus = RecognizeLinqPredicateCheck(
conditionArgument,
out SyntaxNode? linqCollectionExpr,
out SyntaxNode? predicateExpr,
out _);

if (linqStatus != LinqPredicateCheckStatus.Unknown && linqCollectionExpr != null && predicateExpr != null)
{
// For Any() and Where().Any() patterns
if (linqStatus is LinqPredicateCheckStatus.Any or LinqPredicateCheckStatus.WhereAny)
{
string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, CodeFixModeAddArgument);
context.ReportDiagnostic(context.Operation.CreateDiagnostic(
Rule,
additionalLocations: ImmutableArray.Create(
conditionArgument.Syntax.GetLocation(),
predicateExpr.GetLocation(),
linqCollectionExpr.GetLocation()),
properties: properties.ToImmutable(),
properAssertMethod,
isTrueInvocation ? "IsTrue" : "IsFalse"));
return;
}
}

// Check for string method patterns: myString.StartsWith/EndsWith/Contains(...)
StringMethodCheckStatus stringMethodStatus = RecognizeStringMethodCheck(conditionArgument, out SyntaxNode? stringExpr, out SyntaxNode? substringExpr);
if (stringMethodStatus != StringMethodCheckStatus.Unknown)
Expand Down Expand Up @@ -624,6 +731,54 @@ private static void AnalyzeIsTrueOrIsFalseInvocation(OperationAnalysisContext co
return;
}

// Special-case: enumerable.Count(predicate) > 0 → Assert.Contains(predicate, enumerable)
if (conditionArgument is IBinaryOperation binaryOp &&
binaryOp.OperatorKind == BinaryOperatorKind.GreaterThan)
{
if (binaryOp.LeftOperand is IInvocationOperation countInvocation &&
binaryOp.RightOperand.ConstantValue.HasValue &&
binaryOp.RightOperand.ConstantValue.Value is int intValue &&
intValue == 0 &&
countInvocation.TargetMethod.Name == "Count")
{
SyntaxNode? countCollectionExpr = null;
SyntaxNode? countPredicateExpr = null;

if (countInvocation.Instance != null && countInvocation.Arguments.Length == 1)
{
countCollectionExpr = countInvocation.Instance.Syntax;
countPredicateExpr = countInvocation.Arguments[0].Value.Syntax;
}
else if (countInvocation.Instance == null && countInvocation.Arguments.Length == 2)
{
countCollectionExpr = countInvocation.Arguments[0].Value.Syntax;
countPredicateExpr = countInvocation.Arguments[1].Value.Syntax;
}

if (countCollectionExpr != null && countPredicateExpr != null)
{
string properAssertMethod = isTrueInvocation ? "Contains" : "DoesNotContain";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, CodeFixModeAddArgument);

context.ReportDiagnostic(
context.Operation.CreateDiagnostic(
Rule,
additionalLocations: ImmutableArray.Create(
conditionArgument.Syntax.GetLocation(),
countPredicateExpr.GetLocation(),
countCollectionExpr.GetLocation()),
properties: properties.ToImmutable(),
properAssertMethod,
isTrueInvocation ? "IsTrue" : "IsFalse"));

return;
}
}
}

// Check for comparison patterns: a > b, a >= b, a < b, a <= b
ComparisonCheckStatus comparisonStatus = RecognizeComparisonCheck(conditionArgument, out SyntaxNode? leftExpr, out SyntaxNode? rightExpr);
if (comparisonStatus != ComparisonCheckStatus.Unknown)
Expand Down Expand Up @@ -722,6 +877,40 @@ private static void AnalyzeAreEqualOrAreNotEqualInvocation(OperationAnalysisCont
{
if (TryGetSecondArgumentValue((IInvocationOperation)context.Operation, out IOperation? actualArgumentValue))
{
// Check for LINQ predicate patterns that suggest ContainsSingle
LinqPredicateCheckStatus linqStatus2 = RecognizeLinqPredicateCheck(
actualArgumentValue!,
out SyntaxNode? linqCollectionExpr2,
out SyntaxNode? predicateExpr2,
out _);

if (isAreEqualInvocation &&
linqStatus2 is LinqPredicateCheckStatus.Count or LinqPredicateCheckStatus.WhereCount &&
linqCollectionExpr2 != null &&
predicateExpr2 != null &&
expectedArgument.ConstantValue.HasValue &&
expectedArgument.ConstantValue.Value is int expectedCountValue &&
expectedCountValue == 1)
{
// We have Assert.AreEqual(1, enumerable.Count(predicate))
// We want Assert.ContainsSingle(predicate, enumerable)
string properAssertMethod = "ContainsSingle";

ImmutableDictionary<string, string?>.Builder properties = ImmutableDictionary.CreateBuilder<string, string?>();
properties.Add(ProperAssertMethodNameKey, properAssertMethod);
properties.Add(CodeFixModeKey, CodeFixModeAddArgument);
context.ReportDiagnostic(context.Operation.CreateDiagnostic(
Rule,
additionalLocations: ImmutableArray.Create(
actualArgumentValue.Syntax.GetLocation(),
predicateExpr2.GetLocation(),
linqCollectionExpr2.GetLocation()),
properties: properties.ToImmutable(),
properAssertMethod,
"AreEqual"));
return;
}

// Check if we're comparing a count/length property
CountCheckStatus countStatus = RecognizeCountCheck(
expectedArgument,
Expand Down
3 changes: 2 additions & 1 deletion test/IntegrationTests/MSTest.IntegrationTests/OutputTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ private static void ValidateOutputIsNotMixed(IEnumerable<TestResult> testResults
Assert.Contains(methodName, message.Text);
Assert.Contains("TestInitialize", message.Text);
Assert.Contains("TestCleanup", message.Text);
Assert.IsFalse(shouldNotContain.Any(message.Text.Contains));
// Assert.IsFalse(shouldNotContain.Any(message.Text.Contains));
Assert.DoesNotContain(message.Text.Contains, shouldNotContain);
}

private static void ValidateInitializeAndCleanup(IEnumerable<TestResult> testResults, Func<TestResultMessage, bool> messageFilter)
Expand Down
Loading