diff --git a/mistralrs-core/src/attention.rs b/mistralrs-core/src/attention.rs index 66692e39e0..057636b0c0 100644 --- a/mistralrs-core/src/attention.rs +++ b/mistralrs-core/src/attention.rs @@ -8,7 +8,7 @@ use crate::{ }; use candle_core::{Device, Result, Tensor}; -use mistralrs_quant::{get_use_matmul_via_f16, MatMul}; +use mistralrs_quant::MatMul; #[cfg(feature = "metal")] /// Initial, sentinel value is usize::MAX @@ -321,69 +321,64 @@ impl Sdpa { // TODO: bench? #[allow(unused)] if let (Device::Cuda(_), Some(cublaslt)) = (q.device(), *CUBLASLT_HANDLE.lock().unwrap()) { - if !get_use_matmul_via_f16() { - #[cfg(feature = "cuda")] - { - // cuBLASLt batch matmul implementation requires inputs to be dims3 - let k = k.flatten(0, 1)?; - let q = q.flatten(0, 1)?; - let v = v.flatten(0, 1)?; - let attention_bias = match mask { - Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => { - Some(mask.repeat((n_attn_heads, 1, 1))?) - } - Some(mask) if mask.rank() == 3 => Some(mask.clone()), - Some(mask) if mask.rank() == 4 => Some(mask.flatten(0, 1)?), - Some(mask) => { - candle_core::bail!("cublaslt attn mask: rank must be 3 or 4") - } - None => None, - }; - - // If attention_bias is set, we fuse the add by giving it as the output matrix - // and setting beta to 1.0 - let beta = match attention_bias.is_some() { - true => Some(1.0), - false => None, - }; - - // Batch matrix multiplication - // Fuse softmax scale and attention_bias add - let mut attention_scores = cublaslt.batch_matmul( - &k, - &q, - attention_bias.as_ref(), - Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)), - beta, - None, - None, - )?; - if let Some(softcap) = sdpa_params.softcap { - attention_scores = (attention_scores.tanh()? * softcap as f64)?; + #[cfg(feature = "cuda")] + { + // cuBLASLt batch matmul implementation requires inputs to be dims3 + let k = k.flatten(0, 1)?; + let q = q.flatten(0, 1)?; + let v = v.flatten(0, 1)?; + let attention_bias = match mask { + Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => { + Some(mask.repeat((n_attn_heads, 1, 1))?) } - candle_nn::ops::inplace_softmax_last_dim(&mut attention_scores)?; - - let context_layer = cublaslt.batch_matmul( - &v.t()?.contiguous().unwrap(), - &attention_scores, - // We save one allocation - Some(&q), - None, - None, - None, - None, - )?; - - // Reshape to dims4 - context_layer.reshape((b_sz, n_attn_heads, seq_len, v_head_dim)) - } - #[cfg(not(feature = "cuda"))] - { - candle_core::bail!("`cuda` feature is not enabled") + Some(mask) if mask.rank() == 3 => Some(mask.clone()), + Some(mask) if mask.rank() == 4 => Some(mask.flatten(0, 1)?), + Some(mask) => { + candle_core::bail!("cublaslt attn mask: rank must be 3 or 4") + } + None => None, + }; + + // If attention_bias is set, we fuse the add by giving it as the output matrix + // and setting beta to 1.0 + let beta = match attention_bias.is_some() { + true => Some(1.0), + false => None, + }; + + // Batch matrix multiplication + // Fuse softmax scale and attention_bias add + let mut attention_scores = cublaslt.batch_matmul( + &k, + &q, + attention_bias.as_ref(), + Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)), + beta, + None, + None, + )?; + if let Some(softcap) = sdpa_params.softcap { + attention_scores = (attention_scores.tanh()? * softcap as f64)?; } - } else { - // Use the f16 kernels here if quantized (ISQ or GGML), and a large enough prompt - naive_sdpa(q, &k, &v, mask, sdpa_params) + candle_nn::ops::inplace_softmax_last_dim(&mut attention_scores)?; + + let context_layer = cublaslt.batch_matmul( + &v.t()?.contiguous().unwrap(), + &attention_scores, + // We save one allocation + Some(&q), + None, + None, + None, + None, + )?; + + // Reshape to dims4 + context_layer.reshape((b_sz, n_attn_heads, seq_len, v_head_dim)) + } + #[cfg(not(feature = "cuda"))] + { + candle_core::bail!("`cuda` feature is not enabled") } } else { naive_sdpa(q, &k, &v, mask, sdpa_params) diff --git a/mistralrs-core/src/layers.rs b/mistralrs-core/src/layers.rs index 298af00c1f..15fe1bbde2 100644 --- a/mistralrs-core/src/layers.rs +++ b/mistralrs-core/src/layers.rs @@ -12,8 +12,7 @@ use candle_nn::{ use float8::F8E4M3; use half::{bf16, f16}; use mistralrs_quant::{ - get_use_matmul_via_f16, ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer, - ShardedVarBuilder, + ColumnParallelLayer, QuantMethod, QuantizedConfig, RowParallelLayer, ShardedVarBuilder, }; use serde::{Deserialize, Serialize}; @@ -1300,17 +1299,13 @@ impl Module for QLinear { } else { xs.clone() }; - let forward_fn = if !get_use_matmul_via_f16() { - QMatMul::forward - } else { - QMatMul::forward_via_f16 - }; if let Some(bias) = &self.bias { - forward_fn(&self.inner, &xs)? + self.inner + .forward(&xs)? .broadcast_add(bias)? .to_dtype(self.dtype) } else { - forward_fn(&self.inner, &xs)?.to_dtype(self.dtype) + self.inner.forward(&xs)?.to_dtype(self.dtype) } } } diff --git a/mistralrs-core/src/layers_masker.rs b/mistralrs-core/src/layers_masker.rs index 5598517d17..d292b290ce 100644 --- a/mistralrs-core/src/layers_masker.rs +++ b/mistralrs-core/src/layers_masker.rs @@ -3,7 +3,6 @@ use std::ops::Add; use candle_core::{DType, Device, Result, Tensor, WithDType}; -use mistralrs_quant::get_use_matmul_via_f16; use crate::pipeline::KvCache; @@ -230,7 +229,7 @@ impl CausalMasker { }; // IMPORTANT: this must match the logic in attention.rs. Assume the cublaslt handle will be initialized - if causal_mask.device().is_cuda() && !get_use_matmul_via_f16() { + if causal_mask.device().is_cuda() { causal_mask = causal_mask.unsqueeze(0)?.repeat((n_attn_heads, 1, 1))?; } @@ -290,7 +289,7 @@ impl CausalMasker { }; // IMPORTANT: this must match the logic in attention.rs. Assume the cublaslt handle will be initialized - if causal_mask.device().is_cuda() && !get_use_matmul_via_f16() { + if causal_mask.device().is_cuda() { causal_mask = causal_mask.unsqueeze(0)?.repeat((n_attn_heads, 1, 1))?; } diff --git a/mistralrs-core/src/pipeline/inputs_processor.rs b/mistralrs-core/src/pipeline/inputs_processor.rs index cfd3eae961..8487050625 100644 --- a/mistralrs-core/src/pipeline/inputs_processor.rs +++ b/mistralrs-core/src/pipeline/inputs_processor.rs @@ -57,7 +57,6 @@ pub mod text_models_inputs_processor { use anyhow::Result; use candle_core::{DType, Device, DeviceLocation, Tensor, WithDType}; - use mistralrs_quant::set_use_matmul_via_f16; use tokenizers::Tokenizer; use crate::{ @@ -68,8 +67,6 @@ pub mod text_models_inputs_processor { use super::{InputProcessorOutput, InputsProcessor, InputsProcessorType}; - const VIA_F16_TOK_THRESHOLD: usize = 512; - fn _make_tensor_with_pad( x: Vec>, max_len: usize, @@ -267,12 +264,6 @@ pub mod text_models_inputs_processor { } let input = Tensor::cat(&seqs_tensors, 0).unwrap(); - // Only use matmul via f16 if prompt and seqlen > 512 - if input.dim(1)? > VIA_F16_TOK_THRESHOLD { - set_use_matmul_via_f16(true); - } else { - set_use_matmul_via_f16(false); - } let paged_attn_meta = if paged_attn_metadata.is_some() { let max_slot_mapping_len = slot_mappings.iter().map(|x| x.len()).max().unwrap(); @@ -445,8 +436,6 @@ pub mod text_models_inputs_processor { seqlens_k_map.insert(device.location(), seqlens_k.to_device(&device)?); } - set_use_matmul_via_f16(false); - let paged_attn_meta = if paged_attn_metadata.is_some() { let slot_mappings = _make_tensor_with_pad(slot_mappings, 1, _PAD_SLOT_ID, device)?; diff --git a/mistralrs-quant/src/gguf/mod.rs b/mistralrs-quant/src/gguf/mod.rs index 1842126820..81c834879b 100644 --- a/mistralrs-quant/src/gguf/mod.rs +++ b/mistralrs-quant/src/gguf/mod.rs @@ -57,15 +57,6 @@ impl QuantMethod for GgufMatMul { } } - fn forward_via_half(&self, a: &Tensor) -> Result { - let x = self.w.forward_via_f16(a)?; - if let Some(ref b) = self.b { - x.broadcast_add(b) - } else { - Ok(x) - } - } - fn quantized_act_type(&self) -> Option { Some(DType::F32) } diff --git a/mistralrs-quant/src/lib.rs b/mistralrs-quant/src/lib.rs index 0197f4a7d5..e09488d3f7 100644 --- a/mistralrs-quant/src/lib.rs +++ b/mistralrs-quant/src/lib.rs @@ -2,10 +2,7 @@ use std::{ borrow::Cow, fmt::{Debug, Display}, num::NonZeroUsize, - sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, - Arc, - }, + sync::{atomic::AtomicUsize, Arc}, }; use blockwise_fp8::blockwise_fp8_linear_b; @@ -160,26 +157,11 @@ pub enum QuantMethodConfig { } /// Device/configurable intelligent matrix multiplication -/// - Configurable to be via f16 (to use the faster GEMM kernels) optionally. /// - Handles limitation of `accelerate` which requires f32 pub struct MatMul; -pub static INHIBIT_GEMM_F16: AtomicBool = AtomicBool::new(false); - -/// Set the matmuls to go via f16 -pub(crate) static USE_MATMUL_VIA_F16: AtomicBool = AtomicBool::new(false); - -pub fn set_use_matmul_via_f16(via_f16: bool) { - if !INHIBIT_GEMM_F16.load(Ordering::Relaxed) { - USE_MATMUL_VIA_F16.store(via_f16, Ordering::Relaxed) - } -} -pub fn get_use_matmul_via_f16() -> bool { - USE_MATMUL_VIA_F16.load(Ordering::Relaxed) -} - impl MatMul { - /// Compute matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels. + /// Compute matrix-matrix product. pub fn matmul(&self, a: &Tensor, b: &Tensor) -> Result { #[cfg(feature = "accelerate")] { @@ -192,50 +174,37 @@ impl MatMul { { if a.device().is_cpu() { let original_dtype = a.dtype(); - return a - .to_dtype(DType::F16)? + a.to_dtype(DType::F16)? .matmul(&b.to_dtype(DType::F16)?)? - .to_dtype(original_dtype); - } else if !get_use_matmul_via_f16() { - return a.matmul(b); + .to_dtype(original_dtype) + } else { + a.matmul(b) } - let original_dtype = a.dtype(); - a.to_dtype(DType::F16)? - .matmul(&b.to_dtype(DType::F16)?)? - .to_dtype(original_dtype) } } - /// Compute matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels. + /// Compute matrix-matrix product. /// The result will be divided by the `scale` parameter in an affine division. pub fn matmul_affine_div(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result { // TODO(EricLBuehler): Optimize this by using the gemm parameter? self.matmul(a, b)? / scale } - /// Compute matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels. + /// Compute matrix-matrix product. /// The result will be divided by the `scale` parameter in an affine multiplication. pub fn matmul_affine_mul(&self, a: &Tensor, b: &Tensor, scale: f64) -> Result { // TODO(EricLBuehler): Optimize this by using the gemm parameter? self.matmul(a, b)? * scale } - /// Compute quantized matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels. + /// Compute quantized matrix-matrix product. pub fn qmatmul(&self, x: &Tensor, matmul: &QMatMul) -> Result { - if get_use_matmul_via_f16() { - matmul.forward_via_f16(x) - } else { - matmul.forward(x) - } + matmul.forward(x) } - /// Compute quantized matrix-matrix product, optionally casting to f16 to use specialized GEMM kernels. + /// Compute quantized matrix-matrix product. pub fn qmethod_matmul(&self, x: &Tensor, matmul: &dyn QuantMethod) -> Result { - if get_use_matmul_via_f16() { - matmul.forward_via_half(x) - } else { - matmul.forward(x) - } + matmul.forward(x) } } @@ -438,12 +407,6 @@ pub trait QuantMethod: Send + Sync + Debug + QuantizedSerde { /// Compute matmul of `self` and `a`. `self` should contain the weights. fn forward(&self, a: &Tensor) -> Result; - /// Compute matmul of `self` and `a`. `self` should contain the weights. - /// This may go via half precision if it is supported. - fn forward_via_half(&self, a: &Tensor) -> Result { - self.forward(a) - } - /// If a quantized method, return the activation dtype. fn quantized_act_type(&self) -> Option;