Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 50 additions & 4 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ pub struct PhiRopeConfig {
pub original_max_position_embeddings: usize,
pub rope_theta: f64,
pub head_dim: usize,
pub partial_rotary_factor: Option<f64>,
}

impl PhiRotaryEmbedding {
Expand All @@ -294,7 +295,7 @@ impl PhiRotaryEmbedding {
dev: &Device,
) -> Result<Self> {
let max_seq_len = cfg.max_position_embeddings;
let dim = cfg.head_dim;
let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;

// Calculate scale
let scale =
Expand Down Expand Up @@ -356,7 +357,7 @@ impl PhiRotaryEmbedding {

fn new_unscaled(cfg: &PhiRopeConfig, dtype: DType, dev: &Device) -> Result<Self> {
let max_seq_len = cfg.max_position_embeddings;
let dim = cfg.head_dim;
let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;

let inv_freq: Vec<_> = (0..dim)
.step_by(2)
Expand Down Expand Up @@ -391,7 +392,7 @@ impl PhiRotaryEmbedding {
dev: &Device,
) -> Result<Self> {
let max_seq_len = cfg.max_position_embeddings;
let dim = cfg.head_dim;
let dim = (cfg.head_dim as f64 * cfg.partial_rotary_factor.unwrap_or(1.)) as usize;

if !matches!(scaling_type, ScaledRopeType::Su) {
candle_core::bail!("Scaled Phi3 RoPE (non-classic scaled, with mscales) must have type `su`/`longrope`.");
Expand Down Expand Up @@ -513,7 +514,52 @@ impl PhiRotaryEmbedding {
let (sin, cos) = self.get_long_or_short_sin_cos(position_ids);
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let all_same = seqlen_offsets.iter().all(|&x| x == seqlen_offsets[0]);
if all_same {

let rot_dim = cos.dim(D::Minus1)? * 2;

// Case for Phi 3 / Phi 4 mini
if rot_dim != q.dim(D::Minus1)? {
let rot_dim = cos.dim(D::Minus1)? * 2;
let q_rot = q.narrow(D::Minus1, 0, rot_dim)?;
let q_pass = q.narrow(D::Minus1, rot_dim, q.dim(D::Minus1)? - rot_dim)?;
let k_rot = k.narrow(D::Minus1, 0, rot_dim)?;
let k_pass = k.narrow(D::Minus1, rot_dim, k.dim(D::Minus1)? - rot_dim)?;

let (q_rot, k_rot) = if all_same {
let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q_rot.contiguous()?, &cos, &sin)?;
let k_embed = candle_nn::rotary_emb::rope(&k_rot.contiguous()?, &cos, &sin)?;
(q_embed, k_embed)
} else {
let mut q_embeds = Vec::new();
let mut k_embeds = Vec::new();
for (i, offset) in seqlen_offsets.iter().enumerate() {
let cos = cos.narrow(0, *offset, seq_len)?;
let sin = sin.narrow(0, *offset, seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(
&q_rot.i(i)?.unsqueeze(0)?.contiguous()?,
&cos,
&sin,
)?;
let k_embed = candle_nn::rotary_emb::rope(
&k_rot.i(i)?.unsqueeze(0)?.contiguous()?,
&cos,
&sin,
)?;
q_embeds.push(q_embed);
k_embeds.push(k_embed);
}
let q_rot = Tensor::cat(&q_embeds, 0)?;
let k_rot = Tensor::cat(&k_embeds, 0)?;
(q_rot, k_rot)
};

Ok((
Tensor::cat(&[q_rot, q_pass], D::Minus1)?.contiguous()?,
Tensor::cat(&[k_rot, k_pass], D::Minus1)?.contiguous()?,
))
} else if all_same {
let cos = cos.narrow(0, seqlen_offsets[0], seq_len)?;
let sin = sin.narrow(0, seqlen_offsets[0], seq_len)?;
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
Expand Down
17 changes: 15 additions & 2 deletions mistralrs-core/src/layers_masker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,25 @@ impl CausalMasker {
return Ok(None);
}

let causal_mask = {
let mut causal_mask = {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;

masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?.to_dtype(dtype)?
masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?
.to_dtype(DType::U8)?
};

let zero = Tensor::new(0.0f32, input_ids.device())?;
causal_mask = {
let mask = causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
// Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)

masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
&mask,
f32::NEG_INFINITY,
)?
};

Ok(Some(causal_mask))
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/models/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub struct Config {
pub quantization_config: Option<QuantizedConfig>,
#[serde(default = "word_emb_default")]
pub tie_word_embeddings: bool,
pub partial_rotary_factor: Option<f64>,
}

impl From<Config> for PhiRopeConfig {
Expand All @@ -63,6 +64,7 @@ impl From<Config> for PhiRopeConfig {
original_max_position_embeddings: val.original_max_position_embeddings,
rope_theta: val.rope_theta,
head_dim: val.hidden_size / val.num_attention_heads,
partial_rotary_factor: val.partial_rotary_factor,
}
}
}
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/phi3_5_moe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ impl From<Config> for PhiRopeConfig {
original_max_position_embeddings: val.original_max_position_embeddings,
rope_theta: val.rope_theta,
head_dim: val.hidden_size / val.num_attention_heads,
partial_rotary_factor: None,
}
}
}
Expand Down
49 changes: 42 additions & 7 deletions mistralrs-core/src/pipeline/cache_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ impl SingleCache {
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad);
};

// Expand kv cache
if self.current_seq_len + seq_len > self.capacity_seq_len {
let diff = self.current_seq_len + seq_len - self.capacity_seq_len;
Expand All @@ -134,7 +135,9 @@ impl SingleCache {
ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?;
self.all_data = Some(ad);
}

let ad = self.all_data.as_mut().unwrap();

ad.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len;
Ok(())
Expand All @@ -152,16 +155,18 @@ pub struct RotatingCache {
// max_seq_len is the size of the rotating buffer, it is actually allowed for the full
// sequence to grow past this limit.
pub max_seq_len: usize,
pub capacity_seq_len: usize,
}

impl RotatingCache {
pub fn new(dim: usize, max_seq_len: usize) -> Self {
pub fn new(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
offset: 0,
current_seq_len: 0,
max_seq_len,
capacity_seq_len,
}
}

Expand Down Expand Up @@ -224,10 +229,32 @@ impl RotatingCache {
// self.all_data.get_or_insert_with.
if self.all_data.is_none() {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.max_seq_len;
shape[self.dim] = self.capacity_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad)
};

// Expand kv cache, this case is a little more complex.
if self.current_seq_len + seq_len > self.capacity_seq_len
&& self.current_seq_len + seq_len < self.max_seq_len
{
let diff = self.current_seq_len + seq_len - self.capacity_seq_len;
let n_blocks_needed = diff.div_ceil(NormalCache::CACHE_GROW_SIZE);
self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE;
if self.capacity_seq_len > self.max_seq_len {
candle_core::bail!(
"kv-cache: requested capacity ({}) above max seq len ({})",
self.capacity_seq_len,
self.max_seq_len
)
}
let mut shape = src.dims().to_vec();
shape[self.dim] = self.capacity_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?;
self.all_data = Some(ad);
}

let ad = self.all_data.as_mut().unwrap();

self.current_seq_len += seq_len;
Expand Down Expand Up @@ -278,9 +305,9 @@ impl KvCache {
Self::Normal { k, v }
}

pub fn new_rotating(dim: usize, sliding_window: usize) -> Self {
let k = RotatingCache::new(dim, sliding_window);
let v = RotatingCache::new(dim, sliding_window);
pub fn new_rotating(dim: usize, sliding_window: usize, capacity_seq_len: usize) -> Self {
let k = RotatingCache::new(dim, sliding_window, capacity_seq_len);
let v = RotatingCache::new(dim, sliding_window, capacity_seq_len);
Self::Rotating { k, v }
}

Expand Down Expand Up @@ -409,7 +436,8 @@ impl NormalCache {
Some(sliding_window) => Arc::new(Mutex::new(Self(vec![
KvCache::new_rotating(
2,
sliding_window
sliding_window,
Self::CACHE_GROW_SIZE
);
len
]))),
Expand All @@ -432,7 +460,7 @@ impl NormalCache {
caches.push(KvCache::new_normal(2, max_seq_len, Self::CACHE_GROW_SIZE));
}
NormalCacheType::SlidingWindow { window } => {
caches.push(KvCache::new_rotating(2, window));
caches.push(KvCache::new_rotating(2, window, Self::CACHE_GROW_SIZE));
}
}
}
Expand Down Expand Up @@ -532,6 +560,7 @@ impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for NormalCa
let template_cache_csl = old_k.current_seq_len;
let template_cache_msl = old_k.max_seq_len;
let template_cache_offset = old_k.offset;
let template_cache_capsl = old_k.capacity_seq_len;

caches.push(KvCache::Rotating {
k: RotatingCache {
Expand All @@ -540,13 +569,15 @@ impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for NormalCa
current_seq_len: template_cache_csl,
max_seq_len: template_cache_msl,
offset: template_cache_offset,
capacity_seq_len: template_cache_capsl,
},
v: RotatingCache {
all_data: v_cache.map(|x| x.contiguous().unwrap()),
dim: template_cache_dim,
current_seq_len: template_cache_csl,
max_seq_len: template_cache_msl,
offset: template_cache_offset,
capacity_seq_len: template_cache_capsl,
},
});
}
Expand Down Expand Up @@ -620,13 +651,15 @@ impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for NormalCa
current_seq_len: cache_k.current_seq_len,
max_seq_len: cache_k.max_seq_len,
offset: cache_k.offset,
capacity_seq_len: cache_k.capacity_seq_len,
},
v: RotatingCache {
all_data: Some(v),
dim: cache_v.dim,
current_seq_len: cache_v.current_seq_len,
max_seq_len: cache_v.max_seq_len,
offset: cache_v.offset,
capacity_seq_len: cache_v.capacity_seq_len,
},
});
}
Expand Down Expand Up @@ -731,13 +764,15 @@ impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for NormalCa
current_seq_len: 0,
max_seq_len: template_cache_msl,
offset: 0,
capacity_seq_len: 0,
},
v: RotatingCache {
all_data: None,
dim: template_cache_dim,
current_seq_len: 0,
max_seq_len: template_cache_msl,
offset: 0,
capacity_seq_len: 0,
},
};
*layer = cache;
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/pipeline/loaders/normal_loaders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,7 @@ struct Phi3BasicConfig {
quantization_config: Option<QuantizedConfig>,
#[serde(default = "word_emb_default")]
tie_word_embeddings: bool,
partial_rotary_factor: Option<f64>,
}

