Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package com.nvidia.cuvs;

import static com.carrotsearch.randomizedtesting.RandomizedTest.assumeTrue;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
import static com.nvidia.cuvs.CuVSMatrixIT.assertSame2dArray;
import static org.junit.Assert.*;

Expand Down Expand Up @@ -133,18 +134,19 @@ public void testIndexingAndSearchingFlow() throws Throwable {
int numTestsRuns = 5;
try (CuVSResources resources = CheckedCuVSResources.create()) {
for (int j = 0; j < numTestsRuns; j++) {
var index = indexOnce(CuVSMatrix.ofArray(dataset), resources);
var indexPath = serializeOnce(index);
var loadedIndex = deserializeOnce(indexPath, resources);
queryAndCompare(
index,
loadedIndex,
SearchResults.IDENTITY_MAPPING,
queries,
expectedResults,
resources);
cleanup(index, loadedIndex);
Files.deleteIfExists(indexPath);
try (var index = indexOnce(CuVSMatrix.ofArray(dataset), resources)) {
var indexPath = serializeOnce(index);
try (var loadedIndex = deserializeOnce(indexPath, resources)) {
queryAndCompare(
index,
loadedIndex,
SearchResults.IDENTITY_MAPPING,
queries,
expectedResults,
resources);
Files.deleteIfExists(indexPath);
Comment on lines +137 to +147
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

}
}
}
}
}
Expand All @@ -163,19 +165,19 @@ public void testIndexingAndSearchingFlowInDifferentThreads() throws Throwable {
for (int j = 0; j < numTestsRuns; j++) {
runInAnotherThread(
() -> {
try {
var index = indexOnce(CuVSMatrix.ofArray(dataset), resources);
try (var index = indexOnce(CuVSMatrix.ofArray(dataset), resources)) {
var indexPath = serializeOnce(index);
var loadedIndex = deserializeOnce(indexPath, resources);
queryAndCompare(
index,
loadedIndex,
SearchResults.IDENTITY_MAPPING,
queries,
expectedResults,
resources);
cleanup(index, loadedIndex);
Files.deleteIfExists(indexPath);
try (var loadedIndex = deserializeOnce(indexPath, resources)) {
queryAndCompare(
index,
loadedIndex,
SearchResults.IDENTITY_MAPPING,
queries,
expectedResults,
resources);
} finally {
Files.deleteIfExists(indexPath);
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
Expand All @@ -199,36 +201,52 @@ public void testIndexingAndSearchingFlowConcurrently() throws Throwable {
numTestsRuns,
() ->
() -> {
try (CuVSResources resources = CheckedCuVSResources.create()) {
var index = indexOnce(CuVSMatrix.ofArray(dataset), resources);
try (CuVSResources resources = CheckedCuVSResources.create();
var index = indexOnce(CuVSMatrix.ofArray(dataset), resources)) {
var indexPath = serializeOnce(index);
var loadedIndex = deserializeOnce(indexPath, resources);
queryAndCompare(
index,
loadedIndex,
SearchResults.IDENTITY_MAPPING,
queries,
expectedResults,
resources);
cleanup(index, loadedIndex);
Files.deleteIfExists(indexPath);
try (var loadedIndex = deserializeOnce(indexPath, resources)) {
queryAndCompare(
index,
loadedIndex,
SearchResults.IDENTITY_MAPPING,
queries,
expectedResults,
resources);
} finally {
Files.deleteIfExists(indexPath);
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
});
}

@Test
public void testIndexing() throws Throwable {
for (int i = 0; i < 100; ++i) {
final float[][] dataset = createSampleData();
int numTestsRuns = 10;
public void testFloatIndexing() throws Throwable {
testIndexing(
() ->
CuVSMatrix.ofArray(
createFloatMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048))));
}

@Test
public void testByteIndexing() throws Throwable {
testIndexing(
() ->
CuVSMatrix.ofArray(
createByteMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048))));
}

private void testIndexing(Supplier<CuVSMatrix> matrixFactory) throws Exception {
for (int i = 0; i < 10; ++i) {
var dataset = matrixFactory.get();
int numTestsRuns = 4;
runConcurrently(
numTestsRuns,
() ->
() -> {
try (CuVSResources resources = CheckedCuVSResources.create()) {
var index = indexOnce(CuVSMatrix.ofArray(dataset), resources);
var index = indexOnce(dataset, resources);
index.close();
} catch (Throwable e) {
throw new RuntimeException(e);
Expand All @@ -238,16 +256,31 @@ public void testIndexing() throws Throwable {
}

@Test
public void testSerialization() throws Throwable {
for (int i = 0; i < 100; ++i) {
final float[][] dataset = createSampleData();
int numTestsRuns = 10;
public void testFloatSerialization() throws Throwable {
testSerialization(
() ->
CuVSMatrix.ofArray(
createFloatMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048))));
}

@Test
public void testByteSerialization() throws Throwable {
testSerialization(
() ->
CuVSMatrix.ofArray(
createByteMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048))));
}

private void testSerialization(Supplier<CuVSMatrix> matrixFactory) throws Throwable {
for (int i = 0; i < 10; ++i) {
final var dataset = matrixFactory.get();
int numTestsRuns = 4;
runConcurrently(
numTestsRuns,
() ->
() -> {
try (CuVSResources resources = CheckedCuVSResources.create();
var index = indexOnce(CuVSMatrix.ofArray(dataset), resources)) {
var index = indexOnce(dataset, resources)) {
var indexPath = serializeOnce(index);
Files.deleteIfExists(indexPath);
} catch (Throwable e) {
Expand All @@ -258,22 +291,43 @@ public void testSerialization() throws Throwable {
}

@Test
public void testDeserialization() throws Throwable {
var indexPath = createSerializedIndex(CuVSMatrix.ofArray(createSampleData()));
for (int i = 0; i < 100; ++i) {
int numTestsRuns = 10;
runConcurrently(
numTestsRuns,
() ->
() -> {
try (CuVSResources resources = CheckedCuVSResources.create()) {
deserializeOnce(indexPath, resources).close();
} catch (Throwable e) {
throw new RuntimeException(e);
}
});
public void testFloatDeserialization() throws Throwable {
testDeserialization(
() ->
CuVSMatrix.ofArray(
createFloatMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048))));
}

@Test
public void testByteDeserialization() throws Throwable {
testDeserialization(
() ->
CuVSMatrix.ofArray(
createByteMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048))));
}

private void testDeserialization(Supplier<CuVSMatrix> matrixFactory) throws Throwable {
Path indexPath;
try (var dataset = matrixFactory.get()) {
indexPath = createSerializedIndex(dataset);
}
try {
for (int i = 0; i < 10; ++i) {
int numTestsRuns = 4;
runConcurrently(
numTestsRuns,
() ->
() -> {
try (CuVSResources resources = CheckedCuVSResources.create()) {
deserializeOnce(indexPath, resources).close();
} catch (Throwable e) {
throw new RuntimeException(e);
}
});
}
} finally {
Files.deleteIfExists(indexPath);
}
Files.deleteIfExists(indexPath);
}

private Path createSerializedIndex(CuVSMatrix dataset) throws Throwable {
Expand Down Expand Up @@ -335,13 +389,14 @@ public void testIndexingAndSearchingFlowWithCustomMappingFunction() throws Throw
Map.of(2, 0.15224178f, 1, 0.59063464f, 0, 0.5986642f));

LongToIntFunction rotate = l -> (int) ((l + 1) % dataset.size());
try (CuVSResources resources = CheckedCuVSResources.create()) {
var index = indexOnce(dataset, resources);
try (CuVSResources resources = CheckedCuVSResources.create();
var index = indexOnce(dataset, resources)) {
var indexPath = serializeOnce(index);
var loadedIndex = deserializeOnce(indexPath, resources);
queryAndCompare(index, loadedIndex, rotate, queries, expectedResults, resources);
cleanup(index, loadedIndex);
Files.deleteIfExists(indexPath);
try (var loadedIndex = deserializeOnce(indexPath, resources)) {
queryAndCompare(index, loadedIndex, rotate, queries, expectedResults, resources);
} finally {
Files.deleteIfExists(indexPath);
}
}
}

Expand All @@ -358,13 +413,14 @@ public void testIndexingAndSearchingFlowWithCustomMappingList() throws Throwable
Map.of(3, 0.15224178f, 4, 0.59063464f, 1, 0.5986642f));

LongToIntFunction rotate = SearchResults.mappingsFromList(mappings);
try (CuVSResources resources = CheckedCuVSResources.create()) {
var index = indexOnce(dataset, resources);
try (CuVSResources resources = CheckedCuVSResources.create();
var index = indexOnce(dataset, resources)) {
var indexPath = serializeOnce(index);
var loadedIndex = deserializeOnce(indexPath, resources);
queryAndCompare(index, loadedIndex, rotate, queries, expectedResults, resources);
cleanup(index, loadedIndex);
Files.deleteIfExists(indexPath);
try (var loadedIndex = deserializeOnce(indexPath, resources)) {
queryAndCompare(index, loadedIndex, rotate, queries, expectedResults, resources);
} finally {
Files.deleteIfExists(indexPath);
}
}
}

Expand Down Expand Up @@ -511,12 +567,6 @@ private void queryAndCompare(
}
}

private void cleanup(CagraIndex index, CagraIndex loadedIndex) throws Throwable {
// Cleanup
index.close();
loadedIndex.close();
}

/**
* Tests that an index built starting from a native MemorySegment is identical to one built from
* Java heap arrays
Expand Down Expand Up @@ -545,19 +595,17 @@ public void testNativeDatasetEquivalent() throws Throwable {
var javaDataset = CuVSMatrix.ofArray(sampleData);
var nativeDataset =
DatasetHelper.fromMemorySegment(
dataMemorySegment, rows, cols, CuVSMatrix.DataType.FLOAT)) {

// Indexing with an on-heap and native datasets produce the same results
var javaIndex = indexOnce(javaDataset, resources);
var nativeIndex = indexOnce(nativeDataset, resources);
dataMemorySegment, rows, cols, CuVSMatrix.DataType.FLOAT);
// Indexing with an on-heap and native datasets produce the same results
var javaIndex = indexOnce(javaDataset, resources);
var nativeIndex = indexOnce(nativeDataset, resources)) {
queryAndCompare(
javaIndex,
nativeIndex,
SearchResults.IDENTITY_MAPPING,
queries,
expectedResults,
resources);
cleanup(javaIndex, nativeIndex);
}
}
}
Expand Down
37 changes: 1 addition & 36 deletions java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSMatrixIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,40 +49,6 @@ public void setup() {
{0, 4, 2}
};

private int[][] createIntMatrix() {
int rows = randomIntBetween(1, 32);
int cols = randomIntBetween(1, 100);

int[][] result = new int[rows][cols];

for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
result[r][c] = randomInt();
}
}
return result;
}

private float[][] createFloatMatrix() {
int rows = randomIntBetween(1, 32);
int cols = randomIntBetween(1, 100);

return createFloatMatrix(rows, cols);
}

private float[][] createFloatMatrix(int rows, int cols) {
float[][] result = new float[rows][cols];

float value = 1;
for (int r = 0; r < rows; ++r) {
for (int c = 0; c < cols; ++c) {
result[r][c] = value;
value += 1;
}
}
return result;
}

private void testByteDatasetRowGetAccess(CuVSMatrix dataset) {
for (int n = 0; n < dataset.size(); ++n) {
var row = dataset.getRow(n);
Expand Down Expand Up @@ -505,8 +471,7 @@ public void testHostToHostWithDifferentStrides() {
}

@Test
public void testHostBuilderWithDifferentStrides() throws Throwable {

public void testHostBuilderWithDifferentStrides() {
int size = randomIntBetween(1, 32 * 1024);
int columns = randomIntBetween(16, 2048);
int rowStride1 = randomIntBetween(columns, columns * 2);
Expand Down
Loading