-
Notifications
You must be signed in to change notification settings - Fork 184
Expand file tree
/
Copy pathAlgorithmPlanner.cpp
More file actions
110 lines (89 loc) · 3.44 KB
/
AlgorithmPlanner.cpp
File metadata and controls
110 lines (89 loc) · 3.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#include <chrono>
#include <iterator>
#include <memory>
#include <mutex>
#include <new>
#include <shared_mutex>
#include <string>
#include <vector>
#include <cuvs/detail/jit_lto/AlgorithmPlanner.hpp>
#include <cuvs/detail/jit_lto/nvjitlink_checker.hpp>
#include "cuda_runtime.h"
#include "nvJitLink.h"
#include <raft/core/logger.hpp>
#include <raft/util/cuda_rt_essentials.hpp>
std::string AlgorithmPlanner::get_fragments_key() const
{
std::string key = "";
for (const auto& fragment : this->fragments) {
key += fragment->get_key();
}
return key;
}
std::shared_ptr<AlgorithmLauncher> AlgorithmPlanner::read_cache(std::string const& launch_key) const
{
auto& launchers = jit_cache_.launchers;
std::shared_lock<std::shared_mutex> read_lock(jit_cache_.mutex);
if (auto it = launchers.find(launch_key); it != launchers.end()) { return it->second; }
return nullptr;
}
std::shared_ptr<AlgorithmLauncher> AlgorithmPlanner::get_launcher()
{
auto& launchers = jit_cache_.launchers;
auto launch_key = this->get_fragments_key();
if (auto hit = read_cache(launch_key)) { return hit; }
std::unique_lock<std::shared_mutex> write_lock(jit_cache_.mutex);
if (auto it = launchers.find(launch_key); it != launchers.end()) { return it->second; }
std::string log_message =
"JIT compiling launcher for kernel: " + this->entrypoint + " and device functions: ";
for (const auto& fragment : this->fragments) {
log_message += std::string{fragment->get_key()} + ",";
}
log_message.pop_back();
RAFT_LOG_DEBUG("%s", log_message.c_str());
auto launcher = this->build();
launchers[launch_key] = launcher;
return launcher;
}
std::shared_ptr<AlgorithmLauncher> AlgorithmPlanner::build()
{
int device = 0;
int major = 0;
int minor = 0;
RAFT_CUDA_TRY(cudaGetDevice(&device));
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device));
RAFT_CUDA_TRY(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device));
std::string archs = "-arch=sm_" + std::to_string((major * 10 + minor));
// Load the generated LTO IR and link them together
nvJitLinkHandle handle;
const char* lopts[] = {"-lto", archs.c_str(), "-maxrregcount=64"};
auto result = nvJitLinkCreate(&handle, 3, lopts);
check_nvjitlink_result(handle, result);
for (const auto& frag : this->fragments) {
frag->add_to(handle);
}
// Call to nvJitLinkComplete causes linker to link together all the LTO-IR
// modules perform any optimizations and generate cubin from it.
result = nvJitLinkComplete(handle);
check_nvjitlink_result(handle, result);
// get cubin from nvJitLink
size_t cubin_size;
result = nvJitLinkGetLinkedCubinSize(handle, &cubin_size);
check_nvjitlink_result(handle, result);
std::unique_ptr<char[]> cubin{new char[cubin_size]};
result = nvJitLinkGetLinkedCubin(handle, cubin.get());
check_nvjitlink_result(handle, result);
result = nvJitLinkDestroy(&handle);
RAFT_EXPECTS(result == NVJITLINK_SUCCESS, "nvJitLinkDestroy failed");
// cubin is linked, so now load it
cudaLibrary_t library;
RAFT_CUDA_TRY(
cudaLibraryLoadData(&library, cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0));
cudaKernel_t kernel;
RAFT_CUDA_TRY(cudaLibraryGetKernel(&kernel, library, this->entrypoint.c_str()));
return std::make_shared<AlgorithmLauncher>(kernel, library);
}