Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions java/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ Also, ensure that your panama bindings are up-to-date. They can be re-generated
./panama-bindings/generate-bindings.sh
```

Tests run using a randomized runner. Specific failures can be reproduced running a test suite with a specific seed,
by passing `-Dtests.seed=42FC5CC6B4C6BA8E` (where `42FC5CC6B4C6BA8E` has to be
replaced with your specific seed). It also possible to re-run a single test, but
in this case it's necessary to pass the extended seed (suite:method), e.g.
Comment on lines +44 to +47
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

```shell
mvn integration-test -Dit.test=com.nvidia.cuvs.CagraBuildAndSearchIT#testFloatIndexing -Dtests.seed=66039A8CAFB9D3C9:449B6310296799E0
```

It is also possible to ask the test runner to run a specific test or suite multiple
times, by passing `-Dtests.iters=10` through the command line.

## Examples

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ protected CuVSMatrixBaseImpl(
this.columns = columns;
}

@Override
public String toString() {
return String.format("%dx%d %s @ 0x%016X", size, columns, dataType, memorySegment.address());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

}

protected static void copyMatrix(
CuVSMatrixInternal sourceMatrix, CuVSMatrixInternal targetMatrix, CuVSResources resources) {
if (targetMatrix.columns() != sourceMatrix.columns()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Function;
import java.util.function.LongToIntFunction;
import java.util.function.Supplier;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.slf4j.Logger;
Expand All @@ -52,12 +52,12 @@ public void setup() {
log.trace("Random context initialized for test.");
}

private static void runConcurrently(int nThreads, Supplier<Runnable> runnableSupplier)
private static void runConcurrently(int nThreads, Function<Integer, Runnable> runnableSupplier)
throws ExecutionException, InterruptedException, TimeoutException {
try (ExecutorService parallelExecutor = Executors.newFixedThreadPool(nThreads)) {
var futures = new CompletableFuture[nThreads];
for (int j = 0; j < nThreads; j++) {
futures[j] = CompletableFuture.runAsync(runnableSupplier.get(), parallelExecutor);
futures[j] = CompletableFuture.runAsync(runnableSupplier.apply(j), parallelExecutor);
}

CompletableFuture.allOf(futures)
Expand All @@ -68,6 +68,11 @@ private static void runConcurrently(int nThreads, Supplier<Runnable> runnableSup
return null;
})
.get(2000, TimeUnit.SECONDS);

parallelExecutor.shutdown();
assertTrue(
"Timeout waiting for parallelExecutor to finish",
parallelExecutor.awaitTermination(10, TimeUnit.SECONDS));
}
}

Expand Down Expand Up @@ -189,12 +194,16 @@ public void testIndexingAndSearchingFlowConcurrently() throws Throwable {

runConcurrently(
numTestsRuns,
() ->
threadIdx ->
() -> {
log.debug("Indexing threadIdx:{}-{}", threadIdx, Thread.currentThread().getName());
try (CuVSResources resources = CheckedCuVSResources.create();
var index = indexOnce(CuVSMatrix.ofArray(dataset), resources)) {
var matrix = CuVSMatrix.ofArray(dataset);
var index = indexOnce(matrix, resources)) {
var indexPath = serializeOnce(index);
try (var loadedIndex = deserializeOnce(indexPath, resources)) {
log.debug(
"Querying threadIdx:{}-{}", threadIdx, Thread.currentThread().getName());
queryAndCompare(
index,
loadedIndex,
Expand All @@ -208,10 +217,10 @@ public void testIndexingAndSearchingFlowConcurrently() throws Throwable {
} catch (Throwable e) {
throw new RuntimeException(e);
}
log.debug("Done threadIdx:{}-{}", threadIdx, Thread.currentThread().getName());
});
}

@Ignore // https://github.com/rapidsai/cuvs/issues/1467
@Test
public void testFloatIndexing() throws Throwable {
testIndexing(
Expand All @@ -230,23 +239,33 @@ public void testByteIndexing() throws Throwable {

private void testIndexing(Supplier<CuVSMatrix> matrixFactory) throws Exception {
for (int i = 0; i < 10; ++i) {
var dataset = matrixFactory.get();
int numTestsRuns = 4;
runConcurrently(
numTestsRuns,
() ->
() -> {
try (CuVSResources resources = CheckedCuVSResources.create()) {
var index = indexOnce(dataset, resources);
index.close();
} catch (Throwable e) {
throw new RuntimeException(e);
}
});
try (var dataset = matrixFactory.get()) {
int numRunners = 4;
final int iteration = i;
runConcurrently(
numRunners,
threadIdx ->
() -> {
try (CuVSResources resources = CheckedCuVSResources.create()) {
// Create a local reference to the dataset, as index will close the dataset too
// when it gets closed.
var indexDatasetReference = dataset.toHost();
log.debug(
"Indexing iteration:{} threadIdx:{} dataset:{}",
iteration,
threadIdx,
dataset);
var index = indexOnce(indexDatasetReference, resources);
log.debug("Done {} {}", iteration, threadIdx);
index.close();
} catch (Throwable e) {
throw new RuntimeException(e);
}
});
}
}
}

@Ignore // https://github.com/rapidsai/cuvs/issues/1467
@Test
public void testFloatSerialization() throws Throwable {
testSerialization(
Expand All @@ -265,20 +284,24 @@ public void testByteSerialization() throws Throwable {

private void testSerialization(Supplier<CuVSMatrix> matrixFactory) throws Throwable {
for (int i = 0; i < 10; ++i) {
final var dataset = matrixFactory.get();
int numTestsRuns = 4;
runConcurrently(
numTestsRuns,
() ->
() -> {
try (CuVSResources resources = CheckedCuVSResources.create();
var index = indexOnce(dataset, resources)) {
var indexPath = serializeOnce(index);
Files.deleteIfExists(indexPath);
} catch (Throwable e) {
throw new RuntimeException(e);
}
});
try (final var dataset = matrixFactory.get()) {
int numRunners = 4;
runConcurrently(
numRunners,
threadIdx ->
() -> {
// Create a local reference to the dataset, as index will close the dataset too
// when it gets closed.
var indexDatasetReference = dataset.toHost();
try (CuVSResources resources = CheckedCuVSResources.create();
var index = indexOnce(indexDatasetReference, resources)) {
var indexPath = serializeOnce(index);
Files.deleteIfExists(indexPath);
} catch (Throwable e) {
throw new RuntimeException(e);
}
});
}
}
}

Expand Down Expand Up @@ -308,7 +331,7 @@ private void testDeserialization(Supplier<CuVSMatrix> matrixFactory) throws Thro
int numTestsRuns = 4;
runConcurrently(
numTestsRuns,
() ->
threadIdx ->
() -> {
try (CuVSResources resources = CheckedCuVSResources.create()) {
deserializeOnce(indexPath, resources).close();
Expand Down