Skip to content

Commit 50ed710

Browse files
authored
[Java] CuVSMatrix for device memory (#1232)
This PR introduces implementation classes for `CuVSDeviceMatrix` (a `CuVSMatrix` backed by device memory). It reworks the base implementation classes a bit to increase reuse, and adds benchmarks and tests for the new classes. Benchmarks were used to try out different implementations so the best one could be chosen: - Row access is backed by a buffer of pinned memory - Builders for device memory use `cudaMemcpyAsync` with the `critical` linker option to directly access heap-based memory - `cuvsMatrixCopy` is used across the board, as it has the same performances of the various `cudaMemcpy*` functions. There are some places in the codebase that will benefit from refactoring to use `CuVSDeviceMatrix` (or a generic `CuVSMatrix` plus `toHost`/`toTensor`/`fromTensor` functions); replacing these multiple ad-hoc implementations with `CuVSDeviceMatrix` will be addressed in a follow-up PR. Final numbers: ``` Benchmark (dims) (size) Mode Cnt Score Error Units CuVSDeviceMatrixBenchmarks.matrixCopyDeviceToHost 2048 16384 thrpt 5 70.531 ± 0.322 ops/s CuVSDeviceMatrixBenchmarks.matrixDeviceBuilder 2048 16384 thrpt 5 35.493 ± 0.772 ops/s CuVSDeviceMatrixBenchmarks.matrixReadRowsFromDevice 2048 16384 thrpt 5 83.616 ± 0.745 ops/s ``` With theoretical max for the PCI-E bus of 15.7 GB/s and a data size of 128MB, we get close to 2/3 of the maximum theoretical throughput (see comments for details). Authors: - Lorenzo Dematté (https://github.com/ldematte) - MithunR (https://github.com/mythrocks) Approvers: - MithunR (https://github.com/mythrocks) URL: #1232
1 parent 8a1f2f6 commit 50ed710

24 files changed

Lines changed: 1337 additions & 326 deletions

java/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
*.iml
2+
hs_err*.log
23
target/
34
jextract-22/
45
openjdk-22-jextract*
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.nvidia.cuvs;
18+
19+
import org.openjdk.jmh.annotations.*;
20+
import org.openjdk.jmh.infra.Blackhole;
21+
22+
import java.util.Random;
23+
24+
@Fork(value = 1, warmups = 0)
25+
@State(Scope.Benchmark)
26+
public class CuVSDeviceMatrixBenchmarks {
27+
28+
@Param({"2048"})
29+
private int dims;
30+
31+
@Param({"16384"})
32+
private int size;
33+
34+
private static final Random random = new Random();
35+
36+
private float[][] data;
37+
38+
private CuVSResources resources;
39+
private CuVSDeviceMatrix deviceMatrix;
40+
private CuVSHostMatrix hostMatrix;
41+
42+
private float[][] createRandomData() {
43+
var array = new float[size][dims];
44+
45+
for (int i = 0; i < size; ++i) {
46+
for (int j = 0; j < dims; ++j) {
47+
array[i][j] = random.nextFloat();
48+
}
49+
}
50+
return array;
51+
}
52+
53+
@Setup
54+
public void initialize() throws Throwable {
55+
data = createRandomData();
56+
resources = CuVSResources.create();
57+
58+
var builder0 = CuVSMatrix.deviceBuilder(resources, size, dims, CuVSMatrix.DataType.FLOAT);
59+
60+
for (int i = 0; i < size; ++i) {
61+
var array = data[i];
62+
builder0.addVector(array);
63+
}
64+
65+
deviceMatrix = builder0.build();
66+
hostMatrix = CuVSMatrix.hostBuilder(size, dims, CuVSMatrix.DataType.FLOAT).build();
67+
}
68+
69+
@TearDown
70+
public void cleanUp() {
71+
if (deviceMatrix != null) {
72+
deviceMatrix.close();
73+
}
74+
if (hostMatrix != null) {
75+
hostMatrix.close();
76+
}
77+
if (resources != null) {
78+
resources.close();
79+
}
80+
}
81+
82+
@Benchmark
83+
public void matrixReadRowsFromDevice(Blackhole bh) {
84+
for (int i = 0; i < size; ++i) {
85+
bh.consume(deviceMatrix.getRow(i));
86+
}
87+
}
88+
89+
@Benchmark
90+
public void matrixCopyDeviceToHost() {
91+
deviceMatrix.toHost(hostMatrix);
92+
}
93+
94+
@Benchmark
95+
public void matrixDeviceBuilder() throws Throwable {
96+
try (CuVSResources resources = CuVSResources.create()) {
97+
var builder = CuVSMatrix.deviceBuilder(resources, size, dims, CuVSMatrix.DataType.FLOAT);
98+
99+
for (int i = 0; i < size; ++i) {
100+
var array = data[i];
101+
builder.addVector(array);
102+
}
103+
CuVSDeviceMatrix matrix = builder.build();
104+
matrix.close();
105+
}
106+
}
107+
}

java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ public interface CagraIndex extends AutoCloseable {
5353

5454
/** Returns the CAGRA graph
5555
*
56-
* @return a {@link CuVSMatrix} encapsulating the native int (uint32_t) array used to represent
56+
* @return a {@link CuVSDeviceMatrix} encapsulating the native int (uint32_t) array used to represent
5757
* the cagra graph
5858
*/
59-
CuVSMatrix getGraph();
59+
CuVSDeviceMatrix getGraph();
6060

6161
/**
6262
* A method to persist a CAGRA index using an instance of {@link OutputStream}

java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSDeviceMatrix.java

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,25 @@
1818
/**
1919
* A Dataset implementation backed by device (GPU) memory.
2020
*/
21-
public interface CuVSDeviceMatrix extends CuVSMatrix {}
21+
public interface CuVSDeviceMatrix extends CuVSMatrix {
22+
23+
/**
24+
* Fills the provided, pre-allocated host matrix with data from this device matrix.
25+
* The content of the provided host matrix will be overwritten; the 2 matrices must have the
26+
* same element type and dimension.
27+
*
28+
* @param hostMatrix the host-memory-backed matrix to fill.
29+
*/
30+
void toHost(CuVSHostMatrix hostMatrix);
31+
32+
/**
33+
* Returns a new, host matrix with data from this device matrix.
34+
* The returned host matrix will need to be managed by the caller, which will be
35+
* responsible to call {@link CuVSMatrix#close()} to free its resources when done.
36+
*/
37+
default CuVSHostMatrix toHost() {
38+
var hostMatrix = CuVSMatrix.hostBuilder(size(), columns(), dataType()).build();
39+
toHost(hostMatrix);
40+
return hostMatrix;
41+
}
42+
}

java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSMatrix.java

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,29 +69,33 @@ static CuVSMatrix ofArray(byte[][] vectors) {
6969
return CuVSProvider.provider().newMatrixFromArray(vectors);
7070
}
7171

72-
interface Builder {
72+
/**
73+
* A builder to construct a new matrix one row at a time
74+
* @param <T> the CuVSMatrix type to build
75+
*/
76+
interface Builder<T extends CuVSMatrix> {
7377
/**
74-
* Add a single vector to the dataset.
78+
* Adds a single vector to the matrix.
7579
*
7680
* @param vector A float array of as many elements as the dimensions
7781
*/
7882
void addVector(float[] vector);
7983

8084
/**
81-
* Add a single vector to the dataset.
85+
* Adds a single vector to the matrix.
8286
*
8387
* @param vector A byte array of as many elements as the dimensions
8488
*/
8589
void addVector(byte[] vector);
8690

8791
/**
88-
* Add a single vector to the dataset.
92+
* Adds a single vector to the matrix.
8993
*
90-
* @param vector A int array of as many elements as the dimensions
94+
* @param vector An int array of as many elements as the dimensions
9195
*/
9296
void addVector(int[] vector);
9397

94-
CuVSMatrix build();
98+
T build();
9599
}
96100

97101
/**
@@ -100,10 +104,24 @@ interface Builder {
100104
* @param size Number of vectors in the dataset
101105
* @param columns Size of each vector in the dataset
102106
* @param dataType The data type of the dataset elements
103-
* @return new instance of {@link CuVSMatrix}
107+
* @return a builder for creating a {@link CuVSHostMatrix}
108+
*/
109+
static Builder<CuVSHostMatrix> hostBuilder(long size, long columns, DataType dataType) {
110+
return CuVSProvider.provider().newHostMatrixBuilder(size, columns, dataType);
111+
}
112+
113+
/**
114+
* Returns a builder to create a new instance of a dataset
115+
*
116+
* @param resources CuVS resources used to allocate the device memory needed
117+
* @param size Number of vectors in the dataset
118+
* @param columns Size of each vector in the dataset
119+
* @param dataType The data type of the dataset elements
120+
* @return a builder for creating a {@link CuVSDeviceMatrix}
104121
*/
105-
static CuVSMatrix.Builder builder(int size, int columns, DataType dataType) {
106-
return CuVSProvider.provider().newMatrixBuilder(size, columns, dataType);
122+
static Builder<CuVSDeviceMatrix> deviceBuilder(
123+
CuVSResources resources, long size, long columns, DataType dataType) {
124+
return CuVSProvider.provider().newDeviceMatrixBuilder(resources, size, columns, dataType);
107125
}
108126

109127
/**
@@ -121,6 +139,13 @@ static CuVSMatrix.Builder builder(int size, int columns, DataType dataType) {
121139
*/
122140
long columns();
123141

142+
/**
143+
* Gets the element type
144+
*
145+
* @return a {@link DataType} describing the matrix element type
146+
*/
147+
DataType dataType();
148+
124149
/**
125150
* Get a view (0-copy) of the row data, as a list of integers (32 bit)
126151
*

java/cuvs-java/src/main/java/com/nvidia/cuvs/RowView.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
package com.nvidia.cuvs;
1717

1818
/**
19-
* Represent a contiguous list of integers (32-bit) backed by off-heap memory.
19+
* Represent a contiguous list of elements backed by off-heap memory.
2020
*
2121
* @since 25.08
2222
*/

java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSProvider.java

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,7 @@
1515
*/
1616
package com.nvidia.cuvs.spi;
1717

18-
import com.nvidia.cuvs.BruteForceIndex;
19-
import com.nvidia.cuvs.CagraIndex;
20-
import com.nvidia.cuvs.CagraMergeParams;
21-
import com.nvidia.cuvs.CuVSMatrix;
22-
import com.nvidia.cuvs.CuVSResources;
23-
import com.nvidia.cuvs.HnswIndex;
24-
import com.nvidia.cuvs.TieredIndex;
18+
import com.nvidia.cuvs.*;
2519
import java.lang.invoke.MethodHandle;
2620
import java.lang.invoke.MethodType;
2721
import java.nio.file.Path;
@@ -52,23 +46,37 @@ default Path nativeLibraryPath() {
5246
/** Creates a new CuVSResources. */
5347
CuVSResources newCuVSResources(Path tempDirectory) throws Throwable;
5448

55-
/** Create a {@link CuVSMatrix.Builder} instance **/
56-
CuVSMatrix.Builder newMatrixBuilder(int size, int dimensions, CuVSMatrix.DataType dataType);
49+
/** Create a {@link CuVSMatrix.Builder} instance for a host memory matrix **/
50+
CuVSMatrix.Builder<CuVSHostMatrix> newHostMatrixBuilder(
51+
long size, long dimensions, CuVSMatrix.DataType dataType);
52+
53+
/** Create a {@link CuVSMatrix.Builder} instance for a device memory matrix **/
54+
CuVSMatrix.Builder<CuVSDeviceMatrix> newDeviceMatrixBuilder(
55+
CuVSResources cuVSResources, long size, long dimensions, CuVSMatrix.DataType dataType);
56+
57+
/** Create a {@link CuVSMatrix.Builder} instance for a device memory matrix **/
58+
CuVSMatrix.Builder<CuVSDeviceMatrix> newDeviceMatrixBuilder(
59+
CuVSResources cuVSResources,
60+
long size,
61+
long dimensions,
62+
int rowStride,
63+
int columnStride,
64+
CuVSMatrix.DataType dataType);
5765

5866
/**
59-
* Returns the factory method used to build a Dataset from native memory.
67+
* Returns the factory method used to build a CuVSMatrix from native memory.
6068
* The factory method will have this signature:
61-
* {@code Dataset createNativeDataset(memorySegment, size, dimensions, dataType)},
69+
* {@code CuVSMatrix createNativeMatrix(memorySegment, size, dimensions, dataType)},
6270
* where {@code memorySegment} is a {@code java.lang.foreign.MemorySegment} containing {@code int size} vectors of
6371
* {@code int dimensions} length of type {@link CuVSMatrix.DataType}.
6472
* <p>
6573
* In order to expose this factory in a way that is compatible with Java 21, the factory method is returned as a
6674
* {@link MethodHandle} with {@link MethodType} equal to
67-
* {@code (Dataset.class, MemorySegment.class, int.class, int.class, Dataset.DataType.class)}.
75+
* {@code (CuVSMatrix.class, MemorySegment.class, int.class, int.class, CuVSMatrix.DataType.class)}.
6876
* The caller will need to invoke the factory via the {@link MethodHandle#invokeExact} method:
69-
* {@code Dataset dataset = (Dataset)newNativeDatasetBuilder().invokeExact(memorySegment, size, dimensions, dataType)}
77+
* {@code var matrix = (CuVSMatrix)newNativeMatrixBuilder().invokeExact(memorySegment, size, dimensions, dataType)}
7078
* </p>
71-
* @return a MethodHandle which can be invoked to build a Dataset from an external {@code MemorySegment}
79+
* @return a MethodHandle which can be invoked to build a CuVSMatrix from an external {@code MemorySegment}
7280
*/
7381
MethodHandle newNativeMatrixBuilder();
7482

java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/UnsupportedProvider.java

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,7 @@
1515
*/
1616
package com.nvidia.cuvs.spi;
1717

18-
import com.nvidia.cuvs.BruteForceIndex;
19-
import com.nvidia.cuvs.CagraIndex;
20-
import com.nvidia.cuvs.CuVSMatrix;
21-
import com.nvidia.cuvs.CuVSResources;
22-
import com.nvidia.cuvs.HnswIndex;
23-
import com.nvidia.cuvs.TieredIndex;
18+
import com.nvidia.cuvs.*;
2419
import java.lang.invoke.MethodHandle;
2520
import java.nio.file.Path;
2621

@@ -55,13 +50,30 @@ public TieredIndex.Builder newTieredIndexBuilder(CuVSResources cuVSResources) {
5550
}
5651

5752
@Override
58-
public CagraIndex mergeCagraIndexes(CagraIndex[] indexes) throws Throwable {
53+
public CagraIndex mergeCagraIndexes(CagraIndex[] indexes) {
5954
throw new UnsupportedOperationException();
6055
}
6156

6257
@Override
63-
public CuVSMatrix.Builder newMatrixBuilder(
64-
int size, int dimensions, CuVSMatrix.DataType dataType) {
58+
public CuVSMatrix.Builder<CuVSHostMatrix> newHostMatrixBuilder(
59+
long size, long dimensions, CuVSMatrix.DataType dataType) {
60+
throw new UnsupportedOperationException();
61+
}
62+
63+
@Override
64+
public CuVSMatrix.Builder<CuVSDeviceMatrix> newDeviceMatrixBuilder(
65+
CuVSResources cuVSResources, long size, long dimensions, CuVSMatrix.DataType dataType) {
66+
throw new UnsupportedOperationException();
67+
}
68+
69+
@Override
70+
public CuVSMatrix.Builder<CuVSDeviceMatrix> newDeviceMatrixBuilder(
71+
CuVSResources cuVSResources,
72+
long size,
73+
long dimensions,
74+
int rowStride,
75+
int columnStride,
76+
CuVSMatrix.DataType dataType) {
6577
throw new UnsupportedOperationException();
6678
}
6779

0 commit comments

Comments
 (0)