Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 57 additions & 62 deletions mistralrs-core/src/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
};

use candle_core::{Device, Result, Tensor};
use mistralrs_quant::{get_use_matmul_via_f16, MatMul};
use mistralrs_quant::MatMul;

#[cfg(feature = "metal")]
/// Initial, sentinel value is usize::MAX
Expand Down Expand Up @@ -321,69 +321,64 @@ impl Sdpa {
// TODO: bench?
#[allow(unused)]
if let (Device::Cuda(_), Some(cublaslt)) = (q.device(), *CUBLASLT_HANDLE.lock().unwrap()) {
if !get_use_matmul_via_f16() {
#[cfg(feature = "cuda")]
{
// cuBLASLt batch matmul implementation requires inputs to be dims3
let k = k.flatten(0, 1)?;
let q = q.flatten(0, 1)?;
let v = v.flatten(0, 1)?;
let attention_bias = match mask {
Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
Some(mask.repeat((n_attn_heads, 1, 1))?)
}
Some(mask) if mask.rank() == 3 => Some(mask.clone()),
Some(mask) if mask.rank() == 4 => Some(mask.flatten(0, 1)?),
Some(mask) => {
candle_core::bail!("cublaslt attn mask: rank must be 3 or 4")
}
None => None,
};

// If attention_bias is set, we fuse the add by giving it as the output matrix
// and setting beta to 1.0
let beta = match attention_bias.is_some() {
true => Some(1.0),
false => None,
};

// Batch matrix multiplication
// Fuse softmax scale and attention_bias add
let mut attention_scores = cublaslt.batch_matmul(
&k,
&q,
attention_bias.as_ref(),
Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
beta,
None,
None,
)?;
if let Some(softcap) = sdpa_params.softcap {
attention_scores = (attention_scores.tanh()? * softcap as f64)?;
#[cfg(feature = "cuda")]
{
// cuBLASLt batch matmul implementation requires inputs to be dims3
let k = k.flatten(0, 1)?;
let q = q.flatten(0, 1)?;
let v = v.flatten(0, 1)?;
let attention_bias = match mask {
Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
Some(mask.repeat((n_attn_heads, 1, 1))?)
}
candle_nn::ops::inplace_softmax_last_dim(&mut attention_scores)?;

let context_layer = cublaslt.batch_matmul(
&v.t()?.contiguous().unwrap(),
&attention_scores,
// We save one allocation
Some(&q),
None,
None,
None,
None,
)?;

// Reshape to dims4
context_layer.reshape((b_sz, n_attn_heads, seq_len, v_head_dim))
}
#[cfg(not(feature = "cuda"))]
{
candle_core::bail!("`cuda` feature is not enabled")
Some(mask) if mask.rank() == 3 => Some(mask.clone()),
Some(mask) if mask.rank() == 4 => Some(mask.flatten(0, 1)?),
Some(mask) => {
candle_core::bail!("cublaslt attn mask: rank must be 3 or 4")
}
None => None,
};

// If attention_bias is set, we fuse the add by giving it as the output matrix
// and setting beta to 1.0
let beta = match attention_bias.is_some() {
true => Some(1.0),
false => None,
};

// Batch matrix multiplication
// Fuse softmax scale and attention_bias add
let mut attention_scores = cublaslt.batch_matmul(
&k,
&q,
attention_bias.as_ref(),
Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
beta,
None,
None,
)?;
if let Some(softcap) = sdpa_params.softcap {
attention_scores = (attention_scores.tanh()? * softcap as f64)?;
}
} else {
// Use the f16 kernels here if quantized (ISQ or GGML), and a large enough prompt
naive_sdpa(q, &k, &v, mask, sdpa_params)
candle_nn::ops::inplace_softmax_last_dim(&mut attention_scores)?;

let context_layer = cublaslt.batch_matmul(
&v.t()?.contiguous().unwrap(),
&attention_scores,
// We save one allocation
Some(&q),
None,
None,
None,
None,
)?;

// Reshape to dims4
context_layer.reshape((b_sz, n_attn_heads, seq_len, v_head_dim))
}
#[cfg(not(feature = "cuda"))]
{
candle_core::bail!("`cuda` feature is not enabled")
}
} else {
naive_sdpa(q, &k, &v, mask, sdpa_params)
Expand Down
13 changes: 4 additions & 9 deletions mistralrs-core/src/layers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ use candle_nn::{
use float8::F8E4M3;
use half::{bf16, f16};
use mistralrs_quant::{
get_use_matmul_via_f16, ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer,
ShardedVarBuilder,
ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer, ShardedVarBuilder,
};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -1300,17 +1299,13 @@ impl Module for QLinear {
} else {
xs.clone()
};
let forward_fn = if !get_use_matmul_via_f16() {
QMatMul::forward
} else {
QMatMul::forward_via_f16
};
if let Some(bias) = &self.bias {
forward_fn(&self.inner, &xs)?
self.inner
.forward(&xs)?
.broadcast_add(bias)?
.to_dtype(self.dtype)
} else {
forward_fn(&self.inner, &xs)?.to_dtype(self.dtype)
self.inner.forward(&xs)?.to_dtype(self.dtype)
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions mistralrs-core/src/layers_masker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
use std::ops::Add;

use candle_core::{DType, Device, Result, Tensor, WithDType};
use mistralrs_quant::get_use_matmul_via_f16;

use crate::pipeline::KvCache;

Expand Down Expand Up @@ -230,7 +229,7 @@ impl CausalMasker {
};

// IMPORTANT: this must match the logic in attention.rs. Assume the cublaslt handle will be initialized
if causal_mask.device().is_cuda() && !get_use_matmul_via_f16() {
if causal_mask.device().is_cuda() {
causal_mask = causal_mask.unsqueeze(0)?.repeat((n_attn_heads, 1, 1))?;
}

Expand Down Expand Up @@ -290,7 +289,7 @@ impl CausalMasker {
};

// IMPORTANT: this must match the logic in attention.rs. Assume the cublaslt handle will be initialized
if causal_mask.device().is_cuda() && !get_use_matmul_via_f16() {
if causal_mask.device().is_cuda() {
causal_mask = causal_mask.unsqueeze(0)?.repeat((n_attn_heads, 1, 1))?;
}

Expand Down
11 changes: 0 additions & 11 deletions mistralrs-core/src/pipeline/inputs_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ pub mod text_models_inputs_processor {

use anyhow::Result;
use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType};
use mistralrs_quant::set_use_matmul_via_f16;
use tokenizers::Tokenizer;

use crate::{
Expand All @@ -68,8 +67,6 @@ pub mod text_models_inputs_processor {

use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType};

const VIA_F16_TOK_THRESHOLD: usize = 512;

fn _make_tensor_with_pad<D: WithDType>(
x: Vec<Vec<D>>,
max_len: usize,
Expand Down Expand Up @@ -267,12 +264,6 @@ pub mod text_models_inputs_processor {
}

let input = Tensor::cat(&seqs_tensors, 0).unwrap();
// Only use matmul via f16 if prompt and seqlen > 512
if input.dim(1)? > VIA_F16_TOK_THRESHOLD {
set_use_matmul_via_f16(true);
} else {
set_use_matmul_via_f16(false);
}

let paged_attn_meta = if paged_attn_metadata.is_some() {
let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap();
Expand Down Expand Up @@ -445,8 +436,6 @@ pub mod text_models_inputs_processor {
seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?);
}

set_use_matmul_via_f16(false);

let paged_attn_meta = if paged_attn_metadata.is_some() {
let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?;

Expand Down
9 changes: 0 additions & 9 deletions mistralrs-quant/src/gguf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,6 @@ impl QuantMethod for GgufMatMul {
}
}

fn forward_via_half(&self, a: &Tensor) -> Result<Tensor> {
let x = self.w.forward_via_f16(a)?;
if let Some(ref b) = self.b {
x.broadcast_add(b)
} else {
Ok(x)
}
}

fn quantized_act_type(&self) -> Option<DType> {
Some(DType::F32)
}
Expand Down
61 changes: 12 additions & 49 deletions mistralrs-quant/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ use std::{
borrow::Cow,
fmt::{Debug, Display},
num::NonZeroUsize,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
},
sync::{atomic::AtomicUsize, Arc},
};

use blockwise_fp8::blockwise_fp8_linear_b;
Expand Down Expand Up @@ -160,26 +157,11 @@ pub enum QuantMethodConfig {
}

/// Device/configurable intelligent matrix multiplication
/// - Configurable to be via f16 (to use the faster GEMM kernels) optionally.
/// - Handles limitation of `accelerate` which requires f32
pub struct MatMul;

pub static INHIBIT_GEMM_F16: AtomicBool = AtomicBool::new(false);

/// Set the matmuls to go via f16
pub(crate) static USE_MATMUL_VIA_F16: AtomicBool = AtomicBool::new(false);

pub fn set_use_matmul_via_f16(via_f16: bool) {
if !INHIBIT_GEMM_F16.load(Ordering::Relaxed) {
USE_MATMUL_VIA_F16.store(via_f16, Ordering::Relaxed)
}
}
pub fn get_use_matmul_via_f16() -> bool {
USE_MATMUL_VIA_F16.load(Ordering::Relaxed)
}

impl MatMul {
/// Compute matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels.
/// Compute matrix-matrix product.
pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result<Tensor> {
#[cfg(feature = "accelerate")]
{
Expand All @@ -192,50 +174,37 @@ impl MatMul {
{
if a.device().is_cpu() {
let original_dtype = a.dtype();
return a
.to_dtype(DType::F16)?
a.to_dtype(DType::F16)?
.matmul(&b.to_dtype(DType::F16)?)?
.to_dtype(original_dtype);
} else if !get_use_matmul_via_f16() {
return a.matmul(b);
.to_dtype(original_dtype)
} else {
a.matmul(b)
}
let original_dtype = a.dtype();
a.to_dtype(DType::F16)?
.matmul(&b.to_dtype(DType::F16)?)?
.to_dtype(original_dtype)
}
}

/// Compute matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels.
/// Compute matrix-matrix product.
/// The result will be divided by the `scale` parameter in an affine division.
pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
// TODO(EricLBuehler): Optimize this by using the gemm parameter?
self.matmul(a, b)? / scale
}

/// Compute matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels.
/// Compute matrix-matrix product.
/// The result will be divided by the `scale` parameter in an affine multiplication.
pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result<Tensor> {
// TODO(EricLBuehler): Optimize this by using the gemm parameter?
self.matmul(a, b)? * scale
}

/// Compute quantized matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels.
/// Compute quantized matrix-matrix product.
pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result<Tensor> {
if get_use_matmul_via_f16() {
matmul.forward_via_f16(x)
} else {
matmul.forward(x)
}
matmul.forward(x)
}

/// Compute quantized matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels.
/// Compute quantized matrix-matrix product.
pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result<Tensor> {
if get_use_matmul_via_f16() {
matmul.forward_via_half(x)
} else {
matmul.forward(x)
}
matmul.forward(x)
}
}

Expand Down Expand Up @@ -438,12 +407,6 @@ pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde {
/// Compute matmul of `self` and `a`. `self` should contain the weights.
fn forward(&self, a: &Tensor) -> Result<Tensor>;

/// Compute matmul of `self` and `a`. `self` should contain the weights.
/// This may go via half precision if it is supported.
fn forward_via_half(&self, a: &Tensor) -> Result<Tensor> {
self.forward(a)
}

/// If a quantized method, return the activation dtype.
fn quantized_act_type(&self) -> Option<DType>;

Expand Down
Loading