diff --git a/src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs b/src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs new file mode 100644 index 000000000..6cc6918ea --- /dev/null +++ b/src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs @@ -0,0 +1,50 @@ +#if NET6_0_OR_GREATER + +using System; +using System.Diagnostics; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; + +namespace Renci.SshNet.Abstractions +{ + internal static partial class SocketAbstraction + { + public static ValueTask ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken) + { + return socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken); + } + + public static ValueTask SendAsync(Socket socket, ReadOnlyMemory data, CancellationToken cancellationToken = default) + { + Debug.Assert(socket != null); + Debug.Assert(data.Length > 0); + + if (cancellationToken.IsCancellationRequested) + { + return ValueTask.FromCanceled(cancellationToken); + } + + return SendAsyncCore(socket, data, cancellationToken); + + static async ValueTask SendAsyncCore(Socket socket, ReadOnlyMemory data, CancellationToken cancellationToken) + { + do + { + try + { + var bytesSent = await socket.SendAsync(data, SocketFlags.None, cancellationToken).ConfigureAwait(false); + data = data.Slice(bytesSent); + } + catch (SocketException ex) when (IsErrorResumable(ex.SocketErrorCode)) + { + // Buffer may be full; attempt a short delay and retry + await Task.Delay(30, cancellationToken).ConfigureAwait(false); + } + } + while (data.Length > 0); + } + } + } +} +#endif // NET6_0_OR_GREATER diff --git a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs index f5f840336..25c72de29 100644 --- a/src/Renci.SshNet/Abstractions/SocketAbstraction.cs +++ b/src/Renci.SshNet/Abstractions/SocketAbstraction.cs @@ -10,7 +10,7 @@ namespace Renci.SshNet.Abstractions { - internal static class SocketAbstraction + internal static partial class SocketAbstraction { public static bool CanRead(Socket socket) { @@ -325,12 +325,7 @@ public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeS return totalBytesRead; } -#if NET6_0_OR_GREATER - public static async Task ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken) - { - return await socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken).ConfigureAwait(false); - } -#else +#if NET6_0_OR_GREATER == false public static Task ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken) { return socket.ReceiveAsync(buffer, 0, buffer.Length, cancellationToken); diff --git a/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs b/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs index bde732c06..b14da93c0 100644 --- a/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs +++ b/src/Renci.SshNet/Connection/ProtocolVersionExchange.cs @@ -81,7 +81,11 @@ public async Task StartAsync(string clientVersion, Socket soc { // Immediately send the identification string since the spec states both sides MUST send an identification string // when the connection has been established +#if NET6_0_OR_GREATER + await SocketAbstraction.SendAsync(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A"), cancellationToken).ConfigureAwait(false); +#else SocketAbstraction.Send(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A")); +#endif // NET6_0_OR_GREATER var bytesReceived = new List(); diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index dab919552..122e28d3d 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -119,7 +119,7 @@ public class Session : ISession /// This is also used to ensure that will not be disposed /// while performing a given operation or set of operations on . /// - private readonly object _socketDisposeLock = new object(); + private readonly SemaphoreSlim _socketDisposeLock = new SemaphoreSlim(1, 1); /// /// Holds an object that is used to ensure only a single thread can connect @@ -1127,12 +1127,14 @@ internal void SendMessage(Message message) /// /// /// This method is only to be used when the connection is established, as the locking - /// overhead is not required while establising the connection. + /// overhead is not required while establishing the connection. /// /// private void SendPacket(byte[] packet, int offset, int length) { - lock (_socketDisposeLock) + _socketDisposeLock.Wait(); + + try { if (!_socket.IsConnected()) { @@ -1141,6 +1143,10 @@ private void SendPacket(byte[] packet, int offset, int length) SocketAbstraction.Send(_socket, packet, offset, length); } + finally + { + _ = _socketDisposeLock.Release(); + } } /// @@ -1798,8 +1804,9 @@ internal static string ToHex(byte[] bytes) /// private bool IsSocketConnected() { -#pragma warning disable S2222 // Locks should be released on all paths - lock (_socketDisposeLock) + _socketDisposeLock.Wait(); + + try { if (!_socket.IsConnected()) { @@ -1821,7 +1828,10 @@ private bool IsSocketConnected() Monitor.Exit(_socketReadLock); } } -#pragma warning restore S2222 // Locks should be released on all paths + finally + { + _ = _socketDisposeLock.Release(); + } } /// @@ -1848,9 +1858,13 @@ private void SocketDisconnectAndDispose() { if (_socket != null) { - lock (_socketDisposeLock) + _socketDisposeLock.Wait(); + + try { +#pragma warning disable CA1508 // Avoid dead conditional code; Value could have been changed by another thread. if (_socket != null) +#pragma warning restore CA1508 // Avoid dead conditional code { if (_socket.Connected) { @@ -1879,6 +1893,10 @@ private void SocketDisconnectAndDispose() _socket = null; } } + finally + { + _ = _socketDisposeLock.Release(); + } } }