diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 2b1fe8aa09..5b5b47aa9a 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -218,7 +218,6 @@ add_library( src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu src/distance/distance.cu src/distance/pairwise_distance.cu - src/neighbors/brute_force_index.cu src/neighbors/brute_force.cu src/neighbors/cagra_build_float.cu src/neighbors/cagra_build_int8.cu diff --git a/cpp/include/cuvs/neighbors/brute_force.hpp b/cpp/include/cuvs/neighbors/brute_force.hpp index 26951e1ec3..755a941228 100644 --- a/cpp/include/cuvs/neighbors/brute_force.hpp +++ b/cpp/include/cuvs/neighbors/brute_force.hpp @@ -18,8 +18,10 @@ #include "ann_types.hpp" #include +#include #include #include +#include namespace cuvs::neighbors::brute_force { @@ -42,34 +44,90 @@ struct index : cuvs::neighbors::ann::index { index& operator=(const index&) = delete; index& operator=(index&&) = default; ~index() = default; - index(void* raft_index); + + /** Construct a brute force index from dataset + * + * Constructs a brute force index from a dataset. This lets us precompute norms for + * the dataset, providing a speed benefit over doing this at query time. + * This index will store a non-owning reference to the dataset. + */ + index(raft::resources const& res, + raft::host_matrix_view dataset_view, + std::optional>&& norms, + cuvs::distance::DistanceType metric, + T metric_arg = 0.0); + + /** Construct a brute force index from dataset + * + * Constructs a brute force index from a dataset. This lets us precompute norms for + * the dataset, providing a speed benefit over doing this at query time. + * The dataset will be copied to the device and the index will own the device memory. + */ + index(raft::resources const& res, + raft::device_matrix_view dataset_view, + std::optional>&& norms, + cuvs::distance::DistanceType metric, + T metric_arg = 0.0); + + /** Construct a brute force index from dataset + * + * This class stores a non-owning reference to the dataset and norms here. + * Having precomputed norms gives us a performance advantage at query time. + */ + index(raft::resources const& res, + raft::device_matrix_view dataset_view, + std::optional> norms_view, + cuvs::distance::DistanceType metric, + T metric_arg = 0.0); + + /** + * Replace the dataset with a new dataset. + */ + void update_dataset(raft::resources const& res, + raft::device_matrix_view dataset); + + /** + * Replace the dataset with a new dataset. + * + * We create a copy of the dataset on the device. The index manages the lifetime of this copy. + */ + void update_dataset(raft::resources const& res, + raft::host_matrix_view dataset); /** Distance metric used for retrieval */ - cuvs::distance::DistanceType metric() const noexcept; + cuvs::distance::DistanceType metric() const noexcept { return metric_; } /** Metric argument */ - T metric_arg() const noexcept; + T metric_arg() const noexcept { return metric_arg_; } /** Total length of the index (number of vectors). */ - size_t size() const noexcept; + size_t size() const noexcept { return dataset_view_.extent(0); } /** Dimensionality of the data. */ - size_t dim() const noexcept; + size_t dim() const noexcept { return dataset_view_.extent(1); } /** Dataset [size, dim] */ - raft::device_matrix_view dataset() const noexcept; + raft::device_matrix_view dataset() const noexcept + { + return dataset_view_; + } /** Dataset norms */ - raft::device_vector_view norms() const; + raft::device_vector_view norms() const + { + return norms_view_.value(); + } /** Whether ot not this index has dataset norms */ - bool has_norms() const noexcept; - - // Get pointer to underlying RAFT index, not meant to be used outside of cuVS - inline const void* get_raft_index() const noexcept { return raft_index_.get(); } + inline bool has_norms() const noexcept { return norms_view_.has_value(); } private: - std::unique_ptr raft_index_; + cuvs::distance::DistanceType metric_; + raft::device_matrix dataset_; + std::optional> norms_; + std::optional> norms_view_; + raft::device_matrix_view dataset_view_; + T metric_arg_; }; /** * @} diff --git a/cpp/src/neighbors/brute_force.cu b/cpp/src/neighbors/brute_force.cu index 33dc2088ca..b9b74a3f25 100644 --- a/cpp/src/neighbors/brute_force.cu +++ b/cpp/src/neighbors/brute_force.cu @@ -14,41 +14,98 @@ * limitations under the License. */ +#include "./detail/knn_brute_force.cuh" #include -#include + +#include namespace cuvs::neighbors::brute_force { +template +index::index(raft::resources const& res, + raft::host_matrix_view dataset, + std::optional>&& norms, + cuvs::distance::DistanceType metric, + T metric_arg) + : ann::index(), + metric_(metric), + dataset_(raft::make_device_matrix(res, 0, 0)), + norms_(std::move(norms)), + metric_arg_(metric_arg) +{ + if (norms_) { norms_view_ = raft::make_const_mdspan(norms_.value().view()); } + update_dataset(res, dataset); + raft::resource::sync_stream(res); +} + +template +index::index(raft::resources const& res, + raft::device_matrix_view dataset, + std::optional>&& norms, + cuvs::distance::DistanceType metric, + T metric_arg) + : ann::index(), + metric_(metric), + dataset_(raft::make_device_matrix(res, 0, 0)), + norms_(std::move(norms)), + metric_arg_(metric_arg) +{ + if (norms_) { norms_view_ = raft::make_const_mdspan(norms_.value().view()); } + update_dataset(res, dataset); +} + +template +index::index(raft::resources const& res, + raft::device_matrix_view dataset_view, + std::optional> norms_view, + cuvs::distance::DistanceType metric, + T metric_arg) + : ann::index(), + metric_(metric), + dataset_(raft::make_device_matrix(res, 0, 0)), + dataset_view_(dataset_view), + norms_view_(norms_view), + metric_arg_(metric_arg) +{ +} + +template +void index::update_dataset(raft::resources const& res, + raft::device_matrix_view dataset) +{ + dataset_view_ = dataset; +} + +template +void index::update_dataset(raft::resources const& res, + raft::host_matrix_view dataset) +{ + dataset_ = raft::make_device_matrix(res, dataset.extent(0), dataset.extent(1)); + raft::copy(res, dataset_.view(), dataset); + dataset_view_ = raft::make_const_mdspan(dataset_.view()); +} -#define CUVS_INST_BFKNN(T, IdxT) \ - auto build(raft::resources const& res, \ - raft::device_matrix_view dataset, \ - cuvs::distance::DistanceType metric, \ - T metric_arg) \ - ->cuvs::neighbors::brute_force::index \ - { \ - auto index_on_stack = raft::neighbors::brute_force::build( \ - res, dataset, static_cast(metric), metric_arg); \ - auto index_on_heap = \ - new raft::neighbors::brute_force::index(std::move(index_on_stack)); \ - return cuvs::neighbors::brute_force::index(index_on_heap); \ - } \ - \ - void search(raft::resources const& res, \ - const cuvs::neighbors::brute_force::index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances) \ - { \ - auto raft_idx = \ - reinterpret_cast*>(idx.get_raft_index()); \ - raft::neighbors::brute_force::search(res, *raft_idx, queries, neighbors, distances); \ - } \ - \ +#define CUVS_INST_BFKNN(T) \ + auto build(raft::resources const& res, \ + raft::device_matrix_view dataset, \ + cuvs::distance::DistanceType metric, \ + T metric_arg) \ + ->cuvs::neighbors::brute_force::index \ + { \ + return detail::build(res, dataset, metric, metric_arg); \ + } \ + \ + void search(raft::resources const& res, \ + const cuvs::neighbors::brute_force::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances) \ + { \ + detail::brute_force_search(res, idx, queries, neighbors, distances); \ + } \ + \ template struct cuvs::neighbors::brute_force::index; -CUVS_INST_BFKNN(float, int64_t); -// CUVS_INST_BFKNN(int8_t, int64_t); -// CUVS_INST_BFKNN(uint8_t, int64_t); +CUVS_INST_BFKNN(float); #undef CUVS_INST_BFKNN diff --git a/cpp/src/neighbors/brute_force_index.cu b/cpp/src/neighbors/brute_force_index.cu deleted file mode 100644 index b05fa7ced9..0000000000 --- a/cpp/src/neighbors/brute_force_index.cu +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace cuvs::neighbors::brute_force { - -template -inline const raft::neighbors::brute_force::index* get_underlying_index( - const cuvs::neighbors::brute_force::index* idx) -{ - return reinterpret_cast*>(idx->get_raft_index()); -} - -template -index::index(void* raft_index) - : cuvs::neighbors::ann::index(), raft_index_(reinterpret_cast(raft_index)) -{ -} - -template -cuvs::distance::DistanceType index::metric() const noexcept -{ - auto raft_index = cuvs::neighbors::brute_force::get_underlying_index(this); - return static_cast((int)raft_index->metric()); -} - -template -size_t index::size() const noexcept -{ - auto raft_index = get_underlying_index(this); - return raft_index->size(); -} - -template -size_t index::dim() const noexcept -{ - auto raft_index = get_underlying_index(this); - return raft_index->dim(); -} - -template -raft::device_matrix_view index::dataset() const noexcept -{ - auto raft_index = get_underlying_index(this); - return raft_index->dataset(); -} - -template -raft::device_vector_view index::norms() const -{ - auto raft_index = get_underlying_index(this); - return raft_index->norms(); -} - -template -bool index::has_norms() const noexcept -{ - auto raft_index = get_underlying_index(this); - return raft_index->has_norms(); -} - -template -T index::metric_arg() const noexcept -{ - auto raft_index = get_underlying_index(this); - return raft_index->metric_arg(); -} - -template struct index; - -} // namespace cuvs::neighbors::brute_force diff --git a/cpp/src/neighbors/detail/faiss_distance_utils.h b/cpp/src/neighbors/detail/faiss_distance_utils.h new file mode 100644 index 0000000000..e8a41c1aa6 --- /dev/null +++ b/cpp/src/neighbors/detail/faiss_distance_utils.h @@ -0,0 +1,52 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +namespace cuvs::neighbors::detail::faiss_select { +// If the inner size (dim) of the vectors is small, we want a larger query tile +// size, like 1024 +inline void chooseTileSize(size_t numQueries, + size_t numCentroids, + size_t dim, + size_t elementSize, + size_t totalMem, + size_t& tileRows, + size_t& tileCols) +{ + // The matrix multiplication should be large enough to be efficient, but if + // it is too large, we seem to lose efficiency as opposed to + // double-streaming. Each tile size here defines 1/2 of the memory use due + // to double streaming. We ignore available temporary memory, as that is + // adjusted independently by the user and can thus meet these requirements + // (or not). For <= 4 GB GPUs, prefer 512 MB of usage. For <= 8 GB GPUs, + // prefer 768 MB of usage. Otherwise, prefer 1 GB of usage. + size_t targetUsage = 0; + + if (totalMem <= ((size_t)4) * 1024 * 1024 * 1024) { + targetUsage = 512 * 1024 * 1024; + } else if (totalMem <= ((size_t)8) * 1024 * 1024 * 1024) { + targetUsage = 768 * 1024 * 1024; + } else { + targetUsage = 1024 * 1024 * 1024; + } + + targetUsage /= 2 * elementSize; + + // 512 seems to be a batch size sweetspot for float32. + // If we are on float16, increase to 512. + // If the k size (vec dim) of the matrix multiplication is small (<= 32), + // increase to 1024. + size_t preferredTileRows = 512; + if (dim <= 32) { preferredTileRows = 1024; } + + tileRows = std::min(preferredTileRows, numQueries); + + // tileCols is the remainder size + tileCols = std::min(targetUsage / preferredTileRows, numCentroids); +} +} // namespace cuvs::neighbors::detail::faiss_select diff --git a/cpp/src/neighbors/detail/fused_l2_knn.cuh b/cpp/src/neighbors/detail/fused_l2_knn.cuh new file mode 100644 index 0000000000..13ea4d4189 --- /dev/null +++ b/cpp/src/neighbors/detail/fused_l2_knn.cuh @@ -0,0 +1,1061 @@ +/* + * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include + +#include + +#include + +#include "../../distance/detail/distance.cuh" +#include "../../distance/detail/distance_ops/l2_exp.cuh" +#include "../../distance/detail/distance_ops/l2_unexp.cuh" +#include "../../distance/detail/pairwise_distance_base.cuh" +#include "../../distance/distance.cuh" + +namespace cuvs::neighbors::detail { + +template +DI void loadAllWarpQShmem(myWarpSelect** heapArr, + Pair* shDumpKV, + const IdxT m, + const unsigned int numOfNN) +{ + const int lid = raft::laneId(); +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (rowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const int idx = j * warpSize + lid; + if (idx < numOfNN) { + Pair KVPair = shDumpKV[rowId * numOfNN + idx]; + heapArr[i]->warpV[j] = KVPair.key; + heapArr[i]->warpK[j] = KVPair.value; + } + } + } + } +} + +template +DI void loadWarpQShmem(myWarpSelect* heapArr, + Pair* shDumpKV, + const int rowId, + const unsigned int numOfNN) +{ + const int lid = raft::laneId(); +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const int idx = j * warpSize + lid; + if (idx < numOfNN) { + Pair KVPair = shDumpKV[rowId * numOfNN + idx]; + heapArr->warpV[j] = KVPair.key; + heapArr->warpK[j] = KVPair.value; + } + } +} + +template +DI void storeWarpQShmem(myWarpSelect* heapArr, + Pair* shDumpKV, + const IdxT rowId, + const unsigned int numOfNN) +{ + const int lid = raft::laneId(); + +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const int idx = j * warpSize + lid; + if (idx < numOfNN) { + Pair otherKV = Pair(heapArr->warpV[j], heapArr->warpK[j]); + shDumpKV[rowId * numOfNN + idx] = otherKV; + } + } +} + +template +DI void storeWarpQGmem(myWarpSelect** heapArr, + volatile OutT* out_dists, + volatile IdxT* out_inds, + const IdxT m, + const unsigned int numOfNN, + const IdxT starty) +{ + const int lid = raft::laneId(); +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + out_dists[std::size_t(gmemRowId) * numOfNN + idx] = heapArr[i]->warpK[j]; + out_inds[std::size_t(gmemRowId) * numOfNN + idx] = (IdxT)heapArr[i]->warpV[j]; + } + } + } + } +} + +template +DI void loadPrevTopKsGmemWarpQ(myWarpSelect** heapArr, + volatile OutT* out_dists, + volatile IdxT* out_inds, + const IdxT m, + const unsigned int numOfNN, + const IdxT starty) +{ + const int lid = raft::laneId(); +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + heapArr[i]->warpK[j] = out_dists[std::size_t(gmemRowId) * numOfNN + idx]; + heapArr[i]->warpV[j] = (uint32_t)out_inds[std::size_t(gmemRowId) * numOfNN + idx]; + } + } + static constexpr auto kLaneWarpKTop = myWarpSelect::kNumWarpQRegisters - 1; + heapArr[i]->warpKTop = raft::shfl(heapArr[i]->warpK[kLaneWarpKTop], heapArr[i]->kLane); + } + } +} + +template +DI void updateSortedWarpQ( + myWarpSelect& heapArr, Pair* allWarpTopKs, int rowId, int finalNumVals, int startId = 0) +{ + constexpr uint32_t mask = 0xffffffffu; + const int lid = raft::laneId(); + // calculate srcLane such that tid 0 -> 31, 1 -> 0,... 31 -> 30. + // warp around 0 to 31 required for NN > 32 + const auto srcLane = (warpSize + (lid - 1)) & (warpSize - 1); + + for (int k = startId; k < finalNumVals; k++) { + Pair KVPair = allWarpTopKs[rowId * (256) + k]; +#pragma unroll + for (int i = 0; i < NumWarpQRegs; i++) { + unsigned activeLanes = __ballot_sync(mask, KVPair.value < heapArr->warpK[i]); + if (activeLanes) { + Pair tempKV; + tempKV.value = raft::shfl(heapArr->warpK[i], srcLane); + tempKV.key = raft::shfl(heapArr->warpV[i], srcLane); + const auto firstActiveLane = __ffs(activeLanes) - 1; + if (firstActiveLane == lid) { + heapArr->warpK[i] = KVPair.value; + heapArr->warpV[i] = KVPair.key; + } else if (lid > firstActiveLane) { + heapArr->warpK[i] = tempKV.value; + heapArr->warpV[i] = tempKV.key; + } + if (i == 0 && NumWarpQRegs > 1) { + heapArr->warpK[1] = __shfl_up_sync(mask, heapArr->warpK[1], 1); + heapArr->warpV[1] = __shfl_up_sync(mask, heapArr->warpV[1], 1); + if (lid == 0) { + heapArr->warpK[1] = tempKV.value; + heapArr->warpV[1] = tempKV.key; + } + break; + } + } + } + } +} + +template +__launch_bounds__(Policy::Nthreads, 2) RAFT_KERNEL fusedL2kNN(const DataT* x, + const DataT* y, + const DataT* _xn, + const DataT* _yn, + const IdxT m, + const IdxT n, + const IdxT k, + const IdxT lda, + const IdxT ldb, + const IdxT ldd, + OpT distance_op, + FinalLambda fin_op, + unsigned int numOfNN, + volatile int* mutexes, + volatile OutT* out_dists, + volatile IdxT* out_inds) +{ + using AccT = typename OpT::AccT; + extern __shared__ char smem[]; + + typedef cub::KeyValuePair Pair; + constexpr auto identity = std::numeric_limits::max(); + constexpr auto keyMax = std::numeric_limits::max(); + constexpr auto Dir = false; + using namespace raft::neighbors::detail::faiss_select; + typedef WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; + + auto rowEpilog_lambda = + [m, n, &distance_op, numOfNN, out_dists, out_inds, mutexes] __device__(IdxT gridStrideY) { + if (gridDim.x == 1) { return; } + + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + const int lid = threadIdx.x % warpSize; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + + // 0 -> consumer done consuming the buffer. + // -1 -> consumer started consuming the buffer + // -2 -> producer done filling the buffer + // 1 -> prod acquired to fill the buffer + if (blockIdx.x == 0) { + auto cta_processed = 0; + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + __syncwarp(); + + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + + while (cta_processed < gridDim.x - 1) { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) + ; + } + __threadfence(); + __syncthreads(); + +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + otherKV.value = out_dists[rowId * numOfNN + idx]; + otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + } + } + } + } + __threadfence(); + __syncthreads(); + + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } + __threadfence(); + + // Perform merging of otherKV with topk's across warp. +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { +#pragma unroll + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + } + heapArr[i]->add(otherKV.value, otherKV.key); + } + } + } + cta_processed++; + } +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(0xffffffff, needSort); + if (needSort) { heapArr[i]->reduce(); } + } + } + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } else { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) + ; + } + __threadfence(); + __syncthreads(); + +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + for (int idx = lid; idx < numOfNN; idx += warpSize) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; + out_dists[rowId * numOfNN + idx] = KVPair.value; + out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; + } + } + } + __threadfence(); + __syncthreads(); + + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } + __threadfence(); + } + }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = + [&distance_op, numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( + AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + constexpr uint32_t mask = 0xffffffffu; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); + const int lid = raft::laneId(); + + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + if (usePrevTopKs) { + if (gridStrideX == blockIdx.x * Policy::Nblk) { + loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); + } + } + + if (gridStrideX > blockIdx.x * Policy::Nblk) { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; + heapArr[i]->warpKTop = tempKV.value; + } + + // total vals can atmost be 256, (32*8) + int numValsWarpTopK[Policy::AccRowsPerTh]; + int anyWarpTopKs = 0; +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + numValsWarpTopK[i] = 0; + if (rowId < m) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } + } + } + anyWarpTopKs += numValsWarpTopK[i]; + } + } + anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); + if (anyWarpTopKs) { + Pair* allWarpTopKs = (Pair*)(&smem[0]); + uint32_t needScanSort[Policy::AccRowsPerTh]; + +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + needScanSort[i] = 0; + if (gmemRowId < m) { + int myVals = numValsWarpTopK[i]; + needScanSort[i] = __ballot_sync(mask, myVals > 0); + if (needScanSort[i]) { +#pragma unroll + for (unsigned int k = 1; k <= 16; k *= 2) { + const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); + if (lid >= k) { numValsWarpTopK[i] += n; } + } + } + // As each thread will know its total vals to write. + // we only store its starting location. + numValsWarpTopK[i] -= myVals; + } + + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { + if (needScanSort[i] & ((uint32_t)1 << lid)) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { + Pair otherKV = {colId, acc[i][j]}; + allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; + numValsWarpTopK[i]++; + } + } + } + } + __syncwarp(); + const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); + loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); + updateSortedWarpQ( + heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); + } + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { + storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); + } + } + } + } + } else { +#pragma unroll + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { +#pragma unroll + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + Pair otherKV = {keyMax, identity}; + if (colId < ldd) { + otherKV.value = acc[i][j]; + otherKV.key = colId; + } + heapArr[i]->add(otherKV.value, otherKV.key); + } + + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(mask, needSort); + if (needSort) { heapArr[i]->reduce(); } + storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); + } + } + } + + if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { + // This is last iteration of grid stride X + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } + }; + + constexpr bool write_out = false; + cuvs::distance::detail::PairwiseDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + _xn, + _yn, + nullptr, // output ptr, can be null as write_out == false. + smem, + distance_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +} + +template +void fusedL2UnexpKnnImpl(const DataT* x, + const DataT* y, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + bool sqrt, + OutT* out_dists, + IdxT* out_inds, + IdxT numOfNN, + cudaStream_t stream, + void* workspace, + size_t& worksize) +{ + typedef typename raft::linalg::Policy2x8::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef typename std::conditional::type KPolicy; + + ASSERT(isRowMajor, "Only Row major inputs are allowed"); + + dim3 blk(KPolicy::Nthreads); + // Accumulation operation lambda + typedef cub::KeyValuePair Pair; + + cuvs::distance::detail::ops::l2_unexp_distance_op distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN; + constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + + auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; + if (numOfNN <= 32) { + fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; + } else if (numOfNN <= 64) { + fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn64RowMajor; + } else { + ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); + } + + const auto sharedMemSize = + distance_op.template shared_mem_size() + KPolicy::Mblk * numOfNN * sizeof(Pair); + + dim3 grid = cuvs::distance::detail::launchConfigGenerator( + m, n, sharedMemSize, fusedL2UnexpKnnRowMajor); + + if (grid.x > 1) { + const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); + if (workspace == nullptr || worksize < (sizeof(int32_t) * numMutexes)) { + worksize = sizeof(int32_t) * numMutexes; + return; + } else { + RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int32_t) * numMutexes, stream)); + } + } + + fusedL2UnexpKnnRowMajor<<>>(x, + y, + nullptr, + nullptr, + m, + n, + k, + lda, + ldb, + ldd, + distance_op, + fin_op, + (uint32_t)numOfNN, + (int*)workspace, + out_dists, + out_inds); + } else { + } + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +void fusedL2UnexpKnn(IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + const DataT* x, + const DataT* y, + bool sqrt, + OutT* out_dists, + IdxT* out_inds, + IdxT numOfNN, + cudaStream_t stream, + void* workspace, + size_t& worksize) +{ + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + fusedL2UnexpKnnImpl( + x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + fusedL2UnexpKnnImpl( + x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } else { + fusedL2UnexpKnnImpl(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } +} + +template +void fusedL2ExpKnnImpl(const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + bool sqrt, + OutT* out_dists, + IdxT* out_inds, + IdxT numOfNN, + cudaStream_t stream, + void* workspace, + size_t& worksize) +{ + typedef typename raft::linalg::Policy2x8::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef typename std::conditional::type KPolicy; + + ASSERT(isRowMajor, "Only Row major inputs are allowed"); + + ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))), + "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + dim3 blk(KPolicy::Nthreads); + + typedef cub::KeyValuePair Pair; + + cuvs::distance::detail::ops::l2_exp_distance_op distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN; + constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + + auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; + if (numOfNN <= 32) { + fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; + } else if (numOfNN <= 64) { + fusedL2ExpKnnRowMajor = fusedL2ExpKnn64RowMajor; + } else { + ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); + } + + const auto sharedMemSize = + distance_op.template shared_mem_size() + (KPolicy::Mblk * numOfNN * sizeof(Pair)); + dim3 grid = cuvs::distance::detail::launchConfigGenerator( + m, n, sharedMemSize, fusedL2ExpKnnRowMajor); + int32_t* mutexes = nullptr; + if (grid.x > 1) { + const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); + const auto normsSize = (x != y) ? (m + n) * sizeof(DataT) : n * sizeof(DataT); + const auto requiredSize = sizeof(int32_t) * numMutexes + normsSize; + if (worksize < requiredSize) { + worksize = requiredSize; + return; + } else { + mutexes = (int32_t*)((char*)workspace + normsSize); + RAFT_CUDA_TRY(cudaMemsetAsync(mutexes, 0, sizeof(int32_t) * numMutexes, stream)); + } + } + + // calculate norms if they haven't been passed in + if (!xn) { + DataT* xn_ = (DataT*)workspace; + workspace = xn_ + m; + raft::linalg::rowNorm( + xn_, x, k, m, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + xn = xn_; + } + if (!yn) { + if (x == y) { + yn = xn; + } else { + DataT* yn_ = (DataT*)(workspace); + raft::linalg::rowNorm( + yn_, y, k, n, raft::linalg::L2Norm, isRowMajor, stream, raft::identity_op{}); + yn = yn_; + } + } + + fusedL2ExpKnnRowMajor<<>>(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + distance_op, + fin_op, + (uint32_t)numOfNN, + mutexes, + out_dists, + out_inds); + } else { + } + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +template +void fusedL2ExpKnn(IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + bool sqrt, + OutT* out_dists, + IdxT* out_inds, + IdxT numOfNN, + cudaStream_t stream, + void* workspace, + size_t& worksize) +{ + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + fusedL2ExpKnnImpl( + x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + fusedL2ExpKnnImpl( + x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } else { + fusedL2ExpKnnImpl(x, + y, + xn, + yn, + m, + n, + k, + lda, + ldb, + ldd, + sqrt, + out_dists, + out_inds, + numOfNN, + stream, + workspace, + worksize); + } +} + +/** + * Compute the k-nearest neighbors using L2 expanded/unexpanded distance. + + * @tparam value_idx + * @tparam value_t + * @param[out] out_inds output indices array on device (size n_query_rows * k) + * @param[out] out_dists output dists array on device (size n_query_rows * k) + * @param[in] index input index array on device (size n_index_rows * D) + * @param[in] query input query array on device (size n_query_rows * D) + * @param[in] n_index_rows number of rows in index array + * @param[in] n_query_rows number of rows in query array + * @param[in] k number of closest neighbors to return + * @param[in] rowMajorIndex are the index arrays in row-major layout? + * @param[in] rowMajorQuery are the query array in row-major layout? + * @param[in] stream stream to order kernel launch + */ +template +void fusedL2Knn(size_t D, + value_idx* out_inds, + value_t* out_dists, + const value_t* index, + const value_t* query, + size_t n_index_rows, + size_t n_query_rows, + int k, + bool rowMajorIndex, + bool rowMajorQuery, + cudaStream_t stream, + cuvs::distance::DistanceType metric, + const value_t* index_norms = NULL, + const value_t* query_norms = NULL) +{ + // Validate the input data + ASSERT(k > 0, "l2Knn: k must be > 0"); + ASSERT(D > 0, "l2Knn: D must be > 0"); + ASSERT(n_index_rows > 0, "l2Knn: n_index_rows must be > 0"); + ASSERT(index, "l2Knn: index must be provided (passed null)"); + ASSERT(n_query_rows > 0, "l2Knn: n_query_rows must be > 0"); + ASSERT(query, "l2Knn: query must be provided (passed null)"); + ASSERT(out_dists, "l2Knn: out_dists must be provided (passed null)"); + ASSERT(out_inds, "l2Knn: out_inds must be provided (passed null)"); + // Currently we only support same layout for x & y inputs. + ASSERT(rowMajorIndex == rowMajorQuery, + "l2Knn: rowMajorIndex and rowMajorQuery should have same layout"); + // TODO: Add support for column major layout + ASSERT(rowMajorIndex == true, "l2Knn: only rowMajor inputs are supported for now."); + + // Even for L2 Sqrt distance case we use non-sqrt version as FAISS bfKNN only support + // non-sqrt metric & some tests in RAFT/cuML (like Linkage) fails if we use L2 sqrt. + constexpr bool sqrt = false; + + size_t worksize = 0, tempWorksize = 0; + rmm::device_uvector workspace(worksize, stream); + value_idx lda = D, ldb = D, ldd = n_index_rows; + // + switch (metric) { + case cuvs::distance::DistanceType::L2SqrtExpanded: + case cuvs::distance::DistanceType::L2Expanded: + tempWorksize = + cuvs::distance::getWorkspaceSize(query, index, n_query_rows, n_index_rows, D); + worksize = tempWorksize; + workspace.resize(worksize, stream); + fusedL2ExpKnn(n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + query_norms, + index_norms, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); + if (worksize > tempWorksize) { + workspace.resize(worksize, stream); + fusedL2ExpKnn(n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + query_norms, + index_norms, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); + } + break; + case cuvs::distance::DistanceType::L2Unexpanded: + case cuvs::distance::DistanceType::L2SqrtUnexpanded: + fusedL2UnexpKnn(n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); + if (worksize) { + workspace.resize(worksize, stream); + fusedL2UnexpKnn(n_query_rows, + n_index_rows, + D, + lda, + ldb, + ldd, + query, + index, + sqrt, + out_dists, + out_inds, + k, + stream, + workspace.data(), + worksize); + } + break; + default: printf("only L2 distance metric is supported\n"); break; + }; +} + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/haversine_distance.cuh b/cpp/src/neighbors/detail/haversine_distance.cuh new file mode 100644 index 0000000000..fc6aa477d1 --- /dev/null +++ b/cpp/src/neighbors/detail/haversine_distance.cuh @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::detail { +template +DI value_t compute_haversine(value_t x1, value_t y1, value_t x2, value_t y2) +{ + value_t sin_0 = raft::sin(0.5 * (x1 - y1)); + value_t sin_1 = raft::sin(0.5 * (x2 - y2)); + value_t rdist = sin_0 * sin_0 + raft::cos(x1) * raft::cos(y1) * sin_1 * sin_1; + + return 2 * raft::asin(raft::sqrt(rdist)); +} + +/** + * @tparam value_idx data type of indices + * @tparam value_t data type of values and distances + * @tparam warp_q + * @tparam thread_q + * @tparam tpb + * @param[out] out_inds output indices + * @param[out] out_dists output distances + * @param[in] index index array + * @param[in] query query array + * @param[in] n_index_rows number of rows in index array + * @param[in] k number of closest neighbors to return + */ +template +RAFT_KERNEL haversine_knn_kernel(value_idx* out_inds, + value_t* out_dists, + const value_t* index, + const value_t* query, + size_t n_index_rows, + int k) +{ + constexpr int kNumWarps = tpb / raft::WarpSize; + + __shared__ value_t smemK[kNumWarps * warp_q]; + __shared__ value_idx smemV[kNumWarps * warp_q]; + + using namespace raft::neighbors::detail::faiss_select; + BlockSelect, warp_q, thread_q, tpb> heap( + std::numeric_limits::max(), std::numeric_limits::max(), smemK, smemV, k); + + // Grid is exactly sized to rows available + int limit = raft::Pow2::roundDown(n_index_rows); + + const value_t* query_ptr = query + (blockIdx.x * 2); + value_t x1 = query_ptr[0]; + value_t x2 = query_ptr[1]; + + int i = threadIdx.x; + + for (; i < limit; i += tpb) { + const value_t* idx_ptr = index + (i * 2); + value_t y1 = idx_ptr[0]; + value_t y2 = idx_ptr[1]; + + value_t dist = compute_haversine(x1, y1, x2, y2); + + heap.add(dist, i); + } + + // Handle last remainder fraction of a warp of elements + if (i < n_index_rows) { + const value_t* idx_ptr = index + (i * 2); + value_t y1 = idx_ptr[0]; + value_t y2 = idx_ptr[1]; + + value_t dist = compute_haversine(x1, y1, x2, y2); + + heap.addThreadQ(dist, i); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += tpb) { + out_dists[blockIdx.x * k + i] = smemK[i]; + out_inds[blockIdx.x * k + i] = smemV[i]; + } +} + +/** + * Conmpute the k-nearest neighbors using the Haversine + * (great circle arc) distance. Input is assumed to have + * 2 dimensions (latitude, longitude) in radians. + + * @tparam value_idx + * @tparam value_t + * @param[out] out_inds output indices array on device (size n_query_rows * k) + * @param[out] out_dists output dists array on device (size n_query_rows * k) + * @param[in] index input index array on device (size n_index_rows * 2) + * @param[in] query input query array on device (size n_query_rows * 2) + * @param[in] n_index_rows number of rows in index array + * @param[in] n_query_rows number of rows in query array + * @param[in] k number of closest neighbors to return + * @param[in] stream stream to order kernel launch + */ +template +void haversine_knn(value_idx* out_inds, + value_t* out_dists, + const value_t* index, + const value_t* query, + size_t n_index_rows, + size_t n_query_rows, + int k, + cudaStream_t stream) +{ + // ensure kernel does not breach shared memory limits + constexpr int kWarpQ = sizeof(value_t) > 4 ? 512 : 1024; + haversine_knn_kernel + <<>>(out_inds, out_dists, index, query, n_index_rows, k); +} + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh new file mode 100644 index 0000000000..2f5aa176d8 --- /dev/null +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -0,0 +1,567 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +#include "../../distance/detail/distance_ops/l2_exp.cuh" +#include "./faiss_distance_utils.h" +#include "./fused_l2_knn.cuh" +#include "./haversine_distance.cuh" +#include "./knn_merge_parts.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include +#include + +namespace cuvs::neighbors::detail { +/** + * Calculates brute force knn, using a fixed memory budget + * by tiling over both the rows and columns of pairwise_distances + */ +template +void tiled_brute_force_knn(const raft::resources& handle, + const ElementType* search, // size (m ,d) + const ElementType* index, // size (n ,d) + size_t m, + size_t n, + size_t d, + size_t k, + ElementType* distances, // size (m, k) + IndexType* indices, // size (m, k) + cuvs::distance::DistanceType metric, + float metric_arg = 2.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0, + const ElementType* precomputed_index_norms = nullptr, + const ElementType* precomputed_search_norms = nullptr) +{ + // Figure out the number of rows/cols to tile for + size_t tile_rows = 0; + size_t tile_cols = 0; + auto stream = raft::resource::get_cuda_stream(handle); + auto device_memory = raft::resource::get_workspace_resource(handle); + auto total_mem = rmm::available_device_memory().second; + + cuvs::neighbors::detail::faiss_select::chooseTileSize( + m, n, d, sizeof(ElementType), total_mem, tile_rows, tile_cols); + + // for unittesting, its convenient to be able to put a max size on the tiles + // so we can test the tiling logic without having to use huge inputs. + if (max_row_tile_size && (tile_rows > max_row_tile_size)) { tile_rows = max_row_tile_size; } + if (max_col_tile_size && (tile_cols > max_col_tile_size)) { tile_cols = max_col_tile_size; } + + // tile_cols must be at least k items + tile_cols = std::max(tile_cols, k); + + // stores pairwise distances for the current tile + rmm::device_uvector temp_distances(tile_rows * tile_cols, stream); + + // calculate norms for L2 expanded distances - this lets us avoid calculating + // norms repeatedly per-tile, and just do once for the entire input + auto pairwise_metric = metric; + rmm::device_uvector search_norms(0, stream); + rmm::device_uvector index_norms(0, stream); + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded) { + if (!precomputed_search_norms) { search_norms.resize(m, stream); } + if (!precomputed_index_norms) { index_norms.resize(n, stream); } + // cosine needs the l2norm, where as l2 distances needs the squared norm + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + if (!precomputed_search_norms) { + raft::linalg::rowNorm(search_norms.data(), + search, + d, + m, + raft::linalg::NormType::L2Norm, + true, + stream, + raft::sqrt_op{}); + } + if (!precomputed_index_norms) { + raft::linalg::rowNorm(index_norms.data(), + index, + d, + n, + raft::linalg::NormType::L2Norm, + true, + stream, + raft::sqrt_op{}); + } + } else { + if (!precomputed_search_norms) { + raft::linalg::rowNorm( + search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); + } + if (!precomputed_index_norms) { + raft::linalg::rowNorm( + index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); + } + } + pairwise_metric = cuvs::distance::DistanceType::InnerProduct; + } + + // if we're tiling over columns, we need additional buffers for temporary output + // distances/indices + size_t num_col_tiles = raft::ceildiv(n, tile_cols); + size_t temp_out_cols = k * num_col_tiles; + + // the final column tile could have less than 'k' items in it + // in which case the number of columns here is too high in the temp output. + // adjust if necessary + auto last_col_tile_size = n % tile_cols; + if (last_col_tile_size && (last_col_tile_size < k)) { temp_out_cols -= k - last_col_tile_size; } + + // if we have less than k items in the index, we should fill out the result + // to indicate that we are missing items (and match behaviour in faiss) + if (n < k) { + raft::matrix::fill(handle, + raft::make_device_matrix_view(distances, m, k), + std::numeric_limits::lowest()); + + if constexpr (std::is_signed_v) { + raft::matrix::fill(handle, raft::make_device_matrix_view(indices, m, k), IndexType{-1}); + } + } + + rmm::device_uvector temp_out_distances(tile_rows * temp_out_cols, stream); + rmm::device_uvector temp_out_indices(tile_rows * temp_out_cols, stream); + + bool select_min = cuvs::distance::is_min_close(metric); + + for (size_t i = 0; i < m; i += tile_rows) { + size_t current_query_size = std::min(tile_rows, m - i); + + for (size_t j = 0; j < n; j += tile_cols) { + size_t current_centroid_size = std::min(tile_cols, n - j); + size_t current_k = std::min(current_centroid_size, k); + + // calculate the top-k elements for the current tile, by calculating the + // full pairwise distance for the tile - and then selecting the top-k from that + cuvs::distance::pairwise_distance( + handle, + raft::make_device_matrix_view( + search + i * d, current_query_size, d), + raft::make_device_matrix_view( + index + j * d, current_centroid_size, d), + raft::make_device_matrix_view( + temp_distances.data(), current_query_size, current_centroid_size), + pairwise_metric, + metric_arg); + + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); + auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); + auto dist = temp_distances.data(); + bool sqrt = metric == cuvs::distance::DistanceType::L2SqrtExpanded; + + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(dist, current_query_size * current_centroid_size), + [=] __device__(IndexType idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); + + cuvs::distance::detail::ops::l2_exp_cutlass_op l2_op(sqrt); + return l2_op(row_norms[row], col_norms[col], dist[idx]); + }); + } else if (metric == cuvs::distance::DistanceType::CosineExpanded) { + auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); + auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); + auto dist = temp_distances.data(); + + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(dist, current_query_size * current_centroid_size), + [=] __device__(IndexType idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); + auto val = 1.0 - dist[idx] / (row_norms[row] * col_norms[col]); + return val; + }); + } + + raft::matrix::select_k( + handle, + raft::make_device_matrix_view( + temp_distances.data(), current_query_size, current_centroid_size), + std::nullopt, + raft::make_device_matrix_view( + distances + i * k, current_query_size, current_k), + raft::make_device_matrix_view( + indices + i * k, current_query_size, current_k), + select_min, + true); + + // if we're tiling over columns, we need to do a couple things to fix up + // the output of select_k + // 1. The column id's in the output are relative to the tile, so we need + // to adjust the column ids by adding the column the tile starts at (j) + // 2. select_k writes out output in a row-major format, which means we + // can't just concat the output of all the tiles and do a select_k on the + // concatenation. + // Fix both of these problems in a single pass here + if (tile_cols != n) { + const ElementType* in_distances = distances + i * k; + const IndexType* in_indices = indices + i * k; + ElementType* out_distances = temp_out_distances.data(); + IndexType* out_indices = temp_out_indices.data(); + + auto count = thrust::make_counting_iterator(0); + thrust::for_each(raft::resource::get_thrust_policy(handle), + count, + count + current_query_size * current_k, + [=] __device__(IndexType i) { + IndexType row = i / current_k, col = i % current_k; + IndexType out_index = row * temp_out_cols + j * k / tile_cols + col; + + out_distances[out_index] = in_distances[i]; + out_indices[out_index] = in_indices[i] + j; + }); + } + } + + if (tile_cols != n) { + // select the actual top-k items here from the temporary output + raft::matrix::select_k( + handle, + raft::make_device_matrix_view( + temp_out_distances.data(), current_query_size, temp_out_cols), + raft::make_device_matrix_view( + temp_out_indices.data(), current_query_size, temp_out_cols), + raft::make_device_matrix_view( + distances + i * k, current_query_size, k), + raft::make_device_matrix_view( + indices + i * k, current_query_size, k), + select_min, + true); + } + } +} + +/** + * Search the kNN for the k-nearest neighbors of a set of query vectors + * @param[in] input vector of device device memory array pointers to search + * @param[in] sizes vector of memory sizes for each device array pointer in input + * @param[in] D number of cols in input and search_items + * @param[in] search_items set of vectors to query for neighbors + * @param[in] n number of items in search_items + * @param[out] res_I pointer to device memory for returning k nearest indices + * @param[out] res_D pointer to device memory for returning k nearest distances + * @param[in] k number of neighbors to query + * @param[in] userStream the main cuda stream to use + * @param[in] internalStreams optional when n_params > 0, the index partitions can be + * queried in parallel using these streams. Note that n_int_streams also + * has to be > 0 for these to be used and their cardinality does not need + * to correspond to n_parts. + * @param[in] n_int_streams size of internalStreams. When this is <= 0, only the + * user stream will be used. + * @param[in] rowMajorIndex are the index arrays in row-major layout? + * @param[in] rowMajorQuery are the query array in row-major layout? + * @param[in] translations translation ids for indices when index rows represent + * non-contiguous partitions + * @param[in] metric corresponds to the cuvs::distance::DistanceType enum (default is L2Expanded) + * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm + */ +template +void brute_force_knn_impl( + raft::resources const& handle, + std::vector& input, + std::vector& sizes, + IntType D, + value_t* search_items, + IntType n, + IdxType* res_I, + value_t* res_D, + IntType k, + bool rowMajorIndex = true, + bool rowMajorQuery = true, + std::vector* translations = nullptr, + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded, + float metricArg = 0, + std::vector* input_norms = nullptr, + const value_t* search_norms = nullptr) +{ + auto userStream = raft::resource::get_cuda_stream(handle); + + ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size"); + + std::vector* id_ranges; + if (translations == nullptr) { + // If we don't have explicit translations + // for offsets of the indices, build them + // from the local partitions + id_ranges = new std::vector(); + IdxType total_n = 0; + for (size_t i = 0; i < input.size(); i++) { + id_ranges->push_back(total_n); + total_n += sizes[i]; + } + } else { + // otherwise, use the given translations + id_ranges = translations; + } + + int device; + RAFT_CUDA_TRY(cudaGetDevice(&device)); + + rmm::device_uvector trans(id_ranges->size(), userStream); + raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream); + + rmm::device_uvector all_D(0, userStream); + rmm::device_uvector all_I(0, userStream); + + value_t* out_D = res_D; + IdxType* out_I = res_I; + + if (input.size() > 1) { + all_D.resize(input.size() * k * n, userStream); + all_I.resize(input.size() * k * n, userStream); + + out_D = all_D.data(); + out_I = all_I.data(); + } + + // currently we don't support col_major inside tiled_brute_force_knn, because + // of limitations of the pairwise_distance API: + // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have + // multiple options here (like rowMajorQuery/rowMajorIndex) + // 2) because of tiling, we need to be able to set a custom stride in the PW + // api, which isn't supported + // Instead, transpose the input matrices if they are passed as col-major. + auto search = search_items; + rmm::device_uvector search_row_major(0, userStream); + if (!rowMajorQuery) { + search_row_major.resize(n * D, userStream); + raft::linalg::transpose(handle, search, search_row_major.data(), n, D, userStream); + search = search_row_major.data(); + } + + // transpose into a temporary buffer if necessary + rmm::device_uvector index_row_major(0, userStream); + if (!rowMajorIndex) { + size_t total_size = 0; + for (auto size : sizes) { + total_size += size; + } + index_row_major.resize(total_size * D, userStream); + } + + // Make other streams from pool wait on main stream + raft::resource::wait_stream_pool_on_stream(handle); + + size_t total_rows_processed = 0; + for (size_t i = 0; i < input.size(); i++) { + value_t* out_d_ptr = out_D + (i * k * n); + IdxType* out_i_ptr = out_I + (i * k * n); + + auto stream = raft::resource::get_next_usable_stream(handle, i); + + if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && + (metric == cuvs::distance::DistanceType::L2Unexpanded || + metric == cuvs::distance::DistanceType::L2SqrtUnexpanded || + metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded)) { + fusedL2Knn(D, + out_i_ptr, + out_d_ptr, + input[i], + search_items, + sizes[i], + n, + k, + rowMajorIndex, + rowMajorQuery, + stream, + metric, + input_norms ? (*input_norms)[i] : nullptr, + search_norms); + + // Perform necessary post-processing + if (metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::L2SqrtUnexpanded || + metric == cuvs::distance::DistanceType::LpUnexpanded) { + value_t p = 0.5; // standard l2 + if (metric == cuvs::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; + raft::linalg::unaryOp( + res_D, + res_D, + n * k, + [p] __device__(value_t input) { return powf(fabsf(input), p); }, + stream); + } + } else { + switch (metric) { + case cuvs::distance::DistanceType::Haversine: + ASSERT(D == 2, + "Haversine distance requires 2 dimensions " + "(latitude / longitude)."); + + haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); + break; + default: + // Create a new handle with the current stream from the stream pool + raft::resources stream_pool_handle(handle); + raft::resource::set_cuda_stream(stream_pool_handle, stream); + + auto index = input[i]; + if (!rowMajorIndex) { + index = index_row_major.data() + total_rows_processed * D; + total_rows_processed += sizes[i]; + raft::linalg::transpose(handle, input[i], index, sizes[i], D, stream); + } + + tiled_brute_force_knn(stream_pool_handle, + search, + index, + n, + sizes[i], + D, + k, + out_d_ptr, + out_i_ptr, + metric, + metricArg, + 0, + 0, + input_norms ? (*input_norms)[i] : nullptr, + search_norms); + break; + } + } + + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + // Sync internal streams if used. We don't need to + // sync the user stream because we'll already have + // fully serial execution. + raft::resource::sync_stream_pool(handle); + + if (input.size() > 1 || translations != nullptr) { + // This is necessary for proper index translations. If there are + // no translations or partitions to combine, it can be skipped. + knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data()); + } + + if (translations == nullptr) delete id_ranges; +}; + +template +void brute_force_search( + raft::resources const& res, + const cuvs::neighbors::brute_force::index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances, + std::optional> query_norms = std::nullopt) +{ + RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); + RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), + "Number of columns in queries must match brute force index"); + + auto k = neighbors.extent(1); + auto d = idx.dataset().extent(1); + + std::vector dataset = {const_cast(idx.dataset().data_handle())}; + std::vector sizes = {idx.dataset().extent(0)}; + std::vector norms; + if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); } + + brute_force_knn_impl(res, + dataset, + sizes, + d, + const_cast(queries.data_handle()), + queries.extent(0), + neighbors.data_handle(), + distances.data_handle(), + k, + true, + true, + nullptr, + idx.metric(), + idx.metric_arg(), + norms.size() ? &norms : nullptr, + query_norms ? query_norms->data_handle() : nullptr); +} + +template +cuvs::neighbors::brute_force::index build( + raft::resources const& res, + raft::device_matrix_view dataset, + cuvs::distance::DistanceType metric, + T metric_arg) +{ + // certain distance metrics can benefit by pre-calculating the norms for the index dataset + // which lets us avoid calculating these at query time + std::optional> norms; + auto dataset_storage = std::optional>{}; + auto dataset_view = [&res, &dataset_storage, dataset]() { + if constexpr (std::is_same_v>) { + return dataset; + } else { + dataset_storage = + raft::make_device_matrix(res, dataset.extent(0), dataset.extent(1)); + raft::copy(res, dataset_storage->view(), dataset); + return raft::make_const_mdspan(dataset_storage->view()); + } + }(); + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded || + metric == cuvs::distance::DistanceType::CosineExpanded) { + norms = raft::make_device_vector(res, dataset.extent(0)); + // cosine needs the l2norm, where as l2 distances needs the squared norm + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm(res, + dataset_view, + norms->view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op{}); + } else { + raft::linalg::norm(res, + dataset_view, + norms->view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS); + } + } + + return cuvs::neighbors::brute_force::index(res, dataset, std::move(norms), metric, metric_arg); +} +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/knn_merge_parts.cuh b/cpp/src/neighbors/detail/knn_merge_parts.cuh new file mode 100644 index 0000000000..0e9410c7b4 --- /dev/null +++ b/cpp/src/neighbors/detail/knn_merge_parts.cuh @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace cuvs::neighbors::detail { + +template +RAFT_KERNEL knn_merge_parts_kernel(const value_t* inK, + const value_idx* inV, + value_t* outK, + value_idx* outV, + size_t n_samples, + int n_parts, + value_t initK, + value_idx initV, + int k, + value_idx* translations) +{ + constexpr int kNumWarps = tpb / raft::WarpSize; + + __shared__ value_t smemK[kNumWarps * warp_q]; + __shared__ value_idx smemV[kNumWarps * warp_q]; + + /** + * Uses shared memory + */ + raft::neighbors::detail::faiss_select::BlockSelect< + value_t, + value_idx, + false, + raft::neighbors::detail::faiss_select::Comparator, + warp_q, + thread_q, + tpb> + heap(initK, initV, smemK, smemV, k); + + // Grid is exactly sized to rows available + int row = blockIdx.x; + int total_k = k * n_parts; + + int i = threadIdx.x; + + // Get starting pointers for cols in current thread + int part = i / k; + size_t row_idx = (row * k) + (part * n_samples * k); + + int col = i % k; + + const value_t* inKStart = inK + (row_idx + col); + const value_idx* inVStart = inV + (row_idx + col); + + int limit = raft::Pow2::roundDown(total_k); + value_idx translation = 0; + + for (; i < limit; i += tpb) { + translation = translations[part]; + heap.add(*inKStart, (*inVStart) + translation); + + part = (i + tpb) / k; + row_idx = (row * k) + (part * n_samples * k); + + col = (i + tpb) % k; + + inKStart = inK + (row_idx + col); + inVStart = inV + (row_idx + col); + } + + // Handle last remainder fraction of a warp of elements + if (i < total_k) { + translation = translations[part]; + heap.addThreadQ(*inKStart, (*inVStart) + translation); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += tpb) { + outK[row * k + i] = smemK[i]; + outV[row * k + i] = smemV[i]; + } +} + +template +inline void knn_merge_parts_impl(const value_t* inK, + const value_idx* inV, + value_t* outK, + value_idx* outV, + size_t n_samples, + int n_parts, + int k, + cudaStream_t stream, + value_idx* translations) +{ + auto grid = dim3(n_samples); + + constexpr int n_threads = (warp_q < 1024) ? 128 : 64; + auto block = dim3(n_threads); + + auto kInit = std::numeric_limits::max(); + auto vInit = -1; + knn_merge_parts_kernel + <<>>( + inK, inV, outK, outV, n_samples, n_parts, kInit, vInit, k, translations); + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +/** + * @brief Merge knn distances and index matrix, which have been partitioned + * by row, into a single matrix with only the k-nearest neighbors. + * + * @param inK partitioned knn distance matrix + * @param inV partitioned knn index matrix + * @param outK merged knn distance matrix + * @param outV merged knn index matrix + * @param n_samples number of samples per partition + * @param n_parts number of partitions + * @param k number of neighbors per partition (also number of merged neighbors) + * @param stream CUDA stream to use + * @param translations mapping of index offsets for each partition + */ +template +inline void knn_merge_parts(const value_t* inK, + const value_idx* inV, + value_t* outK, + value_idx* outV, + size_t n_samples, + int n_parts, + int k, + cudaStream_t stream, + value_idx* translations) +{ + if (k == 1) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 32) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 64) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 128) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 256) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 512) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else if (k <= 1024) + knn_merge_parts_impl( + inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); + else + THROW("Unimplemented for k=%d, knn_merge_parts works for k<=1024", k); +} +} // namespace cuvs::neighbors::detail