|
42 | 42 | import com.nvidia.cuvs.CagraIndex; |
43 | 43 | import com.nvidia.cuvs.CagraIndexParams; |
44 | 44 | import com.nvidia.cuvs.CagraIndexParams.CagraGraphBuildAlgo; |
| 45 | +import com.nvidia.cuvs.CagraMergeParams; |
45 | 46 | import com.nvidia.cuvs.CagraQuery; |
46 | 47 | import com.nvidia.cuvs.CagraSearchParams; |
47 | 48 | import com.nvidia.cuvs.CuVSResources; |
@@ -87,6 +88,15 @@ public class CagraIndexImpl implements CagraIndex { |
87 | 88 | private static final MethodHandle serializeCAGRAIndexToHNSWMethodHandle = downcallHandle("serialize_cagra_index_to_hnsw", |
88 | 89 | FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, ADDRESS, ADDRESS)); |
89 | 90 |
|
| 91 | + private static final MethodHandle mergeParamsCreateMethodHandle = downcallHandle("cuvsCagraMergeParamsCreate", |
| 92 | + FunctionDescriptor.of(C_INT, ADDRESS)); |
| 93 | + |
| 94 | + private static final MethodHandle mergeParamsDestroyMethodHandle = downcallHandle("cuvsCagraMergeParamsDestroy", |
| 95 | + FunctionDescriptor.of(C_INT, ADDRESS)); |
| 96 | + |
| 97 | + private static final MethodHandle mergeMethodHandle = downcallHandle("cuvsCagraMerge", |
| 98 | + FunctionDescriptor.of(C_INT, ADDRESS, ADDRESS, C_INT, ADDRESS, ADDRESS)); |
| 99 | + |
90 | 100 | private final float[][] dataset; |
91 | 101 | private final CuVSResourcesImpl resources; |
92 | 102 | private final CagraIndexParams cagraIndexParameters; |
@@ -127,6 +137,21 @@ private CagraIndexImpl(InputStream inputStream, CuVSResourcesImpl resources) thr |
127 | 137 | this.cagraIndexReference = deserialize(inputStream); |
128 | 138 | } |
129 | 139 |
|
| 140 | + /** |
| 141 | + * Constructor for creating a CagraIndexImpl from an existing index reference |
| 142 | + * Used internally by merge operation |
| 143 | + * |
| 144 | + * @param indexRef The index reference to use |
| 145 | + * @param resources The CuVS resources instance |
| 146 | + */ |
| 147 | + private CagraIndexImpl(IndexReference indexRef, CuVSResourcesImpl resources) { |
| 148 | + this.cagraIndexParameters = null; |
| 149 | + this.cagraCompressionParams = null; |
| 150 | + this.dataset = null; |
| 151 | + this.resources = resources; |
| 152 | + this.cagraIndexReference = indexRef; |
| 153 | + } |
| 154 | + |
130 | 155 | private void checkNotDestroyed() { |
131 | 156 | if (destroyed) { |
132 | 157 | throw new IllegalStateException("destroyed"); |
@@ -230,6 +255,70 @@ public SearchResults search(CagraQuery query) throws Throwable { |
230 | 255 | distancesMemorySegment, topK, query.getMapping(), numQueries); |
231 | 256 | } |
232 | 257 |
|
| 258 | + /** |
| 259 | + * Merges multiple CAGRA indices into a single CAGRA index. |
| 260 | + * |
| 261 | + * @param indices Array of CAGRA indices to merge |
| 262 | + * @param mergeParams Parameters for the merge operation |
| 263 | + * @return A new merged CAGRA index |
| 264 | + * @throws Throwable if an error occurs during merging |
| 265 | + */ |
| 266 | + @Override |
| 267 | + public CagraIndex merge(CagraIndex[] indices, CagraMergeParams mergeParams) throws Throwable { |
| 268 | + checkNotDestroyed(); |
| 269 | + Objects.requireNonNull(indices, "indices cannot be null"); |
| 270 | + Objects.requireNonNull(mergeParams, "mergeParams cannot be null"); |
| 271 | + |
| 272 | + if (indices.length == 0) { |
| 273 | + throw new IllegalArgumentException("indices array cannot be empty"); |
| 274 | + } |
| 275 | + |
| 276 | + try (var arena = Arena.ofConfined()) { |
| 277 | + // Create merge params |
| 278 | + MemorySegment paramsPtr = arena.allocate(ADDRESS); |
| 279 | + MemorySegment returnValue = arena.allocate(C_INT); |
| 280 | + |
| 281 | + mergeParamsCreateMethodHandle.invokeExact(paramsPtr); |
| 282 | + checkError(returnValue.get(C_INT, 0L), "mergeParamsCreateMethodHandle"); |
| 283 | + |
| 284 | + // Get array of index pointers |
| 285 | + MemorySegment[] indexPtrs = new MemorySegment[indices.length]; |
| 286 | + for (int i = 0; i < indices.length; i++) { |
| 287 | + if (!(indices[i] instanceof CagraIndexImpl)) { |
| 288 | + throw new IllegalArgumentException("All indices must be CagraIndexImpl instances"); |
| 289 | + } |
| 290 | + indexPtrs[i] = ((CagraIndexImpl)indices[i]).cagraIndexReference.getMemorySegment(); |
| 291 | + } |
| 292 | + |
| 293 | + // Create array segment for index pointers |
| 294 | + SequenceLayout indexPtrLayout = MemoryLayout.sequenceLayout(indices.length, ADDRESS); |
| 295 | + MemorySegment indexPtrArray = arena.allocate(indexPtrLayout); |
| 296 | + for (int i = 0; i < indices.length; i++) { |
| 297 | + indexPtrArray.setAtIndex(ADDRESS, i, indexPtrs[i]); |
| 298 | + } |
| 299 | + |
| 300 | + // Create output index pointer |
| 301 | + MemorySegment outputIndexPtr = arena.allocate(ADDRESS); |
| 302 | + |
| 303 | + // Call merge function |
| 304 | + mergeMethodHandle.invokeExact( |
| 305 | + resources.getMemorySegment(), |
| 306 | + indexPtrArray, |
| 307 | + indices.length, |
| 308 | + paramsPtr, |
| 309 | + outputIndexPtr |
| 310 | + ); |
| 311 | + checkError(returnValue.get(C_INT, 0L), "mergeMethodHandle"); |
| 312 | + |
| 313 | + // Clean up merge params |
| 314 | + mergeParamsDestroyMethodHandle.invokeExact(paramsPtr); |
| 315 | + checkError(returnValue.get(C_INT, 0L), "mergeParamsDestroyMethodHandle"); |
| 316 | + |
| 317 | + // Create new CagraIndexImpl with merged index |
| 318 | + return new CagraIndexImpl(new IndexReference(outputIndexPtr.get(ADDRESS, 0)), resources); |
| 319 | + } |
| 320 | + } |
| 321 | + |
233 | 322 | @Override |
234 | 323 | public void serialize(OutputStream outputStream) throws Throwable { |
235 | 324 | Path p = Files.createTempFile(resources.tempDirectory(), UUID.randomUUID().toString(), ".cag"); |
|
0 commit comments