diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java index 218c9d8919..443b12b9c4 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java @@ -69,10 +69,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, final VectorDataType vectorDataType = extractVectorDataType(field); final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, valuesProducer.getBinary(field)); + // For BDV it is fine to use knnVectorValues.totalLiveDocs() as we already run the full loop to calculate total + // live docs if (isMerge) { - NativeIndexWriter.getWriter(field, state).mergeIndex(knnVectorValues); + NativeIndexWriter.getWriter(field, state).mergeIndex(knnVectorValues, (int) knnVectorValues.totalLiveDocs()); } else { - NativeIndexWriter.getWriter(field, state).flushIndex(knnVectorValues); + NativeIndexWriter.getWriter(field, state).flushIndex(knnVectorValues, (int) knnVectorValues.totalLiveDocs()); } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index af7f1c5765..dba0926ff2 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -21,6 +21,7 @@ import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.common.StopWatch; @@ -63,8 +64,6 @@ public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, Fla /** * Add new field for indexing. - * In Lucene, we use single file for all the vector fields so here we need to see how we are going to make things - * work. * @param fieldInfo {@link FieldInfo} */ @Override @@ -204,7 +203,7 @@ private KNNVectorValues getKNNVectorValuesForMerge( */ @FunctionalInterface private interface IndexOperation { - void buildAndWrite(NativeIndexWriter writer, KNNVectorValues knnVectorValues) throws IOException; + void buildAndWrite(NativeIndexWriter writer, KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException; } /** @@ -248,9 +247,11 @@ private void trainAndIndex( KNNVectorValues knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext); QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); QuantizationState quantizationState = null; - if (quantizationParams != null) { + // Count the docIds + int totalLiveDocs = getLiveDocs(vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext)); + if (quantizationParams != null && totalLiveDocs > 0) { initQuantizationStateWriterIfNecessary(); - quantizationState = quantizationService.train(quantizationParams, knnVectorValues); + quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs); quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); } NativeIndexWriter writer = (quantizationParams != null) @@ -261,12 +262,27 @@ private void trainAndIndex( StopWatch stopWatch = new StopWatch(); stopWatch.start(); - indexOperation.buildAndWrite(writer, knnVectorValues); + indexOperation.buildAndWrite(writer, knnVectorValues, totalLiveDocs); long time_in_millis = stopWatch.totalTime().millis(); graphBuildTime.incrementBy(time_in_millis); log.warn("Graph build took " + time_in_millis + " ms for " + operationName); } + /** + * The {@link KNNVectorValues} will be exhausted after this function run. So make sure that you are not sending the + * vectorsValues object which you plan to use later + */ + private int getLiveDocs(KNNVectorValues vectorValues) throws IOException { + // Count all the live docs as there vectorValues.totalLiveDocs() just gives the cost for the FloatVectorValues, + // and doesn't tell the correct number of docs, if there are deleted docs in the segment. So we are counting + // the total live docs here. + int liveDocs = 0; + while (vectorValues.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + liveDocs++; + } + return liveDocs; + } + private void initQuantizationStateWriterIfNecessary() throws IOException { if (quantizationStateWriter == null) { quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java index d2a6027dbf..e68121a7db 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -48,17 +48,17 @@ public static DefaultIndexBuildStrategy getInstance() { * flushed and used to build the index. The index is then written to the specified path using JNI calls.

