Skip to content

Commit e8dbb88

Browse files
authored
[Java] New off-heap Dataset support for CAGRA and Bruteforce (#902)
As reported in #698, current `withDataset(float[][] arr)` requires the entire dataset to be copied in heap first, before writing out the MemorySegment for it. Introducing a new `Dataset` (interface and impl) support with a `addVector(float[] vector)` support for adding the vectors into the MemorySegment one by one, without needing to load them all at once. Authors: - Ishan Chattopadhyaya (https://github.com/chatman) - Vivek Narang (https://github.com/narangvivek10) Approvers: - MithunR (https://github.com/mythrocks) - Corey J. Nolet (https://github.com/cjnolet) URL: #902
1 parent 5460c62 commit e8dbb88

15 files changed

Lines changed: 306 additions & 69 deletions

File tree

java/cuvs-java/src/main/java/com/nvidia/cuvs/BruteForceIndex.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,21 @@ interface Builder {
101101
*/
102102
Builder from(InputStream inputStream);
103103

104+
/**
105+
* Sets the dataset vectors for building the {@link BruteForceIndex}.
106+
*
107+
* @param vectors a two-dimensional float array
108+
* @return an instance of this Builder
109+
*/
110+
Builder withDataset(float[][] vectors);
111+
104112
/**
105113
* Sets the dataset for building the {@link BruteForceIndex}.
106114
*
107-
* @param dataset a two-dimensional float array
115+
* @param dataset a {@link Dataset} object containing the vectors
108116
* @return an instance of this Builder
109117
*/
110-
Builder withDataset(float[][] dataset);
118+
Builder withDataset(Dataset dataset);
111119

112120
/**
113121
* Builds and returns an instance of {@link BruteForceIndex}.

java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.nio.file.Path;
2222
import java.util.Objects;
2323

24+
import com.nvidia.cuvs.BruteForceIndex.Builder;
2425
import com.nvidia.cuvs.spi.CuVSProvider;
2526

2627
/**
@@ -183,13 +184,21 @@ interface Builder {
183184
*/
184185
Builder from(InputStream inputStream);
185186

187+
/**
188+
* Sets the dataset vectors for building the {@link CagraIndex}.
189+
*
190+
* @param vectors a two-dimensional float array
191+
* @return an instance of this Builder
192+
*/
193+
Builder withDataset(float[][] vectors);
194+
186195
/**
187196
* Sets the dataset for building the {@link CagraIndex}.
188197
*
189-
* @param dataset a two-dimensional float array
198+
* @param dataset a {@link Dataset} object containing the vectors
190199
* @return an instance of this Builder
191200
*/
192-
Builder withDataset(float[][] dataset);
201+
Builder withDataset(Dataset dataset);
193202

194203
/**
195204
* Registers an instance of configured {@link CagraIndexParams} with this
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.nvidia.cuvs;
18+
19+
import com.nvidia.cuvs.spi.CuVSProvider;
20+
21+
/**
22+
* This represents a wrapper for a dataset to be used for index construction.
23+
* The purpose is to allow a caller to place the vectors into native memory
24+
* directly, instead of requiring the caller to load all the vectors into the heap
25+
* (e.g. with a float[][]).
26+
*
27+
* @since 25.06
28+
*/
29+
public interface Dataset extends AutoCloseable {
30+
31+
/**
32+
* Add a single vector to the dataset.
33+
*
34+
* @param vector A float array of as many elements as the dimensions
35+
*/
36+
public void addVector(float[] vector);
37+
38+
/**
39+
* Create a new instance of a dataset
40+
*
41+
* @param size Number of vectors in the dataset
42+
* @param dimensions Size of each vector in the dataset
43+
* @return new instance of {@link Dataset}
44+
*/
45+
static Dataset create(int size, int dimensions) {
46+
return CuVSProvider.provider().newDataset(size, dimensions);
47+
}
48+
49+
/**
50+
* Gets the size of the dataset
51+
*
52+
* @return Size of the dataset
53+
*/
54+
public int size();
55+
56+
/**
57+
* Gets the dimensions of the vectors in this dataset
58+
*
59+
* @return Dimensions of the vectors in the dataset
60+
*/
61+
public int dimensions();
62+
}

java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSProvider.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import com.nvidia.cuvs.BruteForceIndex;
2020
import com.nvidia.cuvs.CagraIndex;
2121
import com.nvidia.cuvs.CuVSResources;
22+
import com.nvidia.cuvs.Dataset;
2223
import com.nvidia.cuvs.HnswIndex;
2324

2425
import java.nio.file.Path;
@@ -50,6 +51,9 @@ default Path nativeLibraryPath() {
5051
CuVSResources newCuVSResources(Path tempDirectory)
5152
throws Throwable;
5253

54+
/** Create a {@link Dataset} instance **/
55+
Dataset newDataset(int size, int dimensions) throws UnsupportedOperationException;
56+
5357
/** Creates a new BruteForceIndex Builder. */
5458
BruteForceIndex.Builder newBruteForceIndexBuilder(CuVSResources cuVSResources)
5559
throws UnsupportedOperationException;

java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/UnsupportedProvider.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import com.nvidia.cuvs.BruteForceIndex;
2020
import com.nvidia.cuvs.CagraIndex;
2121
import com.nvidia.cuvs.CuVSResources;
22+
import com.nvidia.cuvs.Dataset;
2223
import com.nvidia.cuvs.HnswIndex;
2324

2425
import java.nio.file.Path;
@@ -47,4 +48,9 @@ public CagraIndex.Builder newCagraIndexBuilder(CuVSResources cuVSResources) {
4748
public HnswIndex.Builder newHnswIndexBuilder(CuVSResources cuVSResources) {
4849
throw new UnsupportedOperationException();
4950
}
51+
52+
@Override
53+
public Dataset newDataset(int size, int dimensions) throws UnsupportedOperationException {
54+
throw new UnsupportedOperationException();
55+
}
5056
}

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import com.nvidia.cuvs.BruteForceIndexParams;
4444
import com.nvidia.cuvs.BruteForceQuery;
4545
import com.nvidia.cuvs.CuVSResources;
46+
import com.nvidia.cuvs.Dataset;
4647
import com.nvidia.cuvs.SearchResults;
4748
import com.nvidia.cuvs.internal.common.Util;
4849
import com.nvidia.cuvs.internal.panama.cuvsBruteForceIndex;
@@ -71,7 +72,8 @@ public class BruteForceIndexImpl implements BruteForceIndex{
7172
private static final MethodHandle deserializeMethodHandle = downcallHandle("deserialize_brute_force_index",
7273
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, ADDRESS, ADDRESS));
7374

74-
private final float[][] dataset;
75+
private final float[][] vectors;
76+
private final Dataset dataset;
7577
private final CuVSResourcesImpl resources;
7678
private final IndexReference bruteForceIndexReference;
7779
private final BruteForceIndexParams bruteForceIndexParams;
@@ -86,8 +88,10 @@ public class BruteForceIndexImpl implements BruteForceIndex{
8688
* @param bruteForceIndexParams an instance of {@link BruteForceIndexParams}
8789
* holding the index parameters
8890
*/
89-
private BruteForceIndexImpl(float[][] dataset, CuVSResourcesImpl resources, BruteForceIndexParams bruteForceIndexParams)
91+
private BruteForceIndexImpl(float[][] vectors, Dataset dataset, CuVSResourcesImpl resources,
92+
BruteForceIndexParams bruteForceIndexParams)
9093
throws Throwable {
94+
this.vectors = vectors;
9195
this.dataset = dataset;
9296
this.resources = resources;
9397
this.bruteForceIndexParams = bruteForceIndexParams;
@@ -102,6 +106,7 @@ private BruteForceIndexImpl(float[][] dataset, CuVSResourcesImpl resources, Brut
102106
*/
103107
private BruteForceIndexImpl(InputStream inputStream, CuVSResourcesImpl resources) throws Throwable {
104108
this.bruteForceIndexParams = null;
109+
this.vectors = null;
105110
this.dataset = null;
106111
this.resources = resources;
107112
this.bruteForceIndexReference = deserialize(inputStream);
@@ -127,6 +132,7 @@ public void destroyIndex() throws Throwable {
127132
} finally {
128133
destroyed = true;
129134
}
135+
if (dataset != null) dataset.close();
130136
}
131137

132138
/**
@@ -137,10 +143,11 @@ public void destroyIndex() throws Throwable {
137143
* index
138144
*/
139145
private IndexReference build() throws Throwable {
140-
long rows = dataset.length;
141-
long cols = rows > 0 ? dataset[0].length : 0;
146+
long rows = dataset != null? dataset.size(): vectors.length;
147+
long cols = dataset != null? dataset.dimensions(): (rows > 0 ? vectors[0].length : 0);
142148

143-
MemorySegment dataSeg = Util.buildMemorySegment(resources.getArena(), dataset);
149+
MemorySegment dataSeg = dataset != null? ((DatasetImpl) dataset).seg:
150+
Util.buildMemorySegment(resources.getArena(), vectors);
144151
try (var localArena = Arena.ofConfined()) {
145152
MemorySegment returnValue = localArena.allocate(C_INT);
146153
MemorySegment indexSeg = (MemorySegment) indexMethodHandle.invokeExact(
@@ -284,7 +291,8 @@ public static BruteForceIndex.Builder newBuilder(CuVSResources cuvsResources) {
284291
*/
285292
public static class Builder implements BruteForceIndex.Builder {
286293

287-
private float[][] dataset;
294+
private float[][] vectors;
295+
private Dataset dataset;
288296
private final CuVSResourcesImpl cuvsResources;
289297
private BruteForceIndexParams bruteForceIndexParams;
290298
private InputStream inputStream;
@@ -327,11 +335,23 @@ public Builder from(InputStream inputStream) {
327335
/**
328336
* Sets the dataset for building the {@link BruteForceIndex}.
329337
*
330-
* @param dataset a two-dimensional float array
338+
* @param vectors a two-dimensional float array
331339
* @return an instance of this Builder
332340
*/
333341
@Override
334-
public Builder withDataset(float[][] dataset) {
342+
public Builder withDataset(float[][] vectors) {
343+
this.vectors = vectors;
344+
return this;
345+
}
346+
347+
/**
348+
* Sets the dataset for building the {@link BruteForceIndex}.
349+
*
350+
* @param dataset a {@link Dataset} object containing the vectors
351+
* @return an instance of this Builder
352+
*/
353+
@Override
354+
public Builder withDataset(Dataset dataset) {
335355
this.dataset = dataset;
336356
return this;
337357
}
@@ -346,7 +366,7 @@ public BruteForceIndexImpl build() throws Throwable {
346366
if (inputStream != null) {
347367
return new BruteForceIndexImpl(inputStream, cuvsResources);
348368
} else {
349-
return new BruteForceIndexImpl(dataset, cuvsResources, bruteForceIndexParams);
369+
return new BruteForceIndexImpl(vectors, dataset, cuvsResources, bruteForceIndexParams);
350370
}
351371
}
352372
}

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import com.nvidia.cuvs.CagraQuery;
4646
import com.nvidia.cuvs.CagraSearchParams;
4747
import com.nvidia.cuvs.CuVSResources;
48+
import com.nvidia.cuvs.Dataset;
4849
import com.nvidia.cuvs.SearchResults;
4950
import com.nvidia.cuvs.internal.common.Util;
5051
import com.nvidia.cuvs.internal.panama.cuvsCagraCompressionParams;
@@ -87,7 +88,8 @@ public class CagraIndexImpl implements CagraIndex {
8788
private static final MethodHandle serializeCAGRAIndexToHNSWMethodHandle = downcallHandle("serialize_cagra_index_to_hnsw",
8889
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, ADDRESS, ADDRESS));
8990

90-
private final float[][] dataset;
91+
private final float[][] vectors;
92+
private final Dataset dataset;
9193
private final CuVSResourcesImpl resources;
9294
private final CagraIndexParams cagraIndexParameters;
9395
private final CagraCompressionParams cagraCompressionParams;
@@ -104,10 +106,11 @@ public class CagraIndexImpl implements CagraIndex {
104106
* @param dataset the dataset for indexing
105107
* @param resources an instance of {@link CuVSResources}
106108
*/
107-
private CagraIndexImpl(CagraIndexParams indexParameters, CagraCompressionParams cagraCompressionParams, float[][] dataset,
108-
CuVSResourcesImpl resources) throws Throwable {
109+
private CagraIndexImpl(CagraIndexParams indexParameters, CagraCompressionParams cagraCompressionParams, float[][] vectors,
110+
Dataset dataset, CuVSResourcesImpl resources) throws Throwable {
109111
this.cagraIndexParameters = indexParameters;
110112
this.cagraCompressionParams = cagraCompressionParams;
113+
this.vectors = vectors;
111114
this.dataset = dataset;
112115
this.resources = resources;
113116
this.cagraIndexReference = build();
@@ -122,6 +125,7 @@ private CagraIndexImpl(CagraIndexParams indexParameters, CagraCompressionParams
122125
private CagraIndexImpl(InputStream inputStream, CuVSResourcesImpl resources) throws Throwable {
123126
this.cagraIndexParameters = null;
124127
this.cagraCompressionParams = null;
128+
this.vectors = null;
125129
this.dataset = null;
126130
this.resources = resources;
127131
this.cagraIndexReference = deserialize(inputStream);
@@ -146,6 +150,7 @@ public void destroyIndex() throws Throwable {
146150
} finally {
147151
destroyed = true;
148152
}
153+
if (dataset != null) dataset.close();
149154
}
150155

151156
/**
@@ -156,8 +161,8 @@ public void destroyIndex() throws Throwable {
156161
* index
157162
*/
158163
private IndexReference build() throws Throwable {
159-
long rows = dataset.length;
160-
long cols = rows > 0 ? dataset[0].length : 0;
164+
long rows = dataset != null? dataset.size(): vectors.length;
165+
long cols = dataset != null? dataset.dimensions(): (rows > 0 ? vectors[0].length : 0);
161166

162167
MemorySegment indexParamsMemorySegment = cagraIndexParameters != null
163168
? segmentFromIndexParams(cagraIndexParameters)
@@ -169,7 +174,8 @@ private IndexReference build() throws Throwable {
169174
? segmentFromCompressionParams(cagraCompressionParams)
170175
: MemorySegment.NULL;
171176

172-
MemorySegment dataSeg = Util.buildMemorySegment(resources.getArena(), dataset);
177+
MemorySegment dataSeg = dataset != null? ((DatasetImpl) dataset).seg:
178+
Util.buildMemorySegment(resources.getArena(), vectors);
173179

174180
try (var localArena = Arena.ofConfined()) {
175181
MemorySegment returnValue = localArena.allocate(C_INT);
@@ -470,7 +476,8 @@ public static CagraIndex.Builder newBuilder(CuVSResources cuvsResources) {
470476
*/
471477
public static class Builder implements CagraIndex.Builder{
472478

473-
private float[][] dataset;
479+
private float[][] vectors;
480+
private Dataset dataset;
474481
private CagraIndexParams cagraIndexParams;
475482
private CagraCompressionParams cagraCompressionParams;
476483
private CuVSResourcesImpl cuvsResources;
@@ -487,7 +494,13 @@ public Builder from(InputStream inputStream) {
487494
}
488495

489496
@Override
490-
public Builder withDataset(float[][] dataset) {
497+
public Builder withDataset(float[][] vectors) {
498+
this.vectors = vectors;
499+
return this;
500+
}
501+
502+
@Override
503+
public Builder withDataset(Dataset dataset) {
491504
this.dataset = dataset;
492505
return this;
493506
}
@@ -509,7 +522,10 @@ public CagraIndexImpl build() throws Throwable {
509522
if (inputStream != null) {
510523
return new CagraIndexImpl(inputStream, cuvsResources);
511524
} else {
512-
return new CagraIndexImpl(cagraIndexParams, cagraCompressionParams, dataset, cuvsResources);
525+
if (vectors != null && dataset != null) {
526+
throw new IllegalArgumentException("Please specify only one type of dataset (a float[] or a Dataset instance)");
527+
}
528+
return new CagraIndexImpl(cagraIndexParams, cagraCompressionParams, vectors, dataset, cuvsResources);
513529
}
514530
}
515531
}

0 commit comments

Comments
 (0)