Skip to content

Commit c3b9374

Browse files
algoriddlefacebook-github-bot
authored andcommitted
bench_fw - fixes & nits for oss (facebookresearch#3102)
Summary: Pull Request resolved: facebookresearch#3102 Reviewed By: pemazare Differential Revision: D50426528 Pulled By: algoriddle fbshipit-source-id: 886960b8b522318967fc5ec305666871b496cae8
1 parent 0a00d81 commit c3b9374

File tree

3 files changed

+49
-95
lines changed

3 files changed

+49
-95
lines changed

benchs/bench_fw/benchmark.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
22

3+
from contextlib import contextmanager
34
import json
45
import logging
5-
import time
66
from dataclasses import dataclass
77
from multiprocessing.pool import ThreadPool
88
from operator import itemgetter
99
from statistics import median, mean
10+
from time import perf_counter
1011
from typing import Any, List, Optional
1112
from .descriptors import DatasetDescriptor, IndexDescriptor
1213

@@ -26,6 +27,15 @@
2627
logger = logging.getLogger(__name__)
2728

2829

30+
@contextmanager
31+
def timer(name) -> float:
32+
logger.info(f"Measuring {name}")
33+
t1 = t2 = perf_counter()
34+
yield lambda: t2 - t1
35+
t2 = perf_counter()
36+
logger.info(f"Time for {name}: {t2 - t1:.3f} seconds")
37+
38+
2939
def refine_distances_knn(
3040
D: np.ndarray, I: np.ndarray, xq: np.ndarray, xb: np.ndarray, metric
3141
):
@@ -77,7 +87,7 @@ def range_search_pr_curve(
7787
tbl = np.vstack(
7888
[dist_ann, metric_score, cum_score, precision, recall, unique_key]
7989
)
80-
group_by_dist_max_cum_score = np.empty(len(dist_ann), np.bool)
90+
group_by_dist_max_cum_score = np.empty(len(dist_ann), bool)
8191
group_by_dist_max_cum_score[-1] = True
8292
group_by_dist_max_cum_score[:-1] = dist_ann[1:] != dist_ann[:-1]
8393
tbl = tbl[:, group_by_dist_max_cum_score]
@@ -161,11 +171,13 @@ def optimizer(codec, search, cost_metric, perf_metric):
161171
op.add_operating_point(key, perf, cost)
162172

163173

164-
def distance_ratio_measure(R, D_GT, metric):
174+
def distance_ratio_measure(I, R, D_GT, metric):
175+
sum_of_R = np.sum(np.where(I >= 0, R, 0))
176+
sum_of_D_GT = np.sum(np.where(I >= 0, D_GT, 0))
165177
if metric == faiss.METRIC_INNER_PRODUCT:
166-
return (np.sum(R) / np.sum(D_GT)).item()
178+
return (sum_of_R / sum_of_D_GT).item()
167179
elif metric == faiss.METRIC_L2:
168-
return (np.sum(D_GT) / np.sum(R)).item()
180+
return (sum_of_D_GT / sum_of_R).item()
169181
else:
170182
raise RuntimeError(f"unknown metric {metric}")
171183

@@ -188,7 +200,7 @@ def get_range_search_metric_function(range_metric, D, R):
188200
assert R is not None
189201
assert D.shape == R.shape
190202
if isinstance(range_metric, list):
191-
aradius, ascore = [], []
203+
aradius, ascore, aradius_from, aradius_to = [], [], [], []
192204
radius_to = 0
193205
for rsd in range_metric:
194206
assert isinstance(rsd, list)
@@ -212,6 +224,8 @@ def get_range_search_metric_function(range_metric, D, R):
212224
)
213225
aradius.append(real_radius)
214226
ascore.append(score)
227+
aradius_from.append(radius_from)
228+
aradius_to.append(radius_to)
215229

216230
def sigmoid(x, a, b, c):
217231
return a / (1 + np.exp(b * x - c))
@@ -229,6 +243,7 @@ def sigmoid(x, a, b, c):
229243
cutoff,
230244
lambda x: np.where(x < cutoff, sigmoid(x, *popt), 0),
231245
popt.tolist(),
246+
list(zip(aradius, ascore, aradius_from, aradius_to, strict=True))
232247
)
233248
else:
234249
# Assuming that the range_metric is a float,
@@ -244,7 +259,7 @@ def sigmoid(x, a, b, c):
244259
f"range_search_metric_function {range_metric=} {real_range=}"
245260
)
246261
assert isinstance(real_range, float)
247-
return real_range * 2, lambda x: np.where(x < real_range, 1, 0), []
262+
return real_range * 2, lambda x: np.where(x < real_range, 1, 0), [], []
248263

