Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mistralrs-core/src/dummy_paged_attention/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ pub trait ModelConfigLike {
fn hidden_size(&self) -> usize;
fn num_kv_heads(&self) -> usize;
fn num_attn_heads(&self) -> usize;
fn head_dim(&self) -> usize {
self.hidden_size() / self.num_attn_heads()
}
}

pub struct ModelConfigMetadata {
Expand All @@ -11,6 +14,7 @@ pub struct ModelConfigMetadata {
pub num_kv_heads: usize,
pub num_attn_heads: usize,
pub sliding_window: Option<usize>,
pub head_dim: Option<usize>,
}

impl ModelConfigLike for ModelConfigMetadata {
Expand All @@ -26,4 +30,8 @@ impl ModelConfigLike for ModelConfigMetadata {
fn num_layers(&self) -> usize {
self.num_layers
}
fn head_dim(&self) -> usize {
self.head_dim
.unwrap_or(self.hidden_size() / self.num_attn_heads())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ impl PagedAttention {
_key_cache: Option<Tensor>,
_value_cache: Option<Tensor>,
_input_metadata: &mut PagedAttentionInputMetadata,
_softcapping: Option<f64>,
) -> Result<Tensor> {
unreachable!();
}
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ impl Attention {
Some(key_cache),
Some(value_cache),
input_metadata,
None,
)?
}
None => {
Expand Down Expand Up @@ -544,6 +545,7 @@ impl Model {
num_kv_heads: cfg.num_key_value_heads,
num_attn_heads: cfg.num_attention_heads,
sliding_window: None,
head_dim: None,
},
})
}
Expand Down
122 changes: 85 additions & 37 deletions mistralrs-core/src/models/gemma2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
device_map::DeviceMapper,
get_delta_from_lora_ab,
layers::{repeat_kv, CausalMasker, MatMul},
paged_attention::{AttentionImplementation, ModelConfigMetadata},
paged_attention::{AttentionImplementation, ModelConfigMetadata, PagedAttention},
pipeline::{
extract_logits, text_models_inputs_processor::PagedAttentionInputMetadata, Cache, IsqModel,
NormalLoadingMetadata, NormalModel,
Expand Down Expand Up @@ -199,7 +199,6 @@ impl MlpLayer for MLP {
}
}

#[derive(Clone)]
struct Attention {
q_proj: Arc<dyn QuantMethod>,
k_proj: Arc<dyn QuantMethod>,
Expand All @@ -214,6 +213,7 @@ struct Attention {
attn_logit_softcapping: Option<f64>,
use_sliding_window: bool,
sliding_window: Option<usize>,
paged_attn: Option<PagedAttention>,
}

impl Attention {
Expand All @@ -222,6 +222,7 @@ impl Attention {
cfg: &Config,
layer_idx: usize,
vb: VarBuilder,
paged_attn: Option<PagedAttention>,
) -> Result<Self> {
let hidden_sz = cfg.hidden_size;
let num_heads = cfg.num_attention_heads;
Expand Down Expand Up @@ -276,9 +277,11 @@ impl Attention {
} else {
None
},
paged_attn,
})
}

#[allow(clippy::too_many_arguments)]
fn forward(
&self,
xs: &Tensor,
Expand All @@ -287,6 +290,7 @@ impl Attention {
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
) -> Result<Tensor> {
let (b_sz, q_len, _) = xs.dims3()?;

Expand Down Expand Up @@ -342,38 +346,55 @@ impl Attention {
attention_mask
};

// self.sliding_window is None if !self.use_sliding_window
let (k, v, mask) = Cache::update_kv_cache_sliding_window(
kv_cache,
k,
v,
mask,
self.sliding_window,
false,
)?;

let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;

let mut att = MatMul.matmul_affine_div(
&q.contiguous()?,
&k.t()?.contiguous()?,
(self.query_pre_attn_scalar as f64).sqrt(),
)?;

if let Some(attn_logit_softcapping) = self.attn_logit_softcapping {
att = (att / attn_logit_softcapping)?;
att = att.tanh()?;
att = (att * attn_logit_softcapping)?;
}
let mut attn_output = match &self.paged_attn {
Some(paged_attn) => {
let ((key_cache, value_cache), input_metadata) = metadata.unwrap();
paged_attn.forward(
&q,
&k,
&v,
attention_mask,
Some(key_cache),
Some(value_cache),
input_metadata,
self.attn_logit_softcapping,
)?
}
None => {
// self.sliding_window is None if !self.use_sliding_window
let (k, v, mask) = Cache::update_kv_cache_sliding_window(
kv_cache,
k,
v,
mask,
self.sliding_window,
false,
)?;

let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;

let mut att = MatMul.matmul_affine_div(
&q.contiguous()?,
&k.t()?.contiguous()?,
(self.query_pre_attn_scalar as f64).sqrt(),
)?;

if let Some(attn_logit_softcapping) = self.attn_logit_softcapping {
att = (att / attn_logit_softcapping)?;
att = att.tanh()?;
att = (att * attn_logit_softcapping)?;
}

let att = match mask {
Some(m) => att.broadcast_add(&m)?,
None => att,
let att = match mask {
Some(m) => att.broadcast_add(&m)?,
None => att,
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
MatMul.matmul(&att, &v.contiguous()?)?
}
};
let att = candle_nn::ops::softmax_last_dim(&att)?;
// Convert to contiguous as matmul doesn't support strided vs for now.
let mut attn_output = MatMul.matmul(&att, &v.contiguous()?)?;

if let Some(t) = self.q_proj.quantized_act_type() {
attn_output = attn_output.to_dtype(t)?;
Expand Down Expand Up @@ -408,12 +429,14 @@ impl DecoderLayer {
mapper: &dyn DeviceMapper,
layer_idx: usize,
loading_isq: bool,
paged_attn: Option<PagedAttention>,
) -> Result<Self> {
let self_attn = Attention::new(
rotary_emb,
cfg,
layer_idx,
mapper.set_device(layer_idx, vb.pp("self_attn"), loading_isq),
paged_attn,
)?;
let mlp = MLP::new(cfg, mapper.set_device(layer_idx, vb.pp("mlp"), loading_isq))?;
let input_layernorm = RmsNorm::new(
Expand Down Expand Up @@ -446,6 +469,7 @@ impl DecoderLayer {
})
}

#[allow(clippy::too_many_arguments)]
fn forward(
&self,
xs: &Tensor,
Expand All @@ -454,6 +478,7 @@ impl DecoderLayer {
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
kv_cache: &mut Option<(Tensor, Tensor)>,
metadata: Option<((Tensor, Tensor), &mut PagedAttentionInputMetadata)>,
) -> Result<Tensor> {
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
Expand All @@ -466,6 +491,7 @@ impl DecoderLayer {
seqlen_offsets,
start_offsets_kernel,
kv_cache,
metadata,
)?
.apply(&self.post_attention_layernorm)?;
let xs = (xs + residual)?;
Expand Down Expand Up @@ -518,10 +544,6 @@ impl Model {
)?;
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
let vb_l = vb_m.pp("layers");
if matches!(attention_mechanism, AttentionImplementation::PagedAttention) {
// TODO softcapping in paged attn
candle_core::bail!("Gemma 2 does not support PagedAttention.");
}
for layer_idx in
NiceProgressBar::<_, 'b'>(0..cfg.num_hidden_layers, "Loading repeating layers")
{
Expand All @@ -535,13 +557,33 @@ impl Model {
is_gptx,
vb.dtype(),
)?);
let head_dim = cfg.head_dim;
let sliding_window = if layer_idx % 2 == 0 {
// ^ Order is SWA, global, SWA
Some(cfg.sliding_window)
} else {
None
};
let paged_attn = match &attention_mechanism {
AttentionImplementation::Eager => None,
AttentionImplementation::PagedAttention => Some(PagedAttention::new(
cfg.num_attention_heads,
head_dim,
(1.0 / (cfg.query_pre_attn_scalar as f64).sqrt()) as f32,
Some(cfg.num_key_value_heads),
sliding_window,
&normal_loading_metadata.real_device,
None,
)?),
};
let layer = DecoderLayer::new(
rotary_emb.clone(),
cfg,
vb_l.pp(layer_idx),
&*mapper,
layer_idx,
normal_loading_metadata.loading_isq,
paged_attn,
)?;
layers.push(layer)
}
Expand Down Expand Up @@ -574,6 +616,7 @@ impl Model {
num_kv_heads: cfg.num_key_value_heads,
num_attn_heads: cfg.num_attention_heads,
sliding_window: None,
head_dim: Some(cfg.head_dim),
},
})
}
Expand All @@ -584,6 +627,7 @@ impl Model {
seqlen_offsets: &[usize],
start_offsets_kernel: Tensor,
context_lens: Vec<(usize, usize)>,
mut metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>,
) -> Result<Tensor> {
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
Expand Down Expand Up @@ -617,6 +661,9 @@ impl Model {
seqlen_offsets,
start_offsets_kernel.clone(),
&mut cache[i],
metadata
.as_mut()
.map(|(kv_cache, metadata)| (kv_cache[i].clone(), &mut **metadata)),
)?;
}
let xs = xs.to_device(&self.device)?;
Expand Down Expand Up @@ -672,13 +719,14 @@ impl NormalModel for Model {
start_offsets_kernel: Tensor,
context_lens: Vec<(usize, usize)>,
_position_ids: Vec<usize>,
_metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>,
metadata: Option<(Vec<(Tensor, Tensor)>, &mut PagedAttentionInputMetadata)>,
) -> Result<Tensor> {
self.forward(
input_ids,
seqlen_offsets,
start_offsets_kernel,
context_lens,
metadata,
)
}
fn xlora_forward(
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ impl CausalSelfAttention {
Some(key_cache),
Some(value_cache),
input_metadata,
None,
)?
}
None => {
Expand Down Expand Up @@ -521,6 +522,7 @@ impl Llama {
num_kv_heads: cfg.num_key_value_heads,
num_attn_heads: cfg.num_attention_heads,
sliding_window: None,
head_dim: None,
},
})
}
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ impl Attention {
Some(key_cache),
Some(value_cache),
input_metadata,
None,
)?
}
None => {
Expand Down Expand Up @@ -523,6 +524,7 @@ impl Model {
num_kv_heads: cfg.num_key_value_heads,
num_attn_heads: cfg.num_attention_heads,
sliding_window: cfg.sliding_window,
head_dim: None,
},
})
}
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/models/mixtral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ impl Attention {
Some(key_cache),
Some(value_cache),
input_metadata,
None,
)?
}
None => {
Expand Down Expand Up @@ -551,6 +552,7 @@ impl Model {
num_kv_heads: cfg.num_key_value_heads,
num_attn_heads: cfg.num_attention_heads,
sliding_window: cfg.sliding_window,
head_dim: None,
},
})
}
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/models/phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ impl Attention {
Some(key_cache),
Some(value_cache),
input_metadata,
None,
)?
}
None => {
Expand Down Expand Up @@ -500,6 +501,7 @@ impl Model {
num_kv_heads: cfg.num_key_value_heads(),
num_attn_heads: cfg.num_attention_heads,
sliding_window: None,
head_dim: None,
},
})
}
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/models/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ impl Attention {
Some(key_cache),
Some(value_cache),
input_metadata,
None,
)?
}
None => {
Expand Down Expand Up @@ -484,6 +485,7 @@ impl Model {
num_kv_heads: cfg.num_key_value_heads,
num_attn_heads: cfg.num_attention_heads,
sliding_window: cfg.sliding_window,
head_dim: None,
},
})
}
Expand Down
Loading