diff --git a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackConnection.cs b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackConnection.cs index b34a64074120d3..31a0dd7893669a 100644 --- a/src/libraries/Common/tests/System/Net/Http/Http3LoopbackConnection.cs +++ b/src/libraries/Common/tests/System/Net/Http/Http3LoopbackConnection.cs @@ -49,6 +49,9 @@ internal sealed class Http3LoopbackConnection : GenericLoopbackConnection private Http3LoopbackStream _inboundControlStream; // Inbound control stream from client private Http3LoopbackStream _outboundControlStream; // Our outbound control stream + public Http3LoopbackStream OutboundControlStream => _outboundControlStream ?? throw new Exception("Control stream has not been opened yet"); + public Http3LoopbackStream InboundControlStream => _inboundControlStream ?? throw new Exception("Inbound control stream has not been accepted yet"); + public Http3LoopbackConnection(QuicConnection connection) { _connection = connection; diff --git a/src/libraries/System.Net.Http/src/System.Net.Http.csproj b/src/libraries/System.Net.Http/src/System.Net.Http.csproj index e84a7e1c466279..e1c24b7265fe42 100644 --- a/src/libraries/System.Net.Http/src/System.Net.Http.csproj +++ b/src/libraries/System.Net.Http/src/System.Net.Http.csproj @@ -446,6 +446,7 @@ + diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs index de79b6c3da418b..5493f3e2b699a1 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3Connection.cs @@ -7,6 +7,7 @@ using System.Runtime.Versioning; using System.Net.Quic; using System.IO; +using System.Linq; using System.Collections.Generic; using System.Diagnostics; using System.Globalization; @@ -368,6 +369,16 @@ private async Task SendSettingsAsync() try { _clientControl = await _connection!.OpenOutboundStreamAsync(QuicStreamType.Unidirectional).ConfigureAwait(false); + + // Server MUST NOT abort our control stream, setup a continuation which will react accordingly + _ = _clientControl.WritesClosed.ContinueWith(t => + { + if (t.Exception?.InnerException is QuicException ex && ex.QuicError == QuicError.StreamAborted) + { + Abort(HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.ClosedCriticalStream)); + } + }, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously, TaskScheduler.Current); + await _clientControl.WriteAsync(_pool.Settings.Http3SettingsFrame, CancellationToken.None).ConfigureAwait(false); } catch (Exception ex) @@ -571,70 +582,78 @@ private async Task ProcessServerStreamAsync(QuicStream stream) /// private async Task ProcessServerControlStreamAsync(QuicStream stream, ArrayBuffer buffer) { - using (buffer) + try { - // Read the first frame of the control stream. Per spec: - // A SETTINGS frame MUST be sent as the first frame of each control stream. - - (Http3FrameType? frameType, long payloadLength) = await ReadFrameEnvelopeAsync().ConfigureAwait(false); - - if (frameType == null) + using (buffer) { - // Connection closed prematurely, expected SETTINGS frame. - throw HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.ClosedCriticalStream); - } + // Read the first frame of the control stream. Per spec: + // A SETTINGS frame MUST be sent as the first frame of each control stream. - if (frameType != Http3FrameType.Settings) - { - throw HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.MissingSettings); - } + (Http3FrameType? frameType, long payloadLength) = await ReadFrameEnvelopeAsync().ConfigureAwait(false); - await ProcessSettingsFrameAsync(payloadLength).ConfigureAwait(false); + if (frameType == null) + { + // Connection closed prematurely, expected SETTINGS frame. + throw HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.ClosedCriticalStream); + } - // Read subsequent frames. + if (frameType != Http3FrameType.Settings) + { + throw HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.MissingSettings); + } - while (true) - { - (frameType, payloadLength) = await ReadFrameEnvelopeAsync().ConfigureAwait(false); + await ProcessSettingsFrameAsync(payloadLength).ConfigureAwait(false); + + // Read subsequent frames. - switch (frameType) + while (true) { - case Http3FrameType.GoAway: - await ProcessGoAwayFrameAsync(payloadLength).ConfigureAwait(false); - break; - case Http3FrameType.Settings: - // If an endpoint receives a second SETTINGS frame on the control stream, the endpoint MUST respond with a connection error of type H3_FRAME_UNEXPECTED. - throw HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.UnexpectedFrame); - case Http3FrameType.Headers: // Servers should not send these frames to a control stream. - case Http3FrameType.Data: - case Http3FrameType.MaxPushId: - case Http3FrameType.ReservedHttp2Priority: // These frames are explicitly reserved and must never be sent. - case Http3FrameType.ReservedHttp2Ping: - case Http3FrameType.ReservedHttp2WindowUpdate: - case Http3FrameType.ReservedHttp2Continuation: - throw HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.UnexpectedFrame); - case Http3FrameType.PushPromise: - case Http3FrameType.CancelPush: - // Because we haven't sent any MAX_PUSH_ID frame, it is invalid to receive any push-related frames as they will all reference a too-large ID. - throw HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.IdError); - case null: - // End of stream reached. If we're shutting down, stop looping. Otherwise, this is an error (this stream should not be closed for life of connection). - bool shuttingDown; - lock (SyncObj) - { - shuttingDown = ShuttingDown; - } - if (!shuttingDown) - { - throw HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.ClosedCriticalStream); - } - return; - default: - await SkipUnknownPayloadAsync(frameType.GetValueOrDefault(), payloadLength).ConfigureAwait(false); - break; + (frameType, payloadLength) = await ReadFrameEnvelopeAsync().ConfigureAwait(false); + + switch (frameType) + { + case Http3FrameType.GoAway: + await ProcessGoAwayFrameAsync(payloadLength).ConfigureAwait(false); + break; + case Http3FrameType.Settings: + // If an endpoint receives a second SETTINGS frame on the control stream, the endpoint MUST respond with a connection error of type H3_FRAME_UNEXPECTED. + throw HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.UnexpectedFrame); + case Http3FrameType.Headers: // Servers should not send these frames to a control stream. + case Http3FrameType.Data: + case Http3FrameType.MaxPushId: + case Http3FrameType.ReservedHttp2Priority: // These frames are explicitly reserved and must never be sent. + case Http3FrameType.ReservedHttp2Ping: + case Http3FrameType.ReservedHttp2WindowUpdate: + case Http3FrameType.ReservedHttp2Continuation: + throw HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.UnexpectedFrame); + case Http3FrameType.PushPromise: + case Http3FrameType.CancelPush: + // Because we haven't sent any MAX_PUSH_ID frame, it is invalid to receive any push-related frames as they will all reference a too-large ID. + throw HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.IdError); + case null: + // End of stream reached. If we're shutting down, stop looping. Otherwise, this is an error (this stream should not be closed for life of connection). + bool shuttingDown; + lock (SyncObj) + { + shuttingDown = ShuttingDown; + } + if (!shuttingDown) + { + throw HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.ClosedCriticalStream); + } + return; + default: + await SkipUnknownPayloadAsync(frameType.GetValueOrDefault(), payloadLength).ConfigureAwait(false); + break; + } } } } + catch (QuicException ex) when (ex.QuicError == QuicError.StreamAborted) + { + // Peers MUST NOT close the control stream + throw HttpProtocolException.CreateHttp3ConnectionException(Http3ErrorCode.ClosedCriticalStream); + } async ValueTask<(Http3FrameType? frameType, long payloadLength)> ReadFrameEnvelopeAsync() { diff --git a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs index a11f8b8e27f960..9f8f9bf036480c 100644 --- a/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs +++ b/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/Http3RequestStream.cs @@ -272,6 +272,11 @@ await Task.WhenAny(sendContentTask, readResponseTask).ConfigureAwait(false) == s Exception abortException = _connection.Abort(HttpProtocolException.CreateHttp3ConnectionException(code, SR.net_http_http3_connection_close)); throw new HttpRequestException(SR.net_http_client_execution_error, abortException); } + catch (QuicException ex) when (ex.QuicError == QuicError.OperationAborted && _connection.AbortException != null) + { + // we close the connection, propagate the AbortException + throw new HttpRequestException(SR.net_http_client_execution_error, _connection.AbortException); + } // It is possible for user's Content code to throw an unexpected OperationCanceledException. catch (OperationCanceledException ex) when (ex.CancellationToken == _requestBodyCancellationSource.Token || ex.CancellationToken == cancellationToken) { diff --git a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs index 2a2a5267653345..cd6056949cc220 100644 --- a/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs +++ b/src/libraries/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.Http3.cs @@ -1596,6 +1596,96 @@ public async Task ServerSendsTrailingHeaders_Success() } + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ServerClosesOutboundControlStream_ClientClosesConnection(bool graceful) + { + using Http3LoopbackServer server = CreateHttp3LoopbackServer(); + + SemaphoreSlim semaphore = new SemaphoreSlim(0); + Task serverTask = Task.Run(async () => + { + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + + // wait for incoming request + await using Http3LoopbackStream requestStream = await connection.AcceptRequestStreamAsync(); + + // abort the control stream + if (graceful) + { + await connection.OutboundControlStream.SendResponseBodyAsync(Array.Empty(), isFinal: true); + } + else + { + connection.OutboundControlStream.Abort(Http3LoopbackConnection.H3_INTERNAL_ERROR); + } + + // wait for client task before tearing down the requestStream and connection + await semaphore.WaitAsync(); + }); + + Task clientTask = Task.Run(async () => + { + using HttpClient client = CreateHttpClient(); + + using HttpRequestMessage request = new() + { + Method = HttpMethod.Get, + RequestUri = server.Address, + Version = HttpVersion30, + VersionPolicy = HttpVersionPolicy.RequestVersionExact + }; + + await AssertProtocolErrorAsync(Http3LoopbackConnection.H3_CLOSED_CRITICAL_STREAM, () => client.SendAsync(request)); + semaphore.Release(); + }); + + await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(200_000); + } + + [Fact] + public async Task ServerClosesInboundControlStream_ClientClosesConnection() + { + using Http3LoopbackServer server = CreateHttp3LoopbackServer(); + + SemaphoreSlim semaphore = new SemaphoreSlim(0); + Task serverTask = Task.Run(async () => + { + await using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync(); + + // wait for incoming request + (Http3LoopbackStream controlStream, Http3LoopbackStream requestStream) = await connection.AcceptControlAndRequestStreamAsync(); + + await using (controlStream) + await using (requestStream) + { + controlStream.Abort(Http3LoopbackConnection.H3_INTERNAL_ERROR); + // wait for client task before tearing down the requestStream and connection + await semaphore.WaitAsync(); + } + + }); + + Task clientTask = Task.Run(async () => + { + using HttpClient client = CreateHttpClient(); + + using HttpRequestMessage request = new() + { + Method = HttpMethod.Get, + RequestUri = server.Address, + Version = HttpVersion30, + VersionPolicy = HttpVersionPolicy.RequestVersionExact + }; + + await AssertProtocolErrorAsync(Http3LoopbackConnection.H3_CLOSED_CRITICAL_STREAM, () => client.SendAsync(request)); + semaphore.Release(); + }); + + await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(200_000); + } + private static async Task AssertThrowsQuicExceptionAsync(QuicError expectedError, Func testCode) { QuicException ex = await Assert.ThrowsAsync(testCode);