@@ -83,6 +83,34 @@ void* _build(cuvsResources_t res, cuvsCagraIndexParams params, DLManagedTensor*
8383 return index;
8484}
8585
86+ template <typename T>
87+ void _extend (cuvsResources_t res,
88+ cuvsCagraExtendParams params,
89+ cuvsCagraIndex index,
90+ DLManagedTensor* additional_dataset_tensor)
91+ {
92+ auto dataset = additional_dataset_tensor->dl_tensor ;
93+ auto index_ptr = reinterpret_cast <cuvs::neighbors::cagra::index<T, uint32_t >*>(index.addr );
94+ auto res_ptr = reinterpret_cast <raft::resources*>(res);
95+
96+ auto extend_params = cuvs::neighbors::cagra::extend_params ();
97+ extend_params.max_chunk_size = params.max_chunk_size ;
98+
99+ if (cuvs::core::is_dlpack_device_compatible (dataset)) {
100+ using mdspan_type = raft::device_matrix_view<T const , int64_t , raft::row_major>;
101+ auto mds = cuvs::core::from_dlpack<mdspan_type>(additional_dataset_tensor);
102+ cuvs::neighbors::cagra::extend (*res_ptr, extend_params, mds, *index_ptr);
103+ } else if (cuvs::core::is_dlpack_host_compatible (dataset)) {
104+ using mdspan_type = raft::host_matrix_view<T const , int64_t , raft::row_major>;
105+ auto mds = cuvs::core::from_dlpack<mdspan_type>(additional_dataset_tensor);
106+ cuvs::neighbors::cagra::extend (*res_ptr, extend_params, mds, *index_ptr);
107+ } else {
108+ RAFT_FAIL (" Unsupported dataset DLtensor dtype: %d and bits: %d" ,
109+ dataset.dtype .code ,
110+ dataset.dtype .bits );
111+ }
112+ }
113+
86114template <typename T>
87115void _search (cuvsResources_t res,
88116 cuvsCagraSearchParams params,
@@ -190,6 +218,29 @@ extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res,
190218 });
191219}
192220
221+ extern " C" cuvsError_t cuvsCagraExtend (cuvsResources_t res,
222+ cuvsCagraExtendParams_t params,
223+ DLManagedTensor* additional_dataset_tensor,
224+ cuvsCagraIndex_t index_c_ptr)
225+ {
226+ return cuvs::core::translate_exceptions ([=] {
227+ auto dataset = additional_dataset_tensor->dl_tensor ;
228+ auto index = *index_c_ptr;
229+
230+ if ((dataset.dtype .code == kDLFloat ) && (dataset.dtype .bits == 32 )) {
231+ _extend<float >(res, *params, index, additional_dataset_tensor);
232+ } else if (dataset.dtype .code == kDLInt && dataset.dtype .bits == 8 ) {
233+ _extend<int8_t >(res, *params, index, additional_dataset_tensor);
234+ } else if (dataset.dtype .code == kDLUInt && dataset.dtype .bits == 8 ) {
235+ _extend<uint8_t >(res, *params, index, additional_dataset_tensor);
236+ } else {
237+ RAFT_FAIL (" Unsupported dataset DLtensor dtype: %d and bits: %d" ,
238+ dataset.dtype .code ,
239+ dataset.dtype .bits );
240+ }
241+ });
242+ }
243+
193244extern " C" cuvsError_t cuvsCagraSearch (cuvsResources_t res,
194245 cuvsCagraSearchParams_t params,
195246 cuvsCagraIndex_t index_c_ptr,
@@ -265,6 +316,17 @@ extern "C" cuvsError_t cuvsCagraCompressionParamsDestroy(cuvsCagraCompressionPar
265316 return cuvs::core::translate_exceptions ([=] { delete params; });
266317}
267318
319+ extern " C" cuvsError_t cuvsCagraExtendParamsCreate (cuvsCagraExtendParams_t* params)
320+ {
321+ return cuvs::core::translate_exceptions (
322+ [=] { *params = new cuvsCagraExtendParams{.max_chunk_size = 0 }; });
323+ }
324+
325+ extern " C" cuvsError_t cuvsCagraExtendParamsDestroy (cuvsCagraExtendParams_t params)
326+ {
327+ return cuvs::core::translate_exceptions ([=] { delete params; });
328+ }
329+
268330extern " C" cuvsError_t cuvsCagraSearchParamsCreate (cuvsCagraSearchParams_t* params)
269331{
270332 return cuvs::core::translate_exceptions ([=] {
0 commit comments