11# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22
3+ from contextlib import contextmanager
34import json
45import logging
5- import time
66from dataclasses import dataclass
77from multiprocessing .pool import ThreadPool
88from operator import itemgetter
99from statistics import median , mean
10+ from time import perf_counter
1011from typing import Any , List , Optional
1112from .descriptors import DatasetDescriptor , IndexDescriptor
1213
2627logger = logging .getLogger (__name__ )
2728
2829
30+ @contextmanager
31+ def timer (name ) -> float :
32+ logger .info (f"Measuring { name } " )
33+ t1 = t2 = perf_counter ()
34+ yield lambda : t2 - t1
35+ t2 = perf_counter ()
36+ logger .info (f"Time for { name } : { t2 - t1 :.3f} seconds" )
37+
38+
2939def refine_distances_knn (
3040 D : np .ndarray , I : np .ndarray , xq : np .ndarray , xb : np .ndarray , metric
3141):
@@ -77,7 +87,7 @@ def range_search_pr_curve(
7787 tbl = np .vstack (
7888 [dist_ann , metric_score , cum_score , precision , recall , unique_key ]
7989 )
80- group_by_dist_max_cum_score = np .empty (len (dist_ann ), np . bool )
90+ group_by_dist_max_cum_score = np .empty (len (dist_ann ), bool )
8191 group_by_dist_max_cum_score [- 1 ] = True
8292 group_by_dist_max_cum_score [:- 1 ] = dist_ann [1 :] != dist_ann [:- 1 ]
8393 tbl = tbl [:, group_by_dist_max_cum_score ]
@@ -161,11 +171,13 @@ def optimizer(codec, search, cost_metric, perf_metric):
161171 op .add_operating_point (key , perf , cost )
162172
163173
164- def distance_ratio_measure (R , D_GT , metric ):
174+ def distance_ratio_measure (I , R , D_GT , metric ):
175+ sum_of_R = np .sum (np .where (I >= 0 , R , 0 ))
176+ sum_of_D_GT = np .sum (np .where (I >= 0 , D_GT , 0 ))
165177 if metric == faiss .METRIC_INNER_PRODUCT :
166- return (np . sum ( R ) / np . sum ( D_GT ) ).item ()
178+ return (sum_of_R / sum_of_D_GT ).item ()
167179 elif metric == faiss .METRIC_L2 :
168- return (np . sum ( D_GT ) / np . sum ( R ) ).item ()
180+ return (sum_of_D_GT / sum_of_R ).item ()
169181 else :
170182 raise RuntimeError (f"unknown metric { metric } " )
171183
@@ -188,7 +200,7 @@ def get_range_search_metric_function(range_metric, D, R):
188200 assert R is not None
189201 assert D .shape == R .shape
190202 if isinstance (range_metric , list ):
191- aradius , ascore = [], []
203+ aradius , ascore , aradius_from , aradius_to = [], [], [], []
192204 radius_to = 0
193205 for rsd in range_metric :
194206 assert isinstance (rsd , list )
@@ -212,6 +224,8 @@ def get_range_search_metric_function(range_metric, D, R):
212224 )
213225 aradius .append (real_radius )
214226 ascore .append (score )
227+ aradius_from .append (radius_from )
228+ aradius_to .append (radius_to )
215229
216230 def sigmoid (x , a , b , c ):
217231 return a / (1 + np .exp (b * x - c ))
@@ -229,6 +243,7 @@ def sigmoid(x, a, b, c):
229243 cutoff ,
230244 lambda x : np .where (x < cutoff , sigmoid (x , * popt ), 0 ),
231245 popt .tolist (),
246+ list (zip (aradius , ascore , aradius_from , aradius_to , strict = True ))
232247 )
233248 else :
234249 # Assuming that the range_metric is a float,
@@ -244,7 +259,7 @@ def sigmoid(x, a, b, c):
244259 f"range_search_metric_function { range_metric = } { real_range = } "
245260 )
246261 assert isinstance (real_range , float )
247- return real_range * 2 , lambda x : np .where (x < real_range , 1 , 0 ), []
262+ return real_range * 2 , lambda x : np .where (x < real_range , 1 , 0 ), [], []
248263
249264
250265@dataclass
@@ -312,9 +327,9 @@ def range_search_reference(self, index_desc, range_metric):
312327 assert len (range_metric ) > 0
313328 ri = len (range_metric [0 ]) - 1
314329 m_radius = (
315- max (range_metric , key = itemgetter ( ri )) [ri ]
330+ max (rm [ri ] for rm in range_metric )
316331 if self .distance_metric_type == faiss .METRIC_L2
317- else min (range_metric , key = itemgetter ( ri )) [ri ]
332+ else min (rm [ri ] for rm in range_metric )
318333 )
319334 else :
320335 m_radius = range_metric
@@ -329,13 +344,14 @@ def range_search_reference(self, index_desc, range_metric):
329344 gt_radius ,
330345 range_search_metric_function ,
331346 coefficients ,
347+ coefficients_training_data ,
332348 ) = get_range_search_metric_function (
333349 range_metric ,
334350 D if not flat else None ,
335351 R if not flat else None ,
336352 )
337353 logger .info ("range_search_reference: end" )
338- return gt_radius , range_search_metric_function , coefficients
354+ return gt_radius , range_search_metric_function , coefficients , coefficients_training_data
339355
340356 def estimate_range (self , index_desc , parameters , range_scoring_radius ):
341357 D , I , R , P = self .knn_search (
@@ -397,16 +413,12 @@ def range_search(
397413 )
398414 # QD = QD[:, :index.nprobe]
399415 # QI = QI[:, :index.nprobe]
400- logger .info ("Timing range_search_preassigned" )
401416 faiss .cvar .indexIVF_stats .reset ()
402- t0 = time .time ()
403- lims , D , I = index .range_search_preassigned (xq , radius , QI , QD )
404- t = time .time () - t0
417+ with timer ("range_search_preassigned" ) as t :
418+ lims , D , I = index .range_search_preassigned (xq , radius , QI , QD )
405419 else :
406- logger .info ("Timing range_search" )
407- t0 = time .time ()
408- lims , D , I = index .range_search (xq , radius )
409- t = time .time () - t0
420+ with timer ("range_search" ) as t :
421+ lims , D , I = index .range_search (xq , radius )
410422 if flat :
411423 R = D
412424 else :
@@ -415,7 +427,7 @@ def range_search(
415427 lims , D , I , xq , xb , self .distance_metric_type
416428 )
417429 P = {
418- "time" : t ,
430+ "time" : t () ,
419431 "radius" : radius ,
420432 "count" : lims [- 1 ].item (),
421433 "parameters" : parameters ,
@@ -560,16 +572,12 @@ def knn_search(
560572 )
561573 # QD = QD[:, :index.nprobe]
562574 # QI = QI[:, :index.nprobe]
563- logger .info ("Timing knn search_preassigned" )
564575 faiss .cvar .indexIVF_stats .reset ()
565- t0 = time .time ()
566- D , I = index .search_preassigned (xq , k , QI , QD )
567- t = time .time () - t0
576+ with timer ("knn search_preassigned" ) as t :
577+ D , I = index .search_preassigned (xq , k , QI , QD )
568578 else :
569- logger .info ("Timing knn search" )
570- t0 = time .time ()
571- D , I = index .search (xq , k )
572- t = time .time () - t0
579+ with timer ("knn search" ) as t :
580+ D , I = index .search (xq , k )
573581 if flat or level > 0 :
574582 R = D
575583 else :
@@ -578,7 +586,7 @@ def knn_search(
578586 D , I , xq , xb , self .distance_metric_type
579587 )
580588 P = {
581- "time" : t ,
589+ "time" : t () ,
582590 "parameters" : parameters ,
583591 "index" : index_desc .factory ,
584592 "level" : level ,
@@ -646,7 +654,7 @@ def experiment(parameters, cost_metric, perf_metric):
646654 I , self .gt_knn_I
647655 ),
648656 "distance_ratio" : distance_ratio_measure (
649- R , self .gt_knn_D , self .distance_metric_type
657+ I , R , self .gt_knn_D , self .distance_metric_type
650658 ),
651659 }
652660 results ["experiments" ][key ] = metrics
@@ -691,8 +699,12 @@ def benchmark(self) -> str:
691699 gt_radius ,
692700 range_search_metric_function ,
693701 coefficients ,
702+ coefficients_training_data ,
694703 ) = self .range_search_reference (index_desc , range_metric )
695- results ["metrics" ][metric_key ] = coefficients
704+ results ["metrics" ][metric_key ] = {
705+ "coefficients" : coefficients ,
706+ "training_data" : coefficients_training_data ,
707+ }
696708 gt_rsm = self .range_ground_truth (
697709 gt_radius , range_search_metric_function
698710 )
0 commit comments