3535
3636class TestComputeGT (unittest .TestCase ):
3737
38- def do_test_compute_GT (self , metric = faiss .METRIC_L2 ):
38+ def do_test_compute_GT (self , metric = faiss .METRIC_L2 , ngpu = 0 ):
3939 d = 64
4040 xt , xb , xq = get_dataset_2 (d , 0 , 10000 , 100 )
4141
@@ -50,7 +50,7 @@ def matrix_iterator(xb, bs):
5050 yield xb [i0 :i0 + bs ]
5151
5252 Dnew , Inew = knn_ground_truth (
53- xq , matrix_iterator (xb , 1000 ), 10 , metric , ngpu = 0 )
53+ xq , matrix_iterator (xb , 1000 ), 10 , metric , ngpu = ngpu )
5454
5555 np .testing .assert_array_equal (Iref , Inew )
5656 # decimal = 4 required when run on GPU
@@ -62,6 +62,12 @@ def test_compute_GT(self):
6262 def test_compute_GT_ip (self ):
6363 self .do_test_compute_GT (faiss .METRIC_INNER_PRODUCT )
6464
65+ def test_compute_GT_gpu (self ):
66+ self .do_test_compute_GT (ngpu = - 1 )
67+
68+ def test_compute_GT_ip_gpu (self ):
69+ self .do_test_compute_GT (faiss .METRIC_INNER_PRODUCT , ngpu = - 1 )
70+
6571
6672class TestDatasets (unittest .TestCase ):
6773 """here we test only the synthetic dataset. Datasets that require
0 commit comments