Skip to content

Commit f3d38bc

Browse files
Fix the force merge with Quantization failures when a segment has deleted docs in it (#2046) (#2051)
Signed-off-by: Navneet Verma <[email protected]> (cherry picked from commit da854c9) Co-authored-by: Navneet Verma <[email protected]>
1 parent 9689a37 commit f3d38bc

14 files changed

Lines changed: 128 additions & 44 deletions

src/main/java/org/opensearch/knn/index/codec/KNN80Codec/KNN80DocValuesConsumer.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,12 @@ public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer,
6969
final VectorDataType vectorDataType = extractVectorDataType(field);
7070
final KNNVectorValues<?> knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, valuesProducer.getBinary(field));
7171

72+
// For BDV it is fine to use knnVectorValues.totalLiveDocs() as we already run the full loop to calculate total
73+
// live docs
7274
if (isMerge) {
73-
NativeIndexWriter.getWriter(field, state).mergeIndex(knnVectorValues);
75+
NativeIndexWriter.getWriter(field, state).mergeIndex(knnVectorValues, (int) knnVectorValues.totalLiveDocs());
7476
} else {
75-
NativeIndexWriter.getWriter(field, state).flushIndex(knnVectorValues);
77+
NativeIndexWriter.getWriter(field, state).flushIndex(knnVectorValues, (int) knnVectorValues.totalLiveDocs());
7678
}
7779
}
7880

src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.lucene.index.MergeState;
2222
import org.apache.lucene.index.SegmentWriteState;
2323
import org.apache.lucene.index.Sorter;
24+
import org.apache.lucene.search.DocIdSetIterator;
2425
import org.apache.lucene.util.IOUtils;
2526
import org.apache.lucene.util.RamUsageEstimator;
2627
import org.opensearch.common.StopWatch;
@@ -63,8 +64,6 @@ public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, Fla
6364

