diff --git a/mlx/backend/metal/kernels/sdpa_vector.h b/mlx/backend/metal/kernels/sdpa_vector.h index 96d22d8e41..cccd6dce1c 100644 --- a/mlx/backend/metal/kernels/sdpa_vector.h +++ b/mlx/backend/metal/kernels/sdpa_vector.h @@ -10,6 +10,7 @@ constant bool do_causal [[function_constant(22)]]; constant bool bool_mask [[function_constant(23)]]; constant bool float_mask [[function_constant(24)]]; constant bool has_sinks [[function_constant(25)]]; +constant int blocks [[function_constant(26)]]; template [[kernel]] void sdpa_vector( @@ -180,10 +181,9 @@ template const device T* queries [[buffer(0)]], const device T* keys [[buffer(1)]], const device T* values [[buffer(2)]], - device float* out [[buffer(3)]], + device T* out [[buffer(3)]], device float* sums [[buffer(4)]], device float* maxs [[buffer(5)]], - const constant int& gqa_factor [[buffer(6)]], const constant int& N [[buffer(7)]], const constant size_t& k_head_stride [[buffer(8)]], const constant size_t& k_seq_stride [[buffer(9)]], @@ -199,94 +199,81 @@ template const constant int& mask_head_stride [[buffer(17), function_constant(has_mask)]], const device T* sinks [[buffer(18), function_constant(has_sinks)]], - const constant int& num_q_heads - [[buffer(19), function_constant(has_sinks)]], + uint3 tptg [[threads_per_threadgroup]], + uint3 tidtg [[thread_position_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], uint3 tpg [[threadgroups_per_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int BN = 8; constexpr int BD = 32; constexpr int qk_per_thread = D / BD; constexpr int v_per_thread = V / BD; - int inner_k_stride = BN * int(k_seq_stride); - int inner_v_stride = BN * int(v_seq_stride); - constexpr int blocks = 32; typedef float U; thread U q[qk_per_thread]; - thread U k[qk_per_thread]; - thread U o[v_per_thread]; - - threadgroup U outputs[BN * BD]; - threadgroup U max_scores[BN]; - threadgroup U sum_exp_scores[BN]; + thread U o[v_per_thread] = {0}; // Adjust positions + const int kv_head_idx = tid.x; + const int batch_idx = tid.y; const int block_idx = tid.z; - const int q_batch_head_idx = tid.x; - const int q_seq_idx = tid.y; - const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; + const int gqa_factor = tptg.y; + const int q_seq_len = tptg.z; + const int q_seq_idx = tidtg.z; + const int q_head_idx = gqa_factor * kv_head_idx + tidtg.y; + const int num_kv_heads = tpg.x; + const int num_q_heads = num_kv_heads * gqa_factor; + const int q_batch_head_idx = (batch_idx * num_q_heads + q_head_idx); + const int o_offset = q_batch_head_idx * q_seq_len + q_seq_idx; const int q_offset = - query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; - const int kv_head_idx = q_batch_head_idx / gqa_factor; + query_transposed ? num_q_heads * q_seq_idx + q_batch_head_idx : o_offset; queries += q_offset * D + simd_lid * qk_per_thread; - keys += kv_head_idx * k_head_stride + - (block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread; - values += kv_head_idx * v_head_stride + - (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; + + const int kv_batch_head_idx = batch_idx * num_kv_heads + kv_head_idx; + keys += kv_batch_head_idx * k_head_stride + block_idx * k_seq_stride + + simd_lid * qk_per_thread; + values += kv_batch_head_idx * v_head_stride + block_idx * v_seq_stride + + simd_lid * v_per_thread; out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; if (bool_mask) { bmask += q_batch_head_idx * mask_head_stride + - (block_idx * BN + simd_gid) * mask_kv_seq_stride + - q_seq_idx * mask_q_seq_stride; + block_idx * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } if (float_mask) { fmask += q_batch_head_idx * mask_head_stride + - (block_idx * BN + simd_gid) * mask_kv_seq_stride + - q_seq_idx * mask_q_seq_stride; + block_idx * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; } sums += o_offset * blocks + block_idx; maxs += o_offset * blocks + block_idx; - // Read the query and 0 the output accumulator + // Read the query for (int i = 0; i < qk_per_thread; i++) { q[i] = static_cast(scale) * queries[i]; } - for (int i = 0; i < v_per_thread; i++) { - o[i] = 0; - } U max_score = Limits::finite_min; U sum_exp_score = 0; - if (has_sinks && block_idx == 0 && simd_gid == 0) { - int q_head_idx = q_batch_head_idx % num_q_heads; + if (has_sinks && block_idx == 0) { max_score = static_cast(sinks[q_head_idx]); sum_exp_score = 1; } // For each key - for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { + for (int i = block_idx; i < N; i += blocks) { bool use_key = true; if (do_causal) { - use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); + use_key = i <= (N - q_seq_len + int(q_seq_idx)); } else if (bool_mask) { use_key = bmask[0]; } else if (float_mask) { use_key = (fmask[0] >= Limits::finite_min); } if (use_key) { - // Read the key - for (int i = 0; i < qk_per_thread; i++) { - k[i] = keys[i]; - } - // Compute the i-th score U score = 0; for (int i = 0; i < qk_per_thread; i++) { - score += q[i] * k[i]; + score += q[i] * keys[i]; } score = simd_sum(score); @@ -309,57 +296,30 @@ template } // Move the pointers to the next kv - keys += blocks * inner_k_stride; - values += blocks * inner_v_stride; + keys += blocks * int(k_seq_stride); + values += blocks * int(v_seq_stride); if (bool_mask) { - bmask += BN * blocks * mask_kv_seq_stride; + bmask += blocks * mask_kv_seq_stride; } if (float_mask) { - fmask += BN * blocks * mask_kv_seq_stride; + fmask += blocks * mask_kv_seq_stride; } } - // Each thread has a partial part of the output so we need to combine them. - - // First let's communicate the max and sum_exp + // Write the sum and max and outputs if (simd_lid == 0) { - max_scores[simd_gid] = max_score; - sum_exp_scores[simd_gid] = sum_exp_score; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9; - U new_max = simd_max(max_score); - U factor = fast::exp(max_score - new_max); - sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; - sum_exp_score = simd_sum(sum_exp_score * factor); - - // Write the sum and new max - if (simd_gid == 0) { sums[0] = sum_exp_score; - maxs[0] = new_max; + maxs[0] = max_score; } - // Now we need to aggregate all the outputs for (int i = 0; i < v_per_thread; i++) { - outputs[simd_lid * BN + simd_gid] = - o[i] * fast::exp(max_scores[simd_gid] - new_max); - threadgroup_barrier(mem_flags::mem_threadgroup); - - // And write the output - if (simd_gid == 0) { - U output = outputs[simd_lid * BN]; - for (int j = 1; j < BN; j++) { - output += outputs[simd_lid * BN + j]; - } - out[i] = static_cast(output); - } - threadgroup_barrier(mem_flags::mem_threadgroup); + out[i] = static_cast(o[i]); } } template [[kernel]] void sdpa_vector_2pass_2( - const device float* partials [[buffer(0)]], + const device T* partials [[buffer(0)]], const device float* sums [[buffer(1)]], const device float* maxs [[buffer(2)]], device T* out [[buffer(3)]], @@ -370,38 +330,56 @@ template constexpr int BN = 32; constexpr int BD = 32; constexpr int elem_per_thread = D / BD; - constexpr int blocks = 32; typedef float U; - thread U o[elem_per_thread]; + thread U o[elem_per_thread] = {0}; threadgroup U outputs[BN * BD]; // Adjust positions const int head_idx = tid.x; const int q_seq_idx = tid.y; const int q_offset = head_idx * tpg.y + q_seq_idx; - ; partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; sums += q_offset * blocks; maxs += q_offset * blocks; out += q_offset * D + simd_gid * elem_per_thread; - // First everybody reads the max and sum_exp - U max_score = maxs[simd_lid]; - U new_max = simd_max(max_score); - U factor = fast::exp(max_score - new_max); - U sum_exp_score = simd_sum(sums[simd_lid] * factor); + // Set defaults + U sum_exp_score = 0.0; + U max_score = Limits::finite_min; - // Now read the block into registers and then use shared memory to transpose - // it - for (int i = 0; i < elem_per_thread; i++) { - o[i] = partials[i]; + // Reduce the max + for (int b = 0; b < blocks / BN; ++b) { + max_score = max(max_score, maxs[simd_lid + BN * b]); + } + max_score = simd_max(max_score); + + // Reduce the d + for (int b = 0; b < blocks / BN; ++b) { + U factor = fast::exp(maxs[simd_lid + BN * b] - max_score); + sum_exp_score += factor * sums[simd_lid + BN * b]; } + sum_exp_score = simd_sum(sum_exp_score); + + // Reduce the sum exp and partials + for (int b = 0; b < blocks / BN; ++b) { + U factor = fast::exp(maxs[simd_gid] - max_score); + + // Update the output accumulator + for (int i = 0; i < elem_per_thread; i++) { + o[i] += factor * partials[i]; + } + maxs += BN; + sums += BN; + partials += BN * D; + } + + // Use shared memory to transpose and reduce the final block for (int i = 0; i < elem_per_thread; i++) { outputs[simd_lid * BD + simd_gid] = o[i]; threadgroup_barrier(mem_flags::mem_threadgroup); - o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid]); o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); threadgroup_barrier(mem_flags::mem_threadgroup); } diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 8dd645b664..f09cacf1c7 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -438,16 +438,48 @@ void sdpa_vector_2pass( // Compute the necessary sizes int gqa_factor = q.shape(1) / k.shape(1); - int N = k.shape(2); - int blocks = 32; - int B = q.shape(0) * q.shape(1); + int n_simds = gqa_factor * q.shape(2); + char devc = d.get_architecture().back(); + int N = k.shape(2); + int blocks; + if (devc == 's') { + blocks = 64; + if (N > 1024 && n_simds > 4) { + if (N <= 8192) { + blocks = 128; + } else if (N <= 32768) { + blocks = 256; + } else if (N <= 65536) { + blocks = 512; + } else { + blocks = 1024; + } + } + } else if (devc == 'd') { + blocks = 128; + if (n_simds <= 2 && N > 8192) { + blocks = 256; + } else if (n_simds >= 6) { + if (N >= 16384 && N < 65536) { + blocks = 512; + } else if (N >= 65536) { + blocks = 1024; + } + } + } else { + if (n_simds >= 4) { + blocks = 64; + } else { + blocks = 32; + } + } size_t k_head_stride = k.shape(1) == 1 ? k.strides(0) : k.strides(1); size_t k_seq_stride = k.strides()[2]; size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); size_t v_seq_stride = v.strides()[2]; - MTL::Size group_dims(8 * 32, 1, 1); - MTL::Size grid_dims(B, q.shape(2), blocks); + MTL::Size group_dims(32, gqa_factor, q.shape(2)); + MTL::Size grid_dims(k.shape(1), q.shape(0), blocks); // Allocate the intermediates Shape intermediate_shape; @@ -456,7 +488,7 @@ void sdpa_vector_2pass( intermediate_shape.end(), out.shape().begin(), out.shape().end() - 1); intermediate_shape.push_back(blocks); intermediate_shape.push_back(out.shape().back()); - array intermediate(intermediate_shape, float32, nullptr, {}); + array intermediate(intermediate_shape, q.dtype(), nullptr, {}); intermediate_shape.pop_back(); array sums(intermediate_shape, float32, nullptr, {}); array maxs(std::move(intermediate_shape), float32, nullptr, {}); @@ -479,12 +511,14 @@ void sdpa_vector_2pass( {&bool_mask, MTL::DataType::DataTypeBool, 23}, {&float_mask, MTL::DataType::DataTypeBool, 24}, {&has_sinks, MTL::DataType::DataTypeBool, 25}, + {&blocks, MTL::DataType::DataTypeInt, 26}, }; std::string hash_name = kname; hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; hash_name += query_transposed ? "_qt" : "_qnt"; hash_name += do_causal ? "_c" : "_nc"; - hash_name += has_sinks ? "_sinks" : "_nosinks"; + hash_name += has_sinks ? "_sinks_" : "_nosinks_"; + hash_name += std::to_string(blocks); // Get the kernel auto& compute_encoder = d.get_command_encoder(s.index); @@ -499,7 +533,6 @@ void sdpa_vector_2pass( compute_encoder.set_output_array(intermediate, 3); compute_encoder.set_output_array(sums, 4); compute_encoder.set_output_array(maxs, 5); - compute_encoder.set_bytes(gqa_factor, 6); compute_encoder.set_bytes(N, 7); compute_encoder.set_bytes(k_head_stride, 8); compute_encoder.set_bytes(k_seq_stride, 9); @@ -519,7 +552,6 @@ void sdpa_vector_2pass( } if (has_sinks) { compute_encoder.set_input_array(*sinks, 18); - compute_encoder.set_bytes(q.shape(1), 19); } // Launch @@ -527,13 +559,18 @@ void sdpa_vector_2pass( // Final pass kname.clear(); - kname += "sdpa_vector_2pass_2_"; + kname = "sdpa_vector_2pass_2_"; kname += get_type_string(q.dtype()); kname += "_"; kname += std::to_string(v.shape(-1)); + func_consts = { + {&blocks, MTL::DataType::DataTypeInt, 26}, + }; + hash_name = kname + "_" + std::to_string(blocks); + // Get the kernel - kernel = d.get_kernel(kname); + kernel = d.get_kernel(kname, hash_name, func_consts); compute_encoder.set_compute_pipeline_state(kernel); // Set its arguments @@ -544,7 +581,7 @@ void sdpa_vector_2pass( // Launch group_dims = MTL::Size(1024, 1, 1); - grid_dims = MTL::Size(B, q.shape(2), 1); + grid_dims = MTL::Size(q.shape(0) * q.shape(1), q.shape(2), 1); compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } @@ -576,6 +613,9 @@ bool ScaledDotProductAttention::use_fallback( const int query_head_dim = q.shape(-1); const int query_sequence_length = q.shape(2); const int key_sequence_length = k.shape(2); + const int num_query_heads = q.shape(1); + const int num_kv_heads = k.shape(1); + const int gqa_factor = num_query_heads / num_kv_heads; const bool sdpa_vector_supported_head_dim = query_head_dim == value_head_dim && @@ -592,7 +632,8 @@ bool ScaledDotProductAttention::use_fallback( const bool supports_sdpa_vector = (query_sequence_length <= 8) && (query_sequence_length <= key_sequence_length) && - sdpa_vector_supported_head_dim; + sdpa_vector_supported_head_dim && + (query_sequence_length * gqa_factor) <= 32; return !(supports_sdpa_full || supports_sdpa_vector); } @@ -699,7 +740,7 @@ void ScaledDotProductAttention::eval_gpu( // - The sequence length is even longer and we have gqa bool do_causal = do_causal_ && q.shape(2) > 1; char devc = d.get_architecture().back(); - if ((devc == 'd' && k.shape(2) >= 1024) || + if (((devc == 'd' || devc == 's') && k.shape(2) >= 1024) || (k.shape(1) < q.shape(1) && k.shape(2) >= 4096)) { sdpa_vector_2pass(s, d, q, k, v, o, scale_, do_causal, mask, sinks); } else {