Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/include/rmm/mr/device/detail/free_list.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <rmm/detail/export.hpp>

#include <algorithm>
#include <mutex>
#ifdef RMM_DEBUG_PRINT
#include <iostream>
#endif
Expand Down Expand Up @@ -138,6 +139,12 @@ class free_list {
}
#endif

/**
* @brief Returns a reference to the mutex used for synchronizing the free list.
*
*/
[[nodiscard]] std::mutex& get_mutex() { return mtx_; }

protected:
/**
* @brief Insert a block in the free list before the specified position
Expand Down Expand Up @@ -182,6 +189,7 @@ class free_list {

private:
list_type blocks; // The internal container of blocks
std::mutex mtx_; // The mutex for each free list
};

} // namespace mr::detail
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <cstddef>
#include <map>
#include <mutex>
#include <shared_mutex>
#include <unordered_map>
#ifdef RMM_DEBUG_PRINT
#include <iostream>
Expand Down Expand Up @@ -87,9 +88,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
stream_ordered_memory_resource& operator=(stream_ordered_memory_resource&&) = delete;

protected:
using free_list = FreeListType;
using block_type = typename free_list::block_type;
using lock_guard = std::lock_guard<std::mutex>;
using free_list = FreeListType;
using block_type = typename free_list::block_type;
using lock_guard = std::lock_guard<std::mutex>;
using read_lock_guard = std::shared_lock<std::shared_mutex>;
using write_lock_guard = std::unique_lock<std::shared_mutex>;

// Derived classes must implement these four methods

Expand Down Expand Up @@ -204,12 +207,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
*/
void* do_allocate(std::size_t size, cuda_stream_view stream) override
{
RMM_FUNC_RANGE();
RMM_LOG_TRACE("[A][stream %s][%zuB]", rmm::detail::format_stream(stream), size);

if (size <= 0) { return nullptr; }

lock_guard lock(mtx_);

auto stream_event = get_event(stream);

size = rmm::align_up(size, rmm::CUDA_ALLOCATION_ALIGNMENT);
Expand All @@ -224,7 +226,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
size,
block.pointer());

log_summary_trace();
// TODO(jigao): this logging is not protected by mutex!
// log_summary_trace();

