Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 127 additions & 105 deletions cub/cub/thread/thread_store.cuh
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
// SPDX-FileCopyrightText: Copyright (c) 2011, Duane Merrill. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2011-2026, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: BSD-3

/**
* @file
* Thread utilities for writing memory using PTX cache modifiers.
*/
//! @file
//! Thread utilities for writing memory using PTX cache modifiers.

#pragma once

Expand All @@ -22,18 +20,17 @@
#include <cub/util_ptx.cuh>
#include <cub/util_type.cuh>

#include <cuda/std/__iterator/concepts.h>
#include <cuda/std/__memory/pointer_traits.h>
#include <cuda/std/__type_traits/integral_constant.h>
#include <cuda/std/__type_traits/is_pointer.h>

CUB_NAMESPACE_BEGIN

//-----------------------------------------------------------------------------
// Tags and constants
//-----------------------------------------------------------------------------

/**
* @brief Enumeration of cache modifiers for memory store operations.
*/
//! @brief Enumeration of cache modifiers for memory store operations.
enum CacheStoreModifier
{
STORE_DEFAULT, ///< Default (no modifier)
Expand All @@ -44,50 +41,46 @@ enum CacheStoreModifier
STORE_VOLATILE, ///< Volatile shared (any memory space)
};

/**
* @name Thread I/O (cache modified)
* @{
*/

/**
* @brief Thread utility for writing memory using cub::CacheStoreModifier cache modifiers.
* Can be used to store any data type.
*
* @par Example
* @code
* #include <cub/cub.cuh> // or equivalently <cub/thread/thread_store.cuh>
*
* // 32-bit store using cache-global modifier:
* int *d_out;
* int val;
* cub::ThreadStore<cub::STORE_CG>(d_out + threadIdx.x, val);
*
* // 16-bit store using default modifier
* short *d_out;
* short val;
* cub::ThreadStore<cub::STORE_DEFAULT>(d_out + threadIdx.x, val);
*
* // 128-bit store using write-through modifier
* float4 *d_out;
* float4 val;
* cub::ThreadStore<cub::STORE_WT>(d_out + threadIdx.x, val);
*
* // 96-bit store using cache-streaming cache modifier
* struct TestFoo { bool a; short b; };
* TestFoo *d_struct;
* TestFoo val;
* cub::ThreadStore<cub::STORE_CS>(d_out + threadIdx.x, val);
* @endcode
*
* @tparam MODIFIER
* <b>[inferred]</b> CacheStoreModifier enumeration
*
* @tparam InputIteratorT
* <b>[inferred]</b> Output iterator type \iterator
*
* @tparam T
* <b>[inferred]</b> Data type of output value
*/
//! @name Thread I/O (cache modified)
//! @{

//! @brief Thread utility for writing memory using cub::CacheStoreModifier cache modifiers.
//! Can be used to store any data type.
//!
//! @par Example
//! @code
//! #include <cub/cub.cuh> // or equivalently <cub/thread/thread_store.cuh>
//!
//! // 32-bit store using cache-global modifier:
//! int *d_out;
//! int val;
//! cub::ThreadStore<cub::STORE_CG>(d_out + threadIdx.x, val);
//!
//! // 16-bit store using default modifier
//! short *d_out;
//! short val;
//! cub::ThreadStore<cub::STORE_DEFAULT>(d_out + threadIdx.x, val);
//!
//! // 128-bit store using write-through modifier
//! float4 *d_out;
//! float4 val;
//! cub::ThreadStore<cub::STORE_WT>(d_out + threadIdx.x, val);
//!
//! // 96-bit store using cache-streaming cache modifier
//! struct TestFoo { bool a; short b; };
//! TestFoo *d_struct;
//! TestFoo val;
//! cub::ThreadStore<cub::STORE_CS>(d_out + threadIdx.x, val);
//! @endcode
//!
//! @tparam MODIFIER
//! <b>[inferred]</b> CacheStoreModifier enumeration
//!
//! @tparam OutputIteratorT
//! <b>[inferred]</b> Output iterator type \iterator
//!
//! @tparam T
//! <b>[inferred]</b> Data type of output value
template <CacheStoreModifier MODIFIER, typename OutputIteratorT, typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore(OutputIteratorT itr, T val);

