diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index a8bb707ae..8f7ffae6d 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -472,11 +472,6 @@ public string ClientVersion /// internal event EventHandler> UserAuthenticationSuccessReceived; - /// - /// Occurs when message received - /// - internal event EventHandler> GlobalRequestReceived; - /// /// Occurs when message received /// @@ -1681,7 +1676,10 @@ internal void OnUserAuthenticationPublicKeyReceived(PublicKeyMessage message) /// message. internal void OnGlobalRequestReceived(GlobalRequestMessage message) { - GlobalRequestReceived?.Invoke(this, new MessageEventArgs(message)); + if (message.WantReply) + { + SendMessage(new RequestFailureMessage()); + } } /// diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs index db0f10f25..daffc6bdf 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected.cs @@ -1,5 +1,7 @@ using System; using System.Linq; +using System.Net.Sockets; +using System.Text; using System.Text.RegularExpressions; using System.Threading; @@ -7,6 +9,7 @@ using Moq; +using Renci.SshNet.Messages.Connection; using Renci.SshNet.Messages.Transport; namespace Renci.SshNet.Tests.Classes @@ -59,7 +62,7 @@ public void ShouldNotIncludeStrictKexPseudoAlgorithmInSubsequentKex() ServerListener.BytesReceived += ServerListener_BytesReceived; - void ServerListener_BytesReceived(byte[] bytesReceived, System.Net.Sockets.Socket socket) + void ServerListener_BytesReceived(byte[] bytesReceived, Socket socket) { if (bytesReceived.Length > 5 && bytesReceived[5] == 20) { @@ -106,6 +109,37 @@ public void SendMessageShouldSendPacketToServer() Assert.AreEqual(1, ServerBytesReceivedRegister.Count); } + [TestMethod] + [DataRow(true)] + [DataRow(false)] + public void UnknownGlobalRequestWithWantReply(bool wantReply) + { + Thread.Sleep(100); + + ServerBytesReceivedRegister.Clear(); + + var globalRequest = + new GlobalRequestMessage(Encoding.ASCII.GetBytes("unknown-request"), wantReply).GetPacket(8, null); + + ServerSocket.Send(globalRequest, 4, globalRequest.Length - 4, SocketFlags.None); + + Thread.Sleep(100); + + if (wantReply) + { + // Should have sent a failure reply. + Assert.AreEqual(1, ServerBytesReceivedRegister.Count); + Assert.AreEqual(82, ServerBytesReceivedRegister[0][5], "Expected to have sent SSH_MSG_REQUEST_FAILURE(82)"); + } + else + { + // Should not have sent any reply. + Assert.AreEqual(0, ServerBytesReceivedRegister.Count); + } + + Assert.AreEqual(0, ErrorOccurredRegister.Count); + } + [TestMethod] public void SessionIdShouldReturnExchangeHashCalculatedFromKeyExchangeInitMessage() {