Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions java/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,14 @@ do `./build.sh java` in the top level directory or just do `./build.sh` in this

Run `./build.sh --run-java-tests` from this directory.

To run a single test:
To run a single test suite:
```sh
cd cuvs-java/
mvn verify -Dintegration-test=com.nvidia.cuvs.CagraBuildAndSearchIT
mvn clean integration-test -Dit.test=com.nvidia.cuvs.CagraBuildAndSearchIT
```
or, for a single test:
```sh
mvn clean integration-test -Dit.test=com.nvidia.cuvs.CagraBuildAndSearchIT#testMergeStrategies
Comment on lines +31 to +35
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

```
Be sure to set (manually, if needed) your `LD_LIBRARY_PATH` to include the directory with the appropriate (matching)
version of `libcuvs.so`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,6 @@ default void serializeToHNSW(OutputStream outputStream, Path tempFile) throws Th
*/
void serializeToHNSW(OutputStream outputStream, Path tempFile, int bufferLength) throws Throwable;

/**
* Gets an instance of {@link CagraIndexParams}
*
* @return an instance of {@link CagraIndexParams}
*/
CagraIndexParams getCagraIndexParameters();
Comment thread
ldematte marked this conversation as resolved.

/**
* Gets an instance of {@link CuVSResources}
*
Expand Down
29 changes: 21 additions & 8 deletions java/cuvs-java/src/main/java/com/nvidia/cuvs/Dataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,47 @@
public interface Dataset extends AutoCloseable {

/**
* Add a single vector to the dataset.
* Creates a dataset from a on-heap array of vectors
*
* @param vector A float array of as many elements as the dimensions
* @since 25.08
*/
public void addVector(float[] vector);
static Dataset ofArray(float[][] vectors) {
return CuVSProvider.provider().newArrayDataset(vectors);
}

interface Builder {
/**
* Add a single vector to the dataset.
*
* @param vector A float array of as many elements as the dimensions
*/
void addVector(float[] vector);

Dataset build();
}

