diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchIT.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchIT.java index c9de5f6eeb..381701f126 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchIT.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchIT.java @@ -16,6 +16,7 @@ package com.nvidia.cuvs; import static com.carrotsearch.randomizedtesting.RandomizedTest.assumeTrue; +import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween; import static com.nvidia.cuvs.CuVSMatrixIT.assertSame2dArray; import static org.junit.Assert.*; @@ -133,18 +134,19 @@ public void testIndexingAndSearchingFlow() throws Throwable { int numTestsRuns = 5; try (CuVSResources resources = CheckedCuVSResources.create()) { for (int j = 0; j < numTestsRuns; j++) { - var index = indexOnce(CuVSMatrix.ofArray(dataset), resources); - var indexPath = serializeOnce(index); - var loadedIndex = deserializeOnce(indexPath, resources); - queryAndCompare( - index, - loadedIndex, - SearchResults.IDENTITY_MAPPING, - queries, - expectedResults, - resources); - cleanup(index, loadedIndex); - Files.deleteIfExists(indexPath); + try (var index = indexOnce(CuVSMatrix.ofArray(dataset), resources)) { + var indexPath = serializeOnce(index); + try (var loadedIndex = deserializeOnce(indexPath, resources)) { + queryAndCompare( + index, + loadedIndex, + SearchResults.IDENTITY_MAPPING, + queries, + expectedResults, + resources); + Files.deleteIfExists(indexPath); + } + } } } } @@ -163,19 +165,19 @@ public void testIndexingAndSearchingFlowInDifferentThreads() throws Throwable { for (int j = 0; j < numTestsRuns; j++) { runInAnotherThread( () -> { - try { - var index = indexOnce(CuVSMatrix.ofArray(dataset), resources); + try (var index = indexOnce(CuVSMatrix.ofArray(dataset), resources)) { var indexPath = serializeOnce(index); - var loadedIndex = deserializeOnce(indexPath, resources); - queryAndCompare( - index, - loadedIndex, - SearchResults.IDENTITY_MAPPING, - queries, - expectedResults, - resources); - cleanup(index, loadedIndex); - Files.deleteIfExists(indexPath); + try (var loadedIndex = deserializeOnce(indexPath, resources)) { + queryAndCompare( + index, + loadedIndex, + SearchResults.IDENTITY_MAPPING, + queries, + expectedResults, + resources); + } finally { + Files.deleteIfExists(indexPath); + } } catch (Throwable e) { throw new RuntimeException(e); } @@ -199,19 +201,20 @@ public void testIndexingAndSearchingFlowConcurrently() throws Throwable { numTestsRuns, () -> () -> { - try (CuVSResources resources = CheckedCuVSResources.create()) { - var index = indexOnce(CuVSMatrix.ofArray(dataset), resources); + try (CuVSResources resources = CheckedCuVSResources.create(); + var index = indexOnce(CuVSMatrix.ofArray(dataset), resources)) { var indexPath = serializeOnce(index); - var loadedIndex = deserializeOnce(indexPath, resources); - queryAndCompare( - index, - loadedIndex, - SearchResults.IDENTITY_MAPPING, - queries, - expectedResults, - resources); - cleanup(index, loadedIndex); - Files.deleteIfExists(indexPath); + try (var loadedIndex = deserializeOnce(indexPath, resources)) { + queryAndCompare( + index, + loadedIndex, + SearchResults.IDENTITY_MAPPING, + queries, + expectedResults, + resources); + } finally { + Files.deleteIfExists(indexPath); + } } catch (Throwable e) { throw new RuntimeException(e); } @@ -219,16 +222,31 @@ public void testIndexingAndSearchingFlowConcurrently() throws Throwable { } @Test - public void testIndexing() throws Throwable { - for (int i = 0; i < 100; ++i) { - final float[][] dataset = createSampleData(); - int numTestsRuns = 10; + public void testFloatIndexing() throws Throwable { + testIndexing( + () -> + CuVSMatrix.ofArray( + createFloatMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048)))); + } + + @Test + public void testByteIndexing() throws Throwable { + testIndexing( + () -> + CuVSMatrix.ofArray( + createByteMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048)))); + } + + private void testIndexing(Supplier matrixFactory) throws Exception { + for (int i = 0; i < 10; ++i) { + var dataset = matrixFactory.get(); + int numTestsRuns = 4; runConcurrently( numTestsRuns, () -> () -> { try (CuVSResources resources = CheckedCuVSResources.create()) { - var index = indexOnce(CuVSMatrix.ofArray(dataset), resources); + var index = indexOnce(dataset, resources); index.close(); } catch (Throwable e) { throw new RuntimeException(e); @@ -238,16 +256,31 @@ public void testIndexing() throws Throwable { } @Test - public void testSerialization() throws Throwable { - for (int i = 0; i < 100; ++i) { - final float[][] dataset = createSampleData(); - int numTestsRuns = 10; + public void testFloatSerialization() throws Throwable { + testSerialization( + () -> + CuVSMatrix.ofArray( + createFloatMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048)))); + } + + @Test + public void testByteSerialization() throws Throwable { + testSerialization( + () -> + CuVSMatrix.ofArray( + createByteMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048)))); + } + + private void testSerialization(Supplier matrixFactory) throws Throwable { + for (int i = 0; i < 10; ++i) { + final var dataset = matrixFactory.get(); + int numTestsRuns = 4; runConcurrently( numTestsRuns, () -> () -> { try (CuVSResources resources = CheckedCuVSResources.create(); - var index = indexOnce(CuVSMatrix.ofArray(dataset), resources)) { + var index = indexOnce(dataset, resources)) { var indexPath = serializeOnce(index); Files.deleteIfExists(indexPath); } catch (Throwable e) { @@ -258,22 +291,43 @@ public void testSerialization() throws Throwable { } @Test - public void testDeserialization() throws Throwable { - var indexPath = createSerializedIndex(CuVSMatrix.ofArray(createSampleData())); - for (int i = 0; i < 100; ++i) { - int numTestsRuns = 10; - runConcurrently( - numTestsRuns, - () -> - () -> { - try (CuVSResources resources = CheckedCuVSResources.create()) { - deserializeOnce(indexPath, resources).close(); - } catch (Throwable e) { - throw new RuntimeException(e); - } - }); + public void testFloatDeserialization() throws Throwable { + testDeserialization( + () -> + CuVSMatrix.ofArray( + createFloatMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048)))); + } + + @Test + public void testByteDeserialization() throws Throwable { + testDeserialization( + () -> + CuVSMatrix.ofArray( + createByteMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048)))); + } + + private void testDeserialization(Supplier matrixFactory) throws Throwable { + Path indexPath; + try (var dataset = matrixFactory.get()) { + indexPath = createSerializedIndex(dataset); + } + try { + for (int i = 0; i < 10; ++i) { + int numTestsRuns = 4; + runConcurrently( + numTestsRuns, + () -> + () -> { + try (CuVSResources resources = CheckedCuVSResources.create()) { + deserializeOnce(indexPath, resources).close(); + } catch (Throwable e) { + throw new RuntimeException(e); + } + }); + } + } finally { + Files.deleteIfExists(indexPath); } - Files.deleteIfExists(indexPath); } private Path createSerializedIndex(CuVSMatrix dataset) throws Throwable { @@ -335,13 +389,14 @@ public void testIndexingAndSearchingFlowWithCustomMappingFunction() throws Throw Map.of(2, 0.15224178f, 1, 0.59063464f, 0, 0.5986642f)); LongToIntFunction rotate = l -> (int) ((l + 1) % dataset.size()); - try (CuVSResources resources = CheckedCuVSResources.create()) { - var index = indexOnce(dataset, resources); + try (CuVSResources resources = CheckedCuVSResources.create(); + var index = indexOnce(dataset, resources)) { var indexPath = serializeOnce(index); - var loadedIndex = deserializeOnce(indexPath, resources); - queryAndCompare(index, loadedIndex, rotate, queries, expectedResults, resources); - cleanup(index, loadedIndex); - Files.deleteIfExists(indexPath); + try (var loadedIndex = deserializeOnce(indexPath, resources)) { + queryAndCompare(index, loadedIndex, rotate, queries, expectedResults, resources); + } finally { + Files.deleteIfExists(indexPath); + } } } @@ -358,13 +413,14 @@ public void testIndexingAndSearchingFlowWithCustomMappingList() throws Throwable Map.of(3, 0.15224178f, 4, 0.59063464f, 1, 0.5986642f)); LongToIntFunction rotate = SearchResults.mappingsFromList(mappings); - try (CuVSResources resources = CheckedCuVSResources.create()) { - var index = indexOnce(dataset, resources); + try (CuVSResources resources = CheckedCuVSResources.create(); + var index = indexOnce(dataset, resources)) { var indexPath = serializeOnce(index); - var loadedIndex = deserializeOnce(indexPath, resources); - queryAndCompare(index, loadedIndex, rotate, queries, expectedResults, resources); - cleanup(index, loadedIndex); - Files.deleteIfExists(indexPath); + try (var loadedIndex = deserializeOnce(indexPath, resources)) { + queryAndCompare(index, loadedIndex, rotate, queries, expectedResults, resources); + } finally { + Files.deleteIfExists(indexPath); + } } } @@ -511,12 +567,6 @@ private void queryAndCompare( } } - private void cleanup(CagraIndex index, CagraIndex loadedIndex) throws Throwable { - // Cleanup - index.close(); - loadedIndex.close(); - } - /** * Tests that an index built starting from a native MemorySegment is identical to one built from * Java heap arrays @@ -545,11 +595,10 @@ public void testNativeDatasetEquivalent() throws Throwable { var javaDataset = CuVSMatrix.ofArray(sampleData); var nativeDataset = DatasetHelper.fromMemorySegment( - dataMemorySegment, rows, cols, CuVSMatrix.DataType.FLOAT)) { - - // Indexing with an on-heap and native datasets produce the same results - var javaIndex = indexOnce(javaDataset, resources); - var nativeIndex = indexOnce(nativeDataset, resources); + dataMemorySegment, rows, cols, CuVSMatrix.DataType.FLOAT); + // Indexing with an on-heap and native datasets produce the same results + var javaIndex = indexOnce(javaDataset, resources); + var nativeIndex = indexOnce(nativeDataset, resources)) { queryAndCompare( javaIndex, nativeIndex, @@ -557,7 +606,6 @@ public void testNativeDatasetEquivalent() throws Throwable { queries, expectedResults, resources); - cleanup(javaIndex, nativeIndex); } } } diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSMatrixIT.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSMatrixIT.java index 9e46845110..70a8fdbd04 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSMatrixIT.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSMatrixIT.java @@ -49,40 +49,6 @@ public void setup() { {0, 4, 2} }; - private int[][] createIntMatrix() { - int rows = randomIntBetween(1, 32); - int cols = randomIntBetween(1, 100); - - int[][] result = new int[rows][cols]; - - for (int r = 0; r < rows; ++r) { - for (int c = 0; c < cols; ++c) { - result[r][c] = randomInt(); - } - } - return result; - } - - private float[][] createFloatMatrix() { - int rows = randomIntBetween(1, 32); - int cols = randomIntBetween(1, 100); - - return createFloatMatrix(rows, cols); - } - - private float[][] createFloatMatrix(int rows, int cols) { - float[][] result = new float[rows][cols]; - - float value = 1; - for (int r = 0; r < rows; ++r) { - for (int c = 0; c < cols; ++c) { - result[r][c] = value; - value += 1; - } - } - return result; - } - private void testByteDatasetRowGetAccess(CuVSMatrix dataset) { for (int n = 0; n < dataset.size(); ++n) { var row = dataset.getRow(n); @@ -505,8 +471,7 @@ public void testHostToHostWithDifferentStrides() { } @Test - public void testHostBuilderWithDifferentStrides() throws Throwable { - + public void testHostBuilderWithDifferentStrides() { int size = randomIntBetween(1, 32 * 1024); int columns = randomIntBetween(16, 2048); int rowStride1 = randomIntBetween(columns, columns * 2); diff --git a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java index fee3c87cca..a3b42d47f3 100644 --- a/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java +++ b/java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSTestCase.java @@ -15,6 +15,7 @@ */ package com.nvidia.cuvs; +import static com.carrotsearch.randomizedtesting.RandomizedTest.*; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -58,7 +59,7 @@ protected List> generateExpectedResults( Map distances = new TreeMap<>(); for (int j = 0; j < dataset.length; j++) { double distance = 0; - if (prefilters != null && prefilters[q].get(j) == false) { + if (prefilters != null && !prefilters[q].get(j)) { distance = Double.POSITIVE_INFINITY; } else { for (int k = 0; k < dimensions; k++) { @@ -124,14 +125,18 @@ protected static void checkResults( List> sortedActual = new ArrayList>(); for (Map map : expected) { sortedExpected.add( - new TreeMap(map) { + new TreeMap<>(map) { @Override public boolean equals(Object o) { - Map map = (Map) o; + if (!(o instanceof Map)) { + return false; + } + @SuppressWarnings("unchecked") + var map = (Map) o; if (this.size() != map.size()) return false; for (Integer key : map.keySet()) { try { - if (Math.abs((float) map.get(key) - ((float) get(key))) < 0.0001f == false) { + if (Math.abs(map.get(key) - ((float) get(key))) >= 0.0001f) { return false; } } catch (Exception ex) { @@ -143,7 +148,7 @@ public boolean equals(Object o) { }); } for (Map map : actual) { - sortedActual.add(new TreeMap(map)); + sortedActual.add(new TreeMap<>(map)); } assertEquals(sortedExpected, sortedActual); } @@ -152,4 +157,58 @@ protected static boolean isLinuxAmd64() { String name = System.getProperty("os.name"); return (name.startsWith("Linux")) && System.getProperty("os.arch").equals("amd64"); } + + protected static int[][] createIntMatrix() { + int rows = randomIntBetween(1, 32); + int cols = randomIntBetween(1, 100); + + return createIntMatrix(rows, cols); + } + + protected static int[][] createIntMatrix(int rows, int cols) { + int[][] result = new int[rows][cols]; + + for (int r = 0; r < rows; ++r) { + for (int c = 0; c < cols; ++c) { + result[r][c] = randomInt(); + } + } + return result; + } + + protected static byte[][] createByteMatrix() { + int rows = randomIntBetween(1, 32); + int cols = randomIntBetween(1, 100); + + return createByteMatrix(rows, cols); + } + + protected static byte[][] createByteMatrix(int rows, int cols) { + byte[][] result = new byte[rows][cols]; + + for (int r = 0; r < rows; ++r) { + for (int c = 0; c < cols; ++c) { + result[r][c] = randomByte(); + } + } + return result; + } + + protected static float[][] createFloatMatrix() { + int rows = randomIntBetween(1, 32); + int cols = randomIntBetween(1, 100); + + return createFloatMatrix(rows, cols); + } + + protected static float[][] createFloatMatrix(int rows, int cols) { + float[][] result = new float[rows][cols]; + + for (int r = 0; r < rows; ++r) { + for (int c = 0; c < cols; ++c) { + result[r][c] = randomFloat(); + } + } + return result; + } }