Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpp/include/cuvs/neighbors/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ cuvsError_t cuvsHnswFromCagra(cuvsResources_t res,
cuvsCagraIndex_t cagra_index,
cuvsHnswIndex_t hnsw_index);

cuvsError_t cuvsHnswFromCagraWithDataset(cuvsResources_t res,
cuvsHnswIndexParams_t params,
cuvsCagraIndex_t cagra_index,
cuvsHnswIndex_t hnsw_index,
DLManagedTensor* dataset_tensor);

/**
* @}
*/
Expand Down
10 changes: 8 additions & 2 deletions cpp/src/neighbors/detail/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,14 @@ std::enable_if_t<hierarchy == HnswHierarchy::NONE, std::unique_ptr<index<T>>> fr
cuvs::neighbors::cagra::serialize_to_hnswlib(res, filepath, cagra_index, dataset);

index<T>* hnsw_index = nullptr;
cuvs::neighbors::hnsw::deserialize(
res, params, filepath, cagra_index.dim(), cagra_index.metric(), &hnsw_index);
int dim;
if (dataset.has_value()) {
dim = dataset.value().extent(1);
} else {
dim = cagra_index.dim();
}

cuvs::neighbors::hnsw::deserialize(res, params, filepath, dim, cagra_index.metric(), &hnsw_index);
std::filesystem::remove(filepath);
return std::unique_ptr<index<T>>(hnsw_index);
}
Expand Down
38 changes: 33 additions & 5 deletions cpp/src/neighbors/hnsw_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,22 @@ template <typename T>
void _from_cagra(cuvsResources_t res,
cuvsHnswIndexParams_t params,
cuvsCagraIndex_t cagra_index,
cuvsHnswIndex_t hnsw_index)
cuvsHnswIndex_t hnsw_index,
std::optional<DLManagedTensor*> dataset_tensor)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(cagra_index->addr);
auto cpp_params = cuvs::neighbors::hnsw::index_params();
cpp_params.hierarchy = static_cast<cuvs::neighbors::hnsw::HnswHierarchy>(params->hierarchy);
cpp_params.ef_construction = params->ef_construction;
cpp_params.num_threads = params->num_threads;
std::optional<raft::host_matrix_view<const T, int64_t, raft::row_major>> dataset = std::nullopt;
std::optional<raft::host_matrix_view<const T, int64_t, raft::row_major>> dataset;
if (dataset_tensor.has_value()) {
using dataset_mdspan_type = raft::host_matrix_view<T const, int64_t, raft::row_major>;
dataset = cuvs::core::from_dlpack<dataset_mdspan_type>(*dataset_tensor);
} else {
dataset = std::nullopt;
}

auto hnsw_index_unique_ptr =
cuvs::neighbors::hnsw::from_cagra(*res_ptr, cpp_params, *index, dataset);
Expand Down Expand Up @@ -175,11 +182,32 @@ extern "C" cuvsError_t cuvsHnswFromCagra(cuvsResources_t res,
auto index = *cagra_index;
hnsw_index->dtype = index.dtype;
if (index.dtype.code == kDLFloat) {
_from_cagra<float>(res, params, cagra_index, hnsw_index);
_from_cagra<float>(res, params, cagra_index, hnsw_index, std::nullopt);
} else if (index.dtype.code == kDLUInt) {
_from_cagra<uint8_t>(res, params, cagra_index, hnsw_index);
_from_cagra<uint8_t>(res, params, cagra_index, hnsw_index, std::nullopt);
} else if (index.dtype.code == kDLInt) {
_from_cagra<int8_t>(res, params, cagra_index, hnsw_index);
_from_cagra<int8_t>(res, params, cagra_index, hnsw_index, std::nullopt);
} else {
RAFT_FAIL("Unsupported dtype: %d", index.dtype.code);
}
});
}

extern "C" cuvsError_t cuvsHnswFromCagraWithDataset(cuvsResources_t res,
cuvsHnswIndexParams_t params,
cuvsCagraIndex_t cagra_index,
cuvsHnswIndex_t hnsw_index,
DLManagedTensor* dataset_tensor)
{
return cuvs::core::translate_exceptions([=] {
auto index = *cagra_index;
hnsw_index->dtype = index.dtype;
if (index.dtype.code == kDLFloat) {
_from_cagra<float>(res, params, cagra_index, hnsw_index, dataset_tensor);
} else if (index.dtype.code == kDLUInt) {
_from_cagra<uint8_t>(res, params, cagra_index, hnsw_index, dataset_tensor);
} else if (index.dtype.code == kDLInt) {
_from_cagra<int8_t>(res, params, cagra_index, hnsw_index, dataset_tensor);
} else {
RAFT_FAIL("Unsupported dtype: %d", index.dtype.code);
}
Expand Down