@@ -977,48 +977,115 @@ void IndexIVF::search_and_reconstruct(
977977 std::min (nlist, params ? params->nprobe : this ->nprobe );
978978 FAISS_THROW_IF_NOT (nprobe > 0 );
979979
980- idx_t * idx = new idx_t [n * nprobe];
981- ScopeDeleter<idx_t > del (idx);
982- float * coarse_dis = new float [n * nprobe];
983- ScopeDeleter<float > del2 (coarse_dis);
980+ std::unique_ptr<idx_t []> idx (new idx_t [n * nprobe]);
981+ std::unique_ptr<float []> coarse_dis (new float [n * nprobe]);
984982
985- quantizer->search (n, x, nprobe, coarse_dis, idx);
983+ quantizer->search (n, x, nprobe, coarse_dis. get () , idx. get () );
986984
987- invlists->prefetch_lists (idx, n * nprobe);
985+ invlists->prefetch_lists (idx. get () , n * nprobe);
988986
989987 // search_preassigned() with `store_pairs` enabled to obtain the list_no
990988 // and offset into `codes` for reconstruction
991989 search_preassigned (
992990 n,
993991 x,
994992 k,
995- idx,
996- coarse_dis,
993+ idx. get () ,
994+ coarse_dis. get () ,
997995 distances,
998996 labels,
999997 true /* store_pairs */ ,
1000998 params);
1001- for (idx_t i = 0 ; i < n; ++i) {
1002- for (idx_t j = 0 ; j < k; ++j) {
1003- idx_t ij = i * k + j;
1004- idx_t key = labels[ij];
1005- float * reconstructed = recons + ij * d;
1006- if (key < 0 ) {
1007- // Fill with NaNs
1008- memset (reconstructed, -1 , sizeof (*reconstructed) * d);
1009- } else {
1010- int list_no = lo_listno (key);
1011- int offset = lo_offset (key);
999+ #pragma omp parallel for if(n * k > 1000)
1000+ for (idx_t ij = 0 ; ij < n * k; ij++) {
1001+ idx_t key = labels[ij];
1002+ float * reconstructed = recons + ij * d;
1003+ if (key < 0 ) {
1004+ // Fill with NaNs
1005+ memset (reconstructed, -1 , sizeof (*reconstructed) * d);
1006+ } else {
1007+ int list_no = lo_listno (key);
1008+ int offset = lo_offset (key);
1009+
1010+ // Update label to the actual id
1011+ labels[ij] = invlists->get_single_id (list_no, offset);
1012+
1013+ reconstruct_from_offset (list_no, offset, reconstructed);
1014+ }
1015+ }
1016+ }
1017+
1018+
1019+ void IndexIVF::search_and_return_codes (
1020+ idx_t n,
1021+ const float * x,
1022+ idx_t k,
1023+ float * distances,
1024+ idx_t * labels,
1025+ uint8_t * codes,
1026+ bool include_listno,
1027+ const SearchParameters* params_in) const
1028+ {
1029+ const IVFSearchParameters* params = nullptr ;
1030+ if (params_in) {
1031+ params = dynamic_cast <const IVFSearchParameters*>(params_in);
1032+ FAISS_THROW_IF_NOT_MSG (params, " IndexIVF params have incorrect type" );
1033+ }
1034+ const size_t nprobe =
1035+ std::min (nlist, params ? params->nprobe : this ->nprobe );
1036+ FAISS_THROW_IF_NOT (nprobe > 0 );
1037+
1038+ std::unique_ptr<idx_t []> idx (new idx_t [n * nprobe]);
1039+ std::unique_ptr<float []> coarse_dis (new float [n * nprobe]);
10121040
1013- // Update label to the actual id
1014- labels[ij] = invlists->get_single_id (list_no, offset);
1041+ quantizer->search (n, x, nprobe, coarse_dis.get (), idx.get ());
10151042
1016- reconstruct_from_offset (list_no, offset, reconstructed);
1043+ invlists->prefetch_lists (idx.get (), n * nprobe);
1044+
1045+ // search_preassigned() with `store_pairs` enabled to obtain the list_no
1046+ // and offset into `codes` for reconstruction
1047+ search_preassigned (
1048+ n,
1049+ x,
1050+ k,
1051+ idx.get (),
1052+ coarse_dis.get (),
1053+ distances,
1054+ labels,
1055+ true /* store_pairs */ ,
1056+ params);
1057+
1058+ size_t code_size_1 = code_size;
1059+ if (include_listno) {
1060+ code_size_1 += coarse_code_size ();
1061+ }
1062+
1063+ #pragma omp parallel for if(n * k > 1000)
1064+ for (idx_t ij = 0 ; ij < n * k; ij++) {
1065+ idx_t key = labels[ij];
1066+ uint8_t * code1 = codes + ij * code_size_1;
1067+
1068+ if (key < 0 ) {
1069+ // Fill with 0xff
1070+ memset (code1, -1 , code_size_1);
1071+ } else {
1072+ int list_no = lo_listno (key);
1073+ int offset = lo_offset (key);
1074+ const uint8_t *cc = invlists->get_single_code (list_no, offset);
1075+
1076+ labels[ij] = invlists->get_single_id (list_no, offset);
1077+
1078+ if (include_listno) {
1079+ encode_listno (list_no, code1);
1080+ code1 += code_size_1 - code_size;
10171081 }
1082+ memcpy (code1, cc, code_size);
10181083 }
10191084 }
10201085}
10211086
1087+
1088+
10221089void IndexIVF::reconstruct_from_offset (
10231090 int64_t /* list_no*/ ,
10241091 int64_t /* offset*/ ,
0 commit comments