Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,67 +30,4 @@ default CuVSHostMatrix toHost() {
toHost(hostMatrix);
return hostMatrix;
}

default CuVSDeviceMatrix toDevice(CuVSResources resources) {
return new CuVSDeviceMatrixDelegate(this);
}

class CuVSDeviceMatrixDelegate implements CuVSDeviceMatrix {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These got moved to internal implementation classes, for better encapsulation and to use the internal interface type.


private final CuVSDeviceMatrix deviceMatrix;

private CuVSDeviceMatrixDelegate(CuVSDeviceMatrix deviceMatrix) {
this.deviceMatrix = deviceMatrix;
}

@Override
public long size() {
return deviceMatrix.size();
}

@Override
public long columns() {
return deviceMatrix.columns();
}

@Override
public DataType dataType() {
return deviceMatrix.dataType();
}

@Override
public RowView getRow(long row) {
return deviceMatrix.getRow(row);
}

@Override
public void toArray(int[][] array) {
deviceMatrix.toArray(array);
}

@Override
public void toArray(float[][] array) {
deviceMatrix.toArray(array);
}

@Override
public void toArray(byte[][] array) {
deviceMatrix.toArray(array);
}

@Override
public void toHost(CuVSHostMatrix hostMatrix) {
deviceMatrix.toHost(hostMatrix);
}

@Override
public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
this.deviceMatrix.toDevice(deviceMatrix, cuVSResources);
}

