diff --git a/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx b/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx index 62486e46d3..afae653f35 100644 --- a/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx +++ b/python/cuvs/cuvs/neighbors/nn_descent/nn_descent.pyx @@ -223,7 +223,7 @@ def build(IndexParams index_params, dataset, graph=None, resources=None): >>> graph = index.graph """ dataset_ai = wrap_array(dataset) - _check_input_array(dataset_ai, [np.dtype('float32'), np.dtype('byte'), + _check_input_array(dataset_ai, [np.dtype('float32'), np.dtype('float16'), np.dtype('byte'), np.dtype('ubyte')]) cdef Index idx = Index() diff --git a/python/cuvs/cuvs/tests/test_nn_descent.py b/python/cuvs/cuvs/tests/test_nn_descent.py index c2d128edbb..862f82dc12 100644 --- a/python/cuvs/cuvs/tests/test_nn_descent.py +++ b/python/cuvs/cuvs/tests/test_nn_descent.py @@ -13,7 +13,7 @@ @pytest.mark.parametrize("n_rows", [1024, 2048]) @pytest.mark.parametrize("n_cols", [32, 64]) @pytest.mark.parametrize("device_memory", [True, False]) -@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize("dtype", [np.float32, np.float16]) @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.parametrize("return_distances", [True, False]) def test_nn_descent(