@@ -96,12 +96,26 @@ void IndexRefine::search(
9696 idx_t k,
9797 float * distances,
9898 idx_t * labels,
99- const SearchParameters* params) const {
100- FAISS_THROW_IF_NOT_MSG (
101- !params, " search params not supported for this index" );
99+ const SearchParameters* params_in) const {
100+ const IndexRefineSearchParameters* params = nullptr ;
101+ if (params_in) {
102+ params = dynamic_cast <const IndexRefineSearchParameters*>(params_in);
103+ FAISS_THROW_IF_NOT_MSG (
104+ params, " IndexRefine params have incorrect type" );
105+ }
106+
107+ idx_t k_base = (params != nullptr ) ? idx_t (k * params->k_factor )
108+ : idx_t (k * k_factor);
109+ SearchParameters* base_index_params =
110+ (params != nullptr ) ? params->base_index_params : nullptr ;
111+
112+ FAISS_THROW_IF_NOT (k_base >= k);
113+
114+ FAISS_THROW_IF_NOT (base_index);
115+ FAISS_THROW_IF_NOT (refine_index);
116+
102117 FAISS_THROW_IF_NOT (k > 0 );
103118 FAISS_THROW_IF_NOT (is_trained);
104- idx_t k_base = idx_t (k * k_factor);
105119 idx_t * base_labels = labels;
106120 float * base_distances = distances;
107121 ScopeDeleter<idx_t > del1;
@@ -114,7 +128,8 @@ void IndexRefine::search(
114128 del2.set (base_distances);
115129 }
116130
117- base_index->search (n, x, k_base, base_distances, base_labels);
131+ base_index->search (
132+ n, x, k_base, base_distances, base_labels, base_index_params);
118133
119134 for (int i = 0 ; i < n * k_base; i++)
120135 assert (base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -225,12 +240,26 @@ void IndexRefineFlat::search(
225240 idx_t k,
226241 float * distances,
227242 idx_t * labels,
228- const SearchParameters* params) const {
229- FAISS_THROW_IF_NOT_MSG (
230- !params, " search params not supported for this index" );
243+ const SearchParameters* params_in) const {
244+ const IndexRefineSearchParameters* params = nullptr ;
245+ if (params_in) {
246+ params = dynamic_cast <const IndexRefineSearchParameters*>(params_in);
247+ FAISS_THROW_IF_NOT_MSG (
248+ params, " IndexRefineFlat params have incorrect type" );
249+ }
250+
251+ idx_t k_base = (params != nullptr ) ? idx_t (k * params->k_factor )
252+ : idx_t (k * k_factor);
253+ SearchParameters* base_index_params =
254+ (params != nullptr ) ? params->base_index_params : nullptr ;
255+
256+ FAISS_THROW_IF_NOT (k_base >= k);
257+
258+ FAISS_THROW_IF_NOT (base_index);
259+ FAISS_THROW_IF_NOT (refine_index);
260+
231261 FAISS_THROW_IF_NOT (k > 0 );
232262 FAISS_THROW_IF_NOT (is_trained);
233- idx_t k_base = idx_t (k * k_factor);
234263 idx_t * base_labels = labels;
235264 float * base_distances = distances;
236265 ScopeDeleter<idx_t > del1;
@@ -243,7 +272,8 @@ void IndexRefineFlat::search(
243272 del2.set (base_distances);
244273 }
245274
246- base_index->search (n, x, k_base, base_distances, base_labels);
275+ base_index->search (
276+ n, x, k_base, base_distances, base_labels, base_index_params);
247277
248278 for (int i = 0 ; i < n * k_base; i++)
249279 assert (base_labels[i] >= -1 && base_labels[i] < ntotal);
0 commit comments