Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions c_api/IndexIVF_c_ex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,23 @@ int faiss_IndexIVF_compute_distance_to_codes_for_list(
const float* x,
idx_t n,
const uint8_t* codes,
float* dists) {
float* dists,
float* dist_table) {
try {
reinterpret_cast<IndexIVF*>(index)->compute_distance_to_codes_for_list(
list_no, x, n, codes, dists);
list_no, x, n, codes, dists, dist_table);
return 0;
}
CATCH_AND_HANDLE
}

int faiss_IndexIVF_compute_distance_table(
FaissIndexIVF* index,
const float* x,
float* dist_table) {
try {
reinterpret_cast<IndexIVF*>(index)->compute_distance_table(
x, dist_table);
return 0;
}
CATCH_AND_HANDLE
Expand Down
18 changes: 17 additions & 1 deletion c_api/IndexIVF_c_ex.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ int faiss_IndexIVF_search_preassigned_with_params(
@param n - number of codes
@param codes - input codes
@param dists - output computed distances
@param dist_table - input precomputed distance table for PQ
*/

int faiss_IndexIVF_compute_distance_to_codes_for_list(
Expand All @@ -89,7 +90,8 @@ int faiss_IndexIVF_compute_distance_to_codes_for_list(
const float* x,
idx_t n,
const uint8_t* codes,
float* dists);
float* dists,
float* dist_table);

/*
Given multiple vector IDs, retrieve the corresponding list (cluster) IDs
Expand All @@ -108,6 +110,20 @@ int faiss_get_lists_for_keys(
size_t n_keys,
idx_t* lists);

/*
Given a query vector x, compute distance table and
return to the caller.

@param x - input query vector
@param dist_table - output precomputed distance table for PQ

*/

int faiss_IndexIVF_compute_distance_table(
FaissIndexIVF* index,
const float* x,
float* dist_table);

#ifdef __cplusplus
}
#endif
Expand Down
16 changes: 15 additions & 1 deletion faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,28 @@ struct IndexIVF : Index, IndexIVFInterface {
* @param n - number of codes
* @param codes - input codes
* @param dists - output computed distances
* @param dist_table - input precomputed distance table for PQ
*/

virtual void compute_distance_to_codes_for_list(
const idx_t list_no,
const float* x,
idx_t n,
const uint8_t* codes,
float* dists) const {};
float* dists,
float* dist_table) const {};

/** Given a query vector x, compute distance table and
* return to the caller.
*
* @param x - input query vector
* @param dist_table - output precomputed distance table for PQ
*
*/

virtual void compute_distance_table(
const float* x,
float* dist_table) const {};


IndexIVF();
Expand Down
101 changes: 101 additions & 0 deletions faiss/IndexIVFPQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <cstdio>

#include <algorithm>
#include <numeric>

#include <faiss/utils/Heap.h>
#include <faiss/utils/distances.h>
Expand Down Expand Up @@ -565,6 +566,7 @@ struct QueryTables {
}
}


/*****************************************************
* When inverted list is known: prepare computations
*****************************************************/
Expand Down Expand Up @@ -748,6 +750,33 @@ struct QueryTables {

return dis0;
}


void init_sim_table(const float* qi, const float* table) {
this->qi = qi;

if (metric_type == METRIC_INNER_PRODUCT) {
memcpy(sim_table, table, pq.ksub * pq.M * sizeof(float));
} else {
if (!by_residual) {
memcpy(sim_table, table, pq.ksub * pq.M * sizeof(float));
} else {
memcpy(sim_table_2, table, pq.ksub * pq.M * sizeof(float));
}
}
}

void copy_sim_table(float* table) const {
if (metric_type == METRIC_INNER_PRODUCT) {
memcpy(table, sim_table, pq.ksub * pq.M * sizeof(float));
} else {
if (!by_residual) {
memcpy(table, sim_table, pq.ksub * pq.M * sizeof(float));
} else {
memcpy(table, sim_table_2, pq.ksub * pq.M * sizeof(float));
}
}
}
};

// This way of handling the selector is not optimal since all distances
Expand Down Expand Up @@ -1387,4 +1416,76 @@ size_t IndexIVFPQ::find_duplicates(idx_t* dup_ids, size_t* lims) const {
return ngroup;
}

void IndexIVFPQ::compute_distance_to_codes_for_list(
const idx_t list_no,
const float* x,
idx_t n,
const uint8_t* codes,
float* dists,
float* dist_table) const {

std::unique_ptr<InvertedListScanner> scanner(
get_InvertedListScanner(true, nullptr));


if (dist_table) {
if (auto* pqscanner = dynamic_cast<QueryTables*>(scanner.get())) {
pqscanner->init_sim_table(x, dist_table);
}
} else {
scanner->set_query(x);
}

// Initialize distances with default values
std::vector<float> dist_out(n, metric_type == METRIC_L2 ? HUGE_VAL : -HUGE_VAL);


//find the centroid corresponding to the input list_no
//and compute its distance from the query vector
std::vector<float> centroid(d);
quantizer->reconstruct(list_no, centroid.data());

float coarse_dis = quantizer->metric_type == faiss::METRIC_L2
? faiss::fvec_L2sqr(x, centroid.data(), d)
: faiss::fvec_inner_product(x, centroid.data(), d);


scanner->set_list(list_no, coarse_dis);

// Initialize ids_in as sequential numbers to allow mapping with output distances.
std::vector<idx_t> ids_in(n);
std::iota(ids_in.begin(), ids_in.end(), 0);

//ids_out contain the order of distances in dist_out after scan_codes returns.
std::vector<idx_t> ids_out(n, 0);

scanner->scan_codes(n, codes, ids_in.data(), dist_out.data(), ids_out.data(), n);

// Reorder the returned distances in dist_out based on ids_out.
// This function needs to return the distances in the same order as input codes.
for (int j = 0; j < n; j++) {
int k = ids_out[j];
dists[k] = dist_out[j];
}

return;
}

//This function computes the distance table for the input vector x and returns it in dtable.
void IndexIVFPQ::compute_distance_table(
const float* x,
float* dist_table) const {

std::unique_ptr<InvertedListScanner> scanner(
get_InvertedListScanner(true, nullptr));

scanner->set_query(x);

if (auto* pqscanner = dynamic_cast<QueryTables*>(scanner.get())) {
pqscanner->copy_sim_table(dist_table);
}

return;
}

} // namespace faiss
12 changes: 12 additions & 0 deletions faiss/IndexIVFPQ.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,18 @@ struct IndexIVFPQ : IndexIVF {
/// build precomputed table
void precompute_table();

void compute_distance_to_codes_for_list(
const idx_t list_no,
const float* x,
idx_t n,
const uint8_t* codes,
float* dists,
float* dist_table) const override;

void compute_distance_table(
const float* x,
float* dist_table) const override;

IndexIVFPQ();
};

Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexScalarQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ void IndexIVFScalarQuantizer::compute_distance_to_codes_for_list(
const float* x,
idx_t n,
const uint8_t* codes,
float* dists) const {
float* dists,
float* dist_table) const {

std::unique_ptr<ScalarQuantizer::SQDistanceComputer> dc(
sq.get_distance_computer(metric_type));
Expand Down
3 changes: 2 additions & 1 deletion faiss/IndexScalarQuantizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ struct IndexIVFScalarQuantizer : IndexIVF {
const float* x,
idx_t n,
const uint8_t* codes,
float* dists) const override;
float* dists,
float* dist_table) const override;

};

Expand Down