Skip to content
56 changes: 56 additions & 0 deletions src/Renci.SshNet/Common/SshOperationCancelledException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using System;
#if NETFRAMEWORK
using System.Runtime.Serialization;
#endif // NETFRAMEWORK

namespace Renci.SshNet.Common
{
/// <summary>
/// The exception that is thrown when operation is timed out.
/// </summary>
#if NETFRAMEWORK
[Serializable]
#endif // NETFRAMEWORK
public class SshOperationCancelledException : SshException
{
/// <summary>
/// Initializes a new instance of the <see cref="SshOperationCancelledException"/> class.
/// </summary>
public SshOperationCancelledException()
{
}

/// <summary>
/// Initializes a new instance of the <see cref="SshOperationCancelledException"/> class.
/// </summary>
/// <param name="message">The message.</param>
public SshOperationCancelledException(string message)
: base(message)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="SshOperationCancelledException"/> class.
/// </summary>
/// <param name="message">The message.</param>
/// <param name="innerException">The inner exception.</param>
public SshOperationCancelledException(string message, Exception innerException)
: base(message, innerException)
{
}

#if NETFRAMEWORK
/// <summary>
/// Initializes a new instance of the <see cref="SshOperationCancelledException"/> class.
/// </summary>
/// <param name="info">The <see cref="SerializationInfo"/> that holds the serialized object data about the exception being thrown.</param>
/// <param name="context">The <see cref="StreamingContext"/> that contains contextual information about the source or destination.</param>
/// <exception cref="ArgumentNullException">The <paramref name="info"/> parameter is <see langword="null"/>.</exception>
/// <exception cref="SerializationException">The class name is <see langword="null"/> or <see cref="Exception.HResult"/> is zero (0). </exception>
protected SshOperationCancelledException(SerializationInfo info, StreamingContext context)
: base(info, context)
{
}
#endif // NETFRAMEWORK
}
}
72 changes: 58 additions & 14 deletions src/Renci.SshNet/SshCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Runtime.ExceptionServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

using Renci.SshNet.Abstractions;
using Renci.SshNet.Channels;
Expand All @@ -26,11 +27,13 @@ public class SshCommand : IDisposable
private CommandAsyncResult _asyncResult;
private AsyncCallback _callback;
private EventWaitHandle _sessionErrorOccuredWaitHandle;
private EventWaitHandle _commandCancelledWaitHandle;
private Exception _exception;
private StringBuilder _result;
private StringBuilder _error;
private bool _hasError;
private bool _isDisposed;
private bool _isCancelled;
private ChannelInputStream _inputStream;
private TimeSpan _commandTimeout;

Expand Down Expand Up @@ -84,7 +87,7 @@ public TimeSpan CommandTimeout
/// <returns>
/// The stream that can be used to transfer data to the command's input stream.
/// </returns>
#pragma warning disable CA1859 // Use concrete types when possible for improved performance
#pragma warning disable CA1859 // Use concrete types when possible for improved performance
public Stream CreateInputStream()
#pragma warning restore CA1859 // Use concrete types when possible for improved performance
{
Expand Down Expand Up @@ -186,7 +189,7 @@ internal SshCommand(ISession session, string commandText, Encoding encoding)
_encoding = encoding;
CommandTimeout = Timeout.InfiniteTimeSpan;
_sessionErrorOccuredWaitHandle = new AutoResetEvent(initialState: false);

_commandCancelledWaitHandle = new AutoResetEvent(initialState: false);
_session.Disconnected += Session_Disconnected;
_session.ErrorOccured += Session_ErrorOccured;
}
Expand Down Expand Up @@ -249,11 +252,11 @@ public IAsyncResult BeginExecute(AsyncCallback callback, object state)

// Create new AsyncResult object
_asyncResult = new CommandAsyncResult
{
AsyncWaitHandle = new ManualResetEvent(initialState: false),
IsCompleted = false,
AsyncState = state,
};
{
AsyncWaitHandle = new ManualResetEvent(initialState: false),
IsCompleted = false,
AsyncState = state,
};

if (_channel is not null)
{
Expand Down Expand Up @@ -349,20 +352,49 @@ public string EndExecute(IAsyncResult asyncResult)

commandAsyncResult.EndCalled = true;

return Result;
if (!_isCancelled)
{
return Result;
}

SetAsyncComplete();
throw new SshOperationCancelledException();
}
}

