From 38f5e8de97e477c07c995f9f31830e6ad3ec0602 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Thu, 5 Sep 2024 05:57:20 -0700 Subject: [PATCH] Fix tests related to quantization state (#2045) Signed-off-by: Ryan Bogan (cherry picked from commit 589a27b8e0b4e7f06cd32388d905d620f576c689) --- .../KNN990QuantizationStateReader.java | 98 +++++++------------ .../NativeEngines990KnnVectorsReader.java | 11 ++- .../KNN990QuantizationStateReaderTests.java | 59 +---------- ...NativeEngines990KnnVectorsFormatTests.java | 14 ++- 4 files changed, 52 insertions(+), 130 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java index cea496c5b1..d9b73d621f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java @@ -21,9 +21,6 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; /** * Reads quantization states @@ -32,7 +29,7 @@ public final class KNN990QuantizationStateReader { /** - * Read quantization states and return list of fieldNames and bytes + * Reads an individual quantization state for a given field * File format: * Header * QS1 state bytes @@ -48,37 +45,6 @@ public final class KNN990QuantizationStateReader { * -1 (marker) * Footer * - * @param state the read state to read from - */ - public static Map read(SegmentReadState state) throws IOException { - String quantizationStateFileName = getQuantizationStateFileName(state); - Map readQuantizationStateInfos = null; - - try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) { - CodecUtil.retrieveChecksum(input); - - int numFields = getNumFields(input); - - readQuantizationStateInfos = new HashMap<>(); - - // Read each field's metadata from the index section and then read bytes - for (int i = 0; i < numFields; i++) { - int fieldNumber = input.readInt(); - int length = input.readInt(); - long position = input.readVLong(); - byte[] stateBytes = readStateBytes(input, position, length); - String fieldName = state.fieldInfos.fieldInfo(fieldNumber).getName(); - readQuantizationStateInfos.put(fieldName, stateBytes); - } - } catch (Exception e) { - log.warn(String.format("Unable to read the quantization state file for segment %s", state.segmentInfo.name), e); - return Collections.emptyMap(); - } - return readQuantizationStateInfos; - } - - /** - * Reads an individual quantization state for a given field * @param readConfig a config class that contains necessary information for reading the state * @return quantization state */ @@ -88,41 +54,43 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr String quantizationStateFileName = getQuantizationStateFileName(segmentReadState); int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber(); - IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ); - CodecUtil.retrieveChecksum(input); - int numFields = getNumFields(input); + try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) { - long position = -1; - int length = 0; + CodecUtil.retrieveChecksum(input); + int numFields = getNumFields(input); - // Read each field's metadata from the index section, break when correct field is found - for (int i = 0; i < numFields; i++) { - int tempFieldNumber = input.readInt(); - int tempLength = input.readInt(); - long tempPosition = input.readVLong(); - if (tempFieldNumber == fieldNumber) { - position = tempPosition; - length = tempLength; - break; - } - } + long position = -1; + int length = 0; - if (position == -1 || length == 0) { - throw new IllegalArgumentException(String.format("Field %s not found", field)); - } + // Read each field's metadata from the index section, break when correct field is found + for (int i = 0; i < numFields; i++) { + int tempFieldNumber = input.readInt(); + int tempLength = input.readInt(); + long tempPosition = input.readVLong(); + if (tempFieldNumber == fieldNumber) { + position = tempPosition; + length = tempLength; + break; + } + } - byte[] stateBytes = readStateBytes(input, position, length); + if (position == -1 || length == 0) { + throw new IllegalArgumentException(String.format("Field %s not found", field)); + } - // Deserialize the byte array to a quantization state object - ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType(); - switch (scalarQuantizationType) { - case ONE_BIT: - return OneBitScalarQuantizationState.fromByteArray(stateBytes); - case TWO_BIT: - case FOUR_BIT: - return MultiBitScalarQuantizationState.fromByteArray(stateBytes); - default: - throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType)); + byte[] stateBytes = readStateBytes(input, position, length); + + // Deserialize the byte array to a quantization state object + ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType(); + switch (scalarQuantizationType) { + case ONE_BIT: + return OneBitScalarQuantizationState.fromByteArray(stateBytes); + case TWO_BIT: + case FOUR_BIT: + return MultiBitScalarQuantizationState.fromByteArray(stateBytes); + default: + throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType)); + } } } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java index ae077188aa..16631fd97b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsReader.java @@ -41,13 +41,12 @@ public class NativeEngines990KnnVectorsReader extends KnnVectorsReader { private final FlatVectorsReader flatVectorsReader; private final SegmentReadState segmentReadState; - private final QuantizationStateCacheManager quantizationStateCacheManager = QuantizationStateCacheManager.getInstance(); private Map quantizationStateCacheKeyPerField; public NativeEngines990KnnVectorsReader(final SegmentReadState state, final FlatVectorsReader flatVectorsReader) throws IOException { this.segmentReadState = state; this.flatVectorsReader = flatVectorsReader; - primeQuantizationStateCache(); + loadCacheKeyMap(); } /** @@ -178,8 +177,10 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits @Override public void close() throws IOException { IOUtils.close(flatVectorsReader); - for (String cacheKey : quantizationStateCacheKeyPerField.values()) { - QuantizationStateCacheManager.getInstance().evict(cacheKey); + if (quantizationStateCacheKeyPerField != null) { + for (String cacheKey : quantizationStateCacheKeyPerField.values()) { + QuantizationStateCacheManager.getInstance().evict(cacheKey); + } } } @@ -191,7 +192,7 @@ public long ramBytesUsed() { return flatVectorsReader.ramBytesUsed(); } - private void primeQuantizationStateCache() throws IOException { + private void loadCacheKeyMap() throws IOException { quantizationStateCacheKeyPerField = new HashMap<>(); for (FieldInfo fieldInfo : segmentReadState.fieldInfos) { String cacheKey = UUIDs.base64UUID(); diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java index b20bcacc49..da790e947d 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReaderTests.java @@ -18,7 +18,6 @@ import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Version; -import org.junit.Ignore; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.opensearch.knn.KNNTestCase; @@ -40,58 +39,6 @@ public class KNN990QuantizationStateReaderTests extends KNNTestCase { - @SneakyThrows - public void testReadFromSegmentReadState() { - final String segmentName = "test-segment-name"; - final String segmentSuffix = "test-segment-suffix"; - - final SegmentInfo segmentInfo = new SegmentInfo( - Mockito.mock(Directory.class), - Mockito.mock(Version.class), - Mockito.mock(Version.class), - segmentName, - 0, - false, - false, - Mockito.mock(Codec.class), - Mockito.mock(Map.class), - new byte[16], - Mockito.mock(Map.class), - Mockito.mock(Sort.class) - ); - - Directory directory = Mockito.mock(Directory.class); - IndexInput input = Mockito.mock(IndexInput.class); - Mockito.when(directory.openInput(any(), any())).thenReturn(input); - - String fieldName = "test-field"; - FieldInfos fieldInfos = Mockito.mock(FieldInfos.class); - FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); - Mockito.when(fieldInfo.getName()).thenReturn(fieldName); - Mockito.when(fieldInfos.fieldInfo(anyInt())).thenReturn(fieldInfo); - - final SegmentReadState segmentReadState = new SegmentReadState( - directory, - segmentInfo, - fieldInfos, - Mockito.mock(IOContext.class), - segmentSuffix - ); - - try (MockedStatic mockedStaticReader = Mockito.mockStatic(KNN990QuantizationStateReader.class)) { - mockedStaticReader.when(() -> KNN990QuantizationStateReader.getNumFields(input)).thenReturn(2); - mockedStaticReader.when(() -> KNN990QuantizationStateReader.read(segmentReadState)).thenCallRealMethod(); - try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { - KNN990QuantizationStateReader.read(segmentReadState); - - mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); - Mockito.verify(input, times(4)).readInt(); - Mockito.verify(input, times(2)).readVLong(); - } - } - } - - @Ignore @SneakyThrows public void testReadFromQuantizationStateReadConfig() { String fieldName = "test-field"; @@ -143,11 +90,6 @@ public void testReadFromQuantizationStateReadConfig() { mockedStaticReader.when(() -> KNN990QuantizationStateReader.readStateBytes(any(IndexInput.class), anyLong(), anyInt())) .thenReturn(new byte[8]); try (MockedStatic mockedStaticCodecUtil = mockStatic(CodecUtil.class)) { - assertThrows(IllegalArgumentException.class, () -> KNN990QuantizationStateReader.read(quantizationStateReadConfig)); - - mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); - Mockito.verify(input, times(4)).readInt(); - Mockito.verify(input, times(2)).readVLong(); Mockito.when(input.readInt()).thenReturn(fieldNumber); @@ -158,6 +100,7 @@ public void testReadFromQuantizationStateReadConfig() { .thenReturn(oneBitScalarQuantizationState); QuantizationState quantizationState = KNN990QuantizationStateReader.read(quantizationStateReadConfig); assertEquals(oneBitScalarQuantizationState, quantizationState); + mockedStaticCodecUtil.verify(() -> CodecUtil.retrieveChecksum(input)); } try (MockedStatic mockedStaticOneBit = mockStatic(MultiBitScalarQuantizationState.class)) { diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index 2b5c1f3ec6..f1e48c3a16 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -52,7 +52,6 @@ import org.apache.lucene.util.Version; import org.junit.After; import org.junit.Assert; -import org.junit.Ignore; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.mockito.stubbing.Answer; @@ -69,6 +68,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -96,7 +96,6 @@ public void tearDown() throws Exception { super.tearDown(); } - @Ignore @SneakyThrows public void testReaderAndWriter_whenValidInput_thenSuccess() { final Lucene99FlatVectorsFormat mockedFlatVectorsFormat = Mockito.mock(Lucene99FlatVectorsFormat.class); @@ -129,6 +128,17 @@ public void testReaderAndWriter_whenValidInput_thenSuccess() { FieldInfo fieldInfo = Mockito.mock(FieldInfo.class); Mockito.when(fieldInfo.getName()).thenReturn(fieldName); Mockito.when(fieldInfos.fieldInfo(anyInt())).thenReturn(fieldInfo); + Mockito.when(fieldInfos.iterator()).thenReturn(new Iterator() { + @Override + public boolean hasNext() { + return false; + } + + @Override + public FieldInfo next() { + return null; + } + }); final SegmentReadState mockedSegmentReadState = new SegmentReadState( directory,