Skip to content

Commit d70e64b

Browse files
author
Gustav von Zitzewitz
committed
add missing sel to hamming.h, add no heap test case, simplify valid_counter
1 parent c5c9cab commit d70e64b

3 files changed

Lines changed: 29 additions & 13 deletions

File tree

faiss/utils/approx_topk_hamming/approx_topk_hamming.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,19 +99,19 @@ struct HeapWithBucketsForHamming32<
9999
for (uint32_t ip = 0; ip < nb; ip += NBUCKETS) {
100100
for (uint32_t j = 0; j < NBUCKETS_8; j++) {
101101
uint32_t hamming_distances[8];
102-
uint32_t valid_mask = 0;
102+
uint8_t valid_counter = 0;
103103
for (size_t j8 = 0; j8 < 8; j8++) {
104104
const uint32_t idx = j8 + j * 8 + ip + n_per_beam * beam_index;
105105
if (!sel || sel->is_member(idx)) {
106106
hamming_distances[j8] = hc.hamming(
107107
binary_vectors + idx * code_size);
108-
valid_mask |= (1 << j8);
108+
valid_counter++;
109109
} else {
110110
hamming_distances[j8] = std::numeric_limits<int32_t>::max();
111111
}
112112
}
113113

114-
if (valid_mask == 0) {
114+
if (valid_counter == 0) {
115115
continue; // Skip if all vectors are filtered out
116116
}
117117

faiss/utils/hamming.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ void hammings_knn_hc(
135135
size_t nb,
136136
size_t ncodes,
137137
int ordered,
138-
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK);
138+
ApproxTopK_mode_t approx_topk_mode = ApproxTopK_mode_t::EXACT_TOPK,
139+
const IDSelector* sel = nullptr);
139140

140141
/* Legacy alias to hammings_knn_hc. */
141142
void hammings_knn(
@@ -166,7 +167,8 @@ void hammings_knn_mc(
166167
size_t k,
167168
size_t ncodes,
168169
int32_t* distances,
169-
int64_t* labels);
170+
int64_t* labels,
171+
const IDSelector* sel = nullptr);
170172

171173
/** same as hammings_knn except we are doing a range search with radius */
172174
void hamming_range_search(
@@ -176,7 +178,8 @@ void hamming_range_search(
176178
size_t nb,
177179
int radius,
178180
size_t ncodes,
179-
RangeSearchResult* result);
181+
RangeSearchResult* result,
182+
const IDSelector* sel = nullptr);
180183

181184
/* Counting the number of matches or of cross-matches (without returning them)
182185
For use with function that assume pre-allocated memory */

tests/test_search_params.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ class TestSelector(unittest.TestCase):
2222
combinations as possible.
2323
"""
2424

25-
def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METRIC_L2, k=10):
25+
def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METRIC_L2, k=10, params=None):
2626
""" Verify that the id selector returns the subset of results that are
2727
members according to the IDSelector.
2828
Supports id_selector_type="batch", "bitmap", "range", "range_sorted", "and", "or", "xor"
29+
params: optional SearchParameters object to override default settings
2930
"""
3031
d = 32 # make sure dimension is multiple of 8 for binary
3132
ds = datasets.SyntheticDataset(d, 1000, 100, 20)
@@ -73,6 +74,8 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR
7374
subset = rs.choice(ds.nb, 50, replace=False).astype('int64')
7475

7576
index.add(xb[subset])
77+
if "IVF" in index_key and id_selector_type == "range_sorted":
78+
self.assertTrue(index.check_ids_sorted())
7679
Dref, Iref0 = index.search(xq, k)
7780
Iref = subset[Iref0]
7881
Iref[Iref0 < 0] = -1
@@ -134,11 +137,16 @@ def do_test_id_selector(self, index_key, id_selector_type="batch", mt=faiss.METR
134137
else:
135138
sel = faiss.IDSelectorBatch(subset)
136139

137-
params = (
138-
faiss.SearchParametersIVF(sel=sel) if "IVF" in index_key else
139-
faiss.SearchParametersPQ(sel=sel) if "PQ" in index_key else
140-
faiss.SearchParameters(sel=sel)
141-
)
140+
if params is None:
141+
params = (
142+
faiss.SearchParametersIVF(sel=sel) if "IVF" in index_key else
143+
faiss.SearchParametersPQ(sel=sel) if "PQ" in index_key else
144+
faiss.SearchParameters(sel=sel)
145+
)
146+
else:
147+
# Use provided params but ensure selector is set
148+
params.sel = sel
149+
142150
Dnew, Inew = index.search(xq, k, params=params)
143151
np.testing.assert_array_equal(Iref, Inew)
144152
np.testing.assert_almost_equal(Dref, Dnew, decimal=5)
@@ -308,6 +316,11 @@ def test_BinaryFlat_id_range(self):
308316
def test_BinaryFlat_id_array(self):
309317
self.do_test_id_selector("BinaryFlat", id_selector_type="array")
310318

319+
def test_BinaryFlat_no_heap(self):
320+
params = faiss.SearchParameters()
321+
params.use_heap = False
322+
self.do_test_id_selector("BinaryFlat", params=params)
323+
311324

312325
class TestSearchParams(unittest.TestCase):
313326

@@ -528,4 +541,4 @@ def test_knn_and_range_PQ(self):
528541
self.do_test_knn_and_range("IVF32,PQ8x4np")
529542

530543
def test_knn_and_range_FS(self):
531-
self.do_test_knn_and_range("IVF32,PQ8x4fs", range=False)
544+
self.do_test_knn_and_range("IVF32,PQ8x4fs", range=False)

0 commit comments

Comments
 (0)