diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java index a1645beb01..141f5b0ab4 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java @@ -24,8 +24,9 @@ import static com.nvidia.cuvs.internal.common.LinkerHelper.C_POINTER; import static com.nvidia.cuvs.internal.common.Util.buildMemorySegment; import static com.nvidia.cuvs.internal.common.Util.checkCuVSError; -import static com.nvidia.cuvs.internal.common.Util.checkCudaError; import static com.nvidia.cuvs.internal.common.Util.concatenate; +import static com.nvidia.cuvs.internal.common.Util.cudaMemcpy; +import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.*; import static com.nvidia.cuvs.internal.common.Util.prepareTensor; import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceBuild; import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceDeserialize; @@ -40,7 +41,6 @@ import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamGet; import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamSync; import static com.nvidia.cuvs.internal.panama.headers_h.omp_set_num_threads; -import static com.nvidia.cuvs.internal.panama.headers_h.cudaMemcpy; import static com.nvidia.cuvs.internal.panama.headers_h.cudaStream_t; import java.io.FileInputStream; @@ -170,8 +170,7 @@ private IndexReference build() throws Throwable { // IMPORTANT: this should only come AFTER cuvsRMMAlloc call MemorySegment datasetMemorySegmentP = datasetMemorySegment.get(C_POINTER, 0); - returnValue = cudaMemcpy(datasetMemorySegmentP, datasetMemSegment, datasetBytes, 4); - checkCudaError(returnValue, "cudaMemcpy"); + cudaMemcpy(datasetMemorySegmentP, datasetMemSegment, datasetBytes, INFER_DIRECTION); long datasetShape[] = { rows, cols }; MemorySegment datasetTensor = prepareTensor(arena, datasetMemorySegmentP, datasetShape, 2, 32, 2, 2, 1); @@ -261,8 +260,7 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable { MemorySegment neighborsDP = neighborsD.get(C_POINTER, 0); MemorySegment distancesDP = distancesD.get(C_POINTER, 0); - returnValue = cudaMemcpy(queriesDP, querySeg, queriesBytes, 4); - checkCudaError(returnValue, "cudaMemcpy"); + cudaMemcpy(queriesDP, querySeg, queriesBytes, INFER_DIRECTION); long queriesShape[] = { numQueries, vectorDimension }; MemorySegment queriesTensor = prepareTensor(arena, queriesDP, queriesShape, 2, 32, 2, 2, 1); @@ -287,8 +285,7 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable { prefilterDP = prefilterD.get(C_POINTER, 0); - returnValue = cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, 1); - checkCudaError(returnValue, "cudaMemcpy"); + cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, HOST_TO_DEVICE); prefilterTensor = prepareTensor(arena, prefilterDP, prefilterShape, 1, 32, 1, 2, 1); @@ -306,10 +303,8 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable { returnValue = cuvsStreamSync(cuvsResources); checkCuVSError(returnValue, "cuvsStreamSync"); - returnValue = cudaMemcpy(neighborsMemorySegment, neighborsDP, neighborsBytes, 4); - checkCudaError(returnValue, "cudaMemcpy"); - returnValue = cudaMemcpy(distancesMemorySegment, distancesDP, distanceBytes, 4); - checkCudaError(returnValue, "cudaMemcpy"); + cudaMemcpy(neighborsMemorySegment, neighborsDP, neighborsBytes, INFER_DIRECTION); + cudaMemcpy(distancesMemorySegment, distancesDP, distanceBytes, INFER_DIRECTION); returnValue = cuvsRMMFree(cuvsResources, neighborsDP, neighborsBytes); checkCuVSError(returnValue, "cuvsRMMFree"); diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java index d19106a2c8..33424e65c4 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java @@ -16,17 +16,16 @@ package com.nvidia.cuvs.internal; -import static java.lang.foreign.ValueLayout.ADDRESS; import static com.nvidia.cuvs.internal.common.LinkerHelper.C_FLOAT; import static com.nvidia.cuvs.internal.common.LinkerHelper.C_FLOAT_BYTE_SIZE; import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT; import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT_BYTE_SIZE; import static com.nvidia.cuvs.internal.common.LinkerHelper.C_POINTER; -import static com.nvidia.cuvs.internal.common.LinkerHelper.C_LONG; import static com.nvidia.cuvs.internal.common.Util.buildMemorySegment; import static com.nvidia.cuvs.internal.common.Util.checkCuVSError; -import static com.nvidia.cuvs.internal.common.Util.checkCudaError; import static com.nvidia.cuvs.internal.common.Util.concatenate; +import static com.nvidia.cuvs.internal.common.Util.cudaMemcpy; +import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.*; import static com.nvidia.cuvs.internal.common.Util.prepareTensor; import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraBuild; import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraDeserialize; @@ -45,7 +44,6 @@ import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamGet; import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamSync; import static com.nvidia.cuvs.internal.panama.headers_h.omp_set_num_threads; -import static com.nvidia.cuvs.internal.panama.headers_h.cudaMemcpy; import static com.nvidia.cuvs.internal.panama.headers_h.cudaStream_t; import java.io.FileInputStream; @@ -53,12 +51,10 @@ import java.io.InputStream; import java.io.OutputStream; import java.lang.foreign.Arena; -import java.lang.foreign.FunctionDescriptor; import java.lang.foreign.MemoryLayout; import java.lang.foreign.MemorySegment; import java.lang.foreign.SequenceLayout; import java.lang.foreign.ValueLayout; -import java.lang.invoke.MethodHandle; import java.nio.file.Files; import java.nio.file.Path; import java.util.Objects; @@ -287,8 +283,7 @@ public SearchResults search(CagraQuery query) throws Throwable { MemorySegment prefilterDP = MemorySegment.NULL; long prefilterLen = 0; - returnValue = cudaMemcpy(queriesDP, floatsSeg, queriesBytes, 4); - checkCudaError(returnValue, "cudaMemcpy"); + cudaMemcpy(queriesDP, floatsSeg, queriesBytes, INFER_DIRECTION); long queriesShape[] = { numQueries, vectorDimension }; MemorySegment queriesTensor = prepareTensor(arena, queriesDP, queriesShape, 2, 32, 2, 2, 1); @@ -329,8 +324,7 @@ public SearchResults search(CagraQuery query) throws Throwable { prefilterDP = prefilterD.get(C_POINTER, 0); - returnValue = cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, 1); - checkCudaError(returnValue, "cudaMemcpy"); + cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, HOST_TO_DEVICE); prefilterTensor = prepareTensor(arena, prefilterDP, prefilterShape, 1, 32, 1, 2, 1); @@ -348,10 +342,8 @@ public SearchResults search(CagraQuery query) throws Throwable { returnValue = cuvsStreamSync(cuvsRes); checkCuVSError(returnValue, "cuvsStreamSync"); - returnValue = cudaMemcpy(neighborsMemorySegment, neighborsDP, neighborsBytes, 4); - checkCudaError(returnValue, "cudaMemcpy"); - returnValue = cudaMemcpy(distancesMemorySegment, distancesDP, distancesBytes, 4); - checkCudaError(returnValue, "cudaMemcpy"); + cudaMemcpy(neighborsMemorySegment, neighborsDP, neighborsBytes, INFER_DIRECTION); + cudaMemcpy(distancesMemorySegment, distancesDP, distancesBytes, INFER_DIRECTION); returnValue = cuvsRMMFree(cuvsRes, distancesDP, distancesBytes); checkCuVSError(returnValue, "cuvsRMMFree"); diff --git a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java index 36512225e3..513c66e257 100644 --- a/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java +++ b/java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java @@ -25,14 +25,11 @@ import static com.nvidia.cuvs.internal.panama.headers_h.cudaMemGetInfo; import static com.nvidia.cuvs.internal.panama.headers_h.cudaSetDevice; import static com.nvidia.cuvs.internal.panama.headers_h.size_t; -import static java.lang.foreign.ValueLayout.ADDRESS; import java.lang.foreign.Arena; -import java.lang.foreign.FunctionDescriptor; import java.lang.foreign.MemoryLayout; import java.lang.foreign.MemoryLayout.PathElement; import java.lang.foreign.MemorySegment; -import java.lang.invoke.MethodHandle; import java.lang.invoke.VarHandle; import java.util.ArrayList; import java.util.BitSet; @@ -79,6 +76,48 @@ public static void checkCudaError(int value, String caller) { } } + /** + * Java analog to CUDA's cudaMemcpyKind, used for cudaMemcpy() calls. + * @see CUDA Runtime API + */ + public enum CudaMemcpyKind { + HOST_TO_HOST(0), + HOST_TO_DEVICE(1), + DEVICE_TO_HOST(2), + DEVICE_TO_DEVICE(3), + INFER_DIRECTION(4); + + CudaMemcpyKind(int k) { + this.kind = k; + } + + public final int kind; + } + + /** + * Helper to invoke cudaMemcpy CUDA runtime function to copy data between host/device memory. + * @param dest Destination address for data copy + * @param src Source address for data copy + * @param numBytes Number of bytes to be copied + * @param kind "Direction" of data copy (Host->Device, Device->Host, etc.) + * @throws RuntimeException on failure of copy + */ + public static void cudaMemcpy(MemorySegment dest, MemorySegment src, long numBytes, CudaMemcpyKind kind) { + int returnValue = com.nvidia.cuvs.internal.panama.headers_h.cudaMemcpy(dest, src, numBytes, kind.kind); + checkCudaError(returnValue, "cudaMemcpy"); + } + + /** + * Helper to invoke cudaMemcpy CUDA runtime function to copy data between host/device memory. + * @param dest Destination address for data copy + * @param src Source address for data copy + * @param numBytes Number of bytes to be copied + * @throws RuntimeException on failure of copy + */ + public static void cudaMemcpy(MemorySegment dest, MemorySegment src, long numBytes) { + Util.cudaMemcpy(dest, src, numBytes, CudaMemcpyKind.INFER_DIRECTION); + } + static final long MAX_ERROR_TEXT = 1_000_000L; static String getLastErrorText() {