-
Notifications
You must be signed in to change notification settings - Fork 186
Add quantization state reader and writer #1997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e9fedcb
1f5c030
627ac7e
a81e99d
daa39ed
8abf3cd
e7d5ac8
c804652
f711e39
425b920
2ea5371
89e45de
8cd2ee3
d644c9b
5fe570d
92bb539
5b03f30
9b486d8
0a7d80e
de89987
59be504
366072f
077a1b6
3dbbad9
45b6fbf
0e0eaf3
9f8ce0c
595ec6f
1f67103
80e74e3
60cd2fa
c0b9e71
776d531
8c21304
b304e3c
fe19fc0
d8959d0
ba931da
fdfc301
0e37d2d
eaff8f0
414b2e4
2accfb1
a6c87e3
ff07704
34cd60b
b82fb90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.knn.index.codec.KNN990Codec; | ||
|
|
||
| import com.google.common.annotations.VisibleForTesting; | ||
| import lombok.extern.log4j.Log4j2; | ||
| import org.apache.lucene.codecs.CodecUtil; | ||
| import org.apache.lucene.index.IndexFileNames; | ||
| import org.apache.lucene.index.SegmentReadState; | ||
| import org.apache.lucene.store.IOContext; | ||
| import org.apache.lucene.store.IndexInput; | ||
| import org.opensearch.knn.common.KNNConstants; | ||
| import org.opensearch.knn.quantization.enums.ScalarQuantizationType; | ||
| import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; | ||
| import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; | ||
| import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; | ||
| import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; | ||
| import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateReadConfig; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.Collections; | ||
| import java.util.HashMap; | ||
| import java.util.Map; | ||
|
|
||
| /** | ||
| * Reads quantization states | ||
| */ | ||
| @Log4j2 | ||
| public final class KNN990QuantizationStateReader { | ||
|
|
||
| /** | ||
| * Read quantization states and return list of fieldNames and bytes | ||
| * File format: | ||
| * Header | ||
| * QS1 state bytes | ||
| * QS2 state bytes | ||
| * Number of quantization states | ||
| * QS1 field number | ||
| * QS1 state bytes length | ||
| * QS1 position of state bytes | ||
| * QS2 field number | ||
| * QS2 state bytes length | ||
| * QS2 position of state bytes | ||
| * Position of index section (where QS1 field name is located) | ||
| * -1 (marker) | ||
| * Footer | ||
| * | ||
| * @param state the read state to read from | ||
| */ | ||
| public static Map<String, byte[]> read(SegmentReadState state) throws IOException { | ||
| String quantizationStateFileName = getQuantizationStateFileName(state); | ||
| Map<String, byte[]> readQuantizationStateInfos = null; | ||
|
|
||
| try (IndexInput input = state.directory.openInput(quantizationStateFileName, IOContext.READ)) { | ||
| CodecUtil.retrieveChecksum(input); | ||
|
|
||
| int numFields = getNumFields(input); | ||
|
|
||
| readQuantizationStateInfos = new HashMap<>(); | ||
|
|
||
| // Read each field's metadata from the index section and then read bytes | ||
| for (int i = 0; i < numFields; i++) { | ||
| int fieldNumber = input.readInt(); | ||
| int length = input.readInt(); | ||
| long position = input.readVLong(); | ||
| byte[] stateBytes = readStateBytes(input, position, length); | ||
| String fieldName = state.fieldInfos.fieldInfo(fieldNumber).getName(); | ||
| readQuantizationStateInfos.put(fieldName, stateBytes); | ||
| } | ||
| } catch (Exception e) { | ||
| log.warn(String.format("Unable to read the quantization state file for segment %s", state.segmentInfo.name), e); | ||
| return Collections.emptyMap(); | ||
| } | ||
| return readQuantizationStateInfos; | ||
| } | ||
|
|
||
| /** | ||
| * Reads an individual quantization state for a given field | ||
| * @param readConfig a config class that contains necessary information for reading the state | ||
| * @return quantization state | ||
| */ | ||
| public static QuantizationState read(QuantizationStateReadConfig readConfig) throws IOException { | ||
| SegmentReadState segmentReadState = readConfig.getSegmentReadState(); | ||
| String field = readConfig.getField(); | ||
| String quantizationStateFileName = getQuantizationStateFileName(segmentReadState); | ||
| int fieldNumber = segmentReadState.fieldInfos.fieldInfo(field).getFieldNumber(); | ||
|
|
||
| try (IndexInput input = segmentReadState.directory.openInput(quantizationStateFileName, IOContext.READ)) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a try without a catch, The same applies here, return null for a specific exception and throw for others
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally we can have it , but this piece of code is getting called from KnnWeight only when QuantizationParm is not null.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same for this #1997 (comment)
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. File not found excpetion will never be for this case , as in KNN wight we already have check before calling this function |
||
| CodecUtil.retrieveChecksum(input); | ||
| int numFields = getNumFields(input); | ||
|
|
||
| long position = -1; | ||
| int length = 0; | ||
|
|
||
| // Read each field's metadata from the index section, break when correct field is found | ||
| for (int i = 0; i < numFields; i++) { | ||
| int tempFieldNumber = input.readInt(); | ||
| int tempLength = input.readInt(); | ||
| long tempPosition = input.readVLong(); | ||
| if (tempFieldNumber == fieldNumber) { | ||
| position = tempPosition; | ||
| length = tempLength; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| if (position == -1 || length == 0) { | ||
| throw new IllegalArgumentException(String.format("Field %s not found", field)); | ||
| } | ||
|
|
||
| byte[] stateBytes = readStateBytes(input, position, length); | ||
|
|
||
| // Deserialize the byte array to a quantization state object | ||
| ScalarQuantizationType scalarQuantizationType = ((ScalarQuantizationParams) readConfig.getQuantizationParams()).getSqType(); | ||
| switch (scalarQuantizationType) { | ||
| case ONE_BIT: | ||
| return OneBitScalarQuantizationState.fromByteArray(stateBytes); | ||
| case TWO_BIT: | ||
| case FOUR_BIT: | ||
| return MultiBitScalarQuantizationState.fromByteArray(stateBytes); | ||
| default: | ||
| throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType)); | ||
| } | ||
| } catch (Exception e) { | ||
| log.warn(String.format("Unable to read the quantization state file for segment %s", segmentReadState.segmentInfo.name), e); | ||
| return null; | ||
| } | ||
| } | ||
|
|
||
| @VisibleForTesting | ||
| static int getNumFields(IndexInput input) throws IOException { | ||
| long footerStart = input.length() - CodecUtil.footerLength(); | ||
| long markerAndIndexPosition = footerStart - Integer.BYTES - Long.BYTES; | ||
| input.seek(markerAndIndexPosition); | ||
| long indexStartPosition = input.readLong(); | ||
| input.seek(indexStartPosition); | ||
| return input.readInt(); | ||
| } | ||
|
|
||
| @VisibleForTesting | ||
| static byte[] readStateBytes(IndexInput input, long position, int length) throws IOException { | ||
| input.seek(position); | ||
| byte[] stateBytes = new byte[length]; | ||
| input.readBytes(stateBytes, 0, length); | ||
| return stateBytes; | ||
| } | ||
|
|
||
| @VisibleForTesting | ||
| static String getQuantizationStateFileName(SegmentReadState state) { | ||
| return IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX); | ||
| } | ||
ryanbogan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| /* | ||
| * Copyright OpenSearch Contributors | ||
| * SPDX-License-Identifier: Apache-2.0 | ||
| */ | ||
|
|
||
| package org.opensearch.knn.index.codec.KNN990Codec; | ||
|
|
||
| import lombok.AllArgsConstructor; | ||
| import lombok.Setter; | ||
| import org.apache.lucene.codecs.CodecUtil; | ||
| import org.apache.lucene.index.IndexFileNames; | ||
| import org.apache.lucene.index.SegmentWriteState; | ||
| import org.apache.lucene.store.IndexOutput; | ||
| import org.opensearch.knn.common.KNNConstants; | ||
| import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.ArrayList; | ||
| import java.util.List; | ||
|
|
||
| /** | ||
| * Writes quantization states to off heap memory | ||
| */ | ||
| public final class KNN990QuantizationStateWriter { | ||
|
|
||
| private final IndexOutput output; | ||
| private List<FieldQuantizationState> fieldQuantizationStates = new ArrayList<>(); | ||
| static final String NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA = "NativeEngines990KnnVectorsFormatQSData"; | ||
|
|
||
| /** | ||
| * Constructor | ||
| * Overall file format for writer: | ||
| * Header | ||
| * QS1 state bytes | ||
| * QS2 state bytes | ||
| * Number of quantization states | ||
| * QS1 field number | ||
| * QS1 state bytes length | ||
| * QS1 position of state bytes | ||
| * QS2 field number | ||
| * QS2 state bytes length | ||
| * QS2 position of state bytes | ||
| * Position of index section (where QS1 field name is located) | ||
| * -1 (marker) | ||
| * Footer | ||
| * @param segmentWriteState segment write state containing segment information | ||
| * @throws IOException exception could be thrown while creating the output | ||
| */ | ||
| public KNN990QuantizationStateWriter(SegmentWriteState segmentWriteState) throws IOException { | ||
| String quantizationStateFileName = IndexFileNames.segmentFileName( | ||
| segmentWriteState.segmentInfo.name, | ||
| segmentWriteState.segmentSuffix, | ||
| KNNConstants.QUANTIZATION_STATE_FILE_SUFFIX | ||
| ); | ||
|
|
||
| output = segmentWriteState.directory.createOutput(quantizationStateFileName, segmentWriteState.context); | ||
| } | ||
|
|
||
| /** | ||
| * Writes an index header | ||
| * @param segmentWriteState state containing segment information | ||
| * @throws IOException exception could be thrown while writing header | ||
| */ | ||
| public void writeHeader(SegmentWriteState segmentWriteState) throws IOException { | ||
| CodecUtil.writeIndexHeader( | ||
| output, | ||
| NATIVE_ENGINES_990_KNN_VECTORS_FORMAT_QS_DATA, | ||
| 0, | ||
| segmentWriteState.segmentInfo.getId(), | ||
| segmentWriteState.segmentSuffix | ||
| ); | ||
| } | ||
|
|
||
| /** | ||
| * Writes a quantization state as bytes | ||
| * | ||
| * @param fieldNumber field number | ||
| * @param quantizationState quantization state | ||
| * @throws IOException could be thrown while writing | ||
| */ | ||
| public void writeState(int fieldNumber, QuantizationState quantizationState) throws IOException { | ||
| byte[] stateBytes = quantizationState.toByteArray(); | ||
| long position = output.getFilePointer(); | ||
| output.writeBytes(stateBytes, stateBytes.length); | ||
| fieldQuantizationStates.add(new FieldQuantizationState(fieldNumber, stateBytes, position)); | ||
| } | ||
|
|
||
| /** | ||
| * Writes index footer and other index information for parsing later | ||
| * @throws IOException could be thrown while writing | ||
| */ | ||
| public void writeFooter() throws IOException { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thinking out loud here, There is a small risk here that writing of footer is dependent on client. To safe guard that the footer is written you can check if its written in close (if this class implement closeable). One of the downsides I see there is that if its used with try with resources the footer will be written even in exception cases and I am not sure if thats considered as a corrupt file. Not blocking |
||
| long indexStartPosition = output.getFilePointer(); | ||
| output.writeInt(fieldQuantizationStates.size()); | ||
| for (FieldQuantizationState fieldQuantizationState : fieldQuantizationStates) { | ||
| output.writeInt(fieldQuantizationState.fieldNumber); | ||
| output.writeInt(fieldQuantizationState.stateBytes.length); | ||
| output.writeVLong(fieldQuantizationState.position); | ||
| } | ||
| output.writeLong(indexStartPosition); | ||
| output.writeInt(-1); | ||
| CodecUtil.writeFooter(output); | ||
| } | ||
|
|
||
| @AllArgsConstructor | ||
| private static class FieldQuantizationState { | ||
| final int fieldNumber; | ||
| final byte[] stateBytes; | ||
| @Setter | ||
| Long position; | ||
| } | ||
|
|
||
| public void closeOutput() throws IOException { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe implement Closeable so you can leverage try with resources wherever needed |
||
| output.close(); | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.