/// <summary>
/// Cancels command execution in asynchronous scenarios.
/// </summary>
public void CancelAsync()
/// <param name="signalBeforeClose">should exit-signal be sent before attempting to close channel.</param>
/// <param name="forceKill">if true send SIGKILL instead of SIGTERM.</param>
/// <param name="timeout">how long to wait before stop waiting for command and close the channel.</param>
/// <returns>
/// Command Cancellation Task.
/// </returns>
/// <remarks>
/// <para>
/// After sending the exit-signal to the recipient, wait until either <paramref name="timeout"/> is exceeded
/// or the async result is signaled before signaling command cancellation.
/// If the exit-signal always results in the command being cancelled by the recipient, then <paramref name="timeout"/>
/// can be set to <see cref="Timeout.InfiniteTimeSpan"/> to wait until the async result is signaled.
/// </para>
/// </remarks>
public Task CancelAsync(bool signalBeforeClose = true, bool forceKill = false, TimeSpan timeout = default)
{
if (_channel is not null && _channel.IsOpen && _asyncResult is not null)
if (signalBeforeClose)
{
// TODO: check with Oleg if we shouldn't dispose the channel and uninitialize it ?
_channel.Dispose();
var signal = forceKill ? "KILL" : "TERM";
_ = _channel?.SendExitSignalRequest(signal, coreDumped: false, "Command execution has been cancelled.", "en");
}

return Task.Run(() =>
{
var signaledElement = WaitHandle.WaitAny(new[] { _asyncResult.AsyncWaitHandle }, timeout);
if (signaledElement == WaitHandle.WaitTimeout)
{
_ = _commandCancelledWaitHandle?.Set();
}
});
}

/// <summary>
Expand Down Expand Up @@ -430,7 +462,7 @@ private void Session_ErrorOccured(object sender, ExceptionEventArgs e)
_ = _sessionErrorOccuredWaitHandle.Set();
}

private void Channel_Closed(object sender, ChannelEventArgs e)
private void SetAsyncComplete()
{
OutputStream?.Flush();
ExtendedOutputStream?.Flush();
Expand All @@ -446,6 +478,11 @@ private void Channel_Closed(object sender, ChannelEventArgs e)
_ = ((EventWaitHandle) _asyncResult.AsyncWaitHandle).Set();
}

private void Channel_Closed(object sender, ChannelEventArgs e)
{
SetAsyncComplete();
}

private void Channel_RequestReceived(object sender, ChannelRequestEventArgs e)
{
if (e.Info is ExitStatusRequestInfo exitStatusInfo)
Expand Down Expand Up @@ -506,7 +543,8 @@ private void WaitOnHandle(WaitHandle waitHandle)
var waitHandles = new[]
{
_sessionErrorOccuredWaitHandle,
waitHandle
waitHandle,
_commandCancelledWaitHandle
};

var signaledElement = WaitHandle.WaitAny(waitHandles, CommandTimeout);
Expand All @@ -518,6 +556,9 @@ private void WaitOnHandle(WaitHandle waitHandle)
case 1:
// Specified waithandle was signaled
break;
case 2:
_isCancelled = true;
break;
case WaitHandle.WaitTimeout:
throw new SshOperationTimeoutException(string.Format(CultureInfo.CurrentCulture, "Command '{0}' has timed out.", CommandText));
default:
Expand Down Expand Up @@ -620,6 +661,9 @@ protected virtual void Dispose(bool disposing)
_sessionErrorOccuredWaitHandle = null;
}

_commandCancelledWaitHandle?.Dispose();
_commandCancelledWaitHandle = null;

_isDisposed = true;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,49 @@ public void Test_Execute_SingleCommand()
}
}

[TestMethod]
[Timeout(5000)]
public void Test_CancelAsync_Unfinished_Command()
{
using var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password);
#region Example SshCommand CancelAsync Unfinished Command Without Sending exit-signal
client.Connect();
var testValue = Guid.NewGuid().ToString();
var command = $"sleep 15s; echo {testValue}";
using var cmd = client.CreateCommand(command);
var asyncResult = cmd.BeginExecute();
_ = cmd.CancelAsync(signalBeforeClose: false);
Assert.ThrowsException<SshOperationCancelledException>(() => cmd.EndExecute(asyncResult));
Assert.IsTrue(asyncResult.IsCompleted);
client.Disconnect();
Assert.AreEqual(string.Empty, cmd.Result.Trim());
#endregion
}

[TestMethod]
public async Task Test_CancelAsync_Finished_Command()
{
using var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password);
#region Example SshCommand CancelAsync Finished Command Without Sending exit-signal
client.Connect();
var testValue = Guid.NewGuid().ToString();
var command = $"echo {testValue}";
using var cmd = client.CreateCommand(command);
var asyncResult = cmd.BeginExecute();
while (!asyncResult.IsCompleted)
{
await Task.Delay(200);
}

_ = cmd.CancelAsync(signalBeforeClose: false);
cmd.EndExecute(asyncResult);
client.Disconnect();

Assert.IsTrue(asyncResult.IsCompleted);
Assert.AreEqual(testValue, cmd.Result.Trim());
#endregion
}

[TestMethod]
public void Test_Execute_OutputStream()
{
Expand Down Expand Up @@ -222,7 +265,7 @@ public void Test_Execute_Command_ExitStatus()
client.Connect();

var cmd = client.RunCommand("exit 128");

Console.WriteLine(cmd.ExitStatus);

client.Disconnect();
Expand Down Expand Up @@ -443,7 +486,7 @@ public void Test_Execute_Invalid_Command()
}

[TestMethod]

public void Test_MultipleThread_100_MultipleConnections()
{
try
Expand Down