Skip to content

Commit 9591d33

Browse files
committed
Add graph merge executor override
1 parent 543294f commit 9591d33

1 file changed

Lines changed: 36 additions & 8 deletions

File tree

src/main/knn/KnnGraphTester.java

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ enum FilterStrategy {
188188
private boolean quantizeCompress;
189189
private int numMaxMerge;
190190
private int numMergeThread;
191+
private ForkJoinPool graphMergeExecutor;
191192
private int numSearchThread;
192193
private VectorSimilarityFunction similarityFunction;
193194
private VectorEncoding vectorEncoding;
@@ -221,6 +222,7 @@ private KnnGraphTester() {
221222
topK = 100;
222223
numMaxMerge = ConcurrentMergeScheduler.AUTO_DETECT_MERGES_AND_THREADS;
223224
numMergeThread = ConcurrentMergeScheduler.AUTO_DETECT_MERGES_AND_THREADS;
225+
graphMergeExecutor = null;
224226
fanout = topK;
225227
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
226228
vectorEncoding = VectorEncoding.FLOAT32;
@@ -470,6 +472,19 @@ private void run(String... args) throws Exception {
470472
throw new IllegalArgumentException("-numMergeThread should be >= 1");
471473
}
472474
break;
475+
case "-numMergeWorker":
476+
int numMergeWorker = Integer.parseInt(args[++iarg]);
477+
if (numMergeWorker <= 0) {
478+
throw new IllegalArgumentException("-numMergeWorker should be >= 1; use 1 for calling thread or omit the flag to use the intra-merge executor");
479+
} else if (numMergeWorker > 1) {
480+
final AtomicInteger id = new AtomicInteger(0);
481+
graphMergeExecutor = new ForkJoinPool(numMergeWorker, pool -> {
482+
var thread = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool);
483+
thread.setName("graph-merge-" + id.getAndIncrement());
484+
return thread;
485+
}, null, false);
486+
}
487+
break;
473488
case "-numSearchThread":
474489
// 0: single thread mode (not passing a executorService)
475490
// -1: use number of threads equal to the number available processors
@@ -618,9 +633,13 @@ private void run(String... args) throws Exception {
618633
ConcurrentMergeScheduler cms = (ConcurrentMergeScheduler) iwc.getMergeScheduler();
619634
cms.setMaxMergesAndThreads(numMaxMerge, numMergeThread);
620635
cms.setDefaultMaxMergesAndThreads(false);
621-
log("Indexing with %d worker(s) and %d thread(s)\n", cms.getMaxMergeCount(), cms.getMaxThreadCount());
622-
623-
iwc.setCodec(getCodec(maxConn, beamWidth, cms.getMaxThreadCount(), quantize, quantizeBits, indexType));
636+
int numMergeWorker = cms.getMaxThreadCount();
637+
if (graphMergeExecutor != null) {
638+
numMergeWorker = graphMergeExecutor.getParallelism();
639+
}
640+
log("Indexing with %d max merge(s), %d merge thread(s), and %d merge worker(s) (using intraMergeExecutor=%s)\n",
641+
cms.getMaxMergeCount(), cms.getMaxThreadCount(), numMergeWorker, graphMergeExecutor == null);
642+
iwc.setCodec(getCodec(maxConn, beamWidth, numMergeWorker, graphMergeExecutor, quantize, quantizeBits, indexType));
624643
});
625644
Files.writeString(indexKeyPath, indexKey);
626645
log("reindex takes %.2f sec\n", msToSec(reindexTimeMsec));
@@ -685,6 +704,10 @@ private void run(String... args) throws Exception {
685704
} else {
686705
printIndexStatistics(indexPath, KNN_FIELD);
687706
}
707+
if (graphMergeExecutor != null) {
708+
// Close this thread pool to avoid leaking into the search ThreadDetails
709+
graphMergeExecutor.close();
710+
}
688711
if (operation != null) {
689712
switch (operation) {
690713
case "-search":
@@ -1016,10 +1039,15 @@ private double forceMerge() throws IOException, InterruptedException {
10161039
KnnIndexer.TrackingConcurrentMergeScheduler tcms = new KnnIndexer.TrackingConcurrentMergeScheduler();
10171040
tcms.setMaxMergesAndThreads(numMaxMerge, numMergeThread);
10181041
tcms.setDefaultMaxMergesAndThreads(false);
1019-
log("Force merge using %d worker(s) and %d thread(s)\n", tcms.getMaxMergeCount(), tcms.getMaxThreadCount());
1042+
int numMergeWorker = tcms.getMaxThreadCount();
1043+
if (graphMergeExecutor != null) {
1044+
numMergeWorker = graphMergeExecutor.getParallelism();
1045+
}
1046+
log("Force merging with %d max merge(s), %d merge thread(s), and %d merge worker(s) (using intraMergeExecutor=%s)\n",
1047+
tcms.getMaxMergeCount(), tcms.getMaxThreadCount(), numMergeWorker, graphMergeExecutor == null);
10201048

10211049
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.APPEND);
1022-
iwc.setCodec(getCodec(maxConn, beamWidth, tcms.getMaxThreadCount(), quantize, quantizeBits, indexType));
1050+
iwc.setCodec(getCodec(maxConn, beamWidth, numMergeWorker, graphMergeExecutor, quantize, quantizeBits, indexType));
10231051
iwc.setMergeScheduler(tcms);
10241052
KnnIndexer.TrackingTieredMergePolicy ttmp = new KnnIndexer.TrackingTieredMergePolicy();
10251053
iwc.setMergePolicy(ttmp);
@@ -1993,17 +2021,17 @@ public void log(String msg, Object... args) {
19932021
}
19942022
}
19952023

1996-
static Codec getCodec(int maxConn, int beamWidth, int numMergeWorker, boolean quantize, int quantizeBits, IndexType indexType) {
2024+
static Codec getCodec(int maxConn, int beamWidth, int numMergeWorker, ForkJoinPool mergeExecutor, boolean quantize, int quantizeBits, IndexType indexType) {
19972025
KnnVectorsFormat knnVectorsFormat;
19982026
if (quantize) {
19992027
knnVectorsFormat = switch (indexType) {
20002028
case FLAT -> new Lucene104ScalarQuantizedVectorsFormat(getScalarEncodingForBits(quantizeBits));
2001-
case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat(getScalarEncodingForBits(quantizeBits), maxConn, beamWidth, numMergeWorker, null);
2029+
case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat(getScalarEncodingForBits(quantizeBits), maxConn, beamWidth, numMergeWorker, mergeExecutor);
20022030
};
20032031
} else {
20042032
knnVectorsFormat = switch (indexType) {
20052033
case FLAT -> new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
2006-
case HNSW -> new Lucene99HnswVectorsFormat(maxConn, beamWidth, numMergeWorker, null);
2034+
case HNSW -> new Lucene99HnswVectorsFormat(maxConn, beamWidth, numMergeWorker, mergeExecutor);
20072035
};
20082036
}
20092037
return new Lucene104Codec() {

0 commit comments

Comments
 (0)