Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 68 additions & 90 deletions mlx/backend/metal/kernels/sdpa_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ constant bool do_causal [[function_constant(22)]];
constant bool bool_mask [[function_constant(23)]];
constant bool float_mask [[function_constant(24)]];
constant bool has_sinks [[function_constant(25)]];
constant int blocks [[function_constant(26)]];

template <typename T, int D, int V = D>
[[kernel]] void sdpa_vector(
Expand Down Expand Up @@ -180,10 +181,9 @@ template <typename T, int D, int V = D>
const device T* queries [[buffer(0)]],
const device T* keys [[buffer(1)]],
const device T* values [[buffer(2)]],
device float* out [[buffer(3)]],
device T* out [[buffer(3)]],
device float* sums [[buffer(4)]],
device float* maxs [[buffer(5)]],
const constant int& gqa_factor [[buffer(6)]],
const constant int& N [[buffer(7)]],
const constant size_t& k_head_stride [[buffer(8)]],
const constant size_t& k_seq_stride [[buffer(9)]],
Expand All @@ -199,94 +199,81 @@ template <typename T, int D, int V = D>
const constant int& mask_head_stride
[[buffer(17), function_constant(has_mask)]],
const device T* sinks [[buffer(18), function_constant(has_sinks)]],
const constant int& num_q_heads
[[buffer(19), function_constant(has_sinks)]],
uint3 tptg [[threads_per_threadgroup]],
uint3 tidtg [[thread_position_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 tpg [[threadgroups_per_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int BN = 8;
constexpr int BD = 32;
constexpr int qk_per_thread = D / BD;
constexpr int v_per_thread = V / BD;
int inner_k_stride = BN * int(k_seq_stride);
int inner_v_stride = BN * int(v_seq_stride);
constexpr int blocks = 32;

typedef float U;

thread U q[qk_per_thread];
thread U k[qk_per_thread];
thread U o[v_per_thread];

threadgroup U outputs[BN * BD];
threadgroup U max_scores[BN];
threadgroup U sum_exp_scores[BN];
thread U o[v_per_thread] = {0};

// Adjust positions
const int kv_head_idx = tid.x;
const int batch_idx = tid.y;
const int block_idx = tid.z;
const int q_batch_head_idx = tid.x;
const int q_seq_idx = tid.y;
const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx;
const int gqa_factor = tptg.y;
const int q_seq_len = tptg.z;
const int q_seq_idx = tidtg.z;
const int q_head_idx = gqa_factor * kv_head_idx + tidtg.y;
const int num_kv_heads = tpg.x;
const int num_q_heads = num_kv_heads * gqa_factor;
const int q_batch_head_idx = (batch_idx * num_q_heads + q_head_idx);
const int o_offset = q_batch_head_idx * q_seq_len + q_seq_idx;
const int q_offset =
query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset;
const int kv_head_idx = q_batch_head_idx / gqa_factor;
query_transposed ? num_q_heads * q_seq_idx + q_batch_head_idx : o_offset;

queries += q_offset * D + simd_lid * qk_per_thread;
keys += kv_head_idx * k_head_stride +
(block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread;
values += kv_head_idx * v_head_stride +
(block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread;

const int kv_batch_head_idx = batch_idx * num_kv_heads + kv_head_idx;
keys += kv_batch_head_idx * k_head_stride + block_idx * k_seq_stride +
simd_lid * qk_per_thread;
values += kv_batch_head_idx * v_head_stride + block_idx * v_seq_stride +
simd_lid * v_per_thread;
out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread;
if (bool_mask) {
bmask += q_batch_head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
block_idx * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
}
if (float_mask) {
fmask += q_batch_head_idx * mask_head_stride +
(block_idx * BN + simd_gid) * mask_kv_seq_stride +
q_seq_idx * mask_q_seq_stride;
block_idx * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride;
}
sums += o_offset * blocks + block_idx;
maxs += o_offset * blocks + block_idx;

// Read the query and 0 the output accumulator
// Read the query
for (int i = 0; i < qk_per_thread; i++) {
q[i] = static_cast<U>(scale) * queries[i];
}
for (int i = 0; i < v_per_thread; i++) {
o[i] = 0;
}

U max_score = Limits<U>::finite_min;
U sum_exp_score = 0;
if (has_sinks && block_idx == 0 && simd_gid == 0) {
int q_head_idx = q_batch_head_idx % num_q_heads;
if (has_sinks && block_idx == 0) {
max_score = static_cast<U>(sinks[q_head_idx]);
sum_exp_score = 1;
}

// For each key
for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) {
for (int i = block_idx; i < N; i += blocks) {
bool use_key = true;
if (do_causal) {
use_key = i <= (N - int(tpg.y) + int(q_seq_idx));
use_key = i <= (N - q_seq_len + int(q_seq_idx));
} else if (bool_mask) {
use_key = bmask[0];
} else if (float_mask) {
use_key = (fmask[0] >= Limits<T>::finite_min);
}
if (use_key) {
// Read the key
for (int i = 0; i < qk_per_thread; i++) {
k[i] = keys[i];
}

// Compute the i-th score
U score = 0;
for (int i = 0; i < qk_per_thread; i++) {
score += q[i] * k[i];
score += q[i] * keys[i];
}
score = simd_sum(score);

Expand All @@ -309,57 +296,30 @@ template <typename T, int D, int V = D>
}

// Move the pointers to the next kv
keys += blocks * inner_k_stride;
values += blocks * inner_v_stride;
keys += blocks * int(k_seq_stride);
values += blocks * int(v_seq_stride);
if (bool_mask) {
bmask += BN * blocks * mask_kv_seq_stride;
bmask += blocks * mask_kv_seq_stride;
}
if (float_mask) {
fmask += BN * blocks * mask_kv_seq_stride;
fmask += blocks * mask_kv_seq_stride;
}
}

// Each thread has a partial part of the output so we need to combine them.

// First let's communicate the max and sum_exp
// Write the sum and max and outputs
if (simd_lid == 0) {
max_scores[simd_gid] = max_score;
sum_exp_scores[simd_gid] = sum_exp_score;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9;
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0;
sum_exp_score = simd_sum(sum_exp_score * factor);

// Write the sum and new max
if (simd_gid == 0) {
sums[0] = sum_exp_score;
maxs[0] = new_max;
maxs[0] = max_score;
}

// Now we need to aggregate all the outputs
for (int i = 0; i < v_per_thread; i++) {
outputs[simd_lid * BN + simd_gid] =
o[i] * fast::exp(max_scores[simd_gid] - new_max);
threadgroup_barrier(mem_flags::mem_threadgroup);

// And write the output
if (simd_gid == 0) {
U output = outputs[simd_lid * BN];
for (int j = 1; j < BN; j++) {
output += outputs[simd_lid * BN + j];
}
out[i] = static_cast<T>(output);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
out[i] = static_cast<T>(o[i]);
}
}

template <typename T, int D>
[[kernel]] void sdpa_vector_2pass_2(
const device float* partials [[buffer(0)]],
const device T* partials [[buffer(0)]],
const device float* sums [[buffer(1)]],
const device float* maxs [[buffer(2)]],
device T* out [[buffer(3)]],
Expand All @@ -370,38 +330,56 @@ template <typename T, int D>
constexpr int BN = 32;
constexpr int BD = 32;
constexpr int elem_per_thread = D / BD;
constexpr int blocks = 32;

typedef float U;

thread U o[elem_per_thread];
thread U o[elem_per_thread] = {0};
threadgroup U outputs[BN * BD];

// Adjust positions
const int head_idx = tid.x;
const int q_seq_idx = tid.y;
const int q_offset = head_idx * tpg.y + q_seq_idx;
;
partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread;
sums += q_offset * blocks;
maxs += q_offset * blocks;
out += q_offset * D + simd_gid * elem_per_thread;

// First everybody reads the max and sum_exp
U max_score = maxs[simd_lid];
U new_max = simd_max(max_score);
U factor = fast::exp(max_score - new_max);
U sum_exp_score = simd_sum(sums[simd_lid] * factor);
// Set defaults
U sum_exp_score = 0.0;
U max_score = Limits<U>::finite_min;

// Now read the block into registers and then use shared memory to transpose
// it
for (int i = 0; i < elem_per_thread; i++) {
o[i] = partials[i];
// Reduce the max
for (int b = 0; b < blocks / BN; ++b) {
max_score = max(max_score, maxs[simd_lid + BN * b]);
}
max_score = simd_max(max_score);

// Reduce the d
for (int b = 0; b < blocks / BN; ++b) {
U factor = fast::exp(maxs[simd_lid + BN * b] - max_score);
sum_exp_score += factor * sums[simd_lid + BN * b];
}
sum_exp_score = simd_sum(sum_exp_score);

// Reduce the sum exp and partials
for (int b = 0; b < blocks / BN; ++b) {
U factor = fast::exp(maxs[simd_gid] - max_score);

// Update the output accumulator
for (int i = 0; i < elem_per_thread; i++) {
o[i] += factor * partials[i];
}
maxs += BN;
sums += BN;
partials += BN * D;
}

// Use shared memory to transpose and reduce the final block
for (int i = 0; i < elem_per_thread; i++) {
outputs[simd_lid * BD + simd_gid] = o[i];
threadgroup_barrier(mem_flags::mem_threadgroup);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor);
o[i] = simd_sum(outputs[simd_gid * BD + simd_lid]);
o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
Expand Down
Loading
Loading