diff --git a/benchs/bench_fw/descriptors.py b/benchs/bench_fw/descriptors.py index 8b1d65a505..d40243bb5b 100644 --- a/benchs/bench_fw/descriptors.py +++ b/benchs/bench_fw/descriptors.py @@ -106,6 +106,8 @@ class DatasetDescriptor: # desc_name desc_name: Optional[str] = None + normalize_L2: bool = False + def __hash__(self): return hash(self.get_filename()) diff --git a/benchs/bench_fw/index.py b/benchs/bench_fw/index.py index fe2fe103ef..b1252ad1b0 100644 --- a/benchs/bench_fw/index.py +++ b/benchs/bench_fw/index.py @@ -1138,6 +1138,8 @@ def assemble(self, dry_run): return None, None, "" logger.info(f"assemble, train {self.factory}") xt = self.io.get_dataset(self.training_vectors) + if self.training_vectors.normalize_L2: + faiss.normalize_L2(xt) _, t, _ = timer("train", lambda: codec.train(xt), once=True) t_aggregate += t