1414#include < faiss/utils/Heap.h>
1515#include < faiss/utils/distances.h>
1616#include < faiss/utils/extra_distances.h>
17+ #include < faiss/utils/prefetch.h>
1718#include < faiss/utils/sorting.h>
1819#include < faiss/utils/utils.h>
1920#include < cstring>
@@ -122,6 +123,39 @@ struct FlatL2Dis : FlatCodesDistanceComputer {
122123 void set_query (const float * x) override {
123124 q = x;
124125 }
126+
127+ // compute four distances
128+ void distances_batch_4 (
129+ const idx_t idx0,
130+ const idx_t idx1,
131+ const idx_t idx2,
132+ const idx_t idx3,
133+ float & dis0,
134+ float & dis1,
135+ float & dis2,
136+ float & dis3) final override {
137+ ndis += 4 ;
138+
139+ // compute first, assign next
140+ const float * __restrict y0 =
141+ reinterpret_cast <const float *>(codes + idx0 * code_size);
142+ const float * __restrict y1 =
143+ reinterpret_cast <const float *>(codes + idx1 * code_size);
144+ const float * __restrict y2 =
145+ reinterpret_cast <const float *>(codes + idx2 * code_size);
146+ const float * __restrict y3 =
147+ reinterpret_cast <const float *>(codes + idx3 * code_size);
148+
149+ float dp0 = 0 ;
150+ float dp1 = 0 ;
151+ float dp2 = 0 ;
152+ float dp3 = 0 ;
153+ fvec_L2sqr_batch_4 (q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
154+ dis0 = dp0;
155+ dis1 = dp1;
156+ dis2 = dp2;
157+ dis3 = dp3;
158+ }
125159};
126160
127161struct FlatIPDis : FlatCodesDistanceComputer {
@@ -131,13 +165,13 @@ struct FlatIPDis : FlatCodesDistanceComputer {
131165 const float * b;
132166 size_t ndis;
133167
134- float symmetric_dis (idx_t i, idx_t j) override {
168+ float symmetric_dis (idx_t i, idx_t j) final override {
135169 return fvec_inner_product (b + j * d, b + i * d, d);
136170 }
137171
138- float distance_to_code (const uint8_t * code) final {
172+ float distance_to_code (const uint8_t * code) final override {
139173 ndis++;
140- return fvec_inner_product (q, (float *)code, d);
174+ return fvec_inner_product (q, (const float *)code, d);
141175 }
142176
143177 explicit FlatIPDis (const IndexFlat& storage, const float * q = nullptr )
@@ -153,6 +187,39 @@ struct FlatIPDis : FlatCodesDistanceComputer {
153187 void set_query (const float * x) override {
154188 q = x;
155189 }
190+
191+ // compute four distances
192+ void distances_batch_4 (
193+ const idx_t idx0,
194+ const idx_t idx1,
195+ const idx_t idx2,
196+ const idx_t idx3,
197+ float & dis0,
198+ float & dis1,
199+ float & dis2,
200+ float & dis3) final override {
201+ ndis += 4 ;
202+
203+ // compute first, assign next
204+ const float * __restrict y0 =
205+ reinterpret_cast <const float *>(codes + idx0 * code_size);
206+ const float * __restrict y1 =
207+ reinterpret_cast <const float *>(codes + idx1 * code_size);
208+ const float * __restrict y2 =
209+ reinterpret_cast <const float *>(codes + idx2 * code_size);
210+ const float * __restrict y3 =
211+ reinterpret_cast <const float *>(codes + idx3 * code_size);
212+
213+ float dp0 = 0 ;
214+ float dp1 = 0 ;
215+ float dp2 = 0 ;
216+ float dp3 = 0 ;
217+ fvec_inner_product_batch_4 (q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
218+ dis0 = dp0;
219+ dis1 = dp1;
220+ dis2 = dp2;
221+ dis3 = dp3;
222+ }
156223};
157224
158225} // namespace
@@ -184,6 +251,131 @@ void IndexFlat::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
184251 }
185252}
186253
254+ /* **************************************************
255+ * IndexFlatL2
256+ ***************************************************/
257+
258+ namespace {
259+ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
260+ size_t d;
261+ idx_t nb;
262+ const float * q;
263+ const float * b;
264+ size_t ndis;
265+
266+ const float * l2norms;
267+ float query_l2norm;
268+
269+ float distance_to_code (const uint8_t * code) final override {
270+ ndis++;
271+ return fvec_L2sqr (q, (float *)code, d);
272+ }
273+
274+ float operator ()(const idx_t i) final override {
275+ const float * __restrict y =
276+ reinterpret_cast <const float *>(codes + i * code_size);
277+
278+ prefetch_L2 (l2norms + i);
279+ const float dp0 = fvec_inner_product (q, y, d);
280+ return query_l2norm + l2norms[i] - 2 * dp0;
281+ }
282+
283+ float symmetric_dis (idx_t i, idx_t j) final override {
284+ const float * __restrict yi =
285+ reinterpret_cast <const float *>(codes + i * code_size);
286+ const float * __restrict yj =
287+ reinterpret_cast <const float *>(codes + j * code_size);
288+
289+ prefetch_L2 (l2norms + i);
290+ prefetch_L2 (l2norms + j);
291+ const float dp0 = fvec_inner_product (yi, yj, d);
292+ return l2norms[i] + l2norms[j] - 2 * dp0;
293+ }
294+
295+ explicit FlatL2WithNormsDis (
296+ const IndexFlatL2& storage,
297+ const float * q = nullptr )
298+ : FlatCodesDistanceComputer(
299+ storage.codes.data(),
300+ storage.code_size),
301+ d(storage.d),
302+ nb(storage.ntotal),
303+ q(q),
304+ b(storage.get_xb()),
305+ ndis(0 ),
306+ l2norms(storage.cached_l2norms.data()),
307+ query_l2norm(0 ) {}
308+
309+ void set_query (const float * x) override {
310+ q = x;
311+ query_l2norm = fvec_norm_L2sqr (q, d);
312+ }
313+
314+ // compute four distances
315+ void distances_batch_4 (
316+ const idx_t idx0,
317+ const idx_t idx1,
318+ const idx_t idx2,
319+ const idx_t idx3,
320+ float & dis0,
321+ float & dis1,
322+ float & dis2,
323+ float & dis3) final override {
324+ ndis += 4 ;
325+
326+ // compute first, assign next
327+ const float * __restrict y0 =
328+ reinterpret_cast <const float *>(codes + idx0 * code_size);
329+ const float * __restrict y1 =
330+ reinterpret_cast <const float *>(codes + idx1 * code_size);
331+ const float * __restrict y2 =
332+ reinterpret_cast <const float *>(codes + idx2 * code_size);
333+ const float * __restrict y3 =
334+ reinterpret_cast <const float *>(codes + idx3 * code_size);
335+
336+ prefetch_L2 (l2norms + idx0);
337+ prefetch_L2 (l2norms + idx1);
338+ prefetch_L2 (l2norms + idx2);
339+ prefetch_L2 (l2norms + idx3);
340+
341+ float dp0 = 0 ;
342+ float dp1 = 0 ;
343+ float dp2 = 0 ;
344+ float dp3 = 0 ;
345+ fvec_inner_product_batch_4 (q, y0, y1, y2, y3, d, dp0, dp1, dp2, dp3);
346+ dis0 = query_l2norm + l2norms[idx0] - 2 * dp0;
347+ dis1 = query_l2norm + l2norms[idx1] - 2 * dp1;
348+ dis2 = query_l2norm + l2norms[idx2] - 2 * dp2;
349+ dis3 = query_l2norm + l2norms[idx3] - 2 * dp3;
350+ }
351+ };
352+
353+ } // namespace
354+
355+ void IndexFlatL2::sync_l2norms () {
356+ cached_l2norms.resize (ntotal);
357+ fvec_norms_L2sqr (
358+ cached_l2norms.data (),
359+ reinterpret_cast <const float *>(codes.data ()),
360+ d,
361+ ntotal);
362+ }
363+
364+ void IndexFlatL2::clear_l2norms () {
365+ cached_l2norms.clear ();
366+ cached_l2norms.shrink_to_fit ();
367+ }
368+
369+ FlatCodesDistanceComputer* IndexFlatL2::get_FlatCodesDistanceComputer () const {
370+ if (metric_type == METRIC_L2) {
371+ if (!cached_l2norms.empty ()) {
372+ return new FlatL2WithNormsDis (*this );
373+ }
374+ }
375+
376+ return IndexFlat::get_FlatCodesDistanceComputer ();
377+ }
378+
187379/* **************************************************
188380 * IndexFlat1D
189381 ***************************************************/
0 commit comments