diff --git a/src/libraries/System.Memory/src/System/Buffers/Text/Base64Decoder.cs b/src/libraries/System.Memory/src/System/Buffers/Text/Base64Decoder.cs index 36a7afa5369566..ee67a360ca7cfe 100644 --- a/src/libraries/System.Memory/src/System/Buffers/Text/Base64Decoder.cs +++ b/src/libraries/System.Memory/src/System/Buffers/Text/Base64Decoder.cs @@ -5,12 +5,14 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; using System.Runtime.Intrinsics.X86; namespace System.Buffers.Text { // AVX2 version based on https://github.com/aklomp/base64/tree/e516d769a2a432c08404f1981e73b431566057be/lib/arch/avx2 // SSSE3 version based on https://github.com/aklomp/base64/tree/e516d769a2a432c08404f1981e73b431566057be/lib/arch/ssse3 + // AdvSimd version based on https://github.com/aklomp/base64/blob/e516d769a2a432c08404f1981e73b431566057be/lib/arch/neon64 public static partial class Base64 { @@ -81,6 +83,15 @@ public static unsafe OperationStatus DecodeFromUtf8(ReadOnlySpan utf8, Spa if (src == srcEnd) goto DoneExit; } + + end = srcMax - 96; + if (BitConverter.IsLittleEndian && AdvSimd.Arm64.IsSupported && (end >= src)) + { + AdvSimdDecode(ref src, ref dest, end, maxSrcLength, destLength, srcBytes, destBytes); + + if (src == srcEnd) + goto DoneExit; + } } // Last bytes could have padding characters, so process them separately and treat them as valid only if isFinalBlock is true @@ -644,6 +655,133 @@ private static unsafe void Ssse3Decode(ref byte* srcBytes, ref byte* destBytes, destBytes = dest; } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector128 AdvSimdTbx8Byte(Vector128 defaults, Vector128 table0, Vector128 table1, Vector128 table2, Vector128 table3, Vector128 table4, Vector128 table5, Vector128 table6, Vector128 table7, Vector128 indicies, Vector128 offset) + { + // Implement an 8 way table lookup. + // This could be reduced by using two NEON TBX4 instructions. + + Debug.Assert(AdvSimd.Arm64.IsSupported && BitConverter.IsLittleEndian); + + Vector128 dest = defaults; + Vector128 indicies_sub = indicies; + + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table0, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table1, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table2, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table3, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table4, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table5, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table6, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table7, indicies_sub); + + return dest; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe Vector128 AdvSimdTbx3Byte(Vector128 defaults, Vector128 table0, Vector128 table1, Vector128 table2, Vector128 indicies, Vector128 offset) + { + // Implement a 3 way table lookup. + + Debug.Assert(AdvSimd.Arm64.IsSupported && BitConverter.IsLittleEndian); + + Vector128 dest = defaults; + Vector128 indicies_sub = indicies; + + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table0, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table1, indicies_sub); + indicies_sub = AdvSimd.Subtract(indicies_sub, offset); + dest = AdvSimd.Arm64.VectorTableLookupExtension(dest, table2, indicies_sub); + + return dest; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static unsafe void AdvSimdDecode(ref byte* srcBytes, ref byte* destBytes, byte* srcEnd, int sourceLength, int destLength, byte* srcStart, byte* destStart) + { + Debug.Assert(AdvSimd.Arm64.IsSupported && BitConverter.IsLittleEndian); + + // Complete lookup table - similar to that used in the SS3 decode. + Vector128 dec_lut0 = Vector128.Create(255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255); + Vector128 dec_lut1 = Vector128.Create(255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255); + Vector128 dec_lut2 = Vector128.Create(255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, 255, 255, 63); + Vector128 dec_lut3 = Vector128.Create( 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, 255, 255, 255, 255); + Vector128 dec_lut4 = Vector128.Create(255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14); + Vector128 dec_lut5 = Vector128.Create( 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 255, 255, 255, 255, 255); + Vector128 dec_lut6 = Vector128.Create(255, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40); + Vector128 dec_lut7 = Vector128.Create( 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 255, 255, 255, 255, 255); + + // Interleave pattern for the ST3. + Vector128 st3_interleave_index0 = Vector128.Create((byte) 0, 16, 32, 1, 17, 33, 2, 18, 34, 3, 19, 35, 4, 20, 36, 5); + Vector128 st3_interleave_index1 = Vector128.Create((byte)21, 37, 6, 22, 38, 7, 23, 39, 8, 24, 40, 9, 25, 41, 10, 26); + Vector128 st3_interleave_index2 = Vector128.Create((byte)42, 11, 27, 43, 12, 28, 44, 13, 29, 45, 14, 30, 46, 15, 31, 47); + + // Some constants. + Vector128 vzero = Vector128.Create((byte)0); + Vector128 v255 = Vector128.Create((byte)255U); + Vector128 v16 = Vector128.Create((byte)16U); + + byte* src = srcBytes; + byte* dest = destBytes; + + do + { + // Load 64 bytes of data and deinterleave the result. + // This is equivalent to a NEON LD4 instruction. + Vector128 str0 = Vector128.LoadUnsafe(ref *src); + Vector128 str1 = Vector128.LoadUnsafe(ref *src, 16); + Vector128 str2 = Vector128.LoadUnsafe(ref *src, 32); + Vector128 str3 = Vector128.LoadUnsafe(ref *src, 48); + Vector128 tmp0 = AdvSimd.Arm64.UnzipEven(str0.AsInt16(), str1.AsInt16()); + Vector128 tmp1 = AdvSimd.Arm64.UnzipOdd(str0.AsInt16(), str1.AsInt16()); + Vector128 tmp2 = AdvSimd.Arm64.UnzipEven(str2.AsInt16(), str3.AsInt16()); + Vector128 tmp3 = AdvSimd.Arm64.UnzipOdd(str2.AsInt16(), str3.AsInt16()); + str0 = AdvSimd.Arm64.UnzipEven(tmp0.AsByte(), tmp2.AsByte()); + str1 = AdvSimd.Arm64.UnzipOdd(tmp0.AsByte(), tmp2.AsByte()); + str2 = AdvSimd.Arm64.UnzipEven(tmp1.AsByte(), tmp3.AsByte()); + str3 = AdvSimd.Arm64.UnzipOdd(tmp1.AsByte(), tmp3.AsByte()); + + // Table lookup on each 16 bytes. + str0 = AdvSimdTbx8Byte(v255, dec_lut0, dec_lut1, dec_lut2, dec_lut3, dec_lut4, dec_lut5, dec_lut6, dec_lut7, str0, v16); + str1 = AdvSimdTbx8Byte(v255, dec_lut0, dec_lut1, dec_lut2, dec_lut3, dec_lut4, dec_lut5, dec_lut6, dec_lut7, str1, v16); + str2 = AdvSimdTbx8Byte(v255, dec_lut0, dec_lut1, dec_lut2, dec_lut3, dec_lut4, dec_lut5, dec_lut6, dec_lut7, str2, v16); + str3 = AdvSimdTbx8Byte(v255, dec_lut0, dec_lut1, dec_lut2, dec_lut3, dec_lut4, dec_lut5, dec_lut6, dec_lut7, str3, v16); + + // Check for invalid input, any value larger than 63. + Vector128 classified0 = AdvSimd.Arm64.MaxPairwise(str0, str1); + Vector128 classified1 = AdvSimd.Arm64.MaxPairwise(str2, str3); + Vector128 maxChars = AdvSimd.Arm64.MaxPairwise(classified0, classified1); + if ((maxChars.AsUInt64().ToScalar() & 0xc0c0c0c0c0c0c0c0) != 0) + break; + + // Compress each four bytes into three. + Vector128 dec0 = Vector128.BitwiseOr(Vector128.ShiftLeft(str0, 2), Vector128.ShiftRightLogical(str1, 4)); + Vector128 dec1 = Vector128.BitwiseOr(Vector128.ShiftLeft(str1, 4), Vector128.ShiftRightLogical(str2, 2)); + Vector128 dec2 = Vector128.BitwiseOr(Vector128.ShiftLeft(str2, 6), str3); + + // Interleave the decoded result and store out. + // This is equivalent to a NEON ST3 instruction. + AdvSimdTbx3Byte(vzero, dec0, dec1, dec2, st3_interleave_index0, v16).Store(dest); + AdvSimdTbx3Byte(vzero, dec0, dec1, dec2, st3_interleave_index1, v16).Store(dest + 16); + AdvSimdTbx3Byte(vzero, dec0, dec1, dec2, st3_interleave_index2, v16).Store(dest + 32); + + src += 64; + dest += 48; + } + while (src <= srcEnd); + + srcBytes = src; + destBytes = dest; + } + [MethodImpl(MethodImplOptions.AggressiveInlining)] private static unsafe int Decode(byte* encodedBytes, ref sbyte decodingMap) {