Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
d138a03
Add support for CUMSUM and TRI for CUDA.
pwilkin Nov 28, 2025
67207d2
Minor optimizations.
pwilkin Nov 28, 2025
fab0029
Correct warp_prefix_inclusive_sum in float2 variant to return float2
pwilkin Nov 28, 2025
51c40a5
Optimize TRI
pwilkin Dec 1, 2025
c30f565
Whitespace
pwilkin Dec 1, 2025
31b55fa
Fix strides.
pwilkin Dec 1, 2025
d1ca1c2
Implement double loop
pwilkin Dec 1, 2025
5289b53
Whitespace
pwilkin Dec 1, 2025
f422ba8
Fix HIP compilation bugs
pwilkin Dec 1, 2025
df917cc
Optimizations + big case performance tests
pwilkin Dec 2, 2025
76382d7
Implement using CUB with fallback to custom kernel
pwilkin Dec 2, 2025
01d4033
Remove error message.
pwilkin Dec 2, 2025
10a2ea9
Fixes from code review
pwilkin Dec 3, 2025
7a83b05
Comment out CPU-unsupported F16/BF16 cases to fix CI
pwilkin Dec 3, 2025
bbe3743
Fine, you win :P
pwilkin Dec 4, 2025
069413a
Fix last cast, use NO_DEVICE_CODE and GGML_UNUSED_VARS
pwilkin Dec 4, 2025
5aa7438
Vary warp-size based on physical warp size
pwilkin Dec 4, 2025
579eba6
Add GGML_UNUSED_VARS in tri as well
pwilkin Dec 4, 2025
08b3f2d
Use constexpr and call prefix_inclusive with warp_size template param
pwilkin Dec 4, 2025
9cd0eff
Update ggml/src/ggml-cuda/cumsum.cu
pwilkin Dec 4, 2025
9574264
Apply suggestions from code review
pwilkin Dec 4, 2025
efd619a
Change to tid % warp_size
pwilkin Dec 4, 2025
86a0853
Fix strides; hardcode mask; add ggml_lane_mask_t
pwilkin Dec 4, 2025
de45c63
Missing renames, remove unused get_warp_mask(), explicit calls to ggm…
pwilkin Dec 4, 2025
8a7375c
Too hasty...
pwilkin Dec 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,15 @@ static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
}

static constexpr __host__ int ggml_cuda_get_physical_warp_size_host() {
#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
return 64;
#else
return 32;
#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
}


// Maximum number of bytes that can be copied in a single instruction.
static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() {
#ifdef GGML_USE_HIP
Expand Down
27 changes: 12 additions & 15 deletions ggml/src/ggml-cuda/cumsum.cu
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
#include <algorithm>
#include "cumsum.cuh"
#include "convert.cuh"
#include "ggml-cuda/common.cuh"
#include "ggml.h"

#if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
# define CUMSUM_WARP_SIZE 64
#else
# define CUMSUM_WARP_SIZE 32
#endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))

#ifdef GGML_CUDA_USE_CUB
# include <cub/device/device_scan.cuh>
#endif
Expand Down Expand Up @@ -85,9 +80,10 @@ static __global__ void cumsum_kernel(
GGML_UNUSED_VARS(nb00, nb0);

const int tid = threadIdx.x;
const int lane = tid & (CUMSUM_WARP_SIZE - 1);
const int warp = tid / CUMSUM_WARP_SIZE;
const int warps_per_block = blockDim.x / CUMSUM_WARP_SIZE;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int lane = tid & (warp_size - 1);
const int warp = tid / warp_size;
const int warps_per_block = blockDim.x / warp_size;

extern __shared__ float smem[];
float* s_vals = smem;
Expand Down Expand Up @@ -116,19 +112,19 @@ static __global__ void cumsum_kernel(
float val = (idx < ne00) ? ggml_cuda_cast<float, T>(src_row[idx]) : 0.0f;

// 1. Warp inclusive scan
val = warp_prefix_inclusive_sum(val);
val = warp_prefix_inclusive_sum<T, warp_size>(val);
s_vals[tid] = val;

// Store warp total
if (lane == CUMSUM_WARP_SIZE - 1) {
if (lane == warp_size - 1) {
s_warp_sums[warp] = val;
}
__syncthreads();

// 2. Exclusive scan of warp sums (warp 0 only)
if (warp == 0) {
float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f;
float inc = warp_prefix_inclusive_sum(w);
float inc = warp_prefix_inclusive_sum<T, warp_size>(w);
if (tid < warps_per_block) {
s_warp_sums[tid] = inc - w; // exclusive sum
}
Expand Down Expand Up @@ -172,11 +168,12 @@ static void cumsum_cuda(
}
#endif // GGML_CUDA_USE_CUB
dim3 grid_dims(ne01, ne02, ne03);
const int num_warps = (ne00 + CUMSUM_WARP_SIZE - 1) / CUMSUM_WARP_SIZE;
int block_size = num_warps * CUMSUM_WARP_SIZE;
constexpr int warp_size = ggml_cuda_get_physical_warp_size_host();
const int num_warps = (ne00 + warp_size - 1) / warp_size;
int block_size = num_warps * warp_size;
block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
dim3 block_dims(block_size, 1, 1);
const int warps_per_block = block_size / CUMSUM_WARP_SIZE;
const int warps_per_block = block_size / warp_size;
const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float);

if (use_cub) {
Expand Down
Loading