Skip to content

Commit 4cecf38

Browse files
KyleFromNVIDIAlowener
authored andcommitted
Remove JIT+LTO fragment database (rapidsai#1927)
Rather than register each fragment in a runtime class with a string key, "register" them with the linker using template specialization. This solves a number of problems: 1. It simplifies the code by removing the `FragmentDatabase` class. 2. It addresses rapidsai#1909 (comment) by bypassing the issue entirely. There is no longer a need to build the fragment name string at runtime. 3. For clients that use the `cuvs_static` static library, it allows the linker to pick and choose which fragment symbols it needs rather than including all of them with every client just in case any of them are needed. 4. Since there is no longer a need for `$<WHOLE_ARCHIVE:...>` linkage, there is no need for the `cuvs_jit_lto_kernels` target at all, thus simplifying the CMake code too. Authors: - Kyle Edwards (https://github.com/KyleFromNVIDIA) Approvers: - Divye Gala (https://github.com/divyegala) URL: rapidsai#1927
1 parent ed011a4 commit 4cecf38

23 files changed

Lines changed: 196 additions & 333 deletions

cpp/CMakeLists.txt

Lines changed: 21 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,7 @@ if(NOT BUILD_CPU_ONLY)
357357

358358
set(JIT_LTO_TARGET_ARCHITECTURE "")
359359
set(JIT_LTO_COMPILATION OFF)
360+
set(jit_lto_files)
360361
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
361362
set(JIT_LTO_TARGET_ARCHITECTURE "75-real")
362363
set(JIT_LTO_COMPILATION ON)
@@ -434,28 +435,16 @@ if(NOT BUILD_CPU_ONLY)
434435
)
435436
endblock()
436437

