@@ -488,7 +488,7 @@ __device__ __forceinline__ void remove_duplicates(
488488// MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048
489489// For architectures 750 and 860 (890), the values for MAX_RESIDENT_THREAD_PER_SM
490490// is 1024 and 1536 respectively, which means the bounds don't work anymore
491- template <typename Index_t, typename ID_t = InternalID_t<Index_t>>
491+ template <typename Index_t, typename ID_t = InternalID_t<Index_t>, typename DistEpilogue_t >
492492RAFT_KERNEL
493493#ifdef __CUDA_ARCH__
494494// Use minBlocksPerMultiprocessor = 4 on specific arches
@@ -513,7 +513,8 @@ __launch_bounds__(BLOCK_SIZE)
513513 int graph_width,
514514 int * locks,
515515 DistData_t* l2_norms,
516- cuvs::distance::DistanceType metric)
516+ cuvs::distance::DistanceType metric,
517+ DistEpilogue_t dist_epilogue)
517518{
518519#if (__CUDA_ARCH__ >= 700)
519520 using namespace nvcuda ;
@@ -623,20 +624,22 @@ __launch_bounds__(BLOCK_SIZE)
623624 __syncthreads ();
624625
625626 for (int i = threadIdx .x ; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim .x ) {
626- if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_new_size &&
627- i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) {
627+ int row_id = i % SKEWED_MAX_NUM_BI_SAMPLES;
628+ int col_id = i / SKEWED_MAX_NUM_BI_SAMPLES;
629+
630+ if (row_id < list_new_size && col_id < list_new_size) {
628631 if (metric == cuvs::distance::DistanceType::InnerProduct) {
629632 s_distances[i] = -s_distances[i];
630633 } else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
631634 s_distances[i] = 1.0 - s_distances[i];
632635 } else { // L2Expanded or L2SqrtExpanded
633- s_distances[i] = l2_norms[new_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] +
634- l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] -
635- 2.0 * s_distances[i];
636+ s_distances[i] =
637+ l2_norms[new_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i];
636638 // for fp32 vs fp16 precision differences resulting in negative distances when distance
637639 // should be 0 related issue: https://github.com/rapidsai/cuvs/issues/991
638640 s_distances[i] = s_distances[i] < 0 .0f ? 0 .0f : s_distances[i];
639641 }
642+ s_distances[i] = dist_epilogue (s_distances[i], new_neighbors[row_id], new_neighbors[col_id]);
640643 } else {
641644 s_distances[i] = std::numeric_limits<float >::max ();
642645 }
@@ -707,20 +710,21 @@ __launch_bounds__(BLOCK_SIZE)
707710 __syncthreads ();
708711
709712 for (int i = threadIdx .x ; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim .x ) {
710- if (i % SKEWED_MAX_NUM_BI_SAMPLES < list_old_size &&
711- i / SKEWED_MAX_NUM_BI_SAMPLES < list_new_size) {
713+ int row_id = i % SKEWED_MAX_NUM_BI_SAMPLES;
714+ int col_id = i / SKEWED_MAX_NUM_BI_SAMPLES;
715+ if (row_id < list_old_size && col_id < list_new_size) {
712716 if (metric == cuvs::distance::DistanceType::InnerProduct) {
713717 s_distances[i] = -s_distances[i];
714718 } else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
715719 s_distances[i] = 1.0 - s_distances[i];
716720 } else { // L2Expanded or L2SqrtExpanded
717- s_distances[i] = l2_norms[old_neighbors[i % SKEWED_MAX_NUM_BI_SAMPLES]] +
718- l2_norms[new_neighbors[i / SKEWED_MAX_NUM_BI_SAMPLES]] -
719- 2.0 * s_distances[i];
721+ s_distances[i] =
722+ l2_norms[old_neighbors[row_id]] + l2_norms[new_neighbors[col_id]] - 2.0 * s_distances[i];
720723 // for fp32 vs fp16 precision differences resulting in negative distances when distance
721724 // should be 0 related issue: https://github.com/rapidsai/cuvs/issues/991
722725 s_distances[i] = s_distances[i] < 0 .0f ? 0 .0f : s_distances[i];
723726 }
727+ s_distances[i] = dist_epilogue (s_distances[i], old_neighbors[row_id], new_neighbors[col_id]);
724728 } else {
725729 s_distances[i] = std::numeric_limits<float >::max ();
726730 }
@@ -1034,7 +1038,8 @@ void GNND<Data_t, Index_t>::add_reverse_edges(Index_t* graph_ptr,
10341038}
10351039
10361040template <typename Data_t, typename Index_t>
1037- void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
1041+ template <typename DistEpilogue_t>
1042+ void GNND<Data_t, Index_t>::local_join(cudaStream_t stream, DistEpilogue_t dist_epilogue)
10381043{
10391044 raft::matrix::fill (res, dists_buffer_.view (), std::numeric_limits<float >::max ());
10401045 local_join_kernel<<<nrow_, BLOCK_SIZE, 0 , stream>>> (graph_.h_graph_new .data_handle (),
@@ -1051,15 +1056,18 @@ void GNND<Data_t, Index_t>::local_join(cudaStream_t stream)
10511056 DEGREE_ON_DEVICE,
10521057 d_locks_.data_handle (),
10531058 l2_norms_.data_handle (),
1054- build_config_.metric );
1059+ build_config_.metric ,
1060+ dist_epilogue);
10551061}
10561062
10571063template <typename Data_t, typename Index_t>
1064+ template <typename DistEpilogue_t>
10581065void GNND<Data_t, Index_t>::build(Data_t* data,
10591066 const Index_t nrow,
10601067 Index_t* output_graph,
10611068 bool return_distances,
1062- DistData_t* output_distances)
1069+ DistData_t* output_distances,
1070+ DistEpilogue_t dist_epilogue)
10631071{
10641072 using input_t = typename std::remove_const<Data_t>::type;
10651073
@@ -1154,7 +1162,7 @@ void GNND<Data_t, Index_t>::build(Data_t* data,
11541162 raft::util::arch::SM_range (raft::util::arch::SM_70 (), raft::util::arch::SM_future ());
11551163
11561164 if (wmma_range.contains (runtime_arch)) {
1157- local_join (stream);
1165+ local_join (stream, dist_epilogue );
11581166 } else {
11591167 THROW (" NN_DESCENT cannot be run for __CUDA_ARCH__ < 700" );
11601168 }
0 commit comments