From d586e0645b2e62a20fc8a26f4dd45eea2cc70e35 Mon Sep 17 00:00:00 2001 From: Robert Hague Date: Sat, 3 May 2025 19:11:02 +0200 Subject: [PATCH] Read the underlying buffer in SshDataStream SshDataStream is a MemoryStream, so we can access the buffer directly. Also simplify some usage in PrivateKeyFile. --- src/Renci.SshNet/Common/Extensions.cs | 24 +++++ src/Renci.SshNet/Common/SshDataStream.cs | 111 +++++++++++---------- src/Renci.SshNet/PrivateKeyFile.OpenSSH.cs | 39 ++++---- src/Renci.SshNet/PrivateKeyFile.PuTTY.cs | 22 ++-- src/Renci.SshNet/PrivateKeyFile.SSHCOM.cs | 31 +++--- src/Renci.SshNet/PrivateKeyFile.cs | 61 ----------- 6 files changed, 129 insertions(+), 159 deletions(-) diff --git a/src/Renci.SshNet/Common/Extensions.cs b/src/Renci.SshNet/Common/Extensions.cs index 00dc27b90..bc72c70ba 100644 --- a/src/Renci.SshNet/Common/Extensions.cs +++ b/src/Renci.SshNet/Common/Extensions.cs @@ -19,7 +19,9 @@ namespace Renci.SshNet.Common /// internal static class Extensions { +#pragma warning disable S4136 // Method overloads should be grouped together internal static byte[] ToArray(this ServiceName serviceName) +#pragma warning restore S4136 // Method overloads should be grouped together { switch (serviceName) { @@ -382,6 +384,28 @@ internal static bool Remove(this Dictionary dictiona value = default; return false; } + + internal static ArraySegment Slice(this ArraySegment arraySegment, int index) + { + return new ArraySegment(arraySegment.Array, arraySegment.Offset + index, arraySegment.Count - index); + } + + internal static ArraySegment Slice(this ArraySegment arraySegment, int index, int count) + { + return new ArraySegment(arraySegment.Array, arraySegment.Offset + index, count); + } + + internal static T[] ToArray(this ArraySegment arraySegment) + { + if (arraySegment.Count == 0) + { + return Array.Empty(); + } + + var array = new T[arraySegment.Count]; + Array.Copy(arraySegment.Array, arraySegment.Offset, array, 0, arraySegment.Count); + return array; + } #endif } } diff --git a/src/Renci.SshNet/Common/SshDataStream.cs b/src/Renci.SshNet/Common/SshDataStream.cs index 653eefd5b..f2ff82f70 100644 --- a/src/Renci.SshNet/Common/SshDataStream.cs +++ b/src/Renci.SshNet/Common/SshDataStream.cs @@ -1,4 +1,6 @@ using System; +using System.Buffers.Binary; +using System.Diagnostics; using System.Globalization; using System.IO; using System.Numerics; @@ -27,7 +29,7 @@ public SshDataStream(int capacity) /// The array of unsigned bytes from which to create the current stream. /// is . public SshDataStream(byte[] buffer) - : base(buffer) + : base(buffer ?? throw new ArgumentNullException(nameof(buffer)), 0, buffer.Length, writable: true, publiclyVisible: true) { } @@ -39,7 +41,7 @@ public SshDataStream(byte[] buffer) /// The number of bytes to load. /// is . public SshDataStream(byte[] buffer, int offset, int count) - : base(buffer, offset, count) + : base(buffer, offset, count, writable: true, publiclyVisible: true) { } @@ -58,19 +60,6 @@ public bool IsEndOfData } #if NETFRAMEWORK || NETSTANDARD2_0 - private int Read(Span buffer) - { - var sharedBuffer = System.Buffers.ArrayPool.Shared.Rent(buffer.Length); - - var numRead = Read(sharedBuffer, 0, buffer.Length); - - sharedBuffer.AsSpan(0, numRead).CopyTo(buffer); - - System.Buffers.ArrayPool.Shared.Return(sharedBuffer); - - return numRead; - } - private void Write(ReadOnlySpan buffer) { var sharedBuffer = System.Buffers.ArrayPool.Shared.Rent(buffer.Length); @@ -90,7 +79,7 @@ private void Write(ReadOnlySpan buffer) public void Write(uint value) { Span bytes = stackalloc byte[4]; - System.Buffers.Binary.BinaryPrimitives.WriteUInt32BigEndian(bytes, value); + BinaryPrimitives.WriteUInt32BigEndian(bytes, value); Write(bytes); } @@ -101,7 +90,7 @@ public void Write(uint value) public void Write(ulong value) { Span bytes = stackalloc byte[8]; - System.Buffers.Binary.BinaryPrimitives.WriteUInt64BigEndian(bytes, value); + BinaryPrimitives.WriteUInt64BigEndian(bytes, value); Write(bytes); } @@ -137,6 +126,7 @@ public void Write(byte[] data) /// is . public void Write(string s, Encoding encoding) { + ThrowHelper.ThrowIfNull(s); ThrowHelper.ThrowIfNull(encoding); #if NETSTANDARD2_1 || NET @@ -153,12 +143,21 @@ public void Write(string s, Encoding encoding) } /// - /// Reads a byte array from the SSH data stream. + /// Reads a length-prefixed byte array from the SSH data stream. /// /// /// The byte array read from the SSH data stream. /// public byte[] ReadBinary() + { + return ReadBinarySegment().ToArray(); + } + + /// + /// Reads a length-prefixed byte array from the SSH data stream, + /// returned as a view over the underlying buffer. + /// + internal ArraySegment ReadBinarySegment() { var length = ReadUInt32(); @@ -167,7 +166,23 @@ public byte[] ReadBinary() throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Data longer than {0} is not supported.", int.MaxValue)); } - return ReadBytes((int)length); + var buffer = GetRemainingBuffer().Slice(0, (int)length); + + Position += length; + + return buffer; + } + + /// + /// Gets a view over the remaining data in the underlying buffer. + /// + private ArraySegment GetRemainingBuffer() + { + var success = TryGetBuffer(out var buffer); + + Debug.Assert(success, "Expected buffer to be publicly visible"); + + return buffer.Slice((int)Position); } /// @@ -205,11 +220,11 @@ public void WriteBinary(byte[] buffer, int offset, int count) /// public BigInteger ReadBigInt() { - var data = ReadBinary(); - #if NETSTANDARD2_1 || NET + var data = ReadBinarySegment(); return new BigInteger(data, isBigEndian: true); #else + var data = ReadBinary(); Array.Reverse(data); return new BigInteger(data); #endif @@ -223,9 +238,9 @@ public BigInteger ReadBigInt() /// public ushort ReadUInt16() { - Span bytes = stackalloc byte[2]; - ReadBytes(bytes); - return System.Buffers.Binary.BinaryPrimitives.ReadUInt16BigEndian(bytes); + var ret = BinaryPrimitives.ReadUInt16BigEndian(GetRemainingBuffer()); + Position += sizeof(ushort); + return ret; } /// @@ -236,9 +251,9 @@ public ushort ReadUInt16() /// public uint ReadUInt32() { - Span span = stackalloc byte[4]; - ReadBytes(span); - return System.Buffers.Binary.BinaryPrimitives.ReadUInt32BigEndian(span); + var ret = BinaryPrimitives.ReadUInt32BigEndian(GetRemainingBuffer()); + Position += sizeof(uint); + return ret; } /// @@ -249,9 +264,9 @@ public uint ReadUInt32() /// public ulong ReadUInt64() { - Span span = stackalloc byte[8]; - ReadBytes(span); - return System.Buffers.Binary.BinaryPrimitives.ReadUInt64BigEndian(span); + var ret = BinaryPrimitives.ReadUInt64BigEndian(GetRemainingBuffer()); + Position += sizeof(ulong); + return ret; } /// @@ -265,19 +280,13 @@ public string ReadString(Encoding encoding = null) { encoding ??= Encoding.UTF8; - var length = ReadUInt32(); - - if (length > int.MaxValue) - { - throw new NotSupportedException(string.Format(CultureInfo.CurrentCulture, "Strings longer than {0} is not supported.", int.MaxValue)); - } + var bytes = ReadBinarySegment(); - var bytes = ReadBytes((int)length); - return encoding.GetString(bytes, 0, bytes.Length); + return encoding.GetString(bytes.Array, bytes.Offset, bytes.Count); } /// - /// Writes the stream contents to a byte array, regardless of the . + /// Retrieves the stream contents as a byte array, regardless of the . /// /// /// This method returns the contents of the as a byte array. @@ -288,9 +297,15 @@ public string ReadString(Encoding encoding = null) /// public override byte[] ToArray() { - if (Capacity == Length) + var success = TryGetBuffer(out var buffer); + + Debug.Assert(success, "Expected buffer to be publicly visible"); + + if (buffer.Offset == 0 && + buffer.Count == buffer.Array.Length && + buffer.Count == Length) { - return GetBuffer(); + return buffer.Array; } return base.ToArray(); @@ -315,19 +330,5 @@ internal byte[] ReadBytes(int length) return data; } - - /// - /// Reads data into the specified . - /// - /// The buffer to read into. - /// is larger than the total of bytes available. - private void ReadBytes(Span buffer) - { - var bytesRead = Read(buffer); - if (bytesRead < buffer.Length) - { - throw new ArgumentOutOfRangeException(nameof(buffer), string.Format(CultureInfo.InvariantCulture, "The requested length ({0}) is greater than the actual number of bytes read ({1}).", buffer.Length, bytesRead)); - } - } } } diff --git a/src/Renci.SshNet/PrivateKeyFile.OpenSSH.cs b/src/Renci.SshNet/PrivateKeyFile.OpenSSH.cs index a91635146..379b51cb3 100644 --- a/src/Renci.SshNet/PrivateKeyFile.OpenSSH.cs +++ b/src/Renci.SshNet/PrivateKeyFile.OpenSSH.cs @@ -32,7 +32,7 @@ public OpenSSH(byte[] data, string? passPhrase) /// public Key Parse() { - var keyReader = new SshDataReader(_data); + var keyReader = new SshDataStream(_data); // check magic header var authMagic = "openssh-key-v1\0"u8; @@ -171,7 +171,7 @@ public Key Parse() // now parse the data we called the private key, it actually contains the public key again // so we need to parse through it to get the private key bytes, plus there's some // validation we need to do. - var privateKeyReader = new SshDataReader(privateKeyBytes); + var privateKeyReader = new SshDataStream(privateKeyBytes); // check ints should match, they wouldn't match for example if the wrong passphrase was supplied var checkInt1 = (int)privateKeyReader.ReadUInt32(); @@ -196,33 +196,29 @@ public Key Parse() // https://datatracker.ietf.org/doc/html/draft-miller-ssh-agent-11#section-3.2.3 // ENC(A) - _ = privateKeyReader.ReadBignum2(); + _ = privateKeyReader.ReadBinarySegment(); // k || ENC(A) - unencryptedPrivateKey = privateKeyReader.ReadBignum2(); + unencryptedPrivateKey = privateKeyReader.ReadBinary(); parsedKey = new ED25519Key(unencryptedPrivateKey); break; case "ecdsa-sha2-nistp256": case "ecdsa-sha2-nistp384": case "ecdsa-sha2-nistp521": - // curve - var len = (int)privateKeyReader.ReadUInt32(); - var curve = Encoding.ASCII.GetString(privateKeyReader.ReadBytes(len)); + var curve = privateKeyReader.ReadString(Encoding.ASCII); - // public key - publicKey = privateKeyReader.ReadBignum2(); + publicKey = privateKeyReader.ReadBinary(); - // private key - unencryptedPrivateKey = privateKeyReader.ReadBignum2(); + unencryptedPrivateKey = privateKeyReader.ReadBinary(); parsedKey = new EcdsaKey(curve, publicKey, unencryptedPrivateKey.TrimLeadingZeros()); break; case "ssh-rsa": - var modulus = privateKeyReader.ReadBignum(); // n - var exponent = privateKeyReader.ReadBignum(); // e - var d = privateKeyReader.ReadBignum(); // d - var inverseQ = privateKeyReader.ReadBignum(); // iqmp - var p = privateKeyReader.ReadBignum(); // p - var q = privateKeyReader.ReadBignum(); // q + var modulus = privateKeyReader.ReadBigInt(); + var exponent = privateKeyReader.ReadBigInt(); + var d = privateKeyReader.ReadBigInt(); + var inverseQ = privateKeyReader.ReadBigInt(); + var p = privateKeyReader.ReadBigInt(); + var q = privateKeyReader.ReadBigInt(); parsedKey = new RsaKey(modulus, exponent, d, p, q, inverseQ); break; default: @@ -233,14 +229,17 @@ public Key Parse() // The list of privatekey/comment pairs is padded with the bytes 1, 2, 3, ... // until the total length is a multiple of the cipher block size. - var padding = privateKeyReader.ReadBytes(); - for (var i = 0; i < padding.Length; i++) + int b, i = 0; + + while ((b = privateKeyReader.ReadByte()) != -1) { - if ((int)padding[i] != i + 1) + if (b != i + 1) { throw new SshException("Padding of openssh key format contained wrong byte at position: " + i.ToString(CultureInfo.InvariantCulture)); } + + i++; } return parsedKey; diff --git a/src/Renci.SshNet/PrivateKeyFile.PuTTY.cs b/src/Renci.SshNet/PrivateKeyFile.PuTTY.cs index 6b1f0ea82..3ac40242b 100644 --- a/src/Renci.SshNet/PrivateKeyFile.PuTTY.cs +++ b/src/Renci.SshNet/PrivateKeyFile.PuTTY.cs @@ -163,34 +163,34 @@ public Key Parse() throw new SshException("MAC verification failed for PuTTY key file"); } - var publicKeyReader = new SshDataReader(_publicKey); + var publicKeyReader = new SshDataStream(_publicKey); var keyType = publicKeyReader.ReadString(Encoding.UTF8); Debug.Assert(keyType == _algorithmName, $"{nameof(keyType)} is not the same as {nameof(_algorithmName)}"); - var privateKeyReader = new SshDataReader(privateKey); + var privateKeyReader = new SshDataStream(privateKey); Key parsedKey; switch (keyType) { case "ssh-ed25519": - parsedKey = new ED25519Key(privateKeyReader.ReadBignum2()); + parsedKey = new ED25519Key(privateKeyReader.ReadBinary()); break; case "ecdsa-sha2-nistp256": case "ecdsa-sha2-nistp384": case "ecdsa-sha2-nistp521": var curve = publicKeyReader.ReadString(Encoding.ASCII); - var pub = publicKeyReader.ReadBignum2(); - var prv = privateKeyReader.ReadBignum2(); + var pub = publicKeyReader.ReadBinary(); + var prv = privateKeyReader.ReadBinary(); parsedKey = new EcdsaKey(curve, pub, prv); break; case "ssh-rsa": - var exponent = publicKeyReader.ReadBignum(); // e - var modulus = publicKeyReader.ReadBignum(); // n - var d = privateKeyReader.ReadBignum(); // d - var p = privateKeyReader.ReadBignum(); // p - var q = privateKeyReader.ReadBignum(); // q - var inverseQ = privateKeyReader.ReadBignum(); // iqmp + var exponent = publicKeyReader.ReadBigInt(); + var modulus = publicKeyReader.ReadBigInt(); + var d = privateKeyReader.ReadBigInt(); + var p = privateKeyReader.ReadBigInt(); + var q = privateKeyReader.ReadBigInt(); + var inverseQ = privateKeyReader.ReadBigInt(); parsedKey = new RsaKey(modulus, exponent, d, p, q, inverseQ); break; default: diff --git a/src/Renci.SshNet/PrivateKeyFile.SSHCOM.cs b/src/Renci.SshNet/PrivateKeyFile.SSHCOM.cs index da879c3c3..5be439608 100644 --- a/src/Renci.SshNet/PrivateKeyFile.SSHCOM.cs +++ b/src/Renci.SshNet/PrivateKeyFile.SSHCOM.cs @@ -1,6 +1,7 @@ #nullable enable using System; using System.Collections.Generic; +using System.Numerics; using System.Security.Cryptography; using System.Text; @@ -27,7 +28,7 @@ public SSHCOM(byte[] data, string? passPhrase) public Key Parse() { - var reader = new SshDataReader(_data); + var reader = new SshDataStream(_data); var magicNumber = reader.ReadUInt32(); if (magicNumber != 0x3f6ff9eb) { @@ -60,11 +61,7 @@ public Key Parse() throw new SshException(string.Format("Cipher method '{0}' is not supported.", ssh2CipherName)); } - /* - * TODO: Create two specific data types to avoid using SshDataReader class. - */ - - reader = new SshDataReader(keyData); + reader = new SshDataStream(keyData); var decryptedLength = reader.ReadUInt32(); @@ -75,16 +72,26 @@ public Key Parse() if (keyType.Contains("rsa")) { - var exponent = reader.ReadBigIntWithBits(); // e - var d = reader.ReadBigIntWithBits(); // d - var modulus = reader.ReadBigIntWithBits(); // n - var inverseQ = reader.ReadBigIntWithBits(); // u - var q = reader.ReadBigIntWithBits(); // p - var p = reader.ReadBigIntWithBits(); // q + var exponent = ReadBigIntWithBits(reader); + var d = ReadBigIntWithBits(reader); + var modulus = ReadBigIntWithBits(reader); + var inverseQ = ReadBigIntWithBits(reader); + var q = ReadBigIntWithBits(reader); + var p = ReadBigIntWithBits(reader); return new RsaKey(modulus, exponent, d, p, q, inverseQ); } throw new NotSupportedException(string.Format("Key type '{0}' is not supported.", keyType)); + + // Reads next mpint where length is specified in bits. + static BigInteger ReadBigIntWithBits(SshDataStream reader) + { + var numBits = (int)reader.ReadUInt32(); + + var numBytes = (numBits + 7) / 8; + + return reader.ReadBytes(numBytes).ToBigInteger2(); + } } private static byte[] GetCipherKey(string passphrase, int length) diff --git a/src/Renci.SshNet/PrivateKeyFile.cs b/src/Renci.SshNet/PrivateKeyFile.cs index fe19ddfe5..589d9a774 100644 --- a/src/Renci.SshNet/PrivateKeyFile.cs +++ b/src/Renci.SshNet/PrivateKeyFile.cs @@ -6,9 +6,7 @@ using System.Globalization; using System.IO; using System.Linq; -using System.Numerics; using System.Security.Cryptography; -using System.Text; using System.Text.RegularExpressions; using Renci.SshNet.Common; @@ -474,65 +472,6 @@ protected virtual void Dispose(bool disposing) } } - private sealed class SshDataReader : SshData - { - public SshDataReader(byte[] data) - { - Load(data); - } - - public new uint ReadUInt32() - { - return base.ReadUInt32(); - } - - public new string ReadString(Encoding encoding) - { - return base.ReadString(encoding); - } - - public new byte[] ReadBytes(int length) - { - return base.ReadBytes(length); - } - - public new byte[] ReadBytes() - { - return base.ReadBytes(); - } - - /// - /// Reads next mpint data type from internal buffer where length specified in bits. - /// - /// mpint read. - public BigInteger ReadBigIntWithBits() - { - var length = (int)base.ReadUInt32(); - - length = (length + 7) / 8; - - return base.ReadBytes(length).ToBigInteger2(); - } - - public BigInteger ReadBignum() - { - return DataStream.ReadBigInt(); - } - - public byte[] ReadBignum2() - { - return ReadBinary(); - } - - protected override void LoadData() - { - } - - protected override void SaveData() - { - } - } - /// /// Represents private key parser. ///