diff --git a/src/DotNetWorker.Grpc/Definition/GrpcFunctionDefinition.cs b/src/DotNetWorker.Grpc/Definition/GrpcFunctionDefinition.cs index 2cc109df7..08ba2d952 100644 --- a/src/DotNetWorker.Grpc/Definition/GrpcFunctionDefinition.cs +++ b/src/DotNetWorker.Grpc/Definition/GrpcFunctionDefinition.cs @@ -19,9 +19,22 @@ public GrpcFunctionDefinition(FunctionLoadRequest loadRequest, IMethodInfoLocato { EntryPoint = loadRequest.Metadata.EntryPoint; Name = loadRequest.Metadata.Name; - PathToAssembly = Path.GetFullPath(loadRequest.Metadata.ScriptFile); Id = loadRequest.FunctionId; + string? scriptRoot = Environment.GetEnvironmentVariable("AzureWebJobsScriptRoot"); + if (string.IsNullOrWhiteSpace(scriptRoot)) + { + throw new InvalidOperationException("The 'AzureWebJobsScriptRoot' environment variable value is not defined. This is a required environment variable that is automatically set by the Azure Functions runtime."); + } + + if (string.IsNullOrWhiteSpace(loadRequest.Metadata.ScriptFile)) + { + throw new InvalidOperationException($"Metadata for function '{loadRequest.Metadata.Name} ({loadRequest.Metadata.FunctionId})' does not specify a 'ScriptFile'."); + } + + string scriptFile = Path.Combine(scriptRoot, loadRequest.Metadata.ScriptFile); + PathToAssembly = Path.GetFullPath(scriptFile); + var grpcBindingsGroup = loadRequest.Metadata.Bindings.GroupBy(kv => kv.Value.Direction); var grpcInputBindings = grpcBindingsGroup.Where(kv => kv.Key == BindingInfo.Types.Direction.In).FirstOrDefault(); var grpcOutputBindings = grpcBindingsGroup.Where(kv => kv.Key != BindingInfo.Types.Direction.In).FirstOrDefault(); diff --git a/src/DotNetWorker.Grpc/DotNetWorker.Grpc.csproj b/src/DotNetWorker.Grpc/DotNetWorker.Grpc.csproj index 7fb6186de..680672fd5 100644 --- a/src/DotNetWorker.Grpc/DotNetWorker.Grpc.csproj +++ b/src/DotNetWorker.Grpc/DotNetWorker.Grpc.csproj @@ -8,8 +8,9 @@ Microsoft.Azure.Functions.Worker.Grpc Microsoft.Azure.Functions.Worker.Grpc true - 7 - -preview2 + 8 + -preview1 + true @@ -39,6 +40,7 @@ Grpc.Net.Client implementation--> + diff --git a/src/DotNetWorker.Grpc/GrpcServiceCollectionExtensions.cs b/src/DotNetWorker.Grpc/GrpcServiceCollectionExtensions.cs index cfce895d9..9ba5842db 100644 --- a/src/DotNetWorker.Grpc/GrpcServiceCollectionExtensions.cs +++ b/src/DotNetWorker.Grpc/GrpcServiceCollectionExtensions.cs @@ -3,22 +3,15 @@ using System; using System.Threading.Channels; -using Grpc.Core; using Microsoft.Azure.Functions.Worker; using Microsoft.Azure.Functions.Worker.Core.FunctionMetadata; using Microsoft.Azure.Functions.Worker.Grpc.Messages; -using static Microsoft.Azure.Functions.Worker.Grpc.Messages.FunctionRpc; using Microsoft.Azure.Functions.Worker.Logging; using Microsoft.Azure.Functions.Worker.Grpc; using Microsoft.Azure.Functions.Worker.Diagnostics; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Options; - -#if NET5_0_OR_GREATER -using Grpc.Net.Client; -#endif namespace Microsoft.Extensions.DependencyInjection { @@ -35,7 +28,7 @@ internal static IServiceCollection RegisterOutputChannel(this IServiceCollection AllowSynchronousContinuations = true }; - return new GrpcHostChannel(System.Threading.Channels.Channel.CreateUnbounded(outputOptions)); + return new GrpcHostChannel(Channel.CreateUnbounded(outputOptions)); }); } @@ -56,40 +49,21 @@ public static IServiceCollection AddGrpc(this IServiceCollection services) // gRPC Core services services.AddSingleton(); - services.AddSingleton(p => - { - IOptions argumentsOptions = p.GetService>() - ?? throw new InvalidOperationException("gRPC Services are not correctly registered."); - - GrpcWorkerStartupOptions arguments = argumentsOptions.Value; - - string uriString = $"http://{arguments.Host}:{arguments.Port}"; - if (!Uri.TryCreate(uriString, UriKind.Absolute, out Uri? grpcUri)) - { - throw new InvalidOperationException($"The gRPC channel URI '{uriString}' could not be parsed."); - } - #if NET5_0_OR_GREATER - GrpcChannel grpcChannel = GrpcChannel.ForAddress(grpcUri, new GrpcChannelOptions() - { - MaxReceiveMessageSize = arguments.GrpcMaxMessageLength, - MaxSendMessageSize = arguments.GrpcMaxMessageLength, - Credentials = ChannelCredentials.Insecure - }); + // If we are running in the native host process, use the native client + // for communication (interop). Otherwise; use the gRPC client. + if (AppContext.GetData("AZURE_FUNCTIONS_NATIVE_HOST") is not null) + { + services.AddSingleton(); + } + else + { + services.AddSingleton(); + } #else - - var options = new ChannelOption[] - { - new ChannelOption(Grpc.Core.ChannelOptions.MaxReceiveMessageLength, arguments.GrpcMaxMessageLength), - new ChannelOption(Grpc.Core.ChannelOptions.MaxSendMessageLength, arguments.GrpcMaxMessageLength) - }; - - Grpc.Core.Channel grpcChannel = new Grpc.Core.Channel(arguments.Host, arguments.Port, ChannelCredentials.Insecure, options); - + services.AddSingleton(); #endif - return new FunctionRpcClient(grpcChannel); - }); services.AddOptions() .Configure((arguments, config) => diff --git a/src/DotNetWorker.Grpc/GrpcWorker.cs b/src/DotNetWorker.Grpc/GrpcWorker.cs index 3989f4d5b..7956d2053 100644 --- a/src/DotNetWorker.Grpc/GrpcWorker.cs +++ b/src/DotNetWorker.Grpc/GrpcWorker.cs @@ -5,10 +5,8 @@ using System.Linq; using System.Runtime.InteropServices; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; using Azure.Core.Serialization; -using Grpc.Core; using Microsoft.Azure.Functions.Worker.Context.Features; using Microsoft.Azure.Functions.Worker.Core.FunctionMetadata; using Microsoft.Azure.Functions.Worker.Grpc; @@ -21,52 +19,43 @@ using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using static Microsoft.Azure.Functions.Worker.Grpc.Messages.FunctionRpc; using MsgType = Microsoft.Azure.Functions.Worker.Grpc.Messages.StreamingMessage.ContentOneofCase; namespace Microsoft.Azure.Functions.Worker { - internal class GrpcWorker : IWorker + internal class GrpcWorker : IWorker, IMessageProcessor { - private readonly ChannelReader _outputReader; - private readonly ChannelWriter _outputWriter; - private readonly IFunctionsApplication _application; - private readonly FunctionRpcClient _rpcClient; private readonly IInvocationFeaturesFactory _invocationFeaturesFactory; private readonly IOutputBindingsInfoProvider _outputBindingsInfoProvider; private readonly IInputConversionFeatureProvider _inputConversionFeatureProvider; private readonly IMethodInfoLocator _methodInfoLocator; - private readonly GrpcWorkerStartupOptions _startupOptions; private readonly WorkerOptions _workerOptions; private readonly ObjectSerializer _serializer; private readonly IHostApplicationLifetime _hostApplicationLifetime; + private readonly IWorkerClientFactory _workerClientFactory; private readonly IInvocationHandler _invocationHandler; private readonly IFunctionMetadataProvider _functionMetadataProvider; - - public GrpcWorker(IFunctionsApplication application, FunctionRpcClient rpcClient, GrpcHostChannel outputChannel, IInvocationFeaturesFactory invocationFeaturesFactory, - IOutputBindingsInfoProvider outputBindingsInfoProvider, IMethodInfoLocator methodInfoLocator, - IOptions startupOptions, IOptions workerOptions, - IInputConversionFeatureProvider inputConversionFeatureProvider, - IFunctionMetadataProvider functionMetadataProvider, - IHostApplicationLifetime hostApplicationLifetime, - ILogger logger) + private IWorkerClient? _workerClient; + + public GrpcWorker(IFunctionsApplication application, + IWorkerClientFactory workerClientFactory, + IInvocationFeaturesFactory invocationFeaturesFactory, + IOutputBindingsInfoProvider outputBindingsInfoProvider, + IMethodInfoLocator methodInfoLocator, + IOptions workerOptions, + IInputConversionFeatureProvider inputConversionFeatureProvider, + IFunctionMetadataProvider functionMetadataProvider, + IHostApplicationLifetime hostApplicationLifetime, + ILogger logger) { - if (outputChannel == null) - { - throw new ArgumentNullException(nameof(outputChannel)); - } - - _outputReader = outputChannel.Channel.Reader; - _outputWriter = outputChannel.Channel.Writer; - _hostApplicationLifetime = hostApplicationLifetime ?? throw new ArgumentNullException(nameof(hostApplicationLifetime)); + _workerClientFactory = workerClientFactory ?? throw new ArgumentNullException(nameof(workerClientFactory)); _application = application ?? throw new ArgumentNullException(nameof(application)); - _rpcClient = rpcClient ?? throw new ArgumentNullException(nameof(rpcClient)); _invocationFeaturesFactory = invocationFeaturesFactory ?? throw new ArgumentNullException(nameof(invocationFeaturesFactory)); _outputBindingsInfoProvider = outputBindingsInfoProvider ?? throw new ArgumentNullException(nameof(outputBindingsInfoProvider)); _methodInfoLocator = methodInfoLocator ?? throw new ArgumentNullException(nameof(methodInfoLocator)); - _startupOptions = startupOptions?.Value ?? throw new ArgumentNullException(nameof(startupOptions)); + _workerOptions = workerOptions?.Value ?? throw new ArgumentNullException(nameof(workerOptions)); _serializer = workerOptions.Value.Serializer ?? throw new InvalidOperationException(nameof(workerOptions.Value.Serializer)); _inputConversionFeatureProvider = inputConversionFeatureProvider ?? throw new ArgumentNullException(nameof(inputConversionFeatureProvider)); @@ -78,54 +67,16 @@ public GrpcWorker(IFunctionsApplication application, FunctionRpcClient rpcClient public async Task StartAsync(CancellationToken token) { - var eventStream = _rpcClient.EventStream(cancellationToken: token); - - await SendStartStreamMessageAsync(eventStream.RequestStream); - - _ = StartWriterAsync(eventStream.RequestStream); - _ = StartReaderAsync(eventStream.ResponseStream); + _workerClient = await _workerClientFactory.StartClientAsync(this, token); } - public Task StopAsync(CancellationToken token) - { - return Task.CompletedTask; - } + public Task StopAsync(CancellationToken token) => Task.CompletedTask; - private async Task SendStartStreamMessageAsync(IClientStreamWriter requestStream) - { - StartStream str = new StartStream() - { - WorkerId = _startupOptions.WorkerId - }; - - StreamingMessage startStream = new StreamingMessage() - { - StartStream = str - }; - - await requestStream.WriteAsync(startStream); - } - - private async Task StartWriterAsync(IClientStreamWriter requestStream) - { - await foreach (StreamingMessage rpcWriteMsg in _outputReader.ReadAllAsync()) - { - await requestStream.WriteAsync(rpcWriteMsg); - } - } - - private async Task StartReaderAsync(IAsyncStreamReader responseStream) - { - while (await responseStream.MoveNext()) - { - await ProcessRequestAsync(responseStream.Current); - } - } - - private Task ProcessRequestAsync(StreamingMessage request) + Task IMessageProcessor.ProcessMessageAsync(StreamingMessage message) { // Dispatch and return. - Task.Run(() => ProcessRequestCoreAsync(request)); + _ = ProcessRequestCoreAsync(message); + return Task.CompletedTask; } @@ -179,7 +130,7 @@ private async Task ProcessRequestCoreAsync(StreamingMessage request) return; } - await _outputWriter.WriteAsync(responseMessage); + await _workerClient!.SendMessageAsync(responseMessage); } internal Task InvocationRequestHandlerAsync(InvocationRequest request) diff --git a/src/DotNetWorker.Grpc/GrpcWorkerClientFactory.cs b/src/DotNetWorker.Grpc/GrpcWorkerClientFactory.cs new file mode 100644 index 000000000..1fafc5aaf --- /dev/null +++ b/src/DotNetWorker.Grpc/GrpcWorkerClientFactory.cs @@ -0,0 +1,144 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Grpc.Core; +using Microsoft.Azure.Functions.Worker.Grpc.Messages; +using Microsoft.Extensions.Options; +using static Microsoft.Azure.Functions.Worker.Grpc.Messages.FunctionRpc; + +#if NET5_0_OR_GREATER +using Grpc.Net.Client; +#else +using GrpcCore = Grpc.Core; +#endif + +namespace Microsoft.Azure.Functions.Worker.Grpc +{ + internal class GrpcWorkerClientFactory : IWorkerClientFactory + { + private readonly GrpcHostChannel _outputChannel; + private readonly GrpcWorkerStartupOptions _startupOptions; + + public GrpcWorkerClientFactory(GrpcHostChannel outputChannel, IOptions startupOptions) + { + _outputChannel = outputChannel ?? throw new ArgumentNullException(nameof(outputChannel)); + _startupOptions = startupOptions?.Value ?? throw new ArgumentNullException(nameof(startupOptions), "gRPC Services are not correctly registered."); + } + + public Task StartClientAsync(IMessageProcessor messageProcessor, CancellationToken token) + => GrpcWorkerClient.CreateAndStartAsync(_outputChannel, _startupOptions, messageProcessor, token); + + private class GrpcWorkerClient : IWorkerClient + { + private readonly FunctionRpcClient _grpcClient; + private readonly GrpcWorkerStartupOptions _startupOptions; + private readonly ChannelReader _outputReader; + private readonly ChannelWriter _outputWriter; + private bool _running; + private IMessageProcessor? _processor; + + public GrpcWorkerClient(GrpcHostChannel outputChannel, GrpcWorkerStartupOptions startupOptions) + { + _startupOptions = startupOptions ?? throw new ArgumentNullException(nameof(startupOptions)); + + _outputReader = outputChannel.Channel.Reader; + _outputWriter = outputChannel.Channel.Writer; + + _grpcClient = CreateClient(); + } + + internal static async Task CreateAndStartAsync(GrpcHostChannel outputChannel, GrpcWorkerStartupOptions startupOptions, IMessageProcessor processor, CancellationToken token) + { + var client = new GrpcWorkerClient(outputChannel, startupOptions); + + await client.StartAsync(processor, token); + + return client; + } + + private async Task StartAsync(IMessageProcessor processor, CancellationToken token) + { + if (_running) + { + throw new InvalidOperationException($"The client is already running. Multiple calls to {nameof(StartAsync)} are not supported."); + } + _running = true; + _processor = processor ?? throw new ArgumentNullException(nameof(processor)); + + var eventStream = _grpcClient.EventStream(cancellationToken: token); + + await SendStartStreamMessageAsync(eventStream.RequestStream); + + _ = StartWriterAsync(eventStream.RequestStream); + _ = StartReaderAsync(eventStream.ResponseStream); + } + + private async Task SendStartStreamMessageAsync(IClientStreamWriter requestStream) + { + StartStream str = new StartStream() + { + WorkerId = _startupOptions.WorkerId + }; + + StreamingMessage startStream = new StreamingMessage() + { + StartStream = str + }; + + await requestStream.WriteAsync(startStream); + } + + public ValueTask SendMessageAsync(StreamingMessage message) => _outputWriter.WriteAsync(message); + + private async Task StartWriterAsync(IClientStreamWriter requestStream) + { + await foreach (StreamingMessage rpcWriteMsg in _outputReader.ReadAllAsync()) + { + await requestStream.WriteAsync(rpcWriteMsg); + } + } + + private async Task StartReaderAsync(IAsyncStreamReader responseStream) + { + while (await responseStream.MoveNext()) + { + await _processor!.ProcessMessageAsync(responseStream.Current); + } + } + + private FunctionRpcClient CreateClient() + { + string uriString = $"http://{_startupOptions.Host}:{_startupOptions.Port}"; + if (!Uri.TryCreate(uriString, UriKind.Absolute, out Uri? grpcUri)) + { + throw new InvalidOperationException($"The gRPC channel URI '{uriString}' could not be parsed."); + } + + +#if NET5_0_OR_GREATER + GrpcChannel grpcChannel = GrpcChannel.ForAddress(grpcUri, new GrpcChannelOptions() + { + MaxReceiveMessageSize = _startupOptions.GrpcMaxMessageLength, + MaxSendMessageSize = _startupOptions.GrpcMaxMessageLength, + Credentials = ChannelCredentials.Insecure + }); +#else + + var options = new ChannelOption[] + { + new ChannelOption(GrpcCore.ChannelOptions.MaxReceiveMessageLength, _startupOptions.GrpcMaxMessageLength), + new ChannelOption(GrpcCore.ChannelOptions.MaxSendMessageLength, _startupOptions.GrpcMaxMessageLength) + }; + + GrpcCore.Channel grpcChannel = new GrpcCore.Channel(_startupOptions.Host, _startupOptions.Port, ChannelCredentials.Insecure, options); + +#endif + return new FunctionRpcClient(grpcChannel); + } + } + } +} diff --git a/src/DotNetWorker.Grpc/IGrpcRequestHandler.cs b/src/DotNetWorker.Grpc/IGrpcRequestHandler.cs new file mode 100644 index 000000000..162d2ae50 --- /dev/null +++ b/src/DotNetWorker.Grpc/IGrpcRequestHandler.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using Microsoft.Azure.Functions.Worker.Grpc.Messages; +using System.Threading.Tasks; + +namespace Microsoft.Azure.Functions.Worker.Grpc +{ + internal interface IMessageProcessor + { + Task ProcessMessageAsync(StreamingMessage request); + } +} diff --git a/src/DotNetWorker.Grpc/IWorkerClient.cs b/src/DotNetWorker.Grpc/IWorkerClient.cs new file mode 100644 index 000000000..a36e546e0 --- /dev/null +++ b/src/DotNetWorker.Grpc/IWorkerClient.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System.Threading.Tasks; +using Microsoft.Azure.Functions.Worker.Grpc.Messages; + +namespace Microsoft.Azure.Functions.Worker.Grpc +{ + internal interface IWorkerClient + { + ValueTask SendMessageAsync(StreamingMessage message); + } +} diff --git a/src/DotNetWorker.Grpc/IWorkerClientFactory.cs b/src/DotNetWorker.Grpc/IWorkerClientFactory.cs new file mode 100644 index 000000000..c6033e8ab --- /dev/null +++ b/src/DotNetWorker.Grpc/IWorkerClientFactory.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Azure.Functions.Worker.Grpc +{ + internal interface IWorkerClientFactory + { + Task StartClientAsync(IMessageProcessor messageProcessor, CancellationToken token); + } +} \ No newline at end of file diff --git a/src/DotNetWorker.Grpc/NativeHostIntegration/NativeHost.cs b/src/DotNetWorker.Grpc/NativeHostIntegration/NativeHost.cs new file mode 100644 index 000000000..2b16e7d6a --- /dev/null +++ b/src/DotNetWorker.Grpc/NativeHostIntegration/NativeHost.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.Azure.Functions.Worker.Grpc.NativeHostIntegration +{ + [StructLayout(LayoutKind.Sequential)] + internal struct NativeHost + { + public IntPtr pNativeApplication; + } +} diff --git a/src/DotNetWorker.Grpc/NativeHostIntegration/NativeMethods.cs b/src/DotNetWorker.Grpc/NativeHostIntegration/NativeMethods.cs new file mode 100644 index 000000000..7d0e03b5c --- /dev/null +++ b/src/DotNetWorker.Grpc/NativeHostIntegration/NativeMethods.cs @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Runtime.InteropServices; +using Google.Protobuf; +using Microsoft.Azure.Functions.Worker.Grpc.Messages; + +namespace Microsoft.Azure.Functions.Worker.Grpc.NativeHostIntegration +{ + internal static unsafe partial class NativeMethods + { + private const string NativeWorkerDll = "FunctionsNetHost.exe"; + + public static NativeHost GetNativeHostData() + { + _ = get_application_properties(out var hostData); + return hostData; + } + + public static void RegisterCallbacks(NativeSafeHandle nativeApplication, + delegate* unmanaged requestCallback, + IntPtr grpcHandler) + { + _ = register_callbacks(nativeApplication, requestCallback, grpcHandler); + } + + public static void SendStreamingMessage(NativeSafeHandle nativeApplication, StreamingMessage streamingMessage) + { + byte[] bytes = streamingMessage.ToByteArray(); + _ = send_streaming_message(nativeApplication, bytes, bytes.Length); + } + + [DllImport(NativeWorkerDll)] + private static extern int get_application_properties(out NativeHost hostData); + + [DllImport(NativeWorkerDll)] + private static extern int send_streaming_message(NativeSafeHandle pInProcessApplication, byte[] streamingMessage, int streamingMessageSize); + + [DllImport(NativeWorkerDll)] + private static extern unsafe int register_callbacks(NativeSafeHandle pInProcessApplication, + delegate* unmanaged requestCallback, + IntPtr grpcHandler); + } +} diff --git a/src/DotNetWorker.Grpc/NativeHostIntegration/NativeSafeHandle.cs b/src/DotNetWorker.Grpc/NativeHostIntegration/NativeSafeHandle.cs new file mode 100644 index 000000000..3a89776ee --- /dev/null +++ b/src/DotNetWorker.Grpc/NativeHostIntegration/NativeSafeHandle.cs @@ -0,0 +1,50 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Runtime.InteropServices; +using System.Threading.Tasks.Sources; + +namespace Microsoft.Azure.Functions.Worker.Grpc.NativeHostIntegration +{ + internal sealed class NativeSafeHandle : SafeHandle, IValueTaskSource + { + private ManualResetValueTaskSourceCore _core; // mutable struct; do not make this readonly + + public override bool IsInvalid => handle == IntPtr.Zero; + public short Version => _core.Version; + + public NativeSafeHandle(IntPtr handle) : base(IntPtr.Zero, ownsHandle: true) + { + this.handle = handle; + } + + protected override bool ReleaseHandle() + { + handle = IntPtr.Zero; + + // Complete the ManualResetValueTaskSourceCore + if (_core.GetStatus(_core.Version) == ValueTaskSourceStatus.Pending) + { + _core.SetResult(null); + } + + return true; + } + + public object? GetResult(short token) + { + return _core.GetResult(token); + } + + public ValueTaskSourceStatus GetStatus(short token) + { + return _core.GetStatus(token); + } + + public void OnCompleted(Action continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) + { + _core.OnCompleted(continuation, state, token, flags); + } + } +} diff --git a/src/DotNetWorker.Grpc/NativeHostIntegration/NativeWorkerClient.cs b/src/DotNetWorker.Grpc/NativeHostIntegration/NativeWorkerClient.cs new file mode 100644 index 000000000..3e357b765 --- /dev/null +++ b/src/DotNetWorker.Grpc/NativeHostIntegration/NativeWorkerClient.cs @@ -0,0 +1,75 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Runtime.InteropServices; +using System.Text.Json; +using System.Threading.Channels; +using System.Threading.Tasks; +using Google.Protobuf; +using Microsoft.Azure.Functions.Worker.Grpc.Messages; +using static System.Net.Mime.MediaTypeNames; + +namespace Microsoft.Azure.Functions.Worker.Grpc.NativeHostIntegration +{ + internal class NativeWorkerClient : IWorkerClient + { + private readonly IMessageProcessor _messageProcessor; + private readonly ChannelReader _outputChannelReader; + private readonly ChannelWriter _outputChannelWriter; + private readonly NativeSafeHandle _application; + private GCHandle _gcHandle; + + private readonly Channel _inbound = Channel.CreateUnbounded(); + + public NativeWorkerClient(IMessageProcessor messageProcessor, GrpcHostChannel outputChannel, NativeHost nativeHostData) + { + _messageProcessor = messageProcessor; + _outputChannelReader = outputChannel.Channel.Reader; + _outputChannelWriter = outputChannel.Channel.Writer; + _application = new NativeSafeHandle(nativeHostData.pNativeApplication); + } + + public unsafe void Start() + { + _gcHandle = GCHandle.Alloc(this); + NativeMethods.RegisterCallbacks(_application, &HandleRequest, (IntPtr)_gcHandle); + + _ = ProcessInbound(); + _ = ProcessOutbound(); + } + + private async Task ProcessInbound() + { + await foreach (StreamingMessage msg in _inbound.Reader.ReadAllAsync()) + { + await _messageProcessor.ProcessMessageAsync(msg); + } + } + + private async Task ProcessOutbound() + { + await foreach (StreamingMessage msg in _outputChannelReader.ReadAllAsync()) + { + NativeMethods.SendStreamingMessage(_application, msg); + } + } + + public ValueTask SendMessageAsync(StreamingMessage message) + { + return _outputChannelWriter.WriteAsync(message); + } + + [UnmanagedCallersOnly] + private static unsafe IntPtr HandleRequest(byte** nativeMessage, int nativeMessageSize, IntPtr grpcHandler) + { + var span = new ReadOnlySpan(*nativeMessage, nativeMessageSize); + var msg = StreamingMessage.Parser.ParseFrom(span); + + NativeWorkerClient handler = (NativeWorkerClient)GCHandle.FromIntPtr(grpcHandler).Target!; + handler._inbound.Writer.TryWrite(msg); + + return IntPtr.Zero; + } + } +} diff --git a/src/DotNetWorker.Grpc/NativeHostIntegration/NativeWorkerClientFactory.cs b/src/DotNetWorker.Grpc/NativeHostIntegration/NativeWorkerClientFactory.cs new file mode 100644 index 000000000..9ac45a37c --- /dev/null +++ b/src/DotNetWorker.Grpc/NativeHostIntegration/NativeWorkerClientFactory.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Azure.Functions.Worker.Grpc.NativeHostIntegration +{ + internal class NativeWorkerClientFactory : IWorkerClientFactory + { + private readonly GrpcHostChannel _hostChannel; + + public NativeWorkerClientFactory(GrpcHostChannel hostChannel) + { + _hostChannel = hostChannel; + } + + public Task StartClientAsync(IMessageProcessor messageProcessor, CancellationToken token) + { + var nativeHostData = NativeMethods.GetNativeHostData(); + + var client = new NativeWorkerClient(messageProcessor, _hostChannel, nativeHostData); + client.Start(); + + return Task.FromResult(client); + } + } +} diff --git a/src/DotNetWorker/DotNetWorker.csproj b/src/DotNetWorker/DotNetWorker.csproj index 69b84cd9a..565330412 100644 --- a/src/DotNetWorker/DotNetWorker.csproj +++ b/src/DotNetWorker/DotNetWorker.csproj @@ -8,8 +8,8 @@ Microsoft.Azure.Functions.Worker Microsoft.Azure.Functions.Worker true - 11 - -preview2 + 12 + -preview1 diff --git a/test/DotNetWorkerTests/GrpcFunctionDefinitionTests.cs b/test/DotNetWorkerTests/GrpcFunctionDefinitionTests.cs index 6a237cadd..009dd9800 100644 --- a/test/DotNetWorkerTests/GrpcFunctionDefinitionTests.cs +++ b/test/DotNetWorkerTests/GrpcFunctionDefinitionTests.cs @@ -3,6 +3,7 @@ using System.IO; using System.Threading; +using Microsoft.Azure.Functions.Tests; using Microsoft.Azure.Functions.Worker.Grpc.Messages; using Microsoft.Azure.Functions.Worker.Http; using Microsoft.Azure.Functions.Worker.Invocation; @@ -18,6 +19,8 @@ public class GrpcFunctionDefinitionTests [Fact] public void Creates() { + using var testVariables = new TestScopedEnvironmentVariable("AzureWebJobsScriptRoot", "."); + var bindingInfoProvider = new DefaultOutputBindingsInfoProvider(); var methodInfoLocator = new DefaultMethodInfoLocator(); diff --git a/test/DotNetWorkerTests/GrpcWorkerTests.cs b/test/DotNetWorkerTests/GrpcWorkerTests.cs index 1821b5693..e4e2993a4 100644 --- a/test/DotNetWorkerTests/GrpcWorkerTests.cs +++ b/test/DotNetWorkerTests/GrpcWorkerTests.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using Azure.Core.Serialization; +using Microsoft.Azure.Functions.Tests; using Microsoft.Azure.Functions.Worker.Context.Features; using Microsoft.Azure.Functions.Worker.Grpc.Messages; using Microsoft.Azure.Functions.Worker.Handlers; @@ -67,6 +68,8 @@ public GrpcWorkerTests() [Fact] public void LoadFunction_ReturnsSuccess() { + using var testVariables = new TestScopedEnvironmentVariable("AzureWebJobsScriptRoot", "test"); + FunctionLoadRequest request = CreateFunctionLoadRequest(); var response = GrpcWorker.FunctionLoadRequestHandler(request, _mockApplication.Object, _mockMethodInfoLocator.Object); @@ -89,6 +92,8 @@ public void LoadFunction_WithProxyMetadata_ReturnsSuccess() [Fact] public void LoadFunction_Throws_ReturnsFailure() { + using var testVariables = new TestScopedEnvironmentVariable("AzureWebJobsScriptRoot", "test"); + _mockApplication .Setup(m => m.LoadFunction(It.IsAny())) .Throws(new InvalidOperationException("whoops")); @@ -105,10 +110,13 @@ public void LoadFunction_Throws_ReturnsFailure() [Fact] public void MethodInfoLocator_Throws_ReturnsFailure() { + using var testVariables = new TestScopedEnvironmentVariable("AzureWebJobsScriptRoot", "test"); + _mockMethodInfoLocator .Setup(m => m.GetMethod(It.IsAny(), It.IsAny())) .Throws(new InvalidOperationException("whoops")); + FunctionLoadRequest request = CreateFunctionLoadRequest(); var response = GrpcWorker.FunctionLoadRequestHandler(request, _mockApplication.Object, _mockMethodInfoLocator.Object); diff --git a/test/TestUtility/TestScopedEnvironmentVariable.cs b/test/TestUtility/TestScopedEnvironmentVariable.cs new file mode 100644 index 000000000..0f112f16a --- /dev/null +++ b/test/TestUtility/TestScopedEnvironmentVariable.cs @@ -0,0 +1,60 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.Azure.Functions.Tests +{ + public class TestScopedEnvironmentVariable : IDisposable + { + private readonly IDictionary _variables; + private readonly IDictionary _existingVariables; + private bool _disposed = false; + + public TestScopedEnvironmentVariable(string name, string value) + : this(new Dictionary { { name, value } }) + { + } + + public TestScopedEnvironmentVariable(IDictionary variables) + { + _variables = variables; + _existingVariables = new Dictionary(variables.Count); + + SetVariables(); + } + + private void SetVariables() + { + foreach (var item in _variables) + { + _existingVariables.Add(item.Key, Environment.GetEnvironmentVariable(item.Key)); + + Environment.SetEnvironmentVariable(item.Key, item.Value); + } + } + + private void ClearVariables() + { + foreach (var item in _variables) + { + Environment.SetEnvironmentVariable(item.Key, _existingVariables[item.Key]); + } + + _existingVariables.Clear(); + } + + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + ClearVariables(); + + _disposed = true; + } + } + + public void Dispose() => Dispose(true); + } +}