@@ -384,11 +384,11 @@ void beam_search_encode_step_tab(
384384 size_t n,
385385 size_t beam_size, // input sizes
386386 const float * codebook_cross_norms, // size K * ldc
387- size_t ldc, // >= K
388- const uint64_t * codebook_offsets, // m
389- const float * query_cp, // size n * ldqc
390- size_t ldqc, // >= K
391- const float * cent_norms_i, // size K
387+ size_t ldc,
388+ const uint64_t * codebook_offsets, // m
389+ const float * query_cp, // size n * ldqc
390+ size_t ldqc, // >= K
391+ const float * cent_norms_i, // size K
392392 size_t m,
393393 const int32_t * codes, // n * beam_size * m
394394 const float * distances, // n * beam_size
@@ -412,35 +412,38 @@ void beam_search_encode_step_tab(
412412 cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
413413 }
414414
415- /*
415+ bool use_baseline_implementation = false ;
416+
416417 // This is the baseline implementation. Its primary flaw
417418 // that it writes way too many info to the temporary buffer
418419 // called dp.
419420 //
420421 // This baseline code is kept intentionally because it is easy to
421422 // understand what an optimized version optimizes exactly.
422423 //
423- for (size_t b = 0; b < beam_size; b++) {
424- std::vector<float> dp(K);
425-
426- for (size_t m1 = 0; m1 < m; m1++) {
427- size_t c = codes_i[b * m + m1];
428- const float* cb =
429- &codebook_cross_norms[(codebook_offsets[m1] + c) * ldc];
430- fvec_add(K, cb, dp.data(), dp.data());
431- }
424+ if (use_baseline_implementation) {
425+ for (size_t b = 0 ; b < beam_size; b++) {
426+ std::vector<float > dp (K);
432427
433- for (size_t k = 0; k < K; k++) {
434- cent_distances[b * K + k] =
435- distances_i[b] + cd_common[k] + 2 * dp[k];
428+ for (size_t m1 = 0 ; m1 < m; m1++) {
429+ size_t c = codes_i[b * m + m1];
430+ const float * cb =
431+ &codebook_cross_norms
432+ [(codebook_offsets[m1] + c) * ldc];
433+ fvec_add (K, cb, dp.data (), dp.data ());
434+ }
435+
436+ for (size_t k = 0 ; k < K; k++) {
437+ cent_distances[b * K + k] =
438+ distances_i[b] + cd_common[k] + 2 * dp[k];
439+ }
436440 }
437- }
438- */
439441
440- // An optimized implementation that avoids using a temporary buffer
441- // and does the accumulation in registers.
442+ } else {
443+ // An optimized implementation that avoids using a temporary buffer
444+ // and does the accumulation in registers.
442445
443- // Compute a sum of NK AQ codes.
446+ // Compute a sum of NK AQ codes.
444447#define ACCUM_AND_FINALIZE_TAB (NK ) \
445448 case NK: \
446449 for (size_t b = 0 ; b < beam_size; b++) { \
@@ -457,51 +460,52 @@ void beam_search_encode_step_tab(
457460 } \
458461 break ;
459462
460- // this version contains many switch-case scenarios, but
461- // they won't affect branch predictor.
462- switch (m) {
463- case 0 :
464- // trivial case
465- for (size_t b = 0 ; b < beam_size; b++) {
466- for (size_t k = 0 ; k < K; k++) {
467- cent_distances[b * K + k] =
468- distances_i[b] + cd_common[k];
463+ // this version contains many switch-case scenarios, but
464+ // they won't affect branch predictor.
465+ switch (m) {
466+ case 0 :
467+ // trivial case
468+ for (size_t b = 0 ; b < beam_size; b++) {
469+ for (size_t k = 0 ; k < K; k++) {
470+ cent_distances[b * K + k] =
471+ distances_i[b] + cd_common[k];
472+ }
469473 }
470- }
471- break ;
472-
473- ACCUM_AND_FINALIZE_TAB (1 )
474- ACCUM_AND_FINALIZE_TAB (2 )
475- ACCUM_AND_FINALIZE_TAB (3 )
476- ACCUM_AND_FINALIZE_TAB (4 )
477- ACCUM_AND_FINALIZE_TAB (5 )
478- ACCUM_AND_FINALIZE_TAB (6 )
479- ACCUM_AND_FINALIZE_TAB ( 7 )
480-
481- default : {
482- // m >= 8 case.
483-
484- // A temporary buffer has to be used due to the lack of
485- // registers. But we'll try to accumulate up to 8 AQ codes in
486- // registers and issue a single write operation to the buffer,
487- // while the baseline does no accumulation. So, the number of
488- // write operations to the temporary buffer is reduced 8x.
489-
490- // allocate a temporary buffer
491- std::vector<float > dp (K);
492-
493- for (size_t b = 0 ; b < beam_size; b++) {
494- // Initialize it. Compute a sum of first 8 AQ codes
495- // because m >= 8 .
496- accum_and_store_tab<8 , 4 >(
497- m,
498- codebook_cross_norms,
499- codebook_offsets,
500- codes_i,
501- b,
502- ldc,
503- K,
504- dp.data ());
474+ break ;
475+
476+ ACCUM_AND_FINALIZE_TAB ( 1 )
477+ ACCUM_AND_FINALIZE_TAB (2 )
478+ ACCUM_AND_FINALIZE_TAB (3 )
479+ ACCUM_AND_FINALIZE_TAB (4 )
480+ ACCUM_AND_FINALIZE_TAB (5 )
481+ ACCUM_AND_FINALIZE_TAB (6 )
482+ ACCUM_AND_FINALIZE_TAB (7 )
483+
484+ default : {
485+ // m >= 8 case.
486+
487+ // A temporary buffer has to be used due to the lack of
488+ // registers. But we'll try to accumulate up to 8 AQ codes
489+ // in registers and issue a single write operation to the
490+ // buffer, while the baseline does no accumulation. So, the
491+ // number of write operations to the temporary buffer is
492+ // reduced 8x.
493+
494+ // allocate a temporary buffer
495+ std::vector<float > dp (K);
496+
497+ for (size_t b = 0 ; b < beam_size; b++) {
498+ // Initialize it. Compute a sum of first 8 AQ codes
499+ // because m >= 8 .
500+ accum_and_store_tab<8 , 4 >(
501+ m,
502+ codebook_cross_norms,
503+ codebook_offsets,
504+ codes_i,
505+ b,
506+ ldc,
507+ K,
508+ dp.data ());
505509
506510#define ACCUM_AND_ADD_TAB (NK ) \
507511 case NK: \
@@ -516,37 +520,37 @@ void beam_search_encode_step_tab(
516520 dp.data ()); \
517521 break ;
518522
519- // accumulate up to 8 additional AQ codes into
520- // a temporary buffer
521- for (size_t im = 8 ; im < ((m + 7 ) / 8 ) * 8 ; im += 8 ) {
522- size_t m_left = m - im;
523- if (m_left > 8 ) {
524- m_left = 8 ;
523+ // accumulate up to 8 additional AQ codes into
524+ // a temporary buffer
525+ for (size_t im = 8 ; im < ((m + 7 ) / 8 ) * 8 ; im += 8 ) {
526+ size_t m_left = m - im;
527+ if (m_left > 8 ) {
528+ m_left = 8 ;
529+ }
530+
531+ switch (m_left) {
532+ ACCUM_AND_ADD_TAB (1 )
533+ ACCUM_AND_ADD_TAB (2 )
534+ ACCUM_AND_ADD_TAB (3 )
535+ ACCUM_AND_ADD_TAB (4 )
536+ ACCUM_AND_ADD_TAB (5 )
537+ ACCUM_AND_ADD_TAB (6 )
538+ ACCUM_AND_ADD_TAB (7 )
539+ ACCUM_AND_ADD_TAB (8 )
540+ }
525541 }
526542
527- switch (m_left) {
528- ACCUM_AND_ADD_TAB (1 )
529- ACCUM_AND_ADD_TAB (2 )
530- ACCUM_AND_ADD_TAB (3 )
531- ACCUM_AND_ADD_TAB (4 )
532- ACCUM_AND_ADD_TAB (5 )
533- ACCUM_AND_ADD_TAB (6 )
534- ACCUM_AND_ADD_TAB (7 )
535- ACCUM_AND_ADD_TAB (8 )
543+ // done. finalize the result
544+ for (size_t k = 0 ; k < K; k++) {
545+ cent_distances[b * K + k] =
546+ distances_i[b] + cd_common[k] + 2 * dp[k];
536547 }
537548 }
538-
539- // done. finalize the result
540- for (size_t k = 0 ; k < K; k++) {
541- cent_distances[b * K + k] =
542- distances_i[b] + cd_common[k] + 2 * dp[k];
543- }
544549 }
545550 }
546- }
547-
548- // the optimized implementation ends here
549551
552+ // the optimized implementation ends here
553+ }
550554 using C = CMax<float , int >;
551555 int32_t * new_codes_i = new_codes + i * (m + 1 ) * new_beam_size;
552556 float * new_distances_i = new_distances + i * new_beam_size;
@@ -784,6 +788,7 @@ void refine_beam_LUT_mp(
784788 // main loop
785789 size_t codes_size = 0 ;
786790 size_t distances_size = 0 ;
791+ size_t cross_ofs = 0 ;
787792 for (int m = 0 ; m < rq.M ; m++) {
788793 int K = 1 << rq.nbits [m];
789794
@@ -792,13 +797,15 @@ void refine_beam_LUT_mp(
792797
793798 codes_size = n * new_beam_size * (m + 1 );
794799 distances_size = n * new_beam_size;
795-
800+ FAISS_THROW_IF_NOT (
801+ cross_ofs + rq.codebook_offsets [m] * K <=
802+ rq.codebook_cross_products .size ());
796803 beam_search_encode_step_tab (
797804 K,
798805 n,
799806 beam_size,
800- rq.codebook_cross_products .data () + rq. codebook_offsets [m] ,
801- rq. total_codebook_size ,
807+ rq.codebook_cross_products .data () + cross_ofs ,
808+ K ,
802809 rq.codebook_offsets .data (),
803810 query_cp + rq.codebook_offsets [m],
804811 rq.total_codebook_size ,
@@ -810,7 +817,7 @@ void refine_beam_LUT_mp(
810817 new_codes_ptr,
811818 new_distances_ptr,
812819 rq.approx_topk_mode );
813-
820+ cross_ofs += rq. codebook_offsets [m] * K;
814821 std::swap (codes_ptr, new_codes_ptr);
815822 std::swap (distances_ptr, new_distances_ptr);
816823
@@ -830,7 +837,6 @@ void refine_beam_LUT_mp(
830837 beam_size);
831838 }
832839 }
833-
834840 if (out_codes) {
835841 memcpy (out_codes, codes_ptr, codes_size * sizeof (*codes_ptr));
836842 }
@@ -903,8 +909,7 @@ void compute_codes_add_centroids_mp_lut1(
903909 pool.distances .resize (rq.max_beam_size * n);
904910
905911 FAISS_THROW_IF_NOT_MSG (
906- rq.codebook_cross_products .size () ==
907- rq.total_codebook_size * rq.total_codebook_size ,
912+ rq.M == 1 || rq.codebook_cross_products .size () > 0 ,
908913 " call compute_codebook_tables first" );
909914
910915 pool.query_norms .resize (n);
0 commit comments