@@ -107,17 +107,15 @@ def randn(n, seed=12345):
107107def checksum (a ):
108108 """ compute a checksum for quick-and-dirty comparisons of arrays """
109109 a = a .view ('uint8' )
110- n = a . size
111- n4 = n & ~ 3
112- cs = ivec_checksum ( int ( n4 / 4 ), swig_ptr ( a [: n4 ]. view ( 'int32' )))
113- for i in range ( n4 , n ):
114- cs += x [ i ] * 33657
110+ if a . ndim == 1 :
111+ return bvec_checksum ( s . size , swig_ptr ( a ))
112+ n , d = a . shape
113+ cs = np . zeros ( n , dtype = 'uint64' )
114+ bvecs_checksum ( n , d , swig_ptr ( a ), swig_ptr ( cs ))
115115 return cs
116116
117-
118117rand_smooth_vectors_c = rand_smooth_vectors
119118
120-
121119def rand_smooth_vectors (n , d , seed = 1234 ):
122120 res = np .empty ((n , d ), dtype = 'float32' )
123121 rand_smooth_vectors_c (n , d , swig_ptr (res ), seed )
@@ -422,7 +420,7 @@ def __init__(self, d, k, **kwargs):
422420 including niter=25, verbose=False, spherical = False
423421 """
424422 self .d = d
425- self .k = k
423+ self .reset ( k )
426424 self .gpu = False
427425 if "progressive_dim_steps" in kwargs :
428426 self .cp = ProgressiveDimClusteringParameters ()
@@ -437,7 +435,32 @@ def __init__(self, d, k, **kwargs):
437435 # if this raises an exception, it means that it is a non-existent field
438436 getattr (self .cp , k )
439437 setattr (self .cp , k , v )
438+ self .set_index ()
439+
440+ def set_index (self ):
441+ d = self .d
442+ if self .cp .__class__ == ClusteringParameters :
443+ if self .cp .spherical :
444+ self .index = IndexFlatIP (d )
445+ else :
446+ self .index = IndexFlatL2 (d )
447+ if self .gpu :
448+ self .index = faiss .index_cpu_to_all_gpus (self .index , ngpu = self .gpu )
449+ else :
450+ if self .gpu :
451+ fac = GpuProgressiveDimIndexFactory (ngpu = self .gpu )
452+ else :
453+ fac = ProgressiveDimIndexFactory ()
454+ self .fac = fac
455+
456+ def reset (self , k = None ):
457+ """ prepare k-means object to perform a new clustering, possibly
458+ with another number of centroids """
459+ if k is not None :
460+ self .k = int (k )
440461 self .centroids = None
462+ self .obj = None
463+ self .iteration_stats = None
441464
442465 def train (self , x , weights = None , init_centroids = None ):
443466 """ Perform k-means clustering.
@@ -476,24 +499,14 @@ def train(self, x, weights=None, init_centroids=None):
476499 nc , d2 = init_centroids .shape
477500 assert d2 == d
478501 faiss .copy_array_to_vector (init_centroids .ravel (), clus .centroids )
479- if self .cp .spherical :
480- self .index = IndexFlatIP (d )
481- else :
482- self .index = IndexFlatL2 (d )
483- if self .gpu :
484- self .index = faiss .index_cpu_to_all_gpus (self .index , ngpu = self .gpu )
485502 clus .train (x , self .index , weights )
486503 else :
487504 # not supported for progressive dim
488505 assert weights is None
489506 assert init_centroids is None
490507 assert not self .cp .spherical
491508 clus = ProgressiveDimClustering (d , self .k , self .cp )
492- if self .gpu :
493- fac = GpuProgressiveDimIndexFactory (ngpu = self .gpu )
494- else :
495- fac = ProgressiveDimIndexFactory ()
496- clus .train (n , swig_ptr (x ), fac )
509+ clus .train (n , swig_ptr (x ), self .fac )
497510
498511 centroids = faiss .vector_float_to_array (clus .centroids )
499512
0 commit comments