Skip to content

Commit ce18dea

Browse files
Alexandr Guzhvafacebook-github-bot
authored andcommitted
HNSW speedup + Distance 4 points (#2841)
Summary: Pull Request resolved: #2841 * Add virtual void DistanceComputer::distances_to_four_indices() * Add the infrastructure * HNSW::search() uses DistanceComputer::distances_to_four_indices() * Add IndexFlatL2::sync_l2norms() and IndexFlatL2::clear_l2norms() that allow to precompute L2 cache for stored vectors and compute L2 distance using dot product * Add downcasting of IndexFlatL2 and IndexFlatIP in swig * Add general-purpose prefetch utilities Reviewed By: mdouze Differential Revision: D45427064 fbshipit-source-id: b8e731a95abe9d6bd026882f8ea3c9862a25bedf
1 parent f276c47 commit ce18dea

File tree

11 files changed

+566
-31
lines changed

11 files changed

+566
-31
lines changed

faiss/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ set(FAISS_HEADERS
190190
utils/hamming.h
191191
utils/ordered_key_value.h
192192
utils/partitioning.h
193+
utils/prefetch.h
193194
utils/quantize_lut.h
194195
utils/random.h
195196
utils/simdlib.h

faiss/IndexFlat.cpp

Lines changed: 195 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <faiss/utils/Heap.h>
1515
#include <faiss/utils/distances.h>
1616
#include <faiss/utils/extra_distances.h>
17+
#include <faiss/utils/prefetch.h>
1718
#include <faiss/utils/sorting.h>
1819
#include <faiss/utils/utils.h>
1920
#include <cstring>
@@ -122,6 +123,39 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
122123
void set_query(const float* x) override {
123124
q = x;
124125
}
126+
127+
// compute four distances
128+
void distances_batch_4(
129+
const idx_t idx0,
130+
const idx_t idx1,
131+
const idx_t idx2,
132+
const idx_t idx3,
133+
float& dis0,
134+
float& dis1,
135+
float& dis2,
136+
float& dis3) final override {
137+
ndis += 4;
138+
139+
// compute first, assign next
140+
const float* __restrict y0 =
141+
reinterpret_cast<const float*>(codes + idx0 * code_size);
142+
const float* __restrict y1 =
143+
reinterpret_cast<const float*>(codes + idx1 * code_size);
144+
const float* __restrict y2 =
145+
reinterpret_cast<const float*>(codes + idx2 * code_size);
146+
const float* __restrict y3 =
147+
reinterpret_cast<const float*>(codes + idx3 * code_size);
148+
149+
float dp0 = 0;
150+
float dp1 = 0;
151+
float dp2 = 0;
152+
float dp3 = 0;
153+
fvec_L2sqr_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
154+
dis0 = dp0;
155+
dis1 = dp1;
156+
dis2 = dp2;
157+
dis3 = dp3;
158+
}
125159
};
126160

127161
struct FlatIPDis : FlatCodesDistanceComputer {
@@ -131,13 +165,13 @@ struct FlatIPDis : FlatCodesDistanceComputer {
131165
const float* b;
132166
size_t ndis;
133167

134-
float symmetric_dis(idx_t i, idx_t j) override {
168+
float symmetric_dis(idx_t i, idx_t j) final override {
135169
return fvec_inner_product(b + j * d, b + i * d, d);
136170
}
137171

138-
float distance_to_code(const uint8_t* code) final {
172+
float distance_to_code(const uint8_t* code) final override {
139173
ndis++;
140-
return fvec_inner_product(q, (float*)code, d);
174+
return fvec_inner_product(q, (const float*)code, d);
141175
}
142176

143177
explicit FlatIPDis(const IndexFlat& storage, const float* q = nullptr)
@@ -153,6 +187,39 @@ struct FlatIPDis : FlatCodesDistanceComputer {
153187
void set_query(const float* x) override {
154188
q = x;
155189
}
190+
191+
// compute four distances
192+
void distances_batch_4(
193+
const idx_t idx0,
194+
const idx_t idx1,
195+
const idx_t idx2,
196+
const idx_t idx3,
197+
float& dis0,
198+
float& dis1,
199+
float& dis2,
200+
float& dis3) final override {
201+
ndis += 4;
202+
203+
// compute first, assign next
204+
const float* __restrict y0 =
205+
reinterpret_cast<const float*>(codes + idx0 * code_size);
206+
const float* __restrict y1 =
207+
reinterpret_cast<const float*>(codes + idx1 * code_size);
208+
const float* __restrict y2 =
209+
reinterpret_cast<const float*>(codes + idx2 * code_size);
210+
const float* __restrict y3 =
211+
reinterpret_cast<const float*>(codes + idx3 * code_size);
212+
213+
float dp0 = 0;
214+
float dp1 = 0;
215+
float dp2 = 0;
216+
float dp3 = 0;
217+
fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
218+
dis0 = dp0;
219+
dis1 = dp1;
220+
dis2 = dp2;
221+
dis3 = dp3;
222+
}
156223
};
157224

158225
} // namespace
@@ -184,6 +251,131 @@ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
184251
}
185252
}
186253

254+
/***************************************************
255+
* IndexFlatL2
256+
***************************************************/
257+
258+
namespace {
259+
struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
260+
size_t d;
261+
idx_t nb;
262+
const float* q;
263+
const float* b;
264+
size_t ndis;
265+
266+
const float* l2norms;
267+
float query_l2norm;
268+
269+
float distance_to_code(const uint8_t* code) final override {
270+
ndis++;
271+
return fvec_L2sqr(q, (float*)code, d);
272+
}
273+
274+
float operator()(const idx_t i) final override {
275+
const float* __restrict y =
276+
reinterpret_cast<const float*>(codes + i * code_size);
277+
278+
prefetch_L2(l2norms + i);
279+
const float dp0 = fvec_inner_product(q, y, d);
280+
return query_l2norm + l2norms[i] - 2 * dp0;
281+
}
282+
283+
float symmetric_dis(idx_t i, idx_t j) final override {
284+
const float* __restrict yi =
285+
reinterpret_cast<const float*>(codes + i * code_size);
286+
const float* __restrict yj =
287+
reinterpret_cast<const float*>(codes + j * code_size);
288+
289+
prefetch_L2(l2norms + i);
290+
prefetch_L2(l2norms + j);
291+
const float dp0 = fvec_inner_product(yi, yj, d);
292+
return l2norms[i] + l2norms[j] - 2 * dp0;
293+
}
294+
295+
explicit FlatL2WithNormsDis(
296+
const IndexFlatL2& storage,
297+
const float* q = nullptr)
298+
: FlatCodesDistanceComputer(
299+
storage.codes.data(),
300+
storage.code_size),
301+
d(storage.d),
302+
nb(storage.ntotal),
303+
q(q),
304+
b(storage.get_xb()),
305+
ndis(0),
306+
l2norms(storage.cached_l2norms.data()),
307+
query_l2norm(0) {}
308+
309+
void set_query(const float* x) override {
310+
q = x;
311+
query_l2norm = fvec_norm_L2sqr(q, d);
312+
}
313+
314+
// compute four distances
315+
void distances_batch_4(
316+
const idx_t idx0,
317+
const idx_t idx1,
318+
const idx_t idx2,
319+
const idx_t idx3,
320+
float& dis0,
321+
float& dis1,
322+
float& dis2,
323+
float& dis3) final override {
324+
ndis += 4;
325+
326+
// compute first, assign next
327+
const float* __restrict y0 =
328+
reinterpret_cast<const float*>(codes + idx0 * code_size);
329+
const float* __restrict y1 =
330+
reinterpret_cast<const float*>(codes + idx1 * code_size);
331+
const float* __restrict y2 =
332+
reinterpret_cast<const float*>(codes + idx2 * code_size);
333+
const float* __restrict y3 =
334+
reinterpret_cast<const float*>(codes + idx3 * code_size);
335+
336+
prefetch_L2(l2norms + idx0);
337+
prefetch_L2(l2norms + idx1);
338+
prefetch_L2(l2norms + idx2);
339+
prefetch_L2(l2norms + idx3);
340+
341+
float dp0 = 0;
342+
float dp1 = 0;
343+
float dp2 = 0;
344+
float dp3 = 0;
345+
fvec_inner_product_batch_4(q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
346+
dis0 = query_l2norm + l2norms[idx0] - 2 * dp0;
347+
dis1 = query_l2norm + l2norms[idx1] - 2 * dp1;
348+
dis2 = query_l2norm + l2norms[idx2] - 2 * dp2;
349+
dis3 = query_l2norm + l2norms[idx3] - 2 * dp3;
350+
}
351+
};
352+
353+
} // namespace
354+
355+
void IndexFlatL2::sync_l2norms() {
356+
cached_l2norms.resize(ntotal);
357+
fvec_norms_L2sqr(
358+
cached_l2norms.data(),
359+
reinterpret_cast<const float*>(codes.data()),
360+
d,
361+
ntotal);
362+
}
363+
364+
void IndexFlatL2::clear_l2norms() {
365+
cached_l2norms.clear();
366+
cached_l2norms.shrink_to_fit();
367+
}
368+
369+
FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer() const {
370+
if (metric_type == METRIC_L2) {
371+
if (!cached_l2norms.empty()) {
372+
return new FlatL2WithNormsDis(*this);
373+
}
374+
}
375+
376+
return IndexFlat::get_FlatCodesDistanceComputer();
377+
}
378+
187379
/***************************************************
188380
* IndexFlat1D
189381
***************************************************/

faiss/IndexFlat.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,22 @@ struct IndexFlatIP : IndexFlat {
7676
};
7777

7878
struct IndexFlatL2 : IndexFlat {
79+
// Special cache for L2 norms.
80+
// If this cache is set, then get_distance_computer() returns
81+
// a special version that computes the distance using dot products
82+
// and l2 norms.
83+
std::vector<float> cached_l2norms;
84+
7985
explicit IndexFlatL2(idx_t d) : IndexFlat(d, METRIC_L2) {}
8086
IndexFlatL2() {}
87+
88+
// override for l2 norms cache.
89+
FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
90+
91+
// compute L2 norms
92+
void sync_l2norms();
93+
// clear L2 norms
94+
void clear_l2norms();
8195
};
8296

8397
/// optimized version for 1D "vectors".

faiss/IndexHNSW.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,10 @@ IndexHNSWFlat::IndexHNSWFlat() {
872872
}
873873

874874
IndexHNSWFlat::IndexHNSWFlat(int d, int M, MetricType metric)
875-
: IndexHNSW(new IndexFlat(d, metric), M) {
875+
: IndexHNSW(
876+
(metric == METRIC_L2) ? new IndexFlatL2(d)
877+
: new IndexFlat(d, metric),
878+
M) {
876879
own_fields = true;
877880
is_trained = true;
878881
}

faiss/impl/DistanceComputer.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,29 @@ struct DistanceComputer {
3030
/// compute distance of vector i to current query
3131
virtual float operator()(idx_t i) = 0;
3232

33+
/// compute distances of current query to 4 stored vectors.
34+
/// certain DistanceComputer implementations may benefit
35+
/// heavily from this.
36+
virtual void distances_batch_4(
37+
const idx_t idx0,
38+
const idx_t idx1,
39+
const idx_t idx2,
40+
const idx_t idx3,
41+
float& dis0,
42+
float& dis1,
43+
float& dis2,
44+
float& dis3) {
45+
// compute first, assign next
46+
const float d0 = this->operator()(idx0);
47+
const float d1 = this->operator()(idx1);
48+
const float d2 = this->operator()(idx2);
49+
const float d3 = this->operator()(idx3);
50+
dis0 = d0;
51+
dis1 = d1;
52+
dis2 = d2;
53+
dis3 = d3;
54+
}
55+
3356
/// compute distance between two stored vectors
3457
virtual float symmetric_dis(idx_t i, idx_t j) = 0;
3558

@@ -49,7 +72,7 @@ struct FlatCodesDistanceComputer : DistanceComputer {
4972

5073
FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {}
5174

52-
float operator()(idx_t i) final {
75+
float operator()(idx_t i) override {
5376
return distance_to_code(codes + i * code_size);
5477
}
5578

0 commit comments

Comments
 (0)