Skip to content

Commit 7bfe32f

Browse files
gtwang01facebook-github-bot
authored andcommitted
Add testing for utils/hamming.cpp
Summary: As title Differential Revision: D66976823
1 parent 96bc9c7 commit 7bfe32f

3 files changed

Lines changed: 295 additions & 1 deletion

File tree

faiss/utils/hamming.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
#include <faiss/utils/hamming.h>
2525

2626
#include <algorithm>
27-
#include <cmath>
2827
#include <cstdio>
2928
#include <memory>
3029
#include <vector>

tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ set(FAISS_TEST_SRC
3535
test_common_ivf_empty_index.cpp
3636
test_callback.cpp
3737
test_utils.cpp
38+
test_hamming.cpp
3839
)
3940

4041
add_executable(faiss_test ${FAISS_TEST_SRC})

tests/test_hamming.cpp

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include <gtest/gtest.h>
9+
10+
#include <faiss/impl/FaissAssert.h>
11+
#include <faiss/utils/hamming.h>
12+
#include <random>
13+
14+
using namespace ::testing;
15+
16+
template <typename T>
17+
std::string print_data(
18+
std::shared_ptr<std::vector<T>> data,
19+
const size_t divider) {
20+
std::string ret = "";
21+
for (int i = 0; i < data->size(); ++i) {
22+
if (i % divider) {
23+
ret += " ";
24+
} else {
25+
ret += "|";
26+
}
27+
ret += std::to_string((*data)[i]);
28+
}
29+
ret += "|";
30+
return ret;
31+
}
32+
33+
std::stringstream get_correct_hamming_example(
34+
const size_t na, // number of queries
35+
const size_t nb, // number of candidates
36+
const size_t k,
37+
const size_t code_size,
38+
std::shared_ptr<std::vector<uint8_t>> a,
39+
std::shared_ptr<std::vector<uint8_t>> b,
40+
std::shared_ptr<std::vector<long>> true_ids,
41+
std::shared_ptr<std::vector<int>> true_distances) {
42+
assert(nb > k);
43+
44+
// Initialization
45+
std::default_random_engine rng(123);
46+
std::uniform_int_distribution<int32_t> uniform(0, nb - 1);
47+
48+
const size_t nresults = na * k;
49+
50+
a->clear();
51+
a->resize(na * code_size, 1); // query vectors are all 1
52+
b->clear();
53+
b->resize(nb * code_size, 2); // database vectors are all 2
54+
55+
true_ids->clear();
56+
true_ids->reserve(nresults);
57+
true_distances->clear();
58+
true_distances->reserve(nresults);
59+
60+
// define correct ids (must be unique)
61+
std::set<long> correct_ids;
62+
do {
63+
correct_ids.insert(uniform(rng));
64+
} while (correct_ids.size() < k);
65+
66+
// replace database vector at id with vector more similar to query
67+
// ordered, so earlier ids must be more similar
68+
for (size_t nmatches = k; nmatches > 0; --nmatches) {
69+
// get id and erase it
70+
const size_t id = *correct_ids.begin();
71+
*correct_ids.erase(correct_ids.begin());
72+
73+
// assemble true id and distance at locations
74+
true_ids->push_back(id);
75+
true_distances->push_back(code_size - nmatches); // hamming dist
76+
for (size_t i = 0; i < nmatches; ++i) {
77+
b->begin()[id * code_size + i] = 1;
78+
}
79+
}
80+
81+
// true_ids and true_distances only contain results for the first query
82+
// each query is identical, so copy the first query na-1 times
83+
for (size_t i = 1; i < na; ++i) {
84+
true_ids->insert(
85+
true_ids->end(), true_ids->begin(), true_ids->begin() + k);
86+
true_distances->insert(
87+
true_distances->end(),
88+
true_distances->begin(),
89+
true_distances->begin() + k);
90+
}
91+
92+
// assemble string for debugging
93+
std::stringstream ret;
94+
ret << "na: " << na << std::endl
95+
<< "nb: " << nb << std::endl
96+
<< "k: " << k << std::endl
97+
<< "code_size: " << code_size << std::endl
98+
<< "a: " << print_data(a, code_size) << std::endl
99+
<< "b: " << print_data(b, code_size) << std::endl
100+
<< "true_ids: " << print_data(true_ids, k) << std::endl
101+
<< "true_distances: " << print_data(true_distances, k) << std::endl;
102+
return ret;
103+
}
104+
TEST(TestHamming, test_crosshamming_count_thres) {
105+
// Initialize randomizer
106+
std::default_random_engine rng(123);
107+
std::uniform_int_distribution<int32_t> uniform(0, 255);
108+
109+
// Initialize inputs
110+
const size_t n = 10; // number of codes
111+
const hamdis_t hamming_threshold = 20;
112+
113+
// one for each case - 65 is default
114+
for (auto ncodes : {8, 16, 32, 64, 65}) {
115+
// initialize inputs
116+
const int nbits = ncodes * 8;
117+
const size_t nwords = nbits / 64;
118+
// 8 to for later conversion to uint64_t, and 2 for buffer
119+
std::vector<uint8_t> dbs(nwords * n * 8 * 2);
120+
for (int i = 0; i < dbs.size(); ++i) {
121+
dbs[i] = uniform(rng);
122+
}
123+
124+
// get true distance
125+
size_t true_count = 0;
126+
uint64_t* bs1 = (uint64_t*)dbs.data();
127+
for (int i = 0; i < n; ++i) {
128+
uint64_t* bs2 = bs1 + 2;
129+
for (int j = i + 1; j < n; ++j) {
130+
if (faiss::hamming(bs1 + i * nwords, bs2 + j * nwords, nwords) <
131+
hamming_threshold) {
132+
++true_count;
133+
}
134+
}
135+
}
136+
137+
// run test and check correctness
138+
size_t count;
139+
if (ncodes == 65) {
140+
ASSERT_THROW(
141+
faiss::crosshamming_count_thres(
142+
dbs.data(), n, hamming_threshold, ncodes, &count),
143+
faiss::FaissException);
144+
continue;
145+
}
146+
faiss::crosshamming_count_thres(
147+
dbs.data(), n, hamming_threshold, ncodes, &count);
148+
149+
ASSERT_EQ(count, true_count) << "ncodes = " << ncodes;
150+
}
151+
}
152+
TEST(TestHamming, test_hamming_thres) {
153+
// Initialize randomizer
154+
std::default_random_engine rng(123);
155+
std::uniform_int_distribution<int32_t> uniform(0, 255);
156+
157+
// Initialize inputs
158+
const size_t n1 = 10;
159+
const size_t n2 = 15;
160+
const hamdis_t hamming_threshold = 100;
161+
162+
// one for each case - 65 is default
163+
for (auto ncodes : {8, 16, 32, 64, 65}) {
164+
// initialize inputs
165+
const int nbits = ncodes * 8;
166+
const size_t nwords = nbits / 64;
167+
std::vector<uint8_t> bs1(nwords * n1 * 8);
168+
std::vector<uint8_t> bs2(nwords * n2 * 8);
169+
for (int i = 0; i < bs1.size(); ++i) {
170+
bs1[i] = uniform(rng);
171+
}
172+
for (int i = 0; i < bs2.size(); ++i) {
173+
bs2[i] = uniform(rng);
174+
}
175+
176+
// get true distance
177+
size_t true_count = 0;
178+
std::vector<int64_t> true_idx;
179+
std::vector<hamdis_t> true_dis;
180+
181+
uint64_t* bs1_64 = (uint64_t*)bs1.data();
182+
uint64_t* bs2_64 = (uint64_t*)bs2.data();
183+
for (int i = 0; i < n1; ++i) {
184+
for (int j = 0; j < n2; ++j) {
185+
hamdis_t ham_dist = faiss::hamming(
186+
bs1_64 + i * nwords, bs2_64 + j * nwords, nwords);
187+
if (ham_dist < hamming_threshold) {
188+
++true_count;
189+
true_idx.push_back(i);
190+
true_idx.push_back(j);
191+
true_dis.push_back(ham_dist);
192+
}
193+
}
194+
}
195+
196+
// run test and check correctness for both
197+
// match_hamming_thres and hamming_count_thres
198+
std::vector<int64_t> idx(true_idx.size());
199+
std::vector<hamdis_t> dis(true_dis.size());
200+
if (ncodes == 65) {
201+
ASSERT_THROW(
202+
faiss::match_hamming_thres(
203+
bs1.data(),
204+
bs2.data(),
205+
n1,
206+
n2,
207+
hamming_threshold,
208+
ncodes,
209+
idx.data(),
210+
dis.data()),
211+
faiss::FaissException);
212+
ASSERT_THROW(
213+
faiss::hamming_count_thres(
214+
bs1.data(),
215+
bs2.data(),
216+
n1,
217+
n2,
218+
hamming_threshold,
219+
ncodes,
220+
nullptr),
221+
faiss::FaissException);
222+
continue;
223+
}
224+
size_t match_count = faiss::match_hamming_thres(
225+
bs1.data(),
226+
bs2.data(),
227+
n1,
228+
n2,
229+
hamming_threshold,
230+
ncodes,
231+
idx.data(),
232+
dis.data());
233+
size_t count_count;
234+
faiss::hamming_count_thres(
235+
bs1.data(),
236+
bs2.data(),
237+
n1,
238+
n2,
239+
hamming_threshold,
240+
ncodes,
241+
&count_count);
242+
243+
ASSERT_EQ(match_count, true_count) << "ncodes = " << ncodes;
244+
ASSERT_EQ(count_count, true_count) << "ncodes = " << ncodes;
245+
ASSERT_EQ(idx, true_idx) << "ncodes = " << ncodes;
246+
ASSERT_EQ(dis, true_dis) << "ncodes = " << ncodes;
247+
}
248+
}
249+
250+
TEST(TestHamming, test_hamming_knn) {
251+
// Initialize randomizer
252+
std::default_random_engine rng(123);
253+
std::uniform_int_distribution<int32_t> uniform(0, 32);
254+
255+
// Initialize inputs
256+
const size_t na = 4;
257+
const size_t nb = 12; // number of candidates
258+
const size_t k = 6;
259+
260+
auto a = std::make_shared<std::vector<uint8_t>>();
261+
auto b = std::make_shared<std::vector<uint8_t>>();
262+
auto true_ids = std::make_shared<std::vector<long>>();
263+
auto true_distances = std::make_shared<std::vector<int>>();
264+
265+
// 8, 16, 32 are cases - 24 will hit default case
266+
// all should be multiples of 8
267+
for (auto code_size : {8, 16, 24, 32}) {
268+
// get example
269+
std::stringstream assert_str = get_correct_hamming_example(
270+
na, nb, k, code_size, a, b, true_ids, true_distances);
271+
272+
// run test on generalized_hammings_knn_hc
273+
std::vector<long> ids_gen(na * k);
274+
std::vector<int> dist_gen(na * k);
275+
faiss::int_maxheap_array_t res = {
276+
na, k, ids_gen.data(), dist_gen.data()};
277+
faiss::generalized_hammings_knn_hc(
278+
&res, a->data(), b->data(), nb, code_size, true);
279+
ASSERT_EQ(ids_gen, *true_ids) << assert_str.str();
280+
ASSERT_EQ(dist_gen, *true_distances) << assert_str.str();
281+
282+
// run test on hammings_knn
283+
std::vector<long> ids_ham_knn(na * k, 0);
284+
std::vector<int> dist_ham_knn(na * k, 0);
285+
res = {na, k, ids_ham_knn.data(), dist_ham_knn.data()};
286+
faiss::hammings_knn(&res, a->data(), b->data(), nb, code_size, true);
287+
ASSERT_EQ(ids_ham_knn, *true_ids) << assert_str.str();
288+
// hammings_knn results in twice the distance for some reason :/
289+
for (int i = 0; i < dist_ham_knn.size(); ++i) {
290+
dist_ham_knn[i] /= 2;
291+
}
292+
ASSERT_EQ(dist_ham_knn, *true_distances) << assert_str.str();
293+
}
294+
}

0 commit comments

Comments
 (0)