Skip to content

Commit 5f5c556

Browse files
Alexandr Guzhvafacebook-github-bot
authored andcommitted
Speedup exhaustive_L2sqr_blas for AVX2, ARM NEON and AVX512 (#2568)
Summary: Pull Request resolved: #2568 Add a fused kernel for exhaustive_L2sqr_blas() call that combines a computation of dot product and the search for the nearest centroid. As a result, no temporary dot product values are written and read in RAM. Speeds up the training of PQx[1] indices for dsub = 1, 2, 4, 8, and the effect is higher for higher values of [1]. AVX512 version provides additional overloads for dsub = 12, 16. The speedup is also beneficial for higher values of pq.cp.max_points_per_centroid (which is 256 by default). Speeds up IVFPQ training as well. AVX512 kernel is not enabled, but I've seen it speeding up the training TWICE versus AVX2 version. So, please feel free to use it by enabling AVX512 manually. Reviewed By: mdouze Differential Revision: D41166766 fbshipit-source-id: f490d7e60f1c1b94a3f412a92f3d72ca8c5d8e1e
1 parent ab13122 commit 5f5c556

File tree

13 files changed

+1202
-5
lines changed

13 files changed

+1202
-5
lines changed

faiss/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ set(FAISS_SRC
8686
utils/quantize_lut.cpp
8787
utils/random.cpp
8888
utils/utils.cpp
89+
utils/distances_fused/avx512.cpp
90+
utils/distances_fused/distances_fused.cpp
91+
utils/distances_fused/simdlib_based.cpp
8992
)
9093

9194
set(FAISS_HEADERS
@@ -187,6 +190,9 @@ set(FAISS_HEADERS
187190
utils/simdlib_emulated.h
188191
utils/simdlib_neon.h
189192
utils/utils.h
193+
utils/distances_fused/avx512.h
194+
utils/distances_fused/distances_fused.h
195+
utils/distances_fused/simdlib_based.h
190196
)
191197

192198
if(NOT WIN32)

faiss/utils/distances.cpp

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
#include <faiss/impl/IDSelector.h>
2727
#include <faiss/impl/ResultHandler.h>
2828

29+
#include <faiss/utils/distances_fused/distances_fused.h>
30+
2931
#ifndef FINTEGER
3032
#define FINTEGER long
3133
#endif
@@ -229,7 +231,7 @@ void exhaustive_inner_product_blas(
229231
// distance correction is an operator that can be applied to transform
230232
// the distances
231233
template <class ResultHandler>
232-
void exhaustive_L2sqr_blas(
234+
void exhaustive_L2sqr_blas_default_impl(
233235
const float* x,
234236
const float* y,
235237
size_t d,
@@ -311,10 +313,20 @@ void exhaustive_L2sqr_blas(
311313
}
312314
}
313315

316+
template <class ResultHandler>
317+
void exhaustive_L2sqr_blas(
318+
const float* x,
319+
const float* y,
320+
size_t d,
321+
size_t nx,
322+
size_t ny,
323+
ResultHandler& res,
324+
const float* y_norms = nullptr) {
325+
exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res);
326+
}
327+
314328
#ifdef __AVX2__
315-
// an override for AVX2 if only a single closest point is needed.
316-
template <>
317-
void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
329+
void exhaustive_L2sqr_blas_cmax_avx2(
318330
const float* x,
319331
const float* y,
320332
size_t d,
@@ -513,11 +525,53 @@ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
513525
res.add_result(i, current_min_distance, current_min_index);
514526
}
515527
}
528+
// Does nothing for SingleBestResultHandler, but
529+
// keeping the call for the consistency.
530+
res.end_multiple();
516531
InterruptCallback::check();
517532
}
518533
}
519534
#endif
520535

536+
// an override if only a single closest point is needed
537+
template <>
538+
void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>(
539+
const float* x,
540+
const float* y,
541+
size_t d,
542+
size_t nx,
543+
size_t ny,
544+
SingleBestResultHandler<CMax<float, int64_t>>& res,
545+
const float* y_norms) {
546+
#if defined(__AVX2__)
547+
// use a faster fused kernel if available
548+
if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
549+
// the kernel is available and it is complete, we're done.
550+
return;
551+
}
552+
553+
// run the specialized AVX2 implementation
554+
exhaustive_L2sqr_blas_cmax_avx2(x, y, d, nx, ny, res, y_norms);
555+
556+
#elif defined(__aarch64__)
557+
// use a faster fused kernel if available
558+
if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) {
559+
// the kernel is available and it is complete, we're done.
560+
return;
561+
}
562+
563+
// run the default implementation
564+
exhaustive_L2sqr_blas_default_impl<
565+
SingleBestResultHandler<CMax<float, int64_t>>>(
566+
x, y, d, nx, ny, res, y_norms);
567+
#else
568+
// run the default implementation
569+
exhaustive_L2sqr_blas_default_impl<
570+
SingleBestResultHandler<CMax<float, int64_t>>>(
571+
x, y, d, nx, ny, res, y_norms);
572+
#endif
573+
}
574+
521575
template <class ResultHandler>
522576
void knn_L2sqr_select(
523577
const float* x,

0 commit comments

Comments
 (0)