diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index f4115ba9874b..804c3a79228e 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -64,6 +64,9 @@ Bug Fixes returns a string which can be parsed back into the original node. (Peter Barna, Adam Schwartz) +* GITHUB#14847: Allow Faiss vector format to index >2GB of vectors per-field per-segment by using MemorySegment APIs + (instead of ByteBuffer) to copy bytes to native memory. (Kaival Parikh) + Changes in Runtime Behavior --------------------- * GITHUB#14187: The query cache is now disabled by default. (Adrien Grand) diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java index c521c4c20108..c7ad78ec6729 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/faiss/LibFaissC.java @@ -17,6 +17,7 @@ package org.apache.lucene.sandbox.codecs.faiss; import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_BYTE; import static java.lang.foreign.ValueLayout.JAVA_FLOAT; import static java.lang.foreign.ValueLayout.JAVA_INT; import static java.lang.foreign.ValueLayout.JAVA_LONG; @@ -32,8 +33,6 @@ import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; import java.nio.ByteOrder; -import java.nio.FloatBuffer; -import java.nio.LongBuffer; import java.util.Arrays; import java.util.Locale; import org.apache.lucene.index.FloatVectorValues; @@ -221,16 +220,22 @@ public static MemorySegment createIndex( // Allocate docs in native memory MemorySegment docs = temp.allocate(JAVA_FLOAT, (long) size * dimension); - FloatBuffer docsBuffer = docs.asByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer(); + long docsOffset = 0; + long perDocByteSize = dimension * JAVA_FLOAT.byteSize(); // Allocate ids in native memory MemorySegment ids = temp.allocate(JAVA_LONG, size); - LongBuffer idsBuffer = ids.asByteBuffer().order(ByteOrder.nativeOrder()).asLongBuffer(); + int idsIndex = 0; KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); for (int i = iterator.nextDoc(); i != NO_MORE_DOCS; i = iterator.nextDoc()) { - idsBuffer.put(oldToNewDocId.apply(i)); - docsBuffer.put(floatVectorValues.vectorValue(iterator.index())); + int id = oldToNewDocId.apply(i); + ids.setAtIndex(JAVA_LONG, idsIndex, id); + idsIndex++; + + float[] vector = floatVectorValues.vectorValue(iterator.index()); + MemorySegment.copy(vector, 0, docs, JAVA_FLOAT, docsOffset, vector.length); + docsOffset += perDocByteSize; } // Train index @@ -254,18 +259,12 @@ private static long writeBytes( inputPointer = inputPointer.reinterpret(size); if (size <= BUFFER_SIZE) { // simple case, avoid buffering - byte[] bytes = new byte[(int) size]; - inputPointer.asSlice(0, size).asByteBuffer().order(ByteOrder.nativeOrder()).get(bytes); - output.writeBytes(bytes, bytes.length); + output.writeBytes(inputPointer.toArray(JAVA_BYTE), (int) size); } else { // copy buffered number of bytes repeatedly byte[] bytes = new byte[BUFFER_SIZE]; for (long offset = 0; offset < size; offset += BUFFER_SIZE) { int length = (int) Math.min(size - offset, BUFFER_SIZE); - inputPointer - .asSlice(offset, length) - .asByteBuffer() - .order(ByteOrder.nativeOrder()) - .get(bytes, 0, length); + MemorySegment.copy(inputPointer, JAVA_BYTE, offset, bytes, 0, length); output.writeBytes(bytes, length); } } @@ -282,21 +281,13 @@ private static long readBytes( if (size <= BUFFER_SIZE) { // simple case, avoid buffering byte[] bytes = new byte[(int) size]; input.readBytes(bytes, 0, bytes.length); - outputPointer - .asSlice(0, bytes.length) - .asByteBuffer() - .order(ByteOrder.nativeOrder()) - .put(bytes); + MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, 0, bytes.length); } else { // copy buffered number of bytes repeatedly byte[] bytes = new byte[BUFFER_SIZE]; for (long offset = 0; offset < size; offset += BUFFER_SIZE) { int length = (int) Math.min(size - offset, BUFFER_SIZE); input.readBytes(bytes, 0, length); - outputPointer - .asSlice(offset, length) - .asByteBuffer() - .order(ByteOrder.nativeOrder()) - .put(bytes, 0, length); + MemorySegment.copy(bytes, 0, outputPointer, JAVA_BYTE, offset, length); } } return numItems; @@ -411,8 +402,7 @@ public static void indexSearch( }; // Allocate queries in native memory - MemorySegment queries = temp.allocate(JAVA_FLOAT, query.length); - queries.asByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer().put(query); + MemorySegment queries = temp.allocateFrom(JAVA_FLOAT, query); // Faiss knn search int k = knnCollector.k(); @@ -427,10 +417,9 @@ public static void indexSearch( MemorySegment pointer = temp.allocate(ADDRESS); long[] bits = fixedBitSet.getBits(); - MemorySegment nativeBits = temp.allocate(JAVA_LONG, bits.length); - - // Use LITTLE_ENDIAN to convert long[] -> uint8_t* - nativeBits.asByteBuffer().order(ByteOrder.LITTLE_ENDIAN).asLongBuffer().put(bits); + MemorySegment nativeBits = + // Use LITTLE_ENDIAN to convert long[] -> uint8_t* + temp.allocateFrom(JAVA_LONG.withOrder(ByteOrder.LITTLE_ENDIAN), bits); callAndHandleError(ID_SELECTOR_BITMAP_NEW, pointer, fixedBitSet.length(), nativeBits); MemorySegment idSelectorBitmapPointer = @@ -458,13 +447,9 @@ public static void indexSearch( idsPointer); } - // Retrieve scores - float[] distances = new float[k]; - distancesPointer.asByteBuffer().order(ByteOrder.nativeOrder()).asFloatBuffer().get(distances); - - // Retrieve ids - long[] ids = new long[k]; - idsPointer.asByteBuffer().order(ByteOrder.nativeOrder()).asLongBuffer().get(ids); + // Retrieve scores and ids + float[] distances = distancesPointer.toArray(JAVA_FLOAT); + long[] ids = idsPointer.toArray(JAVA_LONG); // Record hits for (int i = 0; i < k; i++) { diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java index 4239e3d0b3b1..4a3fb3661e6d 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/faiss/TestFaissKnnVectorsFormat.java @@ -21,9 +21,16 @@ import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; import java.io.IOException; +import java.util.Collections; +import java.util.List; import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.util.TestUtil; import org.junit.BeforeClass; @@ -108,4 +115,27 @@ public void testEmptyByteVectorData() {} @Override @Ignore // does not support byte vectors public void testMergingWithDifferentByteKnnFields() {} + + @Monster("Uses large amount of heap and RAM") + public void testLargeVectorData() throws IOException { + KnnVectorsFormat format = + new FaissKnnVectorsFormat( + "IDMap,Flat", // no need for special indexing like HNSW + ""); + IndexWriterConfig config = + newIndexWriterConfig().setCodec(TestUtil.alwaysKnnVectorsFormat(format)); + + float[] largeVector = + new float[format.getMaxDimensions("vector")]; // largest vector accepted by the format + int numDocs = + Math.ceilDivExact( + Integer.MAX_VALUE, Float.BYTES * largeVector.length); // find minimum number of docs + + // Check that we can index vectors larger than Integer.MAX_VALUE number of bytes + try (Directory directory = newDirectory(); + IndexWriter writer = new IndexWriter(directory, config)) { + writer.addDocuments( + Collections.nCopies(numDocs, List.of(new KnnFloatVectorField("vector", largeVector)))); + } + } }