Skip to content

Commit f02ee48

Browse files
authored
[Java] Encapsulate on-heap float arrays into Dataset (#1024)
This PR is a follow-up from #902. Still WIP (see self-comments on the changes) but I'd like some early feedback. Authors: - Lorenzo Dematté (https://github.com/ldematte) - MithunR (https://github.com/mythrocks) Approvers: - Chris Hegarty (https://github.com/ChrisHegarty) - MithunR (https://github.com/mythrocks) URL: #1024
1 parent 52ff6af commit f02ee48

12 files changed

Lines changed: 177 additions & 193 deletions

File tree

java/README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@ do `./build.sh java` in the top level directory or just do `./build.sh` in this
2525

2626
Run `./build.sh --run-java-tests` from this directory.
2727

28-
To run a single test:
28+
To run a single test suite:
2929
```sh
3030
cd cuvs-java/
31-
mvn verify -Dit.test=com.nvidia.cuvs.CagraBuildAndSearchIT
31+
mvn clean integration-test -Dit.test=com.nvidia.cuvs.CagraBuildAndSearchIT
32+
```
33+
or, for a single test:
34+
```sh
35+
mvn clean integration-test -Dit.test=com.nvidia.cuvs.CagraBuildAndSearchIT#testMergeStrategies
3236
```
3337
Be sure to set (manually, if needed) your `LD_LIBRARY_PATH` to include the directory with the appropriate (matching)
3438
version of `libcuvs.so`.

java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,6 @@ default void serializeToHNSW(OutputStream outputStream, Path tempFile) throws Th
143143
*/
144144
void serializeToHNSW(OutputStream outputStream, Path tempFile, int bufferLength) throws Throwable;
145145

146-
/**
147-
* Gets an instance of {@link CagraIndexParams}
148-
*
149-
* @return an instance of {@link CagraIndexParams}
150-
*/
151-
CagraIndexParams getCagraIndexParameters();
152-
153146
/**
154147
* Gets an instance of {@link CuVSResources}
155148
*

java/cuvs-java/src/main/java/com/nvidia/cuvs/Dataset.java

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,34 +28,47 @@
2828
public interface Dataset extends AutoCloseable {
2929

3030
/**
31-
* Add a single vector to the dataset.
31+
* Creates a dataset from a on-heap array of vectors
3232
*
33-
* @param vector A float array of as many elements as the dimensions
33+
* @since 25.08
3434
*/
35-
public void addVector(float[] vector);
35+
static Dataset ofArray(float[][] vectors) {
36+
return CuVSProvider.provider().newArrayDataset(vectors);
37+
}
38+
39+
interface Builder {
40+
/**
41+
* Add a single vector to the dataset.
42+
*
43+
* @param vector A float array of as many elements as the dimensions
44+
*/
45+
void addVector(float[] vector);
46+
47+
Dataset build();
48+
}
3649

3750
/**
38-
* Create a new instance of a dataset
51+
* Returns a builder to create a new instance of a dataset
3952
*
4053
* @param size Number of vectors in the dataset
4154
* @param dimensions Size of each vector in the dataset
4255
* @return new instance of {@link Dataset}
4356
*/
44-
static Dataset create(int size, int dimensions) {
45-
return CuVSProvider.provider().newDataset(size, dimensions);
57+
static Dataset.Builder builder(int size, int dimensions) {
58+
return CuVSProvider.provider().newDatasetBuilder(size, dimensions);
4659
}
4760

4861
/**
4962
* Gets the size of the dataset
5063
*
5164
* @return Size of the dataset
5265
*/
53-
public int size();
66+
int size();
5467

5568
/**
5669
* Gets the dimensions of the vectors in this dataset
5770
*
5871
* @return Dimensions of the vectors in the dataset
5972
*/
60-
public int dimensions();
73+
int dimensions();
6174
}

java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSProvider.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,11 @@ default Path nativeLibraryPath() {
4949
/** Creates a new CuVSResources. */
5050
CuVSResources newCuVSResources(Path tempDirectory) throws Throwable;
5151

52-
/** Create a {@link Dataset} instance **/
53-
Dataset newDataset(int size, int dimensions) throws UnsupportedOperationException;
52+
/** Create a {@link Dataset.Builder} instance **/
53+
Dataset.Builder newDatasetBuilder(int size, int dimensions);
54+
55+
/** Create a {@link Dataset} backed by a on-heap array **/
56+
Dataset newArrayDataset(float[][] vectors);
5457

5558
/** Creates a new BruteForceIndex Builder. */
5659
BruteForceIndex.Builder newBruteForceIndexBuilder(CuVSResources cuVSResources)

java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/UnsupportedProvider.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,12 @@ public CagraIndex mergeCagraIndexes(CagraIndex[] indexes) throws Throwable {
5353
}
5454

5555
@Override
56-
public Dataset newDataset(int size, int dimensions) throws UnsupportedOperationException {
56+
public Dataset.Builder newDatasetBuilder(int size, int dimensions) {
57+
throw new UnsupportedOperationException();
58+
}
59+
60+
@Override
61+
public Dataset newArrayDataset(float[][] vectors) {
5762
throw new UnsupportedOperationException();
5863
}
5964
}

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java

Lines changed: 16 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
import com.nvidia.cuvs.CuVSResources;
4848
import com.nvidia.cuvs.Dataset;
4949
import com.nvidia.cuvs.SearchResults;
50-
import com.nvidia.cuvs.internal.common.Util;
5150
import com.nvidia.cuvs.internal.panama.cuvsFilter;
5251
import java.io.InputStream;
5352
import java.io.OutputStream;
@@ -70,11 +69,8 @@
7069
*/
7170
public class BruteForceIndexImpl implements BruteForceIndex {
7271

73-
private final float[][] vectors;
74-
private final Dataset dataset;
7572
private final CuVSResourcesImpl resources;
7673
private final IndexReference bruteForceIndexReference;
77-
private final BruteForceIndexParams bruteForceIndexParams;
7874
private boolean destroyed;
7975

8076
/**
@@ -87,15 +83,14 @@ public class BruteForceIndexImpl implements BruteForceIndex {
8783
* holding the index parameters
8884
*/
8985
private BruteForceIndexImpl(
90-
float[][] vectors,
91-
Dataset dataset,
92-
CuVSResourcesImpl resources,
93-
BruteForceIndexParams bruteForceIndexParams) {
94-
this.vectors = vectors;
95-
this.dataset = dataset;
96-
this.resources = resources;
97-
this.bruteForceIndexParams = bruteForceIndexParams;
98-
this.bruteForceIndexReference = build();
86+
Dataset dataset, CuVSResourcesImpl resources, BruteForceIndexParams bruteForceIndexParams)
87+
throws Exception {
88+
Objects.requireNonNull(dataset);
89+
try (dataset) {
90+
this.resources = resources;
91+
assert dataset instanceof DatasetImpl;
92+
this.bruteForceIndexReference = build((DatasetImpl) dataset, bruteForceIndexParams);
93+
}
9994
}
10095

10196
/**
@@ -106,9 +101,6 @@ private BruteForceIndexImpl(
106101
*/
107102
private BruteForceIndexImpl(InputStream inputStream, CuVSResourcesImpl resources)
108103
throws Throwable {
109-
this.bruteForceIndexParams = null;
110-
this.vectors = null;
111-
this.dataset = null;
112104
this.resources = resources;
113105
this.bruteForceIndexReference = deserialize(inputStream);
114106
}
@@ -124,7 +116,7 @@ private void checkNotDestroyed() {
124116
* BRUTEFORCE index
125117
*/
126118
@Override
127-
public void destroyIndex() throws Throwable {
119+
public void destroyIndex() {
128120
checkNotDestroyed();
129121
try {
130122
int returnValue = cuvsBruteForceIndexDestroy(bruteForceIndexReference.indexPtr);
@@ -141,7 +133,6 @@ public void destroyIndex() throws Throwable {
141133
} finally {
142134
destroyed = true;
143135
}
144-
if (dataset != null) dataset.close();
145136
}
146137

147138
/**
@@ -151,16 +142,13 @@ public void destroyIndex() throws Throwable {
151142
* @return an instance of {@link IndexReference} that holds the pointer to the
152143
* index
153144
*/
154-
private IndexReference build() {
145+
private IndexReference build(DatasetImpl dataset, BruteForceIndexParams bruteForceIndexParams) {
155146
try (var localArena = Arena.ofConfined()) {
156-
long rows = dataset != null ? dataset.size() : vectors.length;
157-
long cols = dataset != null ? dataset.dimensions() : (rows > 0 ? vectors[0].length : 0);
147+
long rows = dataset.size();
148+
long cols = dataset.dimensions();
158149

159150
Arena arena = resources.getArena();
160-
MemorySegment datasetMemSegment =
161-
dataset != null
162-
? ((DatasetImpl) dataset).seg
163-
: Util.buildMemorySegment(resources.getArena(), vectors);
151+
MemorySegment datasetMemSegment = dataset.asMemorySegment();
164152

165153
long cuvsResources = resources.getHandle();
166154

@@ -248,7 +236,7 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
248236
long queriesBytes = C_FLOAT_BYTE_SIZE * numQueries * vectorDimension;
249237
long neighborsBytes = C_LONG_BYTE_SIZE * numQueries * topk;
250238
long distanceBytes = C_FLOAT_BYTE_SIZE * numQueries * topk;
251-
long prefilterBytes = 0;
239+
long prefilterBytes = 0; // size assigned later
252240

253241
returnValue = cuvsRMMAlloc(cuvsResources, queriesD, queriesBytes);
254242
checkCuVSError(returnValue, "cuvsRMMAlloc");
@@ -430,7 +418,6 @@ public static BruteForceIndex.Builder newBuilder(CuVSResources cuvsResources) {
430418
*/
431419
public static class Builder implements BruteForceIndex.Builder {
432420

433-
private float[][] vectors;
434421
private Dataset dataset;
435422
private final CuVSResourcesImpl cuvsResources;
436423
private BruteForceIndexParams bruteForceIndexParams;
@@ -479,7 +466,7 @@ public Builder from(InputStream inputStream) {
479466
*/
480467
@Override
481468
public Builder withDataset(float[][] vectors) {
482-
this.vectors = vectors;
469+
this.dataset = Dataset.ofArray(vectors);
483470
return this;
484471
}
485472

@@ -505,7 +492,7 @@ public BruteForceIndexImpl build() throws Throwable {
505492
if (inputStream != null) {
506493
return new BruteForceIndexImpl(inputStream, cuvsResources);
507494
} else {
508-
return new BruteForceIndexImpl(vectors, dataset, cuvsResources, bruteForceIndexParams);
495+
return new BruteForceIndexImpl(dataset, cuvsResources, bruteForceIndexParams);
509496
}
510497
}
511498
}

0 commit comments

Comments
 (0)