Skip to content

Commit 0afbed8

Browse files
committed
[Java]Support for merge API for CAGRA index
1 parent 9e5a53e commit 0afbed8

3 files changed

Lines changed: 187 additions & 0 deletions

File tree

java/cuvs-java/src/main/java/com/nvidia/cuvs/CagraIndex.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ public interface CagraIndex {
5252
*/
5353
SearchResults search(CagraQuery query) throws Throwable;
5454

55+
/**
56+
* Merges multiple CAGRA indices into a single CAGRA index.
57+
*
58+
* @param indices Array of CAGRA indices to merge
59+
* @param mergeParams Parameters for the merge operation
60+
* @return A new merged CAGRA index
61+
* @throws Throwable if an error occurs during merging
62+
* @since 25.06
63+
*/
64+
CagraIndex merge(CagraIndex[] indices, CagraMergeParams mergeParams) throws Throwable;
65+
5566
/**
5667
* A method to persist a CAGRA index using an instance of {@link OutputStream}
5768
* for writing index bytes.
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Copyright (c) 2025, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.nvidia.cuvs;
18+
19+
/**
20+
* Parameters for merging CAGRA indices.
21+
*
22+
* @since 25.06
23+
*/
24+
public class CagraMergeParams {
25+
26+
/**
27+
* Strategy for merging CAGRA indices.
28+
*/
29+
public enum MergeStrategy {
30+
/**
31+
* Merge indices physically by combining their data structures.
32+
*/
33+
PHYSICAL(0),
34+
35+
/**
36+
* Merge indices logically (if supported).
37+
*/
38+
LOGICAL(1);
39+
40+
private final int value;
41+
42+
MergeStrategy(int value) {
43+
this.value = value;
44+
}
45+
46+
/**
47+
* Get the integer value of the strategy for native code.
48+
*
49+
* @return the integer value
50+
*/
51+
public int getValue() {
52+
return value;
53+
}
54+
}
55+
56+
private final CagraIndexParams outputIndexParams;
57+
private final MergeStrategy strategy;
58+
59+
/**
60+
* Creates a new instance of CagraMergeParams.
61+
*
62+
* @param outputIndexParams parameters for the output merged index
63+
* @param strategy the merge strategy to use
64+
*/
65+
public CagraMergeParams(CagraIndexParams outputIndexParams, MergeStrategy strategy) {
66+
this.outputIndexParams = outputIndexParams;
67+
this.strategy = strategy;
68+
}
69+
70+
/**
71+
* Gets the output index parameters.
72+
*
73+
* @return the output index parameters
74+
*/
75+
public CagraIndexParams getOutputIndexParams() {
76+
return outputIndexParams;
77+
}
78+
79+
/**
80+
* Gets the merge strategy.
81+
*
82+
* @return the merge strategy
83+
*/
84+
public MergeStrategy getStrategy() {
85+
return strategy;
86+
}
87+
}

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import com.nvidia.cuvs.CagraIndex;
4343
import com.nvidia.cuvs.CagraIndexParams;
4444
import com.nvidia.cuvs.CagraIndexParams.CagraGraphBuildAlgo;
45+
import com.nvidia.cuvs.CagraMergeParams;
4546
import com.nvidia.cuvs.CagraQuery;
4647
import com.nvidia.cuvs.CagraSearchParams;
4748
import com.nvidia.cuvs.CuVSResources;
@@ -87,6 +88,15 @@ public class CagraIndexImpl implements CagraIndex {
8788
private static final MethodHandle serializeCAGRAIndexToHNSWMethodHandle = downcallHandle("serialize_cagra_index_to_hnsw",
8889
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, ADDRESS, ADDRESS));
8990

91+
private static final MethodHandle mergeParamsCreateMethodHandle = downcallHandle("cuvsCagraMergeParamsCreate",
92+
FunctionDescriptor.of(C_INT, ADDRESS));
93+
94+
private static final MethodHandle mergeParamsDestroyMethodHandle = downcallHandle("cuvsCagraMergeParamsDestroy",
95+
FunctionDescriptor.of(C_INT, ADDRESS));
96+
97+
private static final MethodHandle mergeMethodHandle = downcallHandle("cuvsCagraMerge",
98+
FunctionDescriptor.of(C_INT, ADDRESS, ADDRESS, C_INT, ADDRESS, ADDRESS));
99+
90100
private final float[][] dataset;
91101
private final CuVSResourcesImpl resources;
92102
private final CagraIndexParams cagraIndexParameters;
@@ -127,6 +137,21 @@ private CagraIndexImpl(InputStream inputStream, CuVSResourcesImpl resources) thr
127137
this.cagraIndexReference = deserialize(inputStream);
128138
}
129139

