@@ -544,6 +544,7 @@ def replacement_range_search(self, x, thresh, *, params=None):
544544 n , d = x .shape
545545 assert d == self .d
546546 x = np .ascontiguousarray (x , dtype = 'float32' )
547+ thresh = float (thresh )
547548
548549 res = RangeSearchResult (n )
549550 self .range_search_c (n , swig_ptr (x ), thresh , res , params )
@@ -618,6 +619,64 @@ def replacement_search_preassigned(self, x, k, Iq, Dq, *, params=None, D=None, I
618619 )
619620 return D , I
620621
622+ def replacement_range_search_preassigned (self , x , thresh , Iq , Dq , * , params = None ):
623+ """Search vectors that are within a distance of the query vectors.
624+
625+ Parameters
626+ ----------
627+ x : array_like
628+ Query vectors, shape (n, d) where d is appropriate for the index.
629+ `dtype` must be float32.
630+ thresh : float
631+ Threshold to select neighbors. All elements within this radius are returned,
632+ except for maximum inner product indexes, where the elements above the
633+ threshold are returned
634+ Iq : array_like, optional
635+ Nearest centroids, size (n, nprobe)
636+ Dq : array_like, optional
637+ Distance array to the centroids, size (n, nprobe)
638+ params : SearchParameters
639+ Search parameters of the current search (overrides the class-level params)
640+
641+
642+ Returns
643+ -------
644+ lims: array_like
645+ Starting index of the results for each query vector, size n+1.
646+ D : array_like
647+ Distances of the nearest neighbors, shape `lims[n]`. The distances for
648+ query i are in `D[lims[i]:lims[i+1]]`.
649+ I : array_like
650+ Labels of nearest neighbors, shape `lims[n]`. The labels for query i
651+ are in `I[lims[i]:lims[i+1]]`.
652+
653+ """
654+ n , d = x .shape
655+ assert d == self .d
656+ x = np .ascontiguousarray (x , dtype = 'float32' )
657+
658+ Iq = np .ascontiguousarray (Iq , dtype = 'int64' )
659+ assert params is None , "params not supported"
660+ assert Iq .shape == (n , self .nprobe )
661+
662+ if Dq is not None :
663+ Dq = np .ascontiguousarray (Dq , dtype = 'float32' )
664+ assert Dq .shape == Iq .shape
665+
666+ thresh = float (thresh )
667+ res = RangeSearchResult (n )
668+ self .range_search_preassigned_c (
669+ n , swig_ptr (x ), thresh ,
670+ swig_ptr (Iq ), swig_ptr (Dq ),
671+ res
672+ )
673+ # get pointers and copy them
674+ lims = rev_swig_ptr (res .lims , n + 1 ).copy ()
675+ nd = int (lims [- 1 ])
676+ D = rev_swig_ptr (res .distances , nd ).copy ()
677+ I = rev_swig_ptr (res .labels , nd ).copy ()
678+ return lims , D , I
679+
621680 def replacement_sa_encode (self , x , codes = None ):
622681 n , d = x .shape
623682 assert d == self .d
@@ -675,8 +734,12 @@ def replacement_permute_entries(self, perm):
675734 ignore_missing = True )
676735 replace_method (the_class , 'search_and_reconstruct' ,
677736 replacement_search_and_reconstruct , ignore_missing = True )
737+
738+ # these ones are IVF-specific
678739 replace_method (the_class , 'search_preassigned' ,
679740 replacement_search_preassigned , ignore_missing = True )
741+ replace_method (the_class , 'range_search_preassigned' ,
742+ replacement_range_search_preassigned , ignore_missing = True )
680743 replace_method (the_class , 'sa_encode' , replacement_sa_encode )
681744 replace_method (the_class , 'sa_decode' , replacement_sa_decode )
682745 replace_method (the_class , 'add_sa_codes' , replacement_add_sa_codes ,
@@ -776,6 +839,36 @@ def replacement_range_search(self, x, thresh):
776839 I = rev_swig_ptr (res .labels , nd ).copy ()
777840 return lims , D , I
778841
842+ def replacement_range_search_preassigned (self , x , thresh , Iq , Dq , * , params = None ):
843+ n , d = x .shape
844+ x = _check_dtype_uint8 (x )
845+ assert d * 8 == self .d
846+
847+ Iq = np .ascontiguousarray (Iq , dtype = 'int64' )
848+ assert params is None , "params not supported"
849+ assert Iq .shape == (n , self .nprobe )
850+
851+ if Dq is not None :
852+ Dq = np .ascontiguousarray (Dq , dtype = 'int32' )
853+ assert Dq .shape == Iq .shape
854+
855+ thresh = int (thresh )
856+ res = RangeSearchResult (n )
857+ self .range_search_preassigned_c (
858+ n , swig_ptr (x ), thresh ,
859+ swig_ptr (Iq ), swig_ptr (Dq ),
860+ res
861+ )
862+ # get pointers and copy them
863+ lims = rev_swig_ptr (res .lims , n + 1 ).copy ()
864+ nd = int (lims [- 1 ])
865+ D = rev_swig_ptr (res .distances , nd ).copy ()
866+ I = rev_swig_ptr (res .labels , nd ).copy ()
867+ return lims , D , I
868+
869+
870+
871+
779872 def replacement_remove_ids (self , x ):
780873 if isinstance (x , IDSelector ):
781874 sel = x
@@ -794,6 +887,8 @@ def replacement_remove_ids(self, x):
794887 replace_method (the_class , 'remove_ids' , replacement_remove_ids )
795888 replace_method (the_class , 'search_preassigned' ,
796889 replacement_search_preassigned , ignore_missing = True )
890+ replace_method (the_class , 'range_search_preassigned' ,
891+ replacement_range_search_preassigned , ignore_missing = True )
797892
798893
799894def handle_VectorTransform (the_class ):
@@ -937,7 +1032,7 @@ def handle_MapLong2Long(the_class):
9371032
9381033 def replacement_map_add (self , keys , vals ):
9391034 n , = keys .shape
940- assert (n ,) == keys .shape
1035+ assert (n ,) == vals .shape
9411036 self .add_c (n , swig_ptr (keys ), swig_ptr (vals ))
9421037
9431038 def replacement_map_search_multiple (self , keys ):
0 commit comments