Skip to content

Commit d5edf12

Browse files
authored
[Java] Uniform toHost/toDevice to work across all CuVSMatrix classes (#1328)
This PR moves `toHost`/`toDevice` operations to `CuVSMatrix` in order to uniform them; implementations then do not need to worry or perform checks on the input type, but simply accept any `CuVSMatrix` and then "request" the type they want or need (no ifs, custom cudaAlloc or cudaMemcpy operations needed): ``` CuVSDeviceMatrix matrix = input.toDevice(resources); ``` If the source matrix is a `CuVSHostMatrix`, device memory will be allocated and data will be copied; if the source matrix is a already a `CuVSDeviceMatrix`, a (wrapped) reference to the original data is returned (very fast, no copies, no additional allocations). The wrapper is currently implemented to follow C++ `weak_ptr` semantics: it delegates everything to the wrapped matrix, the only difference is on `close()` (which is a no-op in the wrapper, as you'd expect). The lifecycle of the data follows the one of the original matrix, the "weak" delegate does not influence it in any way. This way the caller can handle it uniformly (e.g. with try-with-resources), without side effects on the original matrix (which lifecycle will be already handled by its owner). An alternative would be to give this C++ `shared_ptr` semantics, making it reference counted. I think this would be more complex and not necessary, but let me know. Authors: - Lorenzo Dematté (https://github.com/ldematte) - MithunR (https://github.com/mythrocks) Approvers: - MithunR (https://github.com/mythrocks) URL: #1328
1 parent b3a4fd8 commit d5edf12

7 files changed

Lines changed: 317 additions & 14 deletions

File tree

java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSDeviceMatrix.java

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,7 @@
2121
public interface CuVSDeviceMatrix extends CuVSMatrix {
2222

2323
/**
24-
* Fills the provided, pre-allocated host matrix with data from this device matrix.
25-
* The content of the provided host matrix will be overwritten; the 2 matrices must have the
26-
* same element type and dimension.
27-
*
28-
* @param hostMatrix the host-memory-backed matrix to fill.
29-
*/
30-
void toHost(CuVSHostMatrix hostMatrix);
31-
32-
/**
33-
* Returns a new, host matrix with data from this device matrix.
24+
* Returns a new host matrix with data from this device matrix.
3425
* The returned host matrix will need to be managed by the caller, which will be
3526
* responsible to call {@link CuVSMatrix#close()} to free its resources when done.
3627
*/
@@ -39,4 +30,67 @@ default CuVSHostMatrix toHost() {
3930
toHost(hostMatrix);
4031
return hostMatrix;
4132
}
33+
34+
default CuVSDeviceMatrix toDevice(CuVSResources resources) {
35+
return new CuVSDeviceMatrixDelegate(this);
36+
}
37+
38+
class CuVSDeviceMatrixDelegate implements CuVSDeviceMatrix {
39+
40+
private final CuVSDeviceMatrix deviceMatrix;
41+
42+
private CuVSDeviceMatrixDelegate(CuVSDeviceMatrix deviceMatrix) {
43+
this.deviceMatrix = deviceMatrix;
44+
}
45+
46+
@Override
47+
public long size() {
48+
return deviceMatrix.size();
49+
}
50+
51+
@Override
52+
public long columns() {
53+
return deviceMatrix.columns();
54+
}
55+
56+
@Override
57+
public DataType dataType() {
58+
return deviceMatrix.dataType();
59+
}
60+
61+
@Override
62+
public RowView getRow(long row) {
63+
return deviceMatrix.getRow(row);
64+
}
65+
66+
@Override
67+
public void toArray(int[][] array) {
68+
deviceMatrix.toArray(array);
69+
}
70+
71+
@Override
72+
public void toArray(float[][] array) {
73+
deviceMatrix.toArray(array);
74+
}
75+
76+
@Override
77+
public void toArray(byte[][] array) {
78+
deviceMatrix.toArray(array);
79+
}
80+
81+
@Override
82+
public void toHost(CuVSHostMatrix hostMatrix) {
83+
deviceMatrix.toHost(hostMatrix);
84+
}
85+
86+
@Override
87+
public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
88+
this.deviceMatrix.toDevice(deviceMatrix, cuVSResources);
89+
}
90+
91+
@Override
92+
public void close() {
93+
// Do nothing
94+
}
95+
}
4296
}

java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSHostMatrix.java

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,77 @@
2020
*/
2121
public interface CuVSHostMatrix extends CuVSMatrix {
2222
int get(int row, int col);
23+
24+
default CuVSHostMatrix toHost() {
25+
return new CuVSHostMatrixDelegate(this);
26+
}
27+
28+
default CuVSDeviceMatrix toDevice(CuVSResources resources) {
29+
var deviceMatrix = CuVSMatrix.deviceBuilder(resources, size(), columns(), dataType()).build();
30+
toDevice(deviceMatrix, resources);
31+
return deviceMatrix;
32+
}
33+
34+
class CuVSHostMatrixDelegate implements CuVSHostMatrix {
35+
private final CuVSHostMatrix hostMatrix;
36+
37+
public CuVSHostMatrixDelegate(CuVSHostMatrix cuVSHostMatrix) {
38+
this.hostMatrix = cuVSHostMatrix;
39+
}
40+
41+
@Override
42+
public int get(int row, int col) {
43+
return hostMatrix.get(row, col);
44+
}
45+
46+
@Override
47+
public long size() {
48+
return hostMatrix.size();
49+
}
50+
51+
@Override
52+
public long columns() {
53+
return hostMatrix.columns();
54+
}
55+
56+
@Override
57+
public DataType dataType() {
58+
return hostMatrix.dataType();
59+
}
60+
61+
@Override
62+
public RowView getRow(long row) {
63+
return hostMatrix.getRow(row);
64+
}
65+
66+
@Override
67+
public void toArray(int[][] array) {
68+
hostMatrix.toArray(array);
69+
}
70+
71+
@Override
72+
public void toArray(float[][] array) {
73+
hostMatrix.toArray(array);
74+
}
75+
76+
@Override
77+
public void toArray(byte[][] array) {
78+
hostMatrix.toArray(array);
79+
}
80+
81+
@Override
82+
public void toHost(CuVSHostMatrix hostMatrix) {
83+
this.hostMatrix.toHost(hostMatrix);
84+
}
85+
86+
@Override
87+
public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
88+
hostMatrix.toDevice(deviceMatrix, cuVSResources);
89+
}
90+
91+
@Override
92+
public void close() {
93+
// Do nothing
94+
}
95+
}
2396
}

java/cuvs-java/src/main/java/com/nvidia/cuvs/CuVSMatrix.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,42 @@ static Builder<CuVSDeviceMatrix> deviceBuilder(
177177
*/
178178
void toArray(byte[][] array);
179179

180+
/**
181+
* Fills the provided, pre-allocated host matrix with data from this matrix.
182+
* The content of the provided host matrix will be overwritten; the 2 matrices must have the
183+
* same element type and dimension.
184+
*
185+
* @param hostMatrix the host-memory-backed matrix to fill.
186+
*/
187+
void toHost(CuVSHostMatrix hostMatrix);
188+
189+
/**
190+
* Returns a host matrix; if the matrix is already a host matrix, a "weak" reference to the same host memory
191+
* is returned. If the matrix is a device matrix, a newly allocated matrix will be populated with data from
192+
* the device matrix.
193+
* The returned host matrix will need to be managed by the caller, which will be
194+
* responsible to call {@link CuVSMatrix#close()} to free its resources when done.
195+
*/
196+
CuVSHostMatrix toHost();
197+
198+
/**
199+
* Fills the provided, pre-allocated device matrix with data from this matrix.
200+
* The content of the provided device matrix will be overwritten; the 2 matrices must have the
201+
* same element type and dimension.
202+
*
203+
* @param deviceMatrix the device-memory-backed matrix to fill.
204+
*/
205+
void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources);
206+
207+
/**
208+
* Returns a device matrix; if this matrix is already a device matrix, a "weak" reference to the same host memory
209+
* is returned. If the matrix is a host matrix, a newly allocated matrix will be populated with data from
210+
* the host matrix.
211+
* The returned device matrix will need to be managed by the caller, which will be
212+
* responsible to call {@link CuVSMatrix#close()} to free its resources when done.
213+
*/
214+
CuVSDeviceMatrix toDevice(CuVSResources cuVSResources);
215+
180216
@Override
181217
void close();
182218
}

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSDeviceMatrixImpl.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,11 @@ public void toHost(CuVSHostMatrix hostMatrix) {
236236
}
237237
}
238238

239+
@Override
240+
public void toDevice(CuVSDeviceMatrix targetMatrix, CuVSResources cuVSResources) {
241+
copyMatrix(this, (CuVSMatrixBaseImpl) targetMatrix, cuVSResources);
242+
}
243+
239244
@Override
240245
public void close() {
241246
if (hostBuffer != MemorySegment.NULL) {

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSHostMatrixImpl.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import static com.nvidia.cuvs.internal.common.Util.prepareTensor;
1919
import static com.nvidia.cuvs.internal.panama.headers_h.*;
2020

21+
import com.nvidia.cuvs.CuVSDeviceMatrix;
2122
import com.nvidia.cuvs.CuVSHostMatrix;
23+
import com.nvidia.cuvs.CuVSResources;
2224
import com.nvidia.cuvs.RowView;
2325
import java.lang.foreign.*;
2426
import java.lang.invoke.VarHandle;
@@ -142,6 +144,19 @@ public void toArray(byte[][] array) {
142144
}
143145
}
144146

147+
@Override
148+
public void toHost(CuVSHostMatrix hostMatrix) {
149+
var targetMatrix = (CuVSHostMatrixImpl) hostMatrix;
150+
var valueByteSize = valueLayout.byteSize();
151+
MemorySegment.copy(
152+
this.memorySegment, 0L, targetMatrix.memorySegment, 0L, size * columns * valueByteSize);
153+
}
154+
155+
@Override
156+
public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
157+
copyMatrix(this, (CuVSMatrixBaseImpl) deviceMatrix, cuVSResources);
158+
}
159+
145160
@Override
146161
public void close() {}
147162

java/cuvs-java/src/main/java22/com/nvidia/cuvs/internal/CuVSMatrixBaseImpl.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_CHAR;
1919
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_FLOAT;
2020
import static com.nvidia.cuvs.internal.common.LinkerHelper.C_INT;
21+
import static com.nvidia.cuvs.internal.common.Util.checkCuVSError;
2122
import static com.nvidia.cuvs.internal.panama.headers_h.*;
2223

2324
import com.nvidia.cuvs.CuVSMatrix;
@@ -49,6 +50,29 @@ protected CuVSMatrixBaseImpl(
4950
this.columns = columns;
5051
}
5152

53+
protected static void copyMatrix(
54+
CuVSMatrixBaseImpl sourceMatrix, CuVSMatrixBaseImpl targetMatrix, CuVSResources resources) {
55+
if (targetMatrix.columns() != sourceMatrix.columns
56+
|| targetMatrix.size() != sourceMatrix.size) {
57+
throw new IllegalArgumentException(
58+
"Source and target matrices must have the same dimensions");
59+
}
60+
if (targetMatrix.dataType() != sourceMatrix.dataType) {
61+
throw new IllegalArgumentException("Source and target matrices must have the same dataType");
62+
}
63+
64+
try (var localArena = Arena.ofConfined()) {
65+
var targetTensor = targetMatrix.toTensor(localArena);
66+
67+
try (var resourceAccess = resources.access()) {
68+
var cuvsRes = resourceAccess.handle();
69+
var sourceTensor = sourceMatrix.toTensor(localArena);
70+
checkCuVSError(cuvsMatrixCopy(cuvsRes, sourceTensor, targetTensor), "cuvsMatrixCopy");
71+
checkCuVSError(cuvsStreamSync(cuvsRes), "cuvsStreamSync");
72+
}
73+
}
74+
}
75+
5276
@Override
5377
public long size() {
5478
return size;

0 commit comments

Comments
 (0)