140+
/**
141+
* Constructor for creating a CagraIndexImpl from an existing index reference
142+
* Used internally by merge operation
143+
*
144+
* @param indexRef The index reference to use
145+
* @param resources The CuVS resources instance
146+
*/
147+
private CagraIndexImpl(IndexReference indexRef, CuVSResourcesImpl resources) {
148+
this.cagraIndexParameters = null;
149+
this.cagraCompressionParams = null;
150+
this.dataset = null;
151+
this.resources = resources;
152+
this.cagraIndexReference = indexRef;
153+
}
154+
130155
private void checkNotDestroyed() {
131156
if (destroyed) {
132157
throw new IllegalStateException("destroyed");
@@ -230,6 +255,70 @@ public SearchResults search(CagraQuery query) throws Throwable {
230255
distancesMemorySegment, topK, query.getMapping(), numQueries);
231256
}
232257

258+
/**
259+
* Merges multiple CAGRA indices into a single CAGRA index.
260+
*
261+
* @param indices Array of CAGRA indices to merge
262+
* @param mergeParams Parameters for the merge operation
263+
* @return A new merged CAGRA index
264+
* @throws Throwable if an error occurs during merging
265+
*/
266+
@Override
267+
public CagraIndex merge(CagraIndex[] indices, CagraMergeParams mergeParams) throws Throwable {
268+
checkNotDestroyed();
269+
Objects.requireNonNull(indices, "indices cannot be null");
270+
Objects.requireNonNull(mergeParams, "mergeParams cannot be null");
271+
272+
if (indices.length == 0) {
273+
throw new IllegalArgumentException("indices array cannot be empty");
274+
}
275+
276+
try (var arena = Arena.ofConfined()) {
277+
// Create merge params
278+
MemorySegment paramsPtr = arena.allocate(ADDRESS);
279+
MemorySegment returnValue = arena.allocate(C_INT);
280+
281+
mergeParamsCreateMethodHandle.invokeExact(paramsPtr);
282+
checkError(returnValue.get(C_INT, 0L), "mergeParamsCreateMethodHandle");
283+
284+
// Get array of index pointers
285+
MemorySegment[] indexPtrs = new MemorySegment[indices.length];
286+
for (int i = 0; i < indices.length; i++) {
287+
if (!(indices[i] instanceof CagraIndexImpl)) {
288+
throw new IllegalArgumentException("All indices must be CagraIndexImpl instances");
289+
}
290+
indexPtrs[i] = ((CagraIndexImpl)indices[i]).cagraIndexReference.getMemorySegment();
291+
}
292+
293+
// Create array segment for index pointers
294+
SequenceLayout indexPtrLayout = MemoryLayout.sequenceLayout(indices.length, ADDRESS);
295+
MemorySegment indexPtrArray = arena.allocate(indexPtrLayout);
296+
for (int i = 0; i < indices.length; i++) {
297+
indexPtrArray.setAtIndex(ADDRESS, i, indexPtrs[i]);
298+
}
299+
300+
// Create output index pointer
301+
MemorySegment outputIndexPtr = arena.allocate(ADDRESS);
302+
303+
// Call merge function
304+
mergeMethodHandle.invokeExact(
305+
resources.getMemorySegment(),
306+
indexPtrArray,
307+
indices.length,
308+
paramsPtr,
309+
outputIndexPtr
310+
);
311+
checkError(returnValue.get(C_INT, 0L), "mergeMethodHandle");
312+
313+
// Clean up merge params
314+
mergeParamsDestroyMethodHandle.invokeExact(paramsPtr);
315+
checkError(returnValue.get(C_INT, 0L), "mergeParamsDestroyMethodHandle");
316+
317+
// Create new CagraIndexImpl with merged index
318+
return new CagraIndexImpl(new IndexReference(outputIndexPtr.get(ADDRESS, 0)), resources);
319+
}
320+
}
321+
233322
@Override
234323
public void serialize(OutputStream outputStream) throws Throwable {
235324
Path p = Files.createTempFile(resources.tempDirectory(), UUID.randomUUID().toString(), ".cag");

0 commit comments

Comments
 (0)