Skip to content

Commit a6fb6d7

Browse files
jmazanec15github-actions[bot]
authored andcommitted
Introduce mode and compression param resolution (#2034)
* Introduce mode and compression param resolution Adds mode and compression based parameter resolution. With this, if a user specifies the mode and/or compression params, we will create a default configuration with the aim of meeting those hints. Currently, it does not contain support for overriding any of the parameters. This will be taken in a future commit. Signed-off-by: John Mazanec <jmazane@amazon.com> * Modify reader changes Signed-off-by: John Mazanec <jmazane@amazon.com> --------- Signed-off-by: John Mazanec <jmazane@amazon.com> (cherry picked from commit ef4922a)
1 parent c6efa81 commit a6fb6d7

25 files changed

Lines changed: 845 additions & 403 deletions

release-notes/opensearch-knn.release-notes-2.17.0.0.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Compatible with OpenSearch 2.17.0
77
* k-NN query rescore support for native engines [#1984](https://github.com/opensearch-project/k-NN/pull/1984)
88
* Add support for byte vector with Faiss Engine HNSW algorithm [#1823](https://github.com/opensearch-project/k-NN/pull/1823)
99
* Add support for byte vector with Faiss Engine IVF algorithm [#2002](https://github.com/opensearch-project/k-NN/pull/2002)
10+
* Add mode/compression configuration support for disk-based vector search [#2034](https://github.com/opensearch-project/k-NN/pull/2034)
1011
### Enhancements
1112
* Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950)
1213
### Bug Fixes

src/main/java/org/opensearch/knn/index/KNNIndexShard.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import org.opensearch.common.lucene.Lucene;
1919
import org.opensearch.index.engine.Engine;
2020
import org.opensearch.index.shard.IndexShard;
21+
import org.opensearch.knn.common.FieldInfoExtractor;
22+
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
2123
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
2224
import org.opensearch.knn.index.memory.NativeMemoryAllocation;
2325
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
@@ -182,7 +184,11 @@ List<EngineFileContext> getEngineFileContexts(IndexReader indexReader, KNNEngine
182184
shardPath,
183185
spaceType,
184186
modelId,
185-
VectorDataType.get(fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()))
187+
FieldInfoExtractor.extractQuantizationConfig(fieldInfo) == QuantizationConfig.EMPTY
188+
? VectorDataType.get(
189+
fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue())
190+
)
191+
: VectorDataType.BINARY
186192
)
187193
);
188194
}

src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -88,45 +88,41 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr
8888
String quantizationStateFileName = getQuantizationStateFileName(segmentReadState);
8989
int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber();
9090

91-
try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) {
92-
CodecUtil.retrieveChecksum(input);
93-
int numFields = getNumFields(input);
94-
95-
long position = -1;
96-
int length = 0;
97-
98-
// Read each field's metadata from the index section, break when correct field is found
99-
for (int i = 0; i < numFields; i++) {
100-
int tempFieldNumber = input.readInt();
101-
int tempLength = input.readInt();
102-
long tempPosition = input.readVLong();
103-
if (tempFieldNumber == fieldNumber) {
104-
position = tempPosition;
105-
length = tempLength;
106-
break;
107-
}
91+
IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ);
92+
CodecUtil.retrieveChecksum(input);
93+
int numFields = getNumFields(input);
94+
95+
long position = -1;
96+
int length = 0;
97+
98+
// Read each field's metadata from the index section, break when correct field is found
99+
for (int i = 0; i < numFields; i++) {
100+
int tempFieldNumber = input.readInt();
101+
int tempLength = input.readInt();
102+
long tempPosition = input.readVLong();
103+
if (tempFieldNumber == fieldNumber) {
104+
position = tempPosition;
105+
length = tempLength;
106+
break;
108107
}
108+
}
109109

110-
if (position == -1 || length == 0) {
111-
throw new IllegalArgumentException(String.format("Field %s not found", field));
112-
}
110+
if (position == -1 || length == 0) {
111+
throw new IllegalArgumentException(String.format("Field %s not found", field));
112+
}
113113

114-
byte[] stateBytes = readStateBytes(input, position, length);
115-
116-
// Deserialize the byte array to a quantization state object
117-
ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType();
118-
switch (scalarQuantizationType) {
119-
case ONE_BIT:
120-
return OneBitScalarQuantizationState.fromByteArray(stateBytes);
121-
case TWO_BIT:
122-
case FOUR_BIT:
123-
return MultiBitScalarQuantizationState.fromByteArray(stateBytes);
124-
default:
125-
throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType));
126-
}
127-
} catch (Exception e) {
128-
log.warn(String.format("Unable to read the quantization state file for segment %s", segmentReadState.segmentInfo.name), e);
129-
return null;
114+
byte[] stateBytes = readStateBytes(input, position, length);
115+
116+
// Deserialize the byte array to a quantization state object
117+
ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType();
118+
switch (scalarQuantizationType) {
119+
case ONE_BIT:
120+
return OneBitScalarQuantizationState.fromByteArray(stateBytes);
121+
case TWO_BIT:
122+
case FOUR_BIT:
123+
return MultiBitScalarQuantizationState.fromByteArray(stateBytes);
124+
default:
125+
throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType));
130126
}
131127
}
132128

