Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 51 additions & 30 deletions mlx/backend/cuda/scaled_dot_product_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,39 @@ struct AttnParams {
int64_t O_strides[3];
};

template <int N, typename T, typename U>
__device__ void load(const T *src, U *dst, int idx) {
if constexpr (N % 2 == 0) {
auto local = load_vector<N>(src, idx);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < N; i++) {
dst[i] = static_cast<U>(local[i]);
}
} else {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < N; i++) {
dst[i] = static_cast<U>(src[N * idx + i]);
}
}
}

template <int N, typename T, typename U>
__device__ void store(T *src, U *dst, int idx) {
if constexpr (N % 2 == 0) {
AlignedVector<U, N> local;
PRAGMA_LOOP_UNROLL
for (int i = 0; i < N; i++) {
local[i] = static_cast<U>(src[i]);
}
store_vector<N>(dst, idx, local);
} else {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < N; i++) {
dst[i] = static_cast<U>(src[N * idx + i]);
}
}
}

template <typename T, bool do_causal, int D>
__global__ void kernel_sdpav_1pass(
const T* Q,
Expand All @@ -58,6 +91,7 @@ __global__ void kernel_sdpav_1pass(

U q[v_per_thread];
U k[v_per_thread];
U v[v_per_thread];
U o[v_per_thread];

__shared__ U outputs[BN][BD + 1];
Expand Down Expand Up @@ -97,9 +131,10 @@ __global__ void kernel_sdpav_1pass(
q_seq_idx * params.O_strides[2]; // Sequence

// Read the query and 0 the output accumulator
load<v_per_thread>(Q, q, lane_idx);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
q[i] *= scale_log2;
}

PRAGMA_LOOP_UNROLL
Expand All @@ -123,10 +158,7 @@ __global__ void kernel_sdpav_1pass(

if (use_key) {
// Read the key
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
k[j] = K[v_per_thread * lane_idx + j];
}
load<v_per_thread>(K, k, lane_idx);

// Compute the i-th score
U score = 0.f;
Expand All @@ -146,11 +178,12 @@ __global__ void kernel_sdpav_1pass(
max_score = new_max;
sum_exp_score = sum_exp_score * factor + exp_score;

load<v_per_thread>(V, v, lane_idx);

// Update the output accumulator
PRAGMA_LOOP_UNROLL
for (int j = 0; j < v_per_thread; j++) {
o[j] = o[j] * factor +
exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
o[j] = o[j] * factor + exp_score * v[j];
}
}

Expand Down Expand Up @@ -184,10 +217,7 @@ __global__ void kernel_sdpav_1pass(

// And write the output
if (lane_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
}
store<v_per_thread>(o, O, warp_idx);
}
}

Expand Down Expand Up @@ -254,9 +284,10 @@ __global__ void kernel_sdpav_2pass_1(
maxs += p_offset;

// Read the query
load<v_per_thread>(Q, q, lane_idx);
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
q[i] *= scale_log2;
}

U max_score = Limits<U>::finite_min();
Expand All @@ -275,16 +306,12 @@ __global__ void kernel_sdpav_2pass_1(
use_key = i <= (params.kL - params.qL + q_seq_idx);
}
// Load keys and values into shared memory
if (warp.any(use_key)) {
block.sync();
if (threadIdx.y == 0 && threadIdx.z == 0) {
for (int j = 0; j < v_per_thread; j++) {
k[j] = static_cast<U>(K[v_per_thread * lane_idx + j]);
v[j] = static_cast<U>(V[v_per_thread * lane_idx + j]);
}
}
block.sync();
block.sync();
if (threadIdx.y == 0 && threadIdx.z == 0) {
load<v_per_thread>(K, k, lane_idx);
load<v_per_thread>(V, v, lane_idx);
}
block.sync();

if (use_key) {
// Compute the i-th score
Expand Down Expand Up @@ -322,10 +349,7 @@ __global__ void kernel_sdpav_2pass_1(
maxs[0] = max_score;
}

PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
partials[v_per_thread * lane_idx + i] = static_cast<T>(o[i]);
}
store<v_per_thread>(o, partials, lane_idx);
}

template <typename T, bool do_causal, int D>
Expand Down Expand Up @@ -412,10 +436,7 @@ __global__ void kernel_sdpav_2pass_2(

// And write the output
if (lane_idx == 0) {
PRAGMA_LOOP_UNROLL
for (int i = 0; i < v_per_thread; i++) {
O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
}
store<v_per_thread>(o, O, warp_idx);
}
}

Expand Down Expand Up @@ -527,7 +548,7 @@ void sdpa_vector_2pass_fallback(
// Allocate the intermediates
int n_simds = params.gqa_factor * params.qL;
// TODO tune on different machines
int blocks = 256;
int blocks = 512;

Shape intermediate_shape;
intermediate_shape.reserve(o.ndim() + 1);
Expand Down
Loading