@Override
public void close() {
// Do nothing
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,76 +21,9 @@
public interface CuVSHostMatrix extends CuVSMatrix {
int get(int row, int col);

default CuVSHostMatrix toHost() {
return new CuVSHostMatrixDelegate(this);
}

default CuVSDeviceMatrix toDevice(CuVSResources resources) {
var deviceMatrix = CuVSMatrix.deviceBuilder(resources, size(), columns(), dataType()).build();
toDevice(deviceMatrix, resources);
return deviceMatrix;
}

class CuVSHostMatrixDelegate implements CuVSHostMatrix {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

private final CuVSHostMatrix hostMatrix;

public CuVSHostMatrixDelegate(CuVSHostMatrix cuVSHostMatrix) {
this.hostMatrix = cuVSHostMatrix;
}

@Override
public int get(int row, int col) {
return hostMatrix.get(row, col);
}

@Override
public long size() {
return hostMatrix.size();
}

@Override
public long columns() {
return hostMatrix.columns();
}

@Override
public DataType dataType() {
return hostMatrix.dataType();
}

@Override
public RowView getRow(long row) {
return hostMatrix.getRow(row);
}

@Override
public void toArray(int[][] array) {
hostMatrix.toArray(array);
}

@Override
public void toArray(float[][] array) {
hostMatrix.toArray(array);
}

@Override
public void toArray(byte[][] array) {
hostMatrix.toArray(array);
}

@Override
public void toHost(CuVSHostMatrix hostMatrix) {
this.hostMatrix.toHost(hostMatrix);
}

@Override
public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
hostMatrix.toDevice(deviceMatrix, cuVSResources);
}

@Override
public void close() {
// Do nothing
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ private BruteForceIndexImpl(
Objects.requireNonNull(dataset);
try (dataset) {
this.resources = resources;
assert dataset instanceof CuVSMatrixBaseImpl;
this.bruteForceIndexReference = build((CuVSMatrixBaseImpl) dataset, bruteForceIndexParams);
assert dataset instanceof CuVSMatrixInternal;
this.bruteForceIndexReference = build((CuVSMatrixInternal) dataset, bruteForceIndexParams);
}
}

Expand Down Expand Up @@ -124,7 +124,7 @@ public void close() {
* index
*/
private IndexReference build(
CuVSMatrixBaseImpl dataset, BruteForceIndexParams bruteForceIndexParams) {
CuVSMatrixInternal dataset, BruteForceIndexParams bruteForceIndexParams) {
long rows = dataset.size();
long cols = dataset.columns();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import com.nvidia.cuvs.internal.common.CloseableRMMAllocation;
import com.nvidia.cuvs.internal.common.CompositeCloseableHandle;
import com.nvidia.cuvs.internal.panama.*;

import java.io.FileInputStream;
import java.io.InputStream;
import java.io.OutputStream;
Expand Down Expand Up @@ -78,8 +77,8 @@ private CagraIndexImpl(
CagraIndexParams indexParameters, CuVSMatrix dataset, CuVSResources resources) {
Objects.requireNonNull(dataset);
this.resources = resources;
assert dataset instanceof CuVSMatrixBaseImpl;
this.cagraIndexReference = build(indexParameters, (CuVSMatrixBaseImpl) dataset);
assert dataset instanceof CuVSMatrixInternal;
this.cagraIndexReference = build(indexParameters, (CuVSMatrixInternal) dataset);
}

/**
Expand Down Expand Up @@ -124,11 +123,11 @@ private CagraIndexImpl(

this.resources = resources;

assert graph instanceof CuVSMatrixBaseImpl;
assert dataset instanceof CuVSMatrixBaseImpl;
assert graph instanceof CuVSMatrixInternal;
assert dataset instanceof CuVSMatrixInternal;

this.cagraIndexReference =
fromGraph(metric, (CuVSMatrixBaseImpl) graph, (CuVSMatrixBaseImpl) dataset);
fromGraph(metric, (CuVSMatrixInternal) graph, (CuVSMatrixInternal) dataset);
}

private void checkNotDestroyed() {
Expand Down Expand Up @@ -161,7 +160,7 @@ public void close() {
* @return an instance of {@link IndexReference} that holds the pointer to the
* index
*/
private IndexReference build(CagraIndexParams indexParameters, CuVSMatrixBaseImpl dataset) {
private IndexReference build(CagraIndexParams indexParameters, CuVSMatrixInternal dataset) {
long rows = dataset.size();

try (var indexParams = segmentFromIndexParams(indexParameters);
Expand Down Expand Up @@ -410,8 +409,8 @@ public CuVSDeviceMatrix getGraph() {

private IndexReference fromGraph(
CagraIndexParams.CuvsDistanceType metric,
CuVSMatrixBaseImpl graph,
CuVSMatrixBaseImpl dataset) {
CuVSMatrixInternal graph,
CuVSMatrixInternal dataset) {
try (var localArena = Arena.ofConfined()) {
var index = createCagraIndex();
try (var resourcesAccess = resources.access()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ public void toHost(CuVSHostMatrix hostMatrix) {
throw new IllegalArgumentException("[hostMatrix] must have the same dataType");
}
try (var localArena = Arena.ofConfined()) {
var hostMatrixTensor = ((CuVSHostMatrixImpl) hostMatrix).toTensor(localArena);
var hostMatrixTensor = ((CuVSMatrixInternal) hostMatrix).toTensor(localArena);

try (var resourceAccess = resources.access()) {
var cuvsRes = resourceAccess.handle();
Expand All @@ -236,9 +236,14 @@ public void toHost(CuVSHostMatrix hostMatrix) {
}
}

@Override
public CuVSDeviceMatrix toDevice(CuVSResources resources) {
return new CuVSDeviceMatrixDelegate(this);
}

@Override
public void toDevice(CuVSDeviceMatrix targetMatrix, CuVSResources cuVSResources) {
copyMatrix(this, (CuVSMatrixBaseImpl) targetMatrix, cuVSResources);
copyMatrix(this, (CuVSMatrixInternal) targetMatrix, cuVSResources);
}

@Override
Expand All @@ -248,4 +253,92 @@ public void close() {
hostBuffer = MemorySegment.NULL;
}
}

private static class CuVSDeviceMatrixDelegate implements CuVSDeviceMatrix, CuVSMatrixInternal {
private final CuVSDeviceMatrixImpl deviceMatrix;

private CuVSDeviceMatrixDelegate(CuVSDeviceMatrixImpl deviceMatrix) {
this.deviceMatrix = deviceMatrix;
}

@Override
public long size() {
return deviceMatrix.size();
}

@Override
public long columns() {
return deviceMatrix.columns();
}

@Override
public DataType dataType() {
return deviceMatrix.dataType();
}

@Override
public RowView getRow(long row) {
return deviceMatrix.getRow(row);
}

@Override
public void toArray(int[][] array) {
deviceMatrix.toArray(array);
}

@Override
public void toArray(float[][] array) {
deviceMatrix.toArray(array);
}

@Override
public void toArray(byte[][] array) {
deviceMatrix.toArray(array);
}

@Override
public void toHost(CuVSHostMatrix hostMatrix) {
deviceMatrix.toHost(hostMatrix);
}

@Override
public void toDevice(CuVSDeviceMatrix deviceMatrix, CuVSResources cuVSResources) {
deviceMatrix.toDevice(deviceMatrix, cuVSResources);
}

@Override
public CuVSDeviceMatrix toDevice(CuVSResources cuVSResources) {
return this;
}

@Override
public MemorySegment memorySegment() {
return deviceMatrix.memorySegment();
}

@Override
public ValueLayout valueLayout() {
return deviceMatrix.valueLayout();
}

@Override
public int bits() {
return deviceMatrix.bits();
}

@Override
public int code() {
return 0;
}

@Override
public MemorySegment toTensor(Arena arena) {
return deviceMatrix.toTensor(arena);
}

@Override
public void close() {
// Do nothing
}
}
}
Loading