Expand All @@ -97,7 +90,7 @@ _CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore(OutputIteratorT itr, T val);

namespace detail
{
/// Helper structure for templated store iteration (inductive case)
// TODO(bgruber): drop those in CCCL 4.0 when we remove the deprecated ThreadStore overloads
template <int COUNT, int MAX>
struct iterate_thread_store
{
Expand All @@ -116,7 +109,7 @@ struct iterate_thread_store
}
};

/// Helper structure for templated store iteration (termination case)
// TODO(bgruber): drop those in CCCL 4.0 when we remove the deprecated ThreadStore overloads
template <int MAX>
struct iterate_thread_store<MAX, MAX>
{
Expand All @@ -128,11 +121,21 @@ struct iterate_thread_store<MAX, MAX>
static _CCCL_DEVICE _CCCL_FORCEINLINE void Dereference(OutputIteratorT /*ptr*/, T* /*vals*/)
{}
};

template <CacheStoreModifier MODIFIER, typename T, size_t... Is>
_CCCL_DEVICE _CCCL_FORCEINLINE void store_helper(T* ptr, T* vals, ::cuda::std::index_sequence<Is...>)
{
(ThreadStore<MODIFIER>(ptr + Is, vals[Is]), ...);
}

template <typename T, size_t... Is>
_CCCL_DEVICE _CCCL_FORCEINLINE void dereference_helper(volatile T* ptr, T* vals, ::cuda::std::index_sequence<Is...>)
{
((ptr[Is] = vals[Is]), ...);
}
} // namespace detail

/**
* Define a uint4 (16B) ThreadStore specialization for the given Cache load modifier
*/
//! Define a uint4 (16B) ThreadStore specialization for the given Cache load modifier
# define _CUB_STORE_16(cub_modifier, ptx_modifier) \
template <> \
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore<cub_modifier, uint4*, uint4>(uint4 * ptr, uint4 val) \
Expand All @@ -148,9 +151,7 @@ struct iterate_thread_store<MAX, MAX>
asm volatile("st." #ptx_modifier ".v2.u64 [%0], {%1, %2};" : : "l"(ptr), "l"(val.x), "l"(val.y)); \
}

/**
* Define a uint2 (8B) ThreadStore specialization for the given Cache load modifier
*/
//! Define a uint2 (8B) ThreadStore specialization for the given Cache load modifier
# define _CUB_STORE_8(cub_modifier, ptx_modifier) \
template <> \
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore<cub_modifier, ushort4*, ushort4>(ushort4 * ptr, ushort4 val) \
Expand All @@ -171,9 +172,7 @@ struct iterate_thread_store<MAX, MAX>
asm volatile("st." #ptx_modifier ".u64 [%0], %1;" : : "l"(ptr), "l"(val)); \
}

