Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cpp/include/cuvs/neighbors/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ cuvsError_t cuvsHnswFromCagra(cuvsResources_t res,
cuvsCagraIndex_t cagra_index,
cuvsHnswIndex_t hnsw_index);


cuvsError_t cuvsHnswFromCagraWithDataset(cuvsResources_t res,
cuvsHnswIndexParams_t params,
cuvsCagraIndex_t cagra_index,
cuvsHnswIndex_t hnsw_index,
DLManagedTensor* dataset_tensor);

/**
* @}
*/
Expand Down
10 changes: 8 additions & 2 deletions cpp/src/neighbors/detail/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,14 @@ std::enable_if_t<hierarchy == HnswHierarchy::NONE, std::unique_ptr<index<T>>> fr
cuvs::neighbors::cagra::serialize_to_hnswlib(res, filepath, cagra_index, dataset);

index<T>* hnsw_index = nullptr;
cuvs::neighbors::hnsw::deserialize(
res, params, filepath, cagra_index.dim(), cagra_index.metric(), &hnsw_index);
int dim;
if (dataset.has_value()) {
dim = dataset.value().extent(1);
} else {
dim = cagra_index.dim();
}

cuvs::neighbors::hnsw::deserialize(res, params, filepath, dim, cagra_index.metric(), &hnsw_index);
std::filesystem::remove(filepath);
return std::unique_ptr<index<T>>(hnsw_index);
}
Expand Down
39 changes: 34 additions & 5 deletions cpp/src/neighbors/hnsw_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,23 @@ template <typename T>
void _from_cagra(cuvsResources_t res,
cuvsHnswIndexParams_t params,
cuvsCagraIndex_t cagra_index,
cuvsHnswIndex_t hnsw_index)
cuvsHnswIndex_t hnsw_index,
std::optional<DLManagedTensor*> dataset_tensor)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(cagra_index->addr);
auto cpp_params = cuvs::neighbors::hnsw::index_params();
cpp_params.hierarchy = static_cast<cuvs::neighbors::hnsw::HnswHierarchy>(params->hierarchy);
cpp_params.ef_construction = params->ef_construction;
cpp_params.num_threads = params->num_threads;
std::optional<raft::host_matrix_view<const T, int64_t, raft::row_major>> dataset = std::nullopt;
std::optional<raft::host_matrix_view<const T, int64_t, raft::row_major>> dataset;
if (dataset_tensor.has_value()) {
using dataset_mdspan_type =
raft::host_matrix_view<T const, int64_t, raft::row_major>;
dataset = cuvs::core::from_dlpack<dataset_mdspan_type>(*dataset_tensor);
} else {
dataset = std::nullopt;
}

auto hnsw_index_unique_ptr =
cuvs::neighbors::hnsw::from_cagra(*res_ptr, cpp_params, *index, dataset);
Expand Down Expand Up @@ -175,11 +183,32 @@ extern "C" cuvsError_t cuvsHnswFromCagra(cuvsResources_t res,
auto index = *cagra_index;
hnsw_index->dtype = index.dtype;
if (index.dtype.code == kDLFloat) {
_from_cagra<float>(res, params, cagra_index, hnsw_index);
_from_cagra<float>(res, params, cagra_index, hnsw_index, std::nullopt);
} else if (index.dtype.code == kDLUInt) {
_from_cagra<uint8_t>(res, params, cagra_index, hnsw_index);
_from_cagra<uint8_t>(res, params, cagra_index, hnsw_index, std::nullopt);
} else if (index.dtype.code == kDLInt) {
_from_cagra<int8_t>(res, params, cagra_index, hnsw_index);
_from_cagra<int8_t>(res, params, cagra_index, hnsw_index, std::nullopt);
} else {
RAFT_FAIL("Unsupported dtype: %d", index.dtype.code);
}
});
}

extern "C" cuvsError_t cuvsHnswFromCagraWithDataset(cuvsResources_t res,
cuvsHnswIndexParams_t params,
cuvsCagraIndex_t cagra_index,
cuvsHnswIndex_t hnsw_index,
DLManagedTensor* dataset_tensor)
{
return cuvs::core::translate_exceptions([=] {
auto index = *cagra_index;
hnsw_index->dtype = index.dtype;
if (index.dtype.code == kDLFloat) {
_from_cagra<float>(res, params, cagra_index, hnsw_index, dataset_tensor);
} else if (index.dtype.code == kDLUInt) {
_from_cagra<uint8_t>(res, params, cagra_index, hnsw_index, dataset_tensor);
} else if (index.dtype.code == kDLInt) {
_from_cagra<int8_t>(res, params, cagra_index, hnsw_index, dataset_tensor);
} else {
RAFT_FAIL("Unsupported dtype: %d", index.dtype.code);
}
Expand Down