1818import static com .nvidia .cuvs .internal .common .Util .*;
1919import static com .nvidia .cuvs .internal .panama .headers_h .cuvsVersionGet ;
2020import static com .nvidia .cuvs .internal .panama .headers_h .uint16_t ;
21+ import static com .nvidia .cuvs .internal .panama .headers_h_1 .cudaStreamSynchronize ;
2122
2223import com .nvidia .cuvs .*;
2324import com .nvidia .cuvs .internal .*;
25+ import com .nvidia .cuvs .internal .common .PinnedMemoryBuffer ;
2426import com .nvidia .cuvs .internal .common .Util ;
2527import java .io .IOException ;
2628import java .lang .foreign .Arena ;
@@ -216,7 +218,7 @@ public CuVSHostMatrix build() {
216218 public CuVSMatrix .Builder <CuVSDeviceMatrix > newDeviceMatrixBuilder (
217219 CuVSResources resources , long size , long columns , CuVSMatrix .DataType dataType )
218220 throws UnsupportedOperationException {
219- return new HeapSegmentBuilder (resources , size , columns , dataType );
221+ return new BufferedSegmentBuilder (resources , size , columns , dataType );
220222 }
221223
222224 @ Override
@@ -227,7 +229,7 @@ public CuVSMatrix.Builder<CuVSDeviceMatrix> newDeviceMatrixBuilder(
227229 int rowStride ,
228230 int columnStride ,
229231 CuVSMatrix .DataType dataType ) {
230- return new HeapSegmentBuilder (resources , size , columns , rowStride , columnStride , dataType );
232+ return new BufferedSegmentBuilder (resources , size , columns , rowStride , columnStride , dataType );
231233 }
232234
233235 @ Override
@@ -279,28 +281,38 @@ public CuVSMatrix newMatrixFromArray(byte[][] vectors) {
279281
280282 /**
281283 * This {@link CuVSDeviceMatrix} builder implementation returns a {@link CuVSDeviceMatrix} backed by managed RMM
282- * device memory. It uses a non-native {@link MemorySegment} created directly from on-heap java arrays to avoid
283- * an intermediate allocation and copy to a native (off-heap) segment.
284- * It requires the copy function ({@code cudaMemcpyAsync}) to have the {@code Critical} linker option in order
285- * to allow the access to on-heap memory (see {@link Util#cudaMemcpyAsync}).
284+ * device memory. It uses a {@link PinnedMemoryBuffer} to batch data before copying it to the GPU.
286285 */
287- private static class HeapSegmentBuilder implements CuVSMatrix .Builder <CuVSDeviceMatrix > {
286+ private static class BufferedSegmentBuilder implements CuVSMatrix .Builder <CuVSDeviceMatrix > {
287+
288288 private final long columns ;
289289 private final long size ;
290290 private final CuVSDeviceMatrixImpl matrix ;
291291 private final MemorySegment stream ;
292- private int current ;
293292
294- private HeapSegmentBuilder (
293+ private final long rowBytes ;
294+ private int currentRow ;
295+
296+ private final PinnedMemoryBuffer hostBuffer ;
297+ private final long bufferRowCount ;
298+ private int currentBufferRow ;
299+
300+ private BufferedSegmentBuilder (
295301 CuVSResources resources , long size , long columns , CuVSMatrix .DataType dataType ) {
296302 this .columns = columns ;
297303 this .size = size ;
298304 this .matrix = CuVSDeviceMatrixRMMImpl .create (resources , size , columns , dataType );
299305 this .stream = Util .getStream (resources );
300- this .current = 0 ;
306+ this .currentRow = 0 ;
307+
308+ this .hostBuffer = new PinnedMemoryBuffer (size , columns , matrix .valueLayout ());
309+
310+ this .rowBytes = columns * matrix .valueLayout ().byteSize ();
311+ this .bufferRowCount = Math .min ((hostBuffer .size () / rowBytes ), size );
312+ this .currentBufferRow = 0 ;
301313 }
302314
303- private HeapSegmentBuilder (
315+ private BufferedSegmentBuilder (
304316 CuVSResources resources ,
305317 long size ,
306318 long columns ,
@@ -313,7 +325,13 @@ private HeapSegmentBuilder(
313325 CuVSDeviceMatrixRMMImpl .create (
314326 resources , size , columns , rowStride , columnStride , dataType );
315327 this .stream = Util .getStream (resources );
316- this .current = 0 ;
328+ this .currentRow = 0 ;
329+
330+ this .hostBuffer = new PinnedMemoryBuffer (size , columns , matrix .valueLayout ());
331+
332+ this .rowBytes = columns * matrix .valueLayout ().byteSize ();
333+ this .bufferRowCount = Math .min ((hostBuffer .size () / rowBytes ), size );
334+ this .currentBufferRow = 0 ;
317335 }
318336
319337 @ Override
@@ -347,19 +365,38 @@ public void addVector(int[] vector) {
347365 }
348366
349367 private void internalAddVector (MemorySegment vector ) {
350- if (current >= size ) {
368+ if (currentRow >= size ) {
351369 throw new ArrayIndexOutOfBoundsException ();
352370 }
371+ var hostBufferOffset = currentBufferRow * rowBytes ;
372+ MemorySegment .copy (vector , 0 , hostBuffer .address (), hostBufferOffset , rowBytes );
353373
354- long rowBytes = columns * matrix .valueLayout ().byteSize ();
374+ currentRow ++;
375+ currentBufferRow ++;
376+ if (currentBufferRow == bufferRowCount ) {
377+ flushBuffer ();
378+ }
379+ }
355380
356- var dstOffset = ((current ++) * rowBytes );
357- var dst = matrix .memorySegment ().asSlice (dstOffset );
358- cudaMemcpyAsync (dst , vector , rowBytes , CudaMemcpyKind .HOST_TO_DEVICE , stream );
381+ private void flushBuffer () {
382+ if (currentBufferRow > 0 ) {
383+ var deviceMemoryOffset = (currentRow - currentBufferRow ) * rowBytes ;
384+ var dst = matrix .memorySegment ().asSlice (deviceMemoryOffset );
385+ cudaMemcpyAsync (
386+ dst ,
387+ hostBuffer .address (),
388+ currentBufferRow * rowBytes ,
389+ CudaMemcpyKind .HOST_TO_DEVICE ,
390+ stream );
391+ currentBufferRow = 0 ;
392+ checkCudaError (cudaStreamSynchronize (stream ), "cudaStreamSynchronize" );
393+ }
359394 }
360395
361396 @ Override
362397 public CuVSDeviceMatrix build () {
398+ flushBuffer ();
399+ hostBuffer .close ();
363400 return matrix ;
364401 }
365402 }
0 commit comments