return block.pointer();
}
Expand All @@ -238,11 +241,11 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
*/
void do_deallocate(void* ptr, std::size_t size, cuda_stream_view stream) override
{
RMM_FUNC_RANGE();
RMM_LOG_TRACE("[D][stream %s][%zuB][%p]", rmm::detail::format_stream(stream), size, ptr);

if (size <= 0 || ptr == nullptr) { return; }

lock_guard lock(mtx_);
auto stream_event = get_event(stream);

size = rmm::align_up(size, rmm::CUDA_ALLOCATION_ALIGNMENT);
Expand All @@ -253,9 +256,21 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
// streams allows stealing from deleted streams.
RMM_ASSERT_CUDA_SUCCESS(cudaEventRecord(stream_event.event, stream.value()));

stream_free_blocks_[stream_event].insert(block);

log_summary_trace();
read_lock_guard rlock(stream_free_blocks_mtx_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds a lot of complexity to do_deallocate that I think should be in the free list implementation. I originally designed this to be portable to other free list implementations, which is why this function was originally so simple -- it more or less just called insert on the stream's free list.

This allocator is already quite fast, and I think in your exploration ultimately you found that it's not the actual bottleneck? Is it worth adding so much complexity? If the complexity can be put in the free list it might be better.

// Try to find a satisfactory block in free list for the same stream (no sync required)
auto iter = stream_free_blocks_.find(stream_event);
if (iter != stream_free_blocks_.end()) {
// Hot path
lock_guard free_list_lock(iter->second.get_mutex());
iter->second.insert(block);
} else {
rlock.unlock();
// Cold path
write_lock_guard wlock(stream_free_blocks_mtx_);
stream_free_blocks_[stream_event].insert(block); // TODO(jigao): is it thread-safe?
}
Comment on lines -258 to +291
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a potential race here, I think.

Suppose that two threads are using the same stream, they both enter line 259 to search for the stream_event in stream_free_blocks_. That key doesn't exist so they go down the cold path on line 266.

Both unlock the mutex, and go ahead and try to grab the write lock. Only one of the threads can win, call this thread-A. Meanwhile thread-B waits to acquire the lock.

Now thread-A inserts its free_block, and continues, unlocking the mutex.

Now thread-B comes in to insert its block, but the key already exists, and so stream_free_blocks_[stream_event].insert(block) does nothing.

So we drop the reference to block from thread-B without returning to the pool, breaking the contract that free_block has:

  /**
   * @brief Finds, frees and returns the block associated with pointer `ptr`.
   *
   * @param ptr The pointer to the memory to free.
   * @param size The size of the memory to free. Must be equal to the original allocation size.
   * @return The (now freed) block associated with `p`. The caller is expected to return the block
                                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   * to the pool.
   */
  block_type free_block(void* ptr, std::size_t size) noexcept

Copy link
Contributor Author

@JigaoLuo JigaoLuo May 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wence- Thanks for reviewing.

, and so stream_free_blocks_[stream_event].insert(block) does nothing.

Could you explain why thread B does nothing in this case? I had the same case in mind and had the same concern. But what will happen in my mind is:

  • (Thread A and thread B must have different block from different pointers.)
  • After thread B acquires the write lock, it must see thread A's memory writes to stream_free_blocks_. The map has a new entry inserted by thread A.
  • This implies that stream_free_blocks_[stream_event] in thread B should return a valid iterator pointing to the created free_list.
  • Finally, thread B inserts its block into this list.

The conclusion I have is: Once a write lock is held on the map, operating on the map and its contained free_lists is thread-safe. Is there a flaw in this?


If the following code is preferable to line 270, I'd be happy to replace:

    auto iter = stream_free_blocks_.find(stream_event);
    free_list& blocks =
      (iter != stream_free_blocks_.end()) ? iter->second : stream_free_blocks_[stream_event];
    lock_guard free_list_lock(blocks.get_mutex());
    blocks.insert(block);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, I thought this insert was calling the insert method on the std::map that is stream_free_blocks_. However, you're right, it's calling insert on the free_list that you get back from the map lookup. So I think this implementation is thread safe.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wence —Thanks for the confirmation! But I will still replace line 270 as discussed for better code style. The write locks on free_list are acquired in a consistent order, and there should be no deadlock

I'll replace line 270 with this code section and add clarifying comments. I've already run unit tests with 64 threads on my end, and the results are stable 🚀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wence- I've refactored do_deallocate for better readability using early returns while keeping the same logic as we discussed. I also added more comments explaining the hot/cold path ideas. My commit passes local unit tests with 64 threads.

// TODO(jigao): this logging is not protected by mutex!
// log_summary_trace();
}

private:
Expand All @@ -271,7 +286,9 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
*/
stream_event_pair get_event(cuda_stream_view stream)
{
RMM_FUNC_RANGE();
if (stream.is_per_thread_default()) {
// Hot path
// Create a thread-local event for each device. These events are
// deliberately leaked since the destructor needs to call into
// the CUDA runtime and thread_local destructors (can) run below
Expand All @@ -289,6 +306,8 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
}();
return stream_event_pair{stream.value(), event};
}
write_lock_guard wlock(stream_events_mtx_);
// Cold path
// We use cudaStreamLegacy as the event map key for the default stream for consistency between
// PTDS and non-PTDS mode. In PTDS mode, the cudaStreamLegacy map key will only exist if the
// user explicitly passes it, so it is used as the default location for the free list
Expand Down Expand Up @@ -319,6 +338,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
*/
block_type allocate_and_insert_remainder(block_type block, std::size_t size, free_list& blocks)
{
RMM_FUNC_RANGE();
auto const [allocated, remainder] = this->underlying().allocate_from_block(block, size);
if (remainder.is_valid()) { blocks.insert(remainder); }
return allocated;
Expand All @@ -333,15 +353,30 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
*/
block_type get_block(std::size_t size, stream_event_pair stream_event)
{
// Try to find a satisfactory block in free list for the same stream (no sync required)
auto iter = stream_free_blocks_.find(stream_event);
if (iter != stream_free_blocks_.end()) {
block_type const block = iter->second.get_block(size);
if (block.is_valid()) { return allocate_and_insert_remainder(block, size, iter->second); }
RMM_FUNC_RANGE();
{
// The hot path of get_block:
// 1. Read-lock the map for lookup
// 2. then exclusively lock the free_list to get a block locally.
read_lock_guard rlock(stream_free_blocks_mtx_);
// Try to find a satisfactory block in free list for the same stream (no sync required)
auto iter = stream_free_blocks_.find(stream_event);
if (iter != stream_free_blocks_.end()) {
lock_guard free_list_lock(iter->second.get_mutex());
block_type const block = iter->second.get_block(size);
if (block.is_valid()) { return allocate_and_insert_remainder(block, size, iter->second); }
}
}

// The cold path of get_block:
// Write lock the map to safely perform another lookup and possibly modify entries.
// This exclusive lock ensures no other threads can access the map and all free lists in the
// map.
write_lock_guard wlock(stream_free_blocks_mtx_);
auto iter = stream_free_blocks_.find(stream_event);
free_list& blocks =
(iter != stream_free_blocks_.end()) ? iter->second : stream_free_blocks_[stream_event];
lock_guard free_list_lock(blocks.get_mutex());

// Try to find an existing block in another stream
{
Expand Down Expand Up @@ -382,6 +417,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
free_list& blocks,
bool merge_first)
{
RMM_FUNC_RANGE();
auto find_block = [&](auto iter) {
auto other_event = iter->first.event;
auto& other_blocks = iter->second;
Expand Down Expand Up @@ -415,6 +451,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
++next_iter; // Points to element after `iter` to allow erasing `iter` in the loop body

if (iter->first.event != stream_event.event) {
lock_guard free_list_lock(iter->second.get_mutex());
block_type const block = find_block(iter);

if (block.is_valid()) {
Expand All @@ -435,6 +472,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
cudaEvent_t other_event,
free_list&& other_blocks)
{
RMM_FUNC_RANGE();
// Since we found a block associated with a different stream, we have to insert a wait
// on the stream's associated event into the allocating stream.
RMM_CUDA_TRY(cudaStreamWaitEvent(stream_event.stream, other_event, 0));
Expand All @@ -450,7 +488,10 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
*/
void release()
{
lock_guard lock(mtx_);
RMM_FUNC_RANGE();
// lock_guard lock(mtx_); TOOD(jigao): rethink mtx_
write_lock_guard stream_event_lock(stream_events_mtx_);
write_lock_guard wlock(stream_free_blocks_mtx_);

for (auto s_e : stream_events_) {
RMM_ASSERT_CUDA_SUCCESS(cudaEventSynchronize(s_e.second.event));
Expand All @@ -464,6 +505,7 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
void log_summary_trace()
{
#if (RMM_LOG_ACTIVE_LEVEL <= RMM_LOG_LEVEL_TRACE)
RMM_FUNC_RANGE();
std::size_t num_blocks{0};
std::size_t max_block{0};
std::size_t free_mem{0};
Expand Down Expand Up @@ -491,8 +533,17 @@ class stream_ordered_memory_resource : public crtp<PoolResource>, public device_
// bidirectional mapping between non-default streams and events
std::unordered_map<cudaStream_t, stream_event_pair> stream_events_;

// TODO(jigao): think about get_mutex function?
std::mutex mtx_; // mutex for thread-safe access

// mutex for thread-safe access to stream_free_blocks_
// Used in the writing part of get_block, get_block_from_other_stream
std::shared_mutex stream_free_blocks_mtx_;

// mutex for thread-safe access to stream_events_
// Used in the NON-PTDS part of get_event
std::shared_mutex stream_events_mtx_;

rmm::cuda_device_id device_id_{rmm::get_current_cuda_device()};
}; // namespace detail

Expand Down
1 change: 1 addition & 0 deletions cpp/include/rmm/mr/device/pool_memory_resource.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ class pool_memory_resource final
*/
block_type free_block(void* ptr, std::size_t size) noexcept
{
RMM_FUNC_RANGE();
#ifdef RMM_POOL_TRACK_ALLOCATIONS
if (ptr == nullptr) return block_type{};
auto const iter = allocated_blocks_.find(static_cast<char*>(ptr));
Expand Down
2 changes: 1 addition & 1 deletion cpp/tests/mr/device/mr_ref_multithreaded_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void spawn_n(std::size_t num_threads, Task task, Arguments&&... args)
template <typename Task, typename... Arguments>
void spawn(Task task, Arguments&&... args)
{
spawn_n(4, task, std::forward<Arguments>(args)...);
spawn_n(16, task, std::forward<Arguments>(args)...);
}

TEST(DefaultTest, UseCurrentDeviceResource_mt) { spawn(test_get_current_device_resource); }
Expand Down