Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ public Socket(System.Net.Sockets.SocketType socketType, System.Net.Sockets.Proto
public bool UseOnlyOverlappedIO { get { throw null; } set { } }
public System.Net.Sockets.Socket Accept() { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptAsync() { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.Socket> AcceptAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
public System.Threading.Tasks.Task<System.Net.Sockets.Socket> AcceptAsync(System.Net.Sockets.Socket? acceptSocket) { throw null; }
public System.Threading.Tasks.ValueTask<System.Net.Sockets.Socket> AcceptAsync(System.Net.Sockets.Socket? acceptSocket, System.Threading.CancellationToken cancellationToken) { throw null; }
public bool AcceptAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public System.IAsyncResult BeginAccept(System.AsyncCallback? callback, object? state) { throw null; }
public System.IAsyncResult BeginAccept(int receiveSize, System.AsyncCallback? callback, object? state) { throw null; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,9 @@ namespace System.Net.Sockets
{
public partial class Socket
{
/// <summary>Cached instance for accept operations.</summary>
private TaskSocketAsyncEventArgs<Socket>? _acceptEventArgs;

/// <summary>Cached instance for receive operations that return <see cref="ValueTask{Int32}"/>. Also used for ConnectAsync operations.</summary>
private AwaitableSocketAsyncEventArgs? _singleBufferReceiveEventArgs;
/// <summary>Cached instance for send operations that return <see cref="ValueTask{Int32}"/>.</summary>
/// <summary>Cached instance for send operations that return <see cref="ValueTask{Int32}"/>. Also used for AcceptAsync operations.</summary>
private AwaitableSocketAsyncEventArgs? _singleBufferSendEventArgs;

/// <summary>Cached instance for receive operations that return <see cref="Task{Int32}"/>.</summary>
Expand All @@ -32,54 +29,44 @@ public partial class Socket
/// Accepts an incoming connection.
/// </summary>
/// <returns>An asynchronous task that completes with the accepted Socket.</returns>
public Task<Socket> AcceptAsync() => AcceptAsync((Socket?)null);
public Task<Socket> AcceptAsync() => AcceptAsync((Socket?)null, CancellationToken.None).AsTask();

/// <summary>
/// Accepts an incoming connection.
/// </summary>
/// <param name="acceptSocket">The socket to use for accepting the connection.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <returns>An asynchronous task that completes with the accepted Socket.</returns>
public Task<Socket> AcceptAsync(Socket? acceptSocket)
{
// Get any cached SocketAsyncEventArg we may have.
TaskSocketAsyncEventArgs<Socket>? saea = Interlocked.Exchange(ref _acceptEventArgs, null);
if (saea is null)
{
saea = new TaskSocketAsyncEventArgs<Socket>();
saea.Completed += (s, e) => CompleteAccept((Socket)s!, (TaskSocketAsyncEventArgs<Socket>)e);
}
public ValueTask<Socket> AcceptAsync(CancellationToken cancellationToken) => AcceptAsync((Socket?)null, cancellationToken);

// Configure the SAEA.
saea.AcceptSocket = acceptSocket;
/// <summary>
/// Accepts an incoming connection.
/// </summary>
/// <param name="acceptSocket">The socket to use for accepting the connection.</param>
/// <returns>An asynchronous task that completes with the accepted Socket.</returns>
public Task<Socket> AcceptAsync(Socket? acceptSocket) => AcceptAsync(acceptSocket, CancellationToken.None).AsTask();

// Initiate the accept operation.
Task<Socket> t;
if (AcceptAsync(saea))
/// <summary>
/// Accepts an incoming connection.
/// </summary>
/// <param name="acceptSocket">The socket to use for accepting the connection.</param>
/// <param name="cancellationToken">A cancellation token that can be used to cancel the asynchronous operation.</param>
/// <returns>An asynchronous task that completes with the accepted Socket.</returns>
public ValueTask<Socket> AcceptAsync(Socket? acceptSocket, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
// The operation is completing asynchronously (it may have already completed).
// Get the task for the operation, with appropriate synchronization to coordinate
// with the async callback that'll be completing the task.
bool responsibleForReturningToPool;
t = saea.GetCompletionResponsibility(out responsibleForReturningToPool).Task;
if (responsibleForReturningToPool)
{
// We're responsible for returning it only if the callback has already been invoked
// and gotten what it needs from the SAEA; otherwise, the callback will return it.
ReturnSocketAsyncEventArgs(saea);
}
return ValueTask.FromCanceled<Socket>(cancellationToken);
}
else
{
// The operation completed synchronously. Get a task for it.
t = saea.SocketError == SocketError.Success ?
Task.FromResult(saea.AcceptSocket!) :
Task.FromException<Socket>(GetException(saea.SocketError));

// There won't be a callback, and we're done with the SAEA, so return it to the pool.
ReturnSocketAsyncEventArgs(saea);
}
AwaitableSocketAsyncEventArgs saea =
Interlocked.Exchange(ref _singleBufferSendEventArgs, null) ??
new AwaitableSocketAsyncEventArgs(this, isReceiveForCaching: false);

return t;
Debug.Assert(saea.BufferList == null);
saea.SetBuffer(null, 0, 0);
saea.AcceptSocket = acceptSocket;
saea.WrapExceptionsForNetworkStream = false;
return saea.AcceptAsync(this, cancellationToken);
}

/// <summary>
Expand Down Expand Up @@ -738,34 +725,6 @@ private Task<int> GetTaskForSendReceive(bool pending, TaskSocketAsyncEventArgs<i
return t;
}

/// <summary>Completes the SocketAsyncEventArg's Task with the result of the send or receive, and returns it to the specified pool.</summary>
private static void CompleteAccept(Socket s, TaskSocketAsyncEventArgs<Socket> saea)
{
// Pull the relevant state off of the SAEA
SocketError error = saea.SocketError;
Socket? acceptSocket = saea.AcceptSocket;

// Synchronize with the initiating thread. If the synchronous caller already got what
// it needs from the SAEA, then we can return it to the pool now. Otherwise, it'll be
// responsible for returning it once it's gotten what it needs from it.
bool responsibleForReturningToPool;
AsyncTaskMethodBuilder<Socket> builder = saea.GetCompletionResponsibility(out responsibleForReturningToPool);
if (responsibleForReturningToPool)
{
s.ReturnSocketAsyncEventArgs(saea);
}

// Complete the builder/task with the results.
if (error == SocketError.Success)
{
builder.SetResult(acceptSocket!);
}
else
{
builder.SetException(GetException(error));
}
}

/// <summary>Completes the SocketAsyncEventArg's Task with the result of the send or receive, and returns it to the specified pool.</summary>
private static void CompleteSendReceive(Socket s, TaskSocketAsyncEventArgs<int> saea, bool isReceive)
{
Expand Down Expand Up @@ -824,29 +783,9 @@ private void ReturnSocketAsyncEventArgs(TaskSocketAsyncEventArgs<int> saea, bool
}
}

/// <summary>Returns a <see cref="TaskSocketAsyncEventArgs{TResult}"/> instance for reuse.</summary>
/// <param name="saea">The instance to return.</param>
private void ReturnSocketAsyncEventArgs(TaskSocketAsyncEventArgs<Socket> saea)
{
// Reset state on the SAEA before returning it. But do not reset buffer state. That'll be done
// if necessary by the consumer, but we want to keep the buffers due to likely subsequent reuse
// and the costs associated with changing them.
saea.AcceptSocket = null;
saea._accessed = false;
saea._builder = default;

// Write this instance back as a cached instance, only if there isn't currently one cached.
if (Interlocked.CompareExchange(ref _acceptEventArgs, saea, null) != null)
{
// Couldn't return it, so dispose it.
saea.Dispose();
}
}

/// <summary>Dispose of any cached <see cref="TaskSocketAsyncEventArgs{TResult}"/> instances.</summary>
private void DisposeCachedTaskSocketAsyncEventArgs()
{
Interlocked.Exchange(ref _acceptEventArgs, null)?.Dispose();
Interlocked.Exchange(ref _multiBufferReceiveEventArgs, null)?.Dispose();
Interlocked.Exchange(ref _multiBufferSendEventArgs, null)?.Dispose();
Interlocked.Exchange(ref _singleBufferReceiveEventArgs, null)?.Dispose();
Expand Down Expand Up @@ -907,7 +846,7 @@ internal AsyncTaskMethodBuilder<TResult> GetCompletionResponsibility(out bool re
}

/// <summary>A SocketAsyncEventArgs that can be awaited to get the result of an operation.</summary>
internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource<int>, IValueTaskSource<SocketReceiveFromResult>, IValueTaskSource<SocketReceiveMessageFromResult>
internal sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, IValueTaskSource, IValueTaskSource<int>, IValueTaskSource<Socket>, IValueTaskSource<SocketReceiveFromResult>, IValueTaskSource<SocketReceiveMessageFromResult>
{
private static readonly Action<object?> s_completedSentinel = new Action<object?>(state => throw new InvalidOperationException(SR.Format(SR.net_sockets_valuetaskmisuse, nameof(s_completedSentinel))));
/// <summary>The owning socket.</summary>
Expand Down Expand Up @@ -987,6 +926,28 @@ protected override void OnCompleted(SocketAsyncEventArgs _)
}
}

/// <summary>Initiates an accept operation on the associated socket.</summary>
/// <returns>This instance.</returns>
public ValueTask<Socket> AcceptAsync(Socket socket, CancellationToken cancellationToken)
{
Debug.Assert(Volatile.Read(ref _continuation) == null, "Expected null continuation to indicate reserved for use");

if (socket.AcceptAsync(this, cancellationToken))
{
_cancellationToken = cancellationToken;
return new ValueTask<Socket>(this, _token);
}

Socket acceptSocket = AcceptSocket!;
SocketError error = SocketError;

Release();

return error == SocketError.Success ?
new ValueTask<Socket>(acceptSocket) :
ValueTask.FromException<Socket>(CreateException(error));
}

/// <summary>Initiates a receive operation on the associated socket.</summary>
/// <returns>This instance.</returns>
public ValueTask<int> ReceiveAsync(Socket socket, CancellationToken cancellationToken)
Expand Down Expand Up @@ -1288,7 +1249,7 @@ private void InvokeContinuation(Action<object?> continuation, object? state, boo
/// Unlike TaskAwaiter's GetResult, this does not block until the operation completes: it must only
/// be used once the operation has completed. This is handled implicitly by await.
/// </remarks>
public int GetResult(short token)
int IValueTaskSource<int>.GetResult(short token)
{
if (token != _token)
{
Expand Down Expand Up @@ -1326,6 +1287,26 @@ void IValueTaskSource.GetResult(short token)
}
}

Socket IValueTaskSource<Socket>.GetResult(short token)
{
if (token != _token)
{
ThrowIncorrectTokenException();
}

SocketError error = SocketError;
Socket acceptSocket = AcceptSocket!;
CancellationToken cancellationToken = _cancellationToken;

Release();

if (error != SocketError.Success)
{
ThrowException(error, cancellationToken);
}
return acceptSocket;
}

SocketReceiveFromResult IValueTaskSource<SocketReceiveFromResult>.GetResult(short token)
{
if (token != _token)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2656,7 +2656,9 @@ public void Shutdown(SocketShutdown how)
// Async methods
//

public bool AcceptAsync(SocketAsyncEventArgs e)
public bool AcceptAsync(SocketAsyncEventArgs e) => AcceptAsync(e, CancellationToken.None);

private bool AcceptAsync(SocketAsyncEventArgs e, CancellationToken cancellationToken)
{
ThrowIfDisposed();

Expand Down Expand Up @@ -2689,7 +2691,7 @@ public bool AcceptAsync(SocketAsyncEventArgs e)
SocketError socketError;
try
{
socketError = e.DoOperationAccept(this, _handle, acceptHandle);
socketError = e.DoOperationAccept(this, _handle, acceptHandle, cancellationToken);
}
catch (Exception ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1433,7 +1433,7 @@ public SocketError Accept(byte[] socketAddress, ref int socketAddressLen, out In
return operation.ErrorCode;
}

public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, Action<IntPtr, byte[], int, SocketError> callback)
public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, out IntPtr acceptedFd, Action<IntPtr, byte[], int, SocketError> callback, CancellationToken cancellationToken)
{
Debug.Assert(socketAddress != null, "Expected non-null socketAddress");
Debug.Assert(socketAddressLen > 0, $"Unexpected socketAddressLen: {socketAddressLen}");
Expand All @@ -1456,7 +1456,7 @@ public SocketError AcceptAsync(byte[] socketAddress, ref int socketAddressLen, o
operation.SocketAddress = socketAddress;
operation.SocketAddressLen = socketAddressLen;

if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber))
if (!_receiveQueue.StartAsyncOperation(this, operation, observedSequenceNumber, cancellationToken))
{
socketAddressLen = operation.SocketAddressLen;
acceptedFd = operation.AcceptedFileDescriptor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ private void CompleteAcceptOperation(IntPtr acceptedFileDescriptor, byte[] socke
_acceptAddressBufferCount = socketAddressSize;
}

internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle? acceptHandle)
internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle handle, SafeSocketHandle? acceptHandle, CancellationToken cancellationToken)
{
if (!_buffer.Equals(default))
{
Expand All @@ -64,7 +64,7 @@ internal unsafe SocketError DoOperationAccept(Socket socket, SafeSocketHandle ha

IntPtr acceptedFd;
int socketAddressLen = _acceptAddressBufferCount / 2;
SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, ref socketAddressLen, out acceptedFd, AcceptCompletionCallback);
SocketError socketError = handle.AsyncContext.AcceptAsync(_acceptBuffer!, ref socketAddressLen, out acceptedFd, AcceptCompletionCallback, cancellationToken);

if (socketError != SocketError.IOPending)
{
Expand Down
Loading