Skip to content
50 changes: 50 additions & 0 deletions src/Renci.SshNet/Abstractions/SocketAbstraction.Async.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#if NET6_0_OR_GREATER

using System;
using System.Diagnostics;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;

namespace Renci.SshNet.Abstractions
{
internal static partial class SocketAbstraction
{
public static ValueTask<int> ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken)
{
return socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken);
}

public static ValueTask SendAsync(Socket socket, ReadOnlyMemory<byte> data, CancellationToken cancellationToken = default)
{
Debug.Assert(socket != null);
Debug.Assert(data.Length > 0);

if (cancellationToken.IsCancellationRequested)
{
return ValueTask.FromCanceled(cancellationToken);
}

return SendAsyncCore(socket, data, cancellationToken);

static async ValueTask SendAsyncCore(Socket socket, ReadOnlyMemory<byte> data, CancellationToken cancellationToken)
{
do
{
try
{
var bytesSent = await socket.SendAsync(data, SocketFlags.None, cancellationToken).ConfigureAwait(false);
data = data.Slice(bytesSent);
}
catch (SocketException ex) when (IsErrorResumable(ex.SocketErrorCode))
{
// Buffer may be full; attempt a short delay and retry
await Task.Delay(30, cancellationToken).ConfigureAwait(false);
}
}
while (data.Length > 0);
}
}
}
}
#endif // NET6_0_OR_GREATER
9 changes: 2 additions & 7 deletions src/Renci.SshNet/Abstractions/SocketAbstraction.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace Renci.SshNet.Abstractions
{
internal static class SocketAbstraction
internal static partial class SocketAbstraction
{
public static bool CanRead(Socket socket)
{
Expand Down Expand Up @@ -325,12 +325,7 @@ public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeS
return totalBytesRead;
}

#if NET6_0_OR_GREATER
public static async Task<int> ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken)
{
return await socket.ReceiveAsync(buffer, SocketFlags.None, cancellationToken).ConfigureAwait(false);
}
#else
#if NET6_0_OR_GREATER == false
public static Task<int> ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken)
{
return socket.ReceiveAsync(buffer, 0, buffer.Length, cancellationToken);
Expand Down
4 changes: 4 additions & 0 deletions src/Renci.SshNet/Connection/ProtocolVersionExchange.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ public async Task<SshIdentification> StartAsync(string clientVersion, Socket soc
{
// Immediately send the identification string since the spec states both sides MUST send an identification string
// when the connection has been established
#if NET6_0_OR_GREATER
await SocketAbstraction.SendAsync(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A"), cancellationToken).ConfigureAwait(false);
#else
SocketAbstraction.Send(socket, Encoding.UTF8.GetBytes(clientVersion + "\x0D\x0A"));
#endif // NET6_0_OR_GREATER

var bytesReceived = new List<byte>();

Expand Down
32 changes: 25 additions & 7 deletions src/Renci.SshNet/Session.cs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public class Session : ISession
/// This is also used to ensure that <see cref="_socket"/> will not be disposed
/// while performing a given operation or set of operations on <see cref="_socket"/>.
/// </remarks>
private readonly object _socketDisposeLock = new object();
private readonly SemaphoreSlim _socketDisposeLock = new SemaphoreSlim(1, 1);

/// <summary>
/// Holds an object that is used to ensure only a single thread can connect
Expand Down Expand Up @@ -1127,12 +1127,14 @@ internal void SendMessage(Message message)
/// </para>
/// <para>
/// This method is only to be used when the connection is established, as the locking
/// overhead is not required while establising the connection.
/// overhead is not required while establishing the connection.
/// </para>
/// </remarks>
private void SendPacket(byte[] packet, int offset, int length)
{
lock (_socketDisposeLock)
_socketDisposeLock.Wait();

try
{
if (!_socket.IsConnected())
{
Expand All @@ -1141,6 +1143,10 @@ private void SendPacket(byte[] packet, int offset, int length)

SocketAbstraction.Send(_socket, packet, offset, length);
}
finally
{
_ = _socketDisposeLock.Release();
}
}

/// <summary>
Expand Down Expand Up @@ -1798,8 +1804,9 @@ internal static string ToHex(byte[] bytes)
/// </remarks>
private bool IsSocketConnected()
{
#pragma warning disable S2222 // Locks should be released on all paths
lock (_socketDisposeLock)
_socketDisposeLock.Wait();

try
{
if (!_socket.IsConnected())
{
Expand All @@ -1821,7 +1828,10 @@ private bool IsSocketConnected()
Monitor.Exit(_socketReadLock);
}
}
#pragma warning restore S2222 // Locks should be released on all paths
finally
{
_ = _socketDisposeLock.Release();
}
}

/// <summary>
Expand All @@ -1848,9 +1858,13 @@ private void SocketDisconnectAndDispose()
{
if (_socket != null)
{
lock (_socketDisposeLock)
_socketDisposeLock.Wait();

try
{
#pragma warning disable CA1508 // Avoid dead conditional code; Value could have been changed by another thread.
if (_socket != null)
#pragma warning restore CA1508 // Avoid dead conditional code
{
if (_socket.Connected)
{
Expand Down Expand Up @@ -1879,6 +1893,10 @@ private void SocketDisconnectAndDispose()
_socket = null;
}
}
finally
{
_ = _socketDisposeLock.Release();
}
}
}

Expand Down