diff --git a/cpp/src/distance/detail/sparse/coo_spmv_kernel.cuh b/cpp/src/distance/detail/sparse/coo_spmv_kernel.cuh index d26bc39fd9..857d87fa0d 100644 --- a/cpp/src/distance/detail/sparse/coo_spmv_kernel.cuh +++ b/cpp/src/distance/detail/sparse/coo_spmv_kernel.cuh @@ -1,11 +1,12 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once #include +#include #include #include @@ -120,10 +121,11 @@ RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy, extern __shared__ char smem[]; - typename strategy_t::smem_type A = (typename strategy_t::smem_type)(smem); - typename warp_reduce::TempStorage* temp_storage = (typename warp_reduce::TempStorage*)(A + dim); + void* A = smem; + typename warp_reduce::TempStorage* temp_storage = + (typename warp_reduce::TempStorage*)((char*)A + dim); - auto inserter = strategy.init_insert(A, dim); + auto map_ref = strategy.init_map(A, dim); __syncthreads(); @@ -134,13 +136,11 @@ RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy, // Convert current row vector in A to dense for (int i = tid; i <= (stop_offset_a - start_offset_a); i += blockDim.x) { - strategy.insert(inserter, indicesA[start_offset_a + i], dataA[start_offset_a + i]); + strategy.insert(map_ref, indicesA[start_offset_a + i], dataA[start_offset_a + i]); } __syncthreads(); - auto finder = strategy.init_find(A, dim); - if (cur_row_a > m || cur_chunk_offset > n_blocks_per_row) return; if (ind >= nnz_b) return; @@ -166,7 +166,7 @@ RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy, auto in_bounds = indptrA.check_indices_bounds(start_index_a, stop_index_a, index_b); if (in_bounds) { - value_t a_col = strategy.find(finder, index_b); + value_t a_col = strategy.find(map_ref, index_b); if (!rev || a_col == 0.0) { c = product_func(a_col, dataB[ind]); } } } @@ -204,7 +204,7 @@ RAFT_KERNEL balanced_coo_generalized_spmv_kernel(strategy_t strategy, auto index_b = indicesB[ind]; auto in_bounds = indptrA.check_indices_bounds(start_index_a, stop_index_a, index_b); if (in_bounds) { - value_t a_col = strategy.find(finder, index_b); + value_t a_col = strategy.find(map_ref, index_b); if (!rev || a_col == 0.0) { c = accum_func(c, product_func(a_col, dataB[ind])); } } diff --git a/cpp/src/distance/detail/sparse/coo_spmv_strategies/dense_smem_strategy.cuh b/cpp/src/distance/detail/sparse/coo_spmv_strategies/dense_smem_strategy.cuh index 00ac28983f..39d2bef3f3 100644 --- a/cpp/src/distance/detail/sparse/coo_spmv_strategies/dense_smem_strategy.cuh +++ b/cpp/src/distance/detail/sparse/coo_spmv_strategies/dense_smem_strategy.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -7,7 +7,7 @@ #include "base_strategy.cuh" -#include // raft::ceildiv +#include namespace cuvs { namespace distance { @@ -17,9 +17,7 @@ namespace sparse { template class dense_smem_strategy : public coo_spmv_strategy { public: - using smem_type = value_t*; - using insert_type = smem_type; - using find_type = smem_type; + using map_type = value_t*; dense_smem_strategy(const distances_config_t& config_) : coo_spmv_strategy(config_) @@ -83,25 +81,21 @@ class dense_smem_strategy : public coo_spmv_strategy { n_blocks_per_row); } - __device__ inline insert_type init_insert(smem_type cache, const value_idx& cache_size) + __device__ inline map_type init_map(void* storage, const value_idx& cache_size) { + auto cache = static_cast(storage); for (int k = threadIdx.x; k < cache_size; k += blockDim.x) { cache[k] = 0.0; } return cache; } - __device__ inline void insert(insert_type cache, const value_idx& key, const value_t& value) + __device__ inline void insert(map_type& cache, const value_idx& key, const value_t& value) { cache[key] = value; } - __device__ inline find_type init_find(smem_type cache, const value_idx& cache_size) - { - return cache; - } - - __device__ inline value_t find(find_type cache, const value_idx& key) { return cache[key]; } + __device__ inline value_t find(map_type& cache, const value_idx& key) { return cache[key]; } }; } // namespace sparse diff --git a/cpp/src/distance/detail/sparse/coo_spmv_strategies/hash_strategy.cuh b/cpp/src/distance/detail/sparse/coo_spmv_strategies/hash_strategy.cuh index 64a329a332..a05fe35cc7 100644 --- a/cpp/src/distance/detail/sparse/coo_spmv_strategies/hash_strategy.cuh +++ b/cpp/src/distance/detail/sparse/coo_spmv_strategies/hash_strategy.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -9,11 +9,15 @@ #include #include +#include #include #include #include +#include +#include + // this is needed by cuco as key, value must be bitwise comparable. // compilers don't declare float/double as bitwise comparable // but that is too strict @@ -32,11 +36,19 @@ namespace sparse { template class hash_strategy : public coo_spmv_strategy { public: - using insert_type = typename cuco::legacy:: - static_map::device_mutable_view; - using smem_type = typename insert_type::slot_type*; - using find_type = - typename cuco::legacy::static_map::device_view; + static constexpr value_idx empty_key_sentinel = value_idx{-1}; + static constexpr value_t empty_value_sentinel = value_t{0}; + using probing_scheme_type = cuco::linear_probing<1, cuco::murmurhash3_32>; + using storage_ref_type = + cuco::bucket_storage_ref, 1, cuco::extent>; + using map_type = cuco::static_map_ref, + probing_scheme_type, + storage_ref_type, + cuco::op::insert_tag, + cuco::op::find_tag>; hash_strategy(const distances_config_t& config_, float capacity_threshold_ = 0.5, @@ -220,32 +232,35 @@ class hash_strategy : public coo_spmv_strategy { } } - __device__ inline insert_type init_insert(smem_type cache, const value_idx& cache_size) + __device__ inline map_type init_map(void* storage, const value_idx& cache_size) { - return insert_type::make_from_uninitialized_slots(cooperative_groups::this_thread_block(), - cache, - cache_size, - cuco::empty_key{value_idx{-1}}, - cuco::empty_value{value_t{0}}); + auto map_ref = + map_type{cuco::empty_key{empty_key_sentinel}, + cuco::empty_value{empty_value_sentinel}, + cuda::std::equal_to{}, + probing_scheme_type{}, + cuco::cuda_thread_scope{}, + storage_ref_type{cuco::extent{cache_size}, + static_cast(storage)}}; + map_ref.initialize(cooperative_groups::this_thread_block()); + + return map_ref; } - __device__ inline void insert(insert_type cache, const value_idx& key, const value_t& value) + __device__ inline void insert(map_type& map_ref, const value_idx& key, const value_t& value) { - auto success = cache.insert(cuco::pair(key, value)); + map_ref.insert(cuco::pair{key, value}); } - __device__ inline find_type init_find(smem_type cache, const value_idx& cache_size) - { - return find_type( - cache, cache_size, cuco::empty_key{value_idx{-1}}, cuco::empty_value{value_t{0}}); - } + // Note: init_find is now merged with init_map since the new API uses the same ref for both + // operations - __device__ inline value_t find(find_type cache, const value_idx& key) + __device__ inline value_t find(map_type& map_ref, const value_idx& key) { - auto a_pair = cache.find(key); + auto a_pair = map_ref.find(key); value_t a_col = 0.0; - if (a_pair != cache.end()) { a_col = a_pair->second; } + if (a_pair != map_ref.end()) { a_col = a_pair->second; } return a_col; } @@ -271,7 +286,7 @@ class hash_strategy : public coo_spmv_strategy { inline static int get_map_size() { return (raft::getSharedMemPerBlock() - ((tpb / raft::warp_size()) * sizeof(value_t))) / - sizeof(typename insert_type::slot_type); + sizeof(cuco::pair); } private: