1616package com .nvidia .cuvs ;
1717
1818import static com .carrotsearch .randomizedtesting .RandomizedTest .assumeTrue ;
19+ import static com .carrotsearch .randomizedtesting .RandomizedTest .randomIntBetween ;
1920import static com .nvidia .cuvs .CuVSMatrixIT .assertSame2dArray ;
2021import static org .junit .Assert .*;
2122
@@ -133,18 +134,19 @@ public void testIndexingAndSearchingFlow() throws Throwable {
133134 int numTestsRuns = 5 ;
134135 try (CuVSResources resources = CheckedCuVSResources .create ()) {
135136 for (int j = 0 ; j < numTestsRuns ; j ++) {
136- var index = indexOnce (CuVSMatrix .ofArray (dataset ), resources );
137- var indexPath = serializeOnce (index );
138- var loadedIndex = deserializeOnce (indexPath , resources );
139- queryAndCompare (
140- index ,
141- loadedIndex ,
142- SearchResults .IDENTITY_MAPPING ,
143- queries ,
144- expectedResults ,
145- resources );
146- cleanup (index , loadedIndex );
147- Files .deleteIfExists (indexPath );
137+ try (var index = indexOnce (CuVSMatrix .ofArray (dataset ), resources )) {
138+ var indexPath = serializeOnce (index );
139+ try (var loadedIndex = deserializeOnce (indexPath , resources )) {
140+ queryAndCompare (
141+ index ,
142+ loadedIndex ,
143+ SearchResults .IDENTITY_MAPPING ,
144+ queries ,
145+ expectedResults ,
146+ resources );
147+ Files .deleteIfExists (indexPath );
148+ }
149+ }
148150 }
149151 }
150152 }
@@ -163,19 +165,19 @@ public void testIndexingAndSearchingFlowInDifferentThreads() throws Throwable {
163165 for (int j = 0 ; j < numTestsRuns ; j ++) {
164166 runInAnotherThread (
165167 () -> {
166- try {
167- var index = indexOnce (CuVSMatrix .ofArray (dataset ), resources );
168+ try (var index = indexOnce (CuVSMatrix .ofArray (dataset ), resources )) {
168169 var indexPath = serializeOnce (index );
169- var loadedIndex = deserializeOnce (indexPath , resources );
170- queryAndCompare (
171- index ,
172- loadedIndex ,
173- SearchResults .IDENTITY_MAPPING ,
174- queries ,
175- expectedResults ,
176- resources );
177- cleanup (index , loadedIndex );
178- Files .deleteIfExists (indexPath );
170+ try (var loadedIndex = deserializeOnce (indexPath , resources )) {
171+ queryAndCompare (
172+ index ,
173+ loadedIndex ,
174+ SearchResults .IDENTITY_MAPPING ,
175+ queries ,
176+ expectedResults ,
177+ resources );
178+ } finally {
179+ Files .deleteIfExists (indexPath );
180+ }
179181 } catch (Throwable e ) {
180182 throw new RuntimeException (e );
181183 }
@@ -199,36 +201,52 @@ public void testIndexingAndSearchingFlowConcurrently() throws Throwable {
199201 numTestsRuns ,
200202 () ->
201203 () -> {
202- try (CuVSResources resources = CheckedCuVSResources .create ()) {
203- var index = indexOnce (CuVSMatrix .ofArray (dataset ), resources );
204+ try (CuVSResources resources = CheckedCuVSResources .create ();
205+ var index = indexOnce (CuVSMatrix .ofArray (dataset ), resources )) {
204206 var indexPath = serializeOnce (index );
205- var loadedIndex = deserializeOnce (indexPath , resources );
206- queryAndCompare (
207- index ,
208- loadedIndex ,
209- SearchResults .IDENTITY_MAPPING ,
210- queries ,
211- expectedResults ,
212- resources );
213- cleanup (index , loadedIndex );
214- Files .deleteIfExists (indexPath );
207+ try (var loadedIndex = deserializeOnce (indexPath , resources )) {
208+ queryAndCompare (
209+ index ,
210+ loadedIndex ,
211+ SearchResults .IDENTITY_MAPPING ,
212+ queries ,
213+ expectedResults ,
214+ resources );
215+ } finally {
216+ Files .deleteIfExists (indexPath );
217+ }
215218 } catch (Throwable e ) {
216219 throw new RuntimeException (e );
217220 }
218221 });
219222 }
220223
221224 @ Test
222- public void testIndexing () throws Throwable {
223- for (int i = 0 ; i < 100 ; ++i ) {
224- final float [][] dataset = createSampleData ();
225- int numTestsRuns = 10 ;
225+ public void testFloatIndexing () throws Throwable {
226+ testIndexing (
227+ () ->
228+ CuVSMatrix .ofArray (
229+ createFloatMatrix (randomIntBetween (2 , 1024 ), randomIntBetween (2 , 2048 ))));
230+ }
231+
232+ @ Test
233+ public void testByteIndexing () throws Throwable {
234+ testIndexing (
235+ () ->
236+ CuVSMatrix .ofArray (
237+ createByteMatrix (randomIntBetween (2 , 1024 ), randomIntBetween (2 , 2048 ))));
238+ }
239+
240+ private void testIndexing (Supplier <CuVSMatrix > matrixFactory ) throws Exception {
241+ for (int i = 0 ; i < 10 ; ++i ) {
242+ var dataset = matrixFactory .get ();
243+ int numTestsRuns = 4 ;
226244 runConcurrently (
227245 numTestsRuns ,
228246 () ->
229247 () -> {
230248 try (CuVSResources resources = CheckedCuVSResources .create ()) {
231- var index = indexOnce (CuVSMatrix . ofArray ( dataset ) , resources );
249+ var index = indexOnce (dataset , resources );
232250 index .close ();
233251 } catch (Throwable e ) {
234252 throw new RuntimeException (e );
@@ -238,16 +256,31 @@ public void testIndexing() throws Throwable {
238256 }
239257
240258 @ Test
241- public void testSerialization () throws Throwable {
242- for (int i = 0 ; i < 100 ; ++i ) {
243- final float [][] dataset = createSampleData ();
244- int numTestsRuns = 10 ;
259+ public void testFloatSerialization () throws Throwable {
260+ testSerialization (
261+ () ->
262+ CuVSMatrix .ofArray (
263+ createFloatMatrix (randomIntBetween (2 , 1024 ), randomIntBetween (2 , 2048 ))));
264+ }
265+
266+ @ Test
267+ public void testByteSerialization () throws Throwable {
268+ testSerialization (
269+ () ->
270+ CuVSMatrix .ofArray (
271+ createByteMatrix (randomIntBetween (2 , 1024 ), randomIntBetween (2 , 2048 ))));
272+ }
273+
274+ private void testSerialization (Supplier <CuVSMatrix > matrixFactory ) throws Throwable {
275+ for (int i = 0 ; i < 10 ; ++i ) {
276+ final var dataset = matrixFactory .get ();
277+ int numTestsRuns = 4 ;
245278 runConcurrently (
246279 numTestsRuns ,
247280 () ->
248281 () -> {
249282 try (CuVSResources resources = CheckedCuVSResources .create ();
250- var index = indexOnce (CuVSMatrix . ofArray ( dataset ) , resources )) {
283+ var index = indexOnce (dataset , resources )) {
251284 var indexPath = serializeOnce (index );
252285 Files .deleteIfExists (indexPath );
253286 } catch (Throwable e ) {
@@ -258,22 +291,43 @@ public void testSerialization() throws Throwable {
258291 }
259292
260293 @ Test
261- public void testDeserialization () throws Throwable {
262- var indexPath = createSerializedIndex (CuVSMatrix .ofArray (createSampleData ()));
263- for (int i = 0 ; i < 100 ; ++i ) {
264- int numTestsRuns = 10 ;
265- runConcurrently (
266- numTestsRuns ,
267- () ->
268- () -> {
269- try (CuVSResources resources = CheckedCuVSResources .create ()) {
270- deserializeOnce (indexPath , resources ).close ();
271- } catch (Throwable e ) {
272- throw new RuntimeException (e );
273- }
274- });
294+ public void testFloatDeserialization () throws Throwable {
295+ testDeserialization (
296+ () ->
297+ CuVSMatrix .ofArray (
298+ createFloatMatrix (randomIntBetween (2 , 1024 ), randomIntBetween (2 , 2048 ))));
299+ }
300+
301+ @ Test
302+ public void testByteDeserialization () throws Throwable {
303+ testDeserialization (
304+ () ->
305+ CuVSMatrix .ofArray (
306+ createByteMatrix (randomIntBetween (2 , 1024 ), randomIntBetween (2 , 2048 ))));
307+ }
308+
309+ private void testDeserialization (Supplier <CuVSMatrix > matrixFactory ) throws Throwable {
310+ Path indexPath ;
311+ try (var dataset = matrixFactory .get ()) {
312+ indexPath = createSerializedIndex (dataset );
313+ }
314+ try {
315+ for (int i = 0 ; i < 10 ; ++i ) {
316+ int numTestsRuns = 4 ;
317+ runConcurrently (
318+ numTestsRuns ,
319+ () ->
320+ () -> {
321+ try (CuVSResources resources = CheckedCuVSResources .create ()) {
322+ deserializeOnce (indexPath , resources ).close ();
323+ } catch (Throwable e ) {
324+ throw new RuntimeException (e );
325+ }
326+ });
327+ }
328+ } finally {
329+ Files .deleteIfExists (indexPath );
275330 }
276- Files .deleteIfExists (indexPath );
277331 }
278332
279333 private Path createSerializedIndex (CuVSMatrix dataset ) throws Throwable {
@@ -335,13 +389,14 @@ public void testIndexingAndSearchingFlowWithCustomMappingFunction() throws Throw
335389 Map .of (2 , 0.15224178f , 1 , 0.59063464f , 0 , 0.5986642f ));
336390
337391 LongToIntFunction rotate = l -> (int ) ((l + 1 ) % dataset .size ());
338- try (CuVSResources resources = CheckedCuVSResources .create ()) {
339- var index = indexOnce (dataset , resources );
392+ try (CuVSResources resources = CheckedCuVSResources .create ();
393+ var index = indexOnce (dataset , resources )) {
340394 var indexPath = serializeOnce (index );
341- var loadedIndex = deserializeOnce (indexPath , resources );
342- queryAndCompare (index , loadedIndex , rotate , queries , expectedResults , resources );
343- cleanup (index , loadedIndex );
344- Files .deleteIfExists (indexPath );
395+ try (var loadedIndex = deserializeOnce (indexPath , resources )) {
396+ queryAndCompare (index , loadedIndex , rotate , queries , expectedResults , resources );
397+ } finally {
398+ Files .deleteIfExists (indexPath );
399+ }
345400 }
346401 }
347402
@@ -358,13 +413,14 @@ public void testIndexingAndSearchingFlowWithCustomMappingList() throws Throwable
358413 Map .of (3 , 0.15224178f , 4 , 0.59063464f , 1 , 0.5986642f ));
359414
360415 LongToIntFunction rotate = SearchResults .mappingsFromList (mappings );
361- try (CuVSResources resources = CheckedCuVSResources .create ()) {
362- var index = indexOnce (dataset , resources );
416+ try (CuVSResources resources = CheckedCuVSResources .create ();
417+ var index = indexOnce (dataset , resources )) {
363418 var indexPath = serializeOnce (index );
364- var loadedIndex = deserializeOnce (indexPath , resources );
365- queryAndCompare (index , loadedIndex , rotate , queries , expectedResults , resources );
366- cleanup (index , loadedIndex );
367- Files .deleteIfExists (indexPath );
419+ try (var loadedIndex = deserializeOnce (indexPath , resources )) {
420+ queryAndCompare (index , loadedIndex , rotate , queries , expectedResults , resources );
421+ } finally {
422+ Files .deleteIfExists (indexPath );
423+ }
368424 }
369425 }
370426
@@ -511,12 +567,6 @@ private void queryAndCompare(
511567 }
512568 }
513569
514- private void cleanup (CagraIndex index , CagraIndex loadedIndex ) throws Throwable {
515- // Cleanup
516- index .close ();
517- loadedIndex .close ();
518- }
519-
520570 /**
521571 * Tests that an index built starting from a native MemorySegment is identical to one built from
522572 * Java heap arrays
@@ -545,19 +595,17 @@ public void testNativeDatasetEquivalent() throws Throwable {
545595 var javaDataset = CuVSMatrix .ofArray (sampleData );
546596 var nativeDataset =
547597 DatasetHelper .fromMemorySegment (
548- dataMemorySegment , rows , cols , CuVSMatrix .DataType .FLOAT )) {
549-
550- // Indexing with an on-heap and native datasets produce the same results
551- var javaIndex = indexOnce (javaDataset , resources );
552- var nativeIndex = indexOnce (nativeDataset , resources );
598+ dataMemorySegment , rows , cols , CuVSMatrix .DataType .FLOAT );
599+ // Indexing with an on-heap and native datasets produce the same results
600+ var javaIndex = indexOnce (javaDataset , resources );
601+ var nativeIndex = indexOnce (nativeDataset , resources )) {
553602 queryAndCompare (
554603 javaIndex ,
555604 nativeIndex ,
556605 SearchResults .IDENTITY_MAPPING ,
557606 queries ,
558607 expectedResults ,
559608 resources );
560- cleanup (javaIndex , nativeIndex );
561609 }
562610 }
563611 }
0 commit comments