Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e9fedcb
Add quantization state reader and writer
ryanbogan Aug 21, 2024
1f5c030
Make inner class private
ryanbogan Aug 21, 2024
627ac7e
Address PR Feedback
ryanbogan Aug 21, 2024
a81e99d
Fix tests
ryanbogan Aug 22, 2024
daa39ed
Address PR feedback
ryanbogan Aug 22, 2024
8abf3cd
Add writer tests
ryanbogan Aug 22, 2024
e7d5ac8
Add reader tests
ryanbogan Aug 22, 2024
c804652
Add changelog entry
ryanbogan Aug 22, 2024
f711e39
Remove extra line
ryanbogan Aug 22, 2024
425b920
Address PR Feedback
ryanbogan Aug 22, 2024
2ea5371
Fix javadocs
ryanbogan Aug 22, 2024
89e45de
Make reader methods static
ryanbogan Aug 22, 2024
8cd2ee3
Integrate with merge
ryanbogan Aug 22, 2024
d644c9b
Change field name writing to internal field number and change file su…
ryanbogan Aug 23, 2024
5fe570d
Merge branch 'main' into quantization_state_writer_reader
ryanbogan Aug 26, 2024
92bb539
Change integration with native engine writer
ryanbogan Aug 26, 2024
5b03f30
Fix tests
ryanbogan Aug 26, 2024
9b486d8
Integrate with query flow
ryanbogan Aug 26, 2024
0a7d80e
Remove duplicate writeFooter
ryanbogan Aug 26, 2024
de89987
Integrate with cache
ryanbogan Aug 26, 2024
59be504
Change implementation and fix tests
ryanbogan Aug 27, 2024
366072f
Add test for reading from QuantizationStateReadConfig
ryanbogan Aug 27, 2024
077a1b6
Add cache manager tests
ryanbogan Aug 28, 2024
3dbbad9
Port changes from feature branch to fix end to end flow
ryanbogan Aug 28, 2024
45b6fbf
Change integration with query flow
ryanbogan Aug 29, 2024
0e0eaf3
Remove unnecessary changes in KNNWeight
ryanbogan Aug 29, 2024
9f8ce0c
Merge branch 'main' into quantization_state_writer_reader
ryanbogan Aug 29, 2024
595ec6f
Merge branch 'main' into quantization_state_writer_reader
ryanbogan Sep 3, 2024
1f67103
Address PR Feedback and fix compile error from rebase
ryanbogan Sep 3, 2024
80e74e3
Abstract common functionality between read methods
ryanbogan Sep 3, 2024
60cd2fa
Avoid repeat calls to quantization cache manager get instance
ryanbogan Sep 3, 2024
c0b9e71
Address PR feedback
ryanbogan Sep 3, 2024
776d531
Merge branch 'main' into quantization_state_writer_reader
ryanbogan Sep 3, 2024
8c21304
Add unit tests for KNNWeight
ryanbogan Sep 4, 2024
b304e3c
Address PR Feedback
ryanbogan Sep 4, 2024
fe19fc0
Merge branch 'main' into quantization_state_writer_reader
ryanbogan Sep 4, 2024
d8959d0
Address feedbackK
ryanbogan Sep 4, 2024
ba931da
Fix bwc tests
ryanbogan Sep 4, 2024
fdfc301
Revert previous change
ryanbogan Sep 4, 2024
0e37d2d
Condense into one loop while reading
ryanbogan Sep 4, 2024
eaff8f0
Address PR Feedback
ryanbogan Sep 4, 2024
414b2e4
Address PR Feedback
ryanbogan Sep 4, 2024
2accfb1
Address feedback
ryanbogan Sep 4, 2024
a6c87e3
Revert "Address feedback"
ryanbogan Sep 4, 2024
ff07704
Address feedback
ryanbogan Sep 4, 2024
34cd60b
Address feedback
ryanbogan Sep 4, 2024
b82fb90
Merge branch 'main' into quantization_state_writer_reader
ryanbogan Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Added Quantization Framework and implemented 1Bit and multibit quantizer[#1889](https://github.com/opensearch-project/k-NN/issues/1889)
* Encapsulate dimension, vector data type validation/processing inside Library [#1957](https://github.com/opensearch-project/k-NN/pull/1957)
* Add quantization state cache [#1960](https://github.com/opensearch-project/k-NN/pull/1960)
* Add quantization state reader and writer [#1997](https://github.com/opensearch-project/k-NN/pull/1997)
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public class KNNConstants {
public static final VectorDataType DEFAULT_VECTOR_DATA_TYPE_FIELD = VectorDataType.FLOAT;

public static final String RADIAL_SEARCH_KEY = "radial_search";
public static final String QUANTIZATION_STATE_FILE_SUFFIX = "qs";

// Lucene specific constants
public static final String LUCENE_NAME = "lucene";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import com.google.common.annotations.VisibleForTesting;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* Reads quantization states
*/
public class KNNQuantizationStateReader {

/**
* Read quantization states and return list of fieldNames and bytes
* File format:
* Header
* QS1 state bytes
* QS2 state bytes
* Number of quantization states
* QS1 field name
* QS1 state bytes length
* QS1 position of state bytes
* QS2 field name
* QS2 state bytes length
* QS2 position of state bytes
* Position of index section (where QS1 field name is located)
* -1 (marker)
* Footer
*
* @param state the read state to read from
*/
public Map<String, byte[]> read(SegmentReadState state) throws IOException {
String quantizationStateFileName = IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX
);
Map<String, byte[]> readQuantizationStateInfos = new HashMap<>();

IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ);
CodecUtil.retrieveChecksum(input);

int numFields = getNumFields(input);

List<String> fieldNames = new ArrayList<>();
List<Long> positions = new ArrayList<>();
List<Integer> lengths = new ArrayList<>();

// Read each field's metadata from the index section
for (int i = 0; i < numFields; i++) {
fieldNames.add(input.readString());
int length = input.readInt();
lengths.add(length);
long position = input.readVLong();
positions.add(position);
}
// Read each field's bytes
for (int i = 0; i < numFields; i++) {
input.seek(positions.get(i));
byte[] stateBytes = new byte[lengths.get(i)];
input.readBytes(stateBytes, 0, lengths.get(i));
readQuantizationStateInfos.put(fieldNames.get(i), stateBytes);
}
input.close();
return readQuantizationStateInfos;
}

/**
* Reads an individual quantization state for a given field
* @param directory directory to open input
* @param segmentName segment name
* @param segmentSuffix segment suffix
* @param fieldInfo field information
* @return quantization state
*/
public QuantizationState read(Directory directory, String segmentName, String segmentSuffix, FieldInfo fieldInfo) throws IOException {
String quantizationStateFileName = IndexFileNames.segmentFileName(
segmentName,
segmentSuffix,
KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX
);
String fieldName = fieldInfo.getName();

IndexInput input = directory.openInput(quantizationStateFileName, IOContext.READ);
CodecUtil.retrieveChecksum(input);
int numFields = getNumFields(input);

long position = -1;
int length = 0;

// Read each field's metadata from the index section
for (int i = 0; i < numFields; i++) {
String tempFieldName = input.readString();
int tempLength = input.readInt();
long tempPosition = input.readVLong();
if (tempFieldName.equals(fieldName)) {
position = tempPosition;
length = tempLength;
break;
}
}

if (position == -1 || length == 0) {
throw new IllegalArgumentException(String.format("Field %s not found", fieldName));
}

input.seek(position);
byte[] stateBytes = new byte[length];
input.readBytes(stateBytes, 0, length);
input.close();
// Deserialize the byte array to a quantization state object
// TODO: Get params from field info and deserialize
return null;
}

@VisibleForTesting
int getNumFields(IndexInput input) throws IOException {
long footerStart = input.length() - CodecUtil.footerLength();
long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES;
input.seek(markerAndIndexPosition);
long indexStartPosition = input.readLong();
input.readInt();
input.seek(indexStartPosition);
return input.readInt();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.AllArgsConstructor;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.IndexOutput;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

/**
* Writes quantization states to off heap memory
*/
public class KNNQuantizationStateWriter {

private final IndexOutput output;
private List<FieldQuantizationState> fieldQuantizationStates = new ArrayList<>();

/**
* Constructor
* @param segmentWriteState segment write state containing segment information
* @throws IOException exception could be thrown while creating the output
*/
public KNNQuantizationStateWriter(SegmentWriteState segmentWriteState) throws IOException {
String quantizationStateFileName = IndexFileNames.segmentFileName(
segmentWriteState.segmentInfo.name,
segmentWriteState.segmentSuffix,
KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX
);

output = segmentWriteState.directory.createOutput(quantizationStateFileName, segmentWriteState.context);
}

/**
* Writes an index header
* @param segmentWriteState state containing segment information
* @throws IOException exception could be thrown while writing header
*/
public void writeHeader(SegmentWriteState segmentWriteState) throws IOException {
CodecUtil.writeIndexHeader(output, "QuantizationCodec", 0, segmentWriteState.segmentInfo.getId(), segmentWriteState.segmentSuffix);
}

/**
* Writes a quantization state as bytes
* @param fieldName field name
* @param quantizationState quantization state
* @throws IOException could be thrown while writing
*/
public void writeState(String fieldName, QuantizationState quantizationState) throws IOException {
byte[] stateBytes = quantizationState.toByteArray();
long position = output.getFilePointer();
output.writeBytes(stateBytes, stateBytes.length);
fieldQuantizationStates.add(new FieldQuantizationState(fieldName, stateBytes, position));
}

/**
* Writes index footer and other index information for parsing later
* @throws IOException could be thrown while writing
*/
public void writeFooter() throws IOException {
long indexStartPosition = output.getFilePointer();
output.writeInt(fieldQuantizationStates.size());
for (FieldQuantizationState fieldQuantizationState : fieldQuantizationStates) {
output.writeString(fieldQuantizationState.fieldName);
output.writeInt(fieldQuantizationState.stateBytes.length);
output.writeVLong(fieldQuantizationState.position);
}
output.writeLong(indexStartPosition);
output.writeInt(-1);
CodecUtil.writeFooter(output);
output.close();
fieldQuantizationStates = new ArrayList<>();
}

@AllArgsConstructor
private static class FieldQuantizationState {
final String fieldName;
final byte[] stateBytes;
final Long position;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
Expand Down Expand Up @@ -39,14 +38,20 @@
* A KNNVectorsWriter class for writing the vector data strcutures and flat vectors for Native Engines.
*/
@Log4j2
@RequiredArgsConstructor
public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter {
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class);
private final SegmentWriteState segmentWriteState;
private final FlatVectorsWriter flatVectorsWriter;
private final KNNQuantizationStateWriter quantizationStateWriter;
private final List<NativeEngineFieldVectorsWriter<?>> fields = new ArrayList<>();
private boolean finished;

public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) throws IOException {
this.segmentWriteState = segmentWriteState;
this.flatVectorsWriter = flatVectorsWriter;
this.quantizationStateWriter = new KNNQuantizationStateWriter(segmentWriteState);
}

/**
* Add new field for indexing.
* In Lucene, we use single file for all the vector fields so here we need to see how we are going to make things
Expand All @@ -70,6 +75,9 @@ public KnnFieldVectorsWriter<?> addField(final FieldInfo fieldInfo) throws IOExc
public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
// simply write data in the flat file
flatVectorsWriter.flush(maxDoc, sortMap);

quantizationStateWriter.writeHeader(segmentWriteState);

for (final NativeEngineFieldVectorsWriter<?> field : fields) {
final VectorDataType vectorDataType = extractVectorDataType(field.getFieldInfo());
final KNNVectorValues<?> knnVectorValues = KNNVectorValuesFactory.getVectorValues(
Expand All @@ -78,8 +86,12 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
field.getVectors()
);

// TODO: Extract quantization state here, uncomment below line once implemented
// quantizationStateWriter.writeState(field.getFieldInfo().getName(), quantizationState);

NativeIndexWriter.getWriter(field.getFieldInfo(), segmentWriteState).flushIndex(knnVectorValues);
}
quantizationStateWriter.writeFooter();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.codec.KNN990Codec;

import lombok.SneakyThrows;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.SegmentInfo;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.search.Sort;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Version;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.opensearch.knn.KNNTestCase;

import java.util.Map;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.times;

public class KNNQuantizationStateReaderTests extends KNNTestCase {

@SneakyThrows
public void testReadFromSegmentReadState() {
final String segmentName = "test-segment-name";
final String segmentSuffix = "test-segment-suffix";

final SegmentInfo segmentInfo = new SegmentInfo(
Mockito.mock(Directory.class),
Mockito.mock(Version.class),
Mockito.mock(Version.class),
segmentName,
0,
false,
false,
Mockito.mock(Codec.class),
Mockito.mock(Map.class),
new byte[16],
Mockito.mock(Map.class),
Mockito.mock(Sort.class)
);

Directory directory = Mockito.mock(Directory.class);
IndexInput input = Mockito.mock(IndexInput.class);
Mockito.when(directory.openInput(any(), any())).thenReturn(input);

final SegmentReadState segmentReadState = new SegmentReadState(
directory,
segmentInfo,
Mockito.mock(FieldInfos.class),
Mockito.mock(IOContext.class),
segmentSuffix
);

KNNQuantizationStateReader quantizationStateReader = Mockito.mock(KNNQuantizationStateReader.class);
Mockito.when(quantizationStateReader.getNumFields(input)).thenReturn(2);
Mockito.when(quantizationStateReader.read(segmentReadState)).thenCallRealMethod();

try (MockedStatic<CodecUtil> mockedStaticCodecUtil = mockStatic(CodecUtil.class)) {
quantizationStateReader.read(segmentReadState);

mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input));
Mockito.verify(input, times(2)).readInt();
Mockito.verify(input, times(2)).readString();
Mockito.verify(input, times(2)).readVLong();
Mockito.verify(input, times(2)).readBytes(any(byte[].class), anyInt(), anyInt());
Mockito.verify(input, times(2)).seek(anyLong());
}
}

@SneakyThrows
public void testGetNumFields() {
IndexInput input = Mockito.mock(IndexInput.class);
KNNQuantizationStateReader quantizationStateReader = new KNNQuantizationStateReader();
quantizationStateReader.getNumFields(input);

Mockito.verify(input, times(2)).readInt();
Mockito.verify(input, times(1)).readLong();
Mockito.verify(input, times(2)).seek(anyLong());
Mockito.verify(input, times(1)).length();
}
}
Loading