-
Notifications
You must be signed in to change notification settings - Fork 247
[WIP] [Performance Improvement] Fine-granularity locking in stream_ordered_memory_resource #1912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
e0633c4
25ddff2
4af208a
f010620
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,6 +29,7 @@ | |
| #include <cstddef> | ||
| #include <map> | ||
| #include <mutex> | ||
| #include <shared_mutex> | ||
| #include <unordered_map> | ||
| #ifdef RMM_DEBUG_PRINT | ||
| #include <iostream> | ||
|
|
@@ -87,9 +88,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| stream_ordered_memory_resource& operator=(stream_ordered_memory_resource&&) = delete; | ||
|
|
||
| protected: | ||
| using free_list = FreeListType; | ||
| using block_type = typename free_list::block_type; | ||
| using lock_guard = std::lock_guard<std::mutex>; | ||
| using free_list = FreeListType; | ||
| using block_type = typename free_list::block_type; | ||
| using lock_guard = std::lock_guard<std::mutex>; | ||
| using read_lock_guard = std::shared_lock<std::shared_mutex>; | ||
| using write_lock_guard = std::unique_lock<std::shared_mutex>; | ||
|
|
||
| // Derived classes must implement these four methods | ||
|
|
||
|
|
@@ -204,12 +207,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| */ | ||
| void* do_allocate(std::size_t size, cuda_stream_view stream) override | ||
| { | ||
| RMM_FUNC_RANGE(); | ||
| RMM_LOG_TRACE("[A][stream %s][%zuB]", rmm::detail::format_stream(stream), size); | ||
|
|
||
| if (size <= 0) { return nullptr; } | ||
|
|
||
| lock_guard lock(mtx_); | ||
|
|
||
| auto stream_event = get_event(stream); | ||
|
|
||
| size = rmm::align_up(size, rmm::CUDA_ALLOCATION_ALIGNMENT); | ||
|
|
@@ -224,7 +226,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| size, | ||
| block.pointer()); | ||
|
|
||
| log_summary_trace(); | ||
| // TODO(jigao): this logging is not protected by mutex! | ||
| // log_summary_trace(); | ||
|
|
||
| return block.pointer(); | ||
| } | ||
|
|
@@ -238,11 +241,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| */ | ||
| void do_deallocate(void* ptr, std::size_t size, cuda_stream_view stream) override | ||
| { | ||
| RMM_FUNC_RANGE(); | ||
| RMM_LOG_TRACE("[D][stream %s][%zuB][%p]", rmm::detail::format_stream(stream), size, ptr); | ||
|
|
||
| if (size <= 0 || ptr == nullptr) { return; } | ||
|
|
||
| lock_guard lock(mtx_); | ||
| auto stream_event = get_event(stream); | ||
|
|
||
| size = rmm::align_up(size, rmm::CUDA_ALLOCATION_ALIGNMENT); | ||
|
|
@@ -253,9 +256,21 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| // streams allows stealing from deleted streams. | ||
| RMM_ASSERT_CUDA_SUCCESS(cudaEventRecord(stream_event.event, stream.value())); | ||
|
|
||
| stream_free_blocks_[stream_event].insert(block); | ||
|
|
||
| log_summary_trace(); | ||
| read_lock_guard rlock(stream_free_blocks_mtx_); | ||
| // Try to find a satisfactory block in free list for the same stream (no sync required) | ||
| auto iter = stream_free_blocks_.find(stream_event); | ||
| if (iter != stream_free_blocks_.end()) { | ||
| // Hot path | ||
| lock_guard free_list_lock(iter->second.get_mutex()); | ||
| iter->second.insert(block); | ||
| } else { | ||
| rlock.unlock(); | ||
| // Cold path | ||
| write_lock_guard wlock(stream_free_blocks_mtx_); | ||
| stream_free_blocks_[stream_event].insert(block); // TODO(jigao): is it thread-safe? | ||
| } | ||
|
Comment on lines
-258
to
+291
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a potential race here, I think. Suppose that two threads are using the same stream, they both enter line 259 to search for the Both unlock the mutex, and go ahead and try to grab the write lock. Only one of the threads can win, call this thread-A. Meanwhile thread-B waits to acquire the lock. Now thread-A inserts its Now thread-B comes in to insert its block, but the key already exists, and so So we drop the reference to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wence- Thanks for reviewing.
Could you explain why thread B does nothing in this case? I had the same case in mind and had the same concern. But what will happen in my mind is:
The conclusion I have is: Once a write lock is held on the map, operating on the map and its contained free_lists is thread-safe. Is there a flaw in this? If the following code is preferable to line 270, I'd be happy to replace:
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah sorry, I thought this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wence —Thanks for the confirmation! But I will still replace line 270 as discussed for better code style. The write locks on I'll replace line 270 with this code section and add clarifying comments. I've already run unit tests with 64 threads on my end, and the results are stable 🚀
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wence- I've refactored |
||
| // TODO(jigao): this logging is not protected by mutex! | ||
| // log_summary_trace(); | ||
| } | ||
|
|
||
| private: | ||
|
|
@@ -271,7 +286,9 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| */ | ||
| stream_event_pair get_event(cuda_stream_view stream) | ||
| { | ||
| RMM_FUNC_RANGE(); | ||
| if (stream.is_per_thread_default()) { | ||
| // Hot path | ||
| // Create a thread-local event for each device. These events are | ||
| // deliberately leaked since the destructor needs to call into | ||
| // the CUDA runtime and thread_local destructors (can) run below | ||
|
|
@@ -289,6 +306,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| }(); | ||
| return stream_event_pair{stream.value(), event}; | ||
| } | ||
| write_lock_guard wlock(stream_events_mtx_); | ||
| // Cold path | ||
| // We use cudaStreamLegacy as the event map key for the default stream for consistency between | ||
| // PTDS and non-PTDS mode. In PTDS mode, the cudaStreamLegacy map key will only exist if the | ||
| // user explicitly passes it, so it is used as the default location for the free list | ||
|
|
@@ -319,6 +338,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| */ | ||
| block_type allocate_and_insert_remainder(block_type block, std::size_t size, free_list& blocks) | ||
| { | ||
| RMM_FUNC_RANGE(); | ||
| auto const [allocated, remainder] = this->underlying().allocate_from_block(block, size); | ||
| if (remainder.is_valid()) { blocks.insert(remainder); } | ||
| return allocated; | ||
|
|
@@ -333,15 +353,30 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| */ | ||
| block_type get_block(std::size_t size, stream_event_pair stream_event) | ||
| { | ||
| // Try to find a satisfactory block in free list for the same stream (no sync required) | ||
| auto iter = stream_free_blocks_.find(stream_event); | ||
| if (iter != stream_free_blocks_.end()) { | ||
| block_type const block = iter->second.get_block(size); | ||
| if (block.is_valid()) { return allocate_and_insert_remainder(block, size, iter->second); } | ||
| RMM_FUNC_RANGE(); | ||
| { | ||
| // The hot path of get_block: | ||
| // 1. Read-lock the map for lookup | ||
| // 2. then exclusively lock the free_list to get a block locally. | ||
| read_lock_guard rlock(stream_free_blocks_mtx_); | ||
| // Try to find a satisfactory block in free list for the same stream (no sync required) | ||
| auto iter = stream_free_blocks_.find(stream_event); | ||
| if (iter != stream_free_blocks_.end()) { | ||
| lock_guard free_list_lock(iter->second.get_mutex()); | ||
| block_type const block = iter->second.get_block(size); | ||
| if (block.is_valid()) { return allocate_and_insert_remainder(block, size, iter->second); } | ||
| } | ||
| } | ||
|
|
||
| // The cold path of get_block: | ||
| // Write lock the map to safely perform another lookup and possibly modify entries. | ||
| // This exclusive lock ensures no other threads can access the map and all free lists in the | ||
| // map. | ||
| write_lock_guard wlock(stream_free_blocks_mtx_); | ||
| auto iter = stream_free_blocks_.find(stream_event); | ||
| free_list& blocks = | ||
| (iter != stream_free_blocks_.end()) ? iter->second : stream_free_blocks_[stream_event]; | ||
| lock_guard free_list_lock(blocks.get_mutex()); | ||
|
|
||
| // Try to find an existing block in another stream | ||
| { | ||
|
|
@@ -382,6 +417,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| free_list& blocks, | ||
| bool merge_first) | ||
| { | ||
| RMM_FUNC_RANGE(); | ||
| auto find_block = [&](auto iter) { | ||
| auto other_event = iter->first.event; | ||
| auto& other_blocks = iter->second; | ||
|
|
@@ -415,6 +451,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| ++next_iter; // Points to element after `iter` to allow erasing `iter` in the loop body | ||
|
|
||
| if (iter->first.event != stream_event.event) { | ||
| lock_guard free_list_lock(iter->second.get_mutex()); | ||
| block_type const block = find_block(iter); | ||
|
|
||
| if (block.is_valid()) { | ||
|
|
@@ -435,6 +472,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| cudaEvent_t other_event, | ||
| free_list&& other_blocks) | ||
| { | ||
| RMM_FUNC_RANGE(); | ||
| // Since we found a block associated with a different stream, we have to insert a wait | ||
| // on the stream's associated event into the allocating stream. | ||
| RMM_CUDA_TRY(cudaStreamWaitEvent(stream_event.stream, other_event, 0)); | ||
|
|
@@ -450,7 +488,10 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| */ | ||
| void release() | ||
| { | ||
| lock_guard lock(mtx_); | ||
| RMM_FUNC_RANGE(); | ||
| // lock_guard lock(mtx_); TOOD(jigao): rethink mtx_ | ||
| write_lock_guard stream_event_lock(stream_events_mtx_); | ||
| write_lock_guard wlock(stream_free_blocks_mtx_); | ||
|
|
||
| for (auto s_e : stream_events_) { | ||
| RMM_ASSERT_CUDA_SUCCESS(cudaEventSynchronize(s_e.second.event)); | ||
|
|
@@ -464,6 +505,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| void log_summary_trace() | ||
| { | ||
| #if (RMM_LOG_ACTIVE_LEVEL <= RMM_LOG_LEVEL_TRACE) | ||
| RMM_FUNC_RANGE(); | ||
| std::size_t num_blocks{0}; | ||
| std::size_t max_block{0}; | ||
| std::size_t free_mem{0}; | ||
|
|
@@ -491,8 +533,17 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_ | |
| // bidirectional mapping between non-default streams and events | ||
| std::unordered_map<cudaStream_t, stream_event_pair> stream_events_; | ||
|
|
||
| // TODO(jigao): think about get_mutex function? | ||
| std::mutex mtx_; // mutex for thread-safe access | ||
|
|
||
| // mutex for thread-safe access to stream_free_blocks_ | ||
| // Used in the writing part of get_block, get_block_from_other_stream | ||
| std::shared_mutex stream_free_blocks_mtx_; | ||
|
|
||
| // mutex for thread-safe access to stream_events_ | ||
| // Used in the NON-PTDS part of get_event | ||
| std::shared_mutex stream_events_mtx_; | ||
|
|
||
| rmm::cuda_device_id device_id_{rmm::get_current_cuda_device()}; | ||
| }; // namespace detail | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This adds a lot of complexity to
do_deallocatethat I think should be in the free list implementation. I originally designed this to be portable to other free list implementations, which is why this function was originally so simple -- it more or less just calledinserton the stream's free list.This allocator is already quite fast, and I think in your exploration ultimately you found that it's not the actual bottleneck? Is it worth adding so much complexity? If the complexity can be put in the free list it might be better.