Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/src/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ void search_main(raft::resources const& res,
if (auto* strided_dset = dynamic_cast<const strided_dataset<T, ds_idx_type>*>(&index.data());
strided_dset != nullptr) {
// Search using a plain (strided) row-major dataset
auto& desc = dataset_descriptor_init_with_cache<T, InternalIdxT, DistanceT>(
auto desc = dataset_descriptor_init_with_cache<T, InternalIdxT, DistanceT>(
res, params, *strided_dset, index.metric());
search_main_core<T, InternalIdxT, DistanceT, CagraSampleFilterT>(
res, params, desc, graph_internal, queries, neighbors, distances, sample_filter);
Expand All @@ -161,7 +161,7 @@ void search_main(raft::resources const& res,
RAFT_FAIL("FP32 VPQ dataset support is coming soon");
} else if (auto* vpq_dset = dynamic_cast<const vpq_dataset<half, ds_idx_type>*>(&index.data());
vpq_dset != nullptr) {
auto& desc = dataset_descriptor_init_with_cache<T, InternalIdxT, DistanceT>(
auto desc = dataset_descriptor_init_with_cache<T, InternalIdxT, DistanceT>(
res, params, *vpq_dset, index.metric());
search_main_core<T, InternalIdxT, DistanceT, CagraSampleFilterT>(
res, params, desc, graph_internal, queries, neighbors, distances, sample_filter);
Expand Down
77 changes: 52 additions & 25 deletions cpp/src/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
#include <raft/util/device_loads_stores.cuh>
#include <raft/util/vectorized.cuh>

#include <atomic>
#include <functional>
#include <memory>
#include <mutex>
#include <type_traits>
#include <variant>

Expand Down Expand Up @@ -232,52 +234,77 @@ struct alignas(device::LOAD_128BIT_T) dataset_descriptor_base_t {
*/
template <typename DataT, typename IndexT, typename DistanceT>
struct dataset_descriptor_host {
using dev_descriptor_t = dataset_descriptor_base_t<DataT, IndexT, DistanceT>;
using dd_ptr_t = std::shared_ptr<dev_descriptor_t>;
using init_f =
std::tuple<std::function<void(dev_descriptor_t*, rmm::cuda_stream_view stream)>, size_t>;
using dev_descriptor_t = dataset_descriptor_base_t<DataT, IndexT, DistanceT>;
uint32_t smem_ws_size_in_bytes = 0;
uint32_t team_size = 0;

struct state {
using ready_t = std::tuple<dev_descriptor_t*, rmm::cuda_stream_view>;
using init_f =
std::tuple<std::function<void(dev_descriptor_t*, rmm::cuda_stream_view)>, size_t>;

std::mutex mutex;
std::atomic<bool> ready; // Not sure if std::holds_alternative is thread-safe
std::variant<ready_t, init_f> value;

template <typename InitF>
state(InitF init, size_t size) : ready{false}, value{std::make_tuple(init, size)}
{
}

~state() noexcept
{
if (std::holds_alternative<ready_t>(value)) {
auto& [ptr, stream] = std::get<ready_t>(value);
RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(ptr, stream));
}
}

void eval(rmm::cuda_stream_view stream)
{
std::lock_guard<std::mutex> lock(mutex);
if (std::holds_alternative<init_f>(value)) {
auto& [fun, size] = std::get<init_f>(value);
dev_descriptor_t* ptr = nullptr;
RAFT_CUDA_TRY(cudaMallocAsync(&ptr, size, stream));
fun(ptr, stream);
value = std::make_tuple(ptr, stream);
ready.store(true, std::memory_order_release);
}
}

auto get(rmm::cuda_stream_view stream) -> dev_descriptor_t*
{
if (!ready.load(std::memory_order_acquire)) { eval(stream); }
return std::get<0>(std::get<ready_t>(value));
}
};

template <typename DescriptorImpl, typename InitF>
dataset_descriptor_host(const DescriptorImpl& dd_host, InitF init)
: value_{std::make_tuple(init, sizeof(DescriptorImpl))},
: value_{std::make_shared<state>(init, sizeof(DescriptorImpl))},
smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()},
team_size{dd_host.team_size()}
{
}

dataset_descriptor_host() = default;

/**
* Return the device pointer, possibly evaluating it in the given thread.
*/
[[nodiscard]] auto dev_ptr(rmm::cuda_stream_view stream) const -> const dev_descriptor_t*
{
if (std::holds_alternative<init_f>(value_)) { value_ = eval(std::get<init_f>(value_), stream); }
return std::get<dd_ptr_t>(value_).get();
return value_->get(stream);
}

[[nodiscard]] auto dev_ptr(rmm::cuda_stream_view stream) -> dev_descriptor_t*
{
if (std::holds_alternative<init_f>(value_)) { value_ = eval(std::get<init_f>(value_), stream); }
return std::get<dd_ptr_t>(value_).get();
return value_->get(stream);
}

private:
mutable std::variant<dd_ptr_t, init_f> value_;

static auto eval(init_f init, rmm::cuda_stream_view stream) -> dd_ptr_t
{
using raft::RAFT_NAME;
auto& [fun, size] = init;
dd_ptr_t dev_ptr{
[stream, s = size]() {
dev_descriptor_t* p;
RAFT_CUDA_TRY(cudaMallocAsync(&p, s, stream));
return p;
}(),
[stream](dev_descriptor_t* p) { RAFT_CUDA_TRY_NO_THROW(cudaFreeAsync(p, stream)); }};
fun(dev_ptr.get(), stream);
return dev_ptr;
}
mutable std::shared_ptr<state> value_;
};

/**
Expand Down
20 changes: 8 additions & 12 deletions cpp/src/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,9 @@ template <typename DataT, typename IndexT, typename DistanceT>
struct store {
/** Number of descriptors to cache. */
static constexpr size_t kDefaultSize = 100;
raft::cache::lru<key,
key_hash,
std::equal_to<>,
std::shared_ptr<dataset_descriptor_host<DataT, IndexT, DistanceT>>>
value{kDefaultSize};
raft::cache::
lru<key, key_hash, std::equal_to<>, dataset_descriptor_host<DataT, IndexT, DistanceT>>
value{kDefaultSize};
};

} // namespace descriptor_cache
Expand All @@ -159,20 +157,18 @@ auto dataset_descriptor_init_with_cache(const raft::resources& res,
const cagra::search_params& params,
const DatasetT& dataset,
cuvs::distance::DistanceType metric)
-> const dataset_descriptor_host<DataT, IndexT, DistanceT>&
-> dataset_descriptor_host<DataT, IndexT, DistanceT>
{
using desc_t = dataset_descriptor_host<DataT, IndexT, DistanceT>;
auto key = descriptor_cache::make_key(params, dataset, metric);
auto key = descriptor_cache::make_key(params, dataset, metric);
auto& cache =
raft::resource::get_custom_resource<descriptor_cache::store<DataT, IndexT, DistanceT>>(res)
->value;
std::shared_ptr<desc_t> desc{nullptr};
dataset_descriptor_host<DataT, IndexT, DistanceT> desc;
if (!cache.get(key, &desc)) {
desc = std::make_shared<desc_t>(
std::move(dataset_descriptor_init<DataT, IndexT, DistanceT>(params, dataset, metric)));
desc = dataset_descriptor_init<DataT, IndexT, DistanceT>(params, dataset, metric);
cache.set(key, desc);
}
return *desc;
return desc;
}

}; // namespace cuvs::neighbors::cagra::detail
12 changes: 6 additions & 6 deletions cpp/src/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_
using base_type::num_seeds;

uint32_t num_cta_per_query;
rmm::device_uvector<INDEX_T> intermediate_indices;
rmm::device_uvector<float> intermediate_distances;
lightweight_uvector<INDEX_T> intermediate_indices;
lightweight_uvector<float> intermediate_distances;
size_t topk_workspace_size;
rmm::device_uvector<uint32_t> topk_workspace;
lightweight_uvector<uint32_t> topk_workspace;

search(raft::resources const& res,
search_params params,
Expand All @@ -105,9 +105,9 @@ struct search : public search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_
int64_t graph_degree,
uint32_t topk)
: base_type(res, params, dataset_desc, dim, graph_degree, topk),
intermediate_indices(0, raft::resource::get_cuda_stream(res)),
intermediate_distances(0, raft::resource::get_cuda_stream(res)),
topk_workspace(0, raft::resource::get_cuda_stream(res))
intermediate_indices(res),
intermediate_distances(res),
topk_workspace(res)

{
set_params(res, params);
Expand Down
53 changes: 31 additions & 22 deletions cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,15 @@ void get_value(T* const host_ptr, const T* const dev_ptr, cudaStream_t cuda_stre
get_value_kernel<T><<<1, 1, 0, cuda_stream>>>(host_ptr, dev_ptr);
}

template <class T>
auto get_value(const T* const dev_ptr, cudaStream_t stream) -> T
{
T value;
RAFT_CUDA_TRY(cudaMemcpyAsync(&value, dev_ptr, sizeof(value), cudaMemcpyDefault, stream));
RAFT_CUDA_TRY(cudaStreamSynchronize(stream));
return value;
}

// MAX_DATASET_DIM : must equal to or greater than dataset_dim
template <class DATASET_DESCRIPTOR_T>
RAFT_KERNEL random_pickup_kernel(
Expand Down Expand Up @@ -609,18 +618,18 @@ struct search : search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_T> {
using base_type::num_seeds;

size_t result_buffer_allocation_size;
rmm::device_uvector<INDEX_T> result_indices; // results_indices_buffer
rmm::device_uvector<DISTANCE_T> result_distances; // result_distances_buffer
rmm::device_uvector<INDEX_T> parent_node_list;
rmm::device_uvector<uint32_t> topk_hint;
rmm::device_scalar<uint32_t> terminate_flag; // dev_terminate_flag, host_terminate_flag.;
rmm::device_uvector<uint32_t> topk_workspace;
lightweight_uvector<INDEX_T> result_indices; // results_indices_buffer
lightweight_uvector<DISTANCE_T> result_distances; // result_distances_buffer
lightweight_uvector<INDEX_T> parent_node_list;
lightweight_uvector<uint32_t> topk_hint;
lightweight_uvector<uint32_t> terminate_flag; // dev_terminate_flag, host_terminate_flag.;
lightweight_uvector<uint32_t> topk_workspace;

// temporary storage for _find_topk
rmm::device_uvector<float> input_keys_storage;
rmm::device_uvector<float> output_keys_storage;
rmm::device_uvector<INDEX_T> input_values_storage;
rmm::device_uvector<INDEX_T> output_values_storage;
lightweight_uvector<float> input_keys_storage;
lightweight_uvector<float> output_keys_storage;
lightweight_uvector<INDEX_T> input_values_storage;
lightweight_uvector<INDEX_T> output_values_storage;

search(raft::resources const& res,
search_params params,
Expand All @@ -629,16 +638,16 @@ struct search : search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_T> {
int64_t graph_degree,
uint32_t topk)
: base_type(res, params, dataset_desc, dim, graph_degree, topk),
result_indices(0, raft::resource::get_cuda_stream(res)),
result_distances(0, raft::resource::get_cuda_stream(res)),
parent_node_list(0, raft::resource::get_cuda_stream(res)),
topk_hint(0, raft::resource::get_cuda_stream(res)),
topk_workspace(0, raft::resource::get_cuda_stream(res)),
terminate_flag(raft::resource::get_cuda_stream(res)),
input_keys_storage(0, raft::resource::get_cuda_stream(res)),
output_keys_storage(0, raft::resource::get_cuda_stream(res)),
input_values_storage(0, raft::resource::get_cuda_stream(res)),
output_values_storage(0, raft::resource::get_cuda_stream(res))
result_indices(res),
result_distances(res),
parent_node_list(res),
topk_hint(res),
topk_workspace(res),
terminate_flag(res),
input_keys_storage(res),
output_keys_storage(res),
input_values_storage(res),
output_values_storage(res)
{
set_params(res);
}
Expand All @@ -662,7 +671,7 @@ struct search : search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_T> {
itopk_size, max_queries, result_buffer_size, utils::get_cuda_data_type<DATA_T>());
RAFT_LOG_DEBUG("# topk_workspace_size: %lu", topk_workspace_size);
topk_workspace.resize(topk_workspace_size, raft::resource::get_cuda_stream(res));

terminate_flag.resize(1, raft::resource::get_cuda_stream(res));
hashmap.resize(hashmap_size, raft::resource::get_cuda_stream(res));
}

Expand Down Expand Up @@ -847,7 +856,7 @@ struct search : search_plan_impl<DataT, IndexT, DistanceT, SAMPLE_FILTER_T> {
stream);

// termination (2)
if (iter + 1 >= min_iterations && terminate_flag.value(stream)) {
if (iter + 1 >= min_iterations && get_value(terminate_flag.data(), stream)) {
iter++;
break;
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ struct search_plan_impl : public search_plan_impl_base {
lightweight_uvector<INDEX_T> hashmap;
lightweight_uvector<uint32_t> num_executed_iterations; // device or managed?
lightweight_uvector<INDEX_T> dev_seed;
const dataset_descriptor_host<DataT, IndexT, DistanceT>& dataset_desc;
dataset_descriptor_host<DataT, IndexT, DistanceT> dataset_desc;

search_plan_impl(raft::resources const& res,
search_params params,
Expand Down