2828import sys
2929import numpy as np
3030
31+ ##################################################################
32+ # Equivalent of swig_ptr for Torch tensors
33+ ##################################################################
34+
3135def swig_ptr_from_UInt8Tensor (x ):
3236 """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
3337 assert x .is_contiguous ()
3438 assert x .dtype == torch .uint8
3539 return faiss .cast_integer_to_uint8_ptr (
3640 x .untyped_storage ().data_ptr () + x .storage_offset ())
3741
42+
3843def swig_ptr_from_HalfTensor (x ):
3944 """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
4045 assert x .is_contiguous ()
@@ -43,27 +48,34 @@ def swig_ptr_from_HalfTensor(x):
4348 return faiss .cast_integer_to_void_ptr (
4449 x .untyped_storage ().data_ptr () + x .storage_offset () * 2 )
4550
51+
4652def swig_ptr_from_FloatTensor (x ):
4753 """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
4854 assert x .is_contiguous ()
4955 assert x .dtype == torch .float32
5056 return faiss .cast_integer_to_float_ptr (
5157 x .untyped_storage ().data_ptr () + x .storage_offset () * 4 )
5258
59+
5360def swig_ptr_from_IntTensor (x ):
5461 """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
5562 assert x .is_contiguous ()
5663 assert x .dtype == torch .int32 , 'dtype=%s' % x .dtype
5764 return faiss .cast_integer_to_int_ptr (
5865 x .untyped_storage ().data_ptr () + x .storage_offset () * 4 )
5966
67+
6068def swig_ptr_from_IndicesTensor (x ):
6169 """ gets a Faiss SWIG pointer from a pytorch tensor (on CPU or GPU) """
6270 assert x .is_contiguous ()
6371 assert x .dtype == torch .int64 , 'dtype=%s' % x .dtype
6472 return faiss .cast_integer_to_idx_t_ptr (
6573 x .untyped_storage ().data_ptr () + x .storage_offset () * 8 )
6674
75+ ##################################################################
76+ # utilities
77+ ##################################################################
78+
6779@contextlib .contextmanager
6880def using_stream (res , pytorch_stream = None ):
6981 """ Creates a scoping object to make Faiss GPU use the same stream
@@ -107,6 +119,10 @@ def torch_replace_method(the_class, name, replacement,
107119 setattr (the_class , name + '_numpy' , orig_method )
108120 setattr (the_class , name , replacement )
109121
122+ ##################################################################
123+ # Setup wrappers
124+ ##################################################################
125+
110126def handle_torch_Index (the_class ):
111127 def torch_replacement_add (self , x ):
112128 if type (x ) is np .ndarray :
@@ -493,6 +509,52 @@ def torch_replacement_sa_decode(self, codes, x=None):
493509 handle_torch_Index (the_class )
494510
495511
512+ # allows torch tensor usage with knn
513+ def torch_replacement_knn (xq , xb , k , metric = faiss .METRIC_L2 , metric_arg = 0 ):
514+ if type (xb ) is np .ndarray :
515+ # Forward to faiss __init__.py base method
516+ return faiss .knn_numpy (xq , xb , k , metric = metric , metric_arg = metric_arg )
517+
518+ nb , d = xb .size ()
519+ assert xb .is_contiguous ()
520+ assert xb .dtype == torch .float32
521+ assert not xb .is_cuda , "use knn_gpu for GPU tensors"
522+
523+ nq , d2 = xq .size ()
524+ assert d2 == d
525+ assert xq .is_contiguous ()
526+ assert xq .dtype == torch .float32
527+ assert not xq .is_cuda , "use knn_gpu for GPU tensors"
528+
529+ D = torch .empty (nq , k , device = xb .device , dtype = torch .float32 )
530+ I = torch .empty (nq , k , device = xb .device , dtype = torch .int64 )
531+ I_ptr = swig_ptr_from_IndicesTensor (I )
532+ D_ptr = swig_ptr_from_FloatTensor (D )
533+ xb_ptr = swig_ptr_from_FloatTensor (xb )
534+ xq_ptr = swig_ptr_from_FloatTensor (xq )
535+
536+ if metric == faiss .METRIC_L2 :
537+ faiss .knn_L2sqr (
538+ xq_ptr , xb_ptr ,
539+ d , nq , nb , k , D_ptr , I_ptr
540+ )
541+ elif metric == faiss .METRIC_INNER_PRODUCT :
542+ faiss .knn_inner_product (
543+ xq_ptr , xb_ptr ,
544+ d , nq , nb , k , D_ptr , I_ptr
545+ )
546+ else :
547+ faiss .knn_extra_metrics (
548+ xq_ptr , xb_ptr ,
549+ d , nq , nb , metric , metric_arg , k , D_ptr , I_ptr
550+ )
551+
552+ return D , I
553+
554+
555+ torch_replace_method (faiss_module , 'knn' , torch_replacement_knn , True , True )
556+
557+
496558# allows torch tensor usage with bfKnn
497559def torch_replacement_knn_gpu (res , xq , xb , k , D = None , I = None , metric = faiss .METRIC_L2 , device = - 1 , use_raft = False ):
498560 if type (xb ) is np .ndarray :
0 commit comments