Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.mapper;

import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;

import java.util.Optional;

Expand Down Expand Up @@ -48,6 +49,14 @@ default CompressionLevel getCompressionLevel() {
return CompressionLevel.NOT_CONFIGURED;
}

/**
* Returns quantization config
* @return
*/
default QuantizationConfig getQuantizationConfig() {
return QuantizationConfig.EMPTY;
}

/**
*
* @return the dimension of the index; for model based indices, it will be null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ public KNNVectorFieldMapper build(BuilderContext context) {
hasDocValues.get(),
modelDao,
indexCreatedVersion,
originalParameters
originalParameters,
knnMethodConfigContext
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ public static MethodFieldMapper createFieldMapper(
boolean hasDocValues,
OriginalMappingParameters originalMappingParameters
) {

KNNMethodContext knnMethodContext = originalMappingParameters.getResolvedKnnMethodContext();
QuantizationConfig quantizationConfig = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext)
.getQuantizationConfig();

final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(
fullname,
metaValue,
Expand All @@ -75,6 +81,11 @@ public Mode getMode() {
public CompressionLevel getCompressionLevel() {
return knnMethodConfigContext.getCompressionLevel();
}

@Override
public QuantizationConfig getQuantizationConfig() {
return quantizationConfig;
}
}
);
return new MethodFieldMapper(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,17 @@ public static ModelFieldMapper createFieldMapper(
boolean hasDocValues,
ModelDao modelDao,
Version indexCreatedVersion,
OriginalMappingParameters originalMappingParameters
OriginalMappingParameters originalMappingParameters,
KNNMethodConfigContext knnMethodConfigContext
) {

final KNNMethodContext knnMethodContext = originalMappingParameters.getKnnMethodContext();
final QuantizationConfig quantizationConfig = knnMethodContext == null
? QuantizationConfig.EMPTY
: knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext)
.getQuantizationConfig();

final KNNVectorFieldType mappedFieldType = new KNNVectorFieldType(fullname, metaValue, vectorDataType, new KNNMappingConfig() {
private Integer dimension = null;
private Mode mode = null;
Expand Down Expand Up @@ -94,6 +102,11 @@ public CompressionLevel getCompressionLevel() {
return compressionLevel;
}

@Override
public QuantizationConfig getQuantizationConfig() {
return quantizationConfig;
}

// ModelMetadata relies on cluster state which may not be available during field mapper creation. Thus,
// we lazily initialize it.
private void initFromModelMetadata() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.model.QueryContext;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.query.parser.RescoreParser;
Expand Down Expand Up @@ -451,6 +452,10 @@ protected Query doToQuery(QueryShardContext context) {
if (vectorDataType == VectorDataType.BINARY) {
throw new UnsupportedOperationException(String.format(Locale.ROOT, "Binary data type does not support radial search"));
}

if (knnMappingConfig.getQuantizationConfig() != QuantizationConfig.EMPTY) {
throw new UnsupportedOperationException("Radial search is not supported for indices which have quantization enabled");
}
}

// Currently, k-NN supports distance and score types radial search
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
Expand All @@ -18,6 +19,7 @@
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.index.query.ResultUtil;
Expand All @@ -39,6 +41,7 @@
* for k-NN query if required. This is done by overriding rewrite method to execute ANN on each leaf
* {@link KNNQuery} does not give the ability to post process segment results.
*/
@Log4j2
@Getter
@RequiredArgsConstructor
public class NativeEngineKnnVectorQuery extends Query {
Expand All @@ -60,7 +63,11 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
int firstPassK = rescoreContext.getFirstPassK(finalK);
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
ResultUtil.reduceToTopK(perLeafResults, firstPassK);

StopWatch stopWatch = new StopWatch().start();
perLeafResults = doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK);
long rescoreTime = stopWatch.stop().totalTime().millis();
log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", rescoreTime, firstPassK, leafReaderContexts.size());
}
ResultUtil.reduceToTopK(perLeafResults, finalK);
TopDocs[] topDocs = new TopDocs[perLeafResults.size()];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,12 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy
MockedStatic<KNNVectorFieldMapperUtil> utilMockedStatic = Mockito.mockStatic(KNNVectorFieldMapperUtil.class);
MockedStatic<ModelUtil> modelUtilMockedStatic = Mockito.mockStatic(ModelUtil.class)
) {
KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder()
.vectorDataType(VectorDataType.FLOAT)
.versionCreated(CURRENT)
.dimension(TEST_DIMENSION)
.build();

for (VectorDataType dataType : VectorDataType.values()) {
log.info("Vector Data Type is : {}", dataType);
SpaceType spaceType = VectorDataType.BINARY == dataType ? SpaceType.DEFAULT_BINARY : SpaceType.INNER_PRODUCT;
Expand Down Expand Up @@ -1022,7 +1028,8 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy
false,
modelDao,
CURRENT,
originalMappingParameters
originalMappingParameters,
knnMethodConfigContext
);

modelFieldMapper.parseCreateField(parseContext);
Expand Down Expand Up @@ -1063,7 +1070,8 @@ public void testModelFieldMapperParseCreateField_validInput_thenDifferentFieldTy
false,
modelDao,
CURRENT,
originalMappingParameters
originalMappingParameters,
knnMethodConfigContext
);

modelFieldMapper.parseCreateField(parseContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.mapper.KNNMappingConfig;
import org.opensearch.knn.index.mapper.KNNVectorFieldType;
import org.opensearch.knn.index.mapper.Mode;
import org.opensearch.knn.index.query.rescore.RescoreContext;
import org.opensearch.knn.index.util.KNNClusterUtil;
import org.opensearch.knn.index.engine.KNNMethodContext;
Expand All @@ -41,6 +44,7 @@
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -440,6 +444,47 @@ public void testDoToQuery_whenRadialSearchOnBinaryIndex_thenException() {
assertTrue(e.getMessage().contains("Binary data type does not support radial search"));
}

public void testDoToQuery_whenRadialSearchOnDiskMode_thenException() {
float[] queryVector = { 1.0f };
KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder()
.fieldName(FIELD_NAME)
.vector(queryVector)
.maxDistance(MAX_DISTANCE)
.build();
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.FLOAT);
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);
MethodComponentContext methodComponentContext = new MethodComponentContext(
org.opensearch.knn.common.KNNConstants.METHOD_HNSW,
ImmutableMap.of()
);
KNNMethodContext knnMethodContext = new KNNMethodContext(KNNEngine.FAISS, SpaceType.L2, methodComponentContext);
when(mockKNNVectorField.getKnnMappingConfig()).thenReturn(new KNNMappingConfig() {
@Override
public Optional<KNNMethodContext> getKnnMethodContext() {
return Optional.of(knnMethodContext);
}

@Override
public int getDimension() {
return 1;
}

public Mode getMode() {
return Mode.ON_DISK;
}

public QuantizationConfig getQuantizationConfig() {
return QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build();
}
});
Exception e = expectThrows(UnsupportedOperationException.class, () -> knnQueryBuilder.doToQuery(mockQueryShardContext));
assertEquals("Radial search is not supported for indices which have quantization enabled", e.getMessage());
}

public void testDoToQuery_KnnQueryWithFilter_Lucene() throws Exception {
// Given
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
Expand Down