/**
* Define a unsigned int (4B) ThreadStore specialization for the given Cache load modifier
*/
//! Define a unsigned int (4B) ThreadStore specialization for the given Cache load modifier
# define _CUB_STORE_4(cub_modifier, ptx_modifier) \
template <> \
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore<cub_modifier, unsigned int*, unsigned int>( \
Expand All @@ -182,9 +181,7 @@ struct iterate_thread_store<MAX, MAX>
asm volatile("st." #ptx_modifier ".u32 [%0], %1;" : : "l"(ptr), "r"(val)); \
}

/**
* Define a unsigned short (2B) ThreadStore specialization for the given Cache load modifier
*/
//! Define a unsigned short (2B) ThreadStore specialization for the given Cache load modifier
# define _CUB_STORE_2(cub_modifier, ptx_modifier) \
template <> \
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore<cub_modifier, unsigned short*, unsigned short>( \
Expand All @@ -193,9 +190,7 @@ struct iterate_thread_store<MAX, MAX>
asm volatile("st." #ptx_modifier ".u16 [%0], %1;" : : "l"(ptr), "h"(val)); \
}

/**
* Define a unsigned char (1B) ThreadStore specialization for the given Cache load modifier
*/
//! Define a unsigned char (1B) ThreadStore specialization for the given Cache load modifier
# define _CUB_STORE_1(cub_modifier, ptx_modifier) \
template <> \
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore<cub_modifier, unsigned char*, unsigned char>( \
Expand All @@ -211,19 +206,15 @@ struct iterate_thread_store<MAX, MAX>
: "l"(ptr), "h"((unsigned short) val)); \
}

/**
* Define powers-of-two ThreadStore specializations for the given Cache load modifier
*/
//! Define powers-of-two ThreadStore specializations for the given Cache load modifier
# define _CUB_STORE_ALL(cub_modifier, ptx_modifier) \
_CUB_STORE_16(cub_modifier, ptx_modifier) \
_CUB_STORE_8(cub_modifier, ptx_modifier) \
_CUB_STORE_4(cub_modifier, ptx_modifier) \
_CUB_STORE_2(cub_modifier, ptx_modifier) \
_CUB_STORE_1(cub_modifier, ptx_modifier)

/**
* Define ThreadStore specializations for the various Cache load modifiers
*/
//! Define ThreadStore specializations for the various Cache load modifiers
_CUB_STORE_ALL(STORE_WB, wb)
_CUB_STORE_ALL(STORE_CG, cg)
_CUB_STORE_ALL(STORE_CS, cs)
Expand All @@ -237,40 +228,38 @@ _CUB_STORE_ALL(STORE_WT, wt)
# undef _CUB_STORE_8
# undef _CUB_STORE_16

/**
* ThreadStore definition for STORE_DEFAULT modifier on iterator types
*/
//! ThreadStore definition for STORE_DEFAULT modifier on iterator types
//! deprecated [Since 3.3]
template <typename OutputIteratorT, typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore(
CCCL_DEPRECATED_BECAUSE("Use *itr = val instead") _CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore(
OutputIteratorT itr, T val, detail::constant_t<STORE_DEFAULT> /*modifier*/, ::cuda::std::false_type /*is_pointer*/)
{
*itr = val;
}

/**
* ThreadStore definition for STORE_DEFAULT modifier on pointer types
*/
//! ThreadStore definition for STORE_DEFAULT modifier on pointer types
//! deprecated [Since 3.3]
template <typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void
CCCL_DEPRECATED_BECAUSE("Use *itr = val instead") _CCCL_DEVICE _CCCL_FORCEINLINE void
ThreadStore(T* ptr, T val, detail::constant_t<STORE_DEFAULT> /*modifier*/, ::cuda::std::true_type /*is_pointer*/)
{
*ptr = val;
}

/**
* ThreadStore definition for STORE_VOLATILE modifier on primitive pointer types
*/
//! ThreadStore definition for STORE_VOLATILE modifier on primitive pointer types
//! deprecated [Since 3.3]
template <typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStoreVolatilePtr(T* ptr, T val, ::cuda::std::true_type /*is_primitive*/)
CCCL_DEPRECATED_BECAUSE("Use ThreadStore<STORE_VOLATILE>(ptr, val) instead") _CCCL_DEVICE _CCCL_FORCEINLINE void
ThreadStoreVolatilePtr(T* ptr, T val, ::cuda::std::true_type /*is_primitive*/)
{
*reinterpret_cast<volatile T*>(ptr) = val;
}

/**
* ThreadStore definition for STORE_VOLATILE modifier on non-primitive pointer types
*/
//! ThreadStore definition for STORE_VOLATILE modifier on non-primitive pointer types
//! deprecated [Since 3.3]
template <typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStoreVolatilePtr(T* ptr, T val, ::cuda::std::false_type /*is_primitive*/)
CCCL_DEPRECATED_BECAUSE("Use ThreadStore<STORE_VOLATILE>(ptr, val) instead") _CCCL_DEVICE _CCCL_FORCEINLINE void
ThreadStoreVolatilePtr(T* ptr, T val, ::cuda::std::false_type /*is_primitive*/)
{
// Create a temporary using shuffle-words, then store using volatile-words
using VolatileWord = typename UnitWord<T>::VolatileWord;
Expand All @@ -290,21 +279,19 @@ _CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStoreVolatilePtr(T* ptr, T val, ::cuda
detail::iterate_thread_store<0, VOLATILE_MULTIPLE>::Dereference(reinterpret_cast<volatile VolatileWord*>(ptr), words);
}

/**
* ThreadStore definition for STORE_VOLATILE modifier on pointer types
*/
//! ThreadStore definition for STORE_VOLATILE modifier on pointer types
//! deprecated [Since 3.3]
template <typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void
CCCL_DEPRECATED_BECAUSE("Use ThreadStore<STORE_VOLATILE>(ptr, val) instead") _CCCL_DEVICE _CCCL_FORCEINLINE void
ThreadStore(T* ptr, T val, detail::constant_t<STORE_VOLATILE> /*modifier*/, ::cuda::std::true_type /*is_pointer*/)
{
ThreadStoreVolatilePtr(ptr, val, detail::bool_constant_v<detail::is_primitive<T>::value>);
}

/**
* ThreadStore definition for generic modifiers on pointer types
*/
//! ThreadStore definition for generic modifiers on pointer types
//! deprecated [Since 3.3]
template <typename T, CacheStoreModifier MODIFIER>
_CCCL_DEVICE _CCCL_FORCEINLINE void
CCCL_DEPRECATED_BECAUSE("Use ThreadStore<MODIFIER>(ptr, val) instead") _CCCL_DEVICE _CCCL_FORCEINLINE void
ThreadStore(T* ptr, T val, detail::constant_t<MODIFIER> /*modifier*/, ::cuda::std::true_type /*is_pointer*/)
{
// Create a temporary using shuffle-words, then store using device-words
Expand All @@ -326,14 +313,49 @@ ThreadStore(T* ptr, T val, detail::constant_t<MODIFIER> /*modifier*/, ::cuda::st
reinterpret_cast<DeviceWord*>(ptr), words);
}

/**
* ThreadStore definition for generic modifiers
*/
template <CacheStoreModifier MODIFIER, typename OutputIteratorT, typename T>
_CCCL_DEVICE _CCCL_FORCEINLINE void ThreadStore(OutputIteratorT itr, T val)
{
ThreadStore(
itr, val, detail::constant_v<MODIFIER>, detail::bool_constant_v<::cuda::std::is_pointer_v<OutputIteratorT>>);
if constexpr (!::cuda::std::contiguous_iterator<OutputIteratorT> || MODIFIER == STORE_DEFAULT)
{
*itr = val;
}
else if constexpr (MODIFIER == STORE_VOLATILE && detail::is_primitive_v<T>)
{
*reinterpret_cast<volatile T*>(::cuda::std::to_address(itr)) = val;
}
else
{
// Create a temporary using shuffle-words, then store using volatile/device-words
using StoreWord = ::cuda::std::
conditional_t<MODIFIER == STORE_VOLATILE, typename UnitWord<T>::VolatileWord, typename UnitWord<T>::DeviceWord>;
using ShuffleWord = typename UnitWord<T>::ShuffleWord;

constexpr int WORD_MULTIPLE = sizeof(T) / sizeof(StoreWord);
constexpr int SHUFFLE_MULTIPLE = sizeof(T) / sizeof(ShuffleWord);

StoreWord words[WORD_MULTIPLE];

_CCCL_PRAGMA_UNROLL_FULL()
for (int i = 0; i < SHUFFLE_MULTIPLE; ++i)
{
reinterpret_cast<ShuffleWord*>(words)[i] = reinterpret_cast<ShuffleWord*>(&val)[i];
}

if constexpr (MODIFIER == STORE_VOLATILE)
{
detail::dereference_helper(reinterpret_cast<volatile StoreWord*>(::cuda::std::to_address(itr)),
words,
::cuda::std::make_index_sequence<WORD_MULTIPLE>{});
}
else
{
detail::store_helper<MODIFIER>(
reinterpret_cast<StoreWord*>(::cuda::std::to_address(itr)),
words,
::cuda::std::make_index_sequence<WORD_MULTIPLE>{});
}
}
}

#endif // _CCCL_DOXYGEN_INVOKED
Expand Down