437-
add_library(
438-
cuvs_jit_lto_kernels STATIC
439-
${interleaved_scan_files}
440-
${metric_files}
441-
${filter_files}
442-
${post_lambda_files}
443-
src/detail/jit_lto/AlgorithmLauncher.cpp
444-
src/detail/jit_lto/AlgorithmPlanner.cpp
445-
src/detail/jit_lto/FragmentDatabase.cpp
446-
src/detail/jit_lto/FragmentEntry.cpp
447-
src/detail/jit_lto/nvjitlink_checker.cpp
438+
set(jit_lto_files
439+
${interleaved_scan_files}
440+
${metric_files}
441+
${filter_files}
442+
${post_lambda_files}
443+
src/detail/jit_lto/AlgorithmLauncher.cpp
444+
src/detail/jit_lto/AlgorithmPlanner.cpp
445+
src/detail/jit_lto/FragmentEntry.cpp
446+
src/detail/jit_lto/nvjitlink_checker.cpp
448447
)
449-
set_target_properties(
450-
cuvs_jit_lto_kernels PROPERTIES POSITION_INDEPENDENT_CODE ON CXX_STANDARD 20
451-
)
452-
target_include_directories(
453-
cuvs_jit_lto_kernels
454-
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/include" "${CMAKE_CURRENT_SOURCE_DIR}/src"
455-
"${CMAKE_CURRENT_SOURCE_DIR}/../c/include"
456-
)
457-
target_link_libraries(cuvs_jit_lto_kernels PRIVATE raft::raft)
458-
add_library(cuvs::cuvs_jit_lto_kernels ALIAS cuvs_jit_lto_kernels)
459448
endif()
460449

461450
add_library(
@@ -666,6 +655,7 @@ if(NOT BUILD_CPU_ONLY)
666655
src/stats/silhouette_score.cu
667656
src/stats/trustworthiness_score.cu
668657
${CUVS_MG_ALGOS}
658+
${jit_lto_files}
669659
)
670660

671661
set_target_properties(
@@ -777,12 +767,8 @@ if(NOT BUILD_CPU_ONLY)
777767
$<BUILD_LOCAL_INTERFACE:$<TARGET_NAME_IF_EXISTS:NCCL::NCCL>>
778768
$<BUILD_LOCAL_INTERFACE:$<TARGET_NAME_IF_EXISTS:hnswlib::hnswlib>>
779769
$<$<BOOL:${CUVS_NVTX}>:CUDA::nvtx3>
780-
PRIVATE
781-
$<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX>
782-
$<COMPILE_ONLY:nvidia::cutlass::cutlass>
783-
$<COMPILE_ONLY:cuco::cuco>
784-
$<$<BOOL:${JIT_LTO_COMPILATION}>:CUDA::nvJitLink>
785-
$<$<BOOL:${JIT_LTO_COMPILATION}>:$<LINK_LIBRARY:WHOLE_ARCHIVE,cuvs::cuvs_jit_lto_kernels>>
770+
PRIVATE $<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX> $<COMPILE_ONLY:nvidia::cutlass::cutlass>
771+
$<COMPILE_ONLY:cuco::cuco> $<$<BOOL:${JIT_LTO_COMPILATION}>:CUDA::nvJitLink>
786772
)
787773

788774
# ensure CUDA symbols aren't relocated to the middle of the debug build binaries
@@ -839,13 +825,11 @@ SECTIONS
839825
${CUVS_CTK_MATH_DEPENDENCIES}
840826
$<TARGET_NAME_IF_EXISTS:NCCL::NCCL> # needs to be public for DT_NEEDED
841827
$<BUILD_LOCAL_INTERFACE:$<TARGET_NAME_IF_EXISTS:hnswlib::hnswlib>> # header only
842-
PRIVATE
843-
$<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX>
844-
$<$<BOOL:${JIT_LTO_COMPILATION}>:CUDA::nvJitLink>
845-
$<$<BOOL:${CUVS_NVTX}>:CUDA::nvtx3>
846-
$<COMPILE_ONLY:nvidia::cutlass::cutlass>
847-
$<COMPILE_ONLY:cuco::cuco>
848-
$<$<BOOL:${JIT_LTO_COMPILATION}>:$<LINK_LIBRARY:WHOLE_ARCHIVE,cuvs::cuvs_jit_lto_kernels>>
828+
PRIVATE $<TARGET_NAME_IF_EXISTS:OpenMP::OpenMP_CXX>
829+
$<$<BOOL:${JIT_LTO_COMPILATION}>:CUDA::nvJitLink>
830+
$<$<BOOL:${CUVS_NVTX}>:CUDA::nvtx3>
831+
$<COMPILE_ONLY:nvidia::cutlass::cutlass>
832+
$<COMPILE_ONLY:cuco::cuco>
849833
)
850834
endif()
851835

