@@ -125,6 +125,10 @@ void ResidualQuantizer::initialize_from(
125125 }
126126}
127127
128+ /* ***************************************************************
129+ * Encoding steps, used both for training and search
130+ */
131+
128132void beam_search_encode_step (
129133 size_t d,
130134 size_t K,
@@ -277,6 +281,10 @@ void beam_search_encode_step(
277281 }
278282}
279283
284+ /* ***************************************************************
285+ * Training
286+ ****************************************************************/
287+
280288void ResidualQuantizer::train (size_t n, const float * x) {
281289 codebooks.resize (d * codebook_offsets.back ());
282290
@@ -568,7 +576,12 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
568576 return mem;
569577}
570578
571- // a namespace full of preallocated buffers
579+ /* ***************************************************************
580+ * Encoding
581+ ****************************************************************/
582+
583+ // a namespace full of preallocated buffers. This speeds up
584+ // computations, instead of re-allocating them at every encoing step
572585namespace {
573586
574587// Preallocated memory chunk for refine_beam_mp() call
@@ -609,8 +622,6 @@ struct ComputeCodesAddCentroidsLUT1MemoryPool {
609622 RefineBeamLUTMemoryPool refine_beam_lut_pool;
610623};
611624
612- } // namespace
613-
614625// forward declaration
615626void refine_beam_mp (
616627 const ResidualQuantizer& rq,
@@ -743,6 +754,8 @@ void compute_codes_add_centroids_mp_lut1(
743754 centroids);
744755}
745756
757+ } // namespace
758+
746759void ResidualQuantizer::compute_codes_add_centroids (
747760 const float * x,
748761 uint8_t * codes_out,
@@ -769,11 +782,6 @@ void ResidualQuantizer::compute_codes_add_centroids(
769782 cent = centroids + i0 * d;
770783 }
771784
772- // compute_codes_add_centroids(
773- // x + i0 * d,
774- // codes_out + i0 * code_size,
775- // i1 - i0,
776- // cent);
777785 if (use_beam_LUT == 0 ) {
778786 compute_codes_add_centroids_mp_lut0 (
779787 *this ,
@@ -794,6 +802,8 @@ void ResidualQuantizer::compute_codes_add_centroids(
794802 }
795803}
796804
805+ namespace {
806+
797807void refine_beam_mp (
798808 const ResidualQuantizer& rq,
799809 size_t n,
@@ -873,15 +883,11 @@ void refine_beam_mp(
873883 codebooks_m,
874884 n,
875885 cur_beam_size,
876- // residuals.data(),
877886 residuals_ptr,
878887 m,
879- // codes.data(),
880888 codes_ptr,
881889 new_beam_size,
882- // new_codes.data(),
883890 new_codes_ptr,
884- // new_residuals.data(),
885891 new_residuals_ptr,
886892 pool.distances .data (),
887893 assign_index.get (),
@@ -896,9 +902,6 @@ void refine_beam_mp(
896902
897903 if (rq.verbose ) {
898904 float sum_distances = 0 ;
899- // for (int j = 0; j < distances.size(); j++) {
900- // sum_distances += distances[j];
901- // }
902905 for (int j = 0 ; j < distances_size; j++) {
903906 sum_distances += pool.distances [j];
904907 }
@@ -914,27 +917,22 @@ void refine_beam_mp(
914917 }
915918
916919 if (out_codes) {
917- // memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
918920 memcpy (out_codes, codes_ptr, codes_size * sizeof (*codes_ptr));
919921 }
920922 if (out_residuals) {
921- // memcpy(out_residuals,
922- // residuals.data(),
923- // residuals.size() * sizeof(residuals[0]));
924923 memcpy (out_residuals,
925924 residuals_ptr,
926925 residuals_size * sizeof (*residuals_ptr));
927926 }
928927 if (out_distances) {
929- // memcpy(out_distances,
930- // distances.data(),
931- // distances.size() * sizeof(distances[0]));
932928 memcpy (out_distances,
933929 pool.distances .data (),
934930 distances_size * sizeof (pool.distances [0 ]));
935931 }
936932}
937933
934+ } // anonymous namespace
935+
938936void ResidualQuantizer::refine_beam (
939937 size_t n,
940938 size_t beam_size,
@@ -1165,7 +1163,7 @@ void accum_and_finalize_tab(
11651163 }
11661164}
11671165
1168- } // namespace
1166+ } // anonymous namespace
11691167
11701168void beam_search_encode_step_tab (
11711169 size_t K,
@@ -1390,6 +1388,8 @@ void beam_search_encode_step_tab(
13901388 }
13911389}
13921390
1391+ namespace {
1392+
13931393//
13941394void refine_beam_LUT_mp (
13951395 const ResidualQuantizer& rq,
@@ -1443,13 +1443,9 @@ void refine_beam_LUT_mp(
14431443 for (int m = 0 ; m < rq.M ; m++) {
14441444 int K = 1 << rq.nbits [m];
14451445
1446- // it is guaranteed that (new_beam_size <= than max_beam_size) ==
1447- // true
1446+ // it is guaranteed that (new_beam_size <= max_beam_size)
14481447 int new_beam_size = std::min (beam_size * K, out_beam_size);
14491448
1450- // std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
1451- // std::vector<float> new_distances(n * new_beam_size);
1452-
14531449 codes_size = n * new_beam_size * (m + 1 );
14541450 distances_size = n * new_beam_size;
14551451
@@ -1464,29 +1460,20 @@ void refine_beam_LUT_mp(
14641460 rq.total_codebook_size ,
14651461 rq.cent_norms .data () + rq.codebook_offsets [m],
14661462 m,
1467- // codes.data(),
14681463 codes_ptr,
1469- // distances.data(),
14701464 distances_ptr,
14711465 new_beam_size,
1472- // new_codes.data(),
14731466 new_codes_ptr,
1474- // new_distances.data()
14751467 new_distances_ptr,
14761468 rq.approx_topk_mode );
14771469
1478- // codes.swap(new_codes);
14791470 std::swap (codes_ptr, new_codes_ptr);
1480- // distances.swap(new_distances);
14811471 std::swap (distances_ptr, new_distances_ptr);
14821472
14831473 beam_size = new_beam_size;
14841474
14851475 if (rq.verbose ) {
14861476 float sum_distances = 0 ;
1487- // for (int j = 0; j < distances.size(); j++) {
1488- // sum_distances += distances[j];
1489- // }
14901477 for (int j = 0 ; j < distances_size; j++) {
14911478 sum_distances += distances_ptr[j];
14921479 }
@@ -1501,19 +1488,17 @@ void refine_beam_LUT_mp(
15011488 }
15021489
15031490 if (out_codes) {
1504- // memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
15051491 memcpy (out_codes, codes_ptr, codes_size * sizeof (*codes_ptr));
15061492 }
15071493 if (out_distances) {
1508- // memcpy(out_distances,
1509- // distances.data(),
1510- // distances.size() * sizeof(distances[0]));
15111494 memcpy (out_distances,
15121495 distances_ptr,
15131496 distances_size * sizeof (*distances_ptr));
15141497 }
15151498}
15161499
1500+ } // namespace
1501+
15171502void ResidualQuantizer::refine_beam_LUT (
15181503 size_t n,
15191504 const float * query_norms, // size n
0 commit comments