|
26 | 26 | #include <faiss/impl/IDSelector.h> |
27 | 27 | #include <faiss/impl/ResultHandler.h> |
28 | 28 |
|
| 29 | +#include <faiss/utils/distances_fused/distances_fused.h> |
| 30 | + |
29 | 31 | #ifndef FINTEGER |
30 | 32 | #define FINTEGER long |
31 | 33 | #endif |
@@ -229,7 +231,7 @@ void exhaustive_inner_product_blas( |
229 | 231 | // distance correction is an operator that can be applied to transform |
230 | 232 | // the distances |
231 | 233 | template <class ResultHandler> |
232 | | -void exhaustive_L2sqr_blas( |
| 234 | +void exhaustive_L2sqr_blas_default_impl( |
233 | 235 | const float* x, |
234 | 236 | const float* y, |
235 | 237 | size_t d, |
@@ -311,10 +313,20 @@ void exhaustive_L2sqr_blas( |
311 | 313 | } |
312 | 314 | } |
313 | 315 |
|
| 316 | +template <class ResultHandler> |
| 317 | +void exhaustive_L2sqr_blas( |
| 318 | + const float* x, |
| 319 | + const float* y, |
| 320 | + size_t d, |
| 321 | + size_t nx, |
| 322 | + size_t ny, |
| 323 | + ResultHandler& res, |
| 324 | + const float* y_norms = nullptr) { |
| 325 | + exhaustive_L2sqr_blas_default_impl(x, y, d, nx, ny, res); |
| 326 | +} |
| 327 | + |
314 | 328 | #ifdef __AVX2__ |
315 | | -// an override for AVX2 if only a single closest point is needed. |
316 | | -template <> |
317 | | -void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>( |
| 329 | +void exhaustive_L2sqr_blas_cmax_avx2( |
318 | 330 | const float* x, |
319 | 331 | const float* y, |
320 | 332 | size_t d, |
@@ -513,11 +525,53 @@ void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>( |
513 | 525 | res.add_result(i, current_min_distance, current_min_index); |
514 | 526 | } |
515 | 527 | } |
| 528 | + // Does nothing for SingleBestResultHandler, but |
| 529 | + // keeping the call for the consistency. |
| 530 | + res.end_multiple(); |
516 | 531 | InterruptCallback::check(); |
517 | 532 | } |
518 | 533 | } |
519 | 534 | #endif |
520 | 535 |
|
| 536 | +// an override if only a single closest point is needed |
| 537 | +template <> |
| 538 | +void exhaustive_L2sqr_blas<SingleBestResultHandler<CMax<float, int64_t>>>( |
| 539 | + const float* x, |
| 540 | + const float* y, |
| 541 | + size_t d, |
| 542 | + size_t nx, |
| 543 | + size_t ny, |
| 544 | + SingleBestResultHandler<CMax<float, int64_t>>& res, |
| 545 | + const float* y_norms) { |
| 546 | +#if defined(__AVX2__) |
| 547 | + // use a faster fused kernel if available |
| 548 | + if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) { |
| 549 | + // the kernel is available and it is complete, we're done. |
| 550 | + return; |
| 551 | + } |
| 552 | + |
| 553 | + // run the specialized AVX2 implementation |
| 554 | + exhaustive_L2sqr_blas_cmax_avx2(x, y, d, nx, ny, res, y_norms); |
| 555 | + |
| 556 | +#elif defined(__aarch64__) |
| 557 | + // use a faster fused kernel if available |
| 558 | + if (exhaustive_L2sqr_fused_cmax(x, y, d, nx, ny, res, y_norms)) { |
| 559 | + // the kernel is available and it is complete, we're done. |
| 560 | + return; |
| 561 | + } |
| 562 | + |
| 563 | + // run the default implementation |
| 564 | + exhaustive_L2sqr_blas_default_impl< |
| 565 | + SingleBestResultHandler<CMax<float, int64_t>>>( |
| 566 | + x, y, d, nx, ny, res, y_norms); |
| 567 | +#else |
| 568 | + // run the default implementation |
| 569 | + exhaustive_L2sqr_blas_default_impl< |
| 570 | + SingleBestResultHandler<CMax<float, int64_t>>>( |
| 571 | + x, y, d, nx, ny, res, y_norms); |
| 572 | +#endif |
| 573 | +} |
| 574 | + |
521 | 575 | template <class ResultHandler> |
522 | 576 | void knn_L2sqr_select( |
523 | 577 | const float* x, |
|
0 commit comments