diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java
index 218c9d8919..443b12b9c4 100644
--- a/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java
+++ b/src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java
@@ -69,10 +69,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
final VectorDataType vectorDataType = extractVectorDataType(field);
final KNNVectorValues> knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, valuesProducer.getBinary(field));
+ // For BDV it is fine to use knnVectorValues.totalLiveDocs() as we already run the full loop to calculate total
+ // live docs
if (isMerge) {
- NativeIndexWriter.getWriter(field, state).mergeIndex(knnVectorValues);
+ NativeIndexWriter.getWriter(field, state).mergeIndex(knnVectorValues, (int) knnVectorValues.totalLiveDocs());
} else {
- NativeIndexWriter.getWriter(field, state).flushIndex(knnVectorValues);
+ NativeIndexWriter.getWriter(field, state).flushIndex(knnVectorValues, (int) knnVectorValues.totalLiveDocs());
}
}
diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java
index af7f1c5765..dba0926ff2 100644
--- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java
+++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java
@@ -21,6 +21,7 @@
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
+import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.common.StopWatch;
@@ -63,8 +64,6 @@ public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, Fla
/**
* Add new field for indexing.
- * In Lucene, we use single file for all the vector fields so here we need to see how we are going to make things
- * work.
* @param fieldInfo {@link FieldInfo}
*/
@Override
@@ -204,7 +203,7 @@ private
+ * int liveDocs = 0;
+ * while(vectorValues.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
+ * liveDocs++;
+ * }
+ *
* @return long
*/
+ @Deprecated
public long totalLiveDocs() {
return vectorValuesIterator.liveDocs();
}
diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java
index 1a8a832aa4..96af0db197 100644
--- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java
+++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java
@@ -74,10 +74,12 @@ public void testBuildAndWrite() {
.knnEngine(KNNEngine.NMSLIB)
.vectorDataType(VectorDataType.FLOAT)
.parameters(Map.of("index", "param"))
+ .vectorValues(knnVectorValues)
+ .totalLiveDocs((int) knnVectorValues.totalLiveDocs())
.build();
// When
- DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues);
+ DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams);
// Then
mockedJNIService.verify(
@@ -166,10 +168,12 @@ public void testBuildAndWrite_withQuantization() {
.vectorDataType(VectorDataType.FLOAT)
.parameters(Map.of("index", "param"))
.quantizationState(quantizationState)
+ .vectorValues(knnVectorValues)
+ .totalLiveDocs((int) knnVectorValues.totalLiveDocs())
.build();
// When
- MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues);
+ MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams);
// Then
mockedJNIService.verify(
@@ -250,10 +254,12 @@ public void testBuildAndWriteWithModel() {
.knnEngine(KNNEngine.NMSLIB)
.vectorDataType(VectorDataType.FLOAT)
.parameters(Map.of("model_id", "id", "model_blob", modelBlob))
+ .vectorValues(knnVectorValues)
+ .totalLiveDocs((int) knnVectorValues.totalLiveDocs())
.build();
// When
- DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues);
+ DefaultIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams);
// Then
mockedJNIService.verify(
diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java
index 81d490bb44..22f9b2dfd6 100644
--- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java
+++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java
@@ -75,10 +75,12 @@ public void testBuildAndWrite() {
.knnEngine(KNNEngine.FAISS)
.vectorDataType(VectorDataType.FLOAT)
.parameters(Map.of("index", "param"))
+ .vectorValues(knnVectorValues)
+ .totalLiveDocs((int) knnVectorValues.totalLiveDocs())
.build();
// When
- MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues);
+ MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams);
// Then
mockedJNIService.verify(
@@ -193,10 +195,12 @@ public void testBuildAndWrite_withQuantization() {
.vectorDataType(VectorDataType.FLOAT)
.parameters(Map.of("index", "param"))
.quantizationState(quantizationState)
+ .vectorValues(knnVectorValues)
+ .totalLiveDocs((int) knnVectorValues.totalLiveDocs())
.build();
// When
- MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues);
+ MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams);
// Then
mockedJNIService.verify(
diff --git a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java
index 720b67fd51..690391dbdb 100644
--- a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java
+++ b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java
@@ -46,7 +46,7 @@ public void setUp() throws Exception {
public void testTrain_oneBitQuantizer_success() throws IOException {
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
- QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues);
+ QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
assertTrue(quantizationState instanceof OneBitScalarQuantizationState);
OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState;
@@ -62,7 +62,7 @@ public void testTrain_oneBitQuantizer_success() throws IOException {
public void testTrain_twoBitQuantizer_success() throws IOException {
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
- QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues);
+ QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
assertTrue(quantizationState instanceof MultiBitScalarQuantizationState);
MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState;
@@ -85,7 +85,7 @@ public void testTrain_twoBitQuantizer_success() throws IOException {
public void testTrain_fourBitQuantizer_success() throws IOException {
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
- QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues);
+ QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
assertTrue(quantizationState instanceof MultiBitScalarQuantizationState);
MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState;
@@ -110,7 +110,7 @@ public void testTrain_fourBitQuantizer_success() throws IOException {
public void testQuantize_oneBitQuantizer_success() throws IOException {
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
- QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues);
+ QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams);
@@ -125,7 +125,7 @@ public void testQuantize_oneBitQuantizer_success() throws IOException {
public void testQuantize_twoBitQuantizer_success() throws IOException {
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
- QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues);
+ QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams);
byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput);
@@ -138,7 +138,7 @@ public void testQuantize_twoBitQuantizer_success() throws IOException {
public void testQuantize_fourBitQuantizer_success() throws IOException {
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
- QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues);
+ QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams);
byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput);
@@ -152,7 +152,7 @@ public void testQuantize_fourBitQuantizer_success() throws IOException {
public void testQuantize_whenInvalidInput_thenThrows() throws IOException {
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
- QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues);
+ QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams);
assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput));
}
diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java
index 2d664de9b6..a2bf46aa42 100644
--- a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java
+++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java
@@ -192,6 +192,27 @@ public void testIndexCreation_whenValid_ThenSucceed() {
}
}
+ @SneakyThrows
+ public void testDeletedDocsWithSegmentMerge_whenValid_ThenSucceed() {
+ XContentBuilder builder;
+ CompressionLevel compressionLevel = CompressionLevel.x32;
+ String indexName = INDEX_NAME + compressionLevel;
+ builder = XContentFactory.jsonBuilder()
+ .startObject()
+ .startObject("properties")
+ .startObject(FIELD_NAME)
+ .field("type", "knn_vector")
+ .field("dimension", DIMENSION)
+ .field(COMPRESSION_LEVEL_PARAMETER, compressionLevel.getName())
+ .field(MODE_PARAMETER, Mode.ON_DISK.getName())
+ .endObject()
+ .endObject()
+ .endObject();
+ String mapping = builder.toString();
+ validateIndexWithDeletedDocs(indexName, mapping);
+ validateSearch(indexName, METHOD_PARAMETER_EF_SEARCH, KNNSettings.INDEX_KNN_DEFAULT_ALGO_PARAM_EF_SEARCH);
+ }
+
@SneakyThrows
public void testTraining_whenInvalid_thenFail() {
setupTrainingIndex();
@@ -319,6 +340,18 @@ private void validateIndex(String indexName, String mapping) {
forceMergeKnnIndex(indexName, 1);
}
+ @SneakyThrows
+ private void validateIndexWithDeletedDocs(String indexName, String mapping) {
+ createKnnIndex(indexName, mapping);
+ addKNNDocs(indexName, FIELD_NAME, DIMENSION, 0, NUM_DOCS);
+ refreshIndex(indexName);
+ // this will simulate the deletion of the docs
+ addKNNDocs(indexName, FIELD_NAME, DIMENSION, 0, NUM_DOCS);
+ refreshIndex(indexName);
+ forceMergeKnnIndex(indexName, 1);
+ refreshIndex(indexName);
+ }
+
@SneakyThrows
private void setupTrainingIndex() {
createBasicKnnIndex(TRAINING_INDEX_NAME, TRAINING_FIELD_NAME, DIMENSION);