From e4db8f08e4e9af17d36d63d501d03150560ec0a7 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 1 Jul 2025 18:51:44 +0000 Subject: [PATCH 1/2] warpSize is being made non constexpr in ROCm 7.0 Signed-off-by: Gregory Shtrasberg --- csrc/attention/attention_kernels.cuh | 6 +++--- csrc/attention/paged_attention_v1.cu | 6 +++--- csrc/attention/paged_attention_v2.cu | 6 +++--- csrc/cuda_compat.h | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 79a546554fa1..bee156469115 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -33,10 +33,10 @@ typedef __hip_bfloat16 __nv_bfloat16; #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif -#ifndef USE_ROCM - #define WARP_SIZE 32 +#if defined(USE_ROCM) && defined(__GFX9__) + #define WARP_SIZE 64 #else - #define WARP_SIZE warpSize + #define WARP_SIZE 32 #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 46108a32d719..f2937190151e 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -19,10 +19,10 @@ #include "attention_kernels.cuh" -#ifndef USE_ROCM - #define WARP_SIZE 32 +#if defined(USE_ROCM) && defined(__GFX9__) + #define WARP_SIZE 64 #else - #define WARP_SIZE warpSize + #define WARP_SIZE 32 #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 9358c0d9f6a2..e1c07bb8859a 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -19,10 +19,10 @@ #include "attention_kernels.cuh" -#ifndef USE_ROCM - #define WARP_SIZE 32 +#if defined(USE_ROCM) && defined(__GFX9__) + #define WARP_SIZE 64 #else - #define WARP_SIZE warpSize + #define WARP_SIZE 32 #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 82e55613d915..affa051c7595 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -4,10 +4,10 @@ #include #endif -#ifndef USE_ROCM - #define WARP_SIZE 32 +#if defined(USE_ROCM) && defined(__GFX9__) + #define WARP_SIZE 64 #else - #define WARP_SIZE warpSize + #define WARP_SIZE 32 #endif #ifndef USE_ROCM From 56d612311d497145f6dd464455eb50989130dffb Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg Date: Tue, 1 Jul 2025 20:30:17 +0000 Subject: [PATCH 2/2] Using cuda_compat to defint the WARP_SIZE once Signed-off-by: Gregory Shtrasberg --- csrc/attention/attention_kernels.cuh | 8 +------- csrc/attention/paged_attention_v1.cu | 8 +------- csrc/attention/paged_attention_v2.cu | 8 +------- 3 files changed, 3 insertions(+), 21 deletions(-) diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index bee156469115..8f24be89578b 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -24,6 +24,7 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" +#include "cuda_compat.h" #ifdef USE_ROCM #include @@ -33,12 +34,6 @@ typedef __hip_bfloat16 __nv_bfloat16; #include "../quantization/fp8/nvidia/quant_utils.cuh" #endif -#if defined(USE_ROCM) && defined(__GFX9__) - #define WARP_SIZE 64 -#else - #define WARP_SIZE 32 -#endif - #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) @@ -670,7 +665,6 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm -#undef WARP_SIZE #undef MAX #undef MIN #undef DIVIDE_ROUND_UP diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index f2937190151e..7a5ef10f8ef3 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -18,12 +18,7 @@ */ #include "attention_kernels.cuh" - -#if defined(USE_ROCM) && defined(__GFX9__) - #define WARP_SIZE 64 -#else - #define WARP_SIZE 32 -#endif +#include "cuda_compat.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -187,7 +182,6 @@ void paged_attention_v1( CALL_V1_LAUNCHER_BLOCK_SIZE) } -#undef WARP_SIZE #undef MAX #undef MIN #undef DIVIDE_ROUND_UP diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index e1c07bb8859a..b45b28dad05e 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -18,12 +18,7 @@ */ #include "attention_kernels.cuh" - -#if defined(USE_ROCM) && defined(__GFX9__) - #define WARP_SIZE 64 -#else - #define WARP_SIZE 32 -#endif +#include "cuda_compat.h" #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -197,7 +192,6 @@ void paged_attention_v2( CALL_V2_LAUNCHER_BLOCK_SIZE) } -#undef WARP_SIZE #undef MAX #undef MIN #undef DIVIDE_ROUND_UP