diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index 36ba1025f7..66d5c67c8a 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -187,7 +187,6 @@ pub struct MistralRsBuilder { no_prefix_cache: Option, prefix_cache_n: Option, disable_eos_stop: Option, - gemm_full_precision_f16: Option, throughput_logging_enabled: Option<()>, } @@ -202,7 +201,6 @@ impl MistralRsBuilder { no_prefix_cache: None, prefix_cache_n: None, disable_eos_stop: None, - gemm_full_precision_f16: None, throughput_logging_enabled: None, } } @@ -234,11 +232,6 @@ impl MistralRsBuilder { self.disable_eos_stop = Some(disable_eos_stop); self } - /// This setting is only applicable on CUDA. If set to false or not specified, this setting enables f16/bf16 reduced precision matmul for GPUs which support it. If set to true, this setting has no effect. - pub fn with_gemm_full_precision_f16(mut self, gemm_full_precision: bool) -> Self { - self.gemm_full_precision_f16 = Some(gemm_full_precision); - self - } pub fn with_throughput_logging(mut self) -> Self { self.throughput_logging_enabled = Some(()); self @@ -249,42 +242,6 @@ impl MistralRsBuilder { } } -#[cfg(feature = "cuda")] -fn set_gemm_reduced_precision_f16(device: candle_core::Device) { - use mistralrs_quant::INHIBIT_GEMM_F16; - - use candle_core::{DType, Tensor}; - - let a = Tensor::zeros((2, 2), DType::BF16, &device).unwrap(); - candle_core::cuda::set_gemm_reduced_precision_bf16(true); - match a.matmul(&a) { - Ok(_) => tracing::info!("Enabling GEMM reduced precision in BF16."), - Err(e) => { - if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") { - tracing::info!("GEMM reduced precision in BF16 not supported."); - candle_core::cuda::set_gemm_reduced_precision_bf16(false); - INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed); - } - } - } - - let a = Tensor::zeros((2, 2), DType::F16, &device).unwrap(); - candle_core::cuda::set_gemm_reduced_precision_f16(true); - match a.matmul(&a) { - Ok(_) => tracing::info!("Enabling GEMM reduced precision in F16."), - Err(e) => { - if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") { - tracing::info!("GEMM reduced precision in F16 not supported."); - candle_core::cuda::set_gemm_reduced_precision_f16(false); - INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed); - } - } - } -} - -#[cfg(not(feature = "cuda"))] -fn set_gemm_reduced_precision_f16(_device: candle_core::Device) {} - impl Drop for MistralRs { fn drop(&mut self) { ENGINE_INSTRUCTIONS @@ -305,19 +262,10 @@ impl MistralRs { no_prefix_cache, prefix_cache_n, disable_eos_stop, - gemm_full_precision_f16, throughput_logging_enabled, } = config; let category = pipeline.try_lock().unwrap().category(); - let model_supports_reduced_gemm = match category { - ModelCategory::Text => true, - ModelCategory::Vision { has_conv2d, .. } => !has_conv2d, - ModelCategory::Diffusion => true, - }; - if !gemm_full_precision_f16.unwrap_or(false) && model_supports_reduced_gemm { - set_gemm_reduced_precision_f16(get_mut_arcmutex!(pipeline).device()); - } setup_cublas_lt_wrapper(get_mut_arcmutex!(pipeline).device()); let truncate_sequence = truncate_sequence.unwrap_or(false); diff --git a/mistralrs-server/src/main.rs b/mistralrs-server/src/main.rs index 3cf375e8e6..ea2391cb6a 100644 --- a/mistralrs-server/src/main.rs +++ b/mistralrs-server/src/main.rs @@ -478,9 +478,7 @@ async fn main() -> Result<()> { .with_opt_log(args.log) .with_truncate_sequence(args.truncate_sequence) .with_no_kv_cache(args.no_kv_cache) - .with_prefix_cache_n(args.prefix_cache_n) - .with_gemm_full_precision_f16(args.cpu) - .with_gemm_full_precision_f16(args.cpu); // Required to allow `cuda` build to use `--cpu`, #1056 + .with_prefix_cache_n(args.prefix_cache_n); if args.interactive_mode { interactive_mode(builder.build(), args.throughput_log).await; diff --git a/mistralrs/src/anymoe.rs b/mistralrs/src/anymoe.rs index 0660f162fb..d2330559a9 100644 --- a/mistralrs/src/anymoe.rs +++ b/mistralrs/src/anymoe.rs @@ -112,7 +112,6 @@ impl AnyMoeModelBuilder { let mut runner = MistralRsBuilder::new(pipeline, scheduler_method) .with_no_kv_cache(self.base.no_kv_cache) - .with_gemm_full_precision_f16(true) .with_no_prefix_cache(self.base.prefix_cache_n.is_none()); if let Some(n) = self.base.prefix_cache_n { diff --git a/mistralrs/src/diffusion_model.rs b/mistralrs/src/diffusion_model.rs index e16cdad4c8..de9379b46a 100644 --- a/mistralrs/src/diffusion_model.rs +++ b/mistralrs/src/diffusion_model.rs @@ -102,8 +102,7 @@ impl DiffusionModelBuilder { method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?), }; - let runner = - MistralRsBuilder::new(pipeline, scheduler_method).with_gemm_full_precision_f16(true); + let runner = MistralRsBuilder::new(pipeline, scheduler_method); Ok(Model::new(runner.build())) } diff --git a/mistralrs/src/gguf.rs b/mistralrs/src/gguf.rs index c81bfec5af..ff5a4c1dc9 100644 --- a/mistralrs/src/gguf.rs +++ b/mistralrs/src/gguf.rs @@ -206,7 +206,6 @@ impl GgufModelBuilder { let mut runner = MistralRsBuilder::new(pipeline, scheduler_method) .with_no_kv_cache(self.no_kv_cache) - .with_gemm_full_precision_f16(true) .with_no_prefix_cache(self.prefix_cache_n.is_none()); if let Some(n) = self.prefix_cache_n { diff --git a/mistralrs/src/gguf_lora_model.rs b/mistralrs/src/gguf_lora_model.rs index c4caecd84f..25ffe0c311 100644 --- a/mistralrs/src/gguf_lora_model.rs +++ b/mistralrs/src/gguf_lora_model.rs @@ -79,7 +79,6 @@ impl GgufLoraModelBuilder { let mut runner = MistralRsBuilder::new(pipeline, scheduler_method) .with_no_kv_cache(self.gguf_model.no_kv_cache) - .with_gemm_full_precision_f16(true) .with_no_prefix_cache(self.gguf_model.prefix_cache_n.is_none()); if let Some(n) = self.gguf_model.prefix_cache_n { diff --git a/mistralrs/src/gguf_xlora_model.rs b/mistralrs/src/gguf_xlora_model.rs index d29662e18c..5ca80d6858 100644 --- a/mistralrs/src/gguf_xlora_model.rs +++ b/mistralrs/src/gguf_xlora_model.rs @@ -91,7 +91,6 @@ impl GgufXLoraModelBuilder { let mut runner = MistralRsBuilder::new(pipeline, scheduler_method) .with_no_kv_cache(self.gguf_model.no_kv_cache) - .with_gemm_full_precision_f16(true) .with_no_prefix_cache(self.gguf_model.prefix_cache_n.is_none()); if let Some(n) = self.gguf_model.prefix_cache_n { diff --git a/mistralrs/src/lora_model.rs b/mistralrs/src/lora_model.rs index aac1a21751..9bef37993a 100644 --- a/mistralrs/src/lora_model.rs +++ b/mistralrs/src/lora_model.rs @@ -85,7 +85,6 @@ impl LoraModelBuilder { let mut runner = MistralRsBuilder::new(pipeline, scheduler_method) .with_no_kv_cache(self.text_model.no_kv_cache) - .with_gemm_full_precision_f16(true) .with_no_prefix_cache(self.text_model.prefix_cache_n.is_none()); if let Some(n) = self.text_model.prefix_cache_n { diff --git a/mistralrs/src/speculative.rs b/mistralrs/src/speculative.rs index ff51404a24..78eb5b0c97 100644 --- a/mistralrs/src/speculative.rs +++ b/mistralrs/src/speculative.rs @@ -92,8 +92,7 @@ impl TextSpeculativeBuilder { self.speculative_config, )?)); - let runner = - MistralRsBuilder::new(pipeline, scheduler_method).with_gemm_full_precision_f16(true); + let runner = MistralRsBuilder::new(pipeline, scheduler_method); Ok(Model::new(runner.build())) } diff --git a/mistralrs/src/text_model.rs b/mistralrs/src/text_model.rs index 86a0f56552..1a23df7c00 100644 --- a/mistralrs/src/text_model.rs +++ b/mistralrs/src/text_model.rs @@ -309,7 +309,6 @@ impl TextModelBuilder { let mut runner = MistralRsBuilder::new(pipeline, scheduler_method) .with_no_kv_cache(self.no_kv_cache) - .with_gemm_full_precision_f16(true) .with_no_prefix_cache(self.prefix_cache_n.is_none()); if let Some(n) = self.prefix_cache_n { diff --git a/mistralrs/src/vision_model.rs b/mistralrs/src/vision_model.rs index 41517156b2..403fd31f9b 100644 --- a/mistralrs/src/vision_model.rs +++ b/mistralrs/src/vision_model.rs @@ -210,7 +210,6 @@ impl VisionModelBuilder { let runner = MistralRsBuilder::new(pipeline, scheduler_method) .with_no_kv_cache(false) - .with_gemm_full_precision_f16(true) .with_no_prefix_cache(false); Ok(Model::new(runner.build())) diff --git a/mistralrs/src/xlora_model.rs b/mistralrs/src/xlora_model.rs index 1972919f7a..c642caf02d 100644 --- a/mistralrs/src/xlora_model.rs +++ b/mistralrs/src/xlora_model.rs @@ -97,7 +97,6 @@ impl XLoraModelBuilder { let mut runner = MistralRsBuilder::new(pipeline, scheduler_method) .with_no_kv_cache(self.text_model.no_kv_cache) - .with_gemm_full_precision_f16(true) .with_no_prefix_cache(self.text_model.prefix_cache_n.is_none()); if let Some(n) = self.text_model.prefix_cache_n {