|
| 1 | +/* |
| 2 | + * Copyright (c) 2025, NVIDIA CORPORATION. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "../common/ann_types.hpp" |
| 18 | +#include "cuvs_ann_bench_param_parser.h" |
| 19 | +#include "cuvs_cagra_diskann_wrapper.h" |
| 20 | + |
| 21 | +#include <rmm/cuda_device.hpp> |
| 22 | +#include <rmm/mr/device/pool_memory_resource.hpp> |
| 23 | +#include <rmm/resource_ref.hpp> |
| 24 | + |
| 25 | +namespace cuvs::bench { |
| 26 | + |
| 27 | +template <typename T, typename IdxT> |
| 28 | +void parse_search_param(const nlohmann::json& conf, |
| 29 | + typename cuvs::bench::cuvs_cagra_diskann<T, IdxT>::search_param& param) |
| 30 | +{ |
| 31 | + if (conf.contains("L_search")) { param.L_search = conf.at("L_search"); } |
| 32 | +} |
| 33 | + |
| 34 | +template <typename T> |
| 35 | +auto create_algo(const std::string& algo_name, |
| 36 | + const std::string& distance, |
| 37 | + int dim, |
| 38 | + const nlohmann::json& conf) -> std::unique_ptr<cuvs::bench::algo<T>> |
| 39 | +{ |
| 40 | + [[maybe_unused]] cuvs::bench::Metric metric = parse_metric(distance); |
| 41 | + std::unique_ptr<cuvs::bench::algo<T>> a; |
| 42 | + |
| 43 | + if constexpr (std::is_same_v<T, float> || std::is_same_v<T, std::uint8_t> || |
| 44 | + std::is_same_v<T, std::int8_t>) { |
| 45 | + if (algo_name == "cuvs_cagra_diskann") { |
| 46 | + typename cuvs::bench::cuvs_cagra_diskann<T, uint32_t>::build_param param; |
| 47 | + ::parse_build_param<T, uint32_t>(conf, param); |
| 48 | + a = std::make_unique<cuvs::bench::cuvs_cagra_diskann<T, uint32_t>>(metric, dim, param); |
| 49 | + } |
| 50 | + } |
| 51 | + |
| 52 | + if (!a) { throw std::runtime_error("invalid algo: '" + algo_name + "'"); } |
| 53 | + |
| 54 | + return a; |
| 55 | +} |
| 56 | + |
| 57 | +template <typename T> |
| 58 | +auto create_search_param(const std::string& algo_name, const nlohmann::json& conf) |
| 59 | + -> std::unique_ptr<typename cuvs::bench::algo<T>::search_param> |
| 60 | +{ |
| 61 | + if (algo_name == "cuvs_cagra_diskann") { |
| 62 | + auto param = |
| 63 | + std::make_unique<typename cuvs::bench::cuvs_cagra_diskann<T, uint32_t>::search_param>(); |
| 64 | + parse_search_param<T, uint32_t>(conf, *param); |
| 65 | + return param; |
| 66 | + } |
| 67 | + |
| 68 | + throw std::runtime_error("invalid algo: '" + algo_name + "'"); |
| 69 | +} |
| 70 | + |
| 71 | +} // namespace cuvs::bench |
| 72 | + |
| 73 | +REGISTER_ALGO_INSTANCE(float); |
| 74 | +REGISTER_ALGO_INSTANCE(std::int8_t); |
| 75 | +REGISTER_ALGO_INSTANCE(std::uint8_t); |
| 76 | + |
| 77 | +#ifdef ANN_BENCH_BUILD_MAIN |
| 78 | +#include "../common/benchmark.hpp" |
| 79 | +/* |
| 80 | +[NOTE] Dear developer, |
| 81 | +
|
| 82 | +Please don't modify the content of the `main` function; this will make the behavior of the benchmark |
| 83 | +executable differ depending on the cmake flags and will complicate the debugging. In particular, |
| 84 | +don't try to setup an RMM memory resource here; it will anyway be modified by the memory resource |
| 85 | +set on per-algorithm basis. For example, see `cuvs/cuvs_ann_bench_utils.h`. |
| 86 | +*/ |
| 87 | +int main(int argc, char** argv) { return cuvs::bench::run_main(argc, argv); } |
| 88 | +#endif |
0 commit comments