Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT_BYTE_SIZE;
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_LONG;
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_LONG_BYTE_SIZE;
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_POINTER;
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.HOST_TO_DEVICE;
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.INFER_DIRECTION;
import static com.nvidia.cuvs.internal.common.Util.allocateRMMSegment;
import static com.nvidia.cuvs.internal.common.Util.buildMemorySegment;
import static com.nvidia.cuvs.internal.common.Util.checkCuVSError;
import static com.nvidia.cuvs.internal.common.Util.concatenate;
import static com.nvidia.cuvs.internal.common.Util.cudaMemcpy;
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.*;
import static com.nvidia.cuvs.internal.common.Util.prepareTensor;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceBuild;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceDeserialize;
Expand All @@ -35,16 +36,13 @@
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceIndex_t;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceSearch;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceSerialize;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsRMMAlloc;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsRMMFree;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsResources_t;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamGet;
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamSync;
import static com.nvidia.cuvs.internal.panama.headers_h.omp_set_num_threads;
import static com.nvidia.cuvs.internal.panama.headers_h.cudaStream_t;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.foreign.Arena;
Expand Down Expand Up @@ -132,6 +130,15 @@ public void destroyIndex() throws Throwable {
try {
int returnValue = cuvsBruteForceIndexDestroy(bruteForceIndexReference.getMemorySegment());
checkCuVSError(returnValue, "cuvsBruteForceIndexDestroy");

if (bruteForceIndexReference.datasetBytes > 0) {
long cuvsResources = resources.getMemorySegment().get(cuvsResources_t, 0);
returnValue = cuvsRMMFree(cuvsResources, bruteForceIndexReference.datasetPtr, bruteForceIndexReference.datasetBytes);
checkCuVSError(returnValue, "cuvsRMMFree");
}
if (bruteForceIndexReference.tensorDataArena != null) {
bruteForceIndexReference.tensorDataArena.close();
}
} finally {
destroyed = true;
}
Expand All @@ -145,54 +152,47 @@ public void destroyIndex() throws Throwable {
* @return an instance of {@link IndexReference} that holds the pointer to the
* index
*/
private IndexReference build() throws Throwable {
try (var localArena = Arena.ofConfined()) {
private IndexReference build() {

long rows = dataset != null? dataset.size(): vectors.length;
long cols = dataset != null? dataset.dimensions(): (rows > 0 ? vectors[0].length : 0);

Arena arena = resources.getArena();
MemorySegment datasetMemSegment = dataset != null? ((DatasetImpl) dataset).seg:
Util.buildMemorySegment(resources.getArena(), vectors);
Comment thread
ldematte marked this conversation as resolved.
Outdated

long cuvsResources = resources.getMemorySegment().get(cuvsResources_t, 0);
MemorySegment stream = arena.allocate(cudaStream_t);
MemorySegment stream = resources.getArena().allocate(cudaStream_t);
Comment thread
ldematte marked this conversation as resolved.
Outdated
var returnValue = cuvsStreamGet(cuvsResources, stream);
checkCuVSError(returnValue, "cuvsStreamGet");

omp_set_num_threads(bruteForceIndexParams.getNumWriterThreads());

MemorySegment datasetMemorySegment = arena.allocate(C_POINTER);

long datasetBytes = C_FLOAT_BYTE_SIZE * rows * cols;
returnValue = cuvsRMMAlloc(cuvsResources, datasetMemorySegment, datasetBytes);
checkCuVSError(returnValue, "cuvsRMMAlloc");

// IMPORTANT: this should only come AFTER cuvsRMMAlloc call
MemorySegment datasetMemorySegmentP = datasetMemorySegment.get(C_POINTER, 0);
MemorySegment datasetMemorySegmentP = allocateRMMSegment(cuvsResources, datasetBytes);

cudaMemcpy(datasetMemorySegmentP, datasetMemSegment, datasetBytes, INFER_DIRECTION);

long datasetShape[] = { rows, cols };
MemorySegment datasetTensor = prepareTensor(arena, datasetMemorySegmentP, datasetShape, 2, 32, 2, 2, 1);

MemorySegment index = arena.allocate(cuvsBruteForceIndex_t);
long[] datasetShape = { rows, cols };
var tensorDataArena = Arena.ofShared();
MemorySegment datasetTensor = prepareTensor(tensorDataArena, datasetMemorySegmentP, datasetShape, 2, 32, 2, 2, 1);

MemorySegment index = resources.getArena().allocate(cuvsBruteForceIndex_t);
Comment thread
ldematte marked this conversation as resolved.
Outdated
returnValue = cuvsBruteForceIndexCreate(index);
checkCuVSError(returnValue, "cuvsBruteForceIndexCreate");

var indexReference = new IndexReference(datasetMemorySegmentP, datasetBytes, tensorDataArena, index);

returnValue = cuvsStreamSync(cuvsResources);
checkCuVSError(returnValue, "cuvsStreamSync");

returnValue = cuvsBruteForceBuild(cuvsResources, datasetTensor, 0, 0.0f, index);
returnValue = cuvsBruteForceBuild(cuvsResources, datasetTensor, 0, 0.0f, indexReference.getMemorySegment());
checkCuVSError(returnValue, "cuvsBruteForceBuild");

returnValue = cuvsStreamSync(cuvsResources);
checkCuVSError(returnValue, "cuvsStreamSync");

omp_set_num_threads(1);

return new IndexReference(index);
}
return indexReference;
}

/**
Expand All @@ -210,84 +210,66 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
long numQueries = cuvsQuery.getQueryVectors().length;
long numBlocks = cuvsQuery.getTopK() * numQueries;
int vectorDimension = numQueries > 0 ? cuvsQuery.getQueryVectors()[0].length : 0;
Arena arena = resources.getArena();

SequenceLayout neighborsSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_LONG);
SequenceLayout distancesSequenceLayout = MemoryLayout.sequenceLayout(numBlocks, C_FLOAT);
MemorySegment neighborsMemorySegment = arena.allocate(neighborsSequenceLayout);
MemorySegment distancesMemorySegment = arena.allocate(distancesSequenceLayout);
MemorySegment neighborsMemorySegment = localArena.allocate(neighborsSequenceLayout);
MemorySegment distancesMemorySegment = localArena.allocate(distancesSequenceLayout);

// prepare the prefiltering data
long prefilterDataLength = 0;
MemorySegment prefilterDataMemorySegment = MemorySegment.NULL;
BitSet[] prefilters = cuvsQuery.getPrefilters();
if (prefilters != null && prefilters.length > 0) {
BitSet concatenatedFilters = concatenate(prefilters, cuvsQuery.getNumDocs());
long filters[] = concatenatedFilters.toLongArray();
prefilterDataMemorySegment = buildMemorySegment(arena, filters);
prefilterDataLength = cuvsQuery.getNumDocs() * prefilters.length;
long[] filters = concatenatedFilters.toLongArray();
prefilterDataMemorySegment = buildMemorySegment(localArena, filters);
prefilterDataLength = (long)cuvsQuery.getNumDocs() * prefilters.length;
}

MemorySegment querySeg = buildMemorySegment(arena, cuvsQuery.getQueryVectors());
MemorySegment querySeg = buildMemorySegment(localArena, cuvsQuery.getQueryVectors());

int topk = cuvsQuery.getTopK();
long cuvsResources = resources.getMemorySegment().get(cuvsResources_t, 0);
MemorySegment stream = arena.allocate(cudaStream_t);
MemorySegment stream = localArena.allocate(cudaStream_t);
var returnValue = cuvsStreamGet(cuvsResources, stream);
checkCuVSError(returnValue, "cuvsStreamGet");

MemorySegment queriesD = arena.allocate(C_POINTER);
MemorySegment neighborsD = arena.allocate(C_POINTER);
MemorySegment distancesD = arena.allocate(C_POINTER);
MemorySegment prefilterD = arena.allocate(C_POINTER);
MemorySegment prefilterDP = MemorySegment.NULL;
long prefilterLen = 0;

long queriesBytes = C_FLOAT_BYTE_SIZE * numQueries * vectorDimension;
long neighborsBytes = C_LONG_BYTE_SIZE * numQueries * topk;
long distanceBytes = C_FLOAT_BYTE_SIZE * numQueries * topk;
long prefilterBytes = 0; // size assigned later

returnValue = cuvsRMMAlloc(cuvsResources, queriesD, queriesBytes);
checkCuVSError(returnValue, "cuvsRMMAlloc");
returnValue = cuvsRMMAlloc(cuvsResources, neighborsD, neighborsBytes);
checkCuVSError(returnValue, "cuvsRMMAlloc");
returnValue = cuvsRMMAlloc(cuvsResources, distancesD, distanceBytes);
checkCuVSError(returnValue, "cuvsRMMAlloc");

// IMPORTANT: these three should only come AFTER cuvsRMMAlloc calls
MemorySegment queriesDP = queriesD.get(C_POINTER, 0);
MemorySegment neighborsDP = neighborsD.get(C_POINTER, 0);
MemorySegment distancesDP = distancesD.get(C_POINTER, 0);
MemorySegment queriesDP = allocateRMMSegment(cuvsResources, queriesBytes);
MemorySegment neighborsDP = allocateRMMSegment(cuvsResources, neighborsBytes);
MemorySegment distancesDP = allocateRMMSegment(cuvsResources, distanceBytes);
MemorySegment prefilterDP = MemorySegment.NULL;

cudaMemcpy(queriesDP, querySeg, queriesBytes, INFER_DIRECTION);

long queriesShape[] = { numQueries, vectorDimension };
MemorySegment queriesTensor = prepareTensor(arena, queriesDP, queriesShape, 2, 32, 2, 2, 1);
long neighborsShape[] = { numQueries, topk };
MemorySegment neighborsTensor = prepareTensor(arena, neighborsDP, neighborsShape, 0, 64, 2, 2, 1);
long distancesShape[] = { numQueries, topk };
MemorySegment distancesTensor = prepareTensor(arena, distancesDP, distancesShape, 2, 32, 2, 2, 1);
long[] queriesShape = { numQueries, vectorDimension };
MemorySegment queriesTensor = prepareTensor(localArena, queriesDP, queriesShape, 2, 32, 2, 2, 1);
long[] neighborsShape = { numQueries, topk };
MemorySegment neighborsTensor = prepareTensor(localArena, neighborsDP, neighborsShape, 0, 64, 2, 2, 1);
long[] distancesShape = { numQueries, topk };
MemorySegment distancesTensor = prepareTensor(localArena, distancesDP, distancesShape, 2, 32, 2, 2, 1);

MemorySegment prefilter = cuvsFilter.allocate(arena);
MemorySegment prefilter = cuvsFilter.allocate(localArena);
MemorySegment prefilterTensor;

if (prefilterDataMemorySegment == MemorySegment.NULL) {
cuvsFilter.type(prefilter, 0); // NO_FILTER
cuvsFilter.addr(prefilter, 0);
} else {
long prefilterShape[] = { (prefilterDataLength + 31) / 32 };
prefilterLen = prefilterShape[0];
long[] prefilterShape = { (prefilterDataLength + 31) / 32 };
long prefilterLen = prefilterShape[0];
prefilterBytes = C_INT_BYTE_SIZE * prefilterLen;

returnValue = cuvsRMMAlloc(cuvsResources, prefilterD, prefilterBytes);
checkCuVSError(returnValue, "cuvsRMMAlloc");

prefilterDP = prefilterD.get(C_POINTER, 0);
prefilterDP = allocateRMMSegment(cuvsResources, prefilterBytes);

cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, HOST_TO_DEVICE);

prefilterTensor = prepareTensor(arena, prefilterDP, prefilterShape, 1, 32, 1, 2, 1);
prefilterTensor = prepareTensor(localArena, prefilterDP, prefilterShape, 1, 32, 1, 2, 1);

cuvsFilter.type(prefilter, 2);
cuvsFilter.addr(prefilter, prefilterTensor.address());
Expand All @@ -312,8 +294,10 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
checkCuVSError(returnValue, "cuvsRMMFree");
returnValue = cuvsRMMFree(cuvsResources, queriesDP, queriesBytes);
checkCuVSError(returnValue, "cuvsRMMFree");
returnValue = cuvsRMMFree(cuvsResources, prefilterDP, prefilterBytes);
checkCuVSError(returnValue, "cuvsRMMFree");
if (prefilterBytes > 0) {
returnValue = cuvsRMMFree(cuvsResources, prefilterDP, prefilterBytes);
checkCuVSError(returnValue, "cuvsRMMFree");
}

return new BruteForceSearchResults(neighborsSequenceLayout, distancesSequenceLayout, neighborsMemorySegment,
distancesMemorySegment, cuvsQuery.getTopK(), cuvsQuery.getMapping(), numQueries);
Expand All @@ -332,12 +316,14 @@ public void serialize(OutputStream outputStream, Path tempFile) throws Throwable
tempFile = tempFile.toAbsolutePath();

long cuvsRes = resources.getMemorySegment().get(cuvsResources_t, 0);
int returnValue = cuvsBruteForceSerialize(cuvsRes, resources.getArena().allocateFrom(tempFile.toString()),
bruteForceIndexReference.getMemorySegment());
checkCuVSError(returnValue, "cuvsBruteForceSerialize");
try (var localArena = Arena.ofConfined()) {
int returnValue = cuvsBruteForceSerialize(cuvsRes, localArena.allocateFrom(tempFile.toString()),
bruteForceIndexReference.getMemorySegment());
checkCuVSError(returnValue, "cuvsBruteForceSerialize");
}

try (FileInputStream fileInputStream = new FileInputStream(tempFile.toFile())) {
fileInputStream.transferTo(outputStream);
try (var inputStream = Files.newInputStream(tempFile)) {
inputStream.transferTo(outputStream);
} finally {
Files.deleteIfExists(tempFile);
}
Expand All @@ -356,11 +342,11 @@ private IndexReference deserialize(InputStream inputStream) throws Throwable {
tmpIndexFile = tmpIndexFile.toAbsolutePath();
IndexReference indexReference = new IndexReference(resources);

try (var in = inputStream; FileOutputStream fileOutputStream = new FileOutputStream(tmpIndexFile.toFile())) {
in.transferTo(fileOutputStream);
try (inputStream; var outputStream = Files.newOutputStream(tmpIndexFile); var arena = Arena.ofConfined()) {
inputStream.transferTo(outputStream);

long cuvsRes = resources.getMemorySegment().get(cuvsResources_t, 0);
int returnValue = cuvsBruteForceDeserialize(cuvsRes, resources.getArena().allocateFrom(tmpIndexFile.toString()),
int returnValue = cuvsBruteForceDeserialize(cuvsRes, arena.allocateFrom(tmpIndexFile.toString()),
indexReference.getMemorySegment());
checkCuVSError(returnValue, "cuvsBruteForceDeserialize");

Expand Down Expand Up @@ -464,37 +450,39 @@ public BruteForceIndexImpl build() throws Throwable {
}

/**
* Holds the memory reference to a BRUTEFORCE index.
* Holds the memory reference to a BRUTEFORCE index, its associated dataset, and the arena used to allocate
* input data
*/
protected static class IndexReference {
private static class IndexReference {

private final MemorySegment memorySegment;
private final MemorySegment datasetPtr;
private final long datasetBytes;
private final Arena tensorDataArena;
private final MemorySegment indexPtr;

/**
* Constructs CagraIndexReference and allocate the MemorySegment.
*/
protected IndexReference(CuVSResourcesImpl resources) {
memorySegment = cuvsBruteForceIndex.allocate(resources.getArena());
this(cuvsBruteForceIndex.allocate(resources.getArena()));
Comment thread
ldematte marked this conversation as resolved.
Outdated
}

/**
* Constructs BruteForceIndexReference with an instance of MemorySegment passed
* as a parameter.
*
* @param indexMemorySegment the MemorySegment instance to use for containing
* index reference
*/
protected IndexReference(MemorySegment indexMemorySegment) {
this.memorySegment = indexMemorySegment;
private IndexReference(MemorySegment datasetPtr, long datasetBytes, Arena tensorDataArena, MemorySegment indexPtr) {
this.datasetPtr = datasetPtr;
this.datasetBytes = datasetBytes;
this.tensorDataArena = tensorDataArena;
this.indexPtr = indexPtr;
}

/**
* Gets the instance of index MemorySegment.
*
* @return index MemorySegment
*/
protected MemorySegment getMemorySegment() {
return memorySegment;
private IndexReference(MemorySegment indexPtr) {
this.datasetPtr = MemorySegment.NULL;
this.datasetBytes = 0;
this.tensorDataArena = null;
this.indexPtr = indexPtr;
}

private MemorySegment getMemorySegment() {
return indexPtr;
}
}
}
Loading