Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -230,20 +230,15 @@ struct search : public search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
num_queries,
dev_seed_ptr,
num_executed_iterations,
*this,
topk,
thread_block_size,
result_buffer_size,
smem_size,
hash_bitlen,
hashmap.data(),
num_cta_per_query,
num_random_samplings,
rand_xor_mask,
num_seeds,
itopk_size,
search_width,
min_iterations,
max_iterations,
sample_filter,
this->metric,
stream);
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,15 @@ namespace cuvs::neighbors::cagra::detail::multi_cta_search {
const uint32_t num_queries, \
const typename DATASET_DESC_T::INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
typename DATASET_DESC_T::INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down
19 changes: 7 additions & 12 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ void select_and_run(
const uint32_t num_queries,
const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds]
uint32_t* const num_executed_iterations, // [num_queries,]
const search_params& ps,
uint32_t topk,
// multi_cta_search (params struct)
uint32_t block_size, //
Expand All @@ -466,13 +467,7 @@ void select_and_run(
int64_t hash_bitlen,
typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr,
uint32_t num_cta_per_query,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cuvs::distance::DistanceType metric,
cudaStream_t stream)
Expand Down Expand Up @@ -507,16 +502,16 @@ void select_and_run(
queries_ptr,
graph.data_handle(),
graph.extent(1),
num_random_samplings,
rand_xor_mask,
ps.num_random_samplings,
ps.rand_xor_mask,
dev_seed_ptr,
num_seeds,
hashmap_ptr,
hash_bitlen,
itopk_size,
search_width,
min_iterations,
max_iterations,
ps.itopk_size,
ps.search_width,
ps.min_iterations,
ps.max_iterations,
num_executed_iterations,
sample_filter,
metric);
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_single_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
num_queries,
dev_seed_ptr,
num_executed_iterations,
*this,
topk,
num_itopk_candidates,
static_cast<uint32_t>(thread_block_size),
Expand All @@ -241,13 +242,7 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
hashmap.data(),
small_hash_bitlen,
small_hash_reset_interval,
num_random_samplings,
rand_xor_mask,
num_seeds,
itopk_size,
search_width,
min_iterations,
max_iterations,
sample_filter,
this->metric,
stream);
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search {
const uint32_t num_queries, \
const typename DATASET_DESC_T::INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t num_itopk_candidates, \
uint32_t block_size, \
Expand All @@ -40,13 +41,7 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search {
typename DATASET_DESC_T::INDEX_T* hashmap_ptr, \
size_t small_hash_bitlen, \
size_t small_hash_reset_interval, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down
21 changes: 8 additions & 13 deletions cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -919,6 +919,7 @@ void select_and_run(
const uint32_t num_queries,
const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds]
uint32_t* const num_executed_iterations, // [num_queries,]
const search_params& ps,
uint32_t topk,
uint32_t num_itopk_candidates,
uint32_t block_size, //
Expand All @@ -927,20 +928,14 @@ void select_and_run(
typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr,
size_t small_hash_bitlen,
size_t small_hash_reset_interval,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cuvs::distance::DistanceType metric,
cudaStream_t stream)
{
auto kernel =
search_kernel_config<TEAM_SIZE, DATASET_BLOCK_DIM, DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T>::
choose_itopk_and_mx_candidates(itopk_size, num_itopk_candidates, block_size);
choose_itopk_and_mx_candidates(ps.itopk_size, num_itopk_candidates, block_size);
RAFT_CUDA_TRY(cudaFuncSetAttribute(kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte));
Expand All @@ -955,15 +950,15 @@ void select_and_run(
queries_ptr,
graph.data_handle(),
graph.extent(1),
num_random_samplings,
rand_xor_mask,
ps.num_random_samplings,
ps.rand_xor_mask,
dev_seed_ptr,
num_seeds,
hashmap_ptr,
itopk_size,
search_width,
min_iterations,
max_iterations,
ps.itopk_size,
ps.search_width,
ps.min_iterations,
ps.max_iterations,
num_executed_iterations,
hash_bitlen,
small_hash_bitlen,
Expand Down