Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/cmake/modules/generate_jit_lto_kernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ function(process_matrix_entry source_list_var)
EMBEDDED_HEADER_FILE "${embedded_header_file}"
LINK_LIBRARIES ${_JIT_LTO_KERNEL_LINK_LIBRARIES}
)

list(APPEND ${source_list_var} "${embedded_header_file}" "${embedded_file}")
set(${source_list_var}
"${${source_list_var}}"
Expand Down
7 changes: 6 additions & 1 deletion cpp/include/cuvs/detail/jit_lto/AlgorithmPlanner.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,22 @@
#pragma once

#include <string>
#include <utility>
#include <vector>

#include "AlgorithmLauncher.hpp"

struct FragmentEntry;

struct AlgorithmPlanner {
AlgorithmPlanner(std::string const& n, std::string const& p) : entrypoint(n + "_" + p) {}
AlgorithmPlanner(std::string&& fragment_key, std::string&& entrypoint)
: fragment_key(std::move(fragment_key)), entrypoint(std::move(entrypoint))
{
}

std::shared_ptr<AlgorithmLauncher> get_launcher();

std::string fragment_key;
std::string entrypoint;
std::vector<std::string> device_functions;
std::vector<FragmentEntry*> fragments;
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cuvs/detail/jit_lto/FragmentDatabase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class FragmentDatabase {

FragmentEntry* get_fragment(std::string const& key);

std::unordered_map<std::string, std::unique_ptr<FragmentEntry>> cache;
Comment thread
KyleFromNVIDIA marked this conversation as resolved.
Outdated

private:
FragmentDatabase();

Expand All @@ -33,8 +35,6 @@ class FragmentDatabase {
std::string const& params,
unsigned char const* blob,
std::size_t size);

std::unordered_map<std::string, std::unique_ptr<FragmentEntry>> cache;
};

FragmentDatabase& fragment_database();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,45 @@ struct tag_acc_ui {};
// Tag types for index types
struct tag_idx_l {};

template <typename T>
struct tag_abbrev;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fragment key is now assembled in C++ from tag_abbrev<>, while the generated embedded file key comes from CMake/JSON via @kernel_name@, no?

If so, that'd mean the planner side and generator side must stay perfectly synchronized across:

  • interleaved_scan_tags.hpp
  • interleaved_scan_planner.hpp
  • the CMake NAME_FORMAT
  • the JSON matrix abbrevs

If any one of those drifts, the failure mode might be late and opaque.

I’d strongly suggest either deriving both names from one source, or adding a smoke test that exercises one generated interleaved scan kernel through the full registration cudaLibraryGetKernel path.

This is not a blocking comment, if I'm correct then this can be addressed as a follow up, so would just request to open an issue to track.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it's possible to derive those from one source. CMake NAME_FORMAT is done at configure/build time, while the C++ implementation is done at runtime.

The way we're doing things here is not new, and any time the naming conventions have drifted, it's failed very loudly due to failure of either nvjitlink or the fragment database to find the appropriate fragment.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could inject at least the string of NAME_FORMAT with placeholders to the Planner class so at runtime, instead of constructing the string piece-by-piece, developers can just substitute the placeholder?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't we then have to generate a whole matrix of files that instantiate Planner for each possible combination?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We wouldn't substitute the real types/values at build time. Just the string with placeholders for the types/values so the developers don't have to know how to construct and match the string.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would that work? There are hundreds of possible strings. How do you substitute in @kernel_name@ without generating another matrix of hundreds of files?

Copy link
Copy Markdown
Member

@KyleFromNVIDIA KyleFromNVIDIA Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. You're thinking of having it look like:

this->set_name_format("some_kernel_@param1@_@param2@");

and then doing the substitution of @param1@ and @param2@ at runtime.

Yes, that's a good idea. We should do that in a follow-up.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's exactly what I was thinking. Follow-up is perfect 👍 will merge this PR now

template <>
struct tag_abbrev<tag_f> {
static constexpr char const* value = "f";
};
template <>
struct tag_abbrev<tag_h> {
static constexpr char const* value = "h";
};
template <>
struct tag_abbrev<tag_sc> {
static constexpr char const* value = "sc";
};
template <>
struct tag_abbrev<tag_uc> {
static constexpr char const* value = "uc";
};
template <>
struct tag_abbrev<tag_acc_f> {
static constexpr char const* value = "f";
};
template <>
struct tag_abbrev<tag_acc_h> {
static constexpr char const* value = "h";
};
template <>
struct tag_abbrev<tag_acc_i> {
static constexpr char const* value = "i";
};
template <>
struct tag_abbrev<tag_acc_ui> {
static constexpr char const* value = "ui";
};
template <>
struct tag_abbrev<tag_idx_l> {
static constexpr char const* value = "l";
};

// Tag types for filter subtypes
struct tag_filter_bitset_impl {};
struct tag_filter_none_impl {};
Expand Down
44 changes: 6 additions & 38 deletions cpp/src/detail/jit_lto/AlgorithmPlanner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

void AlgorithmPlanner::add_entrypoint()
{
auto entrypoint_fragment = fragment_database().get_fragment(this->entrypoint);
auto entrypoint_fragment = fragment_database().get_fragment(this->fragment_key);
this->fragments.push_back(entrypoint_fragment);
}

Expand All @@ -48,15 +48,15 @@ std::string AlgorithmPlanner::get_device_functions_key() const
std::shared_ptr<AlgorithmLauncher> AlgorithmPlanner::get_launcher()
{
auto& launchers = get_cached_launchers();
auto launch_key = this->entrypoint + this->get_device_functions_key();
auto launch_key = this->fragment_key + this->get_device_functions_key();

static std::mutex cache_mutex;
std::lock_guard<std::mutex> lock(cache_mutex);
if (launchers.count(launch_key) == 0) {
add_entrypoint();
add_device_functions();
std::string log_message =
"JIT compiling launcher for entrypoint: " + this->entrypoint + " and device functions: ";
"JIT compiling launcher for fragment: " + this->fragment_key + " and device functions: ";
Comment thread
KyleFromNVIDIA marked this conversation as resolved.
Outdated
for (const auto& device_function : this->device_functions) {
log_message += device_function + ",";
}
Expand Down Expand Up @@ -110,40 +110,8 @@ std::shared_ptr<AlgorithmLauncher> AlgorithmPlanner::build()
RAFT_CUDA_TRY(
cudaLibraryLoadData(&library, cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0));

unsigned int kernel_count = 0;
RAFT_CUDA_TRY(cudaLibraryGetKernelCount(&kernel_count, library));

// NOTE: cudaKernel_t does not need to be freed explicitly
std::unique_ptr<cudaKernel_t[]> kernels{new cudaKernel_t[kernel_count]};
RAFT_CUDA_TRY(cudaLibraryEnumerateKernels(kernels.get(), kernel_count, library));

// Filter out EmptyKernel by checking kernel names using cudaFuncGetName
const char* empty_kernel_name = "_ZN3cub6detail11EmptyKernelIvEEvv";
std::vector<cudaKernel_t> valid_kernels;
valid_kernels.reserve(kernel_count);

for (unsigned int i = 0; i < kernel_count; ++i) {
// cudaFuncGetName can be used with cudaKernel_t by casting to void*
const void* func_ptr = reinterpret_cast<const void*>(kernels[i]);
const char* func_name = nullptr;
RAFT_CUDA_TRY(cudaFuncGetName(&func_name, func_ptr));

bool is_empty_kernel = false;
if (func_name != nullptr) {
std::string kernel_name(func_name);
// Check if this is EmptyKernel
if (kernel_name.find(empty_kernel_name) != std::string::npos ||
kernel_name == empty_kernel_name) {
is_empty_kernel = true;
}
}

// Only keep the kernel if it's not EmptyKernel
if (!is_empty_kernel) { valid_kernels.push_back(kernels[i]); }
}

RAFT_EXPECTS(
valid_kernels.size() == 1, "Expected 1 valid JIT kernel, got %zu", valid_kernels.size());
cudaKernel_t kernel;
RAFT_CUDA_TRY(cudaLibraryGetKernel(&kernel, library, this->entrypoint.c_str()));

return std::make_shared<AlgorithmLauncher>(valid_kernels[0], library);
return std::make_shared<AlgorithmLauncher>(kernel, library);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,13 @@
// This file is auto-generated. Do not edit manually.

#include <cuvs/detail/jit_lto/RegisterKernelFragment.hpp>
#include <cuvs/detail/jit_lto/ivf_flat/interleaved_scan_tags.hpp>
#include "@embedded_header_file@"

using namespace cuvs::neighbors::ivf_flat::detail;

namespace {

__attribute__((__constructor__)) void register_kernel()
{
registerAlgorithm<tag_@type_abbrev@,
tag_acc_@acc_abbrev@,
tag_idx_@idx_abbrev@>(
"interleaved_scan_kernel_capacity_@capacity@_veclen_@veclen@_@ascending_descending@_@compute_norm_name@",
embedded_fatbin,
sizeof(embedded_fatbin));
registerAlgorithm("@kernel_name@", embedded_fatbin, sizeof(embedded_fatbin));
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to note is that the kernel_name variable is currently an implementation detail of process_matrix_entry(). We could document in generate_jit_lto_kernels() that this variable is set and available for usage inside the source file.

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,29 @@

namespace cuvs::neighbors::ivf_flat::detail {

// Instantiate the kernel template
template __global__ void interleaved_scan_kernel<@capacity@, @veclen@, @ascending_value@, @compute_norm_value@, @data_type@, @acc_type@, @idx_type@>(
const uint32_t, const @data_type@*, const uint32_t*, const @data_type@* const*, const uint32_t*,
const uint32_t, const uint32_t, const uint32_t, const uint32_t, const uint32_t*, const uint32_t,
@idx_type@* const* const, uint32_t*, @idx_type@, @idx_type@, uint32_t*, float*);
extern "C" __global__ __launch_bounds__(kThreadsPerBlock) void interleaved_scan(
const uint32_t query_smem_elems,
const @data_type@* query,
const uint32_t* coarse_index,
const @data_type@* const* list_data_ptrs,
const uint32_t* list_sizes,
const uint32_t queries_offset,
const uint32_t n_probes,
const uint32_t k,
const uint32_t max_samples,
const uint32_t* chunk_indices,
const uint32_t dim,
@idx_type@* const* const inds_ptrs,
uint32_t* bitset_ptr,
@idx_type@ bitset_len,
@idx_type@ original_nbits,
uint32_t* neighbors,
float* distances)
{
interleaved_scan_kernel_impl<@capacity@, @veclen@, @ascending_value@, @compute_norm_value@, @data_type@, @acc_type@, @idx_type@>(
query_smem_elems, query, coarse_index, list_data_ptrs, list_sizes, queries_offset, n_probes,
k, max_samples, chunk_indices, dim, inds_ptrs, bitset_ptr, bitset_len, original_nbits,
neighbors, distances);
}

} // namespace cuvs::neighbors::ivf_flat::detail
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,21 @@
#include <cuvs/detail/jit_lto/AlgorithmPlanner.hpp>
#include <cuvs/detail/jit_lto/FragmentDatabase.hpp>
#include <cuvs/detail/jit_lto/MakeFragmentKey.hpp>
#include <cuvs/detail/jit_lto/ivf_flat/interleaved_scan_tags.hpp>
#include <iostream>
#include <string>

template <typename... Args>
template <typename DataTypeTag, typename AccTypeTag, typename IdxTypeTag>
struct InterleavedScanPlanner : AlgorithmPlanner {
InterleavedScanPlanner(int Capacity, int Veclen, bool Ascending, bool ComputeNorm)
: AlgorithmPlanner("interleaved_scan_kernel_capacity_" + std::to_string(Capacity) + "_veclen_" +
: AlgorithmPlanner("interleaved_scan_capacity_" + std::to_string(Capacity) + "_veclen_" +
std::to_string(Veclen) + "_" + (Ascending ? "ascending" : "descending") +
"_" + (ComputeNorm ? "compute_norm" : "no_compute_norm"),
make_fragment_key<Args...>())
"_" + (ComputeNorm ? "compute_norm" : "no_compute_norm") + "_data_" +
cuvs::neighbors::ivf_flat::detail::tag_abbrev<DataTypeTag>::value +
"_acc_" +
cuvs::neighbors::ivf_flat::detail::tag_abbrev<AccTypeTag>::value +
"_idx_" + cuvs::neighbors::ivf_flat::detail::tag_abbrev<IdxTypeTag>::value,
"interleaved_scan")
{
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -764,24 +764,23 @@ template <int Capacity,
typename T,
typename AccT,
typename IdxT>
RAFT_KERNEL __launch_bounds__(kThreadsPerBlock)
interleaved_scan_kernel(const uint32_t query_smem_elems,
const T* query,
const uint32_t* coarse_index,
const T* const* list_data_ptrs,
const uint32_t* list_sizes,
const uint32_t queries_offset,
const uint32_t n_probes,
const uint32_t k,
const uint32_t max_samples,
const uint32_t* chunk_indices,
const uint32_t dim,
IdxT* const* const inds_ptrs,
uint32_t* bitset_ptr,
IdxT bitset_len,
IdxT original_nbits,
uint32_t* neighbors,
float* distances)
__device__ __forceinline__ void interleaved_scan_kernel_impl(const uint32_t query_smem_elems,
const T* query,
const uint32_t* coarse_index,
const T* const* list_data_ptrs,
const uint32_t* list_sizes,
const uint32_t queries_offset,
const uint32_t n_probes,
const uint32_t k,
const uint32_t max_samples,
const uint32_t* chunk_indices,
const uint32_t dim,
IdxT* const* const inds_ptrs,
uint32_t* bitset_ptr,
IdxT bitset_len,
IdxT original_nbits,
uint32_t* neighbors,
float* distances)
{
extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[];
constexpr bool kManageLocalTopK = Capacity > 0;
Expand Down
Loading