diff --git a/CMakeLists.txt b/CMakeLists.txt index cc071e6e8c..6f0a924aa3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -67,6 +67,7 @@ option(FAISS_ENABLE_PYTHON "Build Python extension." ON) option(FAISS_ENABLE_C_API "Build C API." OFF) option(FAISS_ENABLE_EXTRAS "Build extras like benchmarks and demos" ON) option(FAISS_USE_LTO "Enable Link-Time optimization" OFF) +option(FAISS_ENABLE_AVX512_FP16 "Enable AVX512-FP16 arithmetic (for avx512_spr opt level)." OFF) if(FAISS_ENABLE_GPU) if(FAISS_ENABLE_ROCM) diff --git a/faiss/CMakeLists.txt b/faiss/CMakeLists.txt index 4ebcaf59e9..674d745d40 100644 --- a/faiss/CMakeLists.txt +++ b/faiss/CMakeLists.txt @@ -360,6 +360,14 @@ target_compile_definitions(faiss_avx512 PRIVATE FINTEGER=int) target_compile_definitions(faiss_avx512_spr PRIVATE FINTEGER=int) target_compile_definitions(faiss_sve PRIVATE FINTEGER=int) +if(FAISS_ENABLE_AVX512_FP16) + if (FAISS_OPT_LEVEL STREQUAL "avx512_spr") + target_compile_definitions(faiss_avx512_spr PRIVATE ENABLE_AVX512_FP16) + else() + message(STATUS "AVX512_FP16 not supported: requires FAISS_OPT_LEVEL=avx512_spr.") + endif() +endif() + if(FAISS_USE_LTO) include(CheckIPOSupported) check_ipo_supported(RESULT ipo_supported OUTPUT ipo_error) diff --git a/faiss/impl/ScalarQuantizer.cpp b/faiss/impl/ScalarQuantizer.cpp index e05f3a1f25..ebfd5bef9d 100644 --- a/faiss/impl/ScalarQuantizer.cpp +++ b/faiss/impl/ScalarQuantizer.cpp @@ -44,7 +44,10 @@ namespace faiss { * that hides the template mess. ********************************************************************/ -#if defined(__AVX512F__) && defined(__F16C__) +#if defined(ENABLE_AVX512_FP16) && defined(__FLT16_MANT_DIG__) && \ + defined(__AVX512FP16__) +#define USE_AVX512_FP16 +#elif defined(__AVX512F__) && defined(__F16C__) #define USE_AVX512_F16C #elif defined(__AVX2__) #ifdef __F16C__ @@ -70,6 +73,10 @@ typedef ScalarQuantizer::QuantizerType QuantizerType; typedef ScalarQuantizer::RangeStat RangeStat; using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer; +#if defined(USE_AVX512_FP16) +using fp16 = _Float16; +#endif + /******************************************************************* * Codec: converts between values in [0, 1] and an index in a code * array. The "i" parameter is the vector component index (not byte @@ -90,6 +97,19 @@ struct Codec8bit { return (code[i] + 0.5f) / 255.0f; } +#if defined(USE_AVX512_FP16) + static FAISS_ALWAYS_INLINE __m512h + decode_32_components(const uint8_t* code, int i) { + const __m256i c32 = _mm256_loadu_si256((__m256i*)(code + i)); + const __m512i i16_32 = _mm512_cvtepu8_epi16(c32); + const __m512h fp16_32 = _mm512_cvtepi16_ph(i16_32); + const __m512h half_one_255 = + _mm512_set1_ph(static_cast(0.5f / 255.f)); + const __m512h one_255 = _mm512_set1_ph(static_cast(1.f / 255.f)); + return _mm512_fmadd_ph(fp16_32, one_255, half_one_255); + } +#endif + #if defined(__AVX512F__) static FAISS_ALWAYS_INLINE __m512 decode_16_components(const uint8_t* code, int i) { @@ -142,6 +162,29 @@ struct Codec4bit { return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f; } +#if defined(USE_AVX512_FP16) + static FAISS_ALWAYS_INLINE __m512h + decode_32_components(const uint8_t* code, int i) { + const __m128i c16 = _mm_loadu_si128((__m128i*)(code + (i >> 1))); + const __m128i even_nibbles = _mm_and_si128(c16, _mm_set1_epi8(0x0F)); + const __m128i odd_nibbles = + _mm_srli_epi16(_mm_and_si128(c16, _mm_set1_epi8(0xF0)), 4); + + const __m128i interleaved = + _mm_unpacklo_epi8(even_nibbles, odd_nibbles); + const __m128i interleaved_high = + _mm_unpackhi_epi8(even_nibbles, odd_nibbles); + const __m256i c32 = _mm256_set_m128i(interleaved_high, interleaved); + + const __m512i i16_32 = _mm512_cvtepu8_epi16(c32); + const __m512h fp16_32 = _mm512_cvtepi16_ph(i16_32); + const __m512h half_one_255 = + _mm512_set1_ph(static_cast(0.5f / 15.f)); + const __m512h one_255 = _mm512_set1_ph(static_cast(1.f / 15.f)); + return _mm512_fmadd_ph(fp16_32, one_255, half_one_255); + } +#endif + #if defined(__AVX512F__) static FAISS_ALWAYS_INLINE __m512 decode_16_components(const uint8_t* code, int i) { @@ -247,6 +290,40 @@ struct Codec6bit { return (bits + 0.5f) / 63.0f; } +#if defined(USE_AVX512_FP16) + + static FAISS_ALWAYS_INLINE __m512h + decode_32_components(const uint8_t* code, int i) { + const uint16_t* data16_0 = (const uint16_t*)(code + (i >> 2) * 3); + const uint64_t* data64_0 = (const uint64_t*)data16_0; + const uint64_t val_0 = *data64_0; + const uint64_t vext_0 = _pdep_u64(val_0, 0x3F3F3F3F3F3F3F3FULL); + + const uint16_t* data16_1 = data16_0 + 3; + const uint32_t* data32_1 = (const uint32_t*)data16_1; + const uint64_t val_1 = *data32_1 + ((uint64_t)data16_1[2] << 32); + const uint64_t vext_1 = _pdep_u64(val_1, 0x3F3F3F3F3F3F3F3FULL); + + const uint16_t* data16_2 = data16_1 + 3; + const uint64_t* data64_2 = (const uint64_t*)data16_2; + const uint64_t val_2 = *data64_2; + const uint64_t vext_2 = _pdep_u64(val_2, 0x3F3F3F3F3F3F3F3FULL); + + const uint16_t* data16_3 = data16_2 + 3; + const uint32_t* data32_3 = (const uint32_t*)data16_3; + const uint64_t val_3 = *data32_3 + ((uint64_t)data16_3[2] << 32); + const uint64_t vext_3 = _pdep_u64(val_3, 0x3F3F3F3F3F3F3F3FULL); + + const __m256i c32 = _mm256_set_epi64x(vext_3, vext_2, vext_1, vext_0); + const __m512i i16_32 = _mm512_cvtepu8_epi16(c32); + const __m512h fp16_32 = _mm512_cvtepi16_ph(i16_32); + const __m512h half_one_255 = + _mm512_set1_ph(static_cast(0.5f / 63.f)); + const __m512h one_255 = _mm512_set1_ph(static_cast(1.f / 63.f)); + return _mm512_fmadd_ph(fp16_32, one_255, half_one_255); + } +#endif + #if defined(__AVX512F__) static FAISS_ALWAYS_INLINE __m512 @@ -366,28 +443,28 @@ struct Codec6bit { enum class QuantizerTemplateScaling { UNIFORM = 0, NON_UNIFORM = 1 }; -template +template struct QuantizerTemplate {}; -template -struct QuantizerTemplate +template +struct QuantizerTemplate : ScalarQuantizer::SQuantizer { const size_t d; - const float vmin, vdiff; + const T vmin, vdiff; - QuantizerTemplate(size_t d, const std::vector& trained) + QuantizerTemplate(size_t d, const std::vector& trained) : d(d), vmin(trained[0]), vdiff(trained[1]) {} void encode_vector(const float* x, uint8_t* code) const final { for (size_t i = 0; i < d; i++) { - float xi = 0; + T xi = 0; if (vdiff != 0) { - xi = (x[i] - vmin) / vdiff; + xi = static_cast((x[i] - vmin) / vdiff); if (xi < 0) { xi = 0; } if (xi > 1.0) { - xi = 1.0; + xi = static_cast(1.0); } } Codec::encode_component(xi, code, i); @@ -396,25 +473,45 @@ struct QuantizerTemplate void decode_vector(const uint8_t* code, float* x) const final { for (size_t i = 0; i < d; i++) { - float xi = Codec::decode_component(code, i); + T xi = static_cast(Codec::decode_component(code, i)); x[i] = vmin + xi * vdiff; } } - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - float xi = Codec::decode_component(code, i); + FAISS_ALWAYS_INLINE T + reconstruct_component(const uint8_t* code, int i) const { + T xi = static_cast(Codec::decode_component(code, i)); return vmin + xi * vdiff; } }; +#if defined(USE_AVX512_FP16) + +template +struct QuantizerTemplate + : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate( + d, + trained) {} + + FAISS_ALWAYS_INLINE __m512h + reconstruct_32_components(const uint8_t* code, int i) const { + __m512h xi = Codec::decode_32_components(code, i); + return _mm512_fmadd_ph( + xi, _mm512_set1_ph(this->vdiff), _mm512_set1_ph(this->vmin)); + } +}; + +#endif + #if defined(__AVX512F__) -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate( +template +struct QuantizerTemplate + : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate( d, trained) {} @@ -428,11 +525,11 @@ struct QuantizerTemplate #elif defined(__AVX2__) -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate( +template +struct QuantizerTemplate + : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate( d, trained) {} @@ -448,11 +545,11 @@ struct QuantizerTemplate #ifdef USE_NEON -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) - : QuantizerTemplate( +template +struct QuantizerTemplate + : QuantizerTemplate { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate( d, trained) {} @@ -472,25 +569,25 @@ struct QuantizerTemplate #endif -template -struct QuantizerTemplate +template +struct QuantizerTemplate : ScalarQuantizer::SQuantizer { const size_t d; - const float *vmin, *vdiff; + const T *vmin, *vdiff; - QuantizerTemplate(size_t d, const std::vector& trained) + QuantizerTemplate(size_t d, const std::vector& trained) : d(d), vmin(trained.data()), vdiff(trained.data() + d) {} void encode_vector(const float* x, uint8_t* code) const final { for (size_t i = 0; i < d; i++) { - float xi = 0; + T xi = 0; if (vdiff[i] != 0) { - xi = (x[i] - vmin[i]) / vdiff[i]; + xi = static_cast((x[i] - vmin[i]) / vdiff[i]); if (xi < 0) { xi = 0; } if (xi > 1.0) { - xi = 1.0; + xi = static_cast(1.0); } } Codec::encode_component(xi, code, i); @@ -499,25 +596,58 @@ struct QuantizerTemplate void decode_vector(const uint8_t* code, float* x) const final { for (size_t i = 0; i < d; i++) { - float xi = Codec::decode_component(code, i); + T xi = static_cast(Codec::decode_component(code, i)); x[i] = vmin[i] + xi * vdiff[i]; } } - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - float xi = Codec::decode_component(code, i); + FAISS_ALWAYS_INLINE T + reconstruct_component(const uint8_t* code, int i) const { + T xi = static_cast(Codec::decode_component(code, i)); return vmin[i] + xi * vdiff[i]; } }; +#if defined(USE_AVX512_FP16) + +template +struct QuantizerTemplate + : QuantizerTemplate< + T, + Codec, + QuantizerTemplateScaling::NON_UNIFORM, + 1> { + QuantizerTemplate(size_t d, const std::vector& trained) + : QuantizerTemplate< + T, + Codec, + QuantizerTemplateScaling::NON_UNIFORM, + 1>(d, trained) {} + + FAISS_ALWAYS_INLINE __m512h + reconstruct_32_components(const uint8_t* code, int i) const { + __m512h xi = Codec::decode_32_components(code, i); + return _mm512_fmadd_ph( + xi, + _mm512_loadu_ph(this->vdiff + i), + _mm512_loadu_ph(this->vmin + i)); + } +}; + +#endif + #if defined(__AVX512F__) -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) +template +struct QuantizerTemplate + : QuantizerTemplate< + T, + Codec, + QuantizerTemplateScaling::NON_UNIFORM, + 1> { + QuantizerTemplate(size_t d, const std::vector& trained) : QuantizerTemplate< + T, Codec, QuantizerTemplateScaling::NON_UNIFORM, 1>(d, trained) {} @@ -534,11 +664,16 @@ struct QuantizerTemplate #elif defined(__AVX2__) -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) +template +struct QuantizerTemplate + : QuantizerTemplate< + T, + Codec, + QuantizerTemplateScaling::NON_UNIFORM, + 1> { + QuantizerTemplate(size_t d, const std::vector& trained) : QuantizerTemplate< + T, Codec, QuantizerTemplateScaling::NON_UNIFORM, 1>(d, trained) {} @@ -557,11 +692,16 @@ struct QuantizerTemplate #ifdef USE_NEON -template -struct QuantizerTemplate - : QuantizerTemplate { - QuantizerTemplate(size_t d, const std::vector& trained) +template +struct QuantizerTemplate + : QuantizerTemplate< + T, + Codec, + QuantizerTemplateScaling::NON_UNIFORM, + 1> { + QuantizerTemplate(size_t d, const std::vector& trained) : QuantizerTemplate< + T, Codec, QuantizerTemplateScaling::NON_UNIFORM, 1>(d, trained) {} @@ -584,14 +724,14 @@ struct QuantizerTemplate * FP16 quantizer *******************************************************************/ -template +template struct QuantizerFP16 {}; -template <> -struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer { +template +struct QuantizerFP16 : ScalarQuantizer::SQuantizer { const size_t d; - QuantizerFP16(size_t d, const std::vector& /* unused */) : d(d) {} + QuantizerFP16(size_t d, const std::vector& /* unused */) : d(d) {} void encode_vector(const float* x, uint8_t* code) const final { for (size_t i = 0; i < d; i++) { @@ -605,18 +745,34 @@ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer { } } - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { - return decode_fp16(((uint16_t*)code)[i]); + FAISS_ALWAYS_INLINE T + reconstruct_component(const uint8_t* code, int i) const { + return static_cast(decode_fp16(((uint16_t*)code)[i])); } }; +#if defined(USE_AVX512_FP16) + +template +struct QuantizerFP16 : QuantizerFP16 { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16(d, trained) {} + + FAISS_ALWAYS_INLINE __m512h + reconstruct_32_components(const uint8_t* code, int i) const { + __m512i codei = _mm512_loadu_si512((const __m512i*)(code + 2 * i)); + return _mm512_castsi512_ph(codei); + } +}; + +#endif + #if defined(USE_AVX512_F16C) -template <> -struct QuantizerFP16<16> : QuantizerFP16<1> { - QuantizerFP16(size_t d, const std::vector& trained) - : QuantizerFP16<1>(d, trained) {} +template +struct QuantizerFP16 : QuantizerFP16 { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16(d, trained) {} FAISS_ALWAYS_INLINE __m512 reconstruct_16_components(const uint8_t* code, int i) const { @@ -629,10 +785,10 @@ struct QuantizerFP16<16> : QuantizerFP16<1> { #if defined(USE_F16C) -template <> -struct QuantizerFP16<8> : QuantizerFP16<1> { - QuantizerFP16(size_t d, const std::vector& trained) - : QuantizerFP16<1>(d, trained) {} +template +struct QuantizerFP16 : QuantizerFP16 { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16(d, trained) {} FAISS_ALWAYS_INLINE __m256 reconstruct_8_components(const uint8_t* code, int i) const { @@ -645,10 +801,10 @@ struct QuantizerFP16<8> : QuantizerFP16<1> { #ifdef USE_NEON -template <> -struct QuantizerFP16<8> : QuantizerFP16<1> { - QuantizerFP16(size_t d, const std::vector& trained) - : QuantizerFP16<1>(d, trained) {} +template +struct QuantizerFP16 : QuantizerFP16 { + QuantizerFP16(size_t d, const std::vector& trained) + : QuantizerFP16(d, trained) {} FAISS_ALWAYS_INLINE float32x4x2_t reconstruct_8_components(const uint8_t* code, int i) const { @@ -663,14 +819,14 @@ struct QuantizerFP16<8> : QuantizerFP16<1> { * BF16 quantizer *******************************************************************/ -template +template struct QuantizerBF16 {}; -template <> -struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer { +template +struct QuantizerBF16 : ScalarQuantizer::SQuantizer { const size_t d; - QuantizerBF16(size_t d, const std::vector& /* unused */) : d(d) {} + QuantizerBF16(size_t d, const std::vector&) : d(d) {} void encode_vector(const float* x, uint8_t* code) const final { for (size_t i = 0; i < d; i++) { @@ -692,10 +848,10 @@ struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer { #if defined(__AVX512F__) -template <> -struct QuantizerBF16<16> : QuantizerBF16<1> { - QuantizerBF16(size_t d, const std::vector& trained) - : QuantizerBF16<1>(d, trained) {} +template +struct QuantizerBF16 : QuantizerBF16 { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16(d, trained) {} FAISS_ALWAYS_INLINE __m512 reconstruct_16_components(const uint8_t* code, int i) const { __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i)); @@ -707,10 +863,10 @@ struct QuantizerBF16<16> : QuantizerBF16<1> { #elif defined(__AVX2__) -template <> -struct QuantizerBF16<8> : QuantizerBF16<1> { - QuantizerBF16(size_t d, const std::vector& trained) - : QuantizerBF16<1>(d, trained) {} +template +struct QuantizerBF16 : QuantizerBF16 { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16(d, trained) {} FAISS_ALWAYS_INLINE __m256 reconstruct_8_components(const uint8_t* code, int i) const { @@ -725,10 +881,10 @@ struct QuantizerBF16<8> : QuantizerBF16<1> { #ifdef USE_NEON -template <> -struct QuantizerBF16<8> : QuantizerBF16<1> { - QuantizerBF16(size_t d, const std::vector& trained) - : QuantizerBF16<1>(d, trained) {} +template +struct QuantizerBF16 : QuantizerBF16 { + QuantizerBF16(size_t d, const std::vector& trained) + : QuantizerBF16(d, trained) {} FAISS_ALWAYS_INLINE float32x4x2_t reconstruct_8_components(const uint8_t* code, int i) const { @@ -744,15 +900,14 @@ struct QuantizerBF16<8> : QuantizerBF16<1> { * 8bit_direct quantizer *******************************************************************/ -template +template struct Quantizer8bitDirect {}; -template <> -struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer { +template +struct Quantizer8bitDirect : ScalarQuantizer::SQuantizer { const size_t d; - Quantizer8bitDirect(size_t d, const std::vector& /* unused */) - : d(d) {} + Quantizer8bitDirect(size_t d, const std::vector /* unused */) : d(d) {} void encode_vector(const float* x, uint8_t* code) const final { for (size_t i = 0; i < d; i++) { @@ -766,18 +921,35 @@ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer { } } - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { + FAISS_ALWAYS_INLINE T + reconstruct_component(const uint8_t* code, int i) const { return code[i]; } }; +#if defined(USE_AVX512_FP16) + +template +struct Quantizer8bitDirect : Quantizer8bitDirect { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect(d, trained) {} + + FAISS_ALWAYS_INLINE __m512h + reconstruct_32_components(const uint8_t* code, int i) const { + __m256i x32 = _mm256_loadu_si256((__m256i*)(code + i)); + __m512i y32 = _mm512_cvtepu8_epi16(x32); + return _mm512_cvtepi16_ph(y32); + } +}; + +#endif + #if defined(__AVX512F__) -template <> -struct Quantizer8bitDirect<16> : Quantizer8bitDirect<1> { - Quantizer8bitDirect(size_t d, const std::vector& trained) - : Quantizer8bitDirect<1>(d, trained) {} +template +struct Quantizer8bitDirect : Quantizer8bitDirect { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect(d, trained) {} FAISS_ALWAYS_INLINE __m512 reconstruct_16_components(const uint8_t* code, int i) const { @@ -789,10 +961,10 @@ struct Quantizer8bitDirect<16> : Quantizer8bitDirect<1> { #elif defined(__AVX2__) -template <> -struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { - Quantizer8bitDirect(size_t d, const std::vector& trained) - : Quantizer8bitDirect<1>(d, trained) {} +template +struct Quantizer8bitDirect : Quantizer8bitDirect { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect(d, trained) {} FAISS_ALWAYS_INLINE __m256 reconstruct_8_components(const uint8_t* code, int i) const { @@ -806,10 +978,10 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { #ifdef USE_NEON -template <> -struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { - Quantizer8bitDirect(size_t d, const std::vector& trained) - : Quantizer8bitDirect<1>(d, trained) {} +template +struct Quantizer8bitDirect : Quantizer8bitDirect { + Quantizer8bitDirect(size_t d, const std::vector& trained) + : Quantizer8bitDirect(d, trained) {} FAISS_ALWAYS_INLINE float32x4x2_t reconstruct_8_components(const uint8_t* code, int i) const { @@ -829,14 +1001,14 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> { * 8bit_direct_signed quantizer *******************************************************************/ -template +template struct Quantizer8bitDirectSigned {}; -template <> -struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer { +template +struct Quantizer8bitDirectSigned : ScalarQuantizer::SQuantizer { const size_t d; - Quantizer8bitDirectSigned(size_t d, const std::vector& /* unused */) + Quantizer8bitDirectSigned(size_t d, const std::vector& /* unused */) : d(d) {} void encode_vector(const float* x, uint8_t* code) const final { @@ -851,18 +1023,37 @@ struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer { } } - FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i) - const { + FAISS_ALWAYS_INLINE T + reconstruct_component(const uint8_t* code, int i) const { return code[i] - 128; } }; +#if defined(USE_AVX512_FP16) + +template +struct Quantizer8bitDirectSigned : Quantizer8bitDirectSigned { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned(d, trained) {} + + FAISS_ALWAYS_INLINE __m512h + reconstruct_32_components(const uint8_t* code, int i) const { + __m256i x32 = _mm256_loadu_si256((__m256i*)(code + i)); + __m512i y32 = _mm512_cvtepu8_epi16(x32); + __m512i c32 = _mm512_set1_epi16(128); + __m512i z32 = _mm512_sub_epi16(y32, c32); + return _mm512_cvtepi16_ph(z32); + } +}; + +#endif + #if defined(__AVX512F__) -template <> -struct Quantizer8bitDirectSigned<16> : Quantizer8bitDirectSigned<1> { - Quantizer8bitDirectSigned(size_t d, const std::vector& trained) - : Quantizer8bitDirectSigned<1>(d, trained) {} +template +struct Quantizer8bitDirectSigned : Quantizer8bitDirectSigned { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned(d, trained) {} FAISS_ALWAYS_INLINE __m512 reconstruct_16_components(const uint8_t* code, int i) const { @@ -876,10 +1067,10 @@ struct Quantizer8bitDirectSigned<16> : Quantizer8bitDirectSigned<1> { #elif defined(__AVX2__) -template <> -struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { - Quantizer8bitDirectSigned(size_t d, const std::vector& trained) - : Quantizer8bitDirectSigned<1>(d, trained) {} +template +struct Quantizer8bitDirectSigned : Quantizer8bitDirectSigned { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned(d, trained) {} FAISS_ALWAYS_INLINE __m256 reconstruct_8_components(const uint8_t* code, int i) const { @@ -895,10 +1086,10 @@ struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { #ifdef USE_NEON -template <> -struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { - Quantizer8bitDirectSigned(size_t d, const std::vector& trained) - : Quantizer8bitDirectSigned<1>(d, trained) {} +template +struct Quantizer8bitDirectSigned : Quantizer8bitDirectSigned { + Quantizer8bitDirectSigned(size_t d, const std::vector& trained) + : Quantizer8bitDirectSigned(d, trained) {} FAISS_ALWAYS_INLINE float32x4x2_t reconstruct_8_components(const uint8_t* code, int i) const { @@ -919,45 +1110,54 @@ struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> { #endif -template +template ScalarQuantizer::SQuantizer* select_quantizer_1( QuantizerType qtype, size_t d, - const std::vector& trained) { + const std::vector& trained) { switch (qtype) { case ScalarQuantizer::QT_8bit: return new QuantizerTemplate< + T, Codec8bit, QuantizerTemplateScaling::NON_UNIFORM, SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_6bit: return new QuantizerTemplate< + T, Codec6bit, QuantizerTemplateScaling::NON_UNIFORM, SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_4bit: return new QuantizerTemplate< + T, Codec4bit, QuantizerTemplateScaling::NON_UNIFORM, SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_8bit_uniform: return new QuantizerTemplate< + T, Codec8bit, QuantizerTemplateScaling::UNIFORM, SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_4bit_uniform: return new QuantizerTemplate< + T, Codec4bit, QuantizerTemplateScaling::UNIFORM, SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_fp16: - return new QuantizerFP16(d, trained); + return new QuantizerFP16(d, trained); case ScalarQuantizer::QT_bf16: - return new QuantizerBF16(d, trained); +#if defined(USE_AVX512_FP16) + return new QuantizerBF16(d, trained); +#else + return new QuantizerBF16(d, trained); +#endif case ScalarQuantizer::QT_8bit_direct: - return new Quantizer8bitDirect(d, trained); + return new Quantizer8bitDirect(d, trained); case ScalarQuantizer::QT_8bit_direct_signed: - return new Quantizer8bitDirectSigned(d, trained); + return new Quantizer8bitDirectSigned(d, trained); } FAISS_THROW_MSG("unknown qtype"); } @@ -970,29 +1170,32 @@ static float sqr(float x) { return x * x; } +template void train_Uniform( RangeStat rs, float rs_arg, idx_t n, int k, const float* x, - std::vector& trained) { + std::vector& trained) { trained.resize(2); - float& vmin = trained[0]; - float& vmax = trained[1]; + T& vmin = trained[0]; + T& vmax = trained[1]; if (rs == ScalarQuantizer::RS_minmax) { - vmin = HUGE_VAL; - vmax = -HUGE_VAL; + vmin = static_cast(HUGE_VAL); + vmax = static_cast(-HUGE_VAL); for (size_t i = 0; i < n; i++) { - if (x[i] < vmin) - vmin = x[i]; - if (x[i] > vmax) - vmax = x[i]; + float fmin = static_cast(vmin); + float fmax = static_cast(vmax); + if (x[i] < fmin) + vmin = static_cast(x[i]); + if (x[i] > fmax) + vmax = static_cast(x[i]); } float vexp = (vmax - vmin) * rs_arg; - vmin -= vexp; - vmax += vexp; + vmin = static_cast(vmin - vexp); + vmax = static_cast(vmax + vexp); } else if (rs == ScalarQuantizer::RS_meanstd) { double sum = 0, sum2 = 0; for (size_t i = 0; i < n; i++) { @@ -1003,34 +1206,37 @@ void train_Uniform( float var = sum2 / n - mean * mean; float std = var <= 0 ? 1.0 : sqrt(var); - vmin = mean - std * rs_arg; - vmax = mean + std * rs_arg; + vmin = static_cast(mean - std * rs_arg); + vmax = static_cast(mean + std * rs_arg); } else if (rs == ScalarQuantizer::RS_quantiles) { std::vector x_copy(n); memcpy(x_copy.data(), x, n * sizeof(*x)); // TODO just do a quickselect std::sort(x_copy.begin(), x_copy.end()); - int o = int(rs_arg * n); + int o = static_cast(rs_arg * n); if (o < 0) o = 0; if (o > n - o) o = n / 2; - vmin = x_copy[o]; - vmax = x_copy[n - 1 - o]; + vmin = static_cast(x_copy[o]); + vmax = static_cast(x_copy[n - 1 - o]); } else if (rs == ScalarQuantizer::RS_optim) { float a, b; float sx = 0; { - vmin = HUGE_VAL, vmax = -HUGE_VAL; + vmin = static_cast(HUGE_VAL); + vmax = static_cast(-HUGE_VAL); for (size_t i = 0; i < n; i++) { - if (x[i] < vmin) - vmin = x[i]; - if (x[i] > vmax) - vmax = x[i]; + float fmin = static_cast(vmin); + float fmax = static_cast(vmax); + if (x[i] < fmin) + vmin = static_cast(x[i]); + if (x[i] > fmax) + vmax = static_cast(x[i]); sx += x[i]; } - b = vmin; + b = static_cast(vmin); a = (vmax - vmin) / (k - 1); } int verbose = false; @@ -1074,8 +1280,8 @@ void train_Uniform( if (verbose) printf("\n"); - vmin = b; - vmax = b + a * (k - 1); + vmin = static_cast(b); + vmax = static_cast(b + a * (k - 1)); } else { FAISS_THROW_MSG("Invalid qtype"); @@ -1083,6 +1289,7 @@ void train_Uniform( vmax -= vmin; } +template void train_NonUniform( RangeStat rs, float rs_arg, @@ -1090,28 +1297,36 @@ void train_NonUniform( int d, int k, const float* x, - std::vector& trained) { + std::vector& trained) { trained.resize(2 * d); - float* vmin = trained.data(); - float* vmax = trained.data() + d; + T* vmin = trained.data(); + T* vmax = trained.data() + d; if (rs == ScalarQuantizer::RS_minmax) { - memcpy(vmin, x, sizeof(*x) * d); - memcpy(vmax, x, sizeof(*x) * d); + for (int i = 0; i < d; i++) { + vmin[i] = static_cast(x[i]); + vmax[i] = static_cast(x[i]); + } for (size_t i = 1; i < n; i++) { const float* xi = x + i * d; for (size_t j = 0; j < d; j++) { - if (xi[j] < vmin[j]) - vmin[j] = xi[j]; - if (xi[j] > vmax[j]) - vmax[j] = xi[j]; + float fmin = static_cast(vmin[j]); + float fmax = static_cast(vmax[j]); + if (xi[j] < fmin) + vmin[j] = static_cast(xi[j]); + if (xi[j] > fmax) + vmax[j] = static_cast(xi[j]); } } - float* vdiff = vmax; + T* vdiff = vmax; for (size_t j = 0; j < d; j++) { - float vexp = (vmax[j] - vmin[j]) * rs_arg; - vmin[j] -= vexp; - vmax[j] += vexp; - vdiff[j] = vmax[j] - vmin[j]; + float fmin = static_cast(vmin[j]); + float fmax = static_cast(vmax[j]); + float vexp = (fmax - fmin) * rs_arg; + fmin -= vexp; + fmax += vexp; + vmin[j] = static_cast(fmin); + vmax[j] = static_cast(fmax); + vdiff[j] = static_cast(fmax - fmin); } } else { // transpose @@ -1122,12 +1337,12 @@ void train_NonUniform( xt[j * n + i] = xi[j]; } } - std::vector trained_d(2); + std::vector trained_d(2); #pragma omp parallel for for (int j = 0; j < d; j++) { - train_Uniform(rs, rs_arg, n, k, xt.data() + j * n, trained_d); - vmin[j] = trained_d[0]; - vmax[j] = trained_d[1]; + train_Uniform(rs, rs_arg, n, k, xt.data() + j * n, trained_d); + vmin[j] = static_cast(trained_d[0]); + vmax[j] = static_cast(trained_d[1]); } } } @@ -1174,6 +1389,49 @@ struct SimilarityL2<1> { } }; +#if defined(USE_AVX512_FP16) + +template <> +struct SimilarityL2<32> { + static constexpr int simdwidth = 32; + static constexpr MetricType metric_type = METRIC_L2; + + const float *y, *yi; + + explicit SimilarityL2(const float* y) : y(y) {} + __m512h accu32; + + FAISS_ALWAYS_INLINE void begin_32() { + accu32 = _mm512_setzero_ph(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_32_components(__m512h x) { + __m256i lo = _mm512_cvtps_ph( + _mm512_loadu_ps(yi), + _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + __m256i hi = _mm512_cvtps_ph( + _mm512_loadu_ps(yi + 16), + _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + __m512h yiv = _mm512_castsi512_ph( + _mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1)); + yi += 32; + __m512h tmp = _mm512_sub_ph(yiv, x); + accu32 = _mm512_fmadd_ph(tmp, tmp, accu32); + } + + FAISS_ALWAYS_INLINE void add_32_components_2(__m512h x, __m512h y_2) { + __m512h tmp = _mm512_sub_ph(y_2, x); + accu32 = _mm512_fmadd_ph(tmp, tmp, accu32); + } + + FAISS_ALWAYS_INLINE float result_32() { + return _mm512_reduce_add_ph(accu32); + } +}; + +#endif + #if defined(__AVX512F__) template <> @@ -1333,6 +1591,48 @@ struct SimilarityIP<1> { } }; +#if defined(USE_AVX512_FP16) + +template <> +struct SimilarityIP<32> { + static constexpr int simdwidth = 32; + static constexpr MetricType metric_type = METRIC_INNER_PRODUCT; + + const float *y, *yi; + + explicit SimilarityIP(const float* y) : y(y) {} + + __m512h accu32; + + FAISS_ALWAYS_INLINE void begin_32() { + accu32 = _mm512_setzero_ph(); + yi = y; + } + + FAISS_ALWAYS_INLINE void add_32_components(__m512h x) { + __m256i lo = _mm512_cvtps_ph( + _mm512_loadu_ps(yi), + _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + __m256i hi = _mm512_cvtps_ph( + _mm512_loadu_ps(yi + 16), + _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + __m512h yiv = _mm512_castsi512_ph( + _mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1)); + yi += 32; + accu32 = _mm512_fmadd_ph(yiv, x, accu32); + } + + FAISS_ALWAYS_INLINE void add_32_components_2(__m512h x1, __m512h x2) { + accu32 = _mm512_fmadd_ph(x1, x2, accu32); + } + + FAISS_ALWAYS_INLINE float result_32() { + return _mm512_reduce_add_ph(accu32); + } +}; + +#endif + #if defined(__AVX512F__) template <> @@ -1463,23 +1763,22 @@ struct SimilarityIP<8> { * code-to-vector or code-to-code comparisons *******************************************************************/ -template +template struct DCTemplate : SQDistanceComputer {}; -template -struct DCTemplate : SQDistanceComputer { +template +struct DCTemplate : SQDistanceComputer { using Sim = Similarity; Quantizer quant; - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} + DCTemplate(size_t d, const std::vector& trained) : quant(d, trained) {} float compute_distance(const float* x, const uint8_t* code) const { Similarity sim(x); sim.begin(); for (size_t i = 0; i < quant.d; i++) { - float xi = quant.reconstruct_component(code, i); + T xi = static_cast(quant.reconstruct_component(code, i)); sim.add_component(xi); } return sim.result(); @@ -1490,8 +1789,8 @@ struct DCTemplate : SQDistanceComputer { Similarity sim(nullptr); sim.begin(); for (size_t i = 0; i < quant.d; i++) { - float x1 = quant.reconstruct_component(code1, i); - float x2 = quant.reconstruct_component(code2, i); + T x1 = static_cast(quant.reconstruct_component(code1, i)); + T x2 = static_cast(quant.reconstruct_component(code2, i)); sim.add_component_2(x1, x2); } return sim.result(); @@ -1511,17 +1810,155 @@ struct DCTemplate : SQDistanceComputer { } }; +#if defined(USE_AVX512_FP16) + +template +struct DCTemplate + : SQDistanceComputer { // Update to handle 32 lanes + using Sim = Similarity; + + Quantizer quant; + + DCTemplate(size_t d, const std::vector& trained) : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Similarity sim(x); + sim.begin_32(); + for (size_t i = 0; i < quant.d; i += 32) { + __m512h xi = quant.reconstruct_32_components(code, i); + // print_m512h(xi); + sim.add_32_components(xi); + } + return sim.result_32(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Similarity sim(nullptr); + sim.begin_32(); + for (size_t i = 0; i < quant.d; i += 32) { + __m512h x1 = quant.reconstruct_32_components(code1, i); + __m512h x2 = quant.reconstruct_32_components(code2, i); + sim.add_32_components_2(x1, x2); + } + return sim.result_32(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; + +template +struct DCTemplate, SimilarityL2<16>, 16> + : SQDistanceComputer { // Update to handle 16 lanes + using Sim = SimilarityL2<16>; + + QuantizerBF16 quant; + + DCTemplate(size_t d, const std::vector& trained) : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Sim sim(x); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + __m512 xi = quant.reconstruct_16_components(code, i); + sim.add_16_components(xi); + } + return sim.result_16(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Sim sim(nullptr); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + __m512 x1 = quant.reconstruct_16_components(code1, i); + __m512 x2 = quant.reconstruct_16_components(code2, i); + sim.add_16_components_2(x1, x2); + } + return sim.result_16(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; + +template +struct DCTemplate, SimilarityIP<16>, 16> + : SQDistanceComputer { // Update to handle 16 lanes + using Sim = SimilarityIP<16>; + + QuantizerBF16 quant; + + DCTemplate(size_t d, const std::vector& trained) : quant(d, trained) {} + + float compute_distance(const float* x, const uint8_t* code) const { + Sim sim(x); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + __m512 xi = quant.reconstruct_16_components(code, i); + sim.add_16_components(xi); + } + return sim.result_16(); + } + + float compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + Sim sim(nullptr); + sim.begin_16(); + for (size_t i = 0; i < quant.d; i += 16) { + __m512 x1 = quant.reconstruct_16_components(code1, i); + __m512 x2 = quant.reconstruct_16_components(code2, i); + sim.add_16_components_2(x1, x2); + } + return sim.result_16(); + } + + void set_query(const float* x) final { + q = x; + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_distance(q, code); + } +}; +#endif + #if defined(USE_AVX512_F16C) -template -struct DCTemplate +template +struct DCTemplate : SQDistanceComputer { // Update to handle 16 lanes using Sim = Similarity; Quantizer quant; - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} + DCTemplate(size_t d, const std::vector& trained) : quant(d, trained) {} float compute_distance(const float* x, const uint8_t* code) const { Similarity sim(x); @@ -1561,14 +1998,13 @@ struct DCTemplate #elif defined(USE_F16C) -template -struct DCTemplate : SQDistanceComputer { +template +struct DCTemplate : SQDistanceComputer { using Sim = Similarity; Quantizer quant; - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} + DCTemplate(size_t d, const std::vector& trained) : quant(d, trained) {} float compute_distance(const float* x, const uint8_t* code) const { Similarity sim(x); @@ -1610,14 +2046,13 @@ struct DCTemplate : SQDistanceComputer { #ifdef USE_NEON -template -struct DCTemplate : SQDistanceComputer { +template +struct DCTemplate : SQDistanceComputer { using Sim = Similarity; Quantizer quant; - DCTemplate(size_t d, const std::vector& trained) - : quant(d, trained) {} + DCTemplate(size_t d, const std::vector& trained) : quant(d, trained) {} float compute_distance(const float* x, const uint8_t* code) const { Similarity sim(x); sim.begin_8(); @@ -1659,17 +2094,17 @@ struct DCTemplate : SQDistanceComputer { * DistanceComputerByte: computes distances in the integer domain *******************************************************************/ -template +template struct DistanceComputerByte : SQDistanceComputer {}; -template -struct DistanceComputerByte : SQDistanceComputer { +template +struct DistanceComputerByte : SQDistanceComputer { using Sim = Similarity; int d; std::vector tmp; - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} int compute_code_distance(const uint8_t* code1, const uint8_t* code2) const { @@ -1706,16 +2141,71 @@ struct DistanceComputerByte : SQDistanceComputer { } }; +#if defined(USE_AVX512_FP16) + +template +struct DistanceComputerByte : SQDistanceComputer { + using Sim = Similarity; + + int d; + std::vector tmp; + + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + + int compute_code_distance(const uint8_t* code1, const uint8_t* code2) + const { + __m512i accu = _mm512_setzero_si512(); + for (int i = 0; i < d; i += 32) { // Process 32 bytes at a time + __m512i c1 = _mm512_cvtepu8_epi16( + _mm256_loadu_si256((__m256i*)(code1 + i))); + __m512i c2 = _mm512_cvtepu8_epi16( + _mm256_loadu_si256((__m256i*)(code2 + i))); + __m512i prod32; + if (Sim::metric_type == METRIC_INNER_PRODUCT) { + prod32 = _mm512_madd_epi16(c1, c2); + } else { + __m512i diff = _mm512_sub_epi16(c1, c2); + prod32 = _mm512_madd_epi16(diff, diff); + } + accu = _mm512_add_epi32(accu, prod32); + } + // Horizontally add elements of accu + return _mm512_reduce_add_epi32(accu); + } + + void set_query(const float* x) final { + for (int i = 0; i < d; i++) { + tmp[i] = int(x[i]); + } + } + + int compute_distance(const float* x, const uint8_t* code) { + set_query(x); + return compute_code_distance(tmp.data(), code); + } + + float symmetric_dis(idx_t i, idx_t j) override { + return compute_code_distance( + codes + i * code_size, codes + j * code_size); + } + + float query_to_code(const uint8_t* code) const final { + return compute_code_distance(tmp.data(), code); + } +}; + +#endif + #if defined(__AVX512F__) -template -struct DistanceComputerByte : SQDistanceComputer { +template +struct DistanceComputerByte : SQDistanceComputer { using Sim = Similarity; int d; std::vector tmp; - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} int compute_code_distance(const uint8_t* code1, const uint8_t* code2) const { @@ -1761,14 +2251,14 @@ struct DistanceComputerByte : SQDistanceComputer { #elif defined(__AVX2__) -template -struct DistanceComputerByte : SQDistanceComputer { +template +struct DistanceComputerByte : SQDistanceComputer { using Sim = Similarity; int d; std::vector tmp; - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} int compute_code_distance(const uint8_t* code1, const uint8_t* code2) const { @@ -1826,14 +2316,14 @@ struct DistanceComputerByte : SQDistanceComputer { #ifdef USE_NEON -template -struct DistanceComputerByte : SQDistanceComputer { +template +struct DistanceComputerByte : SQDistanceComputer { using Sim = Similarity; int d; std::vector tmp; - DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} + DistanceComputerByte(int d, const std::vector&) : d(d), tmp(d) {} int compute_code_distance(const uint8_t* code1, const uint8_t* code2) const { @@ -1877,16 +2367,18 @@ struct DistanceComputerByte : SQDistanceComputer { * specialization *******************************************************************/ -template +template SQDistanceComputer* select_distance_computer( QuantizerType qtype, size_t d, - const std::vector& trained) { + const std::vector& trained) { constexpr int SIMDWIDTH = Sim::simdwidth; switch (qtype) { case ScalarQuantizer::QT_8bit_uniform: return new DCTemplate< + T, QuantizerTemplate< + T, Codec8bit, QuantizerTemplateScaling::UNIFORM, SIMDWIDTH>, @@ -1895,7 +2387,9 @@ SQDistanceComputer* select_distance_computer( case ScalarQuantizer::QT_4bit_uniform: return new DCTemplate< + T, QuantizerTemplate< + T, Codec4bit, QuantizerTemplateScaling::UNIFORM, SIMDWIDTH>, @@ -1904,7 +2398,9 @@ SQDistanceComputer* select_distance_computer( case ScalarQuantizer::QT_8bit: return new DCTemplate< + T, QuantizerTemplate< + T, Codec8bit, QuantizerTemplateScaling::NON_UNIFORM, SIMDWIDTH>, @@ -1913,7 +2409,9 @@ SQDistanceComputer* select_distance_computer( case ScalarQuantizer::QT_6bit: return new DCTemplate< + T, QuantizerTemplate< + T, Codec6bit, QuantizerTemplateScaling::NON_UNIFORM, SIMDWIDTH>, @@ -1922,7 +2420,9 @@ SQDistanceComputer* select_distance_computer( case ScalarQuantizer::QT_4bit: return new DCTemplate< + T, QuantizerTemplate< + T, Codec4bit, QuantizerTemplateScaling::NON_UNIFORM, SIMDWIDTH>, @@ -1930,32 +2430,54 @@ SQDistanceComputer* select_distance_computer( SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_fp16: - return new DCTemplate, Sim, SIMDWIDTH>( - d, trained); - + return new DCTemplate< + T, + QuantizerFP16, + Sim, + SIMDWIDTH>(d, trained); case ScalarQuantizer::QT_bf16: - return new DCTemplate, Sim, SIMDWIDTH>( - d, trained); - +#if defined(USE_AVX512_FP16) + if (Sim::metric_type == METRIC_L2) { + return new DCTemplate< + T, + QuantizerBF16, + SimilarityL2<16>, + 16>(d, trained); + } else { + return new DCTemplate< + T, + QuantizerBF16, + SimilarityIP<16>, + 16>(d, trained); + } +#else + return new DCTemplate< + T, + QuantizerBF16, + Sim, + SIMDWIDTH>(d, trained); +#endif case ScalarQuantizer::QT_8bit_direct: #if defined(__AVX512F__) if (d % 32 == 0) { - return new DistanceComputerByte(d, trained); + return new DistanceComputerByte(d, trained); } else #elif defined(__AVX2__) if (d % 16 == 0) { - return new DistanceComputerByte(d, trained); + return new DistanceComputerByte(d, trained); } else #endif { return new DCTemplate< - Quantizer8bitDirect, + T, + Quantizer8bitDirect, Sim, SIMDWIDTH>(d, trained); } case ScalarQuantizer::QT_8bit_direct_signed: return new DCTemplate< - Quantizer8bitDirectSigned, + T, + Quantizer8bitDirectSigned, Sim, SIMDWIDTH>(d, trained); } @@ -2016,7 +2538,7 @@ void ScalarQuantizer::train(size_t n, const float* x) { switch (qtype) { case QT_4bit_uniform: case QT_8bit_uniform: - train_Uniform( + train_Uniform( // TODO rangestat, rangestat_arg, n * d, @@ -2027,7 +2549,7 @@ void ScalarQuantizer::train(size_t n, const float* x) { case QT_4bit: case QT_8bit: case QT_6bit: - train_NonUniform( + train_NonUniform( rangestat, rangestat_arg, n, @@ -2046,17 +2568,23 @@ void ScalarQuantizer::train(size_t n, const float* x) { } ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const { -#if defined(USE_AVX512_F16C) + using T = decltype(trained)::value_type; + +#if defined(USE_AVX512_FP16) + if (d % 32 == 0 && qtype != QT_bf16) { + return select_quantizer_1(qtype, d, trained); + } else +#elif defined(USE_AVX512_F16C) if (d % 16 == 0) { - return select_quantizer_1<16>(qtype, d, trained); + return select_quantizer_1(qtype, d, trained); } else #elif defined(USE_F16C) || defined(USE_NEON) if (d % 8 == 0) { - return select_quantizer_1<8>(qtype, d, trained); + return select_quantizer_1(qtype, d, trained); } else #endif { - return select_quantizer_1<1>(qtype, d, trained); + return select_quantizer_1(qtype, d, trained); } } @@ -2080,31 +2608,47 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const { SQDistanceComputer* ScalarQuantizer::get_distance_computer( MetricType metric) const { + using T = decltype(trained)::value_type; + FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT); -#if defined(USE_AVX512_F16C) +#if defined(USE_AVX512_FP16) + if (d % 32 == 0 && qtype != QT_bf16) { + if (metric == METRIC_L2) { + return select_distance_computer>( + qtype, d, trained); + } else { + return select_distance_computer>( + qtype, d, trained); + } + } else +#elif defined(USE_AVX512_F16C) if (d % 16 == 0) { if (metric == METRIC_L2) { - return select_distance_computer>( + return select_distance_computer>( qtype, d, trained); } else { - return select_distance_computer>( + return select_distance_computer>( qtype, d, trained); } } else #elif defined(USE_F16C) || defined(USE_NEON) if (d % 8 == 0) { if (metric == METRIC_L2) { - return select_distance_computer>(qtype, d, trained); + return select_distance_computer>( + qtype, d, trained); } else { - return select_distance_computer>(qtype, d, trained); + return select_distance_computer>( + qtype, d, trained); } } else #endif { if (metric == METRIC_L2) { - return select_distance_computer>(qtype, d, trained); + return select_distance_computer>( + qtype, d, trained); } else { - return select_distance_computer>(qtype, d, trained); + return select_distance_computer>( + qtype, d, trained); } } } @@ -2118,7 +2662,7 @@ SQDistanceComputer* ScalarQuantizer::get_distance_computer( namespace { -template +template struct IVFSQScannerIP : InvertedListScanner { DCClass dc; bool by_residual; @@ -2127,7 +2671,7 @@ struct IVFSQScannerIP : InvertedListScanner { IVFSQScannerIP( int d, - const std::vector& trained, + const std::vector& trained, size_t code_size, bool store_pairs, const IDSelector* sel, @@ -2201,7 +2745,7 @@ struct IVFSQScannerIP : InvertedListScanner { * = 1: check on ids[j] * = 2: check in j directly (normally ids is nullptr and store_pairs) */ -template +template struct IVFSQScannerL2 : InvertedListScanner { DCClass dc; @@ -2213,7 +2757,7 @@ struct IVFSQScannerL2 : InvertedListScanner { IVFSQScannerL2( int d, - const std::vector& trained, + const std::vector& trained, size_t code_size, const Index* quantizer, bool store_pairs, @@ -2295,7 +2839,7 @@ struct IVFSQScannerL2 : InvertedListScanner { } }; -template +template InvertedListScanner* sel3_InvertedListScanner( const ScalarQuantizer* sq, const Index* quantizer, @@ -2303,7 +2847,7 @@ InvertedListScanner* sel3_InvertedListScanner( const IDSelector* sel, bool r) { if (DCClass::Sim::metric_type == METRIC_L2) { - return new IVFSQScannerL2( + return new IVFSQScannerL2( sq->d, sq->trained, sq->code_size, @@ -2312,14 +2856,14 @@ InvertedListScanner* sel3_InvertedListScanner( sel, r); } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) { - return new IVFSQScannerIP( + return new IVFSQScannerIP( sq->d, sq->trained, sq->code_size, store_pairs, sel, r); } else { FAISS_THROW_MSG("unsupported metric type"); } } -template +template InvertedListScanner* sel2_InvertedListScanner( const ScalarQuantizer* sq, const Index* quantizer, @@ -2328,19 +2872,23 @@ InvertedListScanner* sel2_InvertedListScanner( bool r) { if (sel) { if (store_pairs) { - return sel3_InvertedListScanner( + return sel3_InvertedListScanner( sq, quantizer, store_pairs, sel, r); } else { - return sel3_InvertedListScanner( + return sel3_InvertedListScanner( sq, quantizer, store_pairs, sel, r); } } else { - return sel3_InvertedListScanner( + return sel3_InvertedListScanner( sq, quantizer, store_pairs, sel, r); } } -template +template < + class T, + class Similarity, + class Codec, + QuantizerTemplateScaling SCALING> InvertedListScanner* sel12_InvertedListScanner( const ScalarQuantizer* sq, const Index* quantizer, @@ -2348,13 +2896,13 @@ InvertedListScanner* sel12_InvertedListScanner( const IDSelector* sel, bool r) { constexpr int SIMDWIDTH = Similarity::simdwidth; - using QuantizerClass = QuantizerTemplate; - using DCClass = DCTemplate; - return sel2_InvertedListScanner( + using QuantizerClass = QuantizerTemplate; + using DCClass = DCTemplate; + return sel2_InvertedListScanner( sq, quantizer, store_pairs, sel, r); } -template +template InvertedListScanner* sel1_InvertedListScanner( const ScalarQuantizer* sq, const Index* quantizer, @@ -2365,76 +2913,115 @@ InvertedListScanner* sel1_InvertedListScanner( switch (sq->qtype) { case ScalarQuantizer::QT_8bit_uniform: return sel12_InvertedListScanner< + T, Similarity, Codec8bit, QuantizerTemplateScaling::UNIFORM>( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_4bit_uniform: return sel12_InvertedListScanner< + T, Similarity, Codec4bit, QuantizerTemplateScaling::UNIFORM>( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_8bit: return sel12_InvertedListScanner< + T, Similarity, Codec8bit, QuantizerTemplateScaling::NON_UNIFORM>( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_4bit: return sel12_InvertedListScanner< + T, Similarity, Codec4bit, QuantizerTemplateScaling::NON_UNIFORM>( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_6bit: return sel12_InvertedListScanner< + T, Similarity, Codec6bit, QuantizerTemplateScaling::NON_UNIFORM>( sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_fp16: - return sel2_InvertedListScanner, - Similarity, - SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); + return sel2_InvertedListScanner< + T, + DCTemplate< + T, + QuantizerFP16, + Similarity, + SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); case ScalarQuantizer::QT_bf16: - return sel2_InvertedListScanner, - Similarity, - SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); +#if defined(USE_AVX512_FP16) + if (Similarity::metric_type == METRIC_L2) { + return sel2_InvertedListScanner< + T, + DCTemplate< + T, + QuantizerBF16, + SimilarityL2<16>, + 16>>(sq, quantizer, store_pairs, sel, r); + } else { + return sel2_InvertedListScanner< + T, + DCTemplate< + T, + QuantizerBF16, + SimilarityIP<16>, + 16>>(sq, quantizer, store_pairs, sel, r); + } +#else + return sel2_InvertedListScanner< + T, + DCTemplate< + T, + QuantizerBF16, + Similarity, + SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); +#endif case ScalarQuantizer::QT_8bit_direct: #if defined(__AVX512F__) if (sq->d % 32 == 0) { return sel2_InvertedListScanner< - DistanceComputerByte>( + T, + DistanceComputerByte>( sq, quantizer, store_pairs, sel, r); } else #elif defined(__AVX2__) if (sq->d % 16 == 0) { return sel2_InvertedListScanner< - DistanceComputerByte>( + T, + DistanceComputerByte>( sq, quantizer, store_pairs, sel, r); } else #endif { - return sel2_InvertedListScanner, - Similarity, - SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); + return sel2_InvertedListScanner< + T, + DCTemplate< + T, + Quantizer8bitDirect, + Similarity, + SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); } case ScalarQuantizer::QT_8bit_direct_signed: - return sel2_InvertedListScanner, - Similarity, - SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); + return sel2_InvertedListScanner< + T, + DCTemplate< + T, + Quantizer8bitDirectSigned, + Similarity, + SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r); } FAISS_THROW_MSG("unknown qtype"); return nullptr; } -template +template InvertedListScanner* sel0_InvertedListScanner( MetricType mt, const ScalarQuantizer* sq, @@ -2443,10 +3030,10 @@ InvertedListScanner* sel0_InvertedListScanner( const IDSelector* sel, bool by_residual) { if (mt == METRIC_L2) { - return sel1_InvertedListScanner>( + return sel1_InvertedListScanner>( sq, quantizer, store_pairs, sel, by_residual); } else if (mt == METRIC_INNER_PRODUCT) { - return sel1_InvertedListScanner>( + return sel1_InvertedListScanner>( sq, quantizer, store_pairs, sel, by_residual); } else { FAISS_THROW_MSG("unsupported metric type"); @@ -2461,19 +3048,25 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner( bool store_pairs, const IDSelector* sel, bool by_residual) const { -#if defined(USE_AVX512_F16C) + using T = decltype(trained)::value_type; +#if defined(USE_AVX512_FP16) + if (d % 32 == 0 && this->qtype != QT_bf16) { + return sel0_InvertedListScanner( + mt, this, quantizer, store_pairs, sel, by_residual); + } else +#elif defined(USE_AVX512_F16C) if (d % 16 == 0) { - return sel0_InvertedListScanner<16>( + return sel0_InvertedListScanner( mt, this, quantizer, store_pairs, sel, by_residual); } else #elif defined(USE_F16C) || defined(USE_NEON) if (d % 8 == 0) { - return sel0_InvertedListScanner<8>( + return sel0_InvertedListScanner( mt, this, quantizer, store_pairs, sel, by_residual); } else #endif { - return sel0_InvertedListScanner<1>( + return sel0_InvertedListScanner( mt, this, quantizer, store_pairs, sel, by_residual); } } diff --git a/faiss/impl/ScalarQuantizer.h b/faiss/impl/ScalarQuantizer.h index c1f4f98f63..82dcf69b8c 100644 --- a/faiss/impl/ScalarQuantizer.h +++ b/faiss/impl/ScalarQuantizer.h @@ -58,7 +58,12 @@ struct ScalarQuantizer : Quantizer { size_t bits = 0; /// trained values (including the range) +#if defined(ENABLE_AVX512_FP16) && defined(__AVX512FP16__) && \ + defined(__FLT16_MANT_DIG__) + std::vector<_Float16> trained; +#else std::vector trained; +#endif ScalarQuantizer(size_t d, QuantizerType qtype); ScalarQuantizer();