Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import com.nvidia.cuvs.internal.common.CloseableRMMAllocation;
import com.nvidia.cuvs.internal.common.CompositeCloseableHandle;
import com.nvidia.cuvs.internal.panama.*;

import java.io.FileInputStream;
import java.io.InputStream;
import java.io.OutputStream;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ private boolean mustClose() {
public void close() {
if (mustClose()) {
checkCuVSError(cuvsRMMFree(cuvsResourceHandle, pointer, numBytes), "cuvsRMMFree");
pointer = MemorySegment.NULL;
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a small enhancement that came up while testing: it protects us from a double close().
I thought it was worth adding it.

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@
package com.nvidia.cuvs;

import static com.carrotsearch.randomizedtesting.RandomizedTest.assumeTrue;
import static com.nvidia.cuvs.CuVSMatrixIT.assertSame2dArray;
import static org.junit.Assert.*;

import com.carrotsearch.randomizedtesting.RandomizedRunner;
import com.nvidia.cuvs.CagraIndexParams.CagraGraphBuildAlgo;
import com.nvidia.cuvs.CagraIndexParams.CuvsDistanceType;
import com.nvidia.cuvs.CagraMergeParams.MergeStrategy;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.lang.foreign.Arena;
import java.lang.foreign.Linker;
Expand Down Expand Up @@ -397,12 +395,12 @@ public void testPrefilteringReducesResults() throws Throwable {
.withMetric(CuvsDistanceType.L2Expanded)
.build();

try (CuVSResources resources = CheckedCuVSResources.create()) {
CagraIndex index =
CagraIndex.newBuilder(resources)
.withDataset(dataset)
.withIndexParams(indexParams)
.build();
try (CuVSResources resources = CheckedCuVSResources.create();
CagraIndex index =
CagraIndex.newBuilder(resources)
.withDataset(dataset)
.withIndexParams(indexParams)
.build()) {

// No prefilter (all points allowed)
CagraSearchParams searchParams = new CagraSearchParams.Builder().build();
Expand Down Expand Up @@ -514,6 +512,10 @@ private void cleanup(CagraIndex index, CagraIndex loadedIndex) throws Throwable
loadedIndex.close();
}

/**
* Tests that an index built starting from a native MemorySegment is identical to one built from
* Java heap arrays
*/
@Test
public void testNativeDatasetEquivalent() throws Throwable {
float[][] sampleData = createSampleData();
Expand Down Expand Up @@ -555,6 +557,37 @@ public void testNativeDatasetEquivalent() throws Throwable {
}
}

/**
* Tests that an index built starting from device memory ({@link CuVSDeviceMatrix}) is identical to one
* built from Java heap arrays
*/
@Test
public void testDeviceDatasetEquivalent() throws Throwable {
float[][] sampleData = createSampleData();

try (var resources = CuVSResources.create();
var javaDataset = CuVSMatrix.ofArray(sampleData);
var deviceDataset = javaDataset.toDevice(resources)) {

// Indexing with an on-heap and native datasets produce the same results
var javaIndex = indexOnce(javaDataset, resources);
var deviceIndex = indexOnce(deviceDataset, resources);

int size = (int) javaIndex.getGraph().size();
assertEquals(size, (int) deviceIndex.getGraph().size());

int columns = (int) javaIndex.getGraph().columns();
assertEquals(columns, (int) deviceIndex.getGraph().columns());

var javaIndexGraph = new int[size][columns];
var deviceIndexGraph = new int[size][columns];
javaIndex.getGraph().toArray(javaIndexGraph);
deviceIndex.getGraph().toArray(deviceIndexGraph);

assertSame2dArray(size, columns, javaIndexGraph, deviceIndexGraph);
}
}

@Test
public void testMergingIndexes() throws Throwable {
float[][] vector1 = {
Expand Down Expand Up @@ -626,22 +659,23 @@ public void testMergingIndexes() throws Throwable {

// --- Serialization/deserialization check ---
String indexFileName = UUID.randomUUID() + ".cag";
mergedIndex.serialize(new FileOutputStream(indexFileName));
var indexFile = Path.of(indexFileName);

File indexFile = new File(indexFileName);
InputStream inputStream = new FileInputStream(indexFile);
CagraIndex loadedMergedIndex = CagraIndex.newBuilder(resources).from(inputStream).build();
try (var out = Files.newOutputStream(indexFile)) {
mergedIndex.serialize(out);
}

SearchResults resultsFromLoaded = loadedMergedIndex.search(query);
assertEquals(expectedResults, resultsFromLoaded.getResults());
try (InputStream inputStream = Files.newInputStream(indexFile)) {
CagraIndex loadedMergedIndex = CagraIndex.newBuilder(resources).from(inputStream).build();

if (indexFile.exists()) {
indexFile.delete();
SearchResults resultsFromLoaded = loadedMergedIndex.search(query);
assertEquals(expectedResults, resultsFromLoaded.getResults());
mergedIndex.close();
loadedMergedIndex.close();
}
Files.deleteIfExists(indexFile);
index1.close();
index2.close();
mergedIndex.close();
loadedMergedIndex.close();
}
}

Expand Down Expand Up @@ -710,50 +744,51 @@ public void testMergeStrategies() throws Throwable {
.build();

log.trace("Merging indexes with PHYSICAL strategy...");
CagraIndex physicalMergedIndex =
CagraIndex.merge(new CagraIndex[] {index1, index2}, physicalMergeParams);
log.trace("Physical merge completed successfully");
try (CagraIndex physicalMergedIndex =
CagraIndex.merge(new CagraIndex[] {index1, index2}, physicalMergeParams)) {
log.trace("Physical merge completed successfully");

CagraSearchParams searchParams = new CagraSearchParams.Builder().build();

CagraQuery query =
new CagraQuery.Builder(resources)
.withTopK(3)
.withSearchParams(searchParams)
.withQueryVectors(queries)
.withMapping(SearchResults.IDENTITY_MAPPING)
.build();

log.trace("Searching physically merged index...");
SearchResults physicalResults = physicalMergedIndex.search(query);
assertNotNull("Physical merge search results should not be null", physicalResults);
assertEquals(
"Physical merge search results should match expected",
expectedResults,
physicalResults.getResults());

CagraSearchParams searchParams = new CagraSearchParams.Builder().build();
// --- Serialization/deserialization check for both merged indexes ---
String physicalIndexFileName = UUID.randomUUID() + ".cag";
var physicalIndexFile = Path.of(physicalIndexFileName);

CagraQuery query =
new CagraQuery.Builder(resources)
.withTopK(3)
.withSearchParams(searchParams)
.withQueryVectors(queries)
.withMapping(SearchResults.IDENTITY_MAPPING)
.build();
try (var out = Files.newOutputStream(physicalIndexFile)) {
physicalMergedIndex.serialize(out);
}

log.trace("Searching physically merged index...");
SearchResults physicalResults = physicalMergedIndex.search(query);
assertNotNull("Physical merge search results should not be null", physicalResults);
assertEquals(
"Physical merge search results should match expected",
expectedResults,
physicalResults.getResults());

// --- Serialization/deserialization check for both merged indexes ---
String physicalIndexFileName = UUID.randomUUID().toString() + ".cag";
physicalMergedIndex.serialize(new FileOutputStream(physicalIndexFileName));

File physicalIndexFile = new File(physicalIndexFileName);
InputStream physicalInputStream = new FileInputStream(physicalIndexFile);
CagraIndex loadedPhysicalIndex =
CagraIndex.newBuilder(resources).from(physicalInputStream).build();

SearchResults resultsFromLoadedPhysical = loadedPhysicalIndex.search(query);
assertEquals(
"Loaded physical index search results should match expected",
expectedResults,
resultsFromLoadedPhysical.getResults());

if (physicalIndexFile.exists()) {
physicalIndexFile.delete();
try (InputStream physicalInputStream = Files.newInputStream(physicalIndexFile);
CagraIndex loadedPhysicalIndex =
CagraIndex.newBuilder(resources).from(physicalInputStream).build()) {

Files.deleteIfExists(physicalIndexFile);

SearchResults resultsFromLoadedPhysical = loadedPhysicalIndex.search(query);
assertEquals(
"Loaded physical index search results should match expected",
expectedResults,
resultsFromLoadedPhysical.getResults());
}
}
index1.close();
index2.close();
physicalMergedIndex.close();
loadedPhysicalIndex.close();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,26 @@ public void setup() {
log.trace("Random context initialized for test.");
}

enum TestDatasetMemoryKind {
HEAP,
NATIVE,
DEVICE
}

@Test
public void testResultsTopKWithRandomValues() throws Throwable {
boolean[] useNativeMemoryDatasets = {true, false};
TestDatasetMemoryKind[] testDatasetMemoryKinds = {
TestDatasetMemoryKind.HEAP, TestDatasetMemoryKind.NATIVE, TestDatasetMemoryKind.DEVICE
};
for (int i = 0; i < 100; i++) {
for (boolean use : useNativeMemoryDatasets) {
tmpResultsTopKWithRandomValues(use);
for (var datasetMemoryKind : testDatasetMemoryKinds) {
tmpResultsTopKWithRandomValues(datasetMemoryKind);
}
}
}

private void tmpResultsTopKWithRandomValues(boolean useNativeMemoryDataset) throws Throwable {
private void tmpResultsTopKWithRandomValues(TestDatasetMemoryKind datasetMemoryKind)
throws Throwable {
int DATASET_SIZE_LIMIT = 10_000;
int DIMENSIONS_LIMIT = 2048;
int NUM_QUERIES_LIMIT = 10;
Expand Down Expand Up @@ -90,7 +99,7 @@ private void tmpResultsTopKWithRandomValues(boolean useNativeMemoryDataset) thro
log.debug("Dataset size: {}x{}", datasetSize, dimensions);
log.debug("Query size: {}x{}", numQueries, dimensions);
log.debug("TopK: {}", topK);
log.debug("Use native memory dataset? " + useNativeMemoryDataset);
log.debug("Use memory dataset: " + datasetMemoryKind.name());

// Debugging: Log dataset and queries
if (log.isDebugEnabled()) {
Expand Down Expand Up @@ -119,8 +128,8 @@ private void tmpResultsTopKWithRandomValues(boolean useNativeMemoryDataset) thro
.withCagraGraphBuildAlgo(CagraGraphBuildAlgo.NN_DESCENT)
.build();

CagraIndex index;
if (useNativeMemoryDataset) {
final CagraIndex index;
if (datasetMemoryKind == TestDatasetMemoryKind.NATIVE) {
var datasetBuilder =
CuVSMatrix.hostBuilder(vectors.length, vectors[0].length, CuVSMatrix.DataType.FLOAT);
for (float[] v : vectors) {
Expand All @@ -131,7 +140,20 @@ private void tmpResultsTopKWithRandomValues(boolean useNativeMemoryDataset) thro
.withDataset(datasetBuilder.build())
.withIndexParams(indexParams)
.build();
} else if (datasetMemoryKind == TestDatasetMemoryKind.DEVICE) {
var datasetBuilder =
CuVSMatrix.deviceBuilder(
resources, vectors.length, vectors[0].length, CuVSMatrix.DataType.FLOAT);
for (float[] v : vectors) {
datasetBuilder.addVector(v);
}
index =
CagraIndex.newBuilder(resources)
.withDataset(datasetBuilder.build())
.withIndexParams(indexParams)
.build();
} else {
assert datasetMemoryKind == TestDatasetMemoryKind.HEAP;
index =
CagraIndex.newBuilder(resources)
.withDataset(vectors)
Expand Down
Loading