Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.nvidia.cuvs;

import java.nio.file.Path;
import java.util.concurrent.locks.ReentrantLock;

/**
* A decorator for CuVSResources that guarantees synchronized (blocking) access to the wrapped CuVSResource
*/
public class SynchronizedCuVSResources implements CuVSResources {

private final CuVSResources inner;
private final ReentrantLock lock;

private SynchronizedCuVSResources(CuVSResources inner) {
this.inner = inner;
this.lock = new ReentrantLock();
}

static CuVSResources create() throws Throwable {
return new SynchronizedCuVSResources(CuVSResources.create());
}

@Override
public ScopedAccess access() {
lock.lock();
return new ScopedAccess() {
@Override
public long handle() {
return inner.access().handle();
}

@Override
public void close() {
lock.unlock();
}
};
}

@Override
public void close() {
inner.close();
}

@Override
public Path tempDirectory() {
return inner.tempDirectory();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
package com.nvidia.cuvs;

import static com.carrotsearch.randomizedtesting.RandomizedTest.assumeTrue;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.junit.Assert.*;

import com.carrotsearch.randomizedtesting.RandomizedRunner;
import com.nvidia.cuvs.CagraIndexParams.CagraGraphBuildAlgo;
Expand All @@ -45,6 +43,16 @@ public class CagraMultiThreadStabilityIT extends CuVSTestCase {

private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

private final int dimensions = 256;
private final int queriesPerThread = 500;
private final int queryBatchSize = 1; // Small batch size to increase frequency of calls
private final int topK = 10;

@FunctionalInterface
private interface QueryAction {
void run(CagraIndex index) throws Throwable;
}

@Before
public void setup() {
assumeTrue("not supported on " + System.getProperty("os.name"), isLinuxAmd64());
Expand All @@ -53,18 +61,27 @@ public void setup() {
}

@Test
public void testQueryingUsingMultipleThreads() throws Throwable {
public void testQueryingUsingMultipleThreadsWithSharedSynchronizedResources() throws Throwable {
try (CuVSResources sharedResources = SynchronizedCuVSResources.create()) {
testQueryingUsingMultipleThreads(
index -> performQueryWithSharedSynchronizedResource(sharedResources, index));
}
}

@Test
public void testQueryingUsingMultipleThreadsWithPrivateResources() throws Throwable {
testQueryingUsingMultipleThreads(this::performQueryWithPrivateResource);
}

private void testQueryingUsingMultipleThreads(QueryAction queryAction) throws Throwable {
final int dataSize = 10000;
final int dimensions = 256;
final int numThreads = 16; // High thread count to increase contention
final int queriesPerThread = 500;
final int queryBatchSize = 1; // Small batch size to increase frequency of calls
final int topK = 10;
final int numThreads = 16;

log.info(" Dataset: {}x{}", dataSize, dimensions);
// High thread count to increase contention
log.info(" Threads: {}, Queries per thread: {}", numThreads, queriesPerThread);

float[][] dataset = generateRandomDataset(dataSize, dimensions);
float[][] dataset = generateRandomDataset(dataSize);

try (CuVSResources resources = CheckedCuVSResources.create()) {
log.info("Creating CAGRA index for MultiThreaded stability test...");
Expand All @@ -86,105 +103,125 @@ public void testQueryingUsingMultipleThreads() throws Throwable {

log.info("CAGRA index created, starting high-contention multi-threaded search...");

// Create high contention scenario that would fail without using separate resources in every thread
ExecutorService executor = Executors.newFixedThreadPool(numThreads);
List<Future<?>> futures = new ArrayList<>();
CountDownLatch startLatch = new CountDownLatch(1);
AtomicInteger successfulQueries = new AtomicInteger(0);
AtomicReference<Throwable> firstError = new AtomicReference<>();

for (int threadId = 0; threadId < numThreads; threadId++) {
final int finalThreadId = threadId;

Future<?> future =
executor.submit(
() -> {
try {
// Wait for all threads to start simultaneously
startLatch.await();

for (int queryId = 0; queryId < queriesPerThread; queryId++) {
float[][] queries = generateRandomDataset(queryBatchSize, dimensions);

try (CuVSResources threadResources = CheckedCuVSResources.create()) {
CagraSearchParams searchParams = new CagraSearchParams.Builder().build();
CagraQuery query =
new CagraQuery.Builder(threadResources)
.withTopK(topK)
.withSearchParams(searchParams)
.withQueryVectors(queries)
.build();

// This call should now work with per-thread resources
SearchResults results = index.search(query);
assertNotNull("Query should return results", results);
assertTrue(
"Query should return some results", !results.getResults().isEmpty());

// Create high contention scenario that would fail without using separate resources in every
// thread
try (ExecutorService executor = Executors.newFixedThreadPool(numThreads)) {
List<Future<?>> futures = new ArrayList<>();
CountDownLatch startLatch = new CountDownLatch(1);
AtomicInteger successfulQueries = new AtomicInteger(0);
AtomicReference<Throwable> firstError = new AtomicReference<>();

for (int threadId = 0; threadId < numThreads; threadId++) {
final int finalThreadId = threadId;

Future<?> future =
executor.submit(
() -> {
try {
// Wait for all threads to start simultaneously
startLatch.await();

for (int queryId = 0; queryId < queriesPerThread; queryId++) {
queryAction.run(index);
successfulQueries.incrementAndGet();
}

// No Thread.yield() - maximize contention
}
// No Thread.yield() - maximize contention
}

log.info("Thread {} completed successfully", finalThreadId);
log.info("Thread {} completed successfully", finalThreadId);

} catch (Throwable t) {
log.error("Thread {} failed: {}", finalThreadId, t.getMessage(), t);
firstError.compareAndSet(null, t);
throw new RuntimeException("Thread failed", t);
}
});
} catch (Throwable t) {
log.error("Thread {} failed: {}", finalThreadId, t.getMessage(), t);
firstError.compareAndSet(null, t);
throw new RuntimeException("Thread failed", t);
}
});

futures.add(future);
}
futures.add(future);
}

// Start all threads simultaneously to maximize contention
log.info("Starting all {} threads simultaneously...", numThreads);
startLatch.countDown();

// Wait for all threads to complete
boolean allCompleted = true;
for (Future<?> future : futures) {
try {
future.get(120, TimeUnit.SECONDS); // Longer timeout for stress test
} catch (Exception e) {
allCompleted = false;
log.error("Thread failed: {}", e.getMessage(), e);
if (firstError.get() == null) {
firstError.set(e);
// Start all threads simultaneously to maximize contention
log.info("Starting all {} threads simultaneously...", numThreads);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a problem for this PR here. But I would like the tests to run without noise on the screen. We've done this, to a degree, in the TieredIndex tests.

This test is far less noisy than some of the other (extant) ones. We'll address that separately.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

++ to review logging levels. Probably a bit of "info" level feedback is OK, but most statements should probably be debug or trace, so we can turn them on when needed.

startLatch.countDown();

// Wait for all threads to complete
boolean allCompleted = true;
for (Future<?> future : futures) {
try {
future.get(120, TimeUnit.SECONDS); // Longer timeout for stress test
} catch (Exception e) {
allCompleted = false;
log.error("Thread failed: {}", e.getMessage(), e);
if (firstError.get() == null) {
firstError.set(e);
}
}
}
}

executor.shutdown();
if (!executor.awaitTermination(10, TimeUnit.SECONDS)) {
executor.shutdownNow();
}
executor.shutdown();
if (!executor.awaitTermination(10, TimeUnit.SECONDS)) {
executor.shutdownNow();
}

// Verify results
int expectedTotalQueries = numThreads * queriesPerThread;
int actualSuccessfulQueries = successfulQueries.get();
// Verify results
int expectedTotalQueries = numThreads * queriesPerThread;
int actualSuccessfulQueries = successfulQueries.get();

log.info(" Successful queries: {} / {}", actualSuccessfulQueries, expectedTotalQueries);
log.info(" Successful queries: {} / {}", actualSuccessfulQueries, expectedTotalQueries);

if (firstError.get() != null) {
fail(
"MultiThreaded stablity test failed:"
+ " "
+ firstError.get().getMessage());
}
if (firstError.get() != null) {
fail("MultiThreaded stablity test failed:" + " " + firstError.get().getMessage());
}

assertTrue("All threads should complete successfully", allCompleted);
assertTrue(
"All queries should complete successfully",
actualSuccessfulQueries == expectedTotalQueries);
assertTrue("All threads should complete successfully", allCompleted);
assertEquals(
"All queries should complete successfully",
expectedTotalQueries,
actualSuccessfulQueries);
}

index.destroyIndex();
}
}

private float[][] generateRandomDataset(int size, int dimensions) {
private void performQueryWithPrivateResource(CagraIndex index) throws Throwable {
float[][] queries = generateRandomDataset(queryBatchSize);

try (CuVSResources threadResources = CheckedCuVSResources.create()) {
CagraSearchParams searchParams = new CagraSearchParams.Builder().build();
CagraQuery query =
new CagraQuery.Builder(threadResources)
.withTopK(topK)
.withSearchParams(searchParams)
.withQueryVectors(queries)
.build();

// This call should now work with per-thread resources
SearchResults results = index.search(query);
assertNotNull("Query should return results", results);
assertFalse("Query should return some results", results.getResults().isEmpty());
}
}

private void performQueryWithSharedSynchronizedResource(
CuVSResources threadResources, CagraIndex index) throws Throwable {
float[][] queries = generateRandomDataset(queryBatchSize);

CagraSearchParams searchParams = new CagraSearchParams.Builder().build();
CagraQuery query =
new CagraQuery.Builder(threadResources)
.withTopK(topK)
.withSearchParams(searchParams)
.withQueryVectors(queries)
.build();

// This call should now work with per-thread resources
SearchResults results = index.search(query);
assertNotNull("Query should return results", results);
assertFalse("Query should return some results", results.getResults().isEmpty());
}

private float[][] generateRandomDataset(int size) {
Random random = new Random(42 + System.nanoTime());
float[][] data = new float[size][dimensions];

Expand Down