Skip to content

Commit beabc37

Browse files
Mengdi Linfacebook-github-bot
authored andcommitted
RCQ search microbenchmark (#3863)
Summary: add RCQ search microbenchmark with parameters similar to the ones provided by ivansopin in D62529163 Differential Revision: D62769834
1 parent 6da9952 commit beabc37

1 file changed

Lines changed: 65 additions & 0 deletions

File tree

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/**
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include <gflags/gflags.h>
9+
10+
#include <benchmark/benchmark.h>
11+
#include <faiss/IndexAdditiveQuantizer.h> // @manual=//faiss:faiss_no_multithreading
12+
#include <faiss/utils/random.h>
13+
14+
using namespace faiss;
15+
DEFINE_uint32(iterations, 20, "iterations");
16+
DEFINE_uint32(nprobe, 1, "nprobe");
17+
DEFINE_uint32(batch_size, 1, "batch_size");
18+
DEFINE_double(beam_factor, 4.0, "beam factor");
19+
20+
static void bench_search(
21+
benchmark::State& state,
22+
int batch_size,
23+
int nprobe,
24+
float beam_factor) {
25+
int d = 512;
26+
int nt = 2 << 15;
27+
std::vector<float> xt(d * nt);
28+
29+
float_rand(xt.data(), d * nt, 12345);
30+
ResidualCoarseQuantizer rq(d, {16, 8});
31+
rq.verbose = false;
32+
rq.train(nt, xt.data());
33+
34+
std::vector<float> xq(d * batch_size);
35+
float_rand(xq.data(), d * batch_size, 12345);
36+
37+
std::vector<float> distances(nprobe * batch_size);
38+
std::vector<int64_t> clusterIndices(nprobe * batch_size);
39+
SearchParametersResidualCoarseQuantizer param;
40+
param.beam_factor = beam_factor;
41+
for (auto _ : state) {
42+
rq.search(
43+
batch_size,
44+
xq.data(),
45+
nprobe,
46+
distances.data(),
47+
clusterIndices.data(),
48+
&param);
49+
}
50+
}
51+
52+
int main(int argc, char** argv) {
53+
benchmark::Initialize(&argc, argv);
54+
gflags::AllowCommandLineReparsing();
55+
gflags::ParseCommandLineFlags(&argc, &argv, true);
56+
int iterations = FLAGS_iterations;
57+
int nprobe = FLAGS_nprobe;
58+
float beam_factor = FLAGS_beam_factor;
59+
int batch_size = FLAGS_batch_size;
60+
benchmark::RegisterBenchmark(
61+
"search", bench_search, batch_size, nprobe, beam_factor)
62+
->Iterations(iterations);
63+
benchmark::RunSpecifiedBenchmarks();
64+
benchmark::Shutdown();
65+
}

0 commit comments

Comments
 (0)