2727#endif // _CCCL_CUDA_COMPILER(CLANG)
2828
2929#include < cuda/__memory_resource/get_property.h>
30- #include < cuda/__memory_resource/memory_resource_base .h>
30+ #include < cuda/__memory_resource/memory_pool_base .h>
3131#include < cuda/__memory_resource/properties.h>
3232#include < cuda/__runtime/api_wrapper.h>
3333#include < cuda/std/__concepts/concept_macros.h>
3434
3535#include < cuda/std/__cccl/prologue.h>
3636
3737// ! @file
38- // ! The \c device_memory_pool class provides an asynchronous memory resource that allocates device memory in stream
39- // ! order.
38+ // ! The \c device_memory_pool class provides an asynchronous memory resource
39+ // ! that allocates device memory in stream order.
4040_CCCL_BEGIN_NAMESPACE_CUDA
4141
4242// ! @rst
@@ -45,30 +45,34 @@ _CCCL_BEGIN_NAMESPACE_CUDA
4545// ! Stream ordered memory pool
4646// ! ------------------------------
4747// !
48- // ! ``device_memory_pool_ref`` allocates device memory using `cudaMallocFromPoolAsync / cudaFreeAsync
49- // ! <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`__ for allocation/deallocation. A
50- // ! ``device_memory_pool_ref`` is a thin wrapper around a \c cudaMemPool_t with the location type set to \c
51- // ! cudaMemLocationTypeDevice.
48+ // ! ``device_memory_pool_ref`` allocates device memory using
49+ // ! `cudaMallocFromPoolAsync / cudaFreeAsync
50+ // ! <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`__
51+ // ! for allocation/deallocation. A
52+ // ! ``device_memory_pool_ref`` is a thin wrapper around a \c cudaMemPool_t with
53+ // ! the location type set to \c cudaMemLocationTypeDevice.
5254// !
5355// ! .. warning::
5456// !
55- // ! ``device_memory_pool_ref`` does not own the pool and it is the responsibility of the user to ensure that the
56- // ! lifetime of the pool exceeds the lifetime of the ``device_memory_pool_ref``.
57+ // ! ``device_memory_pool_ref`` does not own the pool and it is the
58+ // ! responsibility of the user to ensure that the lifetime of the pool
59+ // ! exceeds the lifetime of the ``device_memory_pool_ref``.
5760// !
5861// ! @endrst
59- class device_memory_pool_ref : public __memory_resource_base
62+ class device_memory_pool_ref : public __memory_pool_base
6063{
6164public:
6265 // ! @brief Constructs the device_memory_pool_ref from a \c cudaMemPool_t.
6366 // ! @param __pool The \c cudaMemPool_t used to allocate memory.
6467 _CCCL_HOST_API explicit device_memory_pool_ref (::cudaMemPool_t __pool) noexcept
65- : __memory_resource_base (__pool)
68+ : __memory_pool_base (__pool)
6669 {}
6770
6871 device_memory_pool_ref (int ) = delete ;
6972 device_memory_pool_ref (::cuda::std::nullptr_t ) = delete ;
7073
71- // ! @brief Enables the \c device_accessible property for \c device_memory_pool_ref.
74+ // ! @brief Enables the \c device_accessible property for \c
75+ // ! device_memory_pool_ref.
7276 // ! @relates device_memory_pool_ref
7377 _CCCL_HOST_API friend constexpr void
7478 get_property (device_memory_pool_ref const &, ::cuda::mr::device_accessible) noexcept
@@ -82,12 +86,9 @@ class device_memory_pool_ref : public __memory_resource_base
8286// ! @returns The default memory pool of the specified device.
8387[[nodiscard]] inline device_memory_pool_ref device_default_memory_pool (::cuda::device_ref __device)
8488{
85- ::cuda::__verify_device_supports_stream_ordered_allocations (__device.get());
86-
87- ::cudaMemPool_t __pool;
88- _CCCL_TRY_CUDA_API (
89- ::cudaDeviceGetDefaultMemPool, " Failed to call cudaDeviceGetDefaultMemPool" , &__pool, __device.get ());
90- return device_memory_pool_ref{__pool};
89+ static ::cudaMemPool_t __pool = ::cuda::__get_default_memory_pool (
90+ ::CUmemLocation{::CU_MEM_LOCATION_TYPE_DEVICE, __device.get ()}, ::CU_MEM_ALLOCATION_TYPE_PINNED);
91+ return device_memory_pool_ref (__pool);
9192}
9293
9394// ! @rst
@@ -96,22 +97,28 @@ class device_memory_pool_ref : public __memory_resource_base
9697// ! Stream ordered memory resource
9798// ! ------------------------------
9899// !
99- // ! ``device_memory_pool`` allocates device memory using `cudaMallocFromPoolAsync / cudaFreeAsync
100- // ! <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`__ for allocation/deallocation. A
101- // ! When constructed it creates an underlying \c cudaMemPool_t with the location type set to \c
102- // ! cudaMemLocationTypeDevice and owns it.
100+ // ! ``device_memory_pool`` allocates device memory using
101+ // ! `cudaMallocFromPoolAsync / cudaFreeAsync
102+ // ! <https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY__POOLS.html>`__
103+ // ! for allocation/deallocation. A When constructed it creates an underlying \c
104+ // ! cudaMemPool_t with the location type set to \c cudaMemLocationTypeDevice and
105+ // ! owns it.
103106// !
104107// ! @endrst
105108struct device_memory_pool : device_memory_pool_ref
106109{
107110 using reference_type = device_memory_pool_ref;
108111
109- // ! @brief Constructs a \c device_memory_pool with the optionally specified initial pool size and release
110- // ! threshold. If the pool size grows beyond the release threshold, unused memory held by the pool will be released at
111- // ! the next synchronization event.
112- // ! @throws cuda_error if the CUDA version does not support ``cudaMallocAsync``.
113- // ! @param __device_id The device id of the device the stream pool is constructed on.
114- // ! @param __pool_properties Optional, additional properties of the pool to be created.
112+ // ! @brief Constructs a \c device_memory_pool with the optionally specified
113+ // ! initial pool size and release threshold. If the pool size grows beyond the
114+ // ! release threshold, unused memory held by the pool will be released at the
115+ // ! next synchronization event.
116+ // ! @throws cuda_error if the CUDA version does not support
117+ // ! ``cudaMallocAsync``.
118+ // ! @param __device_id The device id of the device the stream pool is
119+ // ! constructed on.
120+ // ! @param __pool_properties Optional, additional properties of the pool to be
121+ // ! created.
115122 _CCCL_HOST_API device_memory_pool (::cuda::device_ref __device_id, memory_pool_properties __properties = {})
116123 : device_memory_pool_ref(__create_cuda_mempool(
117124 __properties,
0 commit comments