@@ -333,6 +333,63 @@ void get_graph_view(cuvsCagraIndex_t index, DLManagedTensor* graph)
333333 auto index_ptr = reinterpret_cast <cuvs::neighbors::cagra::index<T, IdxT>*>(index->addr );
334334 cuvs::core::to_dlpack (index_ptr->graph (), graph);
335335}
336+
337+ // Helper function to populate C IVF-PQ params from C++ params
338+ static void _populate_c_ivf_pq_params (cuvsIvfPqParams* c_ivf_pq,
339+ const cuvs::neighbors::cagra::graph_build_params::ivf_pq_params& cpp_ivf_pq)
340+ {
341+ // Populate the IVF-PQ build params
342+ auto & bp = cpp_ivf_pq.build_params ;
343+ c_ivf_pq->ivf_pq_build_params ->metric = static_cast <cuvsDistanceType>(bp.metric );
344+ c_ivf_pq->ivf_pq_build_params ->metric_arg = bp.metric_arg ;
345+ c_ivf_pq->ivf_pq_build_params ->add_data_on_build = bp.add_data_on_build ;
346+ c_ivf_pq->ivf_pq_build_params ->n_lists = bp.n_lists ;
347+ c_ivf_pq->ivf_pq_build_params ->kmeans_n_iters = bp.kmeans_n_iters ;
348+ c_ivf_pq->ivf_pq_build_params ->kmeans_trainset_fraction = bp.kmeans_trainset_fraction ;
349+ c_ivf_pq->ivf_pq_build_params ->pq_bits = bp.pq_bits ;
350+ c_ivf_pq->ivf_pq_build_params ->pq_dim = bp.pq_dim ;
351+ c_ivf_pq->ivf_pq_build_params ->codebook_kind = static_cast <codebook_gen>(bp.codebook_kind );
352+ c_ivf_pq->ivf_pq_build_params ->force_random_rotation = bp.force_random_rotation ;
353+ c_ivf_pq->ivf_pq_build_params ->conservative_memory_allocation = bp.conservative_memory_allocation ;
354+ c_ivf_pq->ivf_pq_build_params ->max_train_points_per_pq_code = bp.max_train_points_per_pq_code ;
355+
356+ // Populate the IVF-PQ search params
357+ auto & sp = cpp_ivf_pq.search_params ;
358+ c_ivf_pq->ivf_pq_search_params ->n_probes = sp.n_probes ;
359+ c_ivf_pq->ivf_pq_search_params ->lut_dtype = sp.lut_dtype ;
360+ c_ivf_pq->ivf_pq_search_params ->internal_distance_dtype = sp.internal_distance_dtype ;
361+ c_ivf_pq->ivf_pq_search_params ->preferred_shmem_carveout = sp.preferred_shmem_carveout ;
362+
363+ c_ivf_pq->refinement_rate = cpp_ivf_pq.refinement_rate ;
364+ }
365+
366+ // Helper function to populate C struct from C++ index_params
367+ static void _populate_cagra_index_params_from_cpp (cuvsCagraIndexParams_t c_params,
368+ const cuvs::neighbors::cagra::index_params& cpp_params)
369+ {
370+ c_params->metric = static_cast <cuvsDistanceType>(cpp_params.metric );
371+ c_params->intermediate_graph_degree = cpp_params.intermediate_graph_degree ;
372+ c_params->graph_degree = cpp_params.graph_degree ;
373+
374+ // Set build algo and parameters based on the variant
375+ if (std::holds_alternative<cuvs::neighbors::cagra::graph_build_params::nn_descent_params>(
376+ cpp_params.graph_build_params )) {
377+ c_params->build_algo = NN_DESCENT;
378+ auto nn_params =
379+ std::get<cuvs::neighbors::cagra::graph_build_params::nn_descent_params>(
380+ cpp_params.graph_build_params );
381+ c_params->nn_descent_niter = nn_params.max_iterations ;
382+ } else if (std::holds_alternative<cuvs::neighbors::cagra::graph_build_params::ivf_pq_params>(
383+ cpp_params.graph_build_params )) {
384+ c_params->build_algo = IVF_PQ;
385+ auto ivf_pq_params =
386+ std::get<cuvs::neighbors::cagra::graph_build_params::ivf_pq_params>(
387+ cpp_params.graph_build_params );
388+
389+ _populate_c_ivf_pq_params (c_params->graph_build_params , ivf_pq_params);
390+ }
391+ }
392+
336393} // namespace
337394
338395namespace cuvs ::neighbors::cagra {
@@ -665,6 +722,24 @@ extern "C" cuvsError_t cuvsCagraCompressionParamsDestroy(cuvsCagraCompressionPar
665722 return cuvs::core::translate_exceptions ([=] { delete params; });
666723}
667724
725+ extern " C" cuvsError_t cuvsCagraIndexParamsFromHnswParams (cuvsCagraIndexParams_t params,
726+ int64_t n_rows,
727+ int64_t dim,
728+ int M,
729+ int ef_construction,
730+ enum cuvsCagraHnswHeuristicType heuristic,
731+ cuvsDistanceType metric)
732+ {
733+ return cuvs::core::translate_exceptions ([=] {
734+ auto cpp_metric = static_cast <cuvs::distance::DistanceType>((int )metric);
735+ auto cpp_heuristic = static_cast <cuvs::neighbors::cagra::hnsw_heuristic_type>((int )heuristic);
736+ auto cpp_params = cuvs::neighbors::cagra::index_params::from_hnsw_params (
737+ raft::matrix_extent<int64_t >(n_rows, dim), M, ef_construction, cpp_heuristic, cpp_metric);
738+
739+ _populate_cagra_index_params_from_cpp (params, cpp_params);
740+ });
741+ }
742+
668743extern " C" cuvsError_t cuvsCagraExtendParamsCreate (cuvsCagraExtendParams_t* params)
669744{
670745 return cuvs::core::translate_exceptions (
0 commit comments