@@ -96,12 +96,24 @@ void IndexRefine::search(
9696 idx_t k,
9797 float * distances,
9898 idx_t * labels,
99- const SearchParameters* params) const {
100- FAISS_THROW_IF_NOT_MSG (
101- !params, " search params not supported for this index" );
99+ const SearchParameters* params_in) const {
100+ const IndexRefineSearchParameters* params = nullptr ;
101+ if (params_in) {
102+ params = dynamic_cast <const IndexRefineSearchParameters*>(params_in);
103+ FAISS_THROW_IF_NOT_MSG (params, " IndexRefine params have incorrect type" );
104+ }
105+
106+ idx_t k_base = (params != nullptr ) ? idx_t (k * params->k_factor ) : idx_t (k * k_factor);
107+ SearchParameters* base_index_params =
108+ (params != nullptr ) ? params->base_index_params : nullptr ;
109+
110+ FAISS_THROW_IF_NOT (k_base >= k);
111+
112+ FAISS_THROW_IF_NOT (base_index);
113+ FAISS_THROW_IF_NOT (refine_index);
114+
102115 FAISS_THROW_IF_NOT (k > 0 );
103116 FAISS_THROW_IF_NOT (is_trained);
104- idx_t k_base = idx_t (k * k_factor);
105117 idx_t * base_labels = labels;
106118 float * base_distances = distances;
107119 ScopeDeleter<idx_t > del1;
@@ -114,7 +126,7 @@ void IndexRefine::search(
114126 del2.set (base_distances);
115127 }
116128
117- base_index->search (n, x, k_base, base_distances, base_labels);
129+ base_index->search (n, x, k_base, base_distances, base_labels, base_index_params );
118130
119131 for (int i = 0 ; i < n * k_base; i++)
120132 assert (base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -225,12 +237,24 @@ void IndexRefineFlat::search(
225237 idx_t k,
226238 float * distances,
227239 idx_t * labels,
228- const SearchParameters* params) const {
229- FAISS_THROW_IF_NOT_MSG (
230- !params, " search params not supported for this index" );
240+ const SearchParameters* params_in) const {
241+ const IndexRefineSearchParameters* params = nullptr ;
242+ if (params_in) {
243+ params = dynamic_cast <const IndexRefineSearchParameters*>(params_in);
244+ FAISS_THROW_IF_NOT_MSG (params, " IndexRefineFlat params have incorrect type" );
245+ }
246+
247+ idx_t k_base = (params != nullptr ) ? idx_t (k * params->k_factor ) : idx_t (k * k_factor);
248+ SearchParameters* base_index_params =
249+ (params != nullptr ) ? params->base_index_params : nullptr ;
250+
251+ FAISS_THROW_IF_NOT (k_base >= k);
252+
253+ FAISS_THROW_IF_NOT (base_index);
254+ FAISS_THROW_IF_NOT (refine_index);
255+
231256 FAISS_THROW_IF_NOT (k > 0 );
232257 FAISS_THROW_IF_NOT (is_trained);
233- idx_t k_base = idx_t (k * k_factor);
234258 idx_t * base_labels = labels;
235259 float * base_distances = distances;
236260 ScopeDeleter<idx_t > del1;
@@ -243,7 +267,7 @@ void IndexRefineFlat::search(
243267 del2.set (base_distances);
244268 }
245269
246- base_index->search (n, x, k_base, base_distances, base_labels);
270+ base_index->search (n, x, k_base, base_distances, base_labels, base_index_params );
247271
248272 for (int i = 0 ; i < n * k_base; i++)
249273 assert (base_labels[i] >= -1 && base_labels[i] < ntotal);
0 commit comments