Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,10 @@ if(BUILD_SHARED_LIBS)
src/neighbors/ivf_flat/ivf_flat_search_half_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_search_int8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_search_uint8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_float_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_half_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_int8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_interleaved_scan_uint8_t_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_half_int64_t.cu
src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu
Expand Down
16 changes: 9 additions & 7 deletions cpp/src/neighbors/ivf_common.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -148,17 +148,19 @@ __device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT
return ix_min;
}

template <int BlockDim, typename IdxT>
template <int BlockDim, typename IdxT, typename DbIdxT>
__launch_bounds__(BlockDim) RAFT_KERNEL
postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk]
const uint32_t* neighbors_in, // [n_queries, topk]
const IdxT* const* db_indices, // [n_clusters][..]
const DbIdxT* const* db_indices, // [n_clusters][..]
Comment thread
tfeher marked this conversation as resolved.
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
const uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t n_queries,
uint32_t n_probes,
uint32_t topk)
{
static_assert(!raft::is_narrowing_v<uint32_t, IdxT>,
"IdxT must be able to represent all values of uint32_t");
const uint64_t i = threadIdx.x + BlockDim * uint64_t(blockIdx.x);
const uint32_t query_ix = i / uint64_t(topk);
if (query_ix >= n_queries) { return; }
Expand All @@ -170,8 +172,8 @@ __launch_bounds__(BlockDim) RAFT_KERNEL
uint32_t data_ix = neighbors_in[k];
const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices);
const bool valid = chunk_ix < n_probes;
neighbors_out[k] =
valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord<IdxT>;
neighbors_out[k] = valid ? static_cast<IdxT>(db_indices[clusters_to_probe[chunk_ix]][data_ix])
: kOutOfBoundsRecord<IdxT>;
}

