From 34223784875e40b0904939a128d3470e506f7f9c Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 21 Aug 2024 07:09:26 -0400 Subject: [PATCH 1/4] Add gemma2 paged attn support --- mistralrs-core/src/models/gemma.rs | 1 + mistralrs-core/src/models/gemma2.rs | 117 +++++++++++++----- mistralrs-core/src/models/llama.rs | 1 + mistralrs-core/src/models/mistral.rs | 1 + mistralrs-core/src/models/mixtral.rs | 1 + mistralrs-core/src/models/phi2.rs | 1 + mistralrs-core/src/models/phi3.rs | 1 + mistralrs-core/src/models/quantized_llama.rs | 1 + mistralrs-core/src/models/quantized_phi2.rs | 1 + mistralrs-core/src/models/quantized_phi3.rs | 1 + .../src/models/quantized_starcoder2.rs | 1 + mistralrs-core/src/models/qwen2.rs | 1 + mistralrs-core/src/models/starcoder2.rs | 1 + .../paged_attention/layers/paged_attention.rs | 7 ++ .../vision_models/llava/llava_llm/llama.rs | 1 + .../vision_models/llava/llava_llm/mistral.rs | 1 + mistralrs-core/src/vision_models/phi3.rs | 1 + .../src/backend/paged_attention.rs | 7 ++ mistralrs-paged-attn/src/ffi.rs | 2 + mistralrs-paged-attn/src/pagedattention.cu | 35 +++++- 20 files changed, 148 insertions(+), 35 deletions(-) diff --git a/mistralrs-core/src/models/gemma.rs b/mistralrs-core/src/models/gemma.rs index 751e6addb5..c0f75c6026 100644 --- a/mistralrs-core/src/models/gemma.rs +++ b/mistralrs-core/src/models/gemma.rs @@ -334,6 +334,7 @@ impl Attention { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/models/gemma2.rs b/mistralrs-core/src/models/gemma2.rs index 438cc16df7..1a62ac5cb2 100644 --- a/mistralrs-core/src/models/gemma2.rs +++ b/mistralrs-core/src/models/gemma2.rs @@ -14,7 +14,7 @@ use crate::{ device_map::DeviceMapper, get_delta_from_lora_ab, layers::{repeat_kv, CausalMasker, MatMul}, - paged_attention::{AttentionImplementation, ModelConfigMetadata}, + paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention}, pipeline::{ extract_logits, text_models_inputs_processor::PagedAttentionInputMetadata, Cache, IsqModel, NormalLoadingMetadata, NormalModel, @@ -199,7 +199,6 @@ impl MlpLayer for MLP { } } -#[derive(Clone)] struct Attention { q_proj: Arc, k_proj: Arc, @@ -214,6 +213,7 @@ struct Attention { attn_logit_softcapping: Option, use_sliding_window: bool, sliding_window: Option, + paged_attn: Option, } impl Attention { @@ -222,6 +222,7 @@ impl Attention { cfg: &Config, layer_idx: usize, vb: VarBuilder, + paged_attn: Option, ) -> Result { let hidden_sz = cfg.hidden_size; let num_heads = cfg.num_attention_heads; @@ -276,9 +277,11 @@ impl Attention { } else { None }, + paged_attn, }) } + #[allow(clippy::too_many_arguments)] fn forward( &self, xs: &Tensor, @@ -287,6 +290,7 @@ impl Attention { seqlen_offsets: &[usize], start_offsets_kernel: Tensor, kv_cache: &mut Option<(Tensor, Tensor)>, + metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>, ) -> Result { let (b_sz, q_len, _) = xs.dims3()?; @@ -342,38 +346,55 @@ impl Attention { attention_mask }; - // self.sliding_window is None if !self.use_sliding_window - let (k, v, mask) = Cache::update_kv_cache_sliding_window( - kv_cache, - k, - v, - mask, - self.sliding_window, - false, - )?; - - let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; - let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?; - - let mut att = MatMul.matmul_affine_div( - &q.contiguous()?, - &k.t()?.contiguous()?, - (self.query_pre_attn_scalar as f64).sqrt(), - )?; - - if let Some(attn_logit_softcapping) = self.attn_logit_softcapping { - att = (att / attn_logit_softcapping)?; - att = att.tanh()?; - att = (att * attn_logit_softcapping)?; - } + let mut attn_output = match &self.paged_attn { + Some(paged_attn) => { + let ((key_cache, value_cache), input_metadata) = metadata.unwrap(); + paged_attn.forward( + &q, + &k, + &v, + attention_mask, + Some(key_cache), + Some(value_cache), + input_metadata, + self.attn_logit_softcapping, + )? + } + None => { + // self.sliding_window is None if !self.use_sliding_window + let (k, v, mask) = Cache::update_kv_cache_sliding_window( + kv_cache, + k, + v, + mask, + self.sliding_window, + false, + )?; + + let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?; + let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?; + + let mut att = MatMul.matmul_affine_div( + &q.contiguous()?, + &k.t()?.contiguous()?, + (self.query_pre_attn_scalar as f64).sqrt(), + )?; + + if let Some(attn_logit_softcapping) = self.attn_logit_softcapping { + att = (att / attn_logit_softcapping)?; + att = att.tanh()?; + att = (att * attn_logit_softcapping)?; + } - let att = match mask { - Some(m) => att.broadcast_add(&m)?, - None => att, + let att = match mask { + Some(m) => att.broadcast_add(&m)?, + None => att, + }; + let att = candle_nn::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + MatMul.matmul(&att, &v.contiguous()?)? + } }; - let att = candle_nn::ops::softmax_last_dim(&att)?; - // Convert to contiguous as matmul doesn't support strided vs for now. - let mut attn_output = MatMul.matmul(&att, &v.contiguous()?)?; if let Some(t) = self.q_proj.quantized_act_type() { attn_output = attn_output.to_dtype(t)?; @@ -408,12 +429,14 @@ impl DecoderLayer { mapper: &dyn DeviceMapper, layer_idx: usize, loading_isq: bool, + paged_attn: Option, ) -> Result { let self_attn = Attention::new( rotary_emb, cfg, layer_idx, mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq), + paged_attn, )?; let mlp = MLP::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?; let input_layernorm = RmsNorm::new( @@ -446,6 +469,7 @@ impl DecoderLayer { }) } + #[allow(clippy::too_many_arguments)] fn forward( &self, xs: &Tensor, @@ -454,6 +478,7 @@ impl DecoderLayer { seqlen_offsets: &[usize], start_offsets_kernel: Tensor, kv_cache: &mut Option<(Tensor, Tensor)>, + metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>, ) -> Result { let residual = xs; let xs = self.input_layernorm.forward(xs)?; @@ -466,6 +491,7 @@ impl DecoderLayer { seqlen_offsets, start_offsets_kernel, kv_cache, + metadata, )? .apply(&self.post_attention_layernorm)?; let xs = (xs + residual)?; @@ -535,6 +561,25 @@ impl Model { is_gptx, vb.dtype(), )?); + let head_dim = cfg.head_dim; + let sliding_window = if layer_idx % 2 == 0 { + // ^ Order is SWA, global, SWA + Some(cfg.sliding_window) + } else { + None + }; + let paged_attn = match &attention_mechanism { + AttentionImplementation::Eager => None, + AttentionImplementation::PagedAttention => Some(PagedAttention::new( + cfg.num_attention_heads, + head_dim, + (1.0 / (cfg.query_pre_attn_scalar as f64).sqrt()) as f32, + Some(cfg.num_key_value_heads), + sliding_window, + &normal_loading_metadata.real_device, + None, + )?), + }; let layer = DecoderLayer::new( rotary_emb.clone(), cfg, @@ -542,6 +587,7 @@ impl Model { &*mapper, layer_idx, normal_loading_metadata.loading_isq, + paged_attn, )?; layers.push(layer) } @@ -584,6 +630,7 @@ impl Model { seqlen_offsets: &[usize], start_offsets_kernel: Tensor, context_lens: Vec<(usize, usize)>, + mut metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, ) -> Result { let xs = self.embed_tokens.forward(input_ids)?; let mut xs = (xs * (self.hidden_size as f64).sqrt())?; @@ -617,6 +664,9 @@ impl Model { seqlen_offsets, start_offsets_kernel.clone(), &mut cache[i], + metadata + .as_mut() + .map(|(kv_cache, metadata)| (kv_cache[i].clone(), &mut **metadata)), )?; } let xs = xs.to_device(&self.device)?; @@ -672,13 +722,14 @@ impl NormalModel for Model { start_offsets_kernel: Tensor, context_lens: Vec<(usize, usize)>, _position_ids: Vec, - _metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, + metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>, ) -> Result { self.forward( input_ids, seqlen_offsets, start_offsets_kernel, context_lens, + metadata, ) } fn xlora_forward( diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index 1a0b84e374..f2cfd8a58f 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -127,6 +127,7 @@ impl CausalSelfAttention { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs index 4746a0e76d..2d31b7a964 100644 --- a/mistralrs-core/src/models/mistral.rs +++ b/mistralrs-core/src/models/mistral.rs @@ -287,6 +287,7 @@ impl Attention { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/models/mixtral.rs b/mistralrs-core/src/models/mixtral.rs index 7f1b0afe4a..a4af17fda5 100644 --- a/mistralrs-core/src/models/mixtral.rs +++ b/mistralrs-core/src/models/mixtral.rs @@ -177,6 +177,7 @@ impl Attention { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/models/phi2.rs b/mistralrs-core/src/models/phi2.rs index 050a099b2f..de1925c187 100644 --- a/mistralrs-core/src/models/phi2.rs +++ b/mistralrs-core/src/models/phi2.rs @@ -300,6 +300,7 @@ impl Attention { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index 908da10a00..bea1e58320 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -182,6 +182,7 @@ impl Attention { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/models/quantized_llama.rs b/mistralrs-core/src/models/quantized_llama.rs index f3051d1e36..da91186618 100644 --- a/mistralrs-core/src/models/quantized_llama.rs +++ b/mistralrs-core/src/models/quantized_llama.rs @@ -189,6 +189,7 @@ impl LayerWeights { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/models/quantized_phi2.rs b/mistralrs-core/src/models/quantized_phi2.rs index e10bf8e3c0..ac453af190 100644 --- a/mistralrs-core/src/models/quantized_phi2.rs +++ b/mistralrs-core/src/models/quantized_phi2.rs @@ -104,6 +104,7 @@ impl LayerWeights { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/models/quantized_phi3.rs b/mistralrs-core/src/models/quantized_phi3.rs index 67c066d191..971c4e8c8f 100644 --- a/mistralrs-core/src/models/quantized_phi3.rs +++ b/mistralrs-core/src/models/quantized_phi3.rs @@ -124,6 +124,7 @@ impl LayerWeights { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/models/quantized_starcoder2.rs b/mistralrs-core/src/models/quantized_starcoder2.rs index ff645cff3f..7d3d3c8dc6 100644 --- a/mistralrs-core/src/models/quantized_starcoder2.rs +++ b/mistralrs-core/src/models/quantized_starcoder2.rs @@ -119,6 +119,7 @@ impl LayerWeights { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/models/qwen2.rs b/mistralrs-core/src/models/qwen2.rs index c624ba1edb..6e63bbbbed 100644 --- a/mistralrs-core/src/models/qwen2.rs +++ b/mistralrs-core/src/models/qwen2.rs @@ -274,6 +274,7 @@ impl Attention { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/models/starcoder2.rs b/mistralrs-core/src/models/starcoder2.rs index d089ddd962..3e0cabd0cd 100644 --- a/mistralrs-core/src/models/starcoder2.rs +++ b/mistralrs-core/src/models/starcoder2.rs @@ -267,6 +267,7 @@ impl Attention { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/paged_attention/layers/paged_attention.rs b/mistralrs-core/src/paged_attention/layers/paged_attention.rs index c3f64235a8..78707d350a 100644 --- a/mistralrs-core/src/paged_attention/layers/paged_attention.rs +++ b/mistralrs-core/src/paged_attention/layers/paged_attention.rs @@ -64,6 +64,7 @@ impl PagedAttention { mut key_cache: Option, mut value_cache: Option, input_metadata: &mut PagedAttentionInputMetadata, + softcapping: Option, ) -> Result { let dims = input_metadata.slot_mappings.dims(); let slot_mapping = if dims.len() > 1 { @@ -93,6 +94,10 @@ impl PagedAttention { } else { (query.matmul(&key.t()?)? * self.scale as f64)? }; + let att = match softcapping { + None => att, + Some(sc) => ((att / sc)?.tanh()? * sc)?, + }; let att = att.broadcast_add(mask)?; let att = candle_nn::ops::softmax_last_dim(&att)?; @@ -163,6 +168,7 @@ impl PagedAttention { // input_metadata: metadata for paged attention. // // alibi_slopes: shape = [num_heads] + #[allow(clippy::cast_possible_truncation)] paged_attention( &query, key_cache.as_ref().unwrap(), @@ -171,6 +177,7 @@ impl PagedAttention { input_metadata.context_lens.as_ref().unwrap(), input_metadata.max_context_len.unwrap(), self.scale, + softcapping.unwrap_or(1.0f64) as f32, ) } } diff --git a/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs b/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs index 9c08ad18c7..b59446f81d 100644 --- a/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs +++ b/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs @@ -95,6 +95,7 @@ impl CausalSelfAttention { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs b/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs index 60386bfc7e..54e3fb5cdc 100644 --- a/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs +++ b/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs @@ -239,6 +239,7 @@ impl Attention { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-core/src/vision_models/phi3.rs b/mistralrs-core/src/vision_models/phi3.rs index f0560bbcf8..62d0dd8246 100644 --- a/mistralrs-core/src/vision_models/phi3.rs +++ b/mistralrs-core/src/vision_models/phi3.rs @@ -253,6 +253,7 @@ impl Attention { Some(key_cache), Some(value_cache), input_metadata, + None, )? } None => { diff --git a/mistralrs-paged-attn/src/backend/paged_attention.rs b/mistralrs-paged-attn/src/backend/paged_attention.rs index 032bf31b41..07b31b67b7 100644 --- a/mistralrs-paged-attn/src/backend/paged_attention.rs +++ b/mistralrs-paged-attn/src/backend/paged_attention.rs @@ -10,6 +10,7 @@ use std::ffi::c_int; struct PagedAttention { softmax_scale: f32, + softcapping: f32, key_cache: Tensor, value_cache: Tensor, @@ -174,6 +175,7 @@ impl PagedAttention { vc_ptr, num_kv_heads as c_int, self.softmax_scale, + self.softcapping, bt_ptr, cl_ptr, block_size as c_int, @@ -210,6 +212,7 @@ impl PagedAttention { vc_ptr, num_kv_heads as c_int, self.softmax_scale, + self.softcapping, bt_ptr, cl_ptr, block_size as c_int, @@ -266,8 +269,10 @@ impl candle::CustomOp1 for PagedAttention { /// * `context_lens` - Tensor associating lengths to each sequence of shape `(num_sequences)` /// * `max_context_len` - Max of `context_len` /// * `softmax_scale` - scaling factor +/// * `softcapping`- Softcapping value as in Gemma 2. Using 1.0 means do nothing. /// /// The resulting tensor has dimensions `(num_sequences, num_heads_q, head_size)`. +#[allow(clippy::too_many_arguments)] pub fn paged_attention( q: &Tensor, key_cache: &Tensor, @@ -276,6 +281,7 @@ pub fn paged_attention( context_lens: &Tensor, max_context_len: usize, softmax_scale: f32, + softcapping: f32, ) -> Result { let op = PagedAttention { softmax_scale, @@ -284,6 +290,7 @@ pub fn paged_attention( block_tables: block_tables.clone(), context_lens: context_lens.clone(), max_context_len, + softcapping, }; q.apply_op1(op) } diff --git a/mistralrs-paged-attn/src/ffi.rs b/mistralrs-paged-attn/src/ffi.rs index 423db2c668..69d7e48bfe 100644 --- a/mistralrs-paged-attn/src/ffi.rs +++ b/mistralrs-paged-attn/src/ffi.rs @@ -26,6 +26,7 @@ extern "C" { value_cache: *const c_void, num_kv_heads: c_int, scale: f32, + softcapping: f32, block_tables: *const c_int, context_lens: *const c_int, block_size: c_int, @@ -52,6 +53,7 @@ extern "C" { value_cache: *const c_void, num_kv_heads: c_int, scale: f32, + softcapping: f32, block_tables: *const c_int, context_lens: *const c_int, block_size: c_int, diff --git a/mistralrs-paged-attn/src/pagedattention.cu b/mistralrs-paged-attn/src/pagedattention.cu index f3a50827ca..eaa98941c0 100644 --- a/mistralrs-paged-attn/src/pagedattention.cu +++ b/mistralrs-paged-attn/src/pagedattention.cu @@ -73,6 +73,20 @@ inline __device__ float block_sum(float* red_smem, float sum) { return VLLM_SHFL_SYNC(sum, 0); } +inline __device__ float fast_tanh(float x) { + #if defined(__CUDA_ARCH__) + #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) + float y; + asm volatile ( "tanh.approx.f32 %0, %1; " : "=f"(y) : "f"(x)); + return y; + #else + return ::tanhf(x); + #endif + #else + return std::tanh(x); + #endif +} + // TODO(woosuk): Merge the last two dimensions of the grid. // Grid: (num_heads, num_seqs, max_num_partitions). template< @@ -90,6 +104,7 @@ __device__ void paged_attention_kernel( const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const int num_kv_heads, // [num_heads] const float scale, + const float softcapping, const uint32_t* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const uint32_t* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, @@ -212,6 +227,12 @@ __device__ void paged_attention_kernel( // Compute dot product. // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + + // Apply softcapping + if (softcapping != 1.0) { + qk = fast_tanh(qk / softcapping) * softcapping; + } + // Add the ALiBi bias if slopes are given. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; @@ -403,6 +424,7 @@ __global__ void paged_attention_v1_kernel( const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const int num_kv_heads, // [num_heads] const float scale, + const float softcapping, const uint32_t* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const uint32_t* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, @@ -412,7 +434,7 @@ __global__ void paged_attention_v1_kernel( const int kv_head_stride) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, + out, q, k_cache, v_cache, num_kv_heads, scale, softcapping, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); } @@ -432,6 +454,7 @@ __global__ void paged_attention_v2_kernel( const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] const int num_kv_heads, // [num_heads] const float scale, + const float softcapping, const uint32_t* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const uint32_t* __restrict__ context_lens, // [num_seqs] const int max_num_blocks_per_seq, @@ -440,7 +463,7 @@ __global__ void paged_attention_v2_kernel( const int kv_block_stride, const int kv_head_stride) { paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, softcapping, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); } @@ -558,6 +581,7 @@ __global__ void paged_attention_v2_reduce_kernel( reinterpret_cast(value_cache), \ num_kv_heads, \ scale, \ + softcapping, \ block_tables, \ context_lens, \ max_num_blocks_per_seq, \ @@ -578,6 +602,7 @@ void paged_attention_v1_launcher( void *value_cache, int num_kv_heads, float scale, + float softcapping, uint32_t *block_tables, uint32_t *context_lens, int max_context_len, @@ -643,6 +668,7 @@ void paged_attention_v1_launcher( value_cache, \ num_kv_heads, \ scale, \ + softcapping, \ block_tables, \ context_lens, \ max_context_len, \ @@ -678,6 +704,7 @@ extern "C" void paged_attention_v1( void *value_cache, // [num_blocks, num_heads, head_size, block_size] int32_t num_kv_heads, // [num_heads] float scale, + float softcapping, uint32_t *block_tables, // [num_seqs, max_num_blocks_per_seq] uint32_t *context_lens, // [num_seqs] int32_t block_size, @@ -713,6 +740,7 @@ extern "C" void paged_attention_v1( reinterpret_cast(value_cache), \ num_kv_heads, \ scale, \ + softcapping, \ block_tables, \ context_lens, \ max_num_blocks_per_seq, \ @@ -744,6 +772,7 @@ void paged_attention_v2_launcher( void *value_cache, int num_kv_heads, float scale, + float softcapping, uint32_t *block_tables, uint32_t *context_lens, int max_context_len, @@ -816,6 +845,7 @@ void paged_attention_v2_launcher( value_cache, \ num_kv_heads, \ scale, \ + softcapping, \ block_tables, \ context_lens, \ max_context_len, \ @@ -854,6 +884,7 @@ extern "C" void paged_attention_v2( void *value_cache, // [num_blocks, num_heads, head_size, block_size] int32_t num_kv_heads, float scale, + float softcapping, uint32_t *block_tables, // [num_seqs, max_num_blocks_per_seq] uint32_t *context_lens, // [num_seqs] int32_t block_size, From 93957c7d1ef24f13c1af42340d41dc6b16cb53cb Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 21 Aug 2024 07:23:05 -0400 Subject: [PATCH 2/4] Non cuda support? --- .../src/dummy_paged_attention/layers/paged_attention.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/mistralrs-core/src/dummy_paged_attention/layers/paged_attention.rs b/mistralrs-core/src/dummy_paged_attention/layers/paged_attention.rs index 9bc2d9199c..7cc29b1020 100644 --- a/mistralrs-core/src/dummy_paged_attention/layers/paged_attention.rs +++ b/mistralrs-core/src/dummy_paged_attention/layers/paged_attention.rs @@ -62,6 +62,7 @@ impl PagedAttention { _key_cache: Option, _value_cache: Option, _input_metadata: &mut PagedAttentionInputMetadata, + _softcapping: Option, ) -> Result { unreachable!(); } From fdf1424f727e6ca1eb1be45388660b4995859444 Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 21 Aug 2024 07:24:37 -0400 Subject: [PATCH 3/4] Remove error --- mistralrs-core/src/models/gemma2.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/mistralrs-core/src/models/gemma2.rs b/mistralrs-core/src/models/gemma2.rs index 1a62ac5cb2..a460cc9140 100644 --- a/mistralrs-core/src/models/gemma2.rs +++ b/mistralrs-core/src/models/gemma2.rs @@ -544,10 +544,6 @@ impl Model { )?; let mut layers = Vec::with_capacity(cfg.num_hidden_layers); let vb_l = vb_m.pp("layers"); - if matches!(attention_mechanism, AttentionImplementation::PagedAttention) { - // TODO softcapping in paged attn - candle_core::bail!("Gemma 2 does not support PagedAttention."); - } for layer_idx in NiceProgressBar::<_, 'b'>(0..cfg.num_hidden_layers, "Loading repeating layers") { From ce9d833b84e14de5ff36fdd7a71237181f0f8e2d Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Wed, 21 Aug 2024 07:54:51 -0400 Subject: [PATCH 4/4] It works --- mistralrs-core/src/dummy_paged_attention/config.rs | 8 ++++++++ mistralrs-core/src/models/gemma.rs | 1 + mistralrs-core/src/models/gemma2.rs | 1 + mistralrs-core/src/models/llama.rs | 1 + mistralrs-core/src/models/mistral.rs | 1 + mistralrs-core/src/models/mixtral.rs | 1 + mistralrs-core/src/models/phi2.rs | 1 + mistralrs-core/src/models/phi3.rs | 1 + mistralrs-core/src/models/qwen2.rs | 1 + mistralrs-core/src/models/starcoder2.rs | 1 + mistralrs-core/src/paged_attention/cache_engine.rs | 4 ++-- mistralrs-core/src/paged_attention/config.rs | 8 ++++++++ mistralrs-core/src/vision_models/llava/llava_llm/llama.rs | 1 + .../src/vision_models/llava/llava_llm/mistral.rs | 1 + mistralrs-core/src/vision_models/phi3.rs | 1 + mistralrs-core/src/xlora_models/gemma.rs | 1 + mistralrs-core/src/xlora_models/gemma2.rs | 1 + mistralrs-core/src/xlora_models/llama.rs | 1 + mistralrs-core/src/xlora_models/mistral.rs | 1 + mistralrs-core/src/xlora_models/mixtral.rs | 1 + mistralrs-core/src/xlora_models/phi2.rs | 1 + mistralrs-core/src/xlora_models/phi3.rs | 1 + mistralrs-core/src/xlora_models/starcoder2.rs | 1 + 23 files changed, 38 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/dummy_paged_attention/config.rs b/mistralrs-core/src/dummy_paged_attention/config.rs index 7ab0bd751c..19889e7319 100644 --- a/mistralrs-core/src/dummy_paged_attention/config.rs +++ b/mistralrs-core/src/dummy_paged_attention/config.rs @@ -3,6 +3,9 @@ pub trait ModelConfigLike { fn hidden_size(&self) -> usize; fn num_kv_heads(&self) -> usize; fn num_attn_heads(&self) -> usize; + fn head_dim(&self) -> usize { + self.hidden_size() / self.num_attn_heads() + } } pub struct ModelConfigMetadata { @@ -11,6 +14,7 @@ pub struct ModelConfigMetadata { pub num_kv_heads: usize, pub num_attn_heads: usize, pub sliding_window: Option, + pub head_dim: Option, } impl ModelConfigLike for ModelConfigMetadata { @@ -26,4 +30,8 @@ impl ModelConfigLike for ModelConfigMetadata { fn num_layers(&self) -> usize { self.num_layers } + fn head_dim(&self) -> usize { + self.head_dim + .unwrap_or(self.hidden_size() / self.num_attn_heads()) + } } diff --git a/mistralrs-core/src/models/gemma.rs b/mistralrs-core/src/models/gemma.rs index c0f75c6026..dd3b68642d 100644 --- a/mistralrs-core/src/models/gemma.rs +++ b/mistralrs-core/src/models/gemma.rs @@ -545,6 +545,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: None, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/models/gemma2.rs b/mistralrs-core/src/models/gemma2.rs index a460cc9140..a22f939a19 100644 --- a/mistralrs-core/src/models/gemma2.rs +++ b/mistralrs-core/src/models/gemma2.rs @@ -616,6 +616,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: None, + head_dim: Some(cfg.head_dim), }, }) } diff --git a/mistralrs-core/src/models/llama.rs b/mistralrs-core/src/models/llama.rs index f2cfd8a58f..323ffd6de0 100644 --- a/mistralrs-core/src/models/llama.rs +++ b/mistralrs-core/src/models/llama.rs @@ -522,6 +522,7 @@ impl Llama { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: None, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/models/mistral.rs b/mistralrs-core/src/models/mistral.rs index 2d31b7a964..edeab43ccb 100644 --- a/mistralrs-core/src/models/mistral.rs +++ b/mistralrs-core/src/models/mistral.rs @@ -524,6 +524,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: cfg.sliding_window, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/models/mixtral.rs b/mistralrs-core/src/models/mixtral.rs index a4af17fda5..bbe988df2a 100644 --- a/mistralrs-core/src/models/mixtral.rs +++ b/mistralrs-core/src/models/mixtral.rs @@ -552,6 +552,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: cfg.sliding_window, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/models/phi2.rs b/mistralrs-core/src/models/phi2.rs index de1925c187..f2122f3be6 100644 --- a/mistralrs-core/src/models/phi2.rs +++ b/mistralrs-core/src/models/phi2.rs @@ -501,6 +501,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads(), num_attn_heads: cfg.num_attention_heads, sliding_window: None, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/models/phi3.rs b/mistralrs-core/src/models/phi3.rs index bea1e58320..4dc7164bd4 100644 --- a/mistralrs-core/src/models/phi3.rs +++ b/mistralrs-core/src/models/phi3.rs @@ -485,6 +485,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: cfg.sliding_window, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/models/qwen2.rs b/mistralrs-core/src/models/qwen2.rs index 6e63bbbbed..7da2241f02 100644 --- a/mistralrs-core/src/models/qwen2.rs +++ b/mistralrs-core/src/models/qwen2.rs @@ -485,6 +485,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: Some(cfg.sliding_window), + head_dim: None, }, }) } diff --git a/mistralrs-core/src/models/starcoder2.rs b/mistralrs-core/src/models/starcoder2.rs index 3e0cabd0cd..234d607d91 100644 --- a/mistralrs-core/src/models/starcoder2.rs +++ b/mistralrs-core/src/models/starcoder2.rs @@ -485,6 +485,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: cfg.sliding_window, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/paged_attention/cache_engine.rs b/mistralrs-core/src/paged_attention/cache_engine.rs index 4aa814925d..3faeb24811 100644 --- a/mistralrs-core/src/paged_attention/cache_engine.rs +++ b/mistralrs-core/src/paged_attention/cache_engine.rs @@ -137,7 +137,7 @@ impl CacheEngine { let x = 16 / element_size; ( model_config.num_kv_heads(), - model_config.hidden_size() / model_config.num_attn_heads() / x, + model_config.head_dim() / x, block_size, x, ) @@ -149,7 +149,7 @@ impl CacheEngine { ) -> (usize, usize, usize) { ( model_config.num_kv_heads(), - model_config.hidden_size() / model_config.num_attn_heads(), + model_config.head_dim(), block_size, ) } diff --git a/mistralrs-core/src/paged_attention/config.rs b/mistralrs-core/src/paged_attention/config.rs index 7ab0bd751c..19889e7319 100644 --- a/mistralrs-core/src/paged_attention/config.rs +++ b/mistralrs-core/src/paged_attention/config.rs @@ -3,6 +3,9 @@ pub trait ModelConfigLike { fn hidden_size(&self) -> usize; fn num_kv_heads(&self) -> usize; fn num_attn_heads(&self) -> usize; + fn head_dim(&self) -> usize { + self.hidden_size() / self.num_attn_heads() + } } pub struct ModelConfigMetadata { @@ -11,6 +14,7 @@ pub struct ModelConfigMetadata { pub num_kv_heads: usize, pub num_attn_heads: usize, pub sliding_window: Option, + pub head_dim: Option, } impl ModelConfigLike for ModelConfigMetadata { @@ -26,4 +30,8 @@ impl ModelConfigLike for ModelConfigMetadata { fn num_layers(&self) -> usize { self.num_layers } + fn head_dim(&self) -> usize { + self.head_dim + .unwrap_or(self.hidden_size() / self.num_attn_heads()) + } } diff --git a/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs b/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs index b59446f81d..ecb1e6df71 100644 --- a/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs +++ b/mistralrs-core/src/vision_models/llava/llava_llm/llama.rs @@ -452,6 +452,7 @@ impl Llama { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: None, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs b/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs index 54e3fb5cdc..89808d1918 100644 --- a/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs +++ b/mistralrs-core/src/vision_models/llava/llava_llm/mistral.rs @@ -466,6 +466,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: cfg.sliding_window, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/vision_models/phi3.rs b/mistralrs-core/src/vision_models/phi3.rs index 62d0dd8246..466e0508dd 100644 --- a/mistralrs-core/src/vision_models/phi3.rs +++ b/mistralrs-core/src/vision_models/phi3.rs @@ -945,6 +945,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: cfg.sliding_window, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/xlora_models/gemma.rs b/mistralrs-core/src/xlora_models/gemma.rs index 335305bff1..b0a3cd912d 100644 --- a/mistralrs-core/src/xlora_models/gemma.rs +++ b/mistralrs-core/src/xlora_models/gemma.rs @@ -606,6 +606,7 @@ impl XLoraModel { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: None, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/xlora_models/gemma2.rs b/mistralrs-core/src/xlora_models/gemma2.rs index e7572bee6e..a1ff2c6143 100644 --- a/mistralrs-core/src/xlora_models/gemma2.rs +++ b/mistralrs-core/src/xlora_models/gemma2.rs @@ -660,6 +660,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: None, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/xlora_models/llama.rs b/mistralrs-core/src/xlora_models/llama.rs index be3b132d49..4051554d6f 100644 --- a/mistralrs-core/src/xlora_models/llama.rs +++ b/mistralrs-core/src/xlora_models/llama.rs @@ -692,6 +692,7 @@ impl XLoraLlama { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: None, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/xlora_models/mistral.rs b/mistralrs-core/src/xlora_models/mistral.rs index a1dd8cc465..82f85cb29d 100644 --- a/mistralrs-core/src/xlora_models/mistral.rs +++ b/mistralrs-core/src/xlora_models/mistral.rs @@ -572,6 +572,7 @@ impl XLoraModel { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: cfg.sliding_window, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/xlora_models/mixtral.rs b/mistralrs-core/src/xlora_models/mixtral.rs index f7f3259ea6..fb6dcdec71 100644 --- a/mistralrs-core/src/xlora_models/mixtral.rs +++ b/mistralrs-core/src/xlora_models/mixtral.rs @@ -709,6 +709,7 @@ impl XLoraModel { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: cfg.sliding_window, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/xlora_models/phi2.rs b/mistralrs-core/src/xlora_models/phi2.rs index bd8e81dcd6..51d32c4e33 100644 --- a/mistralrs-core/src/xlora_models/phi2.rs +++ b/mistralrs-core/src/xlora_models/phi2.rs @@ -542,6 +542,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads(), num_attn_heads: cfg.num_attention_heads, sliding_window: None, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/xlora_models/phi3.rs b/mistralrs-core/src/xlora_models/phi3.rs index 84cb4afd51..8c377c67a0 100644 --- a/mistralrs-core/src/xlora_models/phi3.rs +++ b/mistralrs-core/src/xlora_models/phi3.rs @@ -504,6 +504,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: cfg.sliding_window, + head_dim: None, }, }) } diff --git a/mistralrs-core/src/xlora_models/starcoder2.rs b/mistralrs-core/src/xlora_models/starcoder2.rs index 92d6bae199..de35af03dd 100644 --- a/mistralrs-core/src/xlora_models/starcoder2.rs +++ b/mistralrs-core/src/xlora_models/starcoder2.rs @@ -554,6 +554,7 @@ impl Model { num_kv_heads: cfg.num_key_value_heads, num_attn_heads: cfg.num_attention_heads, sliding_window: cfg.sliding_window, + head_dim: None, }, }) }