Skip to content

Commit 8c775c7

Browse files
committed
Test indexing and serialization with integral (byte) dataset
# Conflicts: # java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchIT.java # java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSMatrixIT.java
1 parent a05de3e commit 8c775c7

3 files changed

Lines changed: 198 additions & 126 deletions

File tree

java/cuvs-java/src/test/java/com/nvidia/cuvs/CagraBuildAndSearchIT.java

Lines changed: 133 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package com.nvidia.cuvs;
1717

1818
import static com.carrotsearch.randomizedtesting.RandomizedTest.assumeTrue;
19+
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
1920
import static com.nvidia.cuvs.CuVSMatrixIT.assertSame2dArray;
2021
import static org.junit.Assert.*;
2122

@@ -133,18 +134,19 @@ public void testIndexingAndSearchingFlow() throws Throwable {
133134
int numTestsRuns = 5;
134135
try (CuVSResources resources = CheckedCuVSResources.create()) {
135136
for (int j = 0; j < numTestsRuns; j++) {
136-
var index = indexOnce(CuVSMatrix.ofArray(dataset), resources);
137-
var indexPath = serializeOnce(index);
138-
var loadedIndex = deserializeOnce(indexPath, resources);
139-
queryAndCompare(
140-
index,
141-
loadedIndex,
142-
SearchResults.IDENTITY_MAPPING,
143-
queries,
144-
expectedResults,
145-
resources);
146-
cleanup(index, loadedIndex);
147-
Files.deleteIfExists(indexPath);
137+
try (var index = indexOnce(CuVSMatrix.ofArray(dataset), resources)) {
138+
var indexPath = serializeOnce(index);
139+
try (var loadedIndex = deserializeOnce(indexPath, resources)) {
140+
queryAndCompare(
141+
index,
142+
loadedIndex,
143+
SearchResults.IDENTITY_MAPPING,
144+
queries,
145+
expectedResults,
146+
resources);
147+
Files.deleteIfExists(indexPath);
148+
}
149+
}
148150
}
149151
}
150152
}
@@ -163,19 +165,19 @@ public void testIndexingAndSearchingFlowInDifferentThreads() throws Throwable {
163165
for (int j = 0; j < numTestsRuns; j++) {
164166
runInAnotherThread(
165167
() -> {
166-
try {
167-
var index = indexOnce(CuVSMatrix.ofArray(dataset), resources);
168+
try (var index = indexOnce(CuVSMatrix.ofArray(dataset), resources)) {
168169
var indexPath = serializeOnce(index);
169-
var loadedIndex = deserializeOnce(indexPath, resources);
170-
queryAndCompare(
171-
index,
172-
loadedIndex,
173-
SearchResults.IDENTITY_MAPPING,
174-
queries,
175-
expectedResults,
176-
resources);
177-
cleanup(index, loadedIndex);
178-
Files.deleteIfExists(indexPath);
170+
try (var loadedIndex = deserializeOnce(indexPath, resources)) {
171+
queryAndCompare(
172+
index,
173+
loadedIndex,
174+
SearchResults.IDENTITY_MAPPING,
175+
queries,
176+
expectedResults,
177+
resources);
178+
} finally {
179+
Files.deleteIfExists(indexPath);
180+
}
179181
} catch (Throwable e) {
180182
throw new RuntimeException(e);
181183
}
@@ -199,36 +201,52 @@ public void testIndexingAndSearchingFlowConcurrently() throws Throwable {
199201
numTestsRuns,
200202
() ->
201203
() -> {
202-
try (CuVSResources resources = CheckedCuVSResources.create()) {
203-
var index = indexOnce(CuVSMatrix.ofArray(dataset), resources);
204+
try (CuVSResources resources = CheckedCuVSResources.create();
205+
var index = indexOnce(CuVSMatrix.ofArray(dataset), resources)) {
204206
var indexPath = serializeOnce(index);
205-
var loadedIndex = deserializeOnce(indexPath, resources);
206-
queryAndCompare(
207-
index,
208-
loadedIndex,
209-
SearchResults.IDENTITY_MAPPING,
210-
queries,
211-
expectedResults,
212-
resources);
213-
cleanup(index, loadedIndex);
214-
Files.deleteIfExists(indexPath);
207+
try (var loadedIndex = deserializeOnce(indexPath, resources)) {
208+
queryAndCompare(
209+
index,
210+
loadedIndex,
211+
SearchResults.IDENTITY_MAPPING,
212+
queries,
213+
expectedResults,
214+
resources);
215+
} finally {
216+
Files.deleteIfExists(indexPath);
217+
}
215218
} catch (Throwable e) {
216219
throw new RuntimeException(e);
217220
}
218221
});
219222
}
220223

221224
@Test
222-
public void testIndexing() throws Throwable {
223-
for (int i = 0; i < 100; ++i) {
224-
final float[][] dataset = createSampleData();
225-
int numTestsRuns = 10;
225+
public void testFloatIndexing() throws Throwable {
226+
testIndexing(
227+
() ->
228+
CuVSMatrix.ofArray(
229+
createFloatMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048))));
230+
}
231+
232+
@Test
233+
public void testByteIndexing() throws Throwable {
234+
testIndexing(
235+
() ->
236+
CuVSMatrix.ofArray(
237+
createByteMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048))));
238+
}
239+
240+
private void testIndexing(Supplier<CuVSMatrix> matrixFactory) throws Exception {
241+
for (int i = 0; i < 10; ++i) {
242+
var dataset = matrixFactory.get();
243+
int numTestsRuns = 4;
226244
runConcurrently(
227245
numTestsRuns,
228246
() ->
229247
() -> {
230248
try (CuVSResources resources = CheckedCuVSResources.create()) {
231-
var index = indexOnce(CuVSMatrix.ofArray(dataset), resources);
249+
var index = indexOnce(dataset, resources);
232250
index.close();
233251
} catch (Throwable e) {
234252
throw new RuntimeException(e);
@@ -238,16 +256,31 @@ public void testIndexing() throws Throwable {
238256
}
239257

240258
@Test
241-
public void testSerialization() throws Throwable {
242-
for (int i = 0; i < 100; ++i) {
243-
final float[][] dataset = createSampleData();
244-
int numTestsRuns = 10;
259+
public void testFloatSerialization() throws Throwable {
260+
testSerialization(
261+
() ->
262+
CuVSMatrix.ofArray(
263+
createFloatMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048))));
264+
}
265+
266+
@Test
267+
public void testByteSerialization() throws Throwable {
268+
testSerialization(
269+
() ->
270+
CuVSMatrix.ofArray(
271+
createByteMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048))));
272+
}
273+
274+
private void testSerialization(Supplier<CuVSMatrix> matrixFactory) throws Throwable {
275+
for (int i = 0; i < 10; ++i) {
276+
final var dataset = matrixFactory.get();
277+
int numTestsRuns = 4;
245278
runConcurrently(
246279
numTestsRuns,
247280
() ->
248281
() -> {
249282
try (CuVSResources resources = CheckedCuVSResources.create();
250-
var index = indexOnce(CuVSMatrix.ofArray(dataset), resources)) {
283+
var index = indexOnce(dataset, resources)) {
251284
var indexPath = serializeOnce(index);
252285
Files.deleteIfExists(indexPath);
253286
} catch (Throwable e) {
@@ -258,22 +291,43 @@ public void testSerialization() throws Throwable {
258291
}
259292

260293
@Test
261-
public void testDeserialization() throws Throwable {
262-
var indexPath = createSerializedIndex(CuVSMatrix.ofArray(createSampleData()));
263-
for (int i = 0; i < 100; ++i) {
264-
int numTestsRuns = 10;
265-
runConcurrently(
266-
numTestsRuns,
267-
() ->
268-
() -> {
269-
try (CuVSResources resources = CheckedCuVSResources.create()) {
270-
deserializeOnce(indexPath, resources).close();
271-
} catch (Throwable e) {
272-
throw new RuntimeException(e);
273-
}
274-
});
294+
public void testFloatDeserialization() throws Throwable {
295+
testDeserialization(
296+
() ->
297+
CuVSMatrix.ofArray(
298+
createFloatMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048))));
299+
}
300+
301+
@Test
302+
public void testByteDeserialization() throws Throwable {
303+
testDeserialization(
304+
() ->
305+
CuVSMatrix.ofArray(
306+
createByteMatrix(randomIntBetween(2, 1024), randomIntBetween(2, 2048))));
307+
}
308+
309+
private void testDeserialization(Supplier<CuVSMatrix> matrixFactory) throws Throwable {
310+
Path indexPath;
311+
try (var dataset = matrixFactory.get()) {
312+
indexPath = createSerializedIndex(dataset);
313+
}
314+
try {
315+
for (int i = 0; i < 10; ++i) {
316+
int numTestsRuns = 4;
317+
runConcurrently(
318+
numTestsRuns,
319+
() ->
320+
() -> {
321+
try (CuVSResources resources = CheckedCuVSResources.create()) {
322+
deserializeOnce(indexPath, resources).close();
323+
} catch (Throwable e) {
324+
throw new RuntimeException(e);
325+
}
326+
});
327+
}
328+
} finally {
329+
Files.deleteIfExists(indexPath);
275330
}
276-
Files.deleteIfExists(indexPath);
277331
}
278332

279333
private Path createSerializedIndex(CuVSMatrix dataset) throws Throwable {
@@ -335,13 +389,14 @@ public void testIndexingAndSearchingFlowWithCustomMappingFunction() throws Throw
335389
Map.of(2, 0.15224178f, 1, 0.59063464f, 0, 0.5986642f));
336390

337391
LongToIntFunction rotate = l -> (int) ((l + 1) % dataset.size());
338-
try (CuVSResources resources = CheckedCuVSResources.create()) {
339-
var index = indexOnce(dataset, resources);
392+
try (CuVSResources resources = CheckedCuVSResources.create();
393+
var index = indexOnce(dataset, resources)) {
340394
var indexPath = serializeOnce(index);
341-
var loadedIndex = deserializeOnce(indexPath, resources);
342-
queryAndCompare(index, loadedIndex, rotate, queries, expectedResults, resources);
343-
cleanup(index, loadedIndex);
344-
Files.deleteIfExists(indexPath);
395+
try (var loadedIndex = deserializeOnce(indexPath, resources)) {
396+
queryAndCompare(index, loadedIndex, rotate, queries, expectedResults, resources);
397+
} finally {
398+
Files.deleteIfExists(indexPath);
399+
}
345400
}
346401
}
347402

@@ -358,13 +413,14 @@ public void testIndexingAndSearchingFlowWithCustomMappingList() throws Throwable
358413
Map.of(3, 0.15224178f, 4, 0.59063464f, 1, 0.5986642f));
359414

360415
LongToIntFunction rotate = SearchResults.mappingsFromList(mappings);
361-
try (CuVSResources resources = CheckedCuVSResources.create()) {
362-
var index = indexOnce(dataset, resources);
416+
try (CuVSResources resources = CheckedCuVSResources.create();
417+
var index = indexOnce(dataset, resources)) {
363418
var indexPath = serializeOnce(index);
364-
var loadedIndex = deserializeOnce(indexPath, resources);
365-
queryAndCompare(index, loadedIndex, rotate, queries, expectedResults, resources);
366-
cleanup(index, loadedIndex);
367-
Files.deleteIfExists(indexPath);
419+
try (var loadedIndex = deserializeOnce(indexPath, resources)) {
420+
queryAndCompare(index, loadedIndex, rotate, queries, expectedResults, resources);
421+
} finally {
422+
Files.deleteIfExists(indexPath);
423+
}
368424
}
369425
}
370426

@@ -511,12 +567,6 @@ private void queryAndCompare(
511567
}
512568
}
513569

514-
private void cleanup(CagraIndex index, CagraIndex loadedIndex) throws Throwable {
515-
// Cleanup
516-
index.close();
517-
loadedIndex.close();
518-
}
519-
520570
/**
521571
* Tests that an index built starting from a native MemorySegment is identical to one built from
522572
* Java heap arrays
@@ -545,19 +595,17 @@ public void testNativeDatasetEquivalent() throws Throwable {
545595
var javaDataset = CuVSMatrix.ofArray(sampleData);
546596
var nativeDataset =
547597
DatasetHelper.fromMemorySegment(
548-
dataMemorySegment, rows, cols, CuVSMatrix.DataType.FLOAT)) {
549-
550-
// Indexing with an on-heap and native datasets produce the same results
551-
var javaIndex = indexOnce(javaDataset, resources);
552-
var nativeIndex = indexOnce(nativeDataset, resources);
598+
dataMemorySegment, rows, cols, CuVSMatrix.DataType.FLOAT);
599+
// Indexing with an on-heap and native datasets produce the same results
600+
var javaIndex = indexOnce(javaDataset, resources);
601+
var nativeIndex = indexOnce(nativeDataset, resources)) {
553602
queryAndCompare(
554603
javaIndex,
555604
nativeIndex,
556605
SearchResults.IDENTITY_MAPPING,
557606
queries,
558607
expectedResults,
559608
resources);
560-
cleanup(javaIndex, nativeIndex);
561609
}
562610
}
563611
}

java/cuvs-java/src/test/java/com/nvidia/cuvs/CuVSMatrixIT.java

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -49,40 +49,6 @@ public void setup() {
4949
{0, 4, 2}
5050
};
5151

52-
private int[][] createIntMatrix() {
53-
int rows = randomIntBetween(1, 32);
54-
int cols = randomIntBetween(1, 100);
55-
56-
int[][] result = new int[rows][cols];
57-
58-
for (int r = 0; r < rows; ++r) {
59-
for (int c = 0; c < cols; ++c) {
60-
result[r][c] = randomInt();
61-
}
62-
}
63-
return result;
64-
}
65-
66-
private float[][] createFloatMatrix() {
67-
int rows = randomIntBetween(1, 32);
68-
int cols = randomIntBetween(1, 100);
69-
70-
return createFloatMatrix(rows, cols);
71-
}
72-
73-
private float[][] createFloatMatrix(int rows, int cols) {
74-
float[][] result = new float[rows][cols];
75-
76-
float value = 1;
77-
for (int r = 0; r < rows; ++r) {
78-
for (int c = 0; c < cols; ++c) {
79-
result[r][c] = value;
80-
value += 1;
81-
}
82-
}
83-
return result;
84-
}
85-
8652
private void testByteDatasetRowGetAccess(CuVSMatrix dataset) {
8753
for (int n = 0; n < dataset.size(); ++n) {
8854
var row = dataset.getRow(n);
@@ -505,8 +471,7 @@ public void testHostToHostWithDifferentStrides() {
505471
}
506472

507473
@Test
508-
public void testHostBuilderWithDifferentStrides() throws Throwable {
509-
474+
public void testHostBuilderWithDifferentStrides() {
510475
int size = randomIntBetween(1, 32 * 1024);
511476
int columns = randomIntBetween(16, 2048);
512477
int rowStride1 = randomIntBetween(columns, columns * 2);

0 commit comments

Comments
 (0)