impl Phi3BasicConfig {
Expand All @@ -1570,6 +1571,7 @@ impl Phi3BasicConfig {
sliding_window: basic_config.sliding_window,
quantization_config: basic_config.quantization_config,
tie_word_embeddings: basic_config.tie_word_embeddings,
partial_rotary_factor: basic_config.partial_rotary_factor,
})
}
}
Expand Down
2 changes: 2 additions & 0 deletions mistralrs-core/src/prefix_cacher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,15 @@ impl PrefixCacheManagerV2 {
current_seq_len: k.current_seq_len,
max_seq_len: k.max_seq_len,
offset: k.offset,
capacity_seq_len: k.capacity_seq_len,
},
v: RotatingCache {
all_data: v.all_data.as_ref().map(|x| x.to_device(device).unwrap()),
dim: v.dim,
current_seq_len: v.current_seq_len,
max_seq_len: v.max_seq_len,
offset: v.offset,
capacity_seq_len: v.capacity_seq_len,
},
}
}
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/vision_models/phi3/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ impl From<Config> for PhiRopeConfig {
original_max_position_embeddings: val.original_max_position_embeddings,
rope_theta: val.rope_theta,
head_dim: val.hidden_size / val.num_attention_heads,
partial_rotary_factor: None,
}
}
}
Expand Down
Loading