@@ -886,11 +870,9 @@ target_compile_definitions(cuvs::cuvs INTERFACE $<$<BOOL:${CUVS_NVTX}>:NVTX_ENAB
886870
include(GNUInstallDirs)
887871
include(CPack)
888872

889-
set(target_names cuvs cuvs_static cuvs_jit_lto_kernels cuvs_cpp_headers cuvs_c)
890-
set(component_names cuvs_shared cuvs_static cuvs_static cuvs_cpp_headers c_api)
891-
set(export_names cuvs-shared-exports cuvs-static-exports cuvs-static-exports
892-
cuvs-cpp-headers-exports cuvs-c-exports
893-
)
873+
set(target_names cuvs cuvs_static cuvs_cpp_headers cuvs_c)
874+
set(component_names cuvs_shared cuvs_static cuvs_cpp_headers c_api)
875+
set(export_names cuvs-shared-exports cuvs-static-exports cuvs-cpp-headers-exports cuvs-c-exports)
894876
foreach(target component export IN ZIP_LISTS target_names component_names export_names)
895877
if(TARGET ${target})
896878
install(

cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,22 @@
1414
struct FragmentEntry;
1515

1616
struct AlgorithmPlanner {
17-
AlgorithmPlanner(std::string fragment_key, std::string entrypoint)
18-
: fragment_key(std::move(fragment_key)), entrypoint(std::move(entrypoint))
19-
{
20-
}
17+
AlgorithmPlanner(std::string entrypoint) : entrypoint(std::move(entrypoint)) {}
2118

2219
std::shared_ptr<AlgorithmLauncher> get_launcher();
2320

24-
std::string fragment_key;
2521
std::string entrypoint;
26-
std::vector<std::string> device_functions;
27-
std::vector<FragmentEntry*> fragments;
22+
std::vector<const FragmentEntry*> fragments;
23+
24+
void add_fragment(const FragmentEntry& fragment);
25+
26+
template <typename FragmentT>
27+
void add_fragment()
28+
{
29+
add_fragment(FragmentT{});
30+
}
2831

2932
private:
30-
void add_entrypoint();
31-
void add_device_functions();
32-
std::string get_device_functions_key() const;
33+
std::string get_fragments_key() const;
3334
std::shared_ptr<AlgorithmLauncher> build();
3435
};

cpp/include/cuvs/detail/jit_lto/FragmentDatabase.hpp

Lines changed: 0 additions & 45 deletions
This file was deleted.

cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,27 @@
1212

1313
#include <nvJitLink.h>
1414

15+
#include "nvjitlink_checker.hpp"
16+
1517
struct FragmentEntry {
16-
FragmentEntry(std::string const& key);
18+
virtual bool add_to(nvJitLinkHandle& handle) const = 0;
1719

18-
bool operator==(const FragmentEntry& rhs) const { return compute_key == rhs.compute_key; }
20+
virtual const char* get_key() const = 0;
21+
};
1922

20-
virtual bool add_to(nvJitLinkHandle& handle) const = 0;
23+
struct FatbinFragmentEntry : FragmentEntry {
24+
virtual const uint8_t* get_data() const = 0;
25+
26+
virtual size_t get_length() const = 0;
2127

22-
std::string compute_key{};
28+
bool add_to(nvJitLinkHandle& handle) const override final;
2329
};
2430

25-
struct FatbinFragmentEntry final : FragmentEntry {
26-
FatbinFragmentEntry(std::string const& key, unsigned char const* view, std::size_t size);
31+
template <typename FragmentT>
32+
struct StaticFatbinFragmentEntry : FatbinFragmentEntry {
33+
const uint8_t* get_data() const override final { return FragmentT::data; }
2734

28-
virtual bool add_to(nvJitLinkHandle& handle) const;
35+
size_t get_length() const override final { return FragmentT::length; }
2936

30-
std::size_t data_size = 0;
31-
unsigned char const* data_view = nullptr;
37+
const char* get_key() const override final { return typeid(FragmentT).name(); }
3238
};

cpp/include/cuvs/detail/jit_lto/RegisterKernelFragment.hpp

Lines changed: 0 additions & 24 deletions
This file was deleted.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
#pragma once
7+
8+
#include <cuvs/detail/jit_lto/FragmentEntry.hpp>
9+
10+
namespace cuvs::neighbors::ivf_flat::detail {
11+
12+
template <typename DataTag,
13+
typename AccTag,
14+
typename IdxTag,
15+
int Capacity,
16+
int Veclen,
17+
bool Ascending,
18+
bool ComputeNorm>
19+
struct InterleavedScanFragmentEntry final
20+
: StaticFatbinFragmentEntry<InterleavedScanFragmentEntry<DataTag,
21+
AccTag,
22+
IdxTag,
23+
Capacity,
24+
Veclen,
25+
Ascending,
26+
ComputeNorm>> {
27+
static const uint8_t* const data;
28+
static const size_t length;
29+
};
30+
31+
template <int Veclen, typename DataTag, typename AccTag, typename MetricTag>
32+
struct MetricFragmentEntry final
33+
: StaticFatbinFragmentEntry<MetricFragmentEntry<Veclen, DataTag, AccTag, MetricTag>> {
34+
static const uint8_t* const data;
35+
static const size_t length;
36+
};
37+
38+
template <typename IvfSampleFilterTag>
39+
struct FilterFragmentEntry final
40+
: StaticFatbinFragmentEntry<FilterFragmentEntry<IvfSampleFilterTag>> {
41+
static const uint8_t* const data;
42+
static const size_t length;
43+
};
44+
45+
template <typename PostLambdaTag>
46+
struct PostLambdaFragmentEntry final
47+
: StaticFatbinFragmentEntry<PostLambdaFragmentEntry<PostLambdaTag>> {
48+
static const uint8_t* const data;
49+
static const size_t length;
50+
};
51+
52+
} // namespace cuvs::neighbors::ivf_flat::detail

cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_tags.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,7 @@ template <typename IdxTag, typename FilterImplTag>
7070
struct tag_filter {};
7171

7272
// Tag types for distance metrics with full template info
73-
template <int Veclen, typename TTag, typename AccTTag>
7473
struct tag_metric_euclidean {};
75-
76-
template <int Veclen, typename TTag, typename AccTTag>
7774
struct tag_metric_inner_product {};
7875

7976
// Tag types for post-processing
File renamed without changes.

cpp/src/detail/jit_lto/AlgorithmPlanner.cpp

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
#include "nvjitlink_checker.hpp"
7-
86
#include <chrono>
97
#include <iterator>
108
#include <memory>
@@ -14,51 +12,41 @@
1412
#include <vector>
1513

1614
#include <cuvs/detail/jit_lto/AlgorithmPlanner.hpp>
17-
#include <cuvs/detail/jit_lto/FragmentDatabase.hpp>
15+
#include <cuvs/detail/jit_lto/FragmentEntry.hpp>
16+
#include <cuvs/detail/jit_lto/nvjitlink_checker.hpp>
1817

1918
#include "cuda_runtime.h"
2019
#include "nvJitLink.h"
2120

2221
#include <raft/core/logger.hpp>
2322
#include <raft/util/cuda_rt_essentials.hpp>
2423

25-
void AlgorithmPlanner::add_entrypoint()
24+
void AlgorithmPlanner::add_fragment(const FragmentEntry& fragment)
2625
{
27-
auto entrypoint_fragment = fragment_database().get_fragment(this->fragment_key);
28-
this->fragments.push_back(entrypoint_fragment);
29-
}
30-
31-
void AlgorithmPlanner::add_device_functions()
32-
{
33-
for (const auto& device_function_key : this->device_functions) {
34-
auto device_function_fragment = fragment_database().get_fragment(device_function_key);
35-
this->fragments.push_back(device_function_fragment);
36-
}
26+
fragments.push_back(&fragment);
3727
}
3828

39-
std::string AlgorithmPlanner::get_device_functions_key() const
29+
std::string AlgorithmPlanner::get_fragments_key() const
4030
{
4131
std::string key = "";
42-
for (const auto& device_function : this->device_functions) {
43-
key += device_function;
32+
for (const auto* fragment : this->fragments) {
33+
key += fragment->get_key();
4434
}
4535
return key;
4636
}
4737

4838
std::shared_ptr<AlgorithmLauncher> AlgorithmPlanner::get_launcher()
4939
{
5040
auto& launchers = get_cached_launchers();
51-
auto launch_key = this->fragment_key + this->get_device_functions_key();
41+
auto launch_key = this->get_fragments_key();
5242

5343
static std::mutex cache_mutex;
5444
std::lock_guard<std::mutex> lock(cache_mutex);
5545
if (launchers.count(launch_key) == 0) {
56-
add_entrypoint();
57-
add_device_functions();
5846
std::string log_message =
59-
"JIT compiling launcher for kernel: " + this->fragment_key + " and device functions: ";
60-
for (const auto& device_function : this->device_functions) {
61-
log_message += device_function + ",";
47+
"JIT compiling launcher for kernel: " + this->entrypoint + " and device functions: ";
48+
for (const auto* fragment : this->fragments) {
49+
log_message += std::string{fragment->get_key()} + ",";
6250
}
6351
log_message.pop_back();
6452
RAFT_LOG_INFO("%s", log_message.c_str());

0 commit comments

Comments
 (0)