* * @param indexInfo The {@link BuildIndexParams} containing the parameters and configuration for building the index. - * @param knnVectorValues The {@link KNNVectorValues} representing the vectors to be indexed. * @throws IOException If an I/O error occurs during the process of building and writing the index. */ - public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { + public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOException { + final KNNVectorValues knnVectorValues = indexInfo.getVectorValues(); // Needed to make sure we don't get 0 dimensions while initializing index iterateVectorValuesOnce(knnVectorValues); IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo); int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector()); try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { - final List transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs()); + final List transferredDocIds = new ArrayList<>(indexInfo.getTotalLiveDocs()); while (knnVectorValues.docId() != NO_MORE_DOCS) { Object vector = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, indexBuildSetup); diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java index 1115bfe05f..af3f4777f4 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -52,7 +52,8 @@ public static MemOptimizedNativeIndexBuildStrategy getInstance() { * @param knnVectorValues The {@link KNNVectorValues} representing the vectors to be indexed. * @throws IOException If an I/O error occurs during the process of building and writing the index. */ - public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { + public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOException { + final KNNVectorValues knnVectorValues = indexInfo.getVectorValues(); // Needed to make sure we don't get 0 dimensions while initializing index iterateVectorValuesOnce(knnVectorValues); KNNEngine engine = indexInfo.getKnnEngine(); @@ -62,7 +63,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector // Initialize the index long indexMemoryAddress = AccessController.doPrivileged( (PrivilegedAction) () -> JNIService.initIndex( - knnVectorValues.totalLiveDocs(), + indexInfo.getTotalLiveDocs(), indexBuildSetup.getDimensions(), indexParameters, engine diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java index 19475adfad..8c9f6de971 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java @@ -6,7 +6,6 @@ package org.opensearch.knn.index.codec.nativeindex; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; -import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import java.io.IOException; @@ -15,5 +14,5 @@ */ public interface NativeIndexBuildStrategy { - void buildAndWriteIndex(BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException; + void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException; } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index 0877730442..edc96c9e14 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -106,9 +106,9 @@ public static NativeIndexWriter getWriter( * @param knnVectorValues * @throws IOException */ - public void flushIndex(final KNNVectorValues knnVectorValues) throws IOException { + public void flushIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { iterateVectorValuesOnce(knnVectorValues); - buildAndWriteIndex(knnVectorValues); + buildAndWriteIndex(knnVectorValues, totalLiveDocs); recordRefreshStats(); } @@ -117,7 +117,7 @@ public void flushIndex(final KNNVectorValues knnVectorValues) throws IOExcept * @param knnVectorValues * @throws IOException */ - public void mergeIndex(final KNNVectorValues knnVectorValues) throws IOException { + public void mergeIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { iterateVectorValuesOnce(knnVectorValues); if (knnVectorValues.docId() == NO_MORE_DOCS) { // This is in place so we do not add metrics @@ -126,13 +126,13 @@ public void mergeIndex(final KNNVectorValues knnVectorValues) throws IOExcept } long bytesPerVector = knnVectorValues.bytesPerVector(); - startMergeStats((int) knnVectorValues.totalLiveDocs(), bytesPerVector); - buildAndWriteIndex(knnVectorValues); - endMergeStats((int) knnVectorValues.totalLiveDocs(), bytesPerVector); + startMergeStats(totalLiveDocs, bytesPerVector); + buildAndWriteIndex(knnVectorValues, totalLiveDocs); + endMergeStats(totalLiveDocs, bytesPerVector); } - private void buildAndWriteIndex(final KNNVectorValues knnVectorValues) throws IOException { - if (knnVectorValues.totalLiveDocs() == 0) { + private void buildAndWriteIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { + if (totalLiveDocs == 0) { log.debug("No live docs for field " + fieldInfo.name); return; } @@ -150,15 +150,21 @@ private void buildAndWriteIndex(final KNNVectorValues knnVectorValues) throws ).toString(); state.directory.createOutput(engineFileName, state.context).close(); - final BuildIndexParams nativeIndexParams = indexParams(fieldInfo, indexPath, knnEngine); - indexBuilder.buildAndWriteIndex(nativeIndexParams, knnVectorValues); + final BuildIndexParams nativeIndexParams = indexParams(fieldInfo, indexPath, knnEngine, knnVectorValues, totalLiveDocs); + indexBuilder.buildAndWriteIndex(nativeIndexParams); writeFooter(indexPath, engineFileName, state); } // The logic for building parameters need to be cleaned up. There are various cases handled here // Currently it falls under two categories - with model and without model. Without model is further divided based on vector data type // TODO: Refactor this so its scalable. Possibly move it out of this class - private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNEngine knnEngine) throws IOException { + private BuildIndexParams indexParams( + FieldInfo fieldInfo, + String indexPath, + KNNEngine knnEngine, + KNNVectorValues vectorValues, + int totalLiveDocs + ) throws IOException { final Map parameters; VectorDataType vectorDataType; if (quantizationState != null) { @@ -180,6 +186,8 @@ private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNE .knnEngine(knnEngine) .indexPath(indexPath) .quantizationState(quantizationState) + .vectorValues(vectorValues) + .totalLiveDocs(totalLiveDocs) .build(); } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java index 78674c64bf..88507b1fc3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java @@ -11,6 +11,7 @@ import org.opensearch.common.Nullable; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.util.Map; @@ -29,4 +30,6 @@ public class BuildIndexParams { */ @Nullable QuantizationState quantizationState; + KNNVectorValues vectorValues; + int totalLiveDocs; } diff --git a/src/main/java/org/opensearch/knn/index/quantizationservice/KNNVectorQuantizationTrainingRequest.java b/src/main/java/org/opensearch/knn/index/quantizationservice/KNNVectorQuantizationTrainingRequest.java index d880a41788..f7ee129044 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationservice/KNNVectorQuantizationTrainingRequest.java +++ b/src/main/java/org/opensearch/knn/index/quantizationservice/KNNVectorQuantizationTrainingRequest.java @@ -28,8 +28,8 @@ final class KNNVectorQuantizationTrainingRequest extends TrainingRequest { * * @param knnVectorValues the KNNVectorValues instance containing the vectors. */ - KNNVectorQuantizationTrainingRequest(KNNVectorValues knnVectorValues) { - super((int) knnVectorValues.totalLiveDocs()); + KNNVectorQuantizationTrainingRequest(KNNVectorValues knnVectorValues, long liveDocs) { + super((int) liveDocs); this.knnVectorValues = knnVectorValues; this.lastIndex = 0; } diff --git a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java index b1c94f993d..7718487306 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java +++ b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java @@ -57,12 +57,15 @@ public static QuantizationService getInstance() { * @return The {@link QuantizationState} containing the state of the trained quantizer. * @throws IOException If an I/O error occurs during the training process. */ - public QuantizationState train(final QuantizationParams quantizationParams, final KNNVectorValues knnVectorValues) - throws IOException { + public QuantizationState train( + final QuantizationParams quantizationParams, + final KNNVectorValues knnVectorValues, + final long liveDocs + ) throws IOException { Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationParams); // Create the training request from the vector values - KNNVectorQuantizationTrainingRequest trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues); + KNNVectorQuantizationTrainingRequest trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues, liveDocs); // Train the quantizer and return the quantization state return quantizer.train(trainingRequest); diff --git a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java index 56ebd208fe..b12395185e 100644 --- a/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java +++ b/src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java @@ -71,9 +71,18 @@ public int bytesPerVector() { } /** - * Returns the total live docs for KNNVectorValues. + * Returns the total live docs for KNNVectorValues. This function is broken and doesn't always give the accurate + * live docs count when iterators are {@link FloatVectorValues}, {@link ByteVectorValues}. Avoid using this iterator, + * rather use a simple function like this: + *
+     *     int liveDocs = 0;
+     *     while(vectorValues.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
+     *         liveDocs++;
+     *     }
+     * 
* @return long */ + @Deprecated public long totalLiveDocs() { return vectorValuesIterator.liveDocs(); } diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java index 1a8a832aa4..96af0db197 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java @@ -74,10 +74,12 @@ public void testBuildAndWrite() { .knnEngine(KNNEngine.NMSLIB) .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("index", "param")) + .vectorValues(knnVectorValues) + .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) .build(); // When - DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams); // Then mockedJNIService.verify( @@ -166,10 +168,12 @@ public void testBuildAndWrite_withQuantization() { .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("index", "param")) .quantizationState(quantizationState) + .vectorValues(knnVectorValues) + .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) .build(); // When - MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams); // Then mockedJNIService.verify( @@ -250,10 +254,12 @@ public void testBuildAndWriteWithModel() { .knnEngine(KNNEngine.NMSLIB) .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("model_id", "id", "model_blob", modelBlob)) + .vectorValues(knnVectorValues) + .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) .build(); // When - DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams); // Then mockedJNIService.verify( diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java index 81d490bb44..22f9b2dfd6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java @@ -75,10 +75,12 @@ public void testBuildAndWrite() { .knnEngine(KNNEngine.FAISS) .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("index", "param")) + .vectorValues(knnVectorValues) + .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) .build(); // When - MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams); // Then mockedJNIService.verify( @@ -193,10 +195,12 @@ public void testBuildAndWrite_withQuantization() { .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("index", "param")) .quantizationState(quantizationState) + .vectorValues(knnVectorValues) + .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) .build(); // When - MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams); // Then mockedJNIService.verify( diff --git a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java index 720b67fd51..690391dbdb 100644 --- a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java +++ b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java @@ -46,7 +46,7 @@ public void setUp() throws Exception { public void testTrain_oneBitQuantizer_success() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); assertTrue(quantizationState instanceof OneBitScalarQuantizationState); OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState; @@ -62,7 +62,7 @@ public void testTrain_oneBitQuantizer_success() throws IOException { public void testTrain_twoBitQuantizer_success() throws IOException { ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues); + QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; @@ -85,7 +85,7 @@ public void testTrain_twoBitQuantizer_success() throws IOException { public void testTrain_fourBitQuantizer_success() throws IOException { ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues); + QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; @@ -110,7 +110,7 @@ public void testTrain_fourBitQuantizer_success() throws IOException { public void testQuantize_oneBitQuantizer_success() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); @@ -125,7 +125,7 @@ public void testQuantize_oneBitQuantizer_success() throws IOException { public void testQuantize_twoBitQuantizer_success() throws IOException { ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues); + QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams); byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput); @@ -138,7 +138,7 @@ public void testQuantize_twoBitQuantizer_success() throws IOException { public void testQuantize_fourBitQuantizer_success() throws IOException { ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues); + QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams); byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput); @@ -152,7 +152,7 @@ public void testQuantize_fourBitQuantizer_success() throws IOException { public void testQuantize_whenInvalidInput_thenThrows() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput)); } diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java index 2d664de9b6..a2bf46aa42 100644 --- a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java +++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java @@ -192,6 +192,27 @@ public void testIndexCreation_whenValid_ThenSucceed() { } } + @SneakyThrows + public void testDeletedDocsWithSegmentMerge_whenValid_ThenSucceed() { + XContentBuilder builder; + CompressionLevel compressionLevel = CompressionLevel.x32; + String indexName = INDEX_NAME + compressionLevel; + builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(FIELD_NAME) + .field("type", "knn_vector") + .field("dimension", DIMENSION) + .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel.getName()) + .field(MODE_PARAMETER, Mode.ON_DISK.getName()) + .endObject() + .endObject() + .endObject(); + String mapping = builder.toString(); + validateIndexWithDeletedDocs(indexName, mapping); + validateSearch(indexName, METHOD_PARAMETER_EF_SEARCH, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH); + } + @SneakyThrows public void testTraining_whenInvalid_thenFail() { setupTrainingIndex(); @@ -319,6 +340,18 @@ private void validateIndex(String indexName, String mapping) { forceMergeKnnIndex(indexName, 1); } + @SneakyThrows + private void validateIndexWithDeletedDocs(String indexName, String mapping) { + createKnnIndex(indexName, mapping); + addKNNDocs(indexName, FIELD_NAME, DIMENSION, 0, NUM_DOCS); + refreshIndex(indexName); + // this will simulate the deletion of the docs + addKNNDocs(indexName, FIELD_NAME, DIMENSION, 0, NUM_DOCS); + refreshIndex(indexName); + forceMergeKnnIndex(indexName, 1); + refreshIndex(indexName); + } + @SneakyThrows private void setupTrainingIndex() { createBasicKnnIndex(TRAINING_INDEX_NAME, TRAINING_FIELD_NAME, DIMENSION);