4848import org .apache .lucene .codecs .Codec ;
4949import org .apache .lucene .codecs .KnnVectorsFormat ;
5050import org .apache .lucene .codecs .KnnVectorsReader ;
51+ import org .apache .lucene .codecs .hnsw .FlatVectorScorerUtil ;
5152import org .apache .lucene .codecs .lucene104 .Lucene104Codec ;
5253import org .apache .lucene .codecs .lucene104 .Lucene104HnswScalarQuantizedVectorsFormat ;
5354import org .apache .lucene .codecs .lucene104 .Lucene104ScalarQuantizedVectorsFormat .ScalarEncoding ;
5455import org .apache .lucene .codecs .lucene104 .Lucene104ScalarQuantizedVectorsFormat ;
56+ import org .apache .lucene .codecs .lucene99 .Lucene99FlatVectorsFormat ;
5557import org .apache .lucene .codecs .lucene99 .Lucene99HnswVectorsFormat ;
5658import org .apache .lucene .codecs .lucene99 .Lucene99HnswVectorsReader ;
5759import org .apache .lucene .index .ByteVectorValues ;
5860import org .apache .lucene .index .CodecReader ;
61+ import org .apache .lucene .index .ConcurrentMergeScheduler ;
5962import org .apache .lucene .index .DirectoryReader ;
6063import org .apache .lucene .index .FloatVectorValues ;
6164import org .apache .lucene .index .IndexReader ;
@@ -178,7 +181,6 @@ enum FilterStrategy {
178181 private int numMergeThread ;
179182 private int numMergeWorker ;
180183 private int numSearchThread ;
181- private ExecutorService exec ;
182184 private VectorSimilarityFunction similarityFunction ;
183185 private VectorEncoding vectorEncoding ;
184186 private Query filterQuery ;
@@ -203,8 +205,8 @@ private KnnGraphTester() {
203205 numQueryVectors = 1000 ;
204206 dim = 256 ;
205207 topK = 100 ;
206- numMergeThread = 1 ;
207- numMergeWorker = 1 ;
208+ numMergeThread = ConcurrentMergeScheduler . AUTO_DETECT_MERGES_AND_THREADS ;
209+ numMergeWorker = ConcurrentMergeScheduler . AUTO_DETECT_MERGES_AND_THREADS ;
208210 fanout = topK ;
209211 similarityFunction = VectorSimilarityFunction .DOT_PRODUCT ;
210212 vectorEncoding = VectorEncoding .FLOAT32 ;
@@ -233,21 +235,7 @@ private static FileChannel getVectorFileChannel(Path path, int dim, VectorEncodi
233235 }
234236
235237 public static void main (String ... args ) throws Exception {
236- new KnnGraphTester ().runWithCleanUp (args );
237- }
238-
239- private void runWithCleanUp (String ... args ) throws Exception {
240- try {
241- run (args );
242- } finally {
243- cleanUp ();
244- }
245- }
246-
247- private void cleanUp () {
248- if (exec != null ) {
249- exec .shutdownNow ();
250- }
238+ new KnnGraphTester ().run (args );
251239 }
252240
253241 private void run (String ... args ) throws Exception {
@@ -464,9 +452,6 @@ private void run(String... args) throws Exception {
464452 break ;
465453 case "-numMergeThread" :
466454 numMergeThread = Integer .parseInt (args [++iarg ]);
467- if (numMergeThread > 1 ) {
468- exec = Executors .newFixedThreadPool (numMergeThread , new NamedThreadFactory ("hnsw-merge" ));
469- }
470455 if (numMergeThread <= 0 ) {
471456 throw new IllegalArgumentException ("-numMergeThread should be >= 1" );
472457 }
@@ -595,7 +580,6 @@ private void run(String... args) throws Exception {
595580 reindexTimeMsec = new KnnIndexer (
596581 docVectorsPath ,
597582 indexPath ,
598- getCodec (maxConn , beamWidth , exec , numMergeWorker , quantize , quantizeBits , indexType ),
599583 numIndexThreads ,
600584 vectorEncoding ,
601585 dim ,
@@ -607,7 +591,14 @@ private void run(String... args) throws Exception {
607591 parentJoinMetaFile ,
608592 useBp ,
609593 indexTimeFilter
610- ).createIndex ();
594+ ).createIndex (iwc -> {
595+ ConcurrentMergeScheduler cms = (ConcurrentMergeScheduler ) iwc .getMergeScheduler ();
596+ cms .setMaxMergesAndThreads (numMergeWorker , numMergeThread );
597+ cms .setDefaultMaxMergesAndThreads (false );
598+ log ("Indexing with %d worker(s) and %d thread(s)\n " , cms .getMaxMergeCount (), cms .getMaxThreadCount ());
599+
600+ iwc .setCodec (getCodec (maxConn , beamWidth , cms .getMaxMergeCount (), quantize , quantizeBits , indexType ));
601+ });
611602 Files .writeString (indexKeyPath , indexKey );
612603 log ("reindex takes %.2f sec\n " , msToSec (reindexTimeMsec ));
613604 // save indexing time so future runs that re-use this index remember:
@@ -986,13 +977,17 @@ private void printFanoutHist(Path indexPath, String field) throws IOException {
986977
987978 @ SuppressForbidden (reason = "Prints stuff" )
988979 private double forceMerge () throws IOException , InterruptedException {
989- IndexWriterConfig iwc = new IndexWriterConfig ().setOpenMode (IndexWriterConfig .OpenMode .APPEND );
990- iwc .setCodec (getCodec (maxConn , beamWidth , exec , numMergeWorker , quantize , quantizeBits , indexType ));
991980 KnnIndexer .TrackingConcurrentMergeScheduler tcms = new KnnIndexer .TrackingConcurrentMergeScheduler ();
981+ tcms .setMaxMergesAndThreads (numMergeWorker , numMergeThread );
982+ tcms .setDefaultMaxMergesAndThreads (false );
983+ log ("Force merge using %d worker(s) and %d thread(s)\n " , tcms .getMaxMergeCount (), tcms .getMaxThreadCount ());
984+
985+ IndexWriterConfig iwc = new IndexWriterConfig ().setOpenMode (IndexWriterConfig .OpenMode .APPEND );
986+ iwc .setCodec (getCodec (maxConn , beamWidth , tcms .getMaxMergeCount (), quantize , quantizeBits , indexType ));
992987 iwc .setMergeScheduler (tcms );
993988 KnnIndexer .TrackingTieredMergePolicy ttmp = new KnnIndexer .TrackingTieredMergePolicy ();
994989 iwc .setMergePolicy (ttmp );
995- log ("Force merge index in " + indexPath + " \n " );
990+ log ("Force merge index in %s \n " , indexPath );
996991 long startNS = System .nanoTime ();
997992 try (IndexWriter iw = new IndexWriter (FSDirectory .open (indexPath ), iwc )) {
998993 iw .forceMerge (1 , false );
@@ -1701,41 +1696,36 @@ public void log(String msg, Object... args) {
17011696 }
17021697 }
17031698
1704- static Codec getCodec (int maxConn , int beamWidth , ExecutorService exec , int numMergeWorker , boolean quantize , int quantizeBits , IndexType indexType ) {
1699+ static Codec getCodec (int maxConn , int beamWidth , int numMergeWorker , boolean quantize , int quantizeBits , IndexType indexType ) {
17051700 return new Lucene104Codec () {
17061701 @ Override
17071702 public KnnVectorsFormat getKnnVectorsFormatForField (String field ) {
17081703 if (quantize ) {
1709- return switch (quantizeBits ) {
1710- case 1 -> switch (indexType ) {
1711- case FLAT -> new Lucene104ScalarQuantizedVectorsFormat (ScalarEncoding .SINGLE_BIT_QUERY_NIBBLE );
1712- case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat (ScalarEncoding .SINGLE_BIT_QUERY_NIBBLE , maxConn , beamWidth , numMergeWorker , exec );
1713- };
1714- case 2 -> switch (indexType ) {
1715- case FLAT -> new Lucene104ScalarQuantizedVectorsFormat (ScalarEncoding .DIBIT_QUERY_NIBBLE );
1716- case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat (ScalarEncoding .DIBIT_QUERY_NIBBLE , maxConn , beamWidth , numMergeWorker , exec );
1717- };
1718- case 4 -> switch (indexType ) {
1719- case FLAT -> new Lucene104ScalarQuantizedVectorsFormat (ScalarEncoding .PACKED_NIBBLE );
1720- case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat (ScalarEncoding .PACKED_NIBBLE , maxConn , beamWidth , numMergeWorker , exec );
1721- };
1722- case 7 -> switch (indexType ) {
1723- case FLAT -> new Lucene104ScalarQuantizedVectorsFormat (ScalarEncoding .SEVEN_BIT );
1724- case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat (ScalarEncoding .SEVEN_BIT , maxConn , beamWidth , numMergeWorker , exec );
1725- };
1726- case 8 -> switch (indexType ) {
1727- case FLAT -> new Lucene104ScalarQuantizedVectorsFormat (ScalarEncoding .UNSIGNED_BYTE );
1728- case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat (ScalarEncoding .UNSIGNED_BYTE , maxConn , beamWidth , numMergeWorker , exec );
1729- };
1730- default -> throw new IllegalArgumentException ("Unsupported quantizeBits: " + quantizeBits );
1704+ return switch (indexType ) {
1705+ case FLAT -> new Lucene104ScalarQuantizedVectorsFormat (getScalarEncodingForBits (quantizeBits ));
1706+ case HNSW -> new Lucene104HnswScalarQuantizedVectorsFormat (getScalarEncodingForBits (quantizeBits ), maxConn , beamWidth , numMergeWorker , null );
17311707 };
17321708 } else {
1733- return new Lucene99HnswVectorsFormat (maxConn , beamWidth , numMergeWorker , exec );
1709+ return switch (indexType ) {
1710+ case FLAT -> new Lucene99FlatVectorsFormat (FlatVectorScorerUtil .getLucene99FlatVectorsScorer ());
1711+ case HNSW -> new Lucene99HnswVectorsFormat (maxConn , beamWidth , numMergeWorker , null );
1712+ };
17341713 }
17351714 }
17361715 };
17371716 }
17381717
1718+ static ScalarEncoding getScalarEncodingForBits (int quantizeBits ) {
1719+ return switch (quantizeBits ) {
1720+ case 1 -> ScalarEncoding .SINGLE_BIT_QUERY_NIBBLE ;
1721+ case 2 -> ScalarEncoding .DIBIT_QUERY_NIBBLE ;
1722+ case 4 -> ScalarEncoding .PACKED_NIBBLE ;
1723+ case 7 -> ScalarEncoding .SEVEN_BIT ;
1724+ case 8 -> ScalarEncoding .UNSIGNED_BYTE ;
1725+ default -> throw new IllegalArgumentException ("Unsupported quantizeBits: " + quantizeBits );
1726+ };
1727+ }
1728+
17391729 private static void usage () {
17401730 String error =
17411731 "Usage: TestKnnGraph [-reindex] [-search {queryfile}|-stats|-check] [-docs {datafile}] [-niter N] [-fanout N] [-maxConn N] [-beamWidth N] [-filterSelectivity N] [-prefilter]" ;
0 commit comments