Skip to content

Commit 900241c

Browse files
mdouzefacebook-github-bot
authored andcommitted
rewrite python kmeans without scipy (#3873)
Summary: Pull Request resolved: #3873 The previous version required scipy to do the accumulation, which is replaced here with a nifty piece of numpy accumulation. This removes the need for scipy for non-sparse data. Reviewed By: junjieqi Differential Revision: D62884307
1 parent 4e30901 commit 900241c

3 files changed

Lines changed: 28 additions & 28 deletions

File tree

contrib/clustering.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,12 @@ def assign_to(self, centroids, weights=None):
151151

152152
I = I.ravel()
153153
D = D.ravel()
154-
n = len(self.x)
154+
nc, d = centroids.shape
155+
sum_per_centroid = np.zeros((nc, d), dtype='float32')
155156
if weights is None:
156-
weights = np.ones(n, dtype='float32')
157-
nc = len(centroids)
158-
m = scipy.sparse.csc_matrix(
159-
(weights, I, np.arange(n + 1)),
160-
shape=(nc, n))
161-
sum_per_centroid = m * self.x
157+
np.add.at(sum_per_centroid, I, self.x)
158+
else:
159+
np.add.at(sum_per_centroid, I, weights[:, np.newaxis] * self.x)
162160

163161
return I, D, sum_per_centroid
164162

@@ -185,7 +183,8 @@ def perform_search(self, centroids):
185183

186184
def sparse_assign_to_dense(xq, xb, xq_norms=None, xb_norms=None):
187185
""" assignment function for xq is sparse, xb is dense
188-
uses a matrix multiplication. The squared norms can be provided if available.
186+
uses a matrix multiplication. The squared norms can be provided if
187+
available.
189188
"""
190189
nq = xq.shape[0]
191190
nb = xb.shape[0]
@@ -272,6 +271,7 @@ def assign_to(self, centroids, weights=None):
272271
if weights is None:
273272
weights = np.ones(n, dtype='float32')
274273
nc = len(centroids)
274+
275275
m = scipy.sparse.csc_matrix(
276276
(weights, I, np.arange(n + 1)),
277277
shape=(nc, n))

tests/test_contrib.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,26 @@ def test_binary(self):
517517

518518
class TestClustering(unittest.TestCase):
519519

520+
def test_python_kmeans(self):
521+
""" Test the python implementation of kmeans """
522+
ds = datasets.SyntheticDataset(32, 10000, 0, 0)
523+
x = ds.get_train()
524+
525+
# bad distribution to stress-test split code
526+
xt = x[:10000].copy()
527+
xt[:5000] = x[0]
528+
529+
km_ref = faiss.Kmeans(ds.d, 100, niter=10)
530+
km_ref.train(xt)
531+
err = faiss.knn(xt, km_ref.centroids, 1)[0].sum()
532+
533+
data = clustering.DatasetAssign(xt)
534+
centroids = clustering.kmeans(100, data, 10)
535+
err2 = faiss.knn(xt, centroids, 1)[0].sum()
536+
537+
# err=33498.332 err2=33380.477
538+
self.assertLess(err2, err * 1.1)
539+
520540
def test_2level(self):
521541
" verify that 2-level clustering is not too sub-optimal "
522542
ds = datasets.SyntheticDataset(32, 10000, 0, 0)

tests/test_contrib_with_scipy.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,6 @@
1717

1818
class TestClustering(unittest.TestCase):
1919

20-
def test_python_kmeans(self):
21-
""" Test the python implementation of kmeans """
22-
ds = datasets.SyntheticDataset(32, 10000, 0, 0)
23-
x = ds.get_train()
24-
25-
# bad distribution to stress-test split code
26-
xt = x[:10000].copy()
27-
xt[:5000] = x[0]
28-
29-
km_ref = faiss.Kmeans(ds.d, 100, niter=10)
30-
km_ref.train(xt)
31-
err = faiss.knn(xt, km_ref.centroids, 1)[0].sum()
32-
33-
data = clustering.DatasetAssign(xt)
34-
centroids = clustering.kmeans(100, data, 10)
35-
err2 = faiss.knn(xt, centroids, 1)[0].sum()
36-
37-
# 33517.645 and 33031.098
38-
self.assertLess(err2, err * 1.1)
39-
4020
def test_sparse_routines(self):
4121
""" the sparse assignment routine """
4222
ds = datasets.SyntheticDataset(1000, 2000, 0, 200)

0 commit comments

Comments
 (0)