Skip to content

Commit 71a408a

Browse files
authored
[java] Utility function for cudaMemcpy (#983)
This commit does a minor rephrasing of calls to `cudaMemCpy`. Prior to this change: 1. Calls to `cudaMemcpy` were using hardcoded / undocumented magic numbers for the direction of the copy. These have been replaced an enum that mirrors what is used in the CUDA Runtime call. 2. `cudaMemcpy` would return an error value that needed to be checked / converted into a RuntimeError. This has been moved into the utility function, to reduce the noise at the call site. This change also includes some minor cleanup of unused imports in the modified files. Authors: - MithunR (https://github.com/mythrocks) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #983
1 parent 41995c2 commit 71a408a

3 files changed

Lines changed: 55 additions & 29 deletions

File tree

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_POINTER;
2525
import static com.nvidia.cuvs.internal.common.Util.buildMemorySegment;
2626
import static com.nvidia.cuvs.internal.common.Util.checkCuVSError;
27-
import static com.nvidia.cuvs.internal.common.Util.checkCudaError;
2827
import static com.nvidia.cuvs.internal.common.Util.concatenate;
28+
import static com.nvidia.cuvs.internal.common.Util.cudaMemcpy;
29+
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.*;
2930
import static com.nvidia.cuvs.internal.common.Util.prepareTensor;
3031
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceBuild;
3132
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceDeserialize;
@@ -40,7 +41,6 @@
4041
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamGet;
4142
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamSync;
4243
import static com.nvidia.cuvs.internal.panama.headers_h.omp_set_num_threads;
43-
import static com.nvidia.cuvs.internal.panama.headers_h.cudaMemcpy;
4444
import static com.nvidia.cuvs.internal.panama.headers_h.cudaStream_t;
4545

4646
import java.io.FileInputStream;
@@ -170,8 +170,7 @@ private IndexReference build() throws Throwable {
170170
// IMPORTANT: this should only come AFTER cuvsRMMAlloc call
171171
MemorySegment datasetMemorySegmentP = datasetMemorySegment.get(C_POINTER, 0);
172172

173-
returnValue = cudaMemcpy(datasetMemorySegmentP, datasetMemSegment, datasetBytes, 4);
174-
checkCudaError(returnValue, "cudaMemcpy");
173+
cudaMemcpy(datasetMemorySegmentP, datasetMemSegment, datasetBytes, INFER_DIRECTION);
175174

176175
long datasetShape[] = { rows, cols };
177176
MemorySegment datasetTensor = prepareTensor(arena, datasetMemorySegmentP, datasetShape, 2, 32, 2, 2, 1);
@@ -261,8 +260,7 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
261260
MemorySegment neighborsDP = neighborsD.get(C_POINTER, 0);
262261
MemorySegment distancesDP = distancesD.get(C_POINTER, 0);
263262

264-
returnValue = cudaMemcpy(queriesDP, querySeg, queriesBytes, 4);
265-
checkCudaError(returnValue, "cudaMemcpy");
263+
cudaMemcpy(queriesDP, querySeg, queriesBytes, INFER_DIRECTION);
266264

267265
long queriesShape[] = { numQueries, vectorDimension };
268266
MemorySegment queriesTensor = prepareTensor(arena, queriesDP, queriesShape, 2, 32, 2, 2, 1);
@@ -287,8 +285,7 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
287285

288286
prefilterDP = prefilterD.get(C_POINTER, 0);
289287

290-
returnValue = cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, 1);
291-
checkCudaError(returnValue, "cudaMemcpy");
288+
cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, HOST_TO_DEVICE);
292289

293290
prefilterTensor = prepareTensor(arena, prefilterDP, prefilterShape, 1, 32, 1, 2, 1);
294291

@@ -306,10 +303,8 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
306303
returnValue = cuvsStreamSync(cuvsResources);
307304
checkCuVSError(returnValue, "cuvsStreamSync");
308305

309-
returnValue = cudaMemcpy(neighborsMemorySegment, neighborsDP, neighborsBytes, 4);
310-
checkCudaError(returnValue, "cudaMemcpy");
311-
returnValue = cudaMemcpy(distancesMemorySegment, distancesDP, distanceBytes, 4);
312-
checkCudaError(returnValue, "cudaMemcpy");
306+
cudaMemcpy(neighborsMemorySegment, neighborsDP, neighborsBytes, INFER_DIRECTION);
307+
cudaMemcpy(distancesMemorySegment, distancesDP, distanceBytes, INFER_DIRECTION);
313308

314309
returnValue = cuvsRMMFree(cuvsResources, neighborsDP, neighborsBytes);
315310
checkCuVSError(returnValue, "cuvsRMMFree");

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CagraIndexImpl.java

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,16 @@
1616

1717
package com.nvidia.cuvs.internal;
1818

19-
import static java.lang.foreign.ValueLayout.ADDRESS;
2019
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_FLOAT;
2120
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_FLOAT_BYTE_SIZE;
2221
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT;
2322
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT_BYTE_SIZE;
2423
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_POINTER;
25-
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_LONG;
2624
import static com.nvidia.cuvs.internal.common.Util.buildMemorySegment;
2725
import static com.nvidia.cuvs.internal.common.Util.checkCuVSError;
28-
import static com.nvidia.cuvs.internal.common.Util.checkCudaError;
2926
import static com.nvidia.cuvs.internal.common.Util.concatenate;
27+
import static com.nvidia.cuvs.internal.common.Util.cudaMemcpy;
28+
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.*;
3029
import static com.nvidia.cuvs.internal.common.Util.prepareTensor;
3130
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraBuild;
3231
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsCagraDeserialize;
@@ -45,20 +44,17 @@
4544
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamGet;
4645
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamSync;
4746
import static com.nvidia.cuvs.internal.panama.headers_h.omp_set_num_threads;
48-
import static com.nvidia.cuvs.internal.panama.headers_h.cudaMemcpy;
4947
import static com.nvidia.cuvs.internal.panama.headers_h.cudaStream_t;
5048

5149
import java.io.FileInputStream;
5250
import java.io.FileOutputStream;
5351
import java.io.InputStream;
5452
import java.io.OutputStream;
5553
import java.lang.foreign.Arena;
56-
import java.lang.foreign.FunctionDescriptor;
5754
import java.lang.foreign.MemoryLayout;
5855
import java.lang.foreign.MemorySegment;
5956
import java.lang.foreign.SequenceLayout;
6057
import java.lang.foreign.ValueLayout;
61-
import java.lang.invoke.MethodHandle;
6258
import java.nio.file.Files;
6359
import java.nio.file.Path;
6460
import java.util.Objects;
@@ -287,8 +283,7 @@ public SearchResults search(CagraQuery query) throws Throwable {
287283
MemorySegment prefilterDP = MemorySegment.NULL;
288284
long prefilterLen = 0;
289285

290-
returnValue = cudaMemcpy(queriesDP, floatsSeg, queriesBytes, 4);
291-
checkCudaError(returnValue, "cudaMemcpy");
286+
cudaMemcpy(queriesDP, floatsSeg, queriesBytes, INFER_DIRECTION);
292287

293288
long queriesShape[] = { numQueries, vectorDimension };
294289
MemorySegment queriesTensor = prepareTensor(arena, queriesDP, queriesShape, 2, 32, 2, 2, 1);
@@ -329,8 +324,7 @@ public SearchResults search(CagraQuery query) throws Throwable {
329324

330325
prefilterDP = prefilterD.get(C_POINTER, 0);
331326

332-
returnValue = cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, 1);
333-
checkCudaError(returnValue, "cudaMemcpy");
327+
cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, HOST_TO_DEVICE);
334328

335329
prefilterTensor = prepareTensor(arena, prefilterDP, prefilterShape, 1, 32, 1, 2, 1);
336330

@@ -348,10 +342,8 @@ public SearchResults search(CagraQuery query) throws Throwable {
348342
returnValue = cuvsStreamSync(cuvsRes);
349343
checkCuVSError(returnValue, "cuvsStreamSync");
350344

351-
returnValue = cudaMemcpy(neighborsMemorySegment, neighborsDP, neighborsBytes, 4);
352-
checkCudaError(returnValue, "cudaMemcpy");
353-
returnValue = cudaMemcpy(distancesMemorySegment, distancesDP, distancesBytes, 4);
354-
checkCudaError(returnValue, "cudaMemcpy");
345+
cudaMemcpy(neighborsMemorySegment, neighborsDP, neighborsBytes, INFER_DIRECTION);
346+
cudaMemcpy(distancesMemorySegment, distancesDP, distancesBytes, INFER_DIRECTION);
355347

356348
returnValue = cuvsRMMFree(cuvsRes, distancesDP, distancesBytes);
357349
checkCuVSError(returnValue, "cuvsRMMFree");

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/common/Util.java

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,11 @@
2525
import static com.nvidia.cuvs.internal.panama.headers_h.cudaMemGetInfo;
2626
import static com.nvidia.cuvs.internal.panama.headers_h.cudaSetDevice;
2727
import static com.nvidia.cuvs.internal.panama.headers_h.size_t;
28-
import static java.lang.foreign.ValueLayout.ADDRESS;
2928

3029
import java.lang.foreign.Arena;
31-
import java.lang.foreign.FunctionDescriptor;
3230
import java.lang.foreign.MemoryLayout;
3331
import java.lang.foreign.MemoryLayout.PathElement;
3432
import java.lang.foreign.MemorySegment;
35-
import java.lang.invoke.MethodHandle;
3633
import java.lang.invoke.VarHandle;
3734
import java.util.ArrayList;
3835
import java.util.BitSet;
@@ -79,6 +76,48 @@ public static void checkCudaError(int value, String caller) {
7976
}
8077
}
8178

79+
/**
80+
* Java analog to CUDA's cudaMemcpyKind, used for cudaMemcpy() calls.
81+
* @see <a href="https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html">CUDA Runtime API</a>
82+
*/
83+
public enum CudaMemcpyKind {
84+
HOST_TO_HOST(0),
85+
HOST_TO_DEVICE(1),
86+
DEVICE_TO_HOST(2),
87+
DEVICE_TO_DEVICE(3),
88+
INFER_DIRECTION(4);
89+
90+
CudaMemcpyKind(int k) {
91+
this.kind = k;
92+
}
93+
94+
public final int kind;
95+
}
96+
97+
/**
98+
* Helper to invoke cudaMemcpy CUDA runtime function to copy data between host/device memory.
99+
* @param dest Destination address for data copy
100+
* @param src Source address for data copy
101+
* @param numBytes Number of bytes to be copied
102+
* @param kind "Direction" of data copy (Host->Device, Device->Host, etc.)
103+
* @throws RuntimeException on failure of copy
104+
*/
105+
public static void cudaMemcpy(MemorySegment dest, MemorySegment src, long numBytes, CudaMemcpyKind kind) {
106+
int returnValue = com.nvidia.cuvs.internal.panama.headers_h.cudaMemcpy(dest, src, numBytes, kind.kind);
107+
checkCudaError(returnValue, "cudaMemcpy");
108+
}
109+
110+
/**
111+
* Helper to invoke cudaMemcpy CUDA runtime function to copy data between host/device memory.
112+
* @param dest Destination address for data copy
113+
* @param src Source address for data copy
114+
* @param numBytes Number of bytes to be copied
115+
* @throws RuntimeException on failure of copy
116+
*/
117+
public static void cudaMemcpy(MemorySegment dest, MemorySegment src, long numBytes) {
118+
Util.cudaMemcpy(dest, src, numBytes, CudaMemcpyKind.INFER_DIRECTION);
119+
}
120+
82121
static final long MAX_ERROR_TEXT = 1_000_000L;
83122

84123
static String getLastErrorText() {

0 commit comments

Comments
 (0)