/**
* Create a new instance of a dataset
* Returns a builder to create a new instance of a dataset
*
* @param size Number of vectors in the dataset
* @param dimensions Size of each vector in the dataset
* @return new instance of {@link Dataset}
*/
static Dataset create(int size, int dimensions) {
return CuVSProvider.provider().newDataset(size, dimensions);
static Dataset.Builder builder(int size, int dimensions) {
return CuVSProvider.provider().newDatasetBuilder(size, dimensions);
}

/**
* Gets the size of the dataset
*
* @return Size of the dataset
*/
public int size();
int size();

/**
* Gets the dimensions of the vectors in this dataset
*
* @return Dimensions of the vectors in the dataset
*/
public int dimensions();
int dimensions();
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,11 @@ default Path nativeLibraryPath() {
/** Creates a new CuVSResources. */
CuVSResources newCuVSResources(Path tempDirectory) throws Throwable;

/** Create a {@link Dataset} instance **/
Dataset newDataset(int size, int dimensions) throws UnsupportedOperationException;
/** Create a {@link Dataset.Builder} instance **/
Dataset.Builder newDatasetBuilder(int size, int dimensions);

/** Create a {@link Dataset} backed by a on-heap array **/
Dataset newArrayDataset(float[][] vectors);

/** Creates a new BruteForceIndex Builder. */
BruteForceIndex.Builder newBruteForceIndexBuilder(CuVSResources cuVSResources)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ public CagraIndex mergeCagraIndexes(CagraIndex[] indexes) throws Throwable {
}

@Override
public Dataset newDataset(int size, int dimensions) throws UnsupportedOperationException {
public Dataset.Builder newDatasetBuilder(int size, int dimensions) {
throw new UnsupportedOperationException();
}

@Override
public Dataset newArrayDataset(float[][] vectors) {
throw new UnsupportedOperationException();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.Dataset;
import com.nvidia.cuvs.SearchResults;
import com.nvidia.cuvs.internal.common.Util;
import com.nvidia.cuvs.internal.panama.cuvsBruteForceIndex;
import com.nvidia.cuvs.internal.panama.cuvsFilter;
import java.io.FileInputStream;
Expand All @@ -74,11 +73,8 @@
*/
public class BruteForceIndexImpl implements BruteForceIndex {

private final float[][] vectors;
private final Dataset dataset;
private final CuVSResourcesImpl resources;
private final IndexReference bruteForceIndexReference;
private final BruteForceIndexParams bruteForceIndexParams;
private boolean destroyed;

/**
Expand All @@ -91,16 +87,14 @@ public class BruteForceIndexImpl implements BruteForceIndex {
* holding the index parameters
*/
private BruteForceIndexImpl(
float[][] vectors,
Dataset dataset,
CuVSResourcesImpl resources,
BruteForceIndexParams bruteForceIndexParams)
throws Throwable {
this.vectors = vectors;
this.dataset = dataset;
this.resources = resources;
this.bruteForceIndexParams = bruteForceIndexParams;
this.bruteForceIndexReference = build();
Dataset dataset, CuVSResourcesImpl resources, BruteForceIndexParams bruteForceIndexParams)
throws Exception {
Objects.requireNonNull(dataset);
try (dataset) {
this.resources = resources;
assert dataset instanceof DatasetImpl;
this.bruteForceIndexReference = build((DatasetImpl) dataset, bruteForceIndexParams);
}
}

/**
Expand All @@ -111,9 +105,6 @@ private BruteForceIndexImpl(
*/
private BruteForceIndexImpl(InputStream inputStream, CuVSResourcesImpl resources)
throws Throwable {
this.bruteForceIndexParams = null;
this.vectors = null;
this.dataset = null;
this.resources = resources;
this.bruteForceIndexReference = deserialize(inputStream);
}
Expand All @@ -137,7 +128,6 @@ public void destroyIndex() throws Throwable {
} finally {
destroyed = true;
}
if (dataset != null) dataset.close();
}

/**
Expand All @@ -147,16 +137,13 @@ public void destroyIndex() throws Throwable {
* @return an instance of {@link IndexReference} that holds the pointer to the
* index
*/
private IndexReference build() throws Throwable {
private IndexReference build(DatasetImpl dataset, BruteForceIndexParams bruteForceIndexParams) {
try (var localArena = Arena.ofConfined()) {
long rows = dataset != null ? dataset.size() : vectors.length;
long cols = dataset != null ? dataset.dimensions() : (rows > 0 ? vectors[0].length : 0);
long rows = dataset.size();
long cols = dataset.dimensions();

Arena arena = resources.getArena();
MemorySegment datasetMemSegment =
dataset != null
? ((DatasetImpl) dataset).seg
: Util.buildMemorySegment(resources.getArena(), vectors);
MemorySegment datasetMemSegment = dataset.asMemorySegment();

long cuvsResources = resources.getMemorySegment().get(cuvsResources_t, 0);
MemorySegment stream = arena.allocate(cudaStream_t);
Expand All @@ -176,7 +163,7 @@ private IndexReference build() throws Throwable {

cudaMemcpy(datasetMemorySegmentP, datasetMemSegment, datasetBytes, INFER_DIRECTION);

long datasetShape[] = {rows, cols};
long[] datasetShape = {rows, cols};
MemorySegment datasetTensor =
prepareTensor(arena, datasetMemorySegmentP, datasetShape, 2, 32, 2, 2, 1);

Expand Down Expand Up @@ -228,9 +215,9 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
BitSet[] prefilters = cuvsQuery.getPrefilters();
if (prefilters != null && prefilters.length > 0) {
BitSet concatenatedFilters = concatenate(prefilters, cuvsQuery.getNumDocs());
long filters[] = concatenatedFilters.toLongArray();
long[] filters = concatenatedFilters.toLongArray();
prefilterDataMemorySegment = buildMemorySegment(arena, filters);
prefilterDataLength = cuvsQuery.getNumDocs() * prefilters.length;
prefilterDataLength = (long) cuvsQuery.getNumDocs() * prefilters.length;
}

MemorySegment querySeg = buildMemorySegment(arena, cuvsQuery.getQueryVectors());
Expand Down Expand Up @@ -267,12 +254,12 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {

cudaMemcpy(queriesDP, querySeg, queriesBytes, INFER_DIRECTION);

long queriesShape[] = {numQueries, vectorDimension};
long[] queriesShape = {numQueries, vectorDimension};
MemorySegment queriesTensor = prepareTensor(arena, queriesDP, queriesShape, 2, 32, 2, 2, 1);
long neighborsShape[] = {numQueries, topk};
long[] neighborsShape = {numQueries, topk};
MemorySegment neighborsTensor =
prepareTensor(arena, neighborsDP, neighborsShape, 0, 64, 2, 2, 1);
long distancesShape[] = {numQueries, topk};
long[] distancesShape = {numQueries, topk};
MemorySegment distancesTensor =
prepareTensor(arena, distancesDP, distancesShape, 2, 32, 2, 2, 1);

Expand All @@ -283,7 +270,7 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
cuvsFilter.type(prefilter, 0); // NO_FILTER
cuvsFilter.addr(prefilter, 0);
} else {
long prefilterShape[] = {(prefilterDataLength + 31) / 32};
long[] prefilterShape = {(prefilterDataLength + 31) / 32};
prefilterLen = prefilterShape[0];
prefilterBytes = C_INT_BYTE_SIZE * prefilterLen;

Expand Down Expand Up @@ -411,7 +398,6 @@ public static BruteForceIndex.Builder newBuilder(CuVSResources cuvsResources) {
*/
public static class Builder implements BruteForceIndex.Builder {

private float[][] vectors;
private Dataset dataset;
private final CuVSResourcesImpl cuvsResources;
private BruteForceIndexParams bruteForceIndexParams;
Expand Down Expand Up @@ -460,7 +446,7 @@ public Builder from(InputStream inputStream) {
*/
@Override
public Builder withDataset(float[][] vectors) {
this.vectors = vectors;
this.dataset = Dataset.ofArray(vectors);
return this;
}

Expand All @@ -486,7 +472,7 @@ public BruteForceIndexImpl build() throws Throwable {
if (inputStream != null) {
return new BruteForceIndexImpl(inputStream, cuvsResources);
} else {
return new BruteForceIndexImpl(vectors, dataset, cuvsResources, bruteForceIndexParams);
return new BruteForceIndexImpl(dataset, cuvsResources, bruteForceIndexParams);
}
}
}
Expand Down
Loading