/**
Expand All @@ -181,10 +183,10 @@ __launch_bounds__(BlockDim) RAFT_KERNEL
* probed clusters / defined by the `chunk_indices`.
* We assume the searched sample sizes (for a single query) fit into `uint32_t`.
*/
template <typename IdxT>
template <typename IdxT, typename DbIdxT>
void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk]
const uint32_t* neighbors_in, // [n_queries, topk]
const IdxT* const* db_indices, // [n_clusters][..]
const DbIdxT* const* db_indices, // [n_clusters][..]
const uint32_t* clusters_to_probe, // [n_queries, n_probes]
const uint32_t* chunk_indices, // [n_queries, n_probes]
uint32_t n_queries,
Expand Down
18 changes: 11 additions & 7 deletions cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -100,10 +100,11 @@ auto clone(const raft::resources& res, const index<T, IdxT>& source) -> index<T,
* there are no dependencies between threads, hence no constraints on the block size.
*
* @tparam T element type.
* @tparam IdxT type of the indices in the source source_vecs
* @tparam IdxT type of the vector ids in the index (corresponds to second arg ofindex<T, IdxT>)
* @tparam LabelT label type
* @tparam gather_src if false, then we build the index from vectors source_vecs[i,:], otherwise
* we use source_vecs[source_ixs[i],:]. In both cases i=0..n_rows-1.
* @tparam SourceIndexT input index type (usually same as IdxT)
*
* @param[in] labels device pointer to the cluster ids for each row [n_rows]
* @param[in] source_vecs device pointer to the input data [n_rows, dim]
Expand All @@ -118,10 +119,10 @@ auto clone(const raft::resources& res, const index<T, IdxT>& source) -> index<T,
* @param veclen size of vectorized loads/stores; must satisfy `dim % veclen == 0`.
*
*/
template <typename T, typename IdxT, typename LabelT, bool gather_src = false>
template <typename T, typename IdxT, typename LabelT, bool gather_src = false, typename SourceIdxT>
Comment thread
tfeher marked this conversation as resolved.
Outdated
RAFT_KERNEL build_index_kernel(const LabelT* labels,
const T* source_vecs,
const IdxT* source_ixs,
const SourceIdxT* source_ixs,
T** list_data_ptrs,
IdxT** list_index_ptrs,
uint32_t* list_sizes_ptr,
Expand All @@ -135,7 +136,10 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels,
auto source_ix = source_ixs == nullptr ? i + batch_offset : source_ixs[i];
// In the context of refinement, some indices may be invalid (the generating NN algorithm does
// not return enough valid items). Do not add the item to the index in this case.
if (source_ix == ivf::kInvalidRecord<IdxT> || source_ix == raft::upper_bound<IdxT>()) { return; }
if (source_ix == ivf::kInvalidRecord<SourceIdxT> ||
source_ix == raft::upper_bound<SourceIdxT>()) {
return;
}
Comment thread
tfeher marked this conversation as resolved.

auto list_id = labels[i];
auto inlist_id = atomicAdd(list_sizes_ptr + list_id, 1);
Expand Down Expand Up @@ -460,11 +464,11 @@ inline auto build(raft::resources const& handle,
* @param[in] candidate_idx device pointer to neighbor candidates, size [n_queries, n_candidates]
* @param[in] n_candidates of neighbor_candidates
*/
template <typename T, typename IdxT>
template <typename T, typename IdxT, typename CandidateIdxT>
inline void fill_refinement_index(raft::resources const& handle,
index<T, IdxT>* refinement_index,
const T* dataset,
const IdxT* candidate_idx,
const CandidateIdxT* candidate_idx,
IdxT n_queries,
uint32_t n_candidates)
{
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -40,11 +40,6 @@ using namespace cuvs::spatial::knn::detail; // NOLINT

constexpr int kThreadsPerBlock = 128;

auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k) -> bool
{
return k <= raft::matrix::detail::select::warpsort::kMaxCapacity;
}

Comment thread
tfeher marked this conversation as resolved.
/**
* @brief Copy `n` elements per block from one place to another.
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "../detail/ann_utils.cuh"
#include "ivf_flat_interleaved_scan.cuh"
#include <cstdint>
#include <cuvs/neighbors/common.hpp>
#include <cuvs/neighbors/ivf_flat.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>

#define CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(T, IdxT, SampleFilterT) \
template void \
ivfflat_interleaved_scan<T, \
typename cuvs::spatial::knn::detail::utils::config<T>::value_t, \
IdxT, \
SampleFilterT>(const index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const uint32_t queries_offset, \
const cuvs::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const uint32_t max_samples, \
const uint32_t* chunk_indices, \
const bool select_min, \
SampleFilterT sample_filter, \
uint32_t* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream);

#define COMMA ,
96 changes: 96 additions & 0 deletions cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cstdint>
#include <cuda_fp16.h>

#include "../detail/ann_utils.cuh"
#include <cuvs/neighbors/common.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
#include <raft/util/raft_explicit.hpp>

namespace cuvs::neighbors::ivf_flat::detail {
template <typename T, typename AccT, typename IdxT, typename IvfSampleFilterT>
void ivfflat_interleaved_scan(const index<T, IdxT>& index,
const T* queries,
const uint32_t* coarse_query_results,
const uint32_t n_queries,
const uint32_t queries_offset,
const cuvs::distance::DistanceType metric,
const uint32_t n_probes,
const uint32_t k,
const uint32_t max_samples,
const uint32_t* chunk_indices,
const bool select_min,
IvfSampleFilterT sample_filter,
uint32_t* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream) RAFT_EXPLICIT;

#define CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(T, IdxT, SampleFilterT) \
Comment thread
tfeher marked this conversation as resolved.
extern template void \
ivfflat_interleaved_scan<T, \
typename cuvs::spatial::knn::detail::utils::config<T>::value_t, \
IdxT, \
SampleFilterT>(const index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const uint32_t queries_offset, \
const cuvs::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const uint32_t max_samples, \
const uint32_t* chunk_indices, \
const bool select_min, \
SampleFilterT sample_filter, \
uint32_t* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(float, int64_t, cuvs::neighbors::filtering::none_sample_filter);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(half, int64_t, cuvs::neighbors::filtering::none_sample_filter);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(int8_t,
int64_t,
cuvs::neighbors::filtering::none_sample_filter);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(uint8_t,
int64_t,
cuvs::neighbors::filtering::none_sample_filter);

#define COMMA ,
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(
float, int64_t, cuvs::neighbors::filtering::bitset_filter<uint32_t COMMA int64_t>);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(
half, int64_t, cuvs::neighbors::filtering::bitset_filter<uint32_t COMMA int64_t>);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(
int8_t, int64_t, cuvs::neighbors::filtering::bitset_filter<uint32_t COMMA int64_t>);

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(
uint8_t, int64_t, cuvs::neighbors::filtering::bitset_filter<uint32_t COMMA int64_t>);
#undef COMMA
#undef CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN

} // namespace cuvs::neighbors::ivf_flat::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ivf_flat_interleaved_scan_explicit_inst.cuh"

namespace cuvs::neighbors::ivf_flat::detail {

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(float, int64_t, filtering::none_sample_filter);
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(float,
int64_t,
filtering::bitset_filter<uint32_t COMMA int64_t>);
Comment thread
tfeher marked this conversation as resolved.
Outdated
} // namespace cuvs::neighbors::ivf_flat::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ivf_flat_interleaved_scan_explicit_inst.cuh"
#include <cuda_fp16.h>

namespace cuvs::neighbors::ivf_flat::detail {

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(half, int64_t, filtering::none_sample_filter);
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(half,
int64_t,
filtering::bitset_filter<uint32_t COMMA int64_t>);

} // namespace cuvs::neighbors::ivf_flat::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ivf_flat_interleaved_scan_explicit_inst.cuh"

namespace cuvs::neighbors::ivf_flat::detail {

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(int8_t, int64_t, filtering::none_sample_filter);
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(int8_t,
int64_t,
filtering::bitset_filter<uint32_t COMMA int64_t>);

#undef COMMA
#undef CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN

} // namespace cuvs::neighbors::ivf_flat::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ivf_flat_interleaved_scan_explicit_inst.cuh"

namespace cuvs::neighbors::ivf_flat::detail {

CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(uint8_t, int64_t, filtering::none_sample_filter);
CUVS_INST_IVF_FLAT_INTERLEAVED_SCAN(uint8_t,
int64_t,
filtering::bitset_filter<uint32_t COMMA int64_t>);

} // namespace cuvs::neighbors::ivf_flat::detail
Loading