From 47a8152a4f7ab3c735e378673f35e0d7123d464a Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 14 Oct 2025 13:55:49 -0700 Subject: [PATCH] Fix serialization with FP16. Add python tests Signed-off-by: Mickael Ide --- c/src/neighbors/brute_force.cpp | 2 +- c/src/neighbors/cagra.cpp | 2 +- c/src/neighbors/ivf_flat.cpp | 2 +- c/src/neighbors/mg_cagra.cpp | 4 ++-- c/src/neighbors/mg_ivf_flat.cpp | 4 ++-- c/src/neighbors/mg_ivf_pq.cpp | 2 +- python/cuvs/cuvs/tests/test_brute_force.py | 21 +++++++++++++++++++-- python/cuvs/cuvs/tests/test_cagra.py | 13 ++++++++++++- python/cuvs/cuvs/tests/test_ivf_flat.py | 15 +++++++++++++-- python/cuvs/cuvs/tests/test_ivf_pq.py | 13 ++++++++++++- python/cuvs/cuvs/tests/test_mg_cagra.py | 2 +- 11 files changed, 65 insertions(+), 15 deletions(-) diff --git a/c/src/neighbors/brute_force.cpp b/c/src/neighbors/brute_force.cpp index 9330539487..72bb9d3f03 100644 --- a/c/src/neighbors/brute_force.cpp +++ b/c/src/neighbors/brute_force.cpp @@ -255,7 +255,7 @@ extern "C" cuvsError_t cuvsBruteForceDeserialize(cuvsResources_t res, if (dtype.kind == 'f' && dtype.itemsize == 4) { index->dtype.code = kDLFloat; index->addr = reinterpret_cast(_deserialize(res, filename)); - } else if (dtype.kind == 'f' && dtype.itemsize == 2) { + } else if (dtype.kind == 'e' && dtype.itemsize == 2) { index->dtype.code = kDLFloat; index->addr = reinterpret_cast(_deserialize(res, filename)); } else { diff --git a/c/src/neighbors/cagra.cpp b/c/src/neighbors/cagra.cpp index 6c3cbafdad..8df6e387b3 100644 --- a/c/src/neighbors/cagra.cpp +++ b/c/src/neighbors/cagra.cpp @@ -742,7 +742,7 @@ extern "C" cuvsError_t cuvsCagraDeserialize(cuvsResources_t res, if (dtype.kind == 'f' && dtype.itemsize == 4) { index->addr = reinterpret_cast(_deserialize(res, filename)); index->dtype.code = kDLFloat; - } else if (dtype.kind == 'f' && dtype.itemsize == 2) { + } else if (dtype.kind == 'e' && dtype.itemsize == 2) { index->addr = reinterpret_cast(_deserialize(res, filename)); index->dtype.code = kDLFloat; } else if (dtype.kind == 'i' && dtype.itemsize == 1) { diff --git a/c/src/neighbors/ivf_flat.cpp b/c/src/neighbors/ivf_flat.cpp index c5a249f85a..02cae8f2ae 100644 --- a/c/src/neighbors/ivf_flat.cpp +++ b/c/src/neighbors/ivf_flat.cpp @@ -318,7 +318,7 @@ extern "C" cuvsError_t cuvsIvfFlatDeserialize(cuvsResources_t res, if (dtype.kind == 'f' && dtype.itemsize == 4) { index->addr = reinterpret_cast(_deserialize(res, filename)); index->dtype.code = kDLFloat; - } else if (dtype.kind == 'f' && dtype.itemsize == 2) { + } else if (dtype.kind == 'e' && dtype.itemsize == 2) { index->addr = reinterpret_cast(_deserialize(res, filename)); index->dtype.code = kDLFloat; index->dtype.bits = 16; diff --git a/c/src/neighbors/mg_cagra.cpp b/c/src/neighbors/mg_cagra.cpp index 2695c69f8a..297fe3359f 100644 --- a/c/src/neighbors/mg_cagra.cpp +++ b/c/src/neighbors/mg_cagra.cpp @@ -419,7 +419,7 @@ extern "C" cuvsError_t cuvsMultiGpuCagraDeserialize(cuvsResources_t res, if (dtype.kind == 'f' && dtype.itemsize == 4) { index->dtype.code = kDLFloat; index->addr = reinterpret_cast(_mg_deserialize(res, filename)); - } else if (dtype.kind == 'f' && dtype.itemsize == 2) { + } else if (dtype.kind == 'e' && dtype.itemsize == 2) { index->dtype.code = kDLFloat; index->addr = reinterpret_cast(_mg_deserialize(res, filename)); } else if (dtype.kind == 'i' && dtype.itemsize == 1) { @@ -450,7 +450,7 @@ extern "C" cuvsError_t cuvsMultiGpuCagraDistribute(cuvsResources_t res, if (dtype.kind == 'f' && dtype.itemsize == 4) { index->dtype.code = kDLFloat; index->addr = reinterpret_cast(_mg_distribute(res, filename)); - } else if (dtype.kind == 'f' && dtype.itemsize == 2) { + } else if (dtype.kind == 'e' && dtype.itemsize == 2) { index->dtype.code = kDLFloat; index->addr = reinterpret_cast(_mg_distribute(res, filename)); } else if (dtype.kind == 'i' && dtype.itemsize == 1) { diff --git a/c/src/neighbors/mg_ivf_flat.cpp b/c/src/neighbors/mg_ivf_flat.cpp index 641d1395f2..87de0f6148 100644 --- a/c/src/neighbors/mg_ivf_flat.cpp +++ b/c/src/neighbors/mg_ivf_flat.cpp @@ -416,7 +416,7 @@ extern "C" cuvsError_t cuvsMultiGpuIvfFlatDeserialize(cuvsResources_t res, if (dtype.kind == 'f' && dtype.itemsize == 4) { index->dtype.code = kDLFloat; index->addr = reinterpret_cast(_mg_deserialize(res, filename)); - } else if (dtype.kind == 'f' && dtype.itemsize == 2) { + } else if (dtype.kind == 'e' && dtype.itemsize == 2) { index->dtype.code = kDLFloat; index->addr = reinterpret_cast(_mg_deserialize(res, filename)); } else if (dtype.kind == 'i' && dtype.itemsize == 1) { @@ -447,7 +447,7 @@ extern "C" cuvsError_t cuvsMultiGpuIvfFlatDistribute(cuvsResources_t res, if (dtype.kind == 'f' && dtype.itemsize == 4) { index->dtype.code = kDLFloat; index->addr = reinterpret_cast(_mg_distribute(res, filename)); - } else if (dtype.kind == 'f' && dtype.itemsize == 2) { + } else if (dtype.kind == 'e' && dtype.itemsize == 2) { index->dtype.code = kDLFloat; index->addr = reinterpret_cast(_mg_distribute(res, filename)); } else if (dtype.kind == 'i' && dtype.itemsize == 1) { diff --git a/c/src/neighbors/mg_ivf_pq.cpp b/c/src/neighbors/mg_ivf_pq.cpp index a16878299a..0015cb7685 100644 --- a/c/src/neighbors/mg_ivf_pq.cpp +++ b/c/src/neighbors/mg_ivf_pq.cpp @@ -408,7 +408,7 @@ extern "C" cuvsError_t cuvsMultiGpuIvfPqDeserialize(cuvsResources_t res, if (dtype.kind == 'f' && dtype.itemsize == 4) { index->dtype.code = kDLFloat; index->addr = reinterpret_cast(_mg_deserialize(res, filename)); - } else if (dtype.kind == 'f' && dtype.itemsize == 2) { + } else if (dtype.kind == 'e' && dtype.itemsize == 2) { index->dtype.code = kDLFloat; index->addr = reinterpret_cast(_mg_deserialize(res, filename)); } else if (dtype.kind == 'i' && dtype.itemsize == 1) { diff --git a/python/cuvs/cuvs/tests/test_brute_force.py b/python/cuvs/cuvs/tests/test_brute_force.py index 3e0cad3e94..336b337c68 100644 --- a/python/cuvs/cuvs/tests/test_brute_force.py +++ b/python/cuvs/cuvs/tests/test_brute_force.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,8 @@ # limitations under the License. # +import tempfile + import numpy as np import pytest from pylibraft.common import device_ndarray @@ -42,8 +44,17 @@ @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.parametrize("order", ["F", "C"]) @pytest.mark.parametrize("dtype", [np.float32, np.float16]) +@pytest.mark.parametrize("serialize", [True, False]) def test_brute_force_knn( - n_index_rows, n_query_rows, n_cols, k, inplace, order, metric, dtype + n_index_rows, + n_query_rows, + n_cols, + k, + inplace, + order, + metric, + dtype, + serialize, ): index = np.random.random_sample((n_index_rows, n_cols)) index = np.asarray(index, order=order).astype(dtype) @@ -68,6 +79,12 @@ def test_brute_force_knn( prefilter = filters.no_filter() brute_force_index = brute_force.build(index_device, metric) + if serialize: + with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as f: + temp_filename = f.name + brute_force.save(temp_filename, brute_force_index) + brute_force_index = brute_force.load(temp_filename) + ret_distances, ret_indices = brute_force.search( brute_force_index, queries_device, diff --git a/python/cuvs/cuvs/tests/test_cagra.py b/python/cuvs/cuvs/tests/test_cagra.py index 4ec3c2a8fb..62cc65c333 100644 --- a/python/cuvs/cuvs/tests/test_cagra.py +++ b/python/cuvs/cuvs/tests/test_cagra.py @@ -13,6 +13,8 @@ # limitations under the License. # +import tempfile + import cupy as cp import numpy as np import pytest @@ -44,6 +46,7 @@ def run_cagra_build_search_test( test_extend=False, search_params={}, compression=None, + serialize=False, ): dataset = generate_data((n_rows, n_cols), dtype) if metric == "inner_product" or metric == "cosine": @@ -79,6 +82,12 @@ def run_cagra_build_search_test( else: index = cagra.build(build_params, dataset) + if serialize: + with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as f: + temp_filename = f.name + cagra.save(temp_filename, index) + index = cagra.load(temp_filename) + queries = generate_data((n_queries, n_cols), dtype) out_idx = np.zeros((n_queries, k), dtype=np.uint32) out_dist = np.zeros((n_queries, k), dtype=np.float32) @@ -159,8 +168,9 @@ def run_cagra_build_search_test( @pytest.mark.parametrize("array_type", ["device"]) @pytest.mark.parametrize("build_algo", ["ivf_pq", "nn_descent"]) @pytest.mark.parametrize("metric", ["sqeuclidean", "inner_product", "cosine"]) +@pytest.mark.parametrize("serialize", [True, False]) def test_cagra_dataset_dtype_host_device( - dtype, array_type, inplace, build_algo, metric + dtype, array_type, inplace, build_algo, metric, serialize ): # Note that inner_product tests use normalized input which we cannot @@ -171,6 +181,7 @@ def test_cagra_dataset_dtype_host_device( array_type=array_type, build_algo=build_algo, metric=metric, + serialize=serialize, ) diff --git a/python/cuvs/cuvs/tests/test_ivf_flat.py b/python/cuvs/cuvs/tests/test_ivf_flat.py index ea51f3a02c..09e0efcee4 100644 --- a/python/cuvs/cuvs/tests/test_ivf_flat.py +++ b/python/cuvs/cuvs/tests/test_ivf_flat.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. +# Copyright (c) 2024-2025, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,8 @@ # limitations under the License. # +import tempfile + import numpy as np import pytest from pylibraft.common import device_ndarray @@ -38,6 +40,7 @@ def run_ivf_flat_build_search_test( compare=True, inplace=True, search_params={}, + serialize=False, ): dataset = generate_data((n_rows, n_cols), dtype) if metric == "inner_product": @@ -51,6 +54,12 @@ def run_ivf_flat_build_search_test( index = ivf_flat.build(build_params, dataset_device) + if serialize: + with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as f: + temp_filename = f.name + ivf_flat.save(temp_filename, index) + index = ivf_flat.load(temp_filename) + if not add_data_on_build: dataset_1 = dataset[: n_rows // 2, :] dataset_2 = dataset[n_rows // 2 :, :] @@ -127,7 +136,8 @@ def test_ivf_flat(inplace, dtype, metric): @pytest.mark.parametrize("dtype", [np.float32, np.float16, np.int8, np.uint8]) -def test_extend(dtype): +@pytest.mark.parametrize("serialize", [True, False]) +def test_extend(dtype, serialize): run_ivf_flat_build_search_test( n_rows=10000, n_cols=10, @@ -136,6 +146,7 @@ def test_extend(dtype): metric="sqeuclidean", dtype=dtype, add_data_on_build=False, + serialize=serialize, ) diff --git a/python/cuvs/cuvs/tests/test_ivf_pq.py b/python/cuvs/cuvs/tests/test_ivf_pq.py index a7367b9736..a3eb8a9df4 100644 --- a/python/cuvs/cuvs/tests/test_ivf_pq.py +++ b/python/cuvs/cuvs/tests/test_ivf_pq.py @@ -13,6 +13,8 @@ # limitations under the License. # +import tempfile + import numpy as np import pytest from pylibraft.common import device_ndarray @@ -44,6 +46,7 @@ def run_ivf_pq_build_search_test( compare=True, inplace=True, array_type="device", + serialize=False, ): dataset = generate_data((n_rows, n_cols), dtype) if metric == "inner_product": @@ -67,6 +70,12 @@ def run_ivf_pq_build_search_test( else: index = ivf_pq.build(build_params, dataset) + if serialize: + with tempfile.NamedTemporaryFile(suffix=".bin", delete=False) as f: + temp_filename = f.name + ivf_pq.save(temp_filename, index) + index = ivf_pq.load(temp_filename) + if not add_data_on_build: dataset_1 = dataset[: n_rows // 2, :] dataset_2 = dataset[n_rows // 2 :, :] @@ -216,9 +225,11 @@ def test_extend(dtype, array_type): @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.parametrize("dtype", [np.float32, np.float16, np.int8, np.uint8]) @pytest.mark.parametrize("array_type", ["host", "device"]) -def test_ivf_pq_dtype(inplace, dtype, array_type): +@pytest.mark.parametrize("serialize", [True, False]) +def test_ivf_pq_dtype(inplace, dtype, array_type, serialize): run_ivf_pq_build_search_test( dtype=dtype, inplace=inplace, array_type=array_type, + serialize=serialize, ) diff --git a/python/cuvs/cuvs/tests/test_mg_cagra.py b/python/cuvs/cuvs/tests/test_mg_cagra.py index 16d40f9c17..e03b49605f 100644 --- a/python/cuvs/cuvs/tests/test_mg_cagra.py +++ b/python/cuvs/cuvs/tests/test_mg_cagra.py @@ -139,7 +139,7 @@ def run_mg_cagra_build_search_test( @requires_multiple_gpus -@pytest.mark.parametrize("dtype", [np.float32]) +@pytest.mark.parametrize("dtype", [np.float32, np.float16]) @pytest.mark.parametrize( "metric", ["sqeuclidean"] ) # Start with just sqeuclidean