Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

/**
* Reads quantization states
Expand All @@ -32,7 +29,7 @@
public final class KNN990QuantizationStateReader {

/**
* Read quantization states and return list of fieldNames and bytes
* Reads an individual quantization state for a given field
* File format:
* Header
* QS1 state bytes
Expand All @@ -48,37 +45,6 @@ public final class KNN990QuantizationStateReader {
* -1 (marker)
* Footer
*
* @param state the read state to read from
*/
public static Map<String, byte[]> read(SegmentReadState state) throws IOException {
String quantizationStateFileName = getQuantizationStateFileName(state);
Map<String, byte[]> readQuantizationStateInfos = null;

try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) {
CodecUtil.retrieveChecksum(input);

int numFields = getNumFields(input);

readQuantizationStateInfos = new HashMap<>();

// Read each field's metadata from the index section and then read bytes
for (int i = 0; i < numFields; i++) {
int fieldNumber = input.readInt();
int length = input.readInt();
long position = input.readVLong();
byte[] stateBytes = readStateBytes(input, position, length);
String fieldName = state.fieldInfos.fieldInfo(fieldNumber).getName();
readQuantizationStateInfos.put(fieldName, stateBytes);
}
} catch (Exception e) {
log.warn(String.format("Unable to read the quantization state file for segment %s", state.segmentInfo.name), e);
return Collections.emptyMap();
}
return readQuantizationStateInfos;
}

/**
* Reads an individual quantization state for a given field
* @param readConfig a config class that contains necessary information for reading the state
* @return quantization state
*/
Expand All @@ -88,41 +54,43 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr
String quantizationStateFileName = getQuantizationStateFileName(segmentReadState);
int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber();

IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ);
CodecUtil.retrieveChecksum(input);
int numFields = getNumFields(input);
try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) {

long position = -1;
int length = 0;
CodecUtil.retrieveChecksum(input);
int numFields = getNumFields(input);

// Read each field's metadata from the index section, break when correct field is found
for (int i = 0; i < numFields; i++) {
int tempFieldNumber = input.readInt();
int tempLength = input.readInt();
long tempPosition = input.readVLong();
if (tempFieldNumber == fieldNumber) {
position = tempPosition;
length = tempLength;
break;
}
}
long position = -1;
int length = 0;

if (position == -1 || length == 0) {
throw new IllegalArgumentException(String.format("Field %s not found", field));
}
// Read each field's metadata from the index section, break when correct field is found
for (int i = 0; i < numFields; i++) {
int tempFieldNumber = input.readInt();
int tempLength = input.readInt();
long tempPosition = input.readVLong();
if (tempFieldNumber == fieldNumber) {
position = tempPosition;
length = tempLength;
break;
}
}

byte[] stateBytes = readStateBytes(input, position, length);
if (position == -1 || length == 0) {
throw new IllegalArgumentException(String.format("Field %s not found", field));
}

// Deserialize the byte array to a quantization state object
ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType();
switch (scalarQuantizationType) {
case ONE_BIT:
return OneBitScalarQuantizationState.fromByteArray(stateBytes);
case TWO_BIT:
case FOUR_BIT:
return MultiBitScalarQuantizationState.fromByteArray(stateBytes);
default:
throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType));
byte[] stateBytes = readStateBytes(input, position, length);

