diff --git a/src/libraries/System.Collections/src/System/Collections/BitArray.cs b/src/libraries/System.Collections/src/System/Collections/BitArray.cs index db0f59e8ac7af2..d7434fd27df27f 100644 --- a/src/libraries/System.Collections/src/System/Collections/BitArray.cs +++ b/src/libraries/System.Collections/src/System/Collections/BitArray.cs @@ -115,10 +115,6 @@ public BitArray(byte[] bytes) _version = 0; } - private const uint Vector128ByteCount = 16; - private const uint Vector128IntCount = 4; - private const uint Vector256ByteCount = 32; - private const uint Vector256IntCount = 8; public unsafe BitArray(bool[] values) { ArgumentNullException.ThrowIfNull(values); @@ -138,10 +134,21 @@ public unsafe BitArray(bool[] values) // Instead, We compare with zeroes (== false) then negate the result to ensure compatibility. ref byte value = ref Unsafe.As(ref MemoryMarshal.GetArrayDataReference(values)); + if (Vector512.IsHardwareAccelerated) + { + for (; i <= (uint)values.Length - Vector512.Count; i += (uint)Vector512.Count) + { + Vector512 vector = Vector512.LoadUnsafe(ref value, i); + Vector512 isFalse = Vector512.Equals(vector, Vector512.Zero); - if (Vector256.IsHardwareAccelerated) + ulong result = isFalse.ExtractMostSignificantBits(); + m_array[i / 32u] = (int)(~result & 0x00000000FFFFFFFF); + m_array[(i / 32u) + 1] = (int)((~result >> 32) & 0x00000000FFFFFFFF); + } + } + else if (Vector256.IsHardwareAccelerated) { - for (; (i + Vector256ByteCount) <= (uint)values.Length; i += Vector256ByteCount) + for (; i <= (uint)values.Length - Vector256.Count; i += (uint)Vector256.Count) { Vector256 vector = Vector256.LoadUnsafe(ref value, i); Vector256 isFalse = Vector256.Equals(vector, Vector256.Zero); @@ -152,13 +159,13 @@ public unsafe BitArray(bool[] values) } else if (Vector128.IsHardwareAccelerated) { - for (; (i + Vector128ByteCount * 2u) <= (uint)values.Length; i += Vector128ByteCount * 2u) + for (; i <= (uint)values.Length - Vector128.Count * 2u; i += (uint)Vector128.Count * 2u) { Vector128 lowerVector = Vector128.LoadUnsafe(ref value, i); Vector128 lowerIsFalse = Vector128.Equals(lowerVector, Vector128.Zero); uint lowerResult = lowerIsFalse.ExtractMostSignificantBits(); - Vector128 upperVector = Vector128.LoadUnsafe(ref value, i + Vector128ByteCount); + Vector128 upperVector = Vector128.LoadUnsafe(ref value, i + (uint)Vector128.Count); Vector128 upperIsFalse = Vector128.Equals(upperVector, Vector128.Zero); uint upperResult = upperIsFalse.ExtractMostSignificantBits(); @@ -339,18 +346,25 @@ public unsafe BitArray And(BitArray value) ref int left = ref MemoryMarshal.GetArrayDataReference(thisArray); ref int right = ref MemoryMarshal.GetArrayDataReference(valueArray); - - if (Vector256.IsHardwareAccelerated) + if (Vector512.IsHardwareAccelerated && (uint)count >= Vector512.Count) { - for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount) + for (; i < (uint)count - (Vector512.Count - 1u); i += (uint)Vector512.Count) + { + Vector512 result = Vector512.LoadUnsafe(ref left, i) & Vector512.LoadUnsafe(ref right, i); + result.StoreUnsafe(ref left, i); + } + } + else if (Vector256.IsHardwareAccelerated && (uint)count >= Vector256.Count) + { + for (; i < (uint)count - (Vector256.Count - 1u); i += (uint)Vector256.Count) { Vector256 result = Vector256.LoadUnsafe(ref left, i) & Vector256.LoadUnsafe(ref right, i); result.StoreUnsafe(ref left, i); } } - else if (Vector128.IsHardwareAccelerated) + else if (Vector128.IsHardwareAccelerated && (uint)count >= Vector128.Count) { - for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount) + for (; i < (uint)count - (Vector128.Count - 1u); i += (uint)Vector128.Count) { Vector128 result = Vector128.LoadUnsafe(ref left, i) & Vector128.LoadUnsafe(ref right, i); result.StoreUnsafe(ref left, i); @@ -405,18 +419,25 @@ public unsafe BitArray Or(BitArray value) ref int left = ref MemoryMarshal.GetArrayDataReference(thisArray); ref int right = ref MemoryMarshal.GetArrayDataReference(valueArray); - - if (Vector256.IsHardwareAccelerated) + if (Vector512.IsHardwareAccelerated && (uint)count >= Vector512.Count) + { + for (; i < (uint)count - (Vector512.Count - 1u); i += (uint)Vector512.Count) + { + Vector512 result = Vector512.LoadUnsafe(ref left, i) | Vector512.LoadUnsafe(ref right, i); + result.StoreUnsafe(ref left, i); + } + } + else if (Vector256.IsHardwareAccelerated && (uint)count >= Vector256.Count) { - for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount) + for (; i < (uint)count - (Vector256.Count - 1u); i += (uint)Vector256.Count) { Vector256 result = Vector256.LoadUnsafe(ref left, i) | Vector256.LoadUnsafe(ref right, i); result.StoreUnsafe(ref left, i); } } - else if (Vector128.IsHardwareAccelerated) + else if (Vector128.IsHardwareAccelerated && (uint)count >= Vector128.Count) { - for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount) + for (; i < (uint)count - (Vector128.Count - 1u); i += (uint)Vector128.Count) { Vector128 result = Vector128.LoadUnsafe(ref left, i) | Vector128.LoadUnsafe(ref right, i); result.StoreUnsafe(ref left, i); @@ -472,17 +493,25 @@ public unsafe BitArray Xor(BitArray value) ref int left = ref MemoryMarshal.GetArrayDataReference(thisArray); ref int right = ref MemoryMarshal.GetArrayDataReference(valueArray); - if (Vector256.IsHardwareAccelerated) + if (Vector512.IsHardwareAccelerated && (uint)count >= Vector512.Count) + { + for (; i < (uint)count - (Vector512.Count - 1u); i += (uint)Vector512.Count) + { + Vector512 result = Vector512.LoadUnsafe(ref left, i) ^ Vector512.LoadUnsafe(ref right, i); + result.StoreUnsafe(ref left, i); + } + } + else if (Vector256.IsHardwareAccelerated && (uint)count >= Vector256.Count) { - for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount) + for (; i < (uint)count - (Vector256.Count - 1u); i += (uint)Vector256.Count) { Vector256 result = Vector256.LoadUnsafe(ref left, i) ^ Vector256.LoadUnsafe(ref right, i); result.StoreUnsafe(ref left, i); } } - else if (Vector128.IsHardwareAccelerated) + else if (Vector128.IsHardwareAccelerated && (uint)count >= Vector128.Count) { - for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount) + for (; i < (uint)count - (Vector128.Count - 1u); i += (uint)Vector128.Count) { Vector128 result = Vector128.LoadUnsafe(ref left, i) ^ Vector128.LoadUnsafe(ref right, i); result.StoreUnsafe(ref left, i); @@ -529,18 +558,25 @@ public unsafe BitArray Not() uint i = 0; ref int value = ref MemoryMarshal.GetArrayDataReference(thisArray); - - if (Vector256.IsHardwareAccelerated) + if (Vector512.IsHardwareAccelerated && (uint)count >= Vector512.Count) { - for (; i < (uint)count - (Vector256IntCount - 1u); i += Vector256IntCount) + for (; i < (uint)count - (Vector512.Count - 1u); i += (uint)Vector512.Count) + { + Vector512 result = ~Vector512.LoadUnsafe(ref value, i); + result.StoreUnsafe(ref value, i); + } + } + else if (Vector256.IsHardwareAccelerated && (uint)count >= Vector256.Count) + { + for (; i < (uint)count - (Vector256.Count - 1u); i += (uint)Vector256.Count) { Vector256 result = ~Vector256.LoadUnsafe(ref value, i); result.StoreUnsafe(ref value, i); } } - else if (Vector128.IsHardwareAccelerated) + else if (Vector128.IsHardwareAccelerated && (uint)count >= Vector128.Count) { - for (; i < (uint)count - (Vector128IntCount - 1u); i += Vector128IntCount) + for (; i < (uint)count - (Vector128.Count - 1u); i += (uint)Vector128.Count) { Vector128 result = ~Vector128.LoadUnsafe(ref value, i); result.StoreUnsafe(ref value, i); @@ -797,21 +833,47 @@ public unsafe void CopyTo(Array array, int index) if (m_length < BitsPerInt32) goto LessThan32; - // The mask used when shuffling a single int into Vector128/256. + // The mask used when shuffling a single int into Vector128/256/512. // On little endian machines, the lower 8 bits of int belong in the first byte, next lower 8 in the second and so on. // We place the bytes that contain the bits to its respective byte so that we can mask out only the relevant bits later. Vector128 lowerShuffleMask_CopyToBoolArray = Vector128.Create(0, 0x01010101_01010101).AsByte(); Vector128 upperShuffleMask_CopyToBoolArray = Vector128.Create(0x02020202_02020202, 0x03030303_03030303).AsByte(); - if (Avx2.IsSupported) + if (Avx512F.IsSupported && (uint)m_length >= Vector512.Count) + { + Vector256 upperShuffleMask_CopyToBoolArray256 = Vector256.Create(0x04040404_04040404, 0x05050505_05050505, + 0x06060606_06060606, 0x07070707_07070707).AsByte(); + Vector256 lowerShuffleMask_CopyToBoolArray256 = Vector256.Create(lowerShuffleMask_CopyToBoolArray, upperShuffleMask_CopyToBoolArray); + Vector512 shuffleMask = Vector512.Create(lowerShuffleMask_CopyToBoolArray256, upperShuffleMask_CopyToBoolArray256); + Vector512 bitMask = Vector512.Create(0x80402010_08040201).AsByte(); + Vector512 ones = Vector512.Create((byte)1); + + fixed (bool* destination = &boolArray[index]) + { + for (; (i + Vector512.Count) <= (uint)m_length; i += (uint)Vector512.Count) + { + ulong bits = (ulong)(uint)m_array[i / (uint)BitsPerInt32] + ((ulong)m_array[(i / (uint)BitsPerInt32) + 1] << BitsPerInt32); + Vector512 scalar = Vector512.Create(bits); + Vector512 shuffled = Avx512BW.Shuffle(scalar.AsByte(), shuffleMask); + Vector512 extracted = Avx512F.And(shuffled, bitMask); + + // The extracted bits can be anywhere between 0 and 255, so we normalise the value to either 0 or 1 + // to ensure compatibility with "C# bool" (0 for false, 1 for true, rest undefined) + Vector512 normalized = Avx512BW.Min(extracted, ones); + Avx512F.Store((byte*)destination + i, normalized); + } + } + } + else if (Avx2.IsSupported && (uint)m_length >= Vector256.Count) { Vector256 shuffleMask = Vector256.Create(lowerShuffleMask_CopyToBoolArray, upperShuffleMask_CopyToBoolArray); Vector256 bitMask = Vector256.Create(0x80402010_08040201).AsByte(); + //Internal.Console.WriteLine(bitMask); Vector256 ones = Vector256.Create((byte)1); fixed (bool* destination = &boolArray[index]) { - for (; (i + Vector256ByteCount) <= (uint)m_length; i += Vector256ByteCount) + for (; (i + Vector256.Count) <= (uint)m_length; i += (uint)Vector256.Count) { int bits = m_array[i / (uint)BitsPerInt32]; Vector256 scalar = Vector256.Create(bits); @@ -825,7 +887,7 @@ public unsafe void CopyTo(Array array, int index) } } } - else if (Ssse3.IsSupported) + else if (Ssse3.IsSupported && ((uint)m_length >= Vector512.Count * 2u)) { Vector128 lowerShuffleMask = lowerShuffleMask_CopyToBoolArray; Vector128 upperShuffleMask = upperShuffleMask_CopyToBoolArray; @@ -836,7 +898,7 @@ public unsafe void CopyTo(Array array, int index) fixed (bool* destination = &boolArray[index]) { - for (; (i + Vector128ByteCount * 2u) <= (uint)m_length; i += Vector128ByteCount * 2u) + for (; (i + Vector128.Count * 2u) <= (uint)m_length; i += (uint)Vector128.Count * 2u) { int bits = m_array[i / (uint)BitsPerInt32]; Vector128 scalar = Vector128.CreateScalarUnsafe(bits); @@ -862,7 +924,7 @@ public unsafe void CopyTo(Array array, int index) fixed (bool* destination = &boolArray[index]) { - for (; (i + Vector128ByteCount * 2u) <= (uint)m_length; i += Vector128ByteCount * 2u) + for (; (i + Vector128.Count * 2u) <= (uint)m_length; i += (uint)Vector128.Count * 2u) { int bits = m_array[i / (uint)BitsPerInt32]; // Same logic as SSSE3 path, except we do not have Shuffle instruction.