forked from rapidsai/cuvs
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAlgorithmPlanner.cu
More file actions
117 lines (96 loc) · 3.54 KB
/
AlgorithmPlanner.cu
File metadata and controls
117 lines (96 loc) · 3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#include "nvjitlink_checker.hpp"
#include <chrono>
#include <iterator>
#include <memory>
#include <mutex>
#include <new>
#include <string>
#include <vector>
#include <cuvs/detail/jit_lto/AlgorithmPlanner.hpp>
#include <cuvs/detail/jit_lto/FragmentDatabase.hpp>
#include "cuda_runtime.h"
#include "nvJitLink.h"
#include <raft/core/logger.hpp>
#include <raft/util/cuda_rt_essentials.hpp>
void AlgorithmPlanner::add_entrypoint()
{
auto entrypoint_fragment = fragment_database().get_fragment(this->fragment_key);
this->fragments.push_back(entrypoint_fragment);
}
void AlgorithmPlanner::add_device_functions()
{
for (const auto& device_function_key : this->device_functions) {
auto device_function_fragment = fragment_database().get_fragment(device_function_key);
this->fragments.push_back(device_function_fragment);
}
}
std::string AlgorithmPlanner::get_device_functions_key() const
{
std::string key = "";
for (const auto& device_function : this->device_functions) {
key += device_function;
}
return key;
}
std::shared_ptr<AlgorithmLauncher> AlgorithmPlanner::get_launcher()
{
auto& launchers = get_cached_launchers();
auto launch_key = this->fragment_key + this->get_device_functions_key();
static std::mutex cache_mutex;
std::lock_guard<std::mutex> lock(cache_mutex);
if (launchers.count(launch_key) == 0) {
add_entrypoint();
add_device_functions();
std::string log_message =
"JIT compiling launcher for fragment: " + this->fragment_key + " and device functions: ";
for (const auto& device_function : this->device_functions) {
log_message += device_function + ",";
}
log_message.pop_back();
RAFT_LOG_INFO("%s", log_message.c_str());
launchers[launch_key] = this->build();
}
return launchers[launch_key];
}
std::shared_ptr<AlgorithmLauncher> AlgorithmPlanner::build()
{
int device = 0;
int major = 0;
int minor = 0;
RAFT_CUDA_TRY(cudaGetDevice(&device));
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
std::string archs = "-arch=sm_" + std::to_string((major * 10 + minor));
// Load the generated LTO IR and link them together
nvJitLinkHandle handle;
const char* lopts[] = {"-lto", archs.c_str()};
auto result = nvJitLinkCreate(&handle, 2, lopts);
check_nvjitlink_result(handle, result);
for (auto& frag : this->fragments) {
frag->add_to(handle);
}
// Call to nvJitLinkComplete causes linker to link together all the LTO-IR
// modules perform any optimizations and generate cubin from it.
result = nvJitLinkComplete(handle);
check_nvjitlink_result(handle, result);
// get cubin from nvJitLink
size_t cubin_size;
result = nvJitLinkGetLinkedCubinSize(handle, &cubin_size);
check_nvjitlink_result(handle, result);
std::unique_ptr<char[]> cubin{new char[cubin_size]};
result = nvJitLinkGetLinkedCubin(handle, cubin.get());
check_nvjitlink_result(handle, result);
result = nvJitLinkDestroy(&handle);
RAFT_EXPECTS(result == NVJITLINK_SUCCESS, "nvJitLinkDestroy failed");
// cubin is linked, so now load it
cudaLibrary_t library;
RAFT_CUDA_TRY(
cudaLibraryLoadData(&library, cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0));
cudaKernel_t kernel;
RAFT_CUDA_TRY(cudaLibraryGetKernel(&kernel, library, this->entrypoint.c_str()));
return std::make_shared<AlgorithmLauncher>(kernel, library);
}