diff --git a/cpp/src/neighbors/detail/nn_descent.cuh b/cpp/src/neighbors/detail/nn_descent.cuh index 7cec3bae61..941af82a0f 100644 --- a/cpp/src/neighbors/detail/nn_descent.cuh +++ b/cpp/src/neighbors/detail/nn_descent.cuh @@ -491,11 +491,12 @@ __device__ __forceinline__ void remove_duplicates( template > RAFT_KERNEL #ifdef __CUDA_ARCH__ -#if (__CUDA_ARCH__) == 750 || ((__CUDA_ARCH__) >= 860 && (__CUDA_ARCH__) <= 890) || \ - (__CUDA_ARCH__) == 1200 -__launch_bounds__(BLOCK_SIZE) -#else +// Use minBlocksPerMultiprocessor = 4 on specific arches +#if (__CUDA_ARCH__) == 700 || (__CUDA_ARCH__) == 800 || (__CUDA_ARCH__) == 900 || \ + (__CUDA_ARCH__) == 1000 __launch_bounds__(BLOCK_SIZE, 4) +#else +__launch_bounds__(BLOCK_SIZE) #endif #endif local_join_kernel(const Index_t* graph_new, diff --git a/cpp/tests/neighbors/ann_cagra.cuh b/cpp/tests/neighbors/ann_cagra.cuh index 4ddc3bd6a7..c0f80f1e82 100644 --- a/cpp/tests/neighbors/ann_cagra.cuh +++ b/cpp/tests/neighbors/ann_cagra.cuh @@ -1035,29 +1035,22 @@ class AnnCagraIndexMergeTest : public ::testing::TestWithParam { auto database1_view = raft::make_device_matrix_view( (const DataT*)database.data() + database0_view.size(), database1_size, ps.dim); - cagra::index index0(handle_); - cagra::index index1(handle_); + cagra::index index0(handle_, index_params.metric); + cagra::index index1(handle_, index_params.metric); + std::optional> database_host{std::nullopt}; if (ps.host_dataset) { + database_host = raft::make_host_matrix(handle_, ps.n_rows, ps.dim); + raft::copy(database_host->data_handle(), database.data(), database.size(), stream_); { - std::optional> database_host{std::nullopt}; - database_host = raft::make_host_matrix(database0_size, ps.dim); - raft::copy(database_host->data_handle(), - database0_view.data_handle(), - database0_view.size(), - stream_); auto database_host_view = raft::make_host_matrix_view( (const DataT*)database_host->data_handle(), database0_size, ps.dim); index0 = cagra::build(handle_, index_params, database_host_view); } { - std::optional> database_host{std::nullopt}; - database_host = raft::make_host_matrix(database1_size, ps.dim); - raft::copy(database_host->data_handle(), - database1_view.data_handle(), - database1_view.size(), - stream_); auto database_host_view = raft::make_host_matrix_view( - (const DataT*)database_host->data_handle(), database1_size, ps.dim); + (const DataT*)database_host->data_handle() + database0_size * ps.dim, + database1_size, + ps.dim); index1 = cagra::build(handle_, index_params, database_host_view); } } else {