diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs index 6e247ef9937376..2385f3b5ca8ae3 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/Internal/MsQuicSafeHandle.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Diagnostics; using System.Runtime.InteropServices; using Microsoft.Quic; @@ -25,7 +26,14 @@ internal unsafe class MsQuicSafeHandle : SafeHandle public override bool IsInvalid => handle == IntPtr.Zero; - public QUIC_HANDLE* QuicHandle => (QUIC_HANDLE*)DangerousGetHandle(); + public QUIC_HANDLE* QuicHandle + { + get + { + ObjectDisposedException.ThrowIf(IsInvalid, this); + return (QUIC_HANDLE*)DangerousGetHandle(); + } + } public MsQuicSafeHandle(QUIC_HANDLE* handle, delegate* unmanaged[Cdecl] releaseAction, SafeHandleType safeHandleType) : base((IntPtr)handle, ownsHandle: true) @@ -41,8 +49,9 @@ public MsQuicSafeHandle(QUIC_HANDLE* handle, delegate* unmanaged[Cdecl](Func call) + { + ObjectDisposedException.ThrowIf(IsInvalid, this); + bool success = false; + try + { + DangerousAddRef(ref success); + Debug.Assert(success); + return call(this); + } + finally + { + if (success) + { + DangerousRelease(); + } + } + } + + public void SafeCall(Action call) + { + ObjectDisposedException.ThrowIf(IsInvalid, this); + bool success = false; + try + { + DangerousAddRef(ref success); + Debug.Assert(success); + call(this); + } + finally + { + if (success) + { + DangerousRelease(); + } + } + } + public override string ToString() => _traceId ??= $"[{s_typeName[(int)_type]}][0x{DangerousGetHandle():X11}]"; } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs index b03d3652e758e2..2cf5799923f56b 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs @@ -242,7 +242,7 @@ private async ValueTask FinishConnectAsync(QuicClientConnectionOptions options, if (address is not null) { QuicAddr quicAddress = new IPEndPoint(address, port).ToQuicAddr(); - MsQuicHelpers.SetMsQuicParameter(_handle, QUIC_PARAM_CONN_REMOTE_ADDRESS, quicAddress); + _handle.SafeCall(handle => MsQuicHelpers.SetMsQuicParameter(handle, QUIC_PARAM_CONN_REMOTE_ADDRESS, quicAddress)); } // RemoteEndPoint is DnsEndPoint containing hostname that is different from requested SNI. // --> Resolve the hostname and set the IP directly, use requested SNI in ConnectionStart. @@ -257,7 +257,7 @@ private async ValueTask FinishConnectAsync(QuicClientConnectionOptions options, } QuicAddr quicAddress = new IPEndPoint(addresses[0], port).ToQuicAddr(); - MsQuicHelpers.SetMsQuicParameter(_handle, QUIC_PARAM_CONN_REMOTE_ADDRESS, quicAddress); + _handle.SafeCall(handle => MsQuicHelpers.SetMsQuicParameter(handle, QUIC_PARAM_CONN_REMOTE_ADDRESS, quicAddress)); } // RemoteEndPoint is DnsEndPoint containing hostname that is the same as the requested SNI. // --> Let MsQuic resolve the hostname/SNI, give address family hint is specified in DnsEndPoint. @@ -276,7 +276,7 @@ private async ValueTask FinishConnectAsync(QuicClientConnectionOptions options, if (options.LocalEndPoint is not null) { QuicAddr quicAddress = options.LocalEndPoint.ToQuicAddr(); - MsQuicHelpers.SetMsQuicParameter(_handle, QUIC_PARAM_CONN_LOCAL_ADDRESS, quicAddress); + _handle.SafeCall(handle => MsQuicHelpers.SetMsQuicParameter(handle, QUIC_PARAM_CONN_LOCAL_ADDRESS, quicAddress)); } _sslConnectionOptions = new SslConnectionOptions( @@ -294,13 +294,13 @@ private async ValueTask FinishConnectAsync(QuicClientConnectionOptions options, { unsafe { - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionStart( - _handle.QuicHandle, + _handle.SafeCall(handle => ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionStart( + handle.QuicHandle, _configuration.QuicHandle, (ushort)addressFamily, (sbyte*)targetHostPtr, (ushort)port), - "ConnectionStart failed"); + "ConnectionStart failed")); } } finally @@ -334,10 +334,10 @@ internal ValueTask FinishHandshakeAsync(QuicServerConnectionOptions options, str unsafe { - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionSetConfiguration( - _handle.QuicHandle, + _handle.SafeCall(handle => ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->ConnectionSetConfiguration( + handle.QuicHandle, _configuration.QuicHandle), - "ConnectionSetConfiguration failed"); + "ConnectionSetConfiguration failed")); } } @@ -359,7 +359,7 @@ public async ValueTask OpenOutboundStreamAsync(QuicStreamType type, QuicStream? stream = null; try { - stream = new QuicStream(_handle, type, _defaultStreamErrorCode); + stream = _handle.SafeCall(handle => new QuicStream((MsQuicContextSafeHandle)handle, type, _defaultStreamErrorCode)); await stream.StartAsync(cancellationToken).ConfigureAwait(false); } catch @@ -392,6 +392,7 @@ public async ValueTask AcceptInboundStreamAsync(CancellationToken ca throw new InvalidOperationException(SR.net_quic_accept_not_allowed); } + GCHandle keepObject = GCHandle.Alloc(this); try { return await _acceptQueue.Reader.ReadAsync(cancellationToken).ConfigureAwait(false); @@ -401,6 +402,10 @@ public async ValueTask AcceptInboundStreamAsync(CancellationToken ca ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); throw; } + finally + { + keepObject.Free(); + } } /// @@ -425,10 +430,10 @@ public ValueTask CloseAsync(long errorCode, CancellationToken cancellationToken { unsafe { - MsQuicApi.Api.ApiTable->ConnectionShutdown( - _handle.QuicHandle, + _handle.SafeCall(handle => MsQuicApi.Api.ApiTable->ConnectionShutdown( + handle.QuicHandle, QUIC_CONNECTION_SHUTDOWN_FLAGS.NONE, - (ulong)errorCode); + (ulong)errorCode)); } } @@ -469,8 +474,8 @@ private unsafe int HandleEventShutdownInitiatedByPeer(ref SHUTDOWN_INITIATED_BY_ } private unsafe int HandleEventShutdownComplete(ref SHUTDOWN_COMPLETE_DATA data) { - _shutdownTcs.TrySetResult(); _acceptQueue.Writer.TryComplete(ExceptionDispatchInfo.SetCurrentStackTrace(ThrowHelper.GetOperationAbortedException())); + _shutdownTcs.TrySetResult(); return QUIC_STATUS_SUCCESS; } private unsafe int HandleEventLocalAddressChanged(ref LOCAL_ADDRESS_CHANGED_DATA data) @@ -577,10 +582,10 @@ public async ValueTask DisposeAsync() { unsafe { - MsQuicApi.Api.ApiTable->ConnectionShutdown( - _handle.QuicHandle, + _handle.SafeCall(handle => MsQuicApi.Api.ApiTable->ConnectionShutdown( + handle.QuicHandle, QUIC_CONNECTION_SHUTDOWN_FLAGS.NONE, - (ulong)_defaultCloseErrorCode); + (ulong)_defaultCloseErrorCode)); } } diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs index 11c4d77731e033..37d0e0d2079d7c 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicListener.cs @@ -162,6 +162,7 @@ public async ValueTask AcceptConnectionAsync(CancellationToken c { ObjectDisposedException.ThrowIf(_disposed == 1, this); + GCHandle keepObject = GCHandle.Alloc(this); try { PendingConnection pendingConnection = await _acceptQueue.Reader.ReadAsync(cancellationToken).ConfigureAwait(false); @@ -175,6 +176,10 @@ public async ValueTask AcceptConnectionAsync(CancellationToken c ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); throw; } + finally + { + keepObject.Free(); + } } private unsafe int HandleEventNewConnection(ref NEW_CONNECTION_DATA data) diff --git a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs index c11f3029a2136c..21879961694c94 100644 --- a/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs +++ b/src/libraries/System.Net.Quic/src/System/Net/Quic/QuicStream.cs @@ -90,6 +90,7 @@ public sealed partial class QuicStream } }; private MsQuicBuffers _sendBuffers = new MsQuicBuffers(); + private object _sendBuffersLock = new object(); private readonly long _defaultErrorCode; @@ -220,9 +221,9 @@ internal ValueTask StartAsync(CancellationToken cancellationToken = default) { unsafe { - int status = MsQuicApi.Api.ApiTable->StreamStart( - _handle.QuicHandle, - QUIC_STREAM_START_FLAGS.SHUTDOWN_ON_FAIL | QUIC_STREAM_START_FLAGS.INDICATE_PEER_ACCEPT); + int status = _handle.SafeCall(handle => MsQuicApi.Api.ApiTable->StreamStart( + handle.QuicHandle, + QUIC_STREAM_START_FLAGS.SHUTDOWN_ON_FAIL | QUIC_STREAM_START_FLAGS.INDICATE_PEER_ACCEPT)); if (ThrowHelper.TryGetStreamExceptionForMsQuicStatus(status, out Exception? exception)) { _startedTcs.TrySetException(exception); @@ -297,9 +298,9 @@ public override async ValueTask ReadAsync(Memory buffer, Cancellation { unsafe { - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->StreamReceiveSetEnabled( - _handle.QuicHandle, - 1), + ThrowHelper.ThrowIfMsQuicError(_handle.SafeCall(handle => MsQuicApi.Api.ApiTable->StreamReceiveSetEnabled( + handle.QuicHandle, + 1)), "StreamReceivedSetEnabled failed"); } } @@ -360,19 +361,34 @@ public ValueTask WriteAsync(ReadOnlyMemory buffer, bool completeWrites, Ca return valueTask; } - _sendBuffers.Initialize(buffer); - unsafe + lock (_sendBuffersLock) { - int status = MsQuicApi.Api.ApiTable->StreamSend( - _handle.QuicHandle, - _sendBuffers.Buffers, - (uint)_sendBuffers.Count, - completeWrites ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE, - null); - if (ThrowHelper.TryGetStreamExceptionForMsQuicStatus(status, out Exception? exception)) + ObjectDisposedException.ThrowIf(_disposed == 1, this); // TODO: valueTask is left unobserved + unsafe { - _sendBuffers.Reset(); - _sendTcs.TrySetException(exception, final: true); + if (_sendBuffers.Count > 0 && _sendBuffers.Buffers[0].Buffer != null) + { + // _sendBuffers are not reset, meaning SendComplete for the previous WriteAsync call didn't arrive yet. + // In case of cancellation, the task from _sendTcs is finished before the aborting. It is technically possible for subsequent + // WriteAsync to grab the next task from _sendTcs and start executing before SendComplete event occurs for the previous (canceled) write. + // This is not an "invalid nested call", because the previous task has finished. Best guess is to mimic OperationAborted as it will be from Abort + // that would execute soon enough, if not already. Not final, because Abort should be the one to set final exception. + _sendTcs.TrySetException(ThrowHelper.GetOperationAbortedException(SR.net_quic_writing_aborted), final: false); + return valueTask; + } + + _sendBuffers.Initialize(buffer); + int status = _handle.SafeCall(handle => MsQuicApi.Api.ApiTable->StreamSend( + handle.QuicHandle, + _sendBuffers.Buffers, + (uint)_sendBuffers.Count, + completeWrites ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE, + null)); + if (ThrowHelper.TryGetStreamExceptionForMsQuicStatus(status, out Exception? exception)) + { + _sendBuffers.Reset(); + _sendTcs.TrySetException(exception, final: true); + } } } @@ -419,10 +435,10 @@ public void Abort(QuicAbortDirection abortDirection, long errorCode) unsafe { - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->StreamShutdown( - _handle.QuicHandle, + ThrowHelper.ThrowIfMsQuicError(_handle.SafeCall(handle => MsQuicApi.Api.ApiTable->StreamShutdown( + handle.QuicHandle, flags, - (ulong)errorCode), + (ulong)errorCode)), "StreamShutdown failed"); } } @@ -442,10 +458,10 @@ public void CompleteWrites() { unsafe { - ThrowHelper.ThrowIfMsQuicError(MsQuicApi.Api.ApiTable->StreamShutdown( - _handle.QuicHandle, + ThrowHelper.ThrowIfMsQuicError(_handle.SafeCall(handle => MsQuicApi.Api.ApiTable->StreamShutdown( + handle.QuicHandle, QUIC_STREAM_SHUTDOWN_FLAGS.GRACEFUL, - default), + default)), "StreamShutdown failed"); } } @@ -490,7 +506,12 @@ private unsafe int HandleEventReceive(ref RECEIVE data) } private unsafe int HandleEventSendComplete(ref SEND_COMPLETE data) { - _sendBuffers.Reset(); + // In case of cancellation, the task from _sendTcs is finished before the aborting. It is technically possible for subsequent WriteAsync to grab the next task + // from _sendTcs and start executing before SendComplete event occurs for the previous (canceled) write + lock (_sendBuffersLock) + { + _sendBuffers.Reset(); + } if (data.Canceled == 0) { _sendTcs.TrySetResult(); @@ -653,15 +674,18 @@ public override async ValueTask DisposeAsync() await valueTask.ConfigureAwait(false); _handle.Dispose(); - // TODO: memory leak if not disposed - _sendBuffers.Dispose(); + lock (_sendBuffersLock) + { + // TODO: memory leak if not disposed + _sendBuffers.Dispose(); + } unsafe void StreamShutdown(QUIC_STREAM_SHUTDOWN_FLAGS flags, long errorCode) { - int status = MsQuicApi.Api.ApiTable->StreamShutdown( - _handle.QuicHandle, + int status = _handle.SafeCall(handle => MsQuicApi.Api.ApiTable->StreamShutdown( + handle.QuicHandle, flags, - (ulong)errorCode); + (ulong)errorCode)); if (StatusFailed(status)) { if (NetEventSource.Log.IsEnabled())