2929#include < cstddef>
3030#include < map>
3131#include < mutex>
32+ #include < shared_mutex>
3233#include < unordered_map>
3334#ifdef RMM_DEBUG_PRINT
3435#include < iostream>
@@ -87,9 +88,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
8788 stream_ordered_memory_resource& operator =(stream_ordered_memory_resource&&) = delete ;
8889
8990 protected:
90- using free_list = FreeListType;
91- using block_type = typename free_list::block_type;
92- using lock_guard = std::lock_guard<std::mutex>;
91+ using free_list = FreeListType;
92+ using block_type = typename free_list::block_type;
93+ using lock_guard = std::lock_guard<std::mutex>;
94+ using read_lock_guard = std::shared_lock<std::shared_mutex>;
95+ using write_lock_guard = std::unique_lock<std::shared_mutex>;
9396
9497 // Derived classes must implement these four methods
9598
@@ -204,12 +207,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
204207 */
205208 void * do_allocate (std::size_t size, cuda_stream_view stream) override
206209 {
210+ RMM_FUNC_RANGE ();
207211 RMM_LOG_TRACE (" [A][stream %s][%zuB]" , rmm::detail::format_stream (stream), size);
208212
209213 if (size <= 0 ) { return nullptr ; }
210214
211- lock_guard lock (mtx_);
212-
213215 auto stream_event = get_event (stream);
214216
215217 size = rmm::align_up (size, rmm::CUDA_ALLOCATION_ALIGNMENT);
@@ -224,7 +226,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
224226 size,
225227 block.pointer ());
226228
227- log_summary_trace ();
229+ // TODO(jigao): this logging is not protected by mutex!
230+ // log_summary_trace();
228231
229232 return block.pointer ();
230233 }
@@ -238,11 +241,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
238241 */
239242 void do_deallocate (void * ptr, std::size_t size, cuda_stream_view stream) override
240243 {
244+ RMM_FUNC_RANGE ();
241245 RMM_LOG_TRACE (" [D][stream %s][%zuB][%p]" , rmm::detail::format_stream (stream), size, ptr);
242246
243247 if (size <= 0 || ptr == nullptr ) { return ; }
244248
245- lock_guard lock (mtx_);
246249 auto stream_event = get_event (stream);
247250
248251 size = rmm::align_up (size, rmm::CUDA_ALLOCATION_ALIGNMENT);
@@ -253,9 +256,21 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
253256 // streams allows stealing from deleted streams.
254257 RMM_ASSERT_CUDA_SUCCESS (cudaEventRecord (stream_event.event , stream.value ()));
255258
256- stream_free_blocks_[stream_event].insert (block);
257-
258- log_summary_trace ();
259+ read_lock_guard rlock (stream_free_blocks_mtx_);
260+ // Try to find a satisfactory block in free list for the same stream (no sync required)
261+ auto iter = stream_free_blocks_.find (stream_event);
262+ if (iter != stream_free_blocks_.end ()) {
263+ // Hot path
264+ lock_guard free_list_lock (iter->second .get_mutex ());
265+ iter->second .insert (block);
266+ } else {
267+ rlock.unlock ();
268+ // Cold path
269+ write_lock_guard wlock (stream_free_blocks_mtx_);
270+ stream_free_blocks_[stream_event].insert (block); // TODO(jigao): is it thread-safe?
271+ }
272+ // TODO(jigao): this logging is not protected by mutex!
273+ // log_summary_trace();
259274 }
260275
261276 private:
@@ -271,7 +286,9 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
271286 */
272287 stream_event_pair get_event (cuda_stream_view stream)
273288 {
289+ RMM_FUNC_RANGE ();
274290 if (stream.is_per_thread_default ()) {
291+ // Hot path
275292 // Create a thread-local event for each device. These events are
276293 // deliberately leaked since the destructor needs to call into
277294 // the CUDA runtime and thread_local destructors (can) run below
@@ -289,6 +306,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
289306 }();
290307 return stream_event_pair{stream.value (), event};
291308 }
309+ write_lock_guard wlock (stream_events_mtx_);
310+ // Cold path
292311 // We use cudaStreamLegacy as the event map key for the default stream for consistency between
293312 // PTDS and non-PTDS mode. In PTDS mode, the cudaStreamLegacy map key will only exist if the
294313 // user explicitly passes it, so it is used as the default location for the free list
@@ -319,6 +338,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
319338 */
320339 block_type allocate_and_insert_remainder (block_type block, std::size_t size, free_list& blocks)
321340 {
341+ RMM_FUNC_RANGE ();
322342 auto const [allocated, remainder] = this ->underlying ().allocate_from_block (block, size);
323343 if (remainder.is_valid ()) { blocks.insert (remainder); }
324344 return allocated;
@@ -333,15 +353,30 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
333353 */
334354 block_type get_block (std::size_t size, stream_event_pair stream_event)
335355 {
336- // Try to find a satisfactory block in free list for the same stream (no sync required)
337- auto iter = stream_free_blocks_.find (stream_event);
338- if (iter != stream_free_blocks_.end ()) {
339- block_type const block = iter->second .get_block (size);
340- if (block.is_valid ()) { return allocate_and_insert_remainder (block, size, iter->second ); }
356+ RMM_FUNC_RANGE ();
357+ {
358+ // The hot path of get_block:
359+ // 1. Read-lock the map for lookup
360+ // 2. then exclusively lock the free_list to get a block locally.
361+ read_lock_guard rlock (stream_free_blocks_mtx_);
362+ // Try to find a satisfactory block in free list for the same stream (no sync required)
363+ auto iter = stream_free_blocks_.find (stream_event);
364+ if (iter != stream_free_blocks_.end ()) {
365+ lock_guard free_list_lock (iter->second .get_mutex ());
366+ block_type const block = iter->second .get_block (size);
367+ if (block.is_valid ()) { return allocate_and_insert_remainder (block, size, iter->second ); }
368+ }
341369 }
342370
371+ // The cold path of get_block:
372+ // Write lock the map to safely perform another lookup and possibly modify entries.
373+ // This exclusive lock ensures no other threads can access the map and all free lists in the
374+ // map.
375+ write_lock_guard wlock (stream_free_blocks_mtx_);
376+ auto iter = stream_free_blocks_.find (stream_event);
343377 free_list& blocks =
344378 (iter != stream_free_blocks_.end ()) ? iter->second : stream_free_blocks_[stream_event];
379+ lock_guard free_list_lock (blocks.get_mutex ());
345380
346381 // Try to find an existing block in another stream
347382 {
@@ -382,6 +417,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
382417 free_list& blocks,
383418 bool merge_first)
384419 {
420+ RMM_FUNC_RANGE ();
385421 auto find_block = [&](auto iter) {
386422 auto other_event = iter->first .event ;
387423 auto & other_blocks = iter->second ;
@@ -415,6 +451,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
415451 ++next_iter; // Points to element after `iter` to allow erasing `iter` in the loop body
416452
417453 if (iter->first .event != stream_event.event ) {
454+ lock_guard free_list_lock (iter->second .get_mutex ());
418455 block_type const block = find_block (iter);
419456
420457 if (block.is_valid ()) {
@@ -435,6 +472,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
435472 cudaEvent_t other_event,
436473 free_list&& other_blocks)
437474 {
475+ RMM_FUNC_RANGE ();
438476 // Since we found a block associated with a different stream, we have to insert a wait
439477 // on the stream's associated event into the allocating stream.
440478 RMM_CUDA_TRY (cudaStreamWaitEvent (stream_event.stream , other_event, 0 ));
@@ -450,7 +488,10 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
450488 */
451489 void release ()
452490 {
453- lock_guard lock (mtx_);
491+ RMM_FUNC_RANGE ();
492+ // lock_guard lock(mtx_); TOOD(jigao): rethink mtx_
493+ write_lock_guard stream_event_lock (stream_events_mtx_);
494+ write_lock_guard wlock (stream_free_blocks_mtx_);
454495
455496 for (auto s_e : stream_events_) {
456497 RMM_ASSERT_CUDA_SUCCESS (cudaEventSynchronize (s_e.second .event ));
@@ -464,6 +505,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
464505 void log_summary_trace ()
465506 {
466507#if (RMM_LOG_ACTIVE_LEVEL <= RMM_LOG_LEVEL_TRACE)
508+ RMM_FUNC_RANGE ();
467509 std::size_t num_blocks{0 };
468510 std::size_t max_block{0 };
469511 std::size_t free_mem{0 };
@@ -491,8 +533,17 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
491533 // bidirectional mapping between non-default streams and events
492534 std::unordered_map<cudaStream_t, stream_event_pair> stream_events_;
493535
536+ // TODO(jigao): think about get_mutex function?
494537 std::mutex mtx_; // mutex for thread-safe access
495538
539+ // mutex for thread-safe access to stream_free_blocks_
540+ // Used in the writing part of get_block, get_block_from_other_stream
541+ std::shared_mutex stream_free_blocks_mtx_;
542+
543+ // mutex for thread-safe access to stream_events_
544+ // Used in the NON-PTDS part of get_event
545+ std::shared_mutex stream_events_mtx_;
546+
496547 rmm::cuda_device_id device_id_{rmm::get_current_cuda_device ()};
497548}; // namespace detail
498549
0 commit comments