Skip to content

Commit 6c00719

Browse files
committed
Fix merge concurrency
1 parent a0761c9 commit 6c00719

2 files changed

Lines changed: 58 additions & 68 deletions

File tree

src/main/knn/KnnGraphTester.java

Lines changed: 40 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,17 @@
4848
import org.apache.lucene.codecs.Codec;
4949
import org.apache.lucene.codecs.KnnVectorsFormat;
5050
import org.apache.lucene.codecs.KnnVectorsReader;
51+
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
5152
import org.apache.lucene.codecs.lucene104.Lucene104Codec;
5253
import org.apache.lucene.codecs.lucene104.Lucene104HnswScalarQuantizedVectorsFormat;
5354
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding;
5455
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat;
56+
import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat;
5557
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
5658
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
5759
import org.apache.lucene.index.ByteVectorValues;
5860
import org.apache.lucene.index.CodecReader;
61+
import org.apache.lucene.index.ConcurrentMergeScheduler;
5962
import org.apache.lucene.index.DirectoryReader;
6063
import org.apache.lucene.index.FloatVectorValues;
6164
import org.apache.lucene.index.IndexReader;
@@ -178,7 +181,6 @@ enum FilterStrategy {
178181
private int numMergeThread;
179182
private int numMergeWorker;
180183
private int numSearchThread;
181-
private ExecutorService exec;
182184
private VectorSimilarityFunction similarityFunction;
183185
private VectorEncoding vectorEncoding;
184186
private Query filterQuery;
@@ -203,8 +205,8 @@ private KnnGraphTester() {
203205
numQueryVectors = 1000;
204206
dim = 256;
205207
topK = 100;
206-
numMergeThread = 1;
207-
numMergeWorker = 1;
208+
numMergeThread = ConcurrentMergeScheduler.AUTO_DETECT_MERGES_AND_THREADS;
209+
numMergeWorker = ConcurrentMergeScheduler.AUTO_DETECT_MERGES_AND_THREADS;
208210
fanout = topK;
209211
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
210212
vectorEncoding = VectorEncoding.FLOAT32;
@@ -233,21 +235,7 @@ private static FileChannel getVectorFileChannel(Path path, int dim, VectorEncodi
233235
}
234236

235237
public static void main(String... args) throws Exception {
236-
new KnnGraphTester().runWithCleanUp(args);
237-
}
238-
239-
private void runWithCleanUp(String... args) throws Exception {
240-
try {
241-
run(args);
242-
} finally {
243-
cleanUp();
244-
}
245-
}
246-
247-
private void cleanUp() {
248-
if (exec != null) {
249-
exec.shutdownNow();
250-
}
238+
new KnnGraphTester().run(args);
251239
}
252240

253241
private void run(String... args) throws Exception {
@@ -464,9 +452,6 @@ private void run(String... args) throws Exception {
464452
break;
465453
case "-numMergeThread":
466454
numMergeThread = Integer.parseInt(args[++iarg]);
467-
if (numMergeThread > 1) {
468-
exec = Executors.newFixedThreadPool(numMergeThread, new NamedThreadFactory("hnsw-merge"));
469-
}
470455
if (numMergeThread <= 0) {
471456
throw new IllegalArgumentException("-numMergeThread should be >= 1");
472457
}
@@ -595,7 +580,6 @@ private void run(String... args) throws Exception {
595580
reindexTimeMsec = new KnnIndexer(
596581
docVectorsPath,
597582
indexPath,
598-
getCodec(maxConn, beamWidth, exec, numMergeWorker, quantize, quantizeBits, indexType),
599583
numIndexThreads,
600584
vectorEncoding,
601585
dim,
@@ -607,7 +591,14 @@ private void run(String... args) throws Exception {
607591
parentJoinMetaFile,
608592
useBp,
609593
indexTimeFilter
610-
).createIndex();
594+
).createIndex(iwc -> {
595+
ConcurrentMergeScheduler cms = (ConcurrentMergeScheduler) iwc.getMergeScheduler();
596+
cms.setMaxMergesAndThreads(numMergeWorker, numMergeThread);
597+
cms.setDefaultMaxMergesAndThreads(false);
598+
log("Indexing with %d worker(s) and %d thread(s)\n", cms.getMaxMergeCount(), cms.getMaxThreadCount());
599+
600+
iwc.setCodec(getCodec(maxConn, beamWidth, cms.getMaxMergeCount(), quantize, quantizeBits, indexType));
601+
});
611602
Files.writeString(indexKeyPath, indexKey);
612603
log("reindex takes %.2f sec\n", msToSec(reindexTimeMsec));
613604
// save indexing time so future runs that re-use this index remember:
@@ -986,13 +977,17 @@ private void printFanoutHist(Path indexPath, String field) throws IOException {
986977

987978
@SuppressForbidden(reason = "Prints stuff")
988979
private double forceMerge() throws IOException, InterruptedException {
989-
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.APPEND);
990-
iwc.setCodec(getCodec(maxConn, beamWidth, exec, numMergeWorker, quantize, quantizeBits, indexType));
991980
KnnIndexer.TrackingConcurrentMergeScheduler tcms = new KnnIndexer.TrackingConcurrentMergeScheduler();
981+
tcms.setMaxMergesAndThreads(numMergeWorker, numMergeThread);
982+
tcms.setDefaultMaxMergesAndThreads(false);
983+
log("Force merge using %d worker(s) and %d thread(s)\n", tcms.getMaxMergeCount(), tcms.getMaxThreadCount());
984+
985+
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.APPEND);
986+
iwc.setCodec(getCodec(maxConn, beamWidth, tcms.getMaxMergeCount(), quantize, quantizeBits, indexType));
992987
iwc.setMergeScheduler(tcms);
993988
KnnIndexer.TrackingTieredMergePolicy ttmp = new KnnIndexer.TrackingTieredMergePolicy();
994989
iwc.setMergePolicy(ttmp);
995-
log("Force merge index in " + indexPath + "\n");
990+
log("Force merge index in %s\n", indexPath);
996991
long startNS = System.nanoTime();
997992
try (IndexWriter iw = new IndexWriter(FSDirectory.open(indexPath), iwc)) {
998993
iw.forceMerge(1, false);
@@ -1701,41 +1696,36 @@ public void log(String msg, Object... args) {
17011696
}
17021697
}
17031698

1704-
static Codec getCodec(int maxConn, int beamWidth, ExecutorService exec, int numMergeWorker, boolean quantize, int quantizeBits, IndexType indexType) {
1699+
static Codec getCodec(int maxConn, int beamWidth, int numMergeWorker, boolean quantize, int quantizeBits, IndexType indexType) {
17051700
return new Lucene104Codec() {
17061701
@Override
17071702
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
17081703
if (quantize) {
1709-
return switch (quantizeBits) {
1710-
case 1 -> switch (indexType) {
1711-
case FLAT -> new Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE);
1712-
case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat(ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE, maxConn, beamWidth, numMergeWorker, exec);
1713-
};
1714-
case 2 -> switch (indexType) {
1715-
case FLAT -> new Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding.DIBIT_QUERY_NIBBLE);
1716-
case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat(ScalarEncoding.DIBIT_QUERY_NIBBLE, maxConn, beamWidth, numMergeWorker, exec);
1717-
};
1718-
case 4 -> switch (indexType) {
1719-
case FLAT -> new Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding.PACKED_NIBBLE);
1720-
case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat(ScalarEncoding.PACKED_NIBBLE, maxConn, beamWidth, numMergeWorker, exec);
1721-
};
1722-
case 7 -> switch (indexType) {
1723-
case FLAT -> new Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding.SEVEN_BIT);
1724-
case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat(ScalarEncoding.SEVEN_BIT, maxConn, beamWidth, numMergeWorker, exec);
1725-
};
1726-
case 8 -> switch (indexType) {
1727-
case FLAT -> new Lucene104ScalarQuantizedVectorsFormat(ScalarEncoding.UNSIGNED_BYTE);
1728-
case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat(ScalarEncoding.UNSIGNED_BYTE, maxConn, beamWidth, numMergeWorker, exec);
1729-
};
1730-
default -> throw new IllegalArgumentException("Unsupported quantizeBits: " + quantizeBits);
1704+
return switch (indexType) {
1705+
case FLAT -> new Lucene104ScalarQuantizedVectorsFormat(getScalarEncodingForBits(quantizeBits));
1706+
case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat(getScalarEncodingForBits(quantizeBits), maxConn, beamWidth, numMergeWorker, null);
17311707
};
17321708
} else {
1733-
return new Lucene99HnswVectorsFormat(maxConn, beamWidth, numMergeWorker, exec);
1709+
return switch (indexType) {
1710+
case FLAT -> new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
1711+
case HNSW -> new Lucene99HnswVectorsFormat(maxConn, beamWidth, numMergeWorker, null);
1712+
};
17341713
}
17351714
}
17361715
};
17371716
}
17381717

1718+
static ScalarEncoding getScalarEncodingForBits(int quantizeBits) {
1719+
return switch (quantizeBits) {
1720+
case 1 -> ScalarEncoding.SINGLE_BIT_QUERY_NIBBLE;
1721+
case 2 -> ScalarEncoding.DIBIT_QUERY_NIBBLE;
1722+
case 4 -> ScalarEncoding.PACKED_NIBBLE;
1723+
case 7 -> ScalarEncoding.SEVEN_BIT;
1724+
case 8 -> ScalarEncoding.UNSIGNED_BYTE;
1725+
default -> throw new IllegalArgumentException("Unsupported quantizeBits: " + quantizeBits);
1726+
};
1727+
}
1728+
17391729
private static void usage() {
17401730
String error =
17411731
"Usage: TestKnnGraph [-reindex] [-search {queryfile}|-stats|-check] [-docs {datafile}] [-niter N] [-fanout N] [-maxConn N] [-beamWidth N] [-filterSelectivity N] [-prefilter]";

src/main/knn/KnnIndexer.java

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.concurrent.Executors;
3131
import java.util.concurrent.TimeUnit;
3232
import java.util.concurrent.atomic.AtomicInteger;
33+
import java.util.function.Consumer;
3334
import java.util.stream.Collectors;
3435

3536
import org.apache.lucene.codecs.Codec;
@@ -49,6 +50,9 @@
4950
import org.apache.lucene.index.SegmentCommitInfo;
5051
import org.apache.lucene.index.SegmentInfos;
5152
import org.apache.lucene.index.TieredMergePolicy;
53+
import org.apache.lucene.document.Field;
54+
import org.apache.lucene.index.IndexWriter;
55+
import org.apache.lucene.index.IndexWriterConfig;
5256
import org.apache.lucene.index.VectorEncoding;
5357
import org.apache.lucene.index.VectorSimilarityFunction;
5458
import org.apache.lucene.misc.index.BPReorderingMergePolicy;
@@ -71,7 +75,6 @@ public class KnnIndexer implements FormatterLogger {
7175
private final VectorEncoding vectorEncoding;
7276
private final int dim;
7377
private final VectorSimilarityFunction similarityFunction;
74-
private final Codec codec;
7578
private final int numDocs;
7679
private final int docsStartIndex;
7780
private final int numIndexThreads;
@@ -80,16 +83,12 @@ public class KnnIndexer implements FormatterLogger {
8083
private final Path parentJoinMetaPath;
8184
private final boolean useBp;
8285
private final FilterScheme filterScheme;
83-
private final TrackingConcurrentMergeScheduler tcms;
84-
private final TrackingTieredMergePolicy ttmp;
8586

86-
public KnnIndexer(Path docsPath, Path indexPath, Codec codec, int numIndexThreads,
87-
VectorEncoding vectorEncoding, int dim,
87+
public KnnIndexer(Path docsPath, Path indexPath, int numIndexThreads, VectorEncoding vectorEncoding, int dim,
8888
VectorSimilarityFunction similarityFunction, int numDocs, int docsStartIndex, boolean quiet,
8989
boolean parentJoin, Path parentJoinMetaPath, boolean useBp, FilterScheme filterScheme) {
9090
this.docsPath = docsPath;
9191
this.indexPath = indexPath;
92-
this.codec = codec;
9392
this.numIndexThreads = numIndexThreads;
9493
this.vectorEncoding = vectorEncoding;
9594
this.dim = dim;
@@ -101,13 +100,12 @@ public KnnIndexer(Path docsPath, Path indexPath, Codec codec, int numIndexThread
101100
this.parentJoinMetaPath = parentJoinMetaPath;
102101
this.useBp = useBp;
103102
this.filterScheme = filterScheme;
104-
this.tcms = new TrackingConcurrentMergeScheduler();
105-
this.ttmp = new TrackingTieredMergePolicy();
106103
}
107104

108-
public int createIndex() throws IOException, InterruptedException {
105+
public int createIndex(Consumer<IndexWriterConfig> iwcMutator) throws IOException, InterruptedException {
109106
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.CREATE);
110-
iwc.setCodec(codec);
107+
108+
TrackingConcurrentMergeScheduler tcms = new TrackingConcurrentMergeScheduler();
111109
iwc.setMergeScheduler(tcms);
112110
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
113111
iwc.setRAMBufferSizeMB(WRITER_BUFFER_MB);
@@ -119,7 +117,7 @@ public int createIndex() throws IOException, InterruptedException {
119117
iwc.setMaxFullFlushMergeWaitMillis(0);
120118

121119
// aim for more compact/realistic index:
122-
120+
TrackingTieredMergePolicy ttmp = new TrackingTieredMergePolicy();
123121
iwc.setMergePolicy(ttmp);
124122
ttmp.setFloorSegmentMB(256);
125123
iwc.getCodec().compoundFormat().setShouldUseCompoundFile(false);
@@ -128,8 +126,8 @@ public int createIndex() throws IOException, InterruptedException {
128126
iwc.setMergePolicy(new BPReorderingMergePolicy(iwc.getMergePolicy(), new BpVectorReorderer(KnnGraphTester.KNN_FIELD)));
129127
}
130128

131-
ConcurrentMergeScheduler cms = (ConcurrentMergeScheduler) iwc.getMergeScheduler();
132-
// cms.setMaxMergesAndThreads(24, 12);
129+
// Apply any other changes on top of these defaults
130+
iwcMutator.accept(iwc);
133131

134132
FieldType fieldType =
135133
switch (vectorEncoding) {
@@ -251,22 +249,24 @@ public int createIndex() throws IOException, InterruptedException {
251249
}
252250

253251
// give merges a chance to kick off and finish:
254-
log("now IndexWriter.commit()\n");
252+
elapsedNS = System.nanoTime() - startNS;
253+
log("now IndexWriter.commit() after %d seconds\n", TimeUnit.NANOSECONDS.toSeconds(elapsedNS));
255254
iw.commit();
256255

257256
elapsedNS = System.nanoTime() - startNS;
258-
257+
log("now wait for already running merges to finish after %d seconds\n", TimeUnit.NANOSECONDS.toSeconds(elapsedNS));
259258
waitForMergesWithStatus(ttmp, tcms, this);
259+
260+
elapsedNS = System.nanoTime() - startNS;
261+
log("now IndexWriter.close() after %d seconds\n", TimeUnit.NANOSECONDS.toSeconds(elapsedNS));
260262
}
263+
elapsedNS = System.nanoTime() - startNS;
261264
log("Indexed %d docs in %d seconds\n", numDocs, TimeUnit.NANOSECONDS.toSeconds(elapsedNS));
262265
return (int) TimeUnit.NANOSECONDS.toMillis(elapsedNS);
263266
}
264267

265268
public static void waitForMergesWithStatus(TrackingTieredMergePolicy ttmp, TrackingConcurrentMergeScheduler tcms, FormatterLogger log) throws InterruptedException {
266269
long startNS = System.nanoTime();
267-
268-
// wait for running merges to complete, and print coarse status updates
269-
log.log("now wait for already running merges to finish\n");
270270

271271
// silliness to just be able to print progress in waiting (so long...) for merges:
272272
long nextPrintNS = System.nanoTime();

0 commit comments

Comments
 (0)