@@ -256,7 +256,12 @@ extern "C" cuvsError_t cuvsMultiGpuIvfPqBuild(cuvsResources_t res,
256256 cuvsMultiGpuIvfPqIndex_t index)
257257{
258258 return cuvs::core::translate_exceptions ([=] {
259- auto dataset = dataset_tensor->dl_tensor ;
259+ auto dataset = dataset_tensor->dl_tensor ;
260+
261+ // Multi-GPU IVF-PQ requires dataset to be in host memory
262+ RAFT_EXPECTS (cuvs::core::is_dlpack_host_compatible (dataset),
263+ " Multi-GPU IVF-PQ build requires dataset to have host compatible memory" );
264+
260265 index->dtype .code = dataset.dtype .code ;
261266 index->dtype .bits = dataset.dtype .bits ;
262267
@@ -284,7 +289,29 @@ extern "C" cuvsError_t cuvsMultiGpuIvfPqSearch(cuvsResources_t res,
284289 DLManagedTensor* distances_tensor)
285290{
286291 return cuvs::core::translate_exceptions ([=] {
287- auto queries = queries_tensor->dl_tensor ;
292+ auto queries = queries_tensor->dl_tensor ;
293+ auto neighbors = neighbors_tensor->dl_tensor ;
294+ auto distances = distances_tensor->dl_tensor ;
295+
296+ // Multi-GPU IVF-PQ requires all tensors to be in host memory
297+ RAFT_EXPECTS (cuvs::core::is_dlpack_host_compatible (queries),
298+ " Multi-GPU IVF-PQ search requires queries to have host compatible memory" );
299+ RAFT_EXPECTS (cuvs::core::is_dlpack_host_compatible (neighbors),
300+ " Multi-GPU IVF-PQ search requires neighbors to have host compatible memory" );
301+ RAFT_EXPECTS (cuvs::core::is_dlpack_host_compatible (distances),
302+ " Multi-GPU IVF-PQ search requires distances to have host compatible memory" );
303+
304+ // Validate data types
305+ RAFT_EXPECTS (neighbors.dtype .code == kDLInt && neighbors.dtype .bits == 64 ,
306+ " neighbors should be of type int64_t" );
307+ RAFT_EXPECTS (distances.dtype .code == kDLFloat && distances.dtype .bits == 32 ,
308+ " distances should be of type float32" );
309+
310+ // Check type compatibility between index and queries
311+ RAFT_EXPECTS (queries.dtype .code == index->dtype .code ,
312+ " type mismatch between index and queries" );
313+ RAFT_EXPECTS (queries.dtype .bits == index->dtype .bits ,
314+ " type mismatch between index and queries" );
288315
289316 if (queries.dtype .code == kDLFloat && queries.dtype .bits == 32 ) {
290317 _mg_search<float >(res, *params, *index, queries_tensor, neighbors_tensor, distances_tensor);
@@ -310,6 +337,25 @@ extern "C" cuvsError_t cuvsMultiGpuIvfPqExtend(cuvsResources_t res,
310337 return cuvs::core::translate_exceptions ([=] {
311338 auto vectors = new_vectors_tensor->dl_tensor ;
312339
340+ // Multi-GPU IVF-PQ requires vectors to be in host memory
341+ RAFT_EXPECTS (cuvs::core::is_dlpack_host_compatible (vectors),
342+ " Multi-GPU IVF-PQ extend requires new_vectors to have host compatible memory" );
343+
344+ // Check type compatibility between index and vectors
345+ RAFT_EXPECTS (vectors.dtype .code == index->dtype .code ,
346+ " type mismatch between index and new_vectors" );
347+ RAFT_EXPECTS (vectors.dtype .bits == index->dtype .bits ,
348+ " type mismatch between index and new_vectors" );
349+
350+ // If indices are provided, they should also be in host memory
351+ if (new_indices_tensor != nullptr ) {
352+ auto indices = new_indices_tensor->dl_tensor ;
353+ RAFT_EXPECTS (cuvs::core::is_dlpack_host_compatible (indices),
354+ " Multi-GPU IVF-PQ extend requires new_indices to have host compatible memory" );
355+ RAFT_EXPECTS (indices.dtype .code == kDLInt && indices.dtype .bits == 64 ,
356+ " new_indices should be of type int64_t" );
357+ }
358+
313359 if (vectors.dtype .code == kDLFloat && vectors.dtype .bits == 32 ) {
314360 _mg_extend<float >(res, *index, new_vectors_tensor, new_indices_tensor);
315361 } else if (vectors.dtype .code == kDLFloat && vectors.dtype .bits == 16 ) {
@@ -381,28 +427,8 @@ extern "C" cuvsError_t cuvsMultiGpuIvfPqDistribute(cuvsResources_t res,
381427 cuvsMultiGpuIvfPqIndex_t index)
382428{
383429 return cuvs::core::translate_exceptions ([=] {
384- std::ifstream is (filename, std::ios::in | std::ios::binary);
385- if (!is) { RAFT_FAIL (" Cannot open file %s" , filename); }
386- char dtype_string[4 ];
387- is.read (dtype_string, 4 );
388- auto dtype = raft::detail::numpy_serializer::parse_descr (std::string (dtype_string, 4 ));
389- is.close ();
390-
391- index->dtype .bits = dtype.itemsize * 8 ;
392- if (dtype.kind == ' f' && dtype.itemsize == 4 ) {
393- index->dtype .code = kDLFloat ;
394- index->addr = reinterpret_cast <uintptr_t >(_mg_distribute<float >(res, filename));
395- } else if (dtype.kind == ' f' && dtype.itemsize == 2 ) {
396- index->dtype .code = kDLFloat ;
397- index->addr = reinterpret_cast <uintptr_t >(_mg_distribute<half>(res, filename));
398- } else if (dtype.kind == ' i' && dtype.itemsize == 1 ) {
399- index->dtype .code = kDLInt ;
400- index->addr = reinterpret_cast <uintptr_t >(_mg_distribute<int8_t >(res, filename));
401- } else if (dtype.kind == ' u' && dtype.itemsize == 1 ) {
402- index->dtype .code = kDLUInt ;
403- index->addr = reinterpret_cast <uintptr_t >(_mg_distribute<uint8_t >(res, filename));
404- } else {
405- RAFT_FAIL (" Unsupported index dtype" );
406- }
430+ index->dtype .code = kDLFloat ;
431+ index->dtype .bits = 32 ;
432+ index->addr = reinterpret_cast <uintptr_t >(_mg_distribute<float >(res, filename));
407433 });
408434}
0 commit comments