6465
/**
6566
* Add new field for indexing.
66-
* In Lucene, we use single file for all the vector fields so here we need to see how we are going to make things
67-
* work.
6867
* @param fieldInfo {@link FieldInfo}
6968
*/
7069
@Override
@@ -204,7 +203,7 @@ private <T> KNNVectorValues<T> getKNNVectorValuesForMerge(
204203
*/
205204
@FunctionalInterface
206205
private interface IndexOperation<T> {
207-
void buildAndWrite(NativeIndexWriter writer, KNNVectorValues<T> knnVectorValues) throws IOException;
206+
void buildAndWrite(NativeIndexWriter writer, KNNVectorValues<T> knnVectorValues, int totalLiveDocs) throws IOException;
208207
}
209208

210209
/**
@@ -248,9 +247,11 @@ private <T, C> void trainAndIndex(
248247
KNNVectorValues<T> knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
249248
QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo);
250249
QuantizationState quantizationState = null;
251-
if (quantizationParams != null) {
250+
// Count the docIds
251+
int totalLiveDocs = getLiveDocs(vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext));
252+
if (quantizationParams != null && totalLiveDocs > 0) {
252253
initQuantizationStateWriterIfNecessary();
253-
quantizationState = quantizationService.train(quantizationParams, knnVectorValues);
254+
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs);
254255
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
255256
}
256257
NativeIndexWriter writer = (quantizationParams != null)
@@ -261,12 +262,27 @@ private <T, C> void trainAndIndex(
261262

262263
StopWatch stopWatch = new StopWatch();
263264
stopWatch.start();
264-
indexOperation.buildAndWrite(writer, knnVectorValues);
265+
indexOperation.buildAndWrite(writer, knnVectorValues, totalLiveDocs);
265266
long time_in_millis = stopWatch.totalTime().millis();
266267
graphBuildTime.incrementBy(time_in_millis);
267268
log.warn("Graph build took " + time_in_millis + " ms for " + operationName);
268269
}
269270

271+
/**
272+
* The {@link KNNVectorValues} will be exhausted after this function run. So make sure that you are not sending the
273+
* vectorsValues object which you plan to use later
274+
*/
275+
private int getLiveDocs(KNNVectorValues<?> vectorValues) throws IOException {
276+
// Count all the live docs as there vectorValues.totalLiveDocs() just gives the cost for the FloatVectorValues,
277+
// and doesn't tell the correct number of docs, if there are deleted docs in the segment. So we are counting
278+
// the total live docs here.
279+
int liveDocs = 0;
280+
while (vectorValues.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
281+
liveDocs++;
282+
}
283+
return liveDocs;
284+
}
285+
270286
private void initQuantizationStateWriterIfNecessary() throws IOException {
271287
if (quantizationStateWriter == null) {
272288
quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState);

src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,17 @@ public static DefaultIndexBuildStrategy getInstance() {
4848
* flushed and used to build the index. The index is then written to the specified path using JNI calls.</p>
4949
*
5050
* @param indexInfo The {@link BuildIndexParams} containing the parameters and configuration for building the index.
51-
* @param knnVectorValues The {@link KNNVectorValues} representing the vectors to be indexed.
5251
* @throws IOException If an I/O error occurs during the process of building and writing the index.
5352
*/
54-
public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException {
53+
public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOException {
54+
final KNNVectorValues<?> knnVectorValues = indexInfo.getVectorValues();
5555
// Needed to make sure we don't get 0 dimensions while initializing index
5656
iterateVectorValuesOnce(knnVectorValues);
5757
IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo);
5858

5959
int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector());
6060
try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) {
61-
final List<Integer> transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs());
61+
final List<Integer> transferredDocIds = new ArrayList<>(indexInfo.getTotalLiveDocs());
6262

6363
while (knnVectorValues.docId() != NO_MORE_DOCS) {
6464
Object vector = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, indexBuildSetup);

src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ public static MemOptimizedNativeIndexBuildStrategy getInstance() {
5252
* @param knnVectorValues The {@link KNNVectorValues} representing the vectors to be indexed.
5353
* @throws IOException If an I/O error occurs during the process of building and writing the index.
5454
*/
55-
public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException {
55+
public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOException {
56+
final KNNVectorValues<?> knnVectorValues = indexInfo.getVectorValues();
5657
// Needed to make sure we don't get 0 dimensions while initializing index
5758
iterateVectorValuesOnce(knnVectorValues);
5859
KNNEngine engine = indexInfo.getKnnEngine();
@@ -62,7 +63,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector
6263
// Initialize the index
6364
long indexMemoryAddress = AccessController.doPrivileged(
6465
(PrivilegedAction<Long>) () -> JNIService.initIndex(
65-
knnVectorValues.totalLiveDocs(),
66+
indexInfo.getTotalLiveDocs(),
6667
indexBuildSetup.getDimensions(),
6768
indexParameters,
6869
engine

src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexBuildStrategy.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package org.opensearch.knn.index.codec.nativeindex;
77

88
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
9-
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
109

1110
import java.io.IOException;
1211

@@ -15,5 +14,5 @@
1514
*/
1615
public interface NativeIndexBuildStrategy {
1716

18-
void buildAndWriteIndex(BuildIndexParams indexInfo, final KNNVectorValues<?> knnVectorValues) throws IOException;
17+
void buildAndWriteIndex(BuildIndexParams indexInfo) throws IOException;
1918
}

src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ public static NativeIndexWriter getWriter(
106106
* @param knnVectorValues
107107
* @throws IOException
108108
*/
109-
public void flushIndex(final KNNVectorValues<?> knnVectorValues) throws IOException {
109+
public void flushIndex(final KNNVectorValues<?> knnVectorValues, int totalLiveDocs) throws IOException {
110110
iterateVectorValuesOnce(knnVectorValues);
111-
buildAndWriteIndex(knnVectorValues);
111+
buildAndWriteIndex(knnVectorValues, totalLiveDocs);
112112
recordRefreshStats();
113113
}
114114

@@ -117,7 +117,7 @@ public void flushIndex(final KNNVectorValues<?> knnVectorValues) throws IOExcept
117117
* @param knnVectorValues
118118
* @throws IOException
119119
*/
120-
public void mergeIndex(final KNNVectorValues<?> knnVectorValues) throws IOException {
120+
public void mergeIndex(final KNNVectorValues<?> knnVectorValues, int totalLiveDocs) throws IOException {
121121
iterateVectorValuesOnce(knnVectorValues);
122122
if (knnVectorValues.docId() == NO_MORE_DOCS) {
123123
// This is in place so we do not add metrics
@@ -126,13 +126,13 @@ public void mergeIndex(final KNNVectorValues<?> knnVectorValues) throws IOExcept
126126
}
127127

128128
long bytesPerVector = knnVectorValues.bytesPerVector();
129-
startMergeStats((int) knnVectorValues.totalLiveDocs(), bytesPerVector);
130-
buildAndWriteIndex(knnVectorValues);
131-
endMergeStats((int) knnVectorValues.totalLiveDocs(), bytesPerVector);
129+
startMergeStats(totalLiveDocs, bytesPerVector);
130+
buildAndWriteIndex(knnVectorValues, totalLiveDocs);
131+
endMergeStats(totalLiveDocs, bytesPerVector);
132132
}
133133

134-
private void buildAndWriteIndex(final KNNVectorValues<?> knnVectorValues) throws IOException {
135-
if (knnVectorValues.totalLiveDocs() == 0) {
134+
private void buildAndWriteIndex(final KNNVectorValues<?> knnVectorValues, int totalLiveDocs) throws IOException {
135+
if (totalLiveDocs == 0) {
136136
log.debug("No live docs for field " + fieldInfo.name);
137137
return;
138138
}
@@ -150,15 +150,21 @@ private void buildAndWriteIndex(final KNNVectorValues<?> knnVectorValues) throws
150150
).toString();
151151
state.directory.createOutput(engineFileName, state.context).close();
152152

153-
final BuildIndexParams nativeIndexParams = indexParams(fieldInfo, indexPath, knnEngine);
154-
indexBuilder.buildAndWriteIndex(nativeIndexParams, knnVectorValues);
153+
final BuildIndexParams nativeIndexParams = indexParams(fieldInfo, indexPath, knnEngine, knnVectorValues, totalLiveDocs);
154+
indexBuilder.buildAndWriteIndex(nativeIndexParams);
155155
writeFooter(indexPath, engineFileName, state);
156156
}
157157

158158
// The logic for building parameters need to be cleaned up. There are various cases handled here
159159
// Currently it falls under two categories - with model and without model. Without model is further divided based on vector data type
160160
// TODO: Refactor this so its scalable. Possibly move it out of this class
161-
private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNEngine knnEngine) throws IOException {
161+
private BuildIndexParams indexParams(
162+
FieldInfo fieldInfo,
163+
String indexPath,
164+
KNNEngine knnEngine,
165+
KNNVectorValues<?> vectorValues,
166+
int totalLiveDocs
167+
) throws IOException {
162168
final Map<String, Object> parameters;
163169
VectorDataType vectorDataType;
164170
if (quantizationState != null) {
@@ -180,6 +186,8 @@ private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNE
180186
.knnEngine(knnEngine)
181187
.indexPath(indexPath)
182188
.quantizationState(quantizationState)
189+
.vectorValues(vectorValues)
190+
.totalLiveDocs(totalLiveDocs)
183191
.build();
184192
}
185193

src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.opensearch.common.Nullable;
1212
import org.opensearch.knn.index.VectorDataType;
1313
import org.opensearch.knn.index.engine.KNNEngine;
14+
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
1415
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
1516

1617
import java.util.Map;
@@ -29,4 +30,6 @@ public class BuildIndexParams {
2930
*/
3031
@Nullable
3132
QuantizationState quantizationState;
33+
KNNVectorValues<?> vectorValues;
34+
int totalLiveDocs;
3235
}

src/main/java/org/opensearch/knn/index/quantizationservice/KNNVectorQuantizationTrainingRequest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ final class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {
2828
*
2929
* @param knnVectorValues the KNNVectorValues instance containing the vectors.
3030
*/
31-
KNNVectorQuantizationTrainingRequest(KNNVectorValues<T> knnVectorValues) {
32-
super((int) knnVectorValues.totalLiveDocs());
31+
KNNVectorQuantizationTrainingRequest(KNNVectorValues<T> knnVectorValues, long liveDocs) {
32+
super((int) liveDocs);
3333
this.knnVectorValues = knnVectorValues;
3434
this.lastIndex = 0;
3535
}

src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,15 @@ public static <T, R> QuantizationService<T, R> getInstance() {
5757
* @return The {@link QuantizationState} containing the state of the trained quantizer.
5858
* @throws IOException If an I/O error occurs during the training process.
5959
*/
60-
public QuantizationState train(final QuantizationParams quantizationParams, final KNNVectorValues<T> knnVectorValues)
61-
throws IOException {
60+
public QuantizationState train(
61+
final QuantizationParams quantizationParams,
62+
final KNNVectorValues<T> knnVectorValues,
63+
final long liveDocs
64+
) throws IOException {
6265
Quantizer<T, R> quantizer = QuantizerFactory.getQuantizer(quantizationParams);
6366

6467
// Create the training request from the vector values
65-
KNNVectorQuantizationTrainingRequest<T> trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues);
68+
KNNVectorQuantizationTrainingRequest<T> trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues, liveDocs);
6669

6770
// Train the quantizer and return the quantization state
6871
return quantizer.train(trainingRequest);

src/main/java/org/opensearch/knn/index/vectorvalues/KNNVectorValues.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,18 @@ public int bytesPerVector() {
7171
}
7272

7373
/**
74-
* Returns the total live docs for KNNVectorValues.
74+
* Returns the total live docs for KNNVectorValues. This function is broken and doesn't always give the accurate
75+
* live docs count when iterators are {@link FloatVectorValues}, {@link ByteVectorValues}. Avoid using this iterator,
76+
* rather use a simple function like this:
77+
* <pre class="prettyprint">
78+
* int liveDocs = 0;
79+
* while(vectorValues.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) {
80+
* liveDocs++;
81+
* }
82+
* </pre>
7583
* @return long
7684
*/
85+
@Deprecated
7786
public long totalLiveDocs() {
7887
return vectorValuesIterator.liveDocs();
7988
}

0 commit comments

Comments
 (0)