@@ -78,7 +78,8 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size,
7878 const uint32_t dataset_dim,
7979 IdxT* const knn_graph, // [graph_chunk_size, graph_degree]
8080 const uint32_t graph_size,
81- const uint32_t graph_degree)
81+ const uint32_t graph_degree,
82+ const cuvs::distance::DistanceType metric)
8283{
8384 const IdxT srcNode = (blockDim .x * blockIdx .x + threadIdx .x ) / raft::WarpSize;
8485 if (srcNode >= graph_size) { return ; }
@@ -91,19 +92,46 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size,
9192 // Compute distance from a src node to its neighbors
9293 for (int k = 0 ; k < graph_degree; k++) {
9394 const IdxT dstNode = knn_graph[k + static_cast <uint64_t >(graph_degree) * srcNode];
94- float dist = 0.0 ;
95- for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) {
96- float diff = cuvs::spatial::knn::detail::utils::mapping<float >{}(
97- dataset[d + static_cast <uint64_t >(dataset_dim) * srcNode]) -
98- cuvs::spatial::knn::detail::utils::mapping<float >{}(
99- dataset[d + static_cast <uint64_t >(dataset_dim) * dstNode]);
100- dist += diff * diff;
95+ float dist = 0 ;
96+ float norm2_dst = 0 ;
97+ if (metric == cuvs::distance::DistanceType::InnerProduct ||
98+ metric == cuvs::distance::DistanceType::CosineExpanded) {
99+ for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) {
100+ auto elem_b = cuvs::spatial::knn::detail::utils::mapping<float >{}(
101+ dataset[d + static_cast <uint64_t >(dataset_dim) * dstNode]);
102+ dist -= cuvs::spatial::knn::detail::utils::mapping<float >{}(
103+ dataset[d + static_cast <uint64_t >(dataset_dim) * srcNode]) *
104+ elem_b;
105+
106+ if (metric == cuvs::distance::DistanceType::CosineExpanded) {
107+ norm2_dst += elem_b * elem_b;
108+ }
109+ }
110+ } else {
111+ // L2Expanded
112+ for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) {
113+ float diff = cuvs::spatial::knn::detail::utils::mapping<float >{}(
114+ dataset[d + static_cast <uint64_t >(dataset_dim) * srcNode]) -
115+ cuvs::spatial::knn::detail::utils::mapping<float >{}(
116+ dataset[d + static_cast <uint64_t >(dataset_dim) * dstNode]);
117+ dist += diff * diff;
118+ }
101119 }
102120 dist += __shfl_xor_sync (0xffffffff , dist, 1 );
103121 dist += __shfl_xor_sync (0xffffffff , dist, 2 );
104122 dist += __shfl_xor_sync (0xffffffff , dist, 4 );
105123 dist += __shfl_xor_sync (0xffffffff , dist, 8 );
106124 dist += __shfl_xor_sync (0xffffffff , dist, 16 );
125+
126+ if (metric == cuvs::distance::DistanceType::CosineExpanded) {
127+ norm2_dst += __shfl_xor_sync (0xffffffff , norm2_dst, 1 );
128+ norm2_dst += __shfl_xor_sync (0xffffffff , norm2_dst, 2 );
129+ norm2_dst += __shfl_xor_sync (0xffffffff , norm2_dst, 4 );
130+ norm2_dst += __shfl_xor_sync (0xffffffff , norm2_dst, 8 );
131+ norm2_dst += __shfl_xor_sync (0xffffffff , norm2_dst, 16 );
132+ if (lane_id == (k % raft::WarpSize)) { dist /= sqrt (norm2_dst); }
133+ }
134+
107135 if (lane_id == (k % raft::WarpSize)) {
108136 my_keys[k / raft::WarpSize] = dist;
109137 my_vals[k / raft::WarpSize] = dstNode;
@@ -471,11 +499,17 @@ template <
471499 raft::host_device_accessor<std::experimental::default_accessor<IdxT>, raft::memory_type::host>>
472500void sort_knn_graph (
473501 raft::resources const & res,
502+ const cuvs::distance::DistanceType metric,
474503 raft::mdspan<const DataT, raft::matrix_extent<int64_t >, raft::row_major, d_accessor> dataset,
475504 raft::mdspan<IdxT, raft::matrix_extent<int64_t >, raft::row_major, g_accessor> knn_graph)
476505{
477506 RAFT_EXPECTS (dataset.extent (0 ) == knn_graph.extent (0 ),
478507 " dataset size is expected to have the same number of graph index size" );
508+ RAFT_EXPECTS (
509+ metric == cuvs::distance::DistanceType::InnerProduct ||
510+ metric == cuvs::distance::DistanceType::CosineExpanded ||
511+ metric == cuvs::distance::DistanceType::L2Expanded,
512+ " Unsupported metric. Only InnerProduct, CosineExpanded, and L2Expanded are supported" );
479513 const uint64_t dataset_size = dataset.extent (0 );
480514 const uint64_t dataset_dim = dataset.extent (1 );
481515 const DataT* dataset_ptr = dataset.data_handle ();
@@ -507,8 +541,13 @@ void sort_knn_graph(
507541 graph_size * input_graph_degree,
508542 raft::resource::get_cuda_stream (res));
509543
510- void (*kernel_sort)(
511- const DataT* const , const IdxT, const uint32_t , IdxT* const , const uint32_t , const uint32_t );
544+ void (*kernel_sort)(const DataT* const ,
545+ const IdxT,
546+ const uint32_t ,
547+ IdxT* const ,
548+ const uint32_t ,
549+ const uint32_t ,
550+ const cuvs::distance::DistanceType);
512551 if (input_graph_degree <= 32 ) {
513552 constexpr int numElementsPerThread = 1 ;
514553 kernel_sort = kern_sort<DataT, IdxT, numElementsPerThread>;
@@ -545,7 +584,8 @@ void sort_knn_graph(
545584 dataset_dim,
546585 d_input_graph.data_handle (),
547586 graph_size,
548- input_graph_degree);
587+ input_graph_degree,
588+ metric);
549589 raft::resource::sync_stream (res);
550590 RAFT_LOG_DEBUG (" ." );
551591 raft::copy (input_graph_ptr,
0 commit comments