diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/AotConversionHelper.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/AotConversionHelper.cs deleted file mode 100644 index c0b7e1e445..0000000000 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/AotConversionHelper.cs +++ /dev/null @@ -1,102 +0,0 @@ -using Microsoft.CodeAnalysis; -using TUnit.Core.SourceGenerator.Extensions; - -namespace TUnit.Core.SourceGenerator.CodeGenerators.Helpers; - -/// -/// Helper for generating AOT-compatible type conversions -/// -public static class AotConversionHelper -{ - /// - /// Generates an AOT-compatible conversion expression - /// - /// The source type - /// The target type - /// The expression to convert - /// An AOT-compatible conversion expression or null if no special handling is needed - public static string? GenerateAotConversion(ITypeSymbol sourceType, ITypeSymbol targetType, string sourceExpression) - { - // Check if direct assignment is possible - if (sourceType.Equals(targetType, SymbolEqualityComparer.Default)) - { - return sourceExpression; - } - - // Look for implicit conversion operators - var implicitConversion = FindConversionOperator(sourceType, targetType, "op_Implicit"); - if (implicitConversion != null) - { - // For implicit conversions, we can use a simple cast - return $"({targetType.GloballyQualified()})(({sourceType.GloballyQualified()}){sourceExpression})"; - } - - // Look for explicit conversion operators - var explicitConversion = FindConversionOperator(sourceType, targetType, "op_Explicit"); - if (explicitConversion != null) - { - // For explicit conversions, we also use a cast - return $"({targetType.GloballyQualified()})(({sourceType.GloballyQualified()}){sourceExpression})"; - } - - // No special AOT conversion needed, let CastHelper handle it - return null; - } - - /// - /// Checks if a type has conversion operators that might not work in AOT - /// - public static bool HasConversionOperators(ITypeSymbol type) - { - var members = type.GetMembers(); - return members.Any(m => m is IMethodSymbol { Name: "op_Implicit" or "op_Explicit", IsStatic: true }); - } - - /// - /// Gets all conversion operators for a type - /// - public static IEnumerable<(IMethodSymbol method, ITypeSymbol targetType)> GetConversionOperators(ITypeSymbol type) - { - var members = type.GetMembers(); - foreach (var member in members) - { - if (member is IMethodSymbol { Name: "op_Implicit" or "op_Explicit" } method and { IsStatic: true, Parameters.Length: 1 }) - { - yield return (method, method.ReturnType); - } - } - } - - private static IMethodSymbol? FindConversionOperator(ITypeSymbol sourceType, ITypeSymbol targetType, string operatorName) - { - // Check operators in source type - var sourceOperators = sourceType.GetMembers(operatorName) - .OfType() - .Where(m => m.IsStatic && m.Parameters.Length == 1); - - foreach (var op in sourceOperators) - { - if (op.ReturnType.Equals(targetType, SymbolEqualityComparer.Default) && - op.Parameters[0].Type.Equals(sourceType, SymbolEqualityComparer.Default)) - { - return op; - } - } - - // Check operators in target type - var targetOperators = targetType.GetMembers(operatorName) - .OfType() - .Where(m => m.IsStatic && m.Parameters.Length == 1); - - foreach (var op in targetOperators) - { - if (op.ReturnType.Equals(targetType, SymbolEqualityComparer.Default) && - op.Parameters[0].Type.Equals(sourceType, SymbolEqualityComparer.Default)) - { - return op; - } - } - - return null; - } -} \ No newline at end of file diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/DynamicTestSourceDataModelRetriever.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/DynamicTestSourceDataModelRetriever.cs deleted file mode 100644 index 0668fdeba5..0000000000 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/DynamicTestSourceDataModelRetriever.cs +++ /dev/null @@ -1,21 +0,0 @@ -using Microsoft.CodeAnalysis; -using TUnit.Core.SourceGenerator.Extensions; -using TUnit.Core.SourceGenerator.Models; - -namespace TUnit.Core.SourceGenerator.CodeGenerators.Helpers; - -public static class DynamicTestSourceDataModelRetriever -{ - public static DynamicTestSourceDataModel ParseDynamicTestBuilders(this IMethodSymbol methodSymbol) - { - var testAttribute = methodSymbol.GetRequiredTestAttribute(); - - return new DynamicTestSourceDataModel - { - Class = methodSymbol.ContainingType, - Method = methodSymbol, - FilePath = testAttribute.ConstructorArguments[0].Value?.ToString() ?? string.Empty, - LineNumber = testAttribute.ConstructorArguments[1].Value as int? ?? 0, - }; - } -} diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/RequiredPropertyHelper.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/RequiredPropertyHelper.cs index e36ca5a2a5..ee2c822c51 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/RequiredPropertyHelper.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/RequiredPropertyHelper.cs @@ -26,24 +26,6 @@ public static IEnumerable GetAllRequiredProperties(ITypeSymbol return requiredProperties; } - /// - /// Gets required properties that have data source attributes - /// - public static IEnumerable GetRequiredPropertiesWithDataSource(ITypeSymbol typeSymbol) - { - return GetAllRequiredProperties(typeSymbol) - .Where(p => HasDataSourceAttribute(p)); - } - - /// - /// Gets required properties that don't have data source attributes - /// - public static IEnumerable GetRequiredPropertiesWithoutDataSource(ITypeSymbol typeSymbol) - { - return GetAllRequiredProperties(typeSymbol) - .Where(p => !HasDataSourceAttribute(p)); - } - private static bool HasDataSourceAttribute(IPropertySymbol property) { return property.GetAttributes().Any(attr => diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedConstantParser.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedConstantParser.cs index d8f985f8a1..571c87a76c 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedConstantParser.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedConstantParser.cs @@ -9,41 +9,6 @@ namespace TUnit.Core.SourceGenerator.CodeGenerators.Helpers; public static class TypedConstantParser { private static readonly TypedConstantFormatter _formatter = new(); - - public static string GetTypedConstantValue(SemanticModel semanticModel, - (TypedConstant typedConstant, AttributeArgumentSyntax a) element, ITypeSymbol? parameterType) - { - // For constant values, use the formatter which handles type conversions properly - if (element.typedConstant.Kind == TypedConstantKind.Primitive) - { - return _formatter.FormatForCode(element.typedConstant, parameterType); - } - - var argumentExpression = element.a.Expression; - - var newExpression = argumentExpression.Accept(new FullyQualifiedWithGlobalPrefixRewriter(semanticModel))!; - - if (parameterType?.TypeKind == TypeKind.Enum && - (newExpression.IsKind(SyntaxKind.UnaryMinusExpression) || newExpression.IsKind(SyntaxKind.UnaryPlusExpression))) - { - return $"({parameterType.GloballyQualified()})({newExpression})"; - } - - if (parameterType?.SpecialType == SpecialType.System_Decimal) - { - return $"{newExpression.ToString().TrimEnd('d')}m"; - } - - if (parameterType is not null - && element.typedConstant.Type is not null - && semanticModel.Compilation.ClassifyConversion(element.typedConstant.Type, parameterType) is - { IsExplicit: true, IsImplicit: false }) - { - return $"({parameterType.GloballyQualified()})({newExpression})"; - } - - return newExpression.ToString(); - } public static string GetFullyQualifiedTypeNameFromTypedConstantValue(TypedConstant typedConstant) { @@ -72,11 +37,6 @@ public static string GetRawTypedConstantValue(TypedConstant typedConstant, IType return _formatter.FormatForCode(typedConstant, targetType); } - private static string FormatPrimitive(TypedConstant typedConstant) - { - return FormatPrimitive(typedConstant.Value); - } - public static string FormatPrimitive(object? value) { // Check for special floating-point values first diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedDataSourceOptimizer.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedDataSourceOptimizer.cs deleted file mode 100644 index 48f1c2a16b..0000000000 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Helpers/TypedDataSourceOptimizer.cs +++ /dev/null @@ -1,114 +0,0 @@ -using Microsoft.CodeAnalysis; -using TUnit.Core.SourceGenerator.Extensions; - -namespace TUnit.Core.SourceGenerator.CodeGenerators.Helpers; - -internal static class TypedDataSourceOptimizer -{ - /// - /// Determines if a typed data source can be optimized for the given parameter types - /// - public static bool CanOptimizeTypedDataSource(AttributeData dataSourceAttribute, IMethodSymbol testMethod) - { - // GetTypedDataSourceType already checks if it's a typed data source (returns null if not) - // This avoids enumerating AllInterfaces twice - var dataSourceType = dataSourceAttribute.GetTypedDataSourceType(); - if (dataSourceType == null) - { - return false; - } - - // For single parameter tests, check if types match directly - if (testMethod.Parameters.Length == 1) - { - return SymbolEqualityComparer.Default.Equals(dataSourceType, testMethod.Parameters[0].Type); - } - - // For multiple parameters, check if data source provides a matching tuple - if (dataSourceType is INamedTypeSymbol { IsTupleType: true } namedType && - namedType.TupleElements.Length == testMethod.Parameters.Length) - { - for (var i = 0; i < testMethod.Parameters.Length; i++) - { - if (!SymbolEqualityComparer.Default.Equals(namedType.TupleElements[i].Type, testMethod.Parameters[i].Type)) - { - return false; - } - } - return true; - } - - return false; - } - - /// - /// Generates optimized code for accessing typed data source values - /// - public static void GenerateOptimizedDataSourceAccess( - ICodeWriter writer, - AttributeData dataSourceAttribute, - string dataSourceVariableName, - string metadataVariableName, - IMethodSymbol testMethod) - { - var dataSourceType = dataSourceAttribute.GetTypedDataSourceType(); - if (dataSourceType == null) - { - // Fallback to standard implementation - GenerateStandardDataSourceAccess(writer, dataSourceVariableName, metadataVariableName); - return; - } - - var typedInterfaceName = $"global::TUnit.Core.ITypedDataSourceAttribute<{dataSourceType.GloballyQualified()}>"; - - writer.AppendLine($"// Optimized typed data source access for {dataSourceType.Name}"); - writer.AppendLine($"var typedDataSource = ({typedInterfaceName}){dataSourceVariableName};"); - writer.AppendLine($"await foreach (var dataFunc in typedDataSource.GetTypedDataRowsAsync({metadataVariableName}))"); - writer.AppendLine("{"); - writer.Indent(); - - if (testMethod.Parameters.Length == 1) - { - // Single parameter - direct assignment - writer.AppendLine("var value = await dataFunc();"); - writer.AppendLine($"var args = new object?[] {{ value }};"); - } - else if (dataSourceType is INamedTypeSymbol { IsTupleType: true } namedType) - { - // Tuple - decompose without boxing - writer.AppendLine("var tuple = await dataFunc();"); - writer.Append("var args = new object?[] { "); - for (var i = 0; i < namedType.TupleElements.Length; i++) - { - if (i > 0) - { - writer.Append(", "); - } - writer.Append($"tuple.Item{i + 1}"); - } - writer.AppendLine(" };"); - } - else - { - // Other types - use ToObjectArray if available - writer.AppendLine("var value = await dataFunc();"); - writer.AppendLine("var args = value.ToObjectArray();"); - } - - writer.Unindent(); - writer.AppendLine("}"); - } - - private static void GenerateStandardDataSourceAccess( - ICodeWriter writer, - string dataSourceVariableName, - string metadataVariableName) - { - writer.AppendLine($"await foreach (var dataFunc in {dataSourceVariableName}.GetDataRowsAsync({metadataVariableName}))"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("var args = await dataFunc();"); - writer.Unindent(); - writer.AppendLine("}"); - } -} \ No newline at end of file diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/TestHooksWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/TestHooksWriter.cs deleted file mode 100644 index 166641c073..0000000000 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/Hooks/TestHooksWriter.cs +++ /dev/null @@ -1,73 +0,0 @@ -using TUnit.Core.SourceGenerator.CodeGenerators.Helpers; -using TUnit.Core.SourceGenerator.Enums; -using TUnit.Core.SourceGenerator.Extensions; -using TUnit.Core.SourceGenerator.Models; - -namespace TUnit.Core.SourceGenerator.CodeGenerators.Writers.Hooks; - -public class TestHooksWriter : BaseHookWriter -{ - public static void Execute(ICodeWriter sourceBuilder, HooksDataModel model) - { - if (model.IsEveryHook) - { - if (model.HookLocationType == HookLocationType.Before) - { - sourceBuilder.Append("new global::TUnit.Core.Hooks.BeforeTestHookMethod"); - } - else - { - sourceBuilder.Append("new global::TUnit.Core.Hooks.AfterTestHookMethod"); - } - - sourceBuilder.Append("{"); - sourceBuilder.Append("MethodInfo = "); - SourceInformationWriter.GenerateMethodInformation(sourceBuilder, model.Context.SemanticModel.Compilation, model.ClassType, model.Method, null, ','); - - sourceBuilder.Append($"Body = (context, cancellationToken) => AsyncConvert.Convert(() => {model.FullyQualifiedTypeName}.{model.MethodName}({GetArgs(model)})),"); - - sourceBuilder.Append($"HookExecutor = {HookExecutorHelper.GetHookExecutor(model.HookExecutor)},"); - sourceBuilder.Append($"Order = {model.Order},"); - sourceBuilder.Append($"RegistrationIndex = global::TUnit.Core.HookRegistrationIndices.GetNext{(model.HookLocationType == HookLocationType.Before ? "Before" : "After")}EveryTestHookIndex(),"); - sourceBuilder.Append($"""FilePath = @"{model.FilePath}","""); - sourceBuilder.Append($"LineNumber = {model.LineNumber},"); - - sourceBuilder.Append("},"); - - return; - } - - sourceBuilder.Append("new global::TUnit.Core.Hooks.InstanceHookMethod"); - sourceBuilder.Append("{"); - sourceBuilder.Append($"InitClassType = typeof({model.FullyQualifiedTypeName}),"); - sourceBuilder.Append("MethodInfo = "); - SourceInformationWriter.GenerateMethodInformation(sourceBuilder, model.Context.SemanticModel.Compilation, model.ClassType, model.Method, null, ','); - - - if (model.ClassType.IsGenericDefinition()) - { - sourceBuilder.Append( - $"Body = (classInstance, context, cancellationToken) => AsyncConvert.ConvertObject(() => classInstance.GetType().GetMethod(\"{model.MethodName}\", [{string.Join(", ", model.ParameterTypes.Select(x => $"typeof({x})"))}]).Invoke(classInstance, {GetArgsOrEmptyArray(model)})),"); - } - else - { - sourceBuilder.Append($"Body = (classInstance, context, cancellationToken) => AsyncConvert.Convert(() => (({model.FullyQualifiedTypeName})classInstance).{model.MethodName}({GetArgs(model)})),"); - } - - sourceBuilder.Append($"HookExecutor = {HookExecutorHelper.GetHookExecutor(model.HookExecutor)},"); - sourceBuilder.Append($"Order = {model.Order},"); - sourceBuilder.Append($"RegistrationIndex = global::TUnit.Core.HookRegistrationIndices.GetNext{(model.HookLocationType == HookLocationType.Before ? "Before" : "After")}TestHookIndex(),"); - - sourceBuilder.Append("},"); - } - - private static string GetArgsOrEmptyArray(HooksDataModel model) - { - if (!model.ParameterTypes.Any()) - { - return "[]"; - } - - return GetArgs(model); - } -} diff --git a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/SourceInformationWriter.cs b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/SourceInformationWriter.cs index 0003bdc679..80fdd3135a 100644 --- a/TUnit.Core.SourceGenerator/CodeGenerators/Writers/SourceInformationWriter.cs +++ b/TUnit.Core.SourceGenerator/CodeGenerators/Writers/SourceInformationWriter.cs @@ -13,30 +13,30 @@ public static void GenerateClassInformation(ICodeWriter sourceCodeWriter, Compil var parent = namedTypeSymbol.ContainingType; var parentExpression = parent != null ? MetadataGenerationHelper.GenerateClassMetadataGetOrAdd(parent, null, sourceCodeWriter.IndentLevel) : null; var classMetadata = MetadataGenerationHelper.GenerateClassMetadataGetOrAdd(namedTypeSymbol, parentExpression, sourceCodeWriter.IndentLevel); - + // Handle multi-line class metadata similar to method metadata var lines = classMetadata.Split(new[] { "\r\n", "\r", "\n" }, StringSplitOptions.None); - + if (lines.Length > 0) { sourceCodeWriter.Append(lines[0].TrimStart()); - + if (lines.Length > 1) { var secondLine = lines[1]; var baseIndentSpaces = secondLine.Length - secondLine.TrimStart().Length; - + for (var i = 1; i < lines.Length; i++) { if (!string.IsNullOrWhiteSpace(lines[i]) || i < lines.Length - 1) { sourceCodeWriter.AppendLine(); - + var line = lines[i]; var lineIndentSpaces = line.Length - line.TrimStart().Length; var relativeIndent = Math.Max(0, lineIndentSpaces - baseIndentSpaces); var extraIndentLevels = relativeIndent / 4; - + var trimmedLine = line.TrimStart(); for (var j = 0; j < extraIndentLevels; j++) { @@ -47,14 +47,7 @@ public static void GenerateClassInformation(ICodeWriter sourceCodeWriter, Compil } } } - - sourceCodeWriter.Append(","); - } - private static void GenerateAssemblyInformation(ICodeWriter sourceCodeWriter, Compilation compilation, IAssemblySymbol assembly) - { - var assemblyMetadata = MetadataGenerationHelper.GenerateAssemblyMetadataGetOrAdd(assembly); - sourceCodeWriter.Append(assemblyMetadata); sourceCodeWriter.Append(","); } diff --git a/TUnit.Core.SourceGenerator/Extensions/EnumerableExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/EnumerableExtensions.cs deleted file mode 100644 index dc776532ab..0000000000 --- a/TUnit.Core.SourceGenerator/Extensions/EnumerableExtensions.cs +++ /dev/null @@ -1,9 +0,0 @@ -namespace TUnit.Core.SourceGenerator.Extensions; - -public static class EnumerableExtensions -{ - public static string ToCommaSeparatedString(this IEnumerable enumerable) - { - return string.Join(", ", enumerable); - } -} diff --git a/TUnit.Core.SourceGenerator/Extensions/MethodExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/MethodExtensions.cs index f4d4c41c77..c82b1199cd 100644 --- a/TUnit.Core.SourceGenerator/Extensions/MethodExtensions.cs +++ b/TUnit.Core.SourceGenerator/Extensions/MethodExtensions.cs @@ -33,47 +33,4 @@ public static bool IsHook(this IMethodSymbol methodSymbol, Compilation compilati { return methodSymbol.GetAttributes().Any(x => x.IsNonGlobalHook(compilation) || x.IsGlobalHook(compilation)); } - - public static AttributeData[] GetAttributesIncludingClass(this IMethodSymbol methodSymbol, INamedTypeSymbol namedTypeSymbol) - { - return GetAttributesIncludingClassEnumerable(methodSymbol, namedTypeSymbol).ToArray(); - } - - public static IEnumerable GetAttributesIncludingClassEnumerable(this IMethodSymbol methodSymbol, INamedTypeSymbol namedTypeSymbol) - { - foreach (var attributeData in methodSymbol.GetAttributes()) - { - yield return attributeData; - } - - var type = namedTypeSymbol; - - while (type != null) - { - foreach (var attributeData in type.GetAttributes()) - { - yield return attributeData; - } - - type = type.BaseType; - } - } - - public static IEnumerable ParametersWithoutTimeoutCancellationToken( - this IMethodSymbol methodSymbol) - { - if (methodSymbol.Parameters.IsDefaultOrEmpty) - { - return []; - } - - if (methodSymbol.Parameters.Last().Type - .GloballyQualifiedNonGeneric() == - WellKnownFullyQualifiedClassNames.CancellationToken.WithGlobalPrefix) - { - return methodSymbol.Parameters.Take(methodSymbol.Parameters.Length - 1); - } - - return methodSymbol.Parameters; - } } diff --git a/TUnit.Core.SourceGenerator/Extensions/ParameterExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/ParameterExtensions.cs deleted file mode 100644 index 9120503e97..0000000000 --- a/TUnit.Core.SourceGenerator/Extensions/ParameterExtensions.cs +++ /dev/null @@ -1,23 +0,0 @@ -using System.Collections.Immutable; -using Microsoft.CodeAnalysis; - -namespace TUnit.Core.SourceGenerator.Extensions; - -public static class ParameterExtensions -{ - public static ImmutableArray WithoutCancellationTokenParameter(this ImmutableArray parameterSymbols) - { - if (parameterSymbols.IsDefaultOrEmpty) - { - return parameterSymbols; - } - - if (parameterSymbols.Last().Type.GloballyQualified() == - WellKnownFullyQualifiedClassNames.CancellationToken.WithGlobalPrefix) - { - return ImmutableArray.Create(parameterSymbols, 0, parameterSymbols.Length - 1); - } - - return parameterSymbols; - } -} diff --git a/TUnit.Core.SourceGenerator/Extensions/StringExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/StringExtensions.cs deleted file mode 100644 index 3eee610ddc..0000000000 --- a/TUnit.Core.SourceGenerator/Extensions/StringExtensions.cs +++ /dev/null @@ -1,18 +0,0 @@ -namespace TUnit.Core.SourceGenerator.Extensions; - -public static class StringExtensions -{ - public static string ReplaceFirstOccurrence(this string source, string find, string replace) - { - var place = source.IndexOf(find, StringComparison.Ordinal); - - return place == -1 ? source : source.Remove(place, find.Length).Insert(place, replace); - } - - public static string ReplaceLastOccurrence(this string source, string find, string replace) - { - var place = source.LastIndexOf(find, StringComparison.Ordinal); - - return place == -1 ? source : source.Remove(place, find.Length).Insert(place, replace); - } -} diff --git a/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs b/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs index d3c10a781c..d1671aba73 100644 --- a/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs +++ b/TUnit.Core.SourceGenerator/Extensions/TypeExtensions.cs @@ -145,91 +145,6 @@ public static bool IsIEnumerable(this ITypeSymbol namedTypeSymbol, Compilation c return false; } - public static string GloballyQualifiedOrFallback(this ITypeSymbol? typeSymbol, TypedConstant? typedConstant = null) - { - if (typeSymbol is not null and not ITypeParameterSymbol) - { - return typeSymbol.GloballyQualified(); - } - - if (typedConstant is not null) - { - return TypedConstantParser.GetFullyQualifiedTypeNameFromTypedConstantValue(typedConstant.Value); - } - - return "var"; - } - - public static bool EnumerableGenericTypeIs(this ITypeSymbol enumerable, GeneratorAttributeSyntaxContext context, - ImmutableArray parameterTypes, [NotNullWhen(true)] out ITypeSymbol? enumerableInnerType) - { - if (parameterTypes.IsDefaultOrEmpty) - { - enumerableInnerType = null; - return false; - } - - var genericEnumerableType = - context.SemanticModel.Compilation.GetSpecialType(SpecialType.System_Collections_Generic_IEnumerable_T).ConstructUnboundGenericType(); - - if (enumerable is INamedTypeSymbol { IsGenericType: true } namedTypeSymbol && namedTypeSymbol - .ConstructUnboundGenericType().Equals(genericEnumerableType, SymbolEqualityComparer.Default)) - { - enumerableInnerType = namedTypeSymbol.TypeArguments.First(); - } - else - { - var enumerableInterface = enumerable.AllInterfaces.FirstOrDefault(x => - x.IsGenericType && x.ConstructUnboundGenericType() - .Equals(genericEnumerableType, SymbolEqualityComparer.Default)); - - enumerableInnerType = enumerableInterface?.TypeArguments.FirstOrDefault(); - } - - if (enumerableInnerType is null) - { - enumerableInnerType = null; - return false; - } - - var firstParameterType = parameterTypes.FirstOrDefault(); - - if (context.SemanticModel.Compilation.HasImplicitConversionOrGenericParameter(enumerableInnerType, firstParameterType)) - { - return true; - } - - if (!enumerableInnerType.IsTupleType && firstParameterType is INamedTypeSymbol { IsGenericType: true }) - { - return true; - } - - if (enumerableInnerType.IsTupleType && enumerableInnerType is INamedTypeSymbol namedInnerType) - { - var tupleTypes = namedInnerType.TupleElements.Select(x => x.Type).ToImmutableArray(); - - for (var index = 0; index < tupleTypes.Length; index++) - { - var tupleType = tupleTypes.ElementAtOrDefault(index); - var parameterType = parameterTypes.ElementAtOrDefault(index); - - if (parameterType?.IsGenericDefinition() == true) - { - continue; - } - - if (!context.SemanticModel.Compilation.HasImplicitConversionOrGenericParameter(tupleType, parameterType)) - { - return false; - } - } - - return true; - } - - return false; - } - public static string GloballyQualified(this ISymbol typeSymbol) { // Handle open generic types where type arguments are type parameters @@ -252,7 +167,7 @@ public static string GloballyQualified(this ISymbol typeSymbol) { return "global::System.Nullable<>"; } - + // General case for other open generic types var typeBuilder = new StringBuilder(typeSymbol.ToDisplayString(DisplayFormats.FullyQualifiedNonGenericWithGlobalPrefix)); typeBuilder.Append('<'); @@ -265,7 +180,7 @@ public static string GloballyQualified(this ISymbol typeSymbol) return typeSymbol.ToDisplayString(DisplayFormats.FullyQualifiedGenericWithGlobalPrefix); } - + /// /// Determines if a type is compiler-generated (e.g., async state machines, lambda closures). /// These types typically contain angle brackets in their names and cannot be represented in source code. @@ -276,12 +191,12 @@ public static bool IsCompilerGeneratedType(this ITypeSymbol? typeSymbol) { return false; } - + // Check the type name directly, not the display string // Compiler-generated types have names that start with '<' or contain '<>' // Examples: d__0, <>c__DisplayClass0_0, <>f__AnonymousType0 var typeName = typeSymbol.Name; - + // Compiler-generated types typically: // 1. Start with '<' (like d__0 for async state machines) // 2. Contain '<>' (like <>c for compiler-generated classes) diff --git a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs index 3e99cce971..14186b5fd4 100644 --- a/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs +++ b/TUnit.Core.SourceGenerator/Generators/TestMetadataGenerator.cs @@ -365,44 +365,6 @@ private static void GenerateTestMetadata(CodeWriter writer, TestMethodMetadata t GenerateModuleInitializer(writer, testMethod, uniqueClassName); } - private static ITypeSymbol ReplaceTypeParametersWithConcreteTypes( - ITypeSymbol type, - ImmutableArray typeParameters, - ImmutableArray typeArguments) - { - if (type is ITypeParameterSymbol typeParam) - { - // Find the index of this type parameter - var index = -1; - for (var j = 0; j < typeParameters.Length; j++) - { - if (typeParameters[j].Name == typeParam.Name) - { - index = j; - break; - } - } - - if (index >= 0 && index < typeArguments.Length) - { - return typeArguments[index]; - } - return type; - } - - if (type is INamedTypeSymbol { IsGenericType: true } namedType) - { - // Replace type arguments in generic types like IEnumerable, Func, etc. - var newTypeArgs = namedType.TypeArguments - .Select(ta => ReplaceTypeParametersWithConcreteTypes(ta, typeParameters, typeArguments)) - .ToImmutableArray(); - - return namedType.ConstructedFrom.Construct(newTypeArgs.ToArray()); - } - - return type; - } - private static void GenerateTestMetadataInstance(CodeWriter writer, TestMethodMetadata testMethod, string className, string combinationGuid) { var methodName = testMethod.MethodSymbol.Name; diff --git a/TUnit.Core.SourceGenerator/Helpers/GenericTypeInference.cs b/TUnit.Core.SourceGenerator/Helpers/GenericTypeInference.cs index 6acba46801..091a767e0d 100644 --- a/TUnit.Core.SourceGenerator/Helpers/GenericTypeInference.cs +++ b/TUnit.Core.SourceGenerator/Helpers/GenericTypeInference.cs @@ -10,47 +10,6 @@ namespace TUnit.Core.SourceGenerator.Helpers; /// internal static class GenericTypeInference { - /// - /// Infers generic type arguments for a test method based on its data source attributes - /// - public static ImmutableArray? InferGenericTypes(IMethodSymbol method, ImmutableArray attributes) - { - if (!method.IsGenericMethod || method.TypeParameters.Length == 0) - { - return null; - } - - // Try to infer from typed data sources first - var inferredTypes = TryInferFromTypedDataSources(method, attributes); - if (inferredTypes != null) - { - return inferredTypes; - } - - // Try to infer from Arguments attributes - inferredTypes = TryInferFromArguments(method, attributes); - if (inferredTypes != null) - { - return inferredTypes; - } - - // Try to infer from parameter attributes that implement IInfersType - inferredTypes = TryInferFromTypeInferringAttributes(method); - if (inferredTypes != null) - { - return inferredTypes; - } - - // Try to infer from MethodDataSource - inferredTypes = TryInferFromMethodDataSource(method, attributes); - if (inferredTypes != null) - { - return inferredTypes; - } - - return null; - } - private static ImmutableArray? TryInferFromTypedDataSources(IMethodSymbol method, ImmutableArray attributes) { foreach (var attribute in attributes) @@ -90,7 +49,7 @@ internal static class GenericTypeInference if (current.IsGenericType) { var name = current.Name; - if (name.Contains("DataSourceGeneratorAttribute") || + if (name.Contains("DataSourceGeneratorAttribute") || name.Contains("AsyncDataSourceGeneratorAttribute")) { return current; @@ -103,52 +62,6 @@ internal static class GenericTypeInference return null; } - private static ImmutableArray? TryInferFromArguments(IMethodSymbol method, ImmutableArray attributes) - { - var argumentsAttributes = attributes - .Where(a => a.AttributeClass?.Name == "ArgumentsAttribute") - .ToList(); - - if (argumentsAttributes.Count == 0) - { - return null; - } - - // Get the first Arguments attribute to infer types - var firstArgs = argumentsAttributes[0]; - if (firstArgs.ConstructorArguments.Length == 0) - { - return null; - } - - var inferredTypes = new List(); - - // Match type parameters with method parameters - for (var i = 0; i < method.TypeParameters.Length && i < method.Parameters.Length; i++) - { - var parameter = method.Parameters[i]; - - if (parameter.Type is ITypeParameterSymbol typeParam) - { - // Get the corresponding argument value from the attribute - if (i < firstArgs.ConstructorArguments.Length) - { - var argValue = firstArgs.ConstructorArguments[i]; - var inferredType = InferTypeFromValue(argValue); - - if (inferredType != null) - { - inferredTypes.Add(inferredType); - } - } - } - } - - return inferredTypes.Count == method.TypeParameters.Length - ? inferredTypes.ToImmutableArray() - : null; - } - private static ImmutableArray? TryInferFromTypeInferringAttributes(IMethodSymbol method) { var inferredTypes = new List(); @@ -165,15 +78,15 @@ internal static class GenericTypeInference { // Look for IInfersType in the attribute's interfaces var infersTypeInterface = attr.AttributeClass.AllInterfaces - .FirstOrDefault(i => i.GloballyQualifiedNonGeneric() == "global::TUnit.Core.Interfaces.IInfersType" && - i.IsGenericType && + .FirstOrDefault(i => i.GloballyQualifiedNonGeneric() == "global::TUnit.Core.Interfaces.IInfersType" && + i.IsGenericType && i.TypeArguments.Length == 1); - + if (infersTypeInterface != null) { // Get the type argument from IInfersType var inferredType = infersTypeInterface.TypeArguments[0]; - + // Find the index of this type parameter var typeParamIndex = -1; for (var i = 0; i < method.TypeParameters.Length; i++) @@ -202,9 +115,9 @@ internal static class GenericTypeInference // Remove any null entries and check if we have all types inferredTypes.RemoveAll(t => t == null); - - return inferredTypes.Count == method.TypeParameters.Length - ? inferredTypes.ToImmutableArray() + + return inferredTypes.Count == method.TypeParameters.Length + ? inferredTypes.ToImmutableArray() : null; } @@ -350,7 +263,7 @@ internal static class GenericTypeInference /// Gets all unique generic type combinations for a method based on its data sources /// public static ImmutableArray> GetAllGenericTypeCombinations( - IMethodSymbol method, + IMethodSymbol method, ImmutableArray attributes) { var combinations = new List>(); @@ -405,14 +318,14 @@ public static ImmutableArray> GetAllGenericTypeCombi for (var i = 0; i < method.TypeParameters.Length && i < method.Parameters.Length; i++) { var parameter = method.Parameters[i]; - + if (parameter.Type is ITypeParameterSymbol) { if (i < args.ConstructorArguments.Length) { var argValue = args.ConstructorArguments[i]; var inferredType = InferTypeFromValue(argValue); - + if (inferredType != null) { inferredTypes.Add(inferredType); @@ -421,8 +334,8 @@ public static ImmutableArray> GetAllGenericTypeCombi } } - return inferredTypes.Count == method.TypeParameters.Length - ? inferredTypes.ToImmutableArray() + return inferredTypes.Count == method.TypeParameters.Length + ? inferredTypes.ToImmutableArray() : null; } @@ -443,4 +356,4 @@ private static bool TypeArraysEqual(ImmutableArray a, ImmutableArra return true; } -} \ No newline at end of file +} diff --git a/TUnit.Core.SourceGenerator/Models/Extracted/ExtractedAttribute.cs b/TUnit.Core.SourceGenerator/Models/Extracted/ExtractedAttribute.cs index ada371254a..e0cf766920 100644 --- a/TUnit.Core.SourceGenerator/Models/Extracted/ExtractedAttribute.cs +++ b/TUnit.Core.SourceGenerator/Models/Extracted/ExtractedAttribute.cs @@ -97,14 +97,6 @@ public static EquatableArray ExtractAll(IEnumerable - /// Gets a named argument value by name, or null if not found. - /// - public TypedConstantModel? GetNamedArgument(string name) - { - return NamedArguments.FirstOrDefault(a => a.Name == name)?.Value; - } - /// /// Gets a constructor argument by index, or null if index is out of range. /// diff --git a/TUnit.Core.SourceGenerator/Utilities/AsyncDataSourceHelper.cs b/TUnit.Core.SourceGenerator/Utilities/AsyncDataSourceHelper.cs index a17b05a7d4..a6206f365e 100644 --- a/TUnit.Core.SourceGenerator/Utilities/AsyncDataSourceHelper.cs +++ b/TUnit.Core.SourceGenerator/Utilities/AsyncDataSourceHelper.cs @@ -7,81 +7,6 @@ namespace TUnit.Core.SourceGenerator; /// internal static class AsyncDataSourceHelper { - /// - /// Generates the ConvertToSync helper method for async data sources - /// - public static void GenerateConvertToSyncMethod(CodeWriter writer) - { - writer.AppendLine("private static global::System.Collections.Generic.IEnumerable ConvertToSync(global::System.Func> asyncFactory)"); - writer.AppendLine("{"); - writer.Indent(); - - writer.AppendLine("var cts = new global::System.Threading.CancellationTokenSource();"); - writer.AppendLine("var enumerator = asyncFactory(cts.Token).GetAsyncEnumerator(cts.Token);"); - writer.AppendLine("try"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("while (true)"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("try"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("var moveNextTask = enumerator.MoveNextAsync().AsTask();"); - writer.AppendLine("using (var timeoutCts = new global::System.Threading.CancellationTokenSource(global::System.TimeSpan.FromSeconds(30)))"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("var completedTask = global::System.Threading.Tasks.Task.WhenAny(moveNextTask, global::System.Threading.Tasks.Task.Delay(global::System.Threading.Timeout.Infinite, timeoutCts.Token)).ConfigureAwait(false).GetAwaiter().GetResult();"); - writer.AppendLine("if (completedTask != moveNextTask)"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("throw new global::System.TimeoutException(\"Async data source timed out after 30 seconds\");"); - writer.Unindent(); - writer.AppendLine("}"); - writer.AppendLine("if (!moveNextTask.ConfigureAwait(false).GetAwaiter().GetResult())"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("break;"); - writer.Unindent(); - writer.AppendLine("}"); - writer.Unindent(); - writer.AppendLine("}"); - writer.Unindent(); - writer.AppendLine("}"); - writer.AppendLine("catch (AggregateException ae) when (ae.InnerException is OperationCanceledException)"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("break;"); - writer.Unindent(); - writer.AppendLine("}"); - writer.AppendLine("yield return enumerator.Current;"); - writer.Unindent(); - writer.AppendLine("}"); - writer.Unindent(); - writer.AppendLine("}"); - writer.AppendLine("finally"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("try"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("enumerator.DisposeAsync().AsTask().ConfigureAwait(false).GetAwaiter().GetResult();"); - writer.Unindent(); - writer.AppendLine("}"); - writer.AppendLine("catch"); - writer.AppendLine("{"); - writer.Indent(); - writer.AppendLine("// Ignore disposal errors for async enumerator"); - writer.Unindent(); - writer.AppendLine("}"); - writer.AppendLine("cts.Dispose();"); - writer.Unindent(); - writer.AppendLine("}"); - - writer.Unindent(); - writer.AppendLine("}"); - } - /// /// Determines if a method represents an async data source /// diff --git a/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs b/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs index 1baaabf3a2..e5332c94c2 100644 --- a/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs +++ b/TUnit.Core.SourceGenerator/Utilities/MetadataGenerationHelper.cs @@ -10,53 +10,6 @@ namespace TUnit.Core.SourceGenerator.Utilities; /// internal static class MetadataGenerationHelper { - /// - /// Writes a multi-line string with proper indentation - /// - private static void WriteIndentedString(ICodeWriter writer, string multiLineString, bool firstLineIsInline = true) - { - var lines = multiLineString.Split(new[] { "\r\n", "\r", "\n" }, StringSplitOptions.None); - - // Find the base indentation level from the content (skip first line as it's usually inline) - var baseIndent = 0; - if (lines.Length > 1) - { - // Find first non-empty line after the first to determine base indentation - for (var i = 1; i < lines.Length; i++) - { - if (!string.IsNullOrWhiteSpace(lines[i])) - { - baseIndent = lines[i].Length - lines[i].TrimStart().Length; - break; - } - } - } - - for (var i = 0; i < lines.Length; i++) - { - if (i > 0) - { - writer.AppendLine(); - } - - var line = lines[i]; - if (!string.IsNullOrWhiteSpace(line)) - { - // Calculate how much indentation this line has beyond the base - var currentIndent = line.Length - line.TrimStart().Length; - var relativeIndent = Math.Max(0, currentIndent - baseIndent); - - // Add relative indentation - for (var j = 0; j < relativeIndent; j++) - { - writer.Append(" "); - } - - writer.Append(line.TrimStart()); - } - } - } - /// /// Writes code for creating a MethodMetadata instance ///