1515 */
1616package com .nvidia .cuvs .internal ;
1717
18+ import static com .nvidia .cuvs .internal .common .CloseableRMMAllocation .allocateRMMSegment ;
1819import static com .nvidia .cuvs .internal .common .LinkerHelper .C_FLOAT ;
1920import static com .nvidia .cuvs .internal .common .LinkerHelper .C_FLOAT_BYTE_SIZE ;
2021import static com .nvidia .cuvs .internal .common .LinkerHelper .C_INT_BYTE_SIZE ;
2122import static com .nvidia .cuvs .internal .common .LinkerHelper .C_LONG ;
2223import static com .nvidia .cuvs .internal .common .LinkerHelper .C_LONG_BYTE_SIZE ;
2324import static com .nvidia .cuvs .internal .common .Util .CudaMemcpyKind .HOST_TO_DEVICE ;
2425import static com .nvidia .cuvs .internal .common .Util .CudaMemcpyKind .INFER_DIRECTION ;
25- import static com .nvidia .cuvs .internal .common .Util .allocateRMMSegment ;
2626import static com .nvidia .cuvs .internal .common .Util .buildMemorySegment ;
2727import static com .nvidia .cuvs .internal .common .Util .checkCuVSError ;
2828import static com .nvidia .cuvs .internal .common .Util .concatenate ;
3535import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceIndex_t ;
3636import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceSearch ;
3737import static com .nvidia .cuvs .internal .panama .headers_h .cuvsBruteForceSerialize ;
38- import static com .nvidia .cuvs .internal .panama .headers_h .cuvsRMMFree ;
3938import static com .nvidia .cuvs .internal .panama .headers_h .cuvsStreamSync ;
4039import static com .nvidia .cuvs .internal .panama .headers_h .omp_set_num_threads ;
4140
4544import com .nvidia .cuvs .CuVSMatrix ;
4645import com .nvidia .cuvs .CuVSResources ;
4746import com .nvidia .cuvs .SearchResults ;
47+ import com .nvidia .cuvs .internal .common .CloseableRMMAllocation ;
4848import com .nvidia .cuvs .internal .panama .cuvsFilter ;
4949import java .io .InputStream ;
5050import java .io .OutputStream ;
@@ -118,20 +118,7 @@ public void destroyIndex() {
118118 try {
119119 int returnValue = cuvsBruteForceIndexDestroy (bruteForceIndexReference .indexPtr );
120120 checkCuVSError (returnValue , "cuvsBruteForceIndexDestroy" );
121-
122- if (bruteForceIndexReference .datasetBytes > 0 ) {
123- try (var resourcesAccessor = resources .access ()) {
124- checkCuVSError (
125- cuvsRMMFree (
126- resourcesAccessor .handle (),
127- bruteForceIndexReference .datasetPtr ,
128- bruteForceIndexReference .datasetBytes ),
129- "cuvsRMMFree" );
130- }
131- }
132- if (bruteForceIndexReference .tensorDataArena != null ) {
133- bruteForceIndexReference .tensorDataArena .close ();
134- }
121+ bruteForceIndexReference .close (resources );
135122 } finally {
136123 destroyed = true ;
137124 }
@@ -158,25 +145,31 @@ private IndexReference build(
158145
159146 try (var resourcesAccessor = resources .access ()) {
160147 long cuvsResources = resourcesAccessor .handle ();
161- MemorySegment datasetMemorySegmentP = allocateRMMSegment (cuvsResources , datasetBytes );
148+ try (var closeableDataMemorySegmentP = allocateRMMSegment (cuvsResources , datasetBytes )) {
149+ MemorySegment datasetMemorySegmentP = closeableDataMemorySegmentP .handle ();
162150
163- cudaMemcpy (datasetMemorySegmentP , datasetMemSegment , datasetBytes , INFER_DIRECTION );
151+ cudaMemcpy (datasetMemorySegmentP , datasetMemSegment , datasetBytes , INFER_DIRECTION );
164152
165- long [] datasetShape = {rows , cols };
166- var tensorDataArena = Arena .ofShared ();
167- MemorySegment datasetTensor =
168- prepareTensor (tensorDataArena , datasetMemorySegmentP , datasetShape , 2 , 32 , 2 , 1 );
153+ long [] datasetShape = {rows , cols };
154+ var tensorDataArena = Arena .ofShared ();
155+ MemorySegment datasetTensor =
156+ prepareTensor (tensorDataArena , datasetMemorySegmentP , datasetShape , 2 , 32 , 2 , 1 );
169157
170- var returnValue = cuvsStreamSync (cuvsResources );
171- checkCuVSError (returnValue , "cuvsStreamSync" );
158+ var returnValue = cuvsStreamSync (cuvsResources );
159+ checkCuVSError (returnValue , "cuvsStreamSync" );
172160
173- returnValue = cuvsBruteForceBuild (cuvsResources , datasetTensor , 0 , 0.0f , index );
174- checkCuVSError (returnValue , "cuvsBruteForceBuild" );
161+ returnValue = cuvsBruteForceBuild (cuvsResources , datasetTensor , 0 , 0.0f , index );
162+ checkCuVSError (returnValue , "cuvsBruteForceBuild" );
175163
176- returnValue = cuvsStreamSync (cuvsResources );
177- checkCuVSError (returnValue , "cuvsStreamSync" );
164+ returnValue = cuvsStreamSync (cuvsResources );
165+ checkCuVSError (returnValue , "cuvsStreamSync" );
178166
179- return new IndexReference (datasetMemorySegmentP , datasetBytes , tensorDataArena , index );
167+ return new IndexReference (
168+ new CloseableRMMAllocation (closeableDataMemorySegmentP ),
169+ datasetBytes ,
170+ tensorDataArena ,
171+ index );
172+ }
180173 } finally {
181174 omp_set_num_threads (1 );
182175 }
@@ -205,15 +198,19 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
205198
206199 // prepare the prefiltering data
207200 final long prefilterDataLength ;
201+ final long prefilterBytes ;
208202 final MemorySegment prefilterDataMemorySegment ;
209203 BitSet [] prefilters = cuvsQuery .getPrefilters ();
210204 if (prefilters != null && prefilters .length > 0 ) {
211205 BitSet concatenatedFilters = concatenate (prefilters , cuvsQuery .getNumDocs ());
212206 long [] filters = concatenatedFilters .toLongArray ();
213207 prefilterDataMemorySegment = buildMemorySegment (localArena , filters );
214208 prefilterDataLength = (long ) cuvsQuery .getNumDocs () * prefilters .length ;
209+ long [] prefilterShape = {(prefilterDataLength + 31 ) / 32 };
210+ prefilterBytes = C_INT_BYTE_SIZE * prefilterShape [0 ];
215211 } else {
216212 prefilterDataLength = 0 ;
213+ prefilterBytes = 0 ;
217214 prefilterDataMemorySegment = MemorySegment .NULL ;
218215 }
219216
@@ -223,77 +220,66 @@ public SearchResults search(BruteForceQuery cuvsQuery) throws Throwable {
223220 try (var resourcesAccessor = cuvsQuery .getResources ().access ()) {
224221 long cuvsResources = resourcesAccessor .handle ();
225222
226- long queriesBytes = C_FLOAT_BYTE_SIZE * numQueries * vectorDimension ;
227- long neighborsBytes = C_LONG_BYTE_SIZE * numQueries * topk ;
228- long distanceBytes = C_FLOAT_BYTE_SIZE * numQueries * topk ;
229- long prefilterBytes = 0 ; // size assigned later
230-
231- MemorySegment queriesDP = allocateRMMSegment (cuvsResources , queriesBytes );
232- MemorySegment neighborsDP = allocateRMMSegment (cuvsResources , neighborsBytes );
233- MemorySegment distancesDP = allocateRMMSegment (cuvsResources , distanceBytes );
234- MemorySegment prefilterDP = MemorySegment .NULL ;
235-
236- cudaMemcpy (queriesDP , querySeg , queriesBytes , INFER_DIRECTION );
237-
238- long [] queriesShape = {numQueries , vectorDimension };
239- MemorySegment queriesTensor =
240- prepareTensor (localArena , queriesDP , queriesShape , 2 , 32 , 2 , 1 );
241- long [] neighborsShape = {numQueries , topk };
242- MemorySegment neighborsTensor =
243- prepareTensor (localArena , neighborsDP , neighborsShape , 0 , 64 , 2 , 1 );
244- long [] distancesShape = {numQueries , topk };
245- MemorySegment distancesTensor =
246- prepareTensor (localArena , distancesDP , distancesShape , 2 , 32 , 2 , 1 );
247-
248- MemorySegment prefilter = cuvsFilter .allocate (localArena );
249- MemorySegment prefilterTensor ;
250-
251- if (prefilterDataMemorySegment == MemorySegment .NULL ) {
252- cuvsFilter .type (prefilter , 0 ); // NO_FILTER
253- cuvsFilter .addr (prefilter , 0 );
254- } else {
255- long [] prefilterShape = {(prefilterDataLength + 31 ) / 32 };
256- long prefilterLen = prefilterShape [0 ];
257- prefilterBytes = C_INT_BYTE_SIZE * prefilterLen ;
258-
259- prefilterDP = allocateRMMSegment (cuvsResources , prefilterBytes );
260-
261- cudaMemcpy (prefilterDP , prefilterDataMemorySegment , prefilterBytes , HOST_TO_DEVICE );
262-
263- prefilterTensor = prepareTensor (localArena , prefilterDP , prefilterShape , 1 , 32 , 2 , 1 );
264-
265- cuvsFilter .type (prefilter , 2 );
266- cuvsFilter .addr (prefilter , prefilterTensor .address ());
267- }
268-
269- var returnValue = cuvsStreamSync (cuvsResources );
270- checkCuVSError (returnValue , "cuvsStreamSync" );
271-
272- returnValue =
273- cuvsBruteForceSearch (
274- cuvsResources ,
275- bruteForceIndexReference .indexPtr ,
276- queriesTensor ,
277- neighborsTensor ,
278- distancesTensor ,
279- prefilter );
280- checkCuVSError (returnValue , "cuvsBruteForceSearch" );
281-
282- returnValue = cuvsStreamSync (cuvsResources );
283- checkCuVSError (returnValue , "cuvsStreamSync" );
284-
285- cudaMemcpy (neighborsMemorySegment , neighborsDP , neighborsBytes , INFER_DIRECTION );
286- cudaMemcpy (distancesMemorySegment , distancesDP , distanceBytes , INFER_DIRECTION );
287-
288- returnValue = cuvsRMMFree (cuvsResources , neighborsDP , neighborsBytes );
289- checkCuVSError (returnValue , "cuvsRMMFree" );
290- returnValue = cuvsRMMFree (cuvsResources , distancesDP , distanceBytes );
291- checkCuVSError (returnValue , "cuvsRMMFree" );
292- returnValue = cuvsRMMFree (cuvsResources , queriesDP , queriesBytes );
293- checkCuVSError (returnValue , "cuvsRMMFree" );
294- if (prefilterBytes > 0 ) {
295- returnValue = cuvsRMMFree (cuvsResources , prefilterDP , prefilterBytes );
296- checkCuVSError (returnValue , "cuvsRMMFree" );
223+ final long queriesBytes = C_FLOAT_BYTE_SIZE * numQueries * vectorDimension ;
224+ final long neighborsBytes = C_LONG_BYTE_SIZE * numQueries * topk ;
225+ final long distanceBytes = C_FLOAT_BYTE_SIZE * numQueries * topk ;
226+
227+ try (var queriesDP = allocateRMMSegment (cuvsResources , queriesBytes );
228+ var neighborsDP = allocateRMMSegment (cuvsResources , neighborsBytes );
229+ var distancesDP = allocateRMMSegment (cuvsResources , distanceBytes );
230+ var prefilterDP =
231+ prefilterBytes > 0
232+ ? allocateRMMSegment (cuvsResources , prefilterBytes )
233+ : CloseableRMMAllocation .EMPTY ) {
234+
235+ cudaMemcpy (queriesDP .handle (), querySeg , queriesBytes , INFER_DIRECTION );
236+
237+ long [] queriesShape = {numQueries , vectorDimension };
238+ MemorySegment queriesTensor =
239+ prepareTensor (localArena , queriesDP .handle (), queriesShape , 2 , 32 , 2 , 1 );
240+ long [] neighborsShape = {numQueries , topk };
241+ MemorySegment neighborsTensor =
242+ prepareTensor (localArena , neighborsDP .handle (), neighborsShape , 0 , 64 , 2 , 1 );
243+ long [] distancesShape = {numQueries , topk };
244+ MemorySegment distancesTensor =
245+ prepareTensor (localArena , distancesDP .handle (), distancesShape , 2 , 32 , 2 , 1 );
246+
247+ MemorySegment prefilter = cuvsFilter .allocate (localArena );
248+ MemorySegment prefilterTensor ;
249+
250+ if (prefilterDataMemorySegment == MemorySegment .NULL ) {
251+ cuvsFilter .type (prefilter , 0 ); // NO_FILTER
252+ cuvsFilter .addr (prefilter , 0 );
253+ } else {
254+ long [] prefilterShape = {(prefilterDataLength + 31 ) / 32 };
255+ cudaMemcpy (
256+ prefilterDP .handle (), prefilterDataMemorySegment , prefilterBytes , HOST_TO_DEVICE );
257+
258+ prefilterTensor =
259+ prepareTensor (localArena , prefilterDP .handle (), prefilterShape , 1 , 32 , 2 , 1 );
260+
261+ cuvsFilter .type (prefilter , 2 );
262+ cuvsFilter .addr (prefilter , prefilterTensor .address ());
263+ }
264+
265+ var returnValue = cuvsStreamSync (cuvsResources );
266+ checkCuVSError (returnValue , "cuvsStreamSync" );
267+
268+ returnValue =
269+ cuvsBruteForceSearch (
270+ cuvsResources ,
271+ bruteForceIndexReference .indexPtr ,
272+ queriesTensor ,
273+ neighborsTensor ,
274+ distancesTensor ,
275+ prefilter );
276+ checkCuVSError (returnValue , "cuvsBruteForceSearch" );
277+
278+ returnValue = cuvsStreamSync (cuvsResources );
279+ checkCuVSError (returnValue , "cuvsStreamSync" );
280+
281+ cudaMemcpy (neighborsMemorySegment , neighborsDP .handle (), neighborsBytes , INFER_DIRECTION );
282+ cudaMemcpy (distancesMemorySegment , distancesDP .handle (), distanceBytes , INFER_DIRECTION );
297283 }
298284 }
299285 return BruteForceSearchResults .create (
@@ -479,27 +465,39 @@ public BruteForceIndexImpl build() throws Throwable {
479465 */
480466 private static class IndexReference {
481467
482- private final MemorySegment datasetPtr ;
468+ private final CloseableRMMAllocation datasetAllocationHandle ;
483469 private final long datasetBytes ;
484470 private final Arena tensorDataArena ;
485471 private final MemorySegment indexPtr ;
486472
487473 private IndexReference (
488- MemorySegment datasetPtr ,
474+ CloseableRMMAllocation datasetAllocationHandle ,
489475 long datasetBytes ,
490476 Arena tensorDataArena ,
491477 MemorySegment indexPtr ) {
492- this .datasetPtr = datasetPtr ;
478+ this .datasetAllocationHandle = datasetAllocationHandle ;
493479 this .datasetBytes = datasetBytes ;
494480 this .tensorDataArena = tensorDataArena ;
495481 this .indexPtr = indexPtr ;
496482 }
497483
498484 private IndexReference (MemorySegment indexPtr ) {
499- this .datasetPtr = MemorySegment . NULL ;
485+ this .datasetAllocationHandle = CloseableRMMAllocation . EMPTY ;
500486 this .datasetBytes = 0 ;
501487 this .tensorDataArena = null ;
502488 this .indexPtr = indexPtr ;
503489 }
490+
491+ /**
492+ * Free up the memory used for dataset, tensor-data.
493+ */
494+ private void close (CuVSResources resources ) {
495+ try (var resourcesAccessor = resources .access ()) {
496+ datasetAllocationHandle .close ();
497+ }
498+ if (tensorDataArena != null ) {
499+ tensorDataArena .close ();
500+ }
501+ }
504502 }
505503}
0 commit comments