diff --git a/samples/FunctionApp/Program.cs b/samples/FunctionApp/Program.cs index ddb965c1c..431ac149c 100644 --- a/samples/FunctionApp/Program.cs +++ b/samples/FunctionApp/Program.cs @@ -1,9 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the MIT License. See License.txt in the project root for license information. -using System; -using System.Diagnostics; -using System.Threading; using System.Threading.Tasks; using Microsoft.Azure.Functions.Worker; using Microsoft.Extensions.DependencyInjection; diff --git a/src/DotNetWorker.Core/Context/Features/IDictionaryExtensions.cs b/src/DotNetWorker.Core/Context/Features/IDictionaryExtensions.cs index 78696b9af..1796cd5c6 100644 --- a/src/DotNetWorker.Core/Context/Features/IDictionaryExtensions.cs +++ b/src/DotNetWorker.Core/Context/Features/IDictionaryExtensions.cs @@ -8,49 +8,59 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.ComponentModel; -namespace Microsoft.Azure.Functions.Worker +internal static class IDictionaryExtensions { - internal static class IDictionaryExtensions + internal static bool TryAdd(this IDictionary dictionary, TKey key, TValue value) { - internal static bool TryAdd(this IDictionary dictionary, TKey key, TValue value) + if (key == null) { - if (key == null) - { - throw new ArgumentNullException(nameof(key)); - } - - if (dictionary.ContainsKey(key)) - { - return false; - } - - dictionary.Add(key, value); - return true; + throw new ArgumentNullException(nameof(key)); } + + if (dictionary.ContainsKey(key)) + { + return false; + } + + dictionary.Add(key, value); + return true; } +} - internal static class ConcurrentDictionaryExtensions +internal static class ConcurrentDictionaryExtensions +{ + public static TValue GetOrAdd(this ConcurrentDictionary dictionary, TKey key, Func valueFactory, TArg factoryArgument) { - public static TValue GetOrAdd(this ConcurrentDictionary dictionary, TKey key, Func valueFactory, TArg factoryArgument) + if (dictionary == null) { - if (dictionary == null) - { - throw new ArgumentNullException(nameof(dictionary)); - } - - if (key == null) - { - throw new ArgumentNullException(nameof(key)); - } - - if (valueFactory == null) - { - throw new ArgumentNullException(nameof(valueFactory)); - } - - return dictionary.GetOrAdd(key, k => valueFactory(k, factoryArgument)); + throw new ArgumentNullException(nameof(dictionary)); } + + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + + if (valueFactory == null) + { + throw new ArgumentNullException(nameof(valueFactory)); + } + + return dictionary.GetOrAdd(key, k => valueFactory(k, factoryArgument)); } } + +internal static class KeyValuePairExtensions +{ + // Based on https://source.dot.net/#System.Private.CoreLib/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/KeyValuePair.cs,aa57b8e336bf7f59 + [EditorBrowsable(EditorBrowsableState.Never)] + public static void Deconstruct(this KeyValuePair pair, out TKey key, out TValue value) + { + key = pair.Key; + value = pair.Value; + } +} + #endif diff --git a/src/DotNetWorker.Core/Hosting/WorkerOptions.cs b/src/DotNetWorker.Core/Hosting/WorkerOptions.cs index 85044ba89..c9c3aead6 100644 --- a/src/DotNetWorker.Core/Hosting/WorkerOptions.cs +++ b/src/DotNetWorker.Core/Hosting/WorkerOptions.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the MIT License. See License.txt in the project root for license information. +using System.Collections.Generic; using System.Text.Json; using Azure.Core.Serialization; @@ -22,12 +23,26 @@ public class WorkerOptions /// public InputConverterCollection InputConverters { get; } = new InputConverterCollection(); + /// + /// Gets the optional worker capabilities. + /// + public IDictionary Capabilities { get; } = new Dictionary() + { + // Enable these by default, although they are not strictly required and can be removed + { "HandlesWorkerTerminateMessage", bool.TrueString }, + { "HandlesInvocationCancelMessage", bool.TrueString } + }; + /// /// Gets and sets the flag for opting in to unwrapping user-code-thrown /// exceptions when they are surfaced to the Host. /// - public bool EnableUserCodeException { get; set; } = false; - + public bool EnableUserCodeException + { + get => GetBoolCapability(nameof(EnableUserCodeException)); + set => SetBoolCapability(nameof(EnableUserCodeException), value); + } + /// /// Gets or sets a value that determines if empty entries should be included in the function trigger message payload. /// For example, if a set of entries were sent to a messaging service such as Service Bus or Event Hub and your function @@ -35,6 +50,29 @@ public class WorkerOptions /// function code as trigger data when this setting value is . When it is , /// All entries will be sent to the function code as it is. Default value for this setting is . /// - public bool IncludeEmptyEntriesInMessagePayload { get; set; } + public bool IncludeEmptyEntriesInMessagePayload + { + get => GetBoolCapability(nameof(IncludeEmptyEntriesInMessagePayload)); + set => SetBoolCapability(nameof(IncludeEmptyEntriesInMessagePayload), value); + } + + private bool GetBoolCapability(string name) + { + return Capabilities.TryGetValue(name, out string? value) && bool.TryParse(value, out bool b) && b; + } + + // For false values, the host does not expect the capability to exist; there are some cases where this + // will be interpreted as "true" just because the key is there. + private void SetBoolCapability(string name, bool value) + { + if (value) + { + Capabilities[name] = bool.TrueString; + } + else + { + Capabilities.Remove(name); + } + } } } diff --git a/src/DotNetWorker.Grpc/GrpcWorker.cs b/src/DotNetWorker.Grpc/GrpcWorker.cs index 781360377..f83f5a2b3 100644 --- a/src/DotNetWorker.Grpc/GrpcWorker.cs +++ b/src/DotNetWorker.Grpc/GrpcWorker.cs @@ -212,24 +212,20 @@ internal static WorkerInitResponse WorkerInitRequestHandler(WorkerInitRequest re response.WorkerMetadata.CustomProperties.Add("Worker.Grpc.Version", typeof(GrpcWorker).Assembly.GetName().Version?.ToString()); - response.Capabilities.Add("RpcHttpBodyOnly", bool.TrueString); - response.Capabilities.Add("RawHttpBodyBytes", bool.TrueString); - response.Capabilities.Add("RpcHttpTriggerMetadataRemoved", bool.TrueString); - response.Capabilities.Add("UseNullableValueDictionaryForHttp", bool.TrueString); - response.Capabilities.Add("TypedDataCollection", bool.TrueString); - response.Capabilities.Add("WorkerStatus", bool.TrueString); - response.Capabilities.Add("HandlesWorkerTerminateMessage", bool.TrueString); - response.Capabilities.Add("HandlesInvocationCancelMessage", bool.TrueString); - - if (workerOptions.EnableUserCodeException) + // Add additional capabilities defined by WorkerOptions + foreach ((string key, string value) in workerOptions.Capabilities) { - response.Capabilities.Add("EnableUserCodeException", bool.TrueString); - } - if (workerOptions.IncludeEmptyEntriesInMessagePayload) - { - response.Capabilities.Add("IncludeEmptyEntriesInMessagePayload", bool.TrueString); + response.Capabilities[key] = value; } + // Add required capabilities; these cannot be modified and will override anything from WorkerOptions + response.Capabilities["RpcHttpBodyOnly"] = bool.TrueString; + response.Capabilities["RawHttpBodyBytes"] = bool.TrueString; + response.Capabilities["RpcHttpTriggerMetadataRemoved"] = bool.TrueString; + response.Capabilities["UseNullableValueDictionaryForHttp"] = bool.TrueString; + response.Capabilities["TypedDataCollection"] = bool.TrueString; + response.Capabilities["WorkerStatus"] = bool.TrueString; + return response; } diff --git a/test/DotNetWorkerTests/GrpcWorkerTests.cs b/test/DotNetWorkerTests/GrpcWorkerTests.cs index db5f7b7f8..ba4d29bed 100644 --- a/test/DotNetWorkerTests/GrpcWorkerTests.cs +++ b/test/DotNetWorkerTests/GrpcWorkerTests.cs @@ -15,7 +15,10 @@ using Microsoft.Azure.Functions.Worker.Handlers; using Microsoft.Azure.Functions.Worker.Invocation; using Microsoft.Azure.Functions.Worker.OutputBindings; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; using Moq; using Xunit; @@ -40,7 +43,8 @@ public GrpcWorkerTests() _mockApplication .Setup(m => m.CreateContext(It.IsAny(), It.IsAny())) - .Returns((IInvocationFeatures f, CancellationToken ct) => { + .Returns((IInvocationFeatures f, CancellationToken ct) => + { _context = new TestFunctionContext(f, ct); return _context; }); @@ -176,6 +180,7 @@ public void InitRequest_ReturnsExpectedCapabilities_BasedOnWorkerOptions( string expectedCapabilityValue = null) { var workerOptions = new WorkerOptions(); + // Update boolean property values of workerOption based on test input parameters. workerOptions.GetType().GetProperty(booleanPropertyName)?.SetValue(workerOptions, booleanPropertyValue); @@ -202,6 +207,36 @@ public void InitRequest_ReturnsExpectedCapabilities_BasedOnWorkerOptions( } } + [Fact] + public void WorkerOptions_CanChangeOptionalCapabilities() + { + var host = new HostBuilder() + .ConfigureFunctionsWorkerDefaults((WorkerOptions options) => + { + options.Capabilities.Remove("HandlesWorkerTerminateMessage"); + options.Capabilities.Add("SomeNewCapability", bool.TrueString); + }).Build(); + + var workerOptions = host.Services.GetService>().Value; + var response = GrpcWorker.WorkerInitRequestHandler(new(), workerOptions); + + void AssertKeyAndValue(KeyValuePair kvp, string expectedKey, string expectedValue) + { + Assert.Same(expectedKey, kvp.Key); + Assert.Same(expectedValue, kvp.Value); + } + + Assert.Collection(response.Capabilities.OrderBy(p => p.Key), + c => AssertKeyAndValue(c, "HandlesInvocationCancelMessage", bool.TrueString), + c => AssertKeyAndValue(c, "RawHttpBodyBytes", bool.TrueString), + c => AssertKeyAndValue(c, "RpcHttpBodyOnly", bool.TrueString), + c => AssertKeyAndValue(c, "RpcHttpTriggerMetadataRemoved", bool.TrueString), + c => AssertKeyAndValue(c, "SomeNewCapability", bool.TrueString), + c => AssertKeyAndValue(c, "TypedDataCollection", bool.TrueString), + c => AssertKeyAndValue(c, "UseNullableValueDictionaryForHttp", bool.TrueString), + c => AssertKeyAndValue(c, "WorkerStatus", bool.TrueString)); + } + [Fact] public async Task Invoke_ReturnsSuccess() {