249264

250265
@dataclass
@@ -312,9 +327,9 @@ def range_search_reference(self, index_desc, range_metric):
312327
assert len(range_metric) > 0
313328
ri = len(range_metric[0]) - 1
314329
m_radius = (
315-
max(range_metric, key=itemgetter(ri))[ri]
330+
max(rm[ri] for rm in range_metric)
316331
if self.distance_metric_type == faiss.METRIC_L2
317-
else min(range_metric, key=itemgetter(ri))[ri]
332+
else min(rm[ri] for rm in range_metric)
318333
)
319334
else:
320335
m_radius = range_metric
@@ -329,13 +344,14 @@ def range_search_reference(self, index_desc, range_metric):
329344
gt_radius,
330345
range_search_metric_function,
331346
coefficients,
347+
coefficients_training_data,
332348
) = get_range_search_metric_function(
333349
range_metric,
334350
D if not flat else None,
335351
R if not flat else None,
336352
)
337353
logger.info("range_search_reference: end")
338-
return gt_radius, range_search_metric_function, coefficients
354+
return gt_radius, range_search_metric_function, coefficients, coefficients_training_data
339355

340356
def estimate_range(self, index_desc, parameters, range_scoring_radius):
341357
D, I, R, P = self.knn_search(
@@ -397,16 +413,12 @@ def range_search(
397413
)
398414
# QD = QD[:, :index.nprobe]
399415
# QI = QI[:, :index.nprobe]
400-
logger.info("Timing range_search_preassigned")
401416
faiss.cvar.indexIVF_stats.reset()
402-
t0 = time.time()
403-
lims, D, I = index.range_search_preassigned(xq, radius, QI, QD)
404-
t = time.time() - t0
417+
with timer("range_search_preassigned") as t:
418+
lims, D, I = index.range_search_preassigned(xq, radius, QI, QD)
405419
else:
406-
logger.info("Timing range_search")
407-
t0 = time.time()
408-
lims, D, I = index.range_search(xq, radius)
409-
t = time.time() - t0
420+
with timer("range_search") as t:
421+
lims, D, I = index.range_search(xq, radius)
410422
if flat:
411423
R = D
412424
else:
@@ -415,7 +427,7 @@ def range_search(
415427
lims, D, I, xq, xb, self.distance_metric_type
416428
)
417429
P = {
418-
"time": t,
430+
"time": t(),
419431
"radius": radius,
420432
"count": lims[-1].item(),
421433
"parameters": parameters,
@@ -560,16 +572,12 @@ def knn_search(
560572
)
561573
# QD = QD[:, :index.nprobe]
562574
# QI = QI[:, :index.nprobe]
563-
logger.info("Timing knn search_preassigned")
564575
faiss.cvar.indexIVF_stats.reset()
565-
t0 = time.time()
566-
D, I = index.search_preassigned(xq, k, QI, QD)
567-
t = time.time() - t0
576+
with timer("knn search_preassigned") as t:
577+
D, I = index.search_preassigned(xq, k, QI, QD)
568578
else:
569-
logger.info("Timing knn search")
570-
t0 = time.time()
571-
D, I = index.search(xq, k)
572-
t = time.time() - t0
579+
with timer("knn search") as t:
580+
D, I = index.search(xq, k)
573581
if flat or level > 0:
574582
R = D
575583
else:
@@ -578,7 +586,7 @@ def knn_search(
578586
D, I, xq, xb, self.distance_metric_type
579587
)
580588
P = {
581-
"time": t,
589+
"time": t(),
582590
"parameters": parameters,
583591
"index": index_desc.factory,
584592
"level": level,
@@ -646,7 +654,7 @@ def experiment(parameters, cost_metric, perf_metric):
646654
I, self.gt_knn_I
647655
),
648656
"distance_ratio": distance_ratio_measure(
649-
R, self.gt_knn_D, self.distance_metric_type
657+
I, R, self.gt_knn_D, self.distance_metric_type
650658
),
651659
}
652660
results["experiments"][key] = metrics
@@ -691,8 +699,12 @@ def benchmark(self) -> str:
691699
gt_radius,
692700
range_search_metric_function,
693701
coefficients,
702+
coefficients_training_data,
694703
) = self.range_search_reference(index_desc, range_metric)
695-
results["metrics"][metric_key] = coefficients
704+
results["metrics"][metric_key] = {
705+
"coefficients": coefficients,
706+
"training_data": coefficients_training_data,
707+
}
696708
gt_rsm = self.range_ground_truth(
697709
gt_radius, range_search_metric_function
698710
)

