Skip to content

Commit c093314

Browse files
IndexRefine params
Signed-off-by: Alexandr Guzhva <[email protected]>
1 parent 6c89c8b commit c093314

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

faiss/IndexRefine.cpp

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,24 @@ void IndexRefine::search(
9696
idx_t k,
9797
float* distances,
9898
idx_t* labels,
99-
const SearchParameters* params) const {
100-
FAISS_THROW_IF_NOT_MSG(
101-
!params, "search params not supported for this index");
99+
const SearchParameters* params_in) const {
100+
const IndexRefineSearchParameters* params = nullptr;
101+
if (params_in) {
102+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
103+
FAISS_THROW_IF_NOT_MSG(params, "IndexRefine params have incorrect type");
104+
}
105+
106+
idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor) : idx_t(k * k_factor);
107+
SearchParameters* base_index_params =
108+
(params != nullptr) ? params->base_index_params : nullptr;
109+
110+
FAISS_THROW_IF_NOT(k_base >= k);
111+
112+
FAISS_THROW_IF_NOT(base_index);
113+
FAISS_THROW_IF_NOT(refine_index);
114+
102115
FAISS_THROW_IF_NOT(k > 0);
103116
FAISS_THROW_IF_NOT(is_trained);
104-
idx_t k_base = idx_t(k * k_factor);
105117
idx_t* base_labels = labels;
106118
float* base_distances = distances;
107119
ScopeDeleter<idx_t> del1;
@@ -114,7 +126,7 @@ void IndexRefine::search(
114126
del2.set(base_distances);
115127
}
116128

117-
base_index->search(n, x, k_base, base_distances, base_labels);
129+
base_index->search(n, x, k_base, base_distances, base_labels, base_index_params);
118130

119131
for (int i = 0; i < n * k_base; i++)
120132
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
@@ -225,12 +237,24 @@ void IndexRefineFlat::search(
225237
idx_t k,
226238
float* distances,
227239
idx_t* labels,
228-
const SearchParameters* params) const {
229-
FAISS_THROW_IF_NOT_MSG(
230-
!params, "search params not supported for this index");
240+
const SearchParameters* params_in) const {
241+
const IndexRefineSearchParameters* params = nullptr;
242+
if (params_in) {
243+
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
244+
FAISS_THROW_IF_NOT_MSG(params, "IndexRefineFlat params have incorrect type");
245+
}
246+
247+
idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor) : idx_t(k * k_factor);
248+
SearchParameters* base_index_params =
249+
(params != nullptr) ? params->base_index_params : nullptr;
250+
251+
FAISS_THROW_IF_NOT(k_base >= k);
252+
253+
FAISS_THROW_IF_NOT(base_index);
254+
FAISS_THROW_IF_NOT(refine_index);
255+
231256
FAISS_THROW_IF_NOT(k > 0);
232257
FAISS_THROW_IF_NOT(is_trained);
233-
idx_t k_base = idx_t(k * k_factor);
234258
idx_t* base_labels = labels;
235259
float* base_distances = distances;
236260
ScopeDeleter<idx_t> del1;
@@ -243,7 +267,7 @@ void IndexRefineFlat::search(
243267
del2.set(base_distances);
244268
}
245269

246-
base_index->search(n, x, k_base, base_distances, base_labels);
270+
base_index->search(n, x, k_base, base_distances, base_labels, base_index_params);
247271

248272
for (int i = 0; i < n * k_base; i++)
249273
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);

faiss/IndexRefine.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111

1212
namespace faiss {
1313

14+
struct IndexRefineSearchParameters : SearchParameters {
15+
float k_factor = 1;
16+
SearchParameters* base_index_params = nullptr; // non-owning
17+
18+
virtual ~IndexRefineSearchParameters() = default;
19+
};
20+
1421
/** Index that queries in a base_index (a fast one) and refines the
1522
* results with an exact search, hopefully improving the results.
1623
*/

0 commit comments

Comments
 (0)