3939#include < shared_mutex>
4040#include < unordered_map>
4141#include < cuda/atomic>
42+ #include < pthread.h>
4243
4344#include " matx/core/error.h"
4445
@@ -48,6 +49,24 @@ namespace detail {
4849static constexpr size_t MAX_CUDA_DEVICES_PER_SYSTEM = 16 ;
4950using CacheId = uint64_t ;
5051
52+ // Common cache parameters that every cache entry needs
53+ struct CacheCommonParamsKey {
54+ int device_id;
55+ pthread_t thread_id;
56+
57+ bool operator ==(const CacheCommonParamsKey& other) const {
58+ return device_id == other.device_id && pthread_equal (thread_id, other.thread_id );
59+ }
60+ };
61+
62+ struct CacheCommonParamsKeyHash {
63+ std::size_t operator ()(const CacheCommonParamsKey& key) const {
64+ std::size_t h1 = std::hash<int >{}(key.device_id );
65+ std::size_t h2 = std::hash<pthread_t >{}(key.thread_id );
66+ return h1 ^ (h2 << 1 );
67+ }
68+ };
69+
5170#ifndef DOXYGEN_ONLY
5271__attribute__ ((visibility (" default" )))
5372#endif
@@ -96,38 +115,39 @@ class matxCache_t {
96115 auto el = cache.find (id);
97116 MATX_ASSERT_STR (el != cache.end (), matxInvalidType, " Cache type not found" );
98117
99- for (int i = 0 ; i < static_cast <int >(MAX_CUDA_DEVICES_PER_SYSTEM ); i++) {
100- using CacheArray = cuda::std::array<CacheType, MAX_CUDA_DEVICES_PER_SYSTEM >;
101- std::any_cast<CacheArray&>(el->second )[i].clear ();
102- }
118+ using CacheMap = std::unordered_map<CacheCommonParamsKey, CacheType, CacheCommonParamsKeyHash>;
119+ std::any_cast<CacheMap&>(el->second ).clear ();
103120 }
104121
105122 template <typename CacheType, typename InParams, typename MakeFun, typename ExecFun, typename Executor>
106123 void LookupAndExec (const CacheId &id, const InParams ¶ms, const MakeFun &mfun, const ExecFun &efun, [[maybe_unused]] const Executor &exec) {
107124 // This mutex should eventually be finer-grained so each transform doesn't get blocked by others
108125 [[maybe_unused]] std::lock_guard<std::recursive_mutex> lock (cache_mtx);
109- using CacheArray = cuda:: std::array< CacheType, MAX_CUDA_DEVICES_PER_SYSTEM >;
126+ using CacheMap = std::unordered_map<CacheCommonParamsKey, CacheType, CacheCommonParamsKeyHash >;
110127
111128 // Create named cache if it doesn't exist
112- int device_id;
129+ CacheCommonParamsKey key;
130+ key.thread_id = pthread_self ();
131+
113132 auto el = cache.find (id);
114133 if (el == cache.end ()) {
115- cache[id] = CacheArray {};
134+ cache[id] = CacheMap {};
116135 }
117136
118137 auto &cval = cache[id];
119138 if constexpr (is_cuda_executor_v<Executor>) {
120- cudaGetDevice (&device_id);
139+ cudaGetDevice (&key. device_id );
121140 }
122141 else {
123- device_id = 0 ;
142+ key. device_id = 0 ;
124143 }
125144
126- auto &rmap = std::any_cast<CacheArray&>(cval)[device_id];
127- auto cache_el = rmap.find (params);
128- if (cache_el == rmap.end ()) {
145+ auto &rmap = std::any_cast<CacheMap&>(cval);
146+ auto &common_params_cache = rmap[key];
147+ auto cache_el = common_params_cache.find (params);
148+ if (cache_el == common_params_cache.end ()) {
129149 std::any tmp = mfun ();
130- rmap .insert ({params, tmp});
150+ common_params_cache .insert ({params, tmp});
131151 efun (std::any_cast<decltype (mfun ())>(tmp));
132152 }
133153 else {
@@ -137,11 +157,13 @@ class matxCache_t {
137157
138158 void * GetStreamAlloc (cudaStream_t stream, size_t size) {
139159 void *ptr = nullptr ;
140- int device_id;
141- cudaGetDevice (&device_id);
160+ CacheCommonParamsKey key;
161+ key.thread_id = pthread_self ();
162+ cudaGetDevice (&key.device_id );
142163
143- auto el = stream_alloc_cache[device_id].find (stream);
144- if (el == stream_alloc_cache[device_id].end ()) {
164+ auto &common_params_cache = stream_alloc_cache[key];
165+ auto el = common_params_cache.find (stream);
166+ if (el == common_params_cache.end ()) {
145167 StreamAllocation alloc;
146168
147169 // We allocate at least 2MB for workspace so we don't keep reallocating from small sizes
@@ -150,7 +172,7 @@ class matxCache_t {
150172
151173 alloc.size = size;
152174 alloc.ptr = ptr;
153- stream_alloc_cache[device_id] [stream] = alloc;
175+ common_params_cache [stream] = alloc;
154176 }
155177 else if (el->second .size < size) {
156178 // Free the old allocation and allocate a new one
@@ -168,7 +190,7 @@ class matxCache_t {
168190
169191private:
170192 std::unordered_map<CacheId, std::any> cache;
171- cuda:: std::array< std::unordered_map<cudaStream_t, StreamAllocation>, MAX_CUDA_DEVICES_PER_SYSTEM > stream_alloc_cache;
193+ std::unordered_map<CacheCommonParamsKey, std::unordered_map<cudaStream_t, StreamAllocation>, CacheCommonParamsKeyHash > stream_alloc_cache;
172194};
173195
174196/* *
0 commit comments