Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion release-notes/opensearch-knn.release-notes-2.17.0.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ Compatible with OpenSearch 2.17.0
* Restructure mappers to better handle null cases and avoid branching in parsing [#1939](https://github.com/opensearch-project/k-NN/pull/1939)
* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)
* Encapsulate dimension, vector data type validation/processing inside Library [#1957](https://github.com/opensearch-project/k-NN/pull/1957)
* Add quantization state cache [#1960](https://github.com/opensearch-project/k-NN/pull/1960)
* Add quantization state cache [#1960](https://github.com/opensearch-project/k-NN/pull/1960)
* Add quantization state reader and writer [#1997](https://github.com/opensearch-project/k-NN/pull/1997)
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ public class KNNConstants {
public static final String MINIMAL_MODE_AND_COMPRESSION_FEATURE = "mode_and_compression_feature";

public static final String RADIAL_SEARCH_KEY = "radial_search";
public static final String QUANTIZATION_STATE_FILE_SUFFIX = "osknnqstate";

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";
Expand Down
9 changes: 5 additions & 4 deletions src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import org.opensearch.knn.index.memory.NativeMemoryCacheManager;
import org.opensearch.knn.index.memory.NativeMemoryCacheManagerDto;
import org.opensearch.knn.index.util.IndexHyperParametersUtil;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCache;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateCacheManager;
import org.opensearch.monitor.jvm.JvmInfo;
import org.opensearch.monitor.os.OsProbe;

Expand Down Expand Up @@ -60,6 +60,7 @@ public class KNNSettings {
private static final OsProbe osProbe = OsProbe.getInstance();

private static final int INDEX_THREAD_QTY_MAX = 32;
private static final QuantizationStateCacheManager quantizationStateCacheManager = QuantizationStateCacheManager.getInstance();

/**
* Settings name
Expand Down Expand Up @@ -379,11 +380,11 @@ private void setSettingsUpdateConsumers() {
NativeMemoryCacheManager.getInstance().rebuildCache(builder.build());
}, Stream.concat(dynamicCacheSettings.values().stream(), FEATURE_FLAGS.values().stream()).collect(Collectors.toUnmodifiableList()));
clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_SIZE_LIMIT_SETTING, it -> {
QuantizationStateCache.getInstance().setMaxCacheSizeInKB(it.getKb());
QuantizationStateCache.getInstance().rebuildCache();
quantizationStateCacheManager.setMaxCacheSizeInKB(it.getKb());
quantizationStateCacheManager.rebuildCache();
});
clusterService.getClusterSettings().addSettingsUpdateConsumer(QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES_SETTING, it -> {
QuantizationStateCache.getInstance().rebuildCache();
quantizationStateCacheManager.rebuildCache();
});
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import com.google.common.annotations.VisibleForTesting;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
* Reads quantization states
*/
@Log4j2
public final class KNN990QuantizationStateReader {

/**
* Read quantization states and return list of fieldNames and bytes
* File format:
* Header
* QS1 state bytes
* QS2 state bytes
* Number of quantization states
* QS1 field number
* QS1 state bytes length
* QS1 position of state bytes
* QS2 field number
* QS2 state bytes length
* QS2 position of state bytes
* Position of index section (where QS1 field name is located)
* -1 (marker)
* Footer
*
* @param state the read state to read from
*/
public static Map<String, byte[]> read(SegmentReadState state) throws IOException {
String quantizationStateFileName = getQuantizationStateFileName(state);
Map<String, byte[]> readQuantizationStateInfos = null;

try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) {
CodecUtil.retrieveChecksum(input);

int numFields = getNumFields(input);

readQuantizationStateInfos = new HashMap<>();

// Read each field's metadata from the index section and then read bytes
for (int i = 0; i < numFields; i++) {
int fieldNumber = input.readInt();
int length = input.readInt();
long position = input.readVLong();
byte[] stateBytes = readStateBytes(input, position, length);
String fieldName = state.fieldInfos.fieldInfo(fieldNumber).getName();
readQuantizationStateInfos.put(fieldName, stateBytes);
}
} catch (Exception e) {
log.warn(String.format("Unable to read the quantization state file for segment %s", state.segmentInfo.name), e);
return Collections.emptyMap();
}
return readQuantizationStateInfos;
}

/**
* Reads an individual quantization state for a given field
* @param readConfig a config class that contains necessary information for reading the state
* @return quantization state
*/
public static QuantizationState read(QuantizationStateReadConfig readConfig) throws IOException {
SegmentReadState segmentReadState = readConfig.getSegmentReadState();
String field = readConfig.getField();
String quantizationStateFileName = getQuantizationStateFileName(segmentReadState);
int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber();

try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) {
CodecUtil.retrieveChecksum(input);
int numFields = getNumFields(input);

long position = -1;
int length = 0;

// Read each field's metadata from the index section, break when correct field is found
for (int i = 0; i < numFields; i++) {
int tempFieldNumber = input.readInt();
int tempLength = input.readInt();
long tempPosition = input.readVLong();
if (tempFieldNumber == fieldNumber) {
position = tempPosition;
length = tempLength;
break;
}
}

if (position == -1 || length == 0) {
throw new IllegalArgumentException(String.format("Field %s not found", field));
}

byte[] stateBytes = readStateBytes(input, position, length);

// Deserialize the byte array to a quantization state object
ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType();
switch (scalarQuantizationType) {
case ONE_BIT:
return OneBitScalarQuantizationState.fromByteArray(stateBytes);
case TWO_BIT:
case FOUR_BIT:
return MultiBitScalarQuantizationState.fromByteArray(stateBytes);
default:
throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType));
}
} catch (Exception e) {
log.warn(String.format("Unable to read the quantization state file for segment %s", segmentReadState.segmentInfo.name), e);
return null;
}
}

@VisibleForTesting
static int getNumFields(IndexInput input) throws IOException {
long footerStart = input.length() - CodecUtil.footerLength();
long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES;
input.seek(markerAndIndexPosition);
long indexStartPosition = input.readLong();
input.seek(indexStartPosition);
return input.readInt();
}

@VisibleForTesting
static byte[] readStateBytes(IndexInput input, long position, int length) throws IOException {
input.seek(position);
byte[] stateBytes = new byte[length];
input.readBytes(stateBytes, 0, length);
return stateBytes;
}

@VisibleForTesting
static String getQuantizationStateFileName(SegmentReadState state) {
return IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.AllArgsConstructor;
import lombok.Setter;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.IndexOutput;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
* Writes quantization states to off heap memory
*/
public final class KNN990QuantizationStateWriter {

private final IndexOutput output;
private List<FieldQuantizationState> fieldQuantizationStates = new ArrayList<>();
static final String NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA = "NativeEngines990KnnVectorsFormatQSData";

/**
* Constructor
* Overall file format for writer:
* Header
* QS1 state bytes
* QS2 state bytes
* Number of quantization states
* QS1 field number
* QS1 state bytes length
* QS1 position of state bytes
* QS2 field number
* QS2 state bytes length
* QS2 position of state bytes
* Position of index section (where QS1 field name is located)
* -1 (marker)
* Footer
* @param segmentWriteState segment write state containing segment information
* @throws IOException exception could be thrown while creating the output
*/
public KNN990QuantizationStateWriter(SegmentWriteState segmentWriteState) throws IOException {
String quantizationStateFileName = IndexFileNames.segmentFileName(
segmentWriteState.segmentInfo.name,
segmentWriteState.segmentSuffix,
KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX
);

output = segmentWriteState.directory.createOutput(quantizationStateFileName, segmentWriteState.context);
}

/**
* Writes an index header
* @param segmentWriteState state containing segment information
* @throws IOException exception could be thrown while writing header
*/
public void writeHeader(SegmentWriteState segmentWriteState) throws IOException {
CodecUtil.writeIndexHeader(
output,
NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA,
0,
segmentWriteState.segmentInfo.getId(),
segmentWriteState.segmentSuffix
);
}

/**
* Writes a quantization state as bytes
*
* @param fieldNumber field number
* @param quantizationState quantization state
* @throws IOException could be thrown while writing
*/
public void writeState(int fieldNumber, QuantizationState quantizationState) throws IOException {
byte[] stateBytes = quantizationState.toByteArray();
long position = output.getFilePointer();
output.writeBytes(stateBytes, stateBytes.length);
fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position));
}

/**
* Writes index footer and other index information for parsing later
* @throws IOException could be thrown while writing
*/
public void writeFooter() throws IOException {
long indexStartPosition = output.getFilePointer();
output.writeInt(fieldQuantizationStates.size());
for (FieldQuantizationState fieldQuantizationState : fieldQuantizationStates) {
output.writeInt(fieldQuantizationState.fieldNumber);
output.writeInt(fieldQuantizationState.stateBytes.length);
output.writeVLong(fieldQuantizationState.position);
}
output.writeLong(indexStartPosition);
output.writeInt(-1);
CodecUtil.writeFooter(output);
}

@AllArgsConstructor
private static class FieldQuantizationState {
final int fieldNumber;
final byte[] stateBytes;
@Setter
Long position;
}

public void closeOutput() throws IOException {
output.close();
}
}
Loading