@@ -52,10 +52,12 @@ def accumulate_perf_counter(
5252def run_on_dataset (
5353 ds : Dataset ,
5454 M : int ,
55- num_threads :
56- int ,
55+ num_threads : int ,
56+ num_add_iterations : int ,
57+ num_search_iterations : int ,
5758 efSearch : int = 16 ,
58- efConstruction : int = 40
59+ efConstruction : int = 40 ,
60+ search_bounded_queue : bool = True ,
5961) -> Dict [str , int ]:
6062 xq = ds .get_queries ()
6163 xb = ds .get_database ()
@@ -67,22 +69,27 @@ def run_on_dataset(
6769 # pyre-ignore[16]: Module `faiss` has no attribute `omp_set_num_threads`.
6870 faiss .omp_set_num_threads (num_threads )
6971 index = faiss .IndexHNSWFlat (d , M )
70- index .hnsw .efConstruction = 40 # default
72+ index .hnsw .efConstruction = efConstruction # default
7173 with timed_execution () as t :
72- index .add (xb )
74+ for _ in range (num_add_iterations ):
75+ index .add (xb )
7376 counters = {}
7477 accumulate_perf_counter ("add" , t , counters )
7578 counters ["nb" ] = nb
79+ counters ["num_add_iterations" ] = num_add_iterations
7680
7781 index .hnsw .efSearch = efSearch
82+ index .hnsw .search_bounded_queue = search_bounded_queue
7883 with timed_execution () as t :
79- D , I = index .search (xq , k )
84+ for _ in range (num_search_iterations ):
85+ D , I = index .search (xq , k )
8086 accumulate_perf_counter ("search" , t , counters )
8187 counters ["nq" ] = nq
8288 counters ["efSearch" ] = efSearch
8389 counters ["efConstruction" ] = efConstruction
8490 counters ["M" ] = M
8591 counters ["d" ] = d
92+ counters ["num_search_iterations" ] = num_search_iterations
8693
8794 return counters
8895
@@ -93,61 +100,25 @@ def run(
93100 nq : int ,
94101 M : int ,
95102 num_threads : int ,
103+ num_add_iterations : int = 1 ,
104+ num_search_iterations : int = 1 ,
96105 efSearch : int = 16 ,
97106 efConstruction : int = 40 ,
107+ search_bounded_queue : bool = True ,
98108) -> Dict [str , int ]:
99109 ds = SyntheticDataset (d = d , nb = nb , nt = 0 , nq = nq , metric = "L2" , seed = 1338 )
100110 return run_on_dataset (
101111 ds ,
102112 M = M ,
113+ num_add_iterations = num_add_iterations ,
114+ num_search_iterations = num_search_iterations ,
103115 num_threads = num_threads ,
104116 efSearch = efSearch ,
105117 efConstruction = efConstruction ,
118+ search_bounded_queue = search_bounded_queue ,
106119 )
107120
108121
109- def _merge_counters (
110- element : Dict [str , int ], accu : Optional [Dict [str , int ]] = None
111- ) -> Dict [str , int ]:
112- if accu is None :
113- return dict (element )
114- else :
115- assert accu .keys () <= element .keys (), (
116- "Accu keys must be a subset of element keys: "
117- f"{ accu .keys ()} not a subset of { element .keys ()} "
118- )
119- for key in accu .keys ():
120- if is_perf_counter (key ):
121- accu [key ] += element [key ]
122- return accu
123-
124-
125- def run_with_iterations (
126- iterations : int ,
127- d : int ,
128- nb : int ,
129- nq : int ,
130- M : int ,
131- num_threads : int ,
132- efSearch : int = 16 ,
133- efConstruction : int = 40 ,
134- ) -> Dict [str , int ]:
135- result = None
136- for _ in range (iterations ):
137- counters = run (
138- d = d ,
139- nb = nb ,
140- nq = nq ,
141- M = M ,
142- num_threads = num_threads ,
143- efSearch = efSearch ,
144- efConstruction = efConstruction ,
145- )
146- result = _merge_counters (counters , result )
147- assert result is not None
148- return result
149-
150-
151122def _accumulate_counters (
152123 element : Dict [str , int ], accu : Optional [Dict [str , List [int ]]] = None
153124) -> Dict [str , List [int ]]:
@@ -169,10 +140,13 @@ def main():
169140 parser .add_argument ("-M" , "--M" , type = int , required = True )
170141 parser .add_argument ("-t" , "--num-threads" , type = int , required = True )
171142 parser .add_argument ("-w" , "--warm-up-iterations" , type = int , default = 0 )
172- parser .add_argument ("-i" , "--num-iterations" , type = int , default = 20 )
143+ parser .add_argument ("-i" , "--num-search-iterations" , type = int , default = 20 )
144+ parser .add_argument ("-i" , "--num-add-iterations" , type = int , default = 20 )
173145 parser .add_argument ("-r" , "--num-repetitions" , type = int , default = 20 )
174146 parser .add_argument ("-s" , "--ef-search" , type = int , default = 16 )
175147 parser .add_argument ("-c" , "--ef-construction" , type = int , default = 40 )
148+ parser .add_argument ("-b" , "--search-bounded-queue" , action = "store_true" )
149+
176150 parser .add_argument ("-n" , "--nb" , type = int , default = 5000 )
177151 parser .add_argument ("-q" , "--nq" , type = int , default = 500 )
178152 parser .add_argument ("-d" , "--d" , type = int , default = 128 )
@@ -181,15 +155,17 @@ def main():
181155 if args .warm_up_iterations > 0 :
182156 print (f"Warming up for { args .warm_up_iterations } iterations..." )
183157 # warm-up
184- run_with_iterations (
185- iterations = args .warm_up_iterations ,
158+ run (
159+ num_search_iterations = args .warm_up_iterations ,
160+ num_add_iterations = args .warm_up_iterations ,
186161 d = args .d ,
187162 nb = args .nb ,
188163 nq = args .nq ,
189164 M = args .M ,
190165 num_threads = args .num_threads ,
191166 efSearch = args .ef_search ,
192167 efConstruction = args .ef_construction ,
168+ search_bounded_queue = args .search_bounded_queue ,
193169 )
194170 print (
195171 f"Running benchmark with dataset(nb={ args .nb } , nq={ args .nq } , "
@@ -198,24 +174,23 @@ def main():
198174 )
199175 result = None
200176 for _ in range (args .num_repetitions ):
201- counters = run_with_iterations (
202- iterations = args .num_iterations ,
177+ counters = run (
178+ num_search_iterations = args .num_search_iterations ,
179+ num_add_iterations = args .num_add_iterations ,
203180 d = args .d ,
204181 nb = args .nb ,
205182 nq = args .nq ,
206183 M = args .M ,
207184 num_threads = args .num_threads ,
208185 efSearch = args .ef_search ,
209186 efConstruction = args .ef_construction ,
187+ search_bounded_queue = args .search_bounded_queue ,
210188 )
211189 result = _accumulate_counters (counters , result )
212190 assert result is not None
213191 for counter , values in result .items ():
214192 if is_perf_counter (counter ):
215193 print (
216- "%s t=%.3f us (± %.4f)" % (
217- counter ,
218- np .mean (values ),
219- np .std (values )
220- )
194+ "%s t=%.3f us (± %.4f)" %
195+ (counter , np .mean (values ), np .std (values ))
221196 )
0 commit comments