Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 40 additions & 16 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,23 +312,47 @@ impl MistralRs {
};

let engine_handler = thread::spawn(move || {
let rt = Runtime::new().unwrap();
rt.block_on(async move {
let engine = Engine::new(
rx,
pipeline,
method,
truncate_sequence,
no_kv_cache,
no_prefix_cache,
prefix_cache_n,
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model,
)
.expect("Engine creation failed.");
Arc::new(engine).run().await;
#[cfg(feature = "metal")]
objc::rc::autoreleasepool(move || {
let rt = Runtime::new().unwrap();
rt.block_on(async move {
let engine = Engine::new(
rx,
pipeline,
method,
truncate_sequence,
no_kv_cache,
no_prefix_cache,
prefix_cache_n,
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model,
)
.expect("Engine creation failed.");
Arc::new(engine).run().await;
})
});

#[cfg(not(feature = "metal"))]
{
let rt = Runtime::new().unwrap();
rt.block_on(async move {
let engine = Engine::new(
rx,
pipeline,
method,
truncate_sequence,
no_kv_cache,
no_prefix_cache,
prefix_cache_n,
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model,
)
.expect("Engine creation failed.");
Arc::new(engine).run().await;
})
}
});

let engine_id = ENGINE_ID.fetch_add(1, atomic::Ordering::SeqCst);
Expand Down
69 changes: 39 additions & 30 deletions mistralrs-core/src/pipeline/cache_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,46 +486,55 @@ impl<T: CacheManagerMixin + MetadataMixin + ?Sized> CacheManager<T> for NormalCa
let mut new_k_cache = Vec::new();
let mut new_v_cache = Vec::new();

'outer: for layer in 0..pipeline.get_metadata().num_hidden_layers {
let mut k_vec = Vec::new();
let mut v_vec = Vec::new();
for seq in &mut *seqs {
for layer in 0..pipeline.get_metadata().num_hidden_layers {
// Preallocate combined k and v caches across all sequences, avoiding Tensor::cat copies
let batch_len = seqs.len();
// Use the first sequence as template
let (first_k, first_v) = {
let src_cache = if modify_draft_cache {
seq.normal_draft_cache()
seqs[0].normal_draft_cache()
} else {
seq.normal_cache()
seqs[0].normal_cache()
};
let cache = src_cache.get(layer).unwrap();
// This case for llama 3.2 vision cross attn
if cache.is_none() {
new_k_cache.push(None);
new_v_cache.push(None);
continue 'outer;
}
let cache = cache
.as_ref()
.expect("Not handling completions in `clone_in_cache`.");
let cache = src_cache.get(layer).unwrap().as_ref().unwrap();
match cache {
KvCache::Normal { k, v } => {
k_vec.push(k.all_data.clone().unwrap());
v_vec.push(v.all_data.clone().unwrap());
(k.all_data.clone().unwrap(), v.all_data.clone().unwrap())
}
KvCache::Rotating { k, v } => {
k_vec.push(k.all_data.clone().unwrap());
v_vec.push(v.all_data.clone().unwrap());
(k.all_data.clone().unwrap(), v.all_data.clone().unwrap())
}
}
};
// Build dims for batched cache
let mut dims_k = first_k.dims().to_vec();
let mut dims_v = first_v.dims().to_vec();
dims_k[0] *= batch_len;
dims_v[0] *= batch_len;
let batch_k = Tensor::zeros(dims_k.clone(), first_k.dtype(), first_k.device()).unwrap();
let batch_v = Tensor::zeros(dims_v.clone(), first_v.dtype(), first_v.device()).unwrap();
// Fill each sequence's cache slice
for (i, seq) in seqs.iter_mut().enumerate() {
let src_cache = if modify_draft_cache {
seq.normal_draft_cache()
} else {
seq.normal_cache()
};
let cache = src_cache.get(layer).unwrap().as_ref().unwrap();
let (src_k, src_v) = match cache {
KvCache::Normal { k, v } => {
(k.all_data.clone().unwrap(), v.all_data.clone().unwrap())
}
KvCache::Rotating { k, v } => {
(k.all_data.clone().unwrap(), v.all_data.clone().unwrap())
}
};
let offset = i * first_k.dims()[0];
batch_k.slice_set(&src_k, 0, offset).unwrap();
batch_v.slice_set(&src_v, 0, offset).unwrap();
}
new_k_cache.push(Some(if k_vec.len() > 1 {
Tensor::cat(&k_vec, 0).unwrap()
} else {
k_vec[0].clone()
}));
new_v_cache.push(Some(if v_vec.len() > 1 {
Tensor::cat(&v_vec, 0).unwrap()
} else {
v_vec[0].clone()
}));
new_k_cache.push(Some(batch_k));
new_v_cache.push(Some(batch_v));
}

let seq0_cache = if modify_draft_cache {
Expand Down
32 changes: 0 additions & 32 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -999,38 +999,6 @@ impl Pipeline for NormalPipeline {
}
(None, None) => None,
};
#[cfg(feature = "metal")]
let logits = objc::rc::autoreleasepool(|| -> candle_core::Result<Tensor> {
match self.model.is_xlora() {
false => {
let paged_attn_meta = paged_attn_meta
.as_ref()
.map(|meta| (meta.0.get_kv_cache().clone(), meta.1.clone()));

self.model.forward(
&input_ids,
&seqlen_offsets,
context_lens,
position_ids,
paged_attn_meta.as_ref().map(|(a, b)| (a.clone(), b)),
&flash_meta,
)
}
true => self.model.xlora_forward(
&input_ids,
input_ids_full.as_ref().unwrap_or(&input_ids),
&seqlen_offsets,
seqlen_offsets_full.as_ref().unwrap_or(&seqlen_offsets),
self.no_kv_cache,
&self.non_granular_state,
context_lens,
position_ids,
&flash_meta,
flash_meta_full.as_ref().unwrap_or(&flash_meta),
),
}
})?;
#[cfg(not(feature = "metal"))]
let logits = match self.model.is_xlora() {
false => {
let paged_attn_meta = paged_attn_meta
Expand Down
14 changes: 0 additions & 14 deletions mistralrs-core/src/pipeline/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -845,20 +845,6 @@ impl Pipeline for VisionPipeline {
}
(None, None) => None,
};
#[cfg(feature = "metal")]
let logits = objc::rc::autoreleasepool(|| {
self.model.forward(
&input_ids,
pixel_values,
&seqlen_offsets,
context_lens,
position_ids,
model_specific_args,
paged_attn_meta,
&flash_meta,
)
})?;
#[cfg(not(feature = "metal"))]
let logits = self.model.forward(
&input_ids,
pixel_values,
Expand Down
27 changes: 2 additions & 25 deletions mistralrs-quant/src/metal_kernels/blockwise_fp8.metal
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "float8.metal"
#include "utils.metal"

#include <metal_stdlib>

using namespace metal;
Expand All @@ -13,31 +15,6 @@ extern "C" struct DequantParams {
uint block_size_x; // tile width ( == weight_block_size_x )
};

/* ------------------ FP8-E4M3 → float32 helper ------------------ */
inline float fp8_e4m3_to_float(uchar v) {
const uint sign = v >> 7;
const uint exponent = (v >> 3) & 0xF;
const uint mantissa = v & 0x7;

/* special encodings ------------------------------------------------ */
if (exponent == 0) { // sub-normal / zero
if (mantissa == 0)
return 0.0f * (1.0f - 2.0f * sign); // signed zero
float m = float(mantissa) / 8.0f; // 2⁻³ scale for 3-bit frac
float val = ldexp(m, -6); // 2^(1-bias-fracbits) : bias=7
return sign ? -val : val;
}
if (exponent == 0xF) { // Inf / NaN
return sign ? -INFINITY : INFINITY;
}

/* normal numbers --------------------------------------------------- */
float m = 1.0f + float(mantissa) / 8.0f; // implicit leading 1
int exp = int(exponent) - 7; // remove bias (bias = 7)
float val = ldexp(m, exp);
return sign ? -val : val;
}

/* -------------------------- kernel bodies ---------------------------- */
template <typename OutT>
kernel void
Expand Down
112 changes: 112 additions & 0 deletions mistralrs-quant/src/metal_kernels/float8.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include <metal_stdlib>

using namespace metal;

// ————————————————————————————————————————————————————————————————
// F8E4M3 (Sign=1, Exponent=4, Mantissa=3; bias=2^(4−1)−1 = 7)
// ————————————————————————————————————————————————————————————————

inline float fp8_e4m3_to_float(uchar v) {
const uint sign = (v >> 7) & 0x1;
const uint exp_bits = (v >> 3) & 0xF;
const uint man_bits = v & 0x7;

// handle zero / subnormals
if (exp_bits == 0) {
if (man_bits == 0) {
return sign ? -0.0f : 0.0f;
}
// subnormal: mantissa / 2^(bias + mantissa_bits)
float m = float(man_bits) / float(1 << 3);
float val = ldexp(m, 1 - 7 - 3);
return sign ? -val : val;
}
// handle Inf/NaN
if (exp_bits == 0xF) {
return sign ? -INFINITY : INFINITY;
}
// normalised
float mant = 1.0f + float(man_bits) / float(1 << 3);
int expn = int(exp_bits) - 7;
float val = ldexp(mant, expn);
return sign ? -val : val;
}

inline uchar float_to_fp8_e4m3(float f) {
uint bits = as_type<uint>(f);
uint sign = bits >> 31;
int exp = int((bits >> 23) & 0xFF) - 127 + 7; // adjust bias
uint man = bits & 0x7FFFFF;

// handle NaN/Inf
if (exp > 0xE) {
// map all overflow to Inf
return uchar((sign << 7) | (0xF << 3));
}
// handle zero and subnormals
if (exp <= 0) {
// subnormal or underflow → zero
return uchar(sign << 7);
}
// round-to-nearest-even: add half-ULP
uint mant_rounded = (man + (1 << (23 - 3 - 1))) >> (23 - 3);
if (mant_rounded == (1 << 3)) {
// overflow in mantissa → bump exponent
mant_rounded = 0;
exp += 1;
if (exp >= 0xF) {
// now overflow → Inf
return uchar((sign << 7) | (0xF << 3));
}
}
return uchar((sign << 7) | (uint(exp) << 3) | (mant_rounded & 0x7));
}

// ————————————————————————————————————————————————————————————————
// F8E5M2 (Sign=1, Exponent=5, Mantissa=2; bias=2^(5−1)−1 = 15)
// ————————————————————————————————————————————————————————————————

inline float fp8_e5m2_to_float(uchar v) {
const uint sign = (v >> 7) & 0x1;
const uint exp_bits = (v >> 2) & 0x1F;
const uint man_bits = v & 0x3;

if (exp_bits == 0) {
if (man_bits == 0) {
return sign ? -0.0f : 0.0f;
}
float m = float(man_bits) / float(1 << 2);
float val = ldexp(m, 1 - 15 - 2);
return sign ? -val : val;
}
if (exp_bits == 0x1F) {
return sign ? -INFINITY : INFINITY;
}
float mant = 1.0f + float(man_bits) / float(1 << 2);
int expn = int(exp_bits) - 15;
float val = ldexp(mant, expn);
return sign ? -val : val;
}

inline uchar float_to_fp8_e5m2(float f) {
uint bits = as_type<uint>(f);
uint sign = bits >> 31;
int exp = int((bits >> 23) & 0xFF) - 127 + 15;
uint man = bits & 0x7FFFFF;

if (exp > 0x1D) {
return uchar((sign << 7) | (0x1F << 2));
}
if (exp <= 0) {
return uchar(sign << 7);
}
uint mant_rounded = (man + (1 << (23 - 2 - 1))) >> (23 - 2);
if (mant_rounded == (1 << 2)) {
mant_rounded = 0;
exp += 1;
if (exp >= 0x1F) {
return uchar((sign << 7) | (0x1F << 2));
}
}
return uchar((sign << 7) | (uint(exp) << 2) | (mant_rounded & 0x3));
}
Loading