Skip to content

Commit eed3860

Browse files
Michael Norrisfacebook-github-bot
authored andcommitted
stop dealloc of coarse quantizer when it is deleted (#4045)
Summary: X-link: meta-pytorch/torchrec#2603 Pull Request resolved: #4045 Need to add to `__init__.py` like Matthijs mentioned on the github issue #3993. But we can't do it for non-GPU code, otherwise it will throw an exception and fail many tests than include fbcode/faiss. So we need to check if FAISS GPU is importable first. To find the class names like GpuIndexIVFFlat etc, I checked everything under faiss/gpu where the constructor accepts an Index. The other Index is always parameter at index 1 (0-indexed), so that's why we use 1 in the function calls. Reviewed By: pankajsingh88 Differential Revision: D66675910 fbshipit-source-id: f170dadb6318c620420689164f9522f9815aa980
1 parent f8ae5f4 commit eed3860

2 files changed

Lines changed: 108 additions & 0 deletions

File tree

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from __future__ import print_function
7+
import unittest
8+
import numpy as np
9+
import faiss
10+
from enum import Enum
11+
from faiss.contrib.datasets import SyntheticDataset
12+
13+
14+
class DeletionSite(Enum):
15+
BEFORE_TRAIN = 1
16+
BEFORE_ADD = 2
17+
BEFORE_SEARCH = 3
18+
19+
20+
def do_test(idx, index_to_delete, db, deletion_site: DeletionSite):
21+
if deletion_site == DeletionSite.BEFORE_TRAIN:
22+
del index_to_delete
23+
24+
idx.train(db)
25+
26+
if deletion_site == DeletionSite.BEFORE_ADD:
27+
del index_to_delete
28+
29+
idx.add(db)
30+
31+
if deletion_site == DeletionSite.BEFORE_SEARCH:
32+
del index_to_delete
33+
34+
idx.search(db, 1)
35+
36+
37+
def do_multi_test(idx, index_to_delete, db):
38+
for site in DeletionSite:
39+
do_test(idx, index_to_delete, db, site)
40+
41+
42+
#
43+
# Test
44+
#
45+
46+
47+
class TestRefs(unittest.TestCase):
48+
d = 32
49+
nv = 1000
50+
nlist = 10
51+
res = faiss.StandardGpuResources() # pyre-ignore
52+
db = np.random.rand(nv, d)
53+
54+
# These GPU classes reference another index.
55+
# This tests to make sure the deletion of the other index
56+
# does not cause a crash.
57+
58+
def test_GpuIndexIVFFlat(self):
59+
index_to_delete = faiss.IndexIVFFlat(
60+
faiss.IndexFlat(self.d), self.d, self.nlist
61+
)
62+
idx = faiss.GpuIndexIVFFlat(
63+
self.res, index_to_delete, faiss.GpuIndexIVFFlatConfig()
64+
)
65+
do_multi_test(idx, index_to_delete, self.db)
66+
67+
def test_GpuIndexBinaryFlat(self):
68+
ds = SyntheticDataset(64, 1000, 1000, 200)
69+
index_to_delete = faiss.IndexBinaryFlat(ds.d)
70+
idx = faiss.GpuIndexBinaryFlat(self.res, index_to_delete)
71+
tobinary = faiss.index_factory(ds.d, "LSHrt")
72+
tobinary.train(ds.get_train())
73+
xb = tobinary.sa_encode(ds.get_database())
74+
do_multi_test(idx, index_to_delete, xb)
75+
76+
def test_GpuIndexFlat(self):
77+
index_to_delete = faiss.IndexFlat(self.d, faiss.METRIC_L2)
78+
idx = faiss.GpuIndexFlat(self.res, index_to_delete)
79+
do_multi_test(idx, index_to_delete, self.db)
80+
81+
def test_GpuIndexIVFPQ(self):
82+
index_to_delete = faiss.IndexIVFPQ(
83+
faiss.IndexFlatL2(self.d),
84+
self.d, self.nlist, 2, 8)
85+
idx = faiss.GpuIndexIVFPQ(self.res, index_to_delete)
86+
do_multi_test(idx, index_to_delete, self.db)
87+
88+
def test_GpuIndexIVFScalarQuantizer(self):
89+
index_to_delete = faiss.IndexIVFScalarQuantizer(
90+
faiss.IndexFlat(self.d, faiss.METRIC_L2),
91+
self.d,
92+
self.nlist,
93+
faiss.ScalarQuantizer.QT_8bit_direct,
94+
faiss.METRIC_L2,
95+
False
96+
)
97+
idx = faiss.GpuIndexIVFScalarQuantizer(self.res, index_to_delete)
98+
do_multi_test(idx, index_to_delete, self.db)

faiss/python/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,16 @@ def replacement_function(*args):
160160
setattr(this_module, function_name, replacement_function)
161161

162162

163+
try:
164+
from swigfaiss_gpu import GpuIndexIVFFlat, GpuIndexBinaryFlat, GpuIndexFlat, GpuIndexIVFPQ, GpuIndexIVFScalarQuantizer
165+
add_ref_in_constructor(GpuIndexIVFFlat, 1)
166+
add_ref_in_constructor(GpuIndexBinaryFlat, 1)
167+
add_ref_in_constructor(GpuIndexFlat, 1)
168+
add_ref_in_constructor(GpuIndexIVFPQ, 1)
169+
add_ref_in_constructor(GpuIndexIVFScalarQuantizer, 1)
170+
except ImportError as e:
171+
print("Failed to load GPU Faiss: %s. Will not load constructor refs for GPU indexes." % e.args[0])
172+
163173
add_ref_in_constructor(IndexIVFFlat, 0)
164174
add_ref_in_constructor(IndexIVFFlatDedup, 0)
165175
add_ref_in_constructor(IndexPreTransform, {2: [0, 1], 1: [0]})

0 commit comments

Comments
 (0)