1616
1717package com .nvidia .cuvs ;
1818
19+ import com .nvidia .cuvs .spi .CuVSProvider ;
1920import org .openjdk .jmh .annotations .Benchmark ;
21+ import org .openjdk .jmh .annotations .BenchmarkMode ;
2022import org .openjdk .jmh .annotations .Fork ;
23+ import org .openjdk .jmh .annotations .Mode ;
24+ import org .openjdk .jmh .annotations .OutputTimeUnit ;
2125import org .openjdk .jmh .annotations .Param ;
2226import org .openjdk .jmh .annotations .Scope ;
2327import org .openjdk .jmh .annotations .Setup ;
24- import org .openjdk .jmh .annotations .TearDown ;
2528import org .openjdk .jmh .annotations .State ;
29+ import org .openjdk .jmh .annotations .TearDown ;
2630import org .openjdk .jmh .infra .Blackhole ;
2731
2832import java .lang .foreign .*;
2933import java .nio .file .Files ;
3034import java .nio .file .Path ;
3135import java .util .Random ;
3236import java .util .UUID ;
37+ import java .util .concurrent .TimeUnit ;
3338
3439@ Fork (value = 1 , warmups = 0 )
3540@ State (Scope .Benchmark )
@@ -43,6 +48,10 @@ public class CagraIndexBenchmarks {
4348
4449 private float [][] arrayDataset ;
4550
51+ private Arena arena ;
52+
53+ private MemorySegment memorySegmentDataset ;
54+
4655 private static final Random random = new Random ();
4756
4857 private static float [][] createSampleData (int size , int dimensions ) {
@@ -55,9 +64,48 @@ private static float[][] createSampleData(int size, int dimensions) {
5564 return array ;
5665 }
5766
67+ private static MemorySegment createSampleDataSegment (Arena arena , float [][] array , int size , int dimensions ) {
68+ final ValueLayout .OfFloat C_FLOAT = (ValueLayout .OfFloat ) Linker .nativeLinker ().canonicalLayouts ().get ("float" );
69+
70+ MemoryLayout dataMemoryLayout = MemoryLayout .sequenceLayout ((long )size * dimensions , C_FLOAT );
71+
72+ var segment = arena .allocate (dataMemoryLayout );
73+ for (int i = 0 ; i < size ; ++i ) {
74+ var vector = array [i ];
75+ MemorySegment .copy (vector , 0 , segment , C_FLOAT , (i * dimensions * C_FLOAT .byteSize ()), dimensions );
76+ }
77+ return segment ;
78+ }
79+
80+ private static Dataset fromMemorySegment (MemorySegment memorySegment , int size , int dimensions ) {
81+ try {
82+ return (Dataset )
83+ CuVSProvider .provider ()
84+ .newNativeDatasetBuilder ()
85+ .invokeExact (memorySegment , size , dimensions );
86+ } catch (Throwable e ) {
87+ if (e instanceof Error err ) {
88+ throw err ;
89+ } else if (e instanceof RuntimeException re ) {
90+ throw re ;
91+ } else {
92+ throw new RuntimeException (e );
93+ }
94+ }
95+ }
96+
5897 @ Setup
5998 public void initialize () {
99+ arena = Arena .ofShared ();
60100 arrayDataset = createSampleData (size , dims );
101+ memorySegmentDataset = createSampleDataSegment (arena , arrayDataset , size , dims );
102+ }
103+
104+ @ TearDown
105+ public void cleanUp () {
106+ if (arena != null ) {
107+ arena .close ();
108+ }
61109 }
62110
63111 @ Benchmark
@@ -111,4 +159,43 @@ public void testIndexingFromHeap(Blackhole blackhole) throws Throwable {
111159 blackhole .consume (index );
112160 }
113161 }
162+
163+ @ Benchmark
164+ public void testIndexingFromMemorySegment (Blackhole blackhole ) throws Throwable {
165+ try (CuVSResources resources = CuVSResources .create ()) {
166+ // Configure index parameters
167+ CagraIndexParams indexParams = new CagraIndexParams .Builder ()
168+ .withCagraGraphBuildAlgo (CagraIndexParams .CagraGraphBuildAlgo .NN_DESCENT )
169+ .withGraphDegree (1 )
170+ .withIntermediateGraphDegree (2 )
171+ .withNumWriterThreads (32 )
172+ .withMetric (CagraIndexParams .CuvsDistanceType .L2Expanded )
173+ .build ();
174+
175+ // Create the index with the dataset
176+ CagraIndex index = CagraIndex .newBuilder (resources )
177+ .withDataset (fromMemorySegment (memorySegmentDataset , size , dims ))
178+ .withIndexParams (indexParams )
179+ .build ();
180+ blackhole .consume (index );
181+ }
182+ }
183+
184+ @ Benchmark
185+ @ OutputTimeUnit (TimeUnit .NANOSECONDS )
186+ @ BenchmarkMode (Mode .AverageTime )
187+ public void testDatasetFromHeap (Blackhole blackhole ) throws Throwable {
188+ try (var dataset = Dataset .ofArray (arrayDataset )) {
189+ blackhole .consume (dataset );
190+ }
191+ }
192+
193+ @ Benchmark
194+ @ OutputTimeUnit (TimeUnit .NANOSECONDS )
195+ @ BenchmarkMode (Mode .AverageTime )
196+ public void testDatasetFromMemorySegment (Blackhole blackhole ) throws Throwable {
197+ try (var dataset = fromMemorySegment (memorySegmentDataset , size , dims )) {
198+ blackhole .consume (dataset );
199+ }
200+ }
114201}
0 commit comments