Skip to content

Commit 9e962d1

Browse files
committed
[Java] Add Dataset based on MemorySegment (rapidsai#1034)
This PR adds the ability to define a Dataset directly over a MemorySegment, "wrapping" it instead of allocating a new one. - Depends on rapidsai#1033 and rapidsai#1024 - ~~The new API has a `Object memorySegment` parameter, as we target Java 21 for the API (but 22 for the implementation); it works but it's definitely a hack and we need to sort this out~~ - As discussed, we want to keep targeting Java 21 for the API. This means the API will return a `MethodHandle`, and the Java 22 implementation will use it to return a factory method to build a Dataset from a MemorySegment. - This factory method can then be used as shown in the tests (see the `DatasetHelper` convenience class/method). - Benchmarks show a sizeable speedup -- it is still tiny related to the "big picture" (index build time), but there is an improvement and above all we avoid a whole new copy of the input data (halving the memory requirements). Fixes rapidsai#698 Authors: - Lorenzo Dematté (https://github.com/ldematte) - MithunR (https://github.com/mythrocks) - Ben Frederickson (https://github.com/benfred) Approvers: - Chris Hegarty (https://github.com/ChrisHegarty) - MithunR (https://github.com/mythrocks) URL: rapidsai#1034
1 parent 36da615 commit 9e962d1

8 files changed

Lines changed: 291 additions & 67 deletions

File tree

java/benchmarks/src/main/java/com/nvidia/cuvs/CagraIndexBenchmarks.java

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,25 @@
1616

1717
package com.nvidia.cuvs;
1818

19+
import com.nvidia.cuvs.spi.CuVSProvider;
1920
import org.openjdk.jmh.annotations.Benchmark;
21+
import org.openjdk.jmh.annotations.BenchmarkMode;
2022
import org.openjdk.jmh.annotations.Fork;
23+
import org.openjdk.jmh.annotations.Mode;
24+
import org.openjdk.jmh.annotations.OutputTimeUnit;
2125
import org.openjdk.jmh.annotations.Param;
2226
import org.openjdk.jmh.annotations.Scope;
2327
import org.openjdk.jmh.annotations.Setup;
24-
import org.openjdk.jmh.annotations.TearDown;
2528
import org.openjdk.jmh.annotations.State;
29+
import org.openjdk.jmh.annotations.TearDown;
2630
import org.openjdk.jmh.infra.Blackhole;
2731

2832
import java.lang.foreign.*;
2933
import java.nio.file.Files;
3034
import java.nio.file.Path;
3135
import java.util.Random;
3236
import java.util.UUID;
37+
import java.util.concurrent.TimeUnit;
3338

3439
@Fork(value = 1, warmups = 0)
3540
@State(Scope.Benchmark)
@@ -43,6 +48,10 @@ public class CagraIndexBenchmarks {
4348

4449
private float[][] arrayDataset;
4550

51+
private Arena arena;
52+
53+
private MemorySegment memorySegmentDataset;
54+
4655
private static final Random random = new Random();
4756

4857
private static float[][] createSampleData(int size, int dimensions) {
@@ -55,9 +64,48 @@ private static float[][] createSampleData(int size, int dimensions) {
5564
return array;
5665
}
5766

67+
private static MemorySegment createSampleDataSegment(Arena arena, float[][] array, int size, int dimensions) {
68+
final ValueLayout.OfFloat C_FLOAT = (ValueLayout.OfFloat) Linker.nativeLinker().canonicalLayouts().get("float");
69+
70+
MemoryLayout dataMemoryLayout = MemoryLayout.sequenceLayout((long)size * dimensions, C_FLOAT);
71+
72+
var segment = arena.allocate(dataMemoryLayout);
73+
for (int i = 0; i < size; ++i) {
74+
var vector = array[i];
75+
MemorySegment.copy(vector, 0, segment, C_FLOAT, (i * dimensions * C_FLOAT.byteSize()), dimensions);
76+
}
77+
return segment;
78+
}
79+
80+
private static Dataset fromMemorySegment(MemorySegment memorySegment, int size, int dimensions) {
81+
try {
82+
return (Dataset)
83+
CuVSProvider.provider()
84+
.newNativeDatasetBuilder()
85+
.invokeExact(memorySegment, size, dimensions);
86+
} catch (Throwable e) {
87+
if (e instanceof Error err) {
88+
throw err;
89+
} else if (e instanceof RuntimeException re) {
90+
throw re;
91+
} else {
92+
throw new RuntimeException(e);
93+
}
94+
}
95+
}
96+
5897
@Setup
5998
public void initialize() {
99+
arena = Arena.ofShared();
60100
arrayDataset = createSampleData(size, dims);
101+
memorySegmentDataset = createSampleDataSegment(arena, arrayDataset, size, dims);
102+
}
103+
104+
@TearDown
105+
public void cleanUp() {
106+
if (arena != null) {
107+
arena.close();
108+
}
61109
}
62110

63111
@Benchmark
@@ -111,4 +159,43 @@ public void testIndexingFromHeap(Blackhole blackhole) throws Throwable {
111159
blackhole.consume(index);
112160
}
113161
}
162+
163+
@Benchmark
164+
public void testIndexingFromMemorySegment(Blackhole blackhole) throws Throwable {
165+
try (CuVSResources resources = CuVSResources.create()) {
166+
// Configure index parameters
167+
CagraIndexParams indexParams = new CagraIndexParams.Builder()
168+
.withCagraGraphBuildAlgo(CagraIndexParams.CagraGraphBuildAlgo.NN_DESCENT)
169+
.withGraphDegree(1)
170+
.withIntermediateGraphDegree(2)
171+
.withNumWriterThreads(32)
172+
.withMetric(CagraIndexParams.CuvsDistanceType.L2Expanded)
173+
.build();
174+
175+
// Create the index with the dataset
176+
CagraIndex index = CagraIndex.newBuilder(resources)
177+
.withDataset(fromMemorySegment(memorySegmentDataset, size, dims))
178+
.withIndexParams(indexParams)
179+
.build();
180+
blackhole.consume(index);
181+
}
182+
}
183+
184+
@Benchmark
185+
@OutputTimeUnit(TimeUnit.NANOSECONDS)
186+
@BenchmarkMode(Mode.AverageTime)
187+
public void testDatasetFromHeap(Blackhole blackhole) throws Throwable {
188+
try (var dataset = Dataset.ofArray(arrayDataset)) {
189+
blackhole.consume(dataset);
190+
}
191+
}
192+
193+
@Benchmark
194+
@OutputTimeUnit(TimeUnit.NANOSECONDS)
195+
@BenchmarkMode(Mode.AverageTime)
196+
public void testDatasetFromMemorySegment(Blackhole blackhole) throws Throwable {
197+
try (var dataset = fromMemorySegment(memorySegmentDataset, size, dims)) {
198+
blackhole.consume(dataset);
199+
}
200+
}
114201
}

java/cuvs-java/src/main/java/com/nvidia/cuvs/Dataset.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
public interface Dataset extends AutoCloseable {
2929

3030
/**
31-
* Creates a dataset from a on-heap array of vectors
31+
* Creates a dataset from an on-heap array of vectors
3232
*
3333
* @since 25.08
3434
*/

java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/CuVSProvider.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import com.nvidia.cuvs.Dataset;
2323
import com.nvidia.cuvs.HnswIndex;
2424
import com.nvidia.cuvs.ScalarQuantizer;
25+
import java.lang.invoke.MethodHandle;
26+
import java.lang.invoke.MethodType;
2527
import java.nio.file.Path;
2628

2729
/**
@@ -53,6 +55,22 @@ default Path nativeLibraryPath() {
5355
/** Create a {@link Dataset.Builder} instance **/
5456
Dataset.Builder newDatasetBuilder(int size, int dimensions);
5557

58+
/**
59+
* Returns the factory method used to build a Dataset from native memory.
60+
* The factory method will have this signature: {@code Dataset createNativeDataset(memorySegment, size, dimensions)},
61+
* where {@code memorySegment} is a {@code java.lang.foreign.MemorySegment} containing {@code int size} vectors of
62+
* {@code int dimensions} length.
63+
* <p>
64+
* In order to expose this factory in a way that is compatible with Java 21, the factory method is returned as a
65+
* {@link MethodHandle} with {@link MethodType} equal to
66+
* {@code (Dataset.class, MemorySegment.class, int.class, int.class)}.
67+
* The caller will need to invoke the factory via the {@link MethodHandle#invokeExact} method:
68+
* {@code Dataset dataset = (Dataset)newNativeDatasetBuilder().invokeExact(memorySegment, size, dimensions)}
69+
* </p>
70+
* @return a MethodHandle which can be invoked to build a Dataset from a {@code MemorySegment}
71+
*/
72+
MethodHandle newNativeDatasetBuilder();
73+
5674
/** Create a {@link Dataset} backed by a on-heap array **/
5775
Dataset newArrayDataset(float[][] vectors);
5876

java/cuvs-java/src/main/java/com/nvidia/cuvs/spi/UnsupportedProvider.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import com.nvidia.cuvs.Dataset;
2222
import com.nvidia.cuvs.HnswIndex;
2323
import com.nvidia.cuvs.ScalarQuantizer;
24+
import java.lang.invoke.MethodHandle;
2425
import java.nio.file.Path;
2526

2627
/**
@@ -58,6 +59,11 @@ public Dataset.Builder newDatasetBuilder(int size, int dimensions) {
5859
throw new UnsupportedOperationException();
5960
}
6061

62+
@Override
63+
public MethodHandle newNativeDatasetBuilder() {
64+
throw new UnsupportedOperationException();
65+
}
66+
6167
@Override
6268
public Dataset newArrayDataset(float[][] vectors) {
6369
throw new UnsupportedOperationException();

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSResourcesImpl.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,8 @@ public class CuVSResourcesImpl implements CuVSResources {
3939
/**
4040
* Constructor that allocates the resources needed for cuVS
4141
*
42-
* @throws Throwable exception thrown when native function is invoked
4342
*/
44-
public CuVSResourcesImpl(Path tempDirectory) throws Throwable {
43+
public CuVSResourcesImpl(Path tempDirectory) {
4544
this.tempDirectory = tempDirectory;
4645
this.arena = Arena.ofShared();
4746
try (var localArena = Arena.ofConfined()) {

java/cuvs-java/src/main/java22/com/nvidia/cuvs/spi/JDKProvider.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,32 @@
3535
import java.lang.foreign.Arena;
3636
import java.lang.foreign.MemoryLayout;
3737
import java.lang.foreign.MemorySegment;
38+
import java.lang.invoke.MethodHandle;
39+
import java.lang.invoke.MethodHandles;
40+
import java.lang.invoke.MethodType;
3841
import java.nio.file.Files;
3942
import java.nio.file.Path;
4043
import java.util.Objects;
4144

4245
final class JDKProvider implements CuVSProvider {
4346

47+
private static final MethodHandle createNativeDataset$mh = createNativeDatasetBuilder();
48+
49+
static MethodHandle createNativeDatasetBuilder() {
50+
try {
51+
var lookup = MethodHandles.lookup();
52+
var mt = MethodType.methodType(Dataset.class, MemorySegment.class, int.class, int.class);
53+
return lookup.findStatic(JDKProvider.class, "createNativeDataset", mt);
54+
} catch (NoSuchMethodException | IllegalAccessException e) {
55+
throw new RuntimeException(e);
56+
}
57+
}
58+
59+
private static Dataset createNativeDataset(
60+
MemorySegment memorySegment, int size, int dimensions) {
61+
return new DatasetImpl(null, memorySegment, size, dimensions);
62+
}
63+
4464
@Override
4565
public CuVSResources newCuVSResources(Path tempDirectory) throws Throwable {
4666
Objects.requireNonNull(tempDirectory);
@@ -110,6 +130,11 @@ public Dataset build() {
110130
};
111131
}
112132

133+
@Override
134+
public MethodHandle newNativeDatasetBuilder() {
135+
return createNativeDataset$mh;
136+
}
137+
113138
@Override
114139
public Dataset newArrayDataset(float[][] vectors) {
115140
Objects.requireNonNull(vectors);

0 commit comments

Comments
 (0)