diff --git a/mlx/backend/cuda/scaled_dot_product_attention.cu b/mlx/backend/cuda/scaled_dot_product_attention.cu index 64ac63180e..838e890b80 100644 --- a/mlx/backend/cuda/scaled_dot_product_attention.cu +++ b/mlx/backend/cuda/scaled_dot_product_attention.cu @@ -38,6 +38,39 @@ struct AttnParams { int64_t O_strides[3]; }; +template +__device__ void load(const T *src, U *dst, int idx) { + if constexpr (N % 2 == 0) { + auto local = load_vector(src, idx); + PRAGMA_LOOP_UNROLL + for (int i = 0; i < N; i++) { + dst[i] = static_cast(local[i]); + } + } else { + PRAGMA_LOOP_UNROLL + for (int i = 0; i < N; i++) { + dst[i] = static_cast(src[N * idx + i]); + } + } +} + +template +__device__ void store(T *src, U *dst, int idx) { + if constexpr (N % 2 == 0) { + AlignedVector local; + PRAGMA_LOOP_UNROLL + for (int i = 0; i < N; i++) { + local[i] = static_cast(src[i]); + } + store_vector(dst, idx, local); + } else { + PRAGMA_LOOP_UNROLL + for (int i = 0; i < N; i++) { + dst[i] = static_cast(src[N * idx + i]); + } + } +} + template __global__ void kernel_sdpav_1pass( const T* Q, @@ -58,6 +91,7 @@ __global__ void kernel_sdpav_1pass( U q[v_per_thread]; U k[v_per_thread]; + U v[v_per_thread]; U o[v_per_thread]; __shared__ U outputs[BN][BD + 1]; @@ -97,9 +131,10 @@ __global__ void kernel_sdpav_1pass( q_seq_idx * params.O_strides[2]; // Sequence // Read the query and 0 the output accumulator + load(Q, q, lane_idx); PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { - q[i] = scale_log2 * static_cast(Q[v_per_thread * lane_idx + i]); + q[i] *= scale_log2; } PRAGMA_LOOP_UNROLL @@ -123,10 +158,7 @@ __global__ void kernel_sdpav_1pass( if (use_key) { // Read the key - PRAGMA_LOOP_UNROLL - for (int j = 0; j < v_per_thread; j++) { - k[j] = K[v_per_thread * lane_idx + j]; - } + load(K, k, lane_idx); // Compute the i-th score U score = 0.f; @@ -146,11 +178,12 @@ __global__ void kernel_sdpav_1pass( max_score = new_max; sum_exp_score = sum_exp_score * factor + exp_score; + load(V, v, lane_idx); + // Update the output accumulator PRAGMA_LOOP_UNROLL for (int j = 0; j < v_per_thread; j++) { - o[j] = o[j] * factor + - exp_score * static_cast(V[v_per_thread * lane_idx + j]); + o[j] = o[j] * factor + exp_score * v[j]; } } @@ -184,10 +217,7 @@ __global__ void kernel_sdpav_1pass( // And write the output if (lane_idx == 0) { - PRAGMA_LOOP_UNROLL - for (int i = 0; i < v_per_thread; i++) { - O[v_per_thread * warp_idx + i] = static_cast(o[i]); - } + store(o, O, warp_idx); } } @@ -254,9 +284,10 @@ __global__ void kernel_sdpav_2pass_1( maxs += p_offset; // Read the query + load(Q, q, lane_idx); PRAGMA_LOOP_UNROLL for (int i = 0; i < v_per_thread; i++) { - q[i] = scale_log2 * static_cast(Q[v_per_thread * lane_idx + i]); + q[i] *= scale_log2; } U max_score = Limits::finite_min(); @@ -275,16 +306,12 @@ __global__ void kernel_sdpav_2pass_1( use_key = i <= (params.kL - params.qL + q_seq_idx); } // Load keys and values into shared memory - if (warp.any(use_key)) { - block.sync(); - if (threadIdx.y == 0 && threadIdx.z == 0) { - for (int j = 0; j < v_per_thread; j++) { - k[j] = static_cast(K[v_per_thread * lane_idx + j]); - v[j] = static_cast(V[v_per_thread * lane_idx + j]); - } - } - block.sync(); + block.sync(); + if (threadIdx.y == 0 && threadIdx.z == 0) { + load(K, k, lane_idx); + load(V, v, lane_idx); } + block.sync(); if (use_key) { // Compute the i-th score @@ -322,10 +349,7 @@ __global__ void kernel_sdpav_2pass_1( maxs[0] = max_score; } - PRAGMA_LOOP_UNROLL - for (int i = 0; i < v_per_thread; i++) { - partials[v_per_thread * lane_idx + i] = static_cast(o[i]); - } + store(o, partials, lane_idx); } template @@ -412,10 +436,7 @@ __global__ void kernel_sdpav_2pass_2( // And write the output if (lane_idx == 0) { - PRAGMA_LOOP_UNROLL - for (int i = 0; i < v_per_thread; i++) { - O[v_per_thread * warp_idx + i] = static_cast(o[i]); - } + store(o, O, warp_idx); } } @@ -527,7 +548,7 @@ void sdpa_vector_2pass_fallback( // Allocate the intermediates int n_simds = params.gqa_factor * params.qL; // TODO tune on different machines - int blocks = 256; + int blocks = 512; Shape intermediate_shape; intermediate_shape.reserve(o.ndim() + 1);