benchs/bench_fw/benchmark_io.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def write_file(
198198
def get_dataset(self, dataset):
199199
if dataset not in self.cached_ds:
200200
self.cached_ds[dataset] = self.read_nparray(
201-
os.path.join(self.path, dataset.name)
201+
os.path.join(self.path, dataset.tablename)
202202
)
203203
return self.cached_ds[dataset]
204204

@@ -207,9 +207,9 @@ def read_nparray(
207207
filename: str,
208208
):
209209
fn = self.download_file_from_blobstore(filename)
210-
logger.info(f"Loading nparray from {fn}\n")
210+
logger.info(f"Loading nparray from {fn}")
211211
nparray = np.load(fn)
212-
logger.info(f"Loaded nparray {nparray.shape} from {fn}\n")
212+
logger.info(f"Loaded nparray {nparray.shape} from {fn}")
213213
return nparray
214214

215215
def write_nparray(
@@ -218,7 +218,7 @@ def write_nparray(
218218
filename: str,
219219
):
220220
fn = self.get_local_filename(filename)
221-
logger.info(f"Saving nparray {nparray.shape} to {fn}\n")
221+
logger.info(f"Saving nparray {nparray.shape} to {fn}")
222222
np.save(fn, nparray)
223223
self.upload_file_to_blobstore(filename)
224224

@@ -227,10 +227,10 @@ def read_json(
227227
filename: str,
228228
):
229229
fn = self.download_file_from_blobstore(filename)
230-
logger.info(f"Loading json {fn}\n")
230+
logger.info(f"Loading json {fn}")
231231
with open(fn, "r") as fp:
232232
json_dict = json.load(fp)
233-
logger.info(f"Loaded json {json_dict} from {fn}\n")
233+
logger.info(f"Loaded json {json_dict} from {fn}")
234234
return json_dict
235235

236236
def write_json(
@@ -240,7 +240,7 @@ def write_json(
240240
overwrite: bool = False,
241241
):
242242
fn = self.get_local_filename(filename)
243-
logger.info(f"Saving json {json_dict} to {fn}\n")
243+
logger.info(f"Saving json {json_dict} to {fn}")
244244
with open(fn, "w") as fp:
245245
json.dump(json_dict, fp)
246246
self.upload_file_to_blobstore(filename, overwrite=overwrite)

build.sh

Lines changed: 0 additions & 58 deletions
This file was deleted.

0 commit comments

Comments
 (0)