Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/attention/attention_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "cuda_compat.h"
#include "../cuda_compat.h"

#ifdef USE_ROCM
#include <hip/hip_bf16.h>
Expand Down
5 changes: 2 additions & 3 deletions csrc/attention/paged_attention_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "attention_kernels.cuh"
#include "cuda_compat.h"
#include "../cuda_compat.h"

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
Expand Down Expand Up @@ -75,7 +74,7 @@ void paged_attention_v1_launcher(
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float);
Expand Down
5 changes: 2 additions & 3 deletions csrc/attention/paged_attention_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "attention_kernels.cuh"
#include "cuda_compat.h"
#include "../cuda_compat.h"

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
Expand Down Expand Up @@ -79,7 +78,7 @@ void paged_attention_v2_launcher(
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int logits_size = PARTITION_SIZE * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
Expand Down
33 changes: 32 additions & 1 deletion csrc/cuda_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,39 @@
#include <hip/hip_runtime.h>
#endif

struct Utils {
static __host__ int get_warp_size() {
#if defined(USE_ROCM)
static bool is_cached = false;
static int result;

if (!is_cached) {
int device_id;
cudaDeviceProp deviceProp;
cudaGetDevice(&device_id);
cudaGetDeviceProperties(&deviceProp, device_id);

result = deviceProp.warpSize;
is_cached = true;
}

return result;
#else
return 32;
#endif
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation of the get_warp_size function on the host side uses a static boolean is_cached and a static variable result to cache the warp size. This caching mechanism is susceptible to a data race if multiple threads call get_warp_size simultaneously before the value is cached. This can lead to incorrect warp size values or potential crashes. To fix this, use a thread-safe initialization of the static variable using a lambda to ensure that the warp size is initialized only once in a thread-safe manner.

  static __host__ int get_warp_size() {
#if defined(USE_ROCM)
    // C++11 guarantees that static local variable initialization is thread-safe.
    // This avoids a data race on is_cached and result.
    static const int result = [] {
      int device_id;
      cudaDeviceProp deviceProp;
      cudaGetDevice(&device_id);
      cudaGetDeviceProperties(&deviceProp, device_id);
      return deviceProp.warpSize;
    }();
    return result;
#else
    return 32;
#endif
  }


static __device__ constexpr int get_warp_size() {
#if defined(USE_ROCM) && defined(__GFX9__)
#define WARP_SIZE 64
return 64;
#else
return 32;
#endif
}
};

#if defined(USE_ROCM)
#define WARP_SIZE Utils::get_warp_size()
#else
#define WARP_SIZE 32
#endif
Expand Down
47 changes: 29 additions & 18 deletions csrc/moe/topk_softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ __launch_bounds__(TPB) __global__ void moeTopK(
2) This implementation assumes k is small, but will work for any k.
*/

template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, int WARP_SIZE_PARAM, typename IndType>
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
int* source_rows, const int k, const int start_expert, const int end_expert)
{
Expand All @@ -209,12 +209,12 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__

// Restrictions based on previous section.
static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
static_assert(WARP_SIZE_PARAM % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
static_assert(THREADS_PER_ROW <= WARP_SIZE_PARAM, "THREADS_PER_ROW can be at most warp size");

// We have NUM_EXPERTS elements per row. We specialize for small #experts
static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
static constexpr int ELTS_PER_WARP = WARP_SIZE_PARAM * VPT;
static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;

Expand Down Expand Up @@ -393,41 +393,51 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
namespace detail
{
// Constructs some constants needed to partition the work across threads at compile time.
template <int EXPERTS, int BYTES_PER_LDG>
template <int EXPERTS, int BYTES_PER_LDG, int WARP_SIZE_PARAM>
struct TopkConstants
{
static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE_PARAM) == 0, "");
static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE_PARAM));
static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
static const int ROWS_PER_WARP = WARP_SIZE_PARAM / THREADS_PER_ROW;
};
} // namespace detail

template <int EXPERTS, int WARPS_PER_TB, typename IndType>
template <int EXPERTS, int WARPS_PER_TB, int WARP_SIZE_PARAM, typename IndType>
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
{
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;

static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG, WARP_SIZE_PARAM>;
static constexpr int VPT = Constants::VPT;
static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;

dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG, WARP_SIZE_PARAM><<<num_blocks, block_dim, 0, stream>>>(
input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
}

#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, \
stream);
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \
switch (warpSize) { \
case 32: \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 32>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
break; \
case 64: \
topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB, 64>( \
gating_output, nullptr, topk_weights, topk_indices, \
token_expert_indices, num_tokens, topk, 0, num_experts, stream); \
break; \
default: \
TORCH_CHECK(false, "Unsupported warp size: ", warpSize); \
}

template <typename IndType>
void topkGatingSoftmaxKernelLauncher(
Expand All @@ -441,6 +451,7 @@ void topkGatingSoftmaxKernelLauncher(
const int topk,
cudaStream_t stream) {
static constexpr int WARPS_PER_TB = 4;
auto warpSize = WARP_SIZE;
switch (num_experts) {
case 1:
LAUNCH_SOFTMAX(1, WARPS_PER_TB);
Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include <cmath>
#include "core/math.hpp"
#include "cuda_compat.h"
#include "../cuda_compat.h"
#include "dispatch_utils.h"

#include "quantization/fp8/common.cuh"
Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/gguf/gguf_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>

#include "cuda_compat.h"
#include "../../cuda_compat.h"
#include "dispatch_utils.h"

#include "ggml-common.h"
Expand Down
2 changes: 1 addition & 1 deletion csrc/rocm/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include <c10/cuda/CUDAGuard.h>
#include <hip/hip_fp8.h>
#include <hip/hip_bf16.h>
#include "cuda_compat.h"
#include "../cuda_compat.h"

#include <algorithm>
#include "../attention/dtype_fp8.cuh"
Expand Down
2 changes: 1 addition & 1 deletion csrc/rocm/skinny_gemms.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <stdexcept>
#include <algorithm>

#include "cuda_compat.h"
#include "../cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/fp8/common.cuh"

Expand Down