// Deserialize the byte array to a quantization state object
ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType();
switch (scalarQuantizationType) {
case ONE_BIT:
return OneBitScalarQuantizationState.fromByteArray(stateBytes);
case TWO_BIT:
case FOUR_BIT:
return MultiBitScalarQuantizationState.fromByteArray(stateBytes);
default:
throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,12 @@ public class NativeEngines990KnnVectorsReader extends KnnVectorsReader {

private final FlatVectorsReader flatVectorsReader;
private final SegmentReadState segmentReadState;
private final QuantizationStateCacheManager quantizationStateCacheManager = QuantizationStateCacheManager.getInstance();
private Map<String, String> quantizationStateCacheKeyPerField;

public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) throws IOException {
this.segmentReadState = state;
this.flatVectorsReader = flatVectorsReader;
primeQuantizationStateCache();
loadCacheKeyMap();
}

/**
Expand Down Expand Up @@ -178,8 +177,10 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
@Override
public void close() throws IOException {
IOUtils.close(flatVectorsReader);
for (String cacheKey : quantizationStateCacheKeyPerField.values()) {
QuantizationStateCacheManager.getInstance().evict(cacheKey);
if (quantizationStateCacheKeyPerField != null) {
for (String cacheKey : quantizationStateCacheKeyPerField.values()) {
QuantizationStateCacheManager.getInstance().evict(cacheKey);
}
}
}

Expand All @@ -191,7 +192,7 @@ public long ramBytesUsed() {
return flatVectorsReader.ramBytesUsed();
}

private void primeQuantizationStateCache() throws IOException {
private void loadCacheKeyMap() throws IOException {
quantizationStateCacheKeyPerField = new HashMap<>();
for (FieldInfo fieldInfo : segmentReadState.fieldInfos) {
String cacheKey = UUIDs.base64UUID();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Version;
import org.junit.Ignore;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.opensearch.knn.KNNTestCase;
Expand All @@ -40,58 +39,6 @@

public class KNN990QuantizationStateReaderTests extends KNNTestCase {

@SneakyThrows
public void testReadFromSegmentReadState() {
final String segmentName = "test-segment-name";
final String segmentSuffix = "test-segment-suffix";

final SegmentInfo segmentInfo = new SegmentInfo(
Mockito.mock(Directory.class),
Mockito.mock(Version.class),
Mockito.mock(Version.class),
segmentName,
0,
false,
false,
Mockito.mock(Codec.class),
Mockito.mock(Map.class),
new byte[16],
Mockito.mock(Map.class),
Mockito.mock(Sort.class)
);

Directory directory = Mockito.mock(Directory.class);
IndexInput input = Mockito.mock(IndexInput.class);
Mockito.when(directory.openInput(any(), any())).thenReturn(input);

String fieldName = "test-field";
FieldInfos fieldInfos = Mockito.mock(FieldInfos.class);
FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
Mockito.when(fieldInfo.getName()).thenReturn(fieldName);
Mockito.when(fieldInfos.fieldInfo(anyInt())).thenReturn(fieldInfo);

final SegmentReadState segmentReadState = new SegmentReadState(
directory,
segmentInfo,
fieldInfos,
Mockito.mock(IOContext.class),
segmentSuffix
);

try (MockedStatic<KNN990QuantizationStateReader> mockedStaticReader = Mockito.mockStatic(KNN990QuantizationStateReader.class)) {
mockedStaticReader.when(() -> KNN990QuantizationStateReader.getNumFields(input)).thenReturn(2);
mockedStaticReader.when(() -> KNN990QuantizationStateReader.read(segmentReadState)).thenCallRealMethod();
try (MockedStatic<CodecUtil> mockedStaticCodecUtil = mockStatic(CodecUtil.class)) {
KNN990QuantizationStateReader.read(segmentReadState);

mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input));
Mockito.verify(input, times(4)).readInt();
Mockito.verify(input, times(2)).readVLong();
}
}
}

@Ignore
@SneakyThrows
public void testReadFromQuantizationStateReadConfig() {
String fieldName = "test-field";
Expand Down Expand Up @@ -143,11 +90,6 @@ public void testReadFromQuantizationStateReadConfig() {
mockedStaticReader.when(() -> KNN990QuantizationStateReader.readStateBytes(any(IndexInput.class), anyLong(), anyInt()))
.thenReturn(new byte[8]);
try (MockedStatic<CodecUtil> mockedStaticCodecUtil = mockStatic(CodecUtil.class)) {
assertThrows(IllegalArgumentException.class, () -> KNN990QuantizationStateReader.read(quantizationStateReadConfig));

mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input));
Mockito.verify(input, times(4)).readInt();
Mockito.verify(input, times(2)).readVLong();

Mockito.when(input.readInt()).thenReturn(fieldNumber);

Expand All @@ -158,6 +100,7 @@ public void testReadFromQuantizationStateReadConfig() {
.thenReturn(oneBitScalarQuantizationState);
QuantizationState quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig);
assertEquals(oneBitScalarQuantizationState, quantizationState);
mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input));
}

try (MockedStatic<MultiBitScalarQuantizationState> mockedStaticOneBit = mockStatic(MultiBitScalarQuantizationState.class)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import org.apache.lucene.util.Version;
import org.junit.After;
import org.junit.Assert;
import org.junit.Ignore;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
Expand All @@ -69,6 +68,7 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -96,7 +96,6 @@ public void tearDown() throws Exception {
super.tearDown();
}

@Ignore
@SneakyThrows
public void testReaderAndWriter_whenValidInput_thenSuccess() {
final Lucene99FlatVectorsFormat mockedFlatVectorsFormat = Mockito.mock(Lucene99FlatVectorsFormat.class);
Expand Down Expand Up @@ -129,6 +128,17 @@ public void testReaderAndWriter_whenValidInput_thenSuccess() {
FieldInfo fieldInfo = Mockito.mock(FieldInfo.class);
Mockito.when(fieldInfo.getName()).thenReturn(fieldName);
Mockito.when(fieldInfos.fieldInfo(anyInt())).thenReturn(fieldInfo);
Mockito.when(fieldInfos.iterator()).thenReturn(new Iterator<FieldInfo>() {
@Override
public boolean hasNext() {
return false;
}

@Override
public FieldInfo next() {
return null;
}
});

final SegmentReadState mockedSegmentReadState = new SegmentReadState(
directory,
Expand Down