Skip to content

How to modify (override) kmeans centroids and do inference #2353

@scaramouche88

Description

@scaramouche88

Summary

I'm a newbie with FAISS, but it works nicely.
Maybe I'm not 100% into it yet, so sorry if my questions are silly.

I would like to train a k-mean, do inference, make some maths on the centroids and redo inference without retraining.
Which is the best way to do it?

Reproduction instructions

Here's the importat part of my class:

class Clustering():`
    def __init__(self, nb_class, start_centr, nb_iter, threshold, gpu=True):
        self.nb_class=nb_class
        self.centroids=start_centr
        self.iter=nb_iter
        self.nb_feat = 16
        self.kmeans = faiss.Kmeans(self.nb_feat, self.nb_class, niter=self.iter, gpu=self.gpu)
        self.kmeans.seed = np.random.randint(1234)
        self.kmeans.centroids = self.centroids

    def train_kmean(self, array, nb_feat):
        self.nb_feat=nb_feat
        #this function prepares the data for FAISS - not reported
        array = self.reshape_array_for_faiss(array)
        self.kmeans.train(array,init_centroids=self.centroids)
        _, I = self.kmeans.index.search(array, 1)
        self.centroids = copy.deepcopy(self.kmeans.centroids)            
        loss = self.kmeans.obj[-1]
        #this function prepares the data for the rest of the code  - not reported
        I = self.reshape_array_for_keras(I)
        return I, loss

    def val_kmean(self,array):
        array = self.reshape_array_for_faiss(array)
        centroids2
        loss = self.kmeans.obj[-1]
        I = self.reshape_array_for_keras(I)
        return I, loss

Then in the main code:

[...]
self.clustering = Clustering(2, self.centroids, 30, 10000, gpu=True)
data_kmeans, kloss = self.clustering.train_kmean(data, data.shape[-1])

new_data_kmeans1, _ = self.clustering.val_kmean(new_data)
centroids1 = getattr(self.clustering.kmeans,'centroids')

[...doing maths on centroids1 to obtain centroids2...]

setattr(self.clustering.kmeans,'centroids',centroids2)
new_data_kmeans2, _ = self.clustering.val_kmean(new_data)

However looking at #1940 and https://gist.github.com/mdouze/9eb96d941c94ef59482a069e5862a650 I have the impression that I do not really update the index.

Should be something like this?
self.clustering.kmeans.index.add(centroids2)
new_data_kmeans2, _ = self.clustering.val_kmean(new_data)
but how to override the old centroids?

Running on:

  • CPU
  • GPU

Interface:

  • C++
  • Python

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions