diff --git a/sdk/Sdk.Generators/Constants.cs b/sdk/Sdk.Generators/Constants.cs index 0d98a856a..8dabe2b19 100644 --- a/sdk/Sdk.Generators/Constants.cs +++ b/sdk/Sdk.Generators/Constants.cs @@ -5,35 +5,50 @@ namespace Microsoft.Azure.Functions.Worker.Sdk.Generators { internal static class Constants { - // Our types - internal const string BindingAttributeType = "Microsoft.Azure.Functions.Worker.Extensions.Abstractions.BindingAttribute"; - internal const string OutputBindingAttributeType = "Microsoft.Azure.Functions.Worker.Extensions.Abstractions.OutputBindingAttribute"; - internal const string FunctionNameType = "Microsoft.Azure.Functions.Worker.FunctionAttribute"; - internal const string HttpResponseType = "Microsoft.Azure.Functions.Worker.Http.HttpResponseData"; - internal const string EventHubsTriggerType = "Microsoft.Azure.Functions.Worker.EventHubTriggerAttribute"; - internal const string BindingPropertyNameAttributeType = "Microsoft.Azure.Functions.Worker.Extensions.Abstractions.BindingPropertyNameAttribute"; - internal const string DefaultValueType = "Microsoft.Azure.Functions.Worker.Extensions.Abstractions.DefaultValueAttribute"; + public static class BuildProperties + { + internal const string EnableSourceGenProp = "build_property.FunctionsMetadataSourceGen_Enabled"; + } - // System types - internal const string IEnumerableType = "System.Collections.IEnumerable"; - internal const string IEnumerableGenericType = "System.Collections.Generic.IEnumerable`1"; - internal const string IEnumerableOfStringType = "System.Collections.Generic.IEnumerable`1"; - internal const string IEnumerableOfBinaryType = "System.Collections.Generic.IEnumerable`1"; - internal const string IEnumerableOfT = "System.Collections.Generic.IEnumerable`1"; - internal const string IEnumerableOfKeyValuePair = "System.Collections.Generic.IEnumerable`1>"; - internal const string StringType = "System.String"; - internal const string ByteArrayType = "System.Byte[]"; - internal const string ByteStructType = "System.Byte"; - internal const string TaskGenericType = "System.Threading.Tasks.Task`1"; - internal const string TaskType = "System.Threading.Tasks.Task"; - internal const string VoidType = "System.Void"; - internal const string ReadOnlyMemoryOfBytes = "System.ReadOnlyMemory`1"; - internal const string LookupGenericType = "System.Linq.Lookup`2"; - internal const string DictionaryGenericType = "System.Collections.Generic.Dictionary`2"; + public static class FileNames + { + internal const string GeneratedFunctionMetadata = "GeneratedFunctionMetadataProvider.g.cs"; + } - internal const string ReturnBindingName = "$return"; - internal const string HttpResponseBindingName = "HttpResponse"; - internal const string HttpTriggerBindingType = "Microsoft.Azure.Functions.Worker.HttpTriggerAttribute"; - internal const string IsBatchedKey = "IsBatched"; + public static class Types + { + // Our types + internal const string BindingAttribute = "Microsoft.Azure.Functions.Worker.Extensions.Abstractions.BindingAttribute"; + internal const string OutputBindingAttribute = "Microsoft.Azure.Functions.Worker.Extensions.Abstractions.OutputBindingAttribute"; + internal const string FunctionName = "Microsoft.Azure.Functions.Worker.FunctionAttribute"; + internal const string HttpResponse = "Microsoft.Azure.Functions.Worker.Http.HttpResponseData"; + internal const string HttpTriggerBinding = "Microsoft.Azure.Functions.Worker.HttpTriggerAttribute"; + internal const string EventHubsTrigger = "Microsoft.Azure.Functions.Worker.EventHubTriggerAttribute"; + internal const string BindingPropertyNameAttribute = "Microsoft.Azure.Functions.Worker.Extensions.Abstractions.BindingPropertyNameAttribute"; + internal const string DefaultValue = "Microsoft.Azure.Functions.Worker.Extensions.Abstractions.DefaultValueAttribute"; + + // System types + internal const string IEnumerable = "System.Collections.IEnumerable"; + internal const string IEnumerableGeneric = "System.Collections.Generic.IEnumerable`1"; + internal const string IEnumerableOfString = "System.Collections.Generic.IEnumerable`1"; + internal const string IEnumerableOfBinary = "System.Collections.Generic.IEnumerable`1"; + internal const string IEnumerableOfT = "System.Collections.Generic.IEnumerable`1"; + internal const string IEnumerableOfKeyValuePair = "System.Collections.Generic.IEnumerable`1>"; + internal const string String = "System.String"; + internal const string ByteArray = "System.Byte[]"; + internal const string ByteStruct = "System.Byte"; + internal const string TaskGeneric = "System.Threading.Tasks.Task`1"; + internal const string Task = "System.Threading.Tasks.Task"; + internal const string Void = "System.Void"; + internal const string ReadOnlyMemoryOfBytes = "System.ReadOnlyMemory`1"; + internal const string LookupGeneric = "System.Linq.Lookup`2"; + internal const string DictionaryGeneric = "System.Collections.Generic.Dictionary`2"; + } + + public static class FunctionMetadataBindingProps { + internal const string ReturnBindingName = "$return"; + internal const string HttpResponseBindingName = "HttpResponse"; + internal const string IsBatchedKey = "IsBatched"; + } } } diff --git a/sdk/Sdk.Generators/FunctionMetadataProviderGenerator.Parser.cs b/sdk/Sdk.Generators/FunctionMetadataProviderGenerator.Parser.cs index 39e15ec02..9ce22ee5b 100644 --- a/sdk/Sdk.Generators/FunctionMetadataProviderGenerator.Parser.cs +++ b/sdk/Sdk.Generators/FunctionMetadataProviderGenerator.Parser.cs @@ -99,7 +99,7 @@ private bool IsValidMethodAzureFunction(SemanticModel model, MethodDeclarationSy foreach (var attr in methodSymbol.GetAttributes()) { if (attr.AttributeClass != null && - SymbolEqualityComparer.Default.Equals(attr.AttributeClass, Compilation.GetTypeByMetadataName(Constants.FunctionNameType))) + SymbolEqualityComparer.Default.Equals(attr.AttributeClass, Compilation.GetTypeByMetadataName(Constants.Types.FunctionName))) { functionName = (string)attr.ConstructorArguments.First().Value!; // If this is a function attribute this won't be null return true; @@ -147,7 +147,7 @@ private bool TryGetMethodOutputBinding(MethodDeclarationSyntax method, SemanticM foreach (var attribute in attributes) { - if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass?.BaseType, Compilation.GetTypeByMetadataName(Constants.OutputBindingAttributeType))) + if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass?.BaseType, Compilation.GetTypeByMetadataName(Constants.Types.OutputBindingAttribute))) { // There can only be one output binding associated with a function. If there is more than one, we return a diagnostic error here. if (hasOutputBinding) @@ -164,7 +164,7 @@ private bool TryGetMethodOutputBinding(MethodDeclarationSyntax method, SemanticM if (outputBindingAttribute != null) { - if (!TryCreateBindingDict(outputBindingAttribute, Constants.ReturnBindingName, bindingLocation, out IDictionary? bindingDict)) + if (!TryCreateBindingDict(outputBindingAttribute, Constants.FunctionMetadataBindingProps.ReturnBindingName, bindingLocation, out IDictionary? bindingDict)) { bindingsList = null; return false; @@ -207,18 +207,18 @@ private bool TryGetParameterInputAndTriggerBindings(MethodDeclarationSyntax meth // Check to see if any of the attributes associated with this parameter is a BindingAttribute foreach (var attribute in parameterSymbol.GetAttributes()) { - if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass?.BaseType?.BaseType, Compilation.GetTypeByMetadataName(Constants.BindingAttributeType))) + if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass?.BaseType?.BaseType, Compilation.GetTypeByMetadataName(Constants.Types.BindingAttribute))) { var validEventHubs = false; var cardinality = Cardinality.Undefined; var dataType = GetDataType(parameterSymbol.Type); // There are two special cases we need to handle: HttpTrigger and EventHubsTrigger. - if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, Compilation.GetTypeByMetadataName(Constants.HttpTriggerBindingType))) + if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, Compilation.GetTypeByMetadataName(Constants.Types.HttpTriggerBinding))) { hasHttpTrigger = true; } - else if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, Compilation.GetTypeByMetadataName(Constants.EventHubsTriggerType))) + else if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, Compilation.GetTypeByMetadataName(Constants.Types.EventHubsTrigger))) { // there are special rules for EventHubsTriggers that we will have to validate validEventHubs = IsEventHubsTriggerValid(parameterSymbol, parameter.Type, model, attribute, out dataType, out cardinality); @@ -276,11 +276,11 @@ private bool TryGetReturnTypeBindings(MethodDeclarationSyntax method, SemanticMo return false; } - if (!SymbolEqualityComparer.Default.Equals(returnTypeSymbol, Compilation.GetTypeByMetadataName(Constants.VoidType)) && - !SymbolEqualityComparer.Default.Equals(returnTypeSymbol, Compilation.GetTypeByMetadataName(Constants.TaskType))) + if (!SymbolEqualityComparer.Default.Equals(returnTypeSymbol, Compilation.GetTypeByMetadataName(Constants.Types.Void)) && + !SymbolEqualityComparer.Default.Equals(returnTypeSymbol, Compilation.GetTypeByMetadataName(Constants.Types.Task))) { // If there is a Task return type, inspect T, the inner type. - if (SymbolEqualityComparer.Default.Equals(returnTypeSymbol, Compilation.GetTypeByMetadataName(Constants.TaskGenericType))) + if (SymbolEqualityComparer.Default.Equals(returnTypeSymbol, Compilation.GetTypeByMetadataName(Constants.Types.TaskGeneric))) { GenericNameSyntax genericSyntax = (GenericNameSyntax)returnTypeSyntax; var innerTypeSyntax = genericSyntax.TypeArgumentList.Arguments.First(); // Generic task should only have one type argument @@ -294,9 +294,9 @@ private bool TryGetReturnTypeBindings(MethodDeclarationSyntax method, SemanticMo } } - if (SymbolEqualityComparer.Default.Equals(returnTypeSymbol, Compilation.GetTypeByMetadataName(Constants.HttpResponseType))) // If return type is HttpResponseData + if (SymbolEqualityComparer.Default.Equals(returnTypeSymbol, Compilation.GetTypeByMetadataName(Constants.Types.HttpResponse))) // If return type is HttpResponseData { - bindingsList.Add(GetHttpReturnBinding(Constants.ReturnBindingName)); + bindingsList.Add(GetHttpReturnBinding(Constants.FunctionMetadataBindingProps.ReturnBindingName)); } else { @@ -327,7 +327,7 @@ private bool TryGetReturnTypePropertyBindings(ITypeSymbol returnTypeSymbol, bool } // Check if this attribute is an HttpResponseData type attribute - if (prop is IPropertySymbol property && SymbolEqualityComparer.Default.Equals(property.Type, Compilation.GetTypeByMetadataName(Constants.HttpResponseType))) + if (prop is IPropertySymbol property && SymbolEqualityComparer.Default.Equals(property.Type, Compilation.GetTypeByMetadataName(Constants.Types.HttpResponse))) { if (foundHttpOutput) { @@ -345,7 +345,7 @@ private bool TryGetReturnTypePropertyBindings(ITypeSymbol returnTypeSymbol, bool foreach (var attr in prop.GetAttributes()) // now loop through and check if any of the attributes are Binding attributes { - if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.BaseType, Compilation.GetTypeByMetadataName(Constants.OutputBindingAttributeType))) + if (SymbolEqualityComparer.Default.Equals(attr.AttributeClass?.BaseType, Compilation.GetTypeByMetadataName(Constants.Types.OutputBindingAttribute))) { // validate that there's only one binding attribute per property if (foundPropertyOutputAttr) @@ -374,11 +374,11 @@ private bool TryGetReturnTypePropertyBindings(ITypeSymbol returnTypeSymbol, bool { if (!hasOutputBinding) { - bindingsList.Add(GetHttpReturnBinding(Constants.ReturnBindingName)); + bindingsList.Add(GetHttpReturnBinding(Constants.FunctionMetadataBindingProps.ReturnBindingName)); } else { - bindingsList.Add(GetHttpReturnBinding(Constants.HttpResponseBindingName)); + bindingsList.Add(GetHttpReturnBinding(Constants.FunctionMetadataBindingProps.HttpResponseBindingName)); } } @@ -413,7 +413,7 @@ private bool TryCreateBindingDict(AttributeData bindingAttrData, string bindingN string bindingType = attributeName.TrimStringsFromEnd(_functionsStringNamesToRemove); // Set binding direction - string bindingDirection = SymbolEqualityComparer.Default.Equals(bindingAttrData.AttributeClass?.BaseType, Compilation.GetTypeByMetadataName(Constants.OutputBindingAttributeType)) ? "Out" : "In"; + string bindingDirection = SymbolEqualityComparer.Default.Equals(bindingAttrData.AttributeClass?.BaseType, Compilation.GetTypeByMetadataName(Constants.Types.OutputBindingAttribute)) ? "Out" : "In"; var bindingCount = attributeProperties!.Count + 3; bindings = new Dictionary(capacity: bindingCount) @@ -458,7 +458,7 @@ private bool TryGetAttributeProperties(AttributeData attributeData, Location? at { if (namedArgument.Value.Value != null) { - if (string.Equals(namedArgument.Key, Constants.IsBatchedKey) && !attrProperties.ContainsKey("cardinality")) + if (string.Equals(namedArgument.Key, Constants.FunctionMetadataBindingProps.IsBatchedKey) && !attrProperties.ContainsKey("cardinality")) { var argValue = (bool)namedArgument.Value.Value; // isBatched only takes in booleans and the generator will parse it as a bool so we can type cast this to use in the next line @@ -474,13 +474,13 @@ private bool TryGetAttributeProperties(AttributeData attributeData, Location? at // some properties have default values, so if these properties were not already defined in constructor or named arguments, we will auto-add them here foreach (var member in attributeData.AttributeClass!.GetMembers().Where(a => a is IPropertySymbol)) { - var defaultValAttrList = member.GetAttributes().Where(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, Compilation.GetTypeByMetadataName(Constants.DefaultValueType))); + var defaultValAttrList = member.GetAttributes().Where(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, Compilation.GetTypeByMetadataName(Constants.Types.DefaultValue))); if (defaultValAttrList.SingleOrDefault() is { } defaultValAttr) // list will only be of size one b/c there cannot be duplicates of an attribute on one piece of syntax { var argName = member.Name; object arg = defaultValAttr.ConstructorArguments.SingleOrDefault().Value!; // only one constructor arg in DefaultValue attribute (the default value) - if (arg is bool b && string.Equals(argName, Constants.IsBatchedKey)) + if (arg is bool b && string.Equals(argName, Constants.FunctionMetadataBindingProps.IsBatchedKey)) { if (!attrProperties.Keys.Contains("cardinality")) { @@ -546,7 +546,7 @@ private bool TryLoadConstructorArguments(AttributeData attributeData, IDictionar private void OverrideBindingName(INamedTypeSymbol attributeClass, ref string argumentName) { - var bindingPropertyNameSymbol = Compilation.GetTypeByMetadataName(Constants.BindingPropertyNameAttributeType); + var bindingPropertyNameSymbol = Compilation.GetTypeByMetadataName(Constants.Types.BindingPropertyNameAttribute); foreach (var prop in attributeClass.GetMembers().Where(a => a is IPropertySymbol)) { @@ -574,7 +574,7 @@ private bool IsEventHubsTriggerValid(IParameterSymbol parameterSymbol, TypeSynta // check if IsBatched is defined in the NamedArguments foreach (var arg in attribute.NamedArguments) { - if (String.Equals(arg.Key, Constants.IsBatchedKey) && + if (String.Equals(arg.Key, Constants.FunctionMetadataBindingProps.IsBatchedKey) && arg.Value.Value != null) { var isBatched = (bool)arg.Value.Value; // isBatched takes in booleans so we can just type cast it here to use @@ -590,8 +590,8 @@ private bool IsEventHubsTriggerValid(IParameterSymbol parameterSymbol, TypeSynta // Check the default value of IsBatched var eventHubsAttr = attribute.AttributeClass; - var isBatchedProp = eventHubsAttr!.GetMembers().Where(m => string.Equals(m.Name, Constants.IsBatchedKey, StringComparison.OrdinalIgnoreCase)).SingleOrDefault(); - AttributeData defaultValAttr = isBatchedProp.GetAttributes().Where(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, Compilation.GetTypeByMetadataName(Constants.DefaultValueType))).SingleOrDefault(); + var isBatchedProp = eventHubsAttr!.GetMembers().Where(m => string.Equals(m.Name, Constants.FunctionMetadataBindingProps.IsBatchedKey, StringComparison.OrdinalIgnoreCase)).SingleOrDefault(); + AttributeData defaultValAttr = isBatchedProp.GetAttributes().Where(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, Compilation.GetTypeByMetadataName(Constants.Types.DefaultValue))).SingleOrDefault(); var defaultVal = defaultValAttr.ConstructorArguments.SingleOrDefault().Value!.ToString(); // there is only one constructor arg, the default value if (!bool.TryParse(defaultVal, out bool b) || !b) { @@ -602,20 +602,20 @@ private bool IsEventHubsTriggerValid(IParameterSymbol parameterSymbol, TypeSynta // we check if the param is an array type // we exclude byte arrays (byte[]) b/c we handle that as cardinality one (we handle this simliar to how a char[] is basically a string) - if (parameterSymbol.Type is IArrayTypeSymbol && !SymbolEqualityComparer.Default.Equals(parameterSymbol.Type, Compilation.GetTypeByMetadataName(Constants.ByteArrayType))) + if (parameterSymbol.Type is IArrayTypeSymbol && !SymbolEqualityComparer.Default.Equals(parameterSymbol.Type, Compilation.GetTypeByMetadataName(Constants.Types.ByteArray))) { dataType = GetDataType(parameterSymbol.Type); cardinality = Cardinality.Many; return true; } - var isGenericEnumerable = parameterSymbol.Type.IsOrImplementsOrDerivesFrom(Compilation.GetTypeByMetadataName(Constants.IEnumerableGenericType)); - var isEnumerable = parameterSymbol.Type.IsOrImplementsOrDerivesFrom(Compilation.GetTypeByMetadataName(Constants.IEnumerableType)); + var isGenericEnumerable = parameterSymbol.Type.IsOrImplementsOrDerivesFrom(Compilation.GetTypeByMetadataName(Constants.Types.IEnumerableGeneric)); + var isEnumerable = parameterSymbol.Type.IsOrImplementsOrDerivesFrom(Compilation.GetTypeByMetadataName(Constants.Types.IEnumerable)); // Check if mapping type - mapping enumerables are not valid types for EventHubParameters - if (parameterSymbol.Type.IsOrImplementsOrDerivesFrom(Compilation.GetTypeByMetadataName(Constants.IEnumerableOfKeyValuePair)) - || parameterSymbol.Type.IsOrImplementsOrDerivesFrom(Compilation.GetTypeByMetadataName(Constants.LookupGenericType)) - || parameterSymbol.Type.IsOrImplementsOrDerivesFrom(Compilation.GetTypeByMetadataName(Constants.DictionaryGenericType))) + if (parameterSymbol.Type.IsOrImplementsOrDerivesFrom(Compilation.GetTypeByMetadataName(Constants.Types.IEnumerableOfKeyValuePair)) + || parameterSymbol.Type.IsOrImplementsOrDerivesFrom(Compilation.GetTypeByMetadataName(Constants.Types.LookupGeneric)) + || parameterSymbol.Type.IsOrImplementsOrDerivesFrom(Compilation.GetTypeByMetadataName(Constants.Types.DictionaryGeneric))) { return false; } @@ -671,13 +671,13 @@ private DataType ResolveIEnumerableOfT(IParameterSymbol parameterSymbol, TypeSyn { INamedTypeSymbol? genericInterfaceSymbol = null; - if (currSymbol.IsOrDerivedFrom(Compilation.GetTypeByMetadataName(Constants.IEnumerableGenericType)) && currSymbol is INamedTypeSymbol currNamedSymbol) + if (currSymbol.IsOrDerivedFrom(Compilation.GetTypeByMetadataName(Constants.Types.IEnumerableGeneric)) && currSymbol is INamedTypeSymbol currNamedSymbol) { finalSymbol = currNamedSymbol; break; } - genericInterfaceSymbol = currSymbol.Interfaces.Where(i => i.IsOrDerivedFrom(Compilation.GetTypeByMetadataName(Constants.IEnumerableGenericType))).FirstOrDefault(); + genericInterfaceSymbol = currSymbol.Interfaces.Where(i => i.IsOrDerivedFrom(Compilation.GetTypeByMetadataName(Constants.Types.IEnumerableGeneric))).FirstOrDefault(); if (genericInterfaceSymbol != null) { finalSymbol = genericInterfaceSymbol; @@ -721,17 +721,17 @@ private DataType GetDataType(ITypeSymbol symbol) private bool IsStringType(ITypeSymbol symbol) { - return SymbolEqualityComparer.Default.Equals(symbol, Compilation.GetTypeByMetadataName(Constants.StringType)) - || (symbol is IArrayTypeSymbol arraySymbol && SymbolEqualityComparer.Default.Equals(arraySymbol.ElementType, Compilation.GetTypeByMetadataName(Constants.StringType))); + return SymbolEqualityComparer.Default.Equals(symbol, Compilation.GetTypeByMetadataName(Constants.Types.String)) + || (symbol is IArrayTypeSymbol arraySymbol && SymbolEqualityComparer.Default.Equals(arraySymbol.ElementType, Compilation.GetTypeByMetadataName(Constants.Types.String))); } private bool IsBinaryType(ITypeSymbol symbol) { - var isByteArray = SymbolEqualityComparer.Default.Equals(symbol, Compilation.GetTypeByMetadataName(Constants.ByteArrayType)) - || (symbol is IArrayTypeSymbol arraySymbol && SymbolEqualityComparer.Default.Equals(arraySymbol.ElementType, Compilation.GetTypeByMetadataName(Constants.ByteStructType))); - var isReadOnlyMemoryOfBytes = SymbolEqualityComparer.Default.Equals(symbol, Compilation.GetTypeByMetadataName(Constants.ReadOnlyMemoryOfBytes)); + var isByteArray = SymbolEqualityComparer.Default.Equals(symbol, Compilation.GetTypeByMetadataName(Constants.Types.ByteArray)) + || (symbol is IArrayTypeSymbol arraySymbol && SymbolEqualityComparer.Default.Equals(arraySymbol.ElementType, Compilation.GetTypeByMetadataName(Constants.Types.ByteStruct))); + var isReadOnlyMemoryOfBytes = SymbolEqualityComparer.Default.Equals(symbol, Compilation.GetTypeByMetadataName(Constants.Types.ReadOnlyMemoryOfBytes)); var isArrayOfByteArrays = symbol is IArrayTypeSymbol outerArray && - outerArray.ElementType is IArrayTypeSymbol innerArray && SymbolEqualityComparer.Default.Equals(innerArray.ElementType, Compilation.GetTypeByMetadataName(Constants.ByteStructType)); + outerArray.ElementType is IArrayTypeSymbol innerArray && SymbolEqualityComparer.Default.Equals(innerArray.ElementType, Compilation.GetTypeByMetadataName(Constants.Types.ByteStruct)); return isByteArray || isReadOnlyMemoryOfBytes || isArrayOfByteArrays; diff --git a/sdk/Sdk.Generators/FunctionMetadataProviderGenerator.cs b/sdk/Sdk.Generators/FunctionMetadataProviderGenerator.cs index 1bf803c49..70365a316 100644 --- a/sdk/Sdk.Generators/FunctionMetadataProviderGenerator.cs +++ b/sdk/Sdk.Generators/FunctionMetadataProviderGenerator.cs @@ -24,17 +24,27 @@ public void Execute(GeneratorExecutionContext context) return; } + context.AnalyzerConfigOptions.GlobalOptions.TryGetValue(Constants.BuildProperties.EnableSourceGenProp, out var sourceGenSwitch); + + bool.TryParse(sourceGenSwitch, out bool enableSourceGen); + + if (!enableSourceGen) + { + return; + } + // attempt to parse user compilation var p = new Parser(context); + IReadOnlyList functionMetadataInfo = p.GetFunctionMetadataInfo(receiver.CandidateMethods); // Proceed to generate the file if function metadata info was successfully returned if (functionMetadataInfo.Count > 0) { - var e = new Emitter(); + Emitter e = new(); string result = e.Emit(functionMetadataInfo, context.CancellationToken); - context.AddSource($"GeneratedFunctionMetadataProvider.g.cs", SourceText.From(result, Encoding.UTF8)); + context.AddSource(Constants.FileNames.GeneratedFunctionMetadata, SourceText.From(result, Encoding.UTF8)); } } diff --git a/sdk/Sdk.Generators/Properties/launchSettings.json b/sdk/Sdk.Generators/Properties/launchSettings.json index a90ca0b10..55769ef01 100644 --- a/sdk/Sdk.Generators/Properties/launchSettings.json +++ b/sdk/Sdk.Generators/Properties/launchSettings.json @@ -1,6 +1,6 @@ { "profiles": { - "SourceGenDebug": { + "SourceGenerator": { "commandName": "DebugRoslynComponent", "targetProject": "..\\..\\samples\\FunctionApp\\FunctionApp.csproj" } diff --git a/sdk/Sdk.Generators/Sdk.Generators.csproj b/sdk/Sdk.Generators/Sdk.Generators.csproj index 9e8211cc8..08c08cc40 100644 --- a/sdk/Sdk.Generators/Sdk.Generators.csproj +++ b/sdk/Sdk.Generators/Sdk.Generators.csproj @@ -34,5 +34,4 @@ - \ No newline at end of file diff --git a/sdk/Sdk/Targets/Microsoft.Azure.Functions.Worker.Sdk.props b/sdk/Sdk/Targets/Microsoft.Azure.Functions.Worker.Sdk.props index b60e95325..239d53a8f 100644 --- a/sdk/Sdk/Targets/Microsoft.Azure.Functions.Worker.Sdk.props +++ b/sdk/Sdk/Targets/Microsoft.Azure.Functions.Worker.Sdk.props @@ -20,6 +20,9 @@ WARNING: DO NOT MODIFY this file unless you are knowledgeable about MSBuild and + + + - Support sdk-type binding reference type (#1107) +- Add MS Build property to disable source generation of function metadata (it is enabled automatically) (#1200) + - Prop name: `FunctionsMetadataSourceGen_Enabled` \ No newline at end of file diff --git a/test/Sdk.Generator.Tests/TestHelpers.cs b/test/Sdk.Generator.Tests/TestHelpers.cs index 1666a983d..2d62e2669 100644 --- a/test/Sdk.Generator.Tests/TestHelpers.cs +++ b/test/Sdk.Generator.Tests/TestHelpers.cs @@ -34,9 +34,9 @@ public static Task RunTestAsync( Sources = { inputSource }, AdditionalReferences = { - typeof(WorkerExtensionStartupAttribute).Assembly, - }, - }, + typeof(WorkerExtensionStartupAttribute).Assembly + } + } }; if (expectedOutputSource != null && expectedFileName != null) @@ -44,6 +44,10 @@ public static Task RunTestAsync( test.TestState.GeneratedSources.Add((typeof(TSourceGenerator), expectedFileName, SourceText.From(expectedOutputSource, Encoding.UTF8))); } + // Enable SourceGen MSBuild Property for testing + string config = $"is_global = true{Environment.NewLine}build_property.FunctionsMetadataSourceGen_Enabled = {true}"; + test.TestState.AnalyzerConfigFiles.Add(("/.globalconfig", config)); + foreach (var item in extensionAssemblyReferences) { test.TestState.AdditionalReferences.Add(item);