Skip to content

Commit 2d542bb

Browse files
committed
Make cache entries per-thread since most CUDA libraries are not thread-safe when sharing plans/handles
1 parent a5b571e commit 2d542bb

1 file changed

Lines changed: 41 additions & 19 deletions

File tree

include/matx/core/cache.h

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include <shared_mutex>
4040
#include <unordered_map>
4141
#include <cuda/atomic>
42+
#include <pthread.h>
4243

4344
#include "matx/core/error.h"
4445

@@ -48,6 +49,24 @@ namespace detail {
4849
static constexpr size_t MAX_CUDA_DEVICES_PER_SYSTEM = 16;
4950
using CacheId = uint64_t;
5051

52+
// Common cache parameters that every cache entry needs
53+
struct CacheCommonParamsKey {
54+
int device_id;
55+
pthread_t thread_id;
56+
57+
bool operator==(const CacheCommonParamsKey& other) const {
58+
return device_id == other.device_id && pthread_equal(thread_id, other.thread_id);
59+
}
60+
};
61+
62+
struct CacheCommonParamsKeyHash {
63+
std::size_t operator()(const CacheCommonParamsKey& key) const {
64+
std::size_t h1 = std::hash<int>{}(key.device_id);
65+
std::size_t h2 = std::hash<pthread_t>{}(key.thread_id);
66+
return h1 ^ (h2 << 1);
67+
}
68+
};
69+
5170
#ifndef DOXYGEN_ONLY
5271
__attribute__ ((visibility ("default")))
5372
#endif
@@ -96,38 +115,39 @@ class matxCache_t {
96115
auto el = cache.find(id);
97116
MATX_ASSERT_STR(el != cache.end(), matxInvalidType, "Cache type not found");
98117

99-
for (int i = 0; i < static_cast<int>(MAX_CUDA_DEVICES_PER_SYSTEM); i++) {
100-
using CacheArray = cuda::std::array<CacheType, MAX_CUDA_DEVICES_PER_SYSTEM>;
101-
std::any_cast<CacheArray&>(el->second)[i].clear();
102-
}
118+
using CacheMap = std::unordered_map<CacheCommonParamsKey, CacheType, CacheCommonParamsKeyHash>;
119+
std::any_cast<CacheMap&>(el->second).clear();
103120
}
104121

105122
template <typename CacheType, typename InParams, typename MakeFun, typename ExecFun, typename Executor>
106123
void LookupAndExec(const CacheId &id, const InParams &params, const MakeFun &mfun, const ExecFun &efun, [[maybe_unused]] const Executor &exec) {
107124
// This mutex should eventually be finer-grained so each transform doesn't get blocked by others
108125
[[maybe_unused]] std::lock_guard<std::recursive_mutex> lock(cache_mtx);
109-
using CacheArray = cuda::std::array<CacheType, MAX_CUDA_DEVICES_PER_SYSTEM>;
126+
using CacheMap = std::unordered_map<CacheCommonParamsKey, CacheType, CacheCommonParamsKeyHash>;
110127

111128
// Create named cache if it doesn't exist
112-
int device_id;
129+
CacheCommonParamsKey key;
130+
key.thread_id = pthread_self();
131+
113132
auto el = cache.find(id);
114133
if (el == cache.end()) {
115-
cache[id] = CacheArray{};
134+
cache[id] = CacheMap{};
116135
}
117136

118137
auto &cval = cache[id];
119138
if constexpr (is_cuda_executor_v<Executor>) {
120-
cudaGetDevice(&device_id);
139+
cudaGetDevice(&key.device_id);
121140
}
122141
else {
123-
device_id = 0;
142+
key.device_id = 0;
124143
}
125144

126-
auto &rmap = std::any_cast<CacheArray&>(cval)[device_id];
127-
auto cache_el = rmap.find(params);
128-
if (cache_el == rmap.end()) {
145+
auto &rmap = std::any_cast<CacheMap&>(cval);
146+
auto &common_params_cache = rmap[key];
147+
auto cache_el = common_params_cache.find(params);
148+
if (cache_el == common_params_cache.end()) {
129149
std::any tmp = mfun();
130-
rmap.insert({params, tmp});
150+
common_params_cache.insert({params, tmp});
131151
efun(std::any_cast<decltype(mfun())>(tmp));
132152
}
133153
else {
@@ -137,11 +157,13 @@ class matxCache_t {
137157

138158
void* GetStreamAlloc(cudaStream_t stream, size_t size) {
139159
void *ptr = nullptr;
140-
int device_id;
141-
cudaGetDevice(&device_id);
160+
CacheCommonParamsKey key;
161+
key.thread_id = pthread_self();
162+
cudaGetDevice(&key.device_id);
142163

143-
auto el = stream_alloc_cache[device_id].find(stream);
144-
if (el == stream_alloc_cache[device_id].end()) {
164+
auto &common_params_cache = stream_alloc_cache[key];
165+
auto el = common_params_cache.find(stream);
166+
if (el == common_params_cache.end()) {
145167
StreamAllocation alloc;
146168

147169
// We allocate at least 2MB for workspace so we don't keep reallocating from small sizes
@@ -150,7 +172,7 @@ class matxCache_t {
150172

151173
alloc.size = size;
152174
alloc.ptr = ptr;
153-
stream_alloc_cache[device_id][stream] = alloc;
175+
common_params_cache[stream] = alloc;
154176
}
155177
else if (el->second.size < size) {
156178
// Free the old allocation and allocate a new one
@@ -168,7 +190,7 @@ class matxCache_t {
168190

169191
private:
170192
std::unordered_map<CacheId, std::any> cache;
171-
cuda::std::array<std::unordered_map<cudaStream_t, StreamAllocation>, MAX_CUDA_DEVICES_PER_SYSTEM> stream_alloc_cache;
193+
std::unordered_map<CacheCommonParamsKey, std::unordered_map<cudaStream_t, StreamAllocation>, CacheCommonParamsKeyHash> stream_alloc_cache;
172194
};
173195

174196
/**

0 commit comments

Comments
 (0)