Skip to content

Commit 5e3a5a6

Browse files
authored
[Java] Exception-safe RMM Allocations (#1215)
This commit introduces exception safety for RMM allocations. Previously, device memory allocated through `cuvsRmmAlloc()` was freed manually using `cuvsRmmFree()`, in all the index impl classes. The problem there is that if an exception is thrown in the intervening time between alloc and free, it would lead to a leak of device memory. This commit extends the `CloseableHandle` class to encapsulate the allocation of device memory. This new class is used in try-with-resources blocks, to make device memory allocations exception-safe. Authors: - MithunR (https://github.com/mythrocks) Approvers: - Lorenzo Dematté (https://github.com/ldematte) - Corey J. Nolet (https://github.com/cjnolet) URL: #1215
1 parent afc24ee commit 5e3a5a6

5 files changed

Lines changed: 424 additions & 354 deletions

File tree

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/BruteForceIndexImpl.java

Lines changed: 102 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
*/
1616
package com.nvidia.cuvs.internal;
1717

18+
import static com.nvidia.cuvs.internal.common.CloseableRMMAllocation.allocateRMMSegment;
1819
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_FLOAT;
1920
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_FLOAT_BYTE_SIZE;
2021
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT_BYTE_SIZE;
2122
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_LONG;
2223
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_LONG_BYTE_SIZE;
2324
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.HOST_TO_DEVICE;
2425
import static com.nvidia.cuvs.internal.common.Util.CudaMemcpyKind.INFER_DIRECTION;
25-
import static com.nvidia.cuvs.internal.common.Util.allocateRMMSegment;
2626
import static com.nvidia.cuvs.internal.common.Util.buildMemorySegment;
2727
import static com.nvidia.cuvs.internal.common.Util.checkCuVSError;
2828
import static com.nvidia.cuvs.internal.common.Util.concatenate;
@@ -35,7 +35,6 @@
3535
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceIndex_t;
3636
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceSearch;
3737
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsBruteForceSerialize;
38-
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsRMMFree;
3938
import static com.nvidia.cuvs.internal.panama.headers_h.cuvsStreamSync;
4039
import static com.nvidia.cuvs.internal.panama.headers_h.omp_set_num_threads;
4140

@@ -45,6 +44,7 @@
4544
import com.nvidia.cuvs.CuVSMatrix;
4645
import com.nvidia.cuvs.CuVSResources;
4746
import com.nvidia.cuvs.SearchResults;
47+
import com.nvidia.cuvs.internal.common.CloseableRMMAllocation;
4848
import com.nvidia.cuvs.internal.panama.cuvsFilter;
4949
import java.io.InputStream;
5050
import java.io.OutputStream;
@@ -118,20 +118,7 @@ public void destroyIndex() {
118118
try {
119119
int returnValue = cuvsBruteForceIndexDestroy(bruteForceIndexReference.indexPtr);
120120
checkCuVSError(returnValue, "cuvsBruteForceIndexDestroy");
121-
122-
if (bruteForceIndexReference.datasetBytes > 0) {
123-
try (var resourcesAccessor = resources.access()) {
124-
checkCuVSError(
125-
cuvsRMMFree(
126-
resourcesAccessor.handle(),
127-
bruteForceIndexReference.datasetPtr,
128-
bruteForceIndexReference.datasetBytes),
129-
"cuvsRMMFree");
130-
}
131-
}
132-
if (bruteForceIndexReference.tensorDataArena != null) {
133-
bruteForceIndexReference.tensorDataArena.close();
134-
}
121+
bruteForceIndexReference.close(resources);
135122
} finally {
136123
destroyed = true;
137124
}
@@ -158,25 +145,31 @@ private IndexReference build(
158145

159146
try (var resourcesAccessor = resources.access()) {
160147
long cuvsResources = resourcesAccessor.handle();
161-
MemorySegment datasetMemorySegmentP = allocateRMMSegment(cuvsResources, datasetBytes);
148+
try (var closeableDataMemorySegmentP = allocateRMMSegment(cuvsResources, datasetBytes)) {
149+
MemorySegment datasetMemorySegmentP = closeableDataMemorySegmentP.handle();
162150

163-
cudaMemcpy(datasetMemorySegmentP, datasetMemSegment, datasetBytes, INFER_DIRECTION);
151+
cudaMemcpy(datasetMemorySegmentP, datasetMemSegment, datasetBytes, INFER_DIRECTION);
164152

165-
long[] datasetShape = {rows, cols};
166-
var tensorDataArena = Arena.ofShared();
167-
MemorySegment datasetTensor =
168-
prepareTensor(tensorDataArena, datasetMemorySegmentP, datasetShape, 2, 32, 2, 1);
153+
long[] datasetShape = {rows, cols};
154+
var tensorDataArena = Arena.ofShared();
155+
MemorySegment datasetTensor =
156+
prepareTensor(tensorDataArena, datasetMemorySegmentP, datasetShape, 2, 32, 2, 1);
169157

170-
var returnValue = cuvsStreamSync(cuvsResources);
171-
checkCuVSError(returnValue, "cuvsStreamSync");
158+
var returnValue = cuvsStreamSync(cuvsResources);
159+
checkCuVSError(returnValue, "cuvsStreamSync");
172160

173-
returnValue = cuvsBruteForceBuild(cuvsResources, datasetTensor, 0, 0.0f, index);
174-
checkCuVSError(returnValue, "cuvsBruteForceBuild");
161+
returnValue = cuvsBruteForceBuild(cuvsResources, datasetTensor, 0, 0.0f, index);
162+
checkCuVSError(returnValue, "cuvsBruteForceBuild");
175163

176-
returnValue = cuvsStreamSync(cuvsResources);
177-
checkCuVSError(returnValue, "cuvsStreamSync");
164+
returnValue = cuvsStreamSync(cuvsResources);
165+
checkCuVSError(returnValue, "cuvsStreamSync");
178166

179-
return new IndexReference(datasetMemorySegmentP, datasetBytes, tensorDataArena, index);
167+
return new IndexReference(
168+
new CloseableRMMAllocation(closeableDataMemorySegmentP),
169+
datasetBytes,
170+
tensorDataArena,
171+
index);
172+
}
180173
} finally {
181174
omp_set_num_threads(1);
182175
}
@@ -205,15 +198,19 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
205198

206199
// prepare the prefiltering data
207200
final long prefilterDataLength;
201+
final long prefilterBytes;
208202
final MemorySegment prefilterDataMemorySegment;
209203
BitSet[] prefilters = cuvsQuery.getPrefilters();
210204
if (prefilters != null && prefilters.length > 0) {
211205
BitSet concatenatedFilters = concatenate(prefilters, cuvsQuery.getNumDocs());
212206
long[] filters = concatenatedFilters.toLongArray();
213207
prefilterDataMemorySegment = buildMemorySegment(localArena, filters);
214208
prefilterDataLength = (long) cuvsQuery.getNumDocs() * prefilters.length;
209+
long[] prefilterShape = {(prefilterDataLength + 31) / 32};
210+
prefilterBytes = C_INT_BYTE_SIZE * prefilterShape[0];
215211
} else {
216212
prefilterDataLength = 0;
213+
prefilterBytes = 0;
217214
prefilterDataMemorySegment = MemorySegment.NULL;
218215
}
219216

@@ -223,77 +220,66 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
223220
try (var resourcesAccessor = cuvsQuery.getResources().access()) {
224221
long cuvsResources = resourcesAccessor.handle();
225222

226-
long queriesBytes = C_FLOAT_BYTE_SIZE * numQueries * vectorDimension;
227-
long neighborsBytes = C_LONG_BYTE_SIZE * numQueries * topk;
228-
long distanceBytes = C_FLOAT_BYTE_SIZE * numQueries * topk;
229-
long prefilterBytes = 0; // size assigned later
230-
231-
MemorySegment queriesDP = allocateRMMSegment(cuvsResources, queriesBytes);
232-
MemorySegment neighborsDP = allocateRMMSegment(cuvsResources, neighborsBytes);
233-
MemorySegment distancesDP = allocateRMMSegment(cuvsResources, distanceBytes);
234-
MemorySegment prefilterDP = MemorySegment.NULL;
235-
236-
cudaMemcpy(queriesDP, querySeg, queriesBytes, INFER_DIRECTION);
237-
238-
long[] queriesShape = {numQueries, vectorDimension};
239-
MemorySegment queriesTensor =
240-
prepareTensor(localArena, queriesDP, queriesShape, 2, 32, 2, 1);
241-
long[] neighborsShape = {numQueries, topk};
242-
MemorySegment neighborsTensor =
243-
prepareTensor(localArena, neighborsDP, neighborsShape, 0, 64, 2, 1);
244-
long[] distancesShape = {numQueries, topk};
245-
MemorySegment distancesTensor =
246-
prepareTensor(localArena, distancesDP, distancesShape, 2, 32, 2, 1);
247-
248-
MemorySegment prefilter = cuvsFilter.allocate(localArena);
249-
MemorySegment prefilterTensor;
250-
251-
if (prefilterDataMemorySegment == MemorySegment.NULL) {
252-
cuvsFilter.type(prefilter, 0); // NO_FILTER
253-
cuvsFilter.addr(prefilter, 0);
254-
} else {
255-
long[] prefilterShape = {(prefilterDataLength + 31) / 32};
256-
long prefilterLen = prefilterShape[0];
257-
prefilterBytes = C_INT_BYTE_SIZE * prefilterLen;
258-
259-
prefilterDP = allocateRMMSegment(cuvsResources, prefilterBytes);
260-
261-
cudaMemcpy(prefilterDP, prefilterDataMemorySegment, prefilterBytes, HOST_TO_DEVICE);
262-
263-
prefilterTensor = prepareTensor(localArena, prefilterDP, prefilterShape, 1, 32, 2, 1);
264-
265-
cuvsFilter.type(prefilter, 2);
266-
cuvsFilter.addr(prefilter, prefilterTensor.address());
267-
}
268-
269-
var returnValue = cuvsStreamSync(cuvsResources);
270-
checkCuVSError(returnValue, "cuvsStreamSync");
271-
272-
returnValue =
273-
cuvsBruteForceSearch(
274-
cuvsResources,
275-
bruteForceIndexReference.indexPtr,
276-
queriesTensor,
277-
neighborsTensor,
278-
distancesTensor,
279-
prefilter);
280-
checkCuVSError(returnValue, "cuvsBruteForceSearch");
281-
282-
returnValue = cuvsStreamSync(cuvsResources);
283-
checkCuVSError(returnValue, "cuvsStreamSync");
284-
285-
cudaMemcpy(neighborsMemorySegment, neighborsDP, neighborsBytes, INFER_DIRECTION);
286-
cudaMemcpy(distancesMemorySegment, distancesDP, distanceBytes, INFER_DIRECTION);
287-
288-
returnValue = cuvsRMMFree(cuvsResources, neighborsDP, neighborsBytes);
289-
checkCuVSError(returnValue, "cuvsRMMFree");
290-
returnValue = cuvsRMMFree(cuvsResources, distancesDP, distanceBytes);
291-
checkCuVSError(returnValue, "cuvsRMMFree");
292-
returnValue = cuvsRMMFree(cuvsResources, queriesDP, queriesBytes);
293-
checkCuVSError(returnValue, "cuvsRMMFree");
294-
if (prefilterBytes > 0) {
295-
returnValue = cuvsRMMFree(cuvsResources, prefilterDP, prefilterBytes);
296-
checkCuVSError(returnValue, "cuvsRMMFree");
223+
final long queriesBytes = C_FLOAT_BYTE_SIZE * numQueries * vectorDimension;
224+
final long neighborsBytes = C_LONG_BYTE_SIZE * numQueries * topk;
225+
final long distanceBytes = C_FLOAT_BYTE_SIZE * numQueries * topk;
226+
227+
try (var queriesDP = allocateRMMSegment(cuvsResources, queriesBytes);
228+
var neighborsDP = allocateRMMSegment(cuvsResources, neighborsBytes);
229+
var distancesDP = allocateRMMSegment(cuvsResources, distanceBytes);
230+
var prefilterDP =
231+
prefilterBytes > 0
232+
? allocateRMMSegment(cuvsResources, prefilterBytes)
233+
: CloseableRMMAllocation.EMPTY) {
234+
235+
cudaMemcpy(queriesDP.handle(), querySeg, queriesBytes, INFER_DIRECTION);
236+
237+
long[] queriesShape = {numQueries, vectorDimension};
238+
MemorySegment queriesTensor =
239+
prepareTensor(localArena, queriesDP.handle(), queriesShape, 2, 32, 2, 1);
240+
long[] neighborsShape = {numQueries, topk};
241+
MemorySegment neighborsTensor =
242+
prepareTensor(localArena, neighborsDP.handle(), neighborsShape, 0, 64, 2, 1);
243+
long[] distancesShape = {numQueries, topk};
244+
MemorySegment distancesTensor =
245+
prepareTensor(localArena, distancesDP.handle(), distancesShape, 2, 32, 2, 1);
246+
247+
MemorySegment prefilter = cuvsFilter.allocate(localArena);
248+
MemorySegment prefilterTensor;
249+
250+
if (prefilterDataMemorySegment == MemorySegment.NULL) {
251+
cuvsFilter.type(prefilter, 0); // NO_FILTER
252+
cuvsFilter.addr(prefilter, 0);
253+
} else {
254+
long[] prefilterShape = {(prefilterDataLength + 31) / 32};
255+
cudaMemcpy(
256+
prefilterDP.handle(), prefilterDataMemorySegment, prefilterBytes, HOST_TO_DEVICE);
257+
258+
prefilterTensor =
259+
prepareTensor(localArena, prefilterDP.handle(), prefilterShape, 1, 32, 2, 1);
260+
261+
cuvsFilter.type(prefilter, 2);
262+
cuvsFilter.addr(prefilter, prefilterTensor.address());
263+
}
264+
265+
var returnValue = cuvsStreamSync(cuvsResources);
266+
checkCuVSError(returnValue, "cuvsStreamSync");
267+
268+
returnValue =
269+
cuvsBruteForceSearch(
270+
cuvsResources,
271+
bruteForceIndexReference.indexPtr,
272+
queriesTensor,
273+
neighborsTensor,
274+
distancesTensor,
275+
prefilter);
276+
checkCuVSError(returnValue, "cuvsBruteForceSearch");
277+
278+
returnValue = cuvsStreamSync(cuvsResources);
279+
checkCuVSError(returnValue, "cuvsStreamSync");
280+
281+
cudaMemcpy(neighborsMemorySegment, neighborsDP.handle(), neighborsBytes, INFER_DIRECTION);
282+
cudaMemcpy(distancesMemorySegment, distancesDP.handle(), distanceBytes, INFER_DIRECTION);
297283
}
298284
}
299285
return BruteForceSearchResults.create(
@@ -479,27 +465,39 @@ public BruteForceIndexImpl build() throws Throwable {
479465
*/
480466
private static class IndexReference {
481467

482-
private final MemorySegment datasetPtr;
468+
private final CloseableRMMAllocation datasetAllocationHandle;
483469
private final long datasetBytes;
484470
private final Arena tensorDataArena;
485471
private final MemorySegment indexPtr;
486472

487473
private IndexReference(
488-
MemorySegment datasetPtr,
474+
CloseableRMMAllocation datasetAllocationHandle,
489475
long datasetBytes,
490476
Arena tensorDataArena,
491477
MemorySegment indexPtr) {
492-
this.datasetPtr = datasetPtr;
478+
this.datasetAllocationHandle = datasetAllocationHandle;
493479
this.datasetBytes = datasetBytes;
494480
this.tensorDataArena = tensorDataArena;
495481
this.indexPtr = indexPtr;
496482
}
497483

498484
private IndexReference(MemorySegment indexPtr) {
499-
this.datasetPtr = MemorySegment.NULL;
485+
this.datasetAllocationHandle = CloseableRMMAllocation.EMPTY;
500486
this.datasetBytes = 0;
501487
this.tensorDataArena = null;
502488
this.indexPtr = indexPtr;
503489
}
490+
491+
/**
492+
* Free up the memory used for dataset, tensor-data.
493+
*/
494+
private void close(CuVSResources resources) {
495+
try (var resourcesAccessor = resources.access()) {
496+
datasetAllocationHandle.close();
497+
}
498+
if (tensorDataArena != null) {
499+
tensorDataArena.close();
500+
}
501+
}
504502
}
505503
}

0 commit comments

Comments
 (0)