@@ -65,16 +65,22 @@ using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
6565 */
6666
6767struct Codec8bit {
68- static void encode_component (float x, uint8_t * code, int i) {
68+ static FAISS_ALWAYS_INLINE void encode_component (
69+ float x,
70+ uint8_t * code,
71+ int i) {
6972 code[i] = (int )(255 * x);
7073 }
7174
72- static float decode_component (const uint8_t * code, int i) {
75+ static FAISS_ALWAYS_INLINE float decode_component (
76+ const uint8_t * code,
77+ int i) {
7378 return (code[i] + 0 .5f ) / 255 .0f ;
7479 }
7580
7681#ifdef __AVX2__
77- static inline __m256 decode_8_components (const uint8_t * code, int i) {
82+ static FAISS_ALWAYS_INLINE __m256
83+ decode_8_components (const uint8_t * code, int i) {
7884 const uint64_t c8 = *(uint64_t *)(code + i);
7985
8086 const __m128i i8 = _mm_set1_epi64x (c8);
@@ -88,16 +94,22 @@ struct Codec8bit {
8894};
8995
9096struct Codec4bit {
91- static void encode_component (float x, uint8_t * code, int i) {
97+ static FAISS_ALWAYS_INLINE void encode_component (
98+ float x,
99+ uint8_t * code,
100+ int i) {
92101 code[i / 2 ] |= (int )(x * 15.0 ) << ((i & 1 ) << 2 );
93102 }
94103
95- static float decode_component (const uint8_t * code, int i) {
104+ static FAISS_ALWAYS_INLINE float decode_component (
105+ const uint8_t * code,
106+ int i) {
96107 return (((code[i / 2 ] >> ((i & 1 ) << 2 )) & 0xf ) + 0 .5f ) / 15 .0f ;
97108 }
98109
99110#ifdef __AVX2__
100- static __m256 decode_8_components (const uint8_t * code, int i) {
111+ static FAISS_ALWAYS_INLINE __m256
112+ decode_8_components (const uint8_t * code, int i) {
101113 uint32_t c4 = *(uint32_t *)(code + (i >> 1 ));
102114 uint32_t mask = 0x0f0f0f0f ;
103115 uint32_t c4ev = c4 & mask;
@@ -120,7 +132,10 @@ struct Codec4bit {
120132};
121133
122134struct Codec6bit {
123- static void encode_component (float x, uint8_t * code, int i) {
135+ static FAISS_ALWAYS_INLINE void encode_component (
136+ float x,
137+ uint8_t * code,
138+ int i) {
124139 int bits = (int )(x * 63.0 );
125140 code += (i >> 2 ) * 3 ;
126141 switch (i & 3 ) {
@@ -141,7 +156,9 @@ struct Codec6bit {
141156 }
142157 }
143158
144- static float decode_component (const uint8_t * code, int i) {
159+ static FAISS_ALWAYS_INLINE float decode_component (
160+ const uint8_t * code,
161+ int i) {
145162 uint8_t bits;
146163 code += (i >> 2 ) * 3 ;
147164 switch (i & 3 ) {
@@ -167,7 +184,7 @@ struct Codec6bit {
167184
168185 /* Load 6 bytes that represent 8 6-bit values, return them as a
169186 * 8*32 bit vector register */
170- static __m256i load6 (const uint16_t * code16) {
187+ static FAISS_ALWAYS_INLINE __m256i load6 (const uint16_t * code16) {
171188 const __m128i perm = _mm_set_epi8 (
172189 -1 , 5 , 5 , 4 , 4 , 3 , -1 , 3 , -1 , 2 , 2 , 1 , 1 , 0 , -1 , 0 );
173190 const __m256i shifts = _mm256_set_epi32 (2 , 4 , 6 , 0 , 2 , 4 , 6 , 0 );
@@ -186,15 +203,28 @@ struct Codec6bit {
186203 return c5;
187204 }
188205
189- static __m256 decode_8_components (const uint8_t * code, int i) {
206+ static FAISS_ALWAYS_INLINE __m256
207+ decode_8_components (const uint8_t * code, int i) {
208+ // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here
209+ // // for the reference, maybe, it becomes used oned day.
210+ // const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3);
211+ // const uint32_t* data32 = (const uint32_t*)data16;
212+ // const uint64_t val = *data32 + ((uint64_t)data16[2] << 32);
213+ // const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL);
214+ // const __m128i i8 = _mm_set1_epi64x(vext);
215+ // const __m256i i32 = _mm256_cvtepi8_epi32(i8);
216+ // const __m256 f8 = _mm256_cvtepi32_ps(i32);
217+ // const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
218+ // const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
219+ // return _mm256_fmadd_ps(f8, one_255, half_one_255);
220+
190221 __m256i i8 = load6 ((const uint16_t *)(code + (i >> 2 ) * 3 ));
191222 __m256 f8 = _mm256_cvtepi32_ps (i8 );
192223 // this could also be done with bit manipulations but it is
193224 // not obviously faster
194- __m256 half = _mm256_set1_ps (0 .5f );
195- f8 = _mm256_add_ps (f8 , half);
196- __m256 one_63 = _mm256_set1_ps (1 .f / 63 .f );
197- return _mm256_mul_ps (f8 , one_63);
225+ const __m256 half_one_255 = _mm256_set1_ps (0 .5f / 63 .f );
226+ const __m256 one_255 = _mm256_set1_ps (1 .f / 63 .f );
227+ return _mm256_fmadd_ps (f8 , one_255, half_one_255);
198228 }
199229
200230#endif
@@ -239,7 +269,8 @@ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
239269 }
240270 }
241271
242- float reconstruct_component (const uint8_t * code, int i) const {
272+ FAISS_ALWAYS_INLINE float reconstruct_component (const uint8_t * code, int i)
273+ const {
243274 float xi = Codec::decode_component (code, i);
244275 return vmin + xi * vdiff;
245276 }
@@ -252,11 +283,11 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
252283 QuantizerTemplate (size_t d, const std::vector<float >& trained)
253284 : QuantizerTemplate<Codec, true , 1 >(d, trained) {}
254285
255- __m256 reconstruct_8_components (const uint8_t * code, int i) const {
286+ FAISS_ALWAYS_INLINE __m256
287+ reconstruct_8_components (const uint8_t * code, int i) const {
256288 __m256 xi = Codec::decode_8_components (code, i);
257- return _mm256_add_ps (
258- _mm256_set1_ps (this ->vmin ),
259- _mm256_mul_ps (xi, _mm256_set1_ps (this ->vdiff )));
289+ return _mm256_fmadd_ps (
290+ xi, _mm256_set1_ps (this ->vdiff ), _mm256_set1_ps (this ->vmin ));
260291 }
261292};
262293
@@ -293,7 +324,8 @@ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
293324 }
294325 }
295326
296- float reconstruct_component (const uint8_t * code, int i) const {
327+ FAISS_ALWAYS_INLINE float reconstruct_component (const uint8_t * code, int i)
328+ const {
297329 float xi = Codec::decode_component (code, i);
298330 return vmin[i] + xi * vdiff[i];
299331 }
@@ -306,11 +338,13 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
306338 QuantizerTemplate (size_t d, const std::vector<float >& trained)
307339 : QuantizerTemplate<Codec, false , 1 >(d, trained) {}
308340
309- __m256 reconstruct_8_components (const uint8_t * code, int i) const {
341+ FAISS_ALWAYS_INLINE __m256
342+ reconstruct_8_components (const uint8_t * code, int i) const {
310343 __m256 xi = Codec::decode_8_components (code, i);
311- return _mm256_add_ps (
312- _mm256_loadu_ps (this ->vmin + i),
313- _mm256_mul_ps (xi, _mm256_loadu_ps (this ->vdiff + i)));
344+ return _mm256_fmadd_ps (
345+ xi,
346+ _mm256_loadu_ps (this ->vdiff + i),
347+ _mm256_loadu_ps (this ->vmin + i));
314348 }
315349};
316350
@@ -341,7 +375,8 @@ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
341375 }
342376 }
343377
344- float reconstruct_component (const uint8_t * code, int i) const {
378+ FAISS_ALWAYS_INLINE float reconstruct_component (const uint8_t * code, int i)
379+ const {
345380 return decode_fp16 (((uint16_t *)code)[i]);
346381 }
347382};
@@ -353,7 +388,8 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
353388 QuantizerFP16 (size_t d, const std::vector<float >& trained)
354389 : QuantizerFP16<1 >(d, trained) {}
355390
356- __m256 reconstruct_8_components (const uint8_t * code, int i) const {
391+ FAISS_ALWAYS_INLINE __m256
392+ reconstruct_8_components (const uint8_t * code, int i) const {
357393 __m128i codei = _mm_loadu_si128 ((const __m128i*)(code + 2 * i));
358394 return _mm256_cvtph_ps (codei);
359395 }
@@ -387,7 +423,8 @@ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
387423 }
388424 }
389425
390- float reconstruct_component (const uint8_t * code, int i) const {
426+ FAISS_ALWAYS_INLINE float reconstruct_component (const uint8_t * code, int i)
427+ const {
391428 return code[i];
392429 }
393430};
@@ -399,7 +436,8 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
399436 Quantizer8bitDirect (size_t d, const std::vector<float >& trained)
400437 : Quantizer8bitDirect<1 >(d, trained) {}
401438
402- __m256 reconstruct_8_components (const uint8_t * code, int i) const {
439+ FAISS_ALWAYS_INLINE __m256
440+ reconstruct_8_components (const uint8_t * code, int i) const {
403441 __m128i x8 = _mm_loadl_epi64 ((__m128i*)(code + i)); // 8 * int8
404442 __m256i y8 = _mm256_cvtepu8_epi32 (x8); // 8 * int32
405443 return _mm256_cvtepi32_ps (y8); // 8 * float32
@@ -629,22 +667,22 @@ struct SimilarityL2<1> {
629667
630668 float accu;
631669
632- void begin () {
670+ FAISS_ALWAYS_INLINE void begin () {
633671 accu = 0 ;
634672 yi = y;
635673 }
636674
637- void add_component (float x) {
675+ FAISS_ALWAYS_INLINE void add_component (float x) {
638676 float tmp = *yi++ - x;
639677 accu += tmp * tmp;
640678 }
641679
642- void add_component_2 (float x1, float x2) {
680+ FAISS_ALWAYS_INLINE void add_component_2 (float x1, float x2) {
643681 float tmp = x1 - x2;
644682 accu += tmp * tmp;
645683 }
646684
647- float result () {
685+ FAISS_ALWAYS_INLINE float result () {
648686 return accu;
649687 }
650688};
@@ -660,29 +698,31 @@ struct SimilarityL2<8> {
660698 explicit SimilarityL2 (const float * y) : y(y) {}
661699 __m256 accu8;
662700
663- void begin_8 () {
701+ FAISS_ALWAYS_INLINE void begin_8 () {
664702 accu8 = _mm256_setzero_ps ();
665703 yi = y;
666704 }
667705
668- void add_8_components (__m256 x) {
706+ FAISS_ALWAYS_INLINE void add_8_components (__m256 x) {
669707 __m256 yiv = _mm256_loadu_ps (yi);
670708 yi += 8 ;
671709 __m256 tmp = _mm256_sub_ps (yiv, x);
672- accu8 = _mm256_add_ps (accu8, _mm256_mul_ps ( tmp, tmp) );
710+ accu8 = _mm256_fmadd_ps (tmp, tmp, accu8 );
673711 }
674712
675- void add_8_components_2 (__m256 x, __m256 y) {
713+ FAISS_ALWAYS_INLINE void add_8_components_2 (__m256 x, __m256 y) {
676714 __m256 tmp = _mm256_sub_ps (y, x);
677- accu8 = _mm256_add_ps (accu8, _mm256_mul_ps ( tmp, tmp) );
715+ accu8 = _mm256_fmadd_ps (tmp, tmp, accu8 );
678716 }
679717
680- float result_8 () {
681- __m256 sum = _mm256_hadd_ps (accu8, accu8);
682- __m256 sum2 = _mm256_hadd_ps (sum, sum);
683- // now add the 0th and 4th component
684- return _mm_cvtss_f32 (_mm256_castps256_ps128 (sum2)) +
685- _mm_cvtss_f32 (_mm256_extractf128_ps (sum2, 1 ));
718+ FAISS_ALWAYS_INLINE float result_8 () {
719+ const __m128 sum = _mm_add_ps (
720+ _mm256_castps256_ps128 (accu8), _mm256_extractf128_ps (accu8, 1 ));
721+ const __m128 v0 = _mm_shuffle_ps (sum, sum, _MM_SHUFFLE (0 , 0 , 3 , 2 ));
722+ const __m128 v1 = _mm_add_ps (sum, v0);
723+ __m128 v2 = _mm_shuffle_ps (v1, v1, _MM_SHUFFLE (0 , 0 , 0 , 1 ));
724+ const __m128 v3 = _mm_add_ps (v1, v2);
725+ return _mm_cvtss_f32 (v3);
686726 }
687727};
688728
@@ -701,20 +741,20 @@ struct SimilarityIP<1> {
701741
702742 explicit SimilarityIP (const float * y) : y(y) {}
703743
704- void begin () {
744+ FAISS_ALWAYS_INLINE void begin () {
705745 accu = 0 ;
706746 yi = y;
707747 }
708748
709- void add_component (float x) {
749+ FAISS_ALWAYS_INLINE void add_component (float x) {
710750 accu += *yi++ * x;
711751 }
712752
713- void add_component_2 (float x1, float x2) {
753+ FAISS_ALWAYS_INLINE void add_component_2 (float x1, float x2) {
714754 accu += x1 * x2;
715755 }
716756
717- float result () {
757+ FAISS_ALWAYS_INLINE float result () {
718758 return accu;
719759 }
720760};
@@ -734,27 +774,29 @@ struct SimilarityIP<8> {
734774
735775 __m256 accu8;
736776
737- void begin_8 () {
777+ FAISS_ALWAYS_INLINE void begin_8 () {
738778 accu8 = _mm256_setzero_ps ();
739779 yi = y;
740780 }
741781
742- void add_8_components (__m256 x) {
782+ FAISS_ALWAYS_INLINE void add_8_components (__m256 x) {
743783 __m256 yiv = _mm256_loadu_ps (yi);
744784 yi += 8 ;
745- accu8 = _mm256_add_ps (accu8, _mm256_mul_ps ( yiv, x) );
785+ accu8 = _mm256_fmadd_ps ( yiv, x, accu8 );
746786 }
747787
748- void add_8_components_2 (__m256 x1, __m256 x2) {
749- accu8 = _mm256_add_ps (accu8, _mm256_mul_ps ( x1, x2) );
788+ FAISS_ALWAYS_INLINE void add_8_components_2 (__m256 x1, __m256 x2) {
789+ accu8 = _mm256_fmadd_ps ( x1, x2, accu8 );
750790 }
751791
752- float result_8 () {
753- __m256 sum = _mm256_hadd_ps (accu8, accu8);
754- __m256 sum2 = _mm256_hadd_ps (sum, sum);
755- // now add the 0th and 4th component
756- return _mm_cvtss_f32 (_mm256_castps256_ps128 (sum2)) +
757- _mm_cvtss_f32 (_mm256_extractf128_ps (sum2, 1 ));
792+ FAISS_ALWAYS_INLINE float result_8 () {
793+ const __m128 sum = _mm_add_ps (
794+ _mm256_castps256_ps128 (accu8), _mm256_extractf128_ps (accu8, 1 ));
795+ const __m128 v0 = _mm_shuffle_ps (sum, sum, _MM_SHUFFLE (0 , 0 , 3 , 2 ));
796+ const __m128 v1 = _mm_add_ps (sum, v0);
797+ __m128 v2 = _mm_shuffle_ps (v1, v1, _MM_SHUFFLE (0 , 0 , 0 , 1 ));
798+ const __m128 v3 = _mm_add_ps (v1, v2);
799+ return _mm_cvtss_f32 (v3);
758800 }
759801};
760802#endif
0 commit comments