src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@
2525
import org.apache.lucene.util.IOUtils;
2626
import org.opensearch.common.UUIDs;
2727
import org.opensearch.knn.index.quantizationservice.QuantizationService;
28-
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
29-
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
30-
import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
31-
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
3228
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
3329
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
3430
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig;
@@ -50,8 +46,8 @@ public class NativeEngines990KnnVectorsReader extends KnnVectorsReader {
5046

5147
public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) throws IOException {
5248
this.segmentReadState = state;
53-
primeQuantizationStateCache();
5449
this.flatVectorsReader = flatVectorsReader;
50+
primeQuantizationStateCache();
5551
}
5652

5753
/**
@@ -197,28 +193,9 @@ public long ramBytesUsed() {
197193

198194
private void primeQuantizationStateCache() throws IOException {
199195
quantizationStateCacheKeyPerField = new HashMap<>();
200-
Map<String, byte[]> stateMap = KNN990QuantizationStateReader.read(segmentReadState);
201-
for (Map.Entry<String, byte[]> entry : stateMap.entrySet()) {
202-
FieldInfo fieldInfo = segmentReadState.fieldInfos.fieldInfo(entry.getKey());
203-
QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo);
204-
if (quantizationParams instanceof ScalarQuantizationParams) {
205-
QuantizationState quantizationState;
206-
ScalarQuantizationParams scalarQuantizationParams = (ScalarQuantizationParams) quantizationParams;
207-
switch (scalarQuantizationParams.getSqType()) {
208-
case ONE_BIT:
209-
quantizationState = OneBitScalarQuantizationState.fromByteArray(entry.getValue());
210-
break;
211-
case TWO_BIT:
212-
case FOUR_BIT:
213-
quantizationState = MultiBitScalarQuantizationState.fromByteArray(entry.getValue());
214-
break;
215-
default:
216-
throw new IllegalArgumentException("Unknown Scalar Quantization Type");
217-
}
218-
String cacheKey = UUIDs.base64UUID();
219-
quantizationStateCacheKeyPerField.put(entry.getKey(), cacheKey);
220-
quantizationStateCacheManager.addQuantizationState(cacheKey, quantizationState);
221-
}
196+
for (FieldInfo fieldInfo : segmentReadState.fieldInfos) {
197+
String cacheKey = UUIDs.base64UUID();
198+
quantizationStateCacheKeyPerField.put(fieldInfo.getName(), cacheKey);
222199
}
223200
}
224201
}

src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
import org.opensearch.core.xcontent.DeprecationHandler;
1919
import org.opensearch.core.xcontent.MediaTypeRegistry;
2020
import org.opensearch.core.xcontent.NamedXContentRegistry;
21+
import org.opensearch.knn.common.FieldInfoExtractor;
2122
import org.opensearch.knn.common.KNNConstants;
2223
import org.opensearch.knn.index.KNNSettings;
2324
import org.opensearch.knn.index.SpaceType;
2425
import org.opensearch.knn.index.VectorDataType;
2526
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
2627
import org.opensearch.knn.index.engine.KNNEngine;
28+
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
2729
import org.opensearch.knn.index.quantizationservice.QuantizationService;
2830
import org.opensearch.knn.index.util.IndexUtil;
2931
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
@@ -255,7 +257,12 @@ private Map<String, Object> getTemplateParameters(FieldInfo fieldInfo, Model mod
255257
parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY));
256258
parameters.put(KNNConstants.MODEL_ID, fieldInfo.attributes().get(MODEL_ID));
257259
parameters.put(KNNConstants.MODEL_BLOB_PARAMETER, model.getModelBlob());
258-
IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType());
260+
if (FieldInfoExtractor.extractQuantizationConfig(fieldInfo) != QuantizationConfig.EMPTY) {
261+
IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY);
262+
} else {
263+
IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType());
264+
}
265+
259266
return parameters;
260267
}
261268

src/main/java/org/opensearch/knn/index/engine/KNNMethodConfigContext.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import lombok.Setter;
1313
import org.opensearch.Version;
1414
import org.opensearch.knn.index.VectorDataType;
15+
import org.opensearch.knn.index.mapper.CompressionLevel;
16+
import org.opensearch.knn.index.mapper.Mode;
1517

1618
/**
1719
* This object provides additional context that the user does not provide when {@link KNNMethodContext} is
@@ -27,5 +29,10 @@ public final class KNNMethodConfigContext {
2729
private VectorDataType vectorDataType;
2830
private Integer dimension;
2931
private Version versionCreated;
32+
@Builder.Default
33+
private Mode mode = Mode.NOT_CONFIGURED;
34+
@Builder.Default
35+
private CompressionLevel compressionLevel = CompressionLevel.NOT_CONFIGURED;
36+
3037
public static final KNNMethodConfigContext EMPTY = KNNMethodConfigContext.builder().build();
3138
}

src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
import lombok.AllArgsConstructor;
99
import lombok.Getter;
1010
import org.opensearch.core.common.Strings;
11+
import org.opensearch.knn.index.query.rescore.RescoreContext;
1112

12-
import java.util.Arrays;
1313
import java.util.Locale;
14-
import java.util.stream.Collectors;
1514

1615
/**
1716
* Enum representing the compression level for float vectors. Compression in this sense refers to compressing a
@@ -20,20 +19,23 @@
2019
*/
2120
@AllArgsConstructor
2221
public enum CompressionLevel {
23-
NOT_CONFIGURED(-1, ""),
24-
x1(1, "1x"),
25-
x2(2, "2x"),
26-
x4(4, "4x"),
27-
x8(8, "8x"),
28-
x16(16, "16x"),
29-
x32(32, "32x");
22+
NOT_CONFIGURED(-1, "", null),
23+
x1(1, "1x", null),
24+
x2(2, "2x", null),
25+
x4(4, "4x", new RescoreContext(1.0f)),
26+
x8(8, "8x", new RescoreContext(1.5f)),
27+
x16(16, "16x", new RescoreContext(2.0f)),
28+
x32(32, "32x", new RescoreContext(2.0f));
3029

3130
// Internally, an empty string is easier to deal with them null. However, from the mapping,
3231
// we do not want users to pass in the empty string and instead want null. So we make the conversion herex
33-
static final String[] NAMES_ARRAY = Arrays.stream(CompressionLevel.values())
34-
.map(compressionLevel -> compressionLevel == NOT_CONFIGURED ? null : compressionLevel.getName())
35-
.collect(Collectors.toList())
36-
.toArray(new String[0]);
32+
public static final String[] NAMES_ARRAY = new String[] {
33+
NOT_CONFIGURED.getName(),
34+
x1.getName(),
35+
x2.getName(),
36+
x8.getName(),
37+
x16.getName(),
38+
x32.getName() };
3739

3840
/**
3941
* Default is set to 1x and is a noop
@@ -62,6 +64,8 @@ public static CompressionLevel fromName(String name) {
6264
private final int compressionLevel;
6365
@Getter
6466
private final String name;
67+
@Getter
68+
private final RescoreContext defaultRescoreContext;
6569

6670
/**
6771
* Gets the number of bits used to represent a float in order to achieve this compression. For instance, for

src/main/java/org/opensearch/knn/index/mapper/KNNMappingConfig.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,24 @@ default Optional<KNNMethodContext> getKnnMethodContext() {
3030
return Optional.empty();
3131
}
3232

33+
/**
34+
* Return the mode to be used for this field
35+
*
36+
* @return {@link Mode}
37+
*/
38+
default Mode getMode() {
39+
return Mode.NOT_CONFIGURED;
40+
}
41+
42+
/**
43+
* Return compression level to be used for this field
44+
*
45+
* @return {@link CompressionLevel}
46+
*/
47+
default CompressionLevel getCompressionLevel() {
48+
return CompressionLevel.NOT_CONFIGURED;
49+
}
50+
3351
/**
3452
*
3553
* @return the dimension of the index; for model based indices, it will be null

src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapper.java

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -145,16 +145,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder {
145145
b.startObject(n);
146146
v.toXContent(b, ToXContent.EMPTY_PARAMS);
147147
b.endObject();
148-
}), m -> m.getMethodComponentContext().getName()).setValidator(v -> {
149-
if (v == null) return;
150-
151-
ValidationException validationException;
152-
if (v.isTrainingRequired()) {
153-
validationException = new ValidationException();
154-
validationException.addValidationError(String.format(Locale.ROOT, "\"%s\" requires training.", KNN_METHOD));
155-
throw validationException;
156-
}
157-
});
148+
}), m -> m.getMethodComponentContext().getName());
158149

159150
protected final Parameter<String> mode = Parameter.restrictedStringParam(
160151
KNNConstants.MODE_PARAMETER,
@@ -354,13 +345,34 @@ public Mapper.Builder<?> parse(String name, Map<String, Object> node, ParserCont
354345
} else if (builder.modelId.get() != null) {
355346
validateFromModel(builder);
356347
} else {
348+
validateMode(builder);
357349
resolveKNNMethodComponents(builder, parserContext);
358350
validateFromKNNMethod(builder);
359351
}
360352

361353
return builder;
362354
}
363355

356+
private void validateMode(KNNVectorFieldMapper.Builder builder) {
357+
boolean isKNNMethodContextConfigured = builder.originalParameters.getKnnMethodContext() != null;
358+
boolean isModeConfigured = builder.mode.isConfigured() || builder.compressionLevel.isConfigured();
359+
if (isModeConfigured && isKNNMethodContextConfigured) {
360+
throw new MapperParsingException(
361+
String.format(
362+
Locale.ROOT,
363+
"Compression and mode can not be specified in a \"method\" mapping configuration for field: %s",
364+
builder.name
365+
)
366+
);
367+
}
368+
369+
if (isModeConfigured && builder.vectorDataType.getValue() != VectorDataType.FLOAT) {
370+
throw new MapperParsingException(
371+
String.format(Locale.ROOT, "Compression and mode cannot be used for non-float32 data type for field %s", builder.name)
372+
);
373+
}
374+
}
375+
364376
private void validateFromFlat(KNNVectorFieldMapper.Builder builder) {
365377
if (builder.modelId.get() != null || builder.knnMethodContext.get() != null) {
366378
throw new IllegalArgumentException("Cannot set modelId or method parameters when index.knn setting is false");
@@ -378,9 +390,15 @@ private void validateFromModel(KNNVectorFieldMapper.Builder builder) {
378390
}
379391

380392
private void validateFromKNNMethod(KNNVectorFieldMapper.Builder builder) {
393+
ValidationException validationException;
394+
if (builder.originalParameters.getResolvedKnnMethodContext().isTrainingRequired()) {
395+
validationException = new ValidationException();
396+
validationException.addValidationError(String.format(Locale.ROOT, "\"%s\" requires training.", KNN_METHOD));
397+
throw validationException;
398+
}
399+
381400
if (builder.originalParameters.getResolvedKnnMethodContext() != null) {
382-
ValidationException validationException = builder.originalParameters.getResolvedKnnMethodContext()
383-
.validate(builder.knnMethodConfigContext);
401+
validationException = builder.originalParameters.getResolvedKnnMethodContext().validate(builder.knnMethodConfigContext);
384402
if (validationException != null) {
385403
throw validationException;
386404
}
@@ -410,9 +428,11 @@ private void validateCompressionAndModeNotSet(KNNVectorFieldMapper.Builder build
410428
private void resolveKNNMethodComponents(KNNVectorFieldMapper.Builder builder, ParserContext parserContext) {
411429
builder.setKnnMethodConfigContext(
412430
KNNMethodConfigContext.builder()
413-
.vectorDataType(builder.vectorDataType.getValue())
431+
.vectorDataType(builder.originalParameters.getVectorDataType())
414432
.versionCreated(parserContext.indexVersionCreated())
415-
.dimension(builder.dimension.getValue())
433+
.dimension(builder.originalParameters.getDimension())
434+
.mode(Mode.fromName(builder.originalParameters.getMode()))
435+
.compressionLevel(CompressionLevel.fromName(builder.originalParameters.getCompressionLevel()))
416436
.build()
417437
);
418438

@@ -421,8 +441,17 @@ private void resolveKNNMethodComponents(KNNVectorFieldMapper.Builder builder, Pa
421441
builder.originalParameters.setResolvedKnnMethodContext(
422442
createKNNMethodContextFromLegacy(parserContext.getSettings(), parserContext.indexVersionCreated())
423443
);
424-
}
425-
setDefaultSpaceType(builder.originalParameters.getResolvedKnnMethodContext(), builder.vectorDataType.getValue());
444+
} else if (Mode.isConfigured(Mode.fromName(builder.mode.get()))
445+
|| CompressionLevel.isConfigured(CompressionLevel.fromName(builder.compressionLevel.get()))) {
446+
builder.originalParameters.setResolvedKnnMethodContext(
447+
ModeBasedResolver.INSTANCE.resolveKNNMethodContext(
448+
builder.knnMethodConfigContext.getMode(),
449+
builder.knnMethodConfigContext.getCompressionLevel(),
450+
false
451+
)
452+
);
453+
}
454+
setDefaultSpaceType(builder.originalParameters.getResolvedKnnMethodContext(), builder.originalParameters.getVectorDataType());
426455
}
427456

428457
private boolean isKNNDisabled(Settings settings) {

0 commit comments

Comments
 (0)