Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 0 additions & 52 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ pub struct MistralRsBuilder {
no_prefix_cache: Option<bool>,
prefix_cache_n: Option<usize>,
disable_eos_stop: Option<bool>,
gemm_full_precision_f16: Option<bool>,
throughput_logging_enabled: Option<()>,
}

Expand All @@ -202,7 +201,6 @@ impl MistralRsBuilder {
no_prefix_cache: None,
prefix_cache_n: None,
disable_eos_stop: None,
gemm_full_precision_f16: None,
throughput_logging_enabled: None,
}
}
Expand Down Expand Up @@ -234,11 +232,6 @@ impl MistralRsBuilder {
self.disable_eos_stop = Some(disable_eos_stop);
self
}
/// This setting is only applicable on CUDA. If set to false or not specified, this setting enables f16/bf16 reduced precision matmul for GPUs which support it. If set to true, this setting has no effect.
pub fn with_gemm_full_precision_f16(mut self, gemm_full_precision: bool) -> Self {
self.gemm_full_precision_f16 = Some(gemm_full_precision);
self
}
pub fn with_throughput_logging(mut self) -> Self {
self.throughput_logging_enabled = Some(());
self
Expand All @@ -249,42 +242,6 @@ impl MistralRsBuilder {
}
}

#[cfg(feature = "cuda")]
fn set_gemm_reduced_precision_f16(device: candle_core::Device) {
use mistralrs_quant::INHIBIT_GEMM_F16;

use candle_core::{DType, Tensor};

let a = Tensor::zeros((2, 2), DType::BF16, &device).unwrap();
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
match a.matmul(&a) {
Ok(_) => tracing::info!("Enabling GEMM reduced precision in BF16."),
Err(e) => {
if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") {
tracing::info!("GEMM reduced precision in BF16 not supported.");
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed);
}
}
}

let a = Tensor::zeros((2, 2), DType::F16, &device).unwrap();
candle_core::cuda::set_gemm_reduced_precision_f16(true);
match a.matmul(&a) {
Ok(_) => tracing::info!("Enabling GEMM reduced precision in F16."),
Err(e) => {
if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") {
tracing::info!("GEMM reduced precision in F16 not supported.");
candle_core::cuda::set_gemm_reduced_precision_f16(false);
INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed);
}
}
}
}

#[cfg(not(feature = "cuda"))]
fn set_gemm_reduced_precision_f16(_device: candle_core::Device) {}

impl Drop for MistralRs {
fn drop(&mut self) {
ENGINE_INSTRUCTIONS
Expand All @@ -305,19 +262,10 @@ impl MistralRs {
no_prefix_cache,
prefix_cache_n,
disable_eos_stop,
gemm_full_precision_f16,
throughput_logging_enabled,
} = config;

let category = pipeline.try_lock().unwrap().category();
let model_supports_reduced_gemm = match category {
ModelCategory::Text => true,
ModelCategory::Vision { has_conv2d, .. } => !has_conv2d,
ModelCategory::Diffusion => true,
};
if !gemm_full_precision_f16.unwrap_or(false) && model_supports_reduced_gemm {
set_gemm_reduced_precision_f16(get_mut_arcmutex!(pipeline).device());
}
setup_cublas_lt_wrapper(get_mut_arcmutex!(pipeline).device());

let truncate_sequence = truncate_sequence.unwrap_or(false);
Expand Down
4 changes: 1 addition & 3 deletions mistralrs-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,7 @@ async fn main() -> Result<()> {
.with_opt_log(args.log)
.with_truncate_sequence(args.truncate_sequence)
.with_no_kv_cache(args.no_kv_cache)
.with_prefix_cache_n(args.prefix_cache_n)
.with_gemm_full_precision_f16(args.cpu)
.with_gemm_full_precision_f16(args.cpu); // Required to allow `cuda` build to use `--cpu`, #1056
.with_prefix_cache_n(args.prefix_cache_n);

if args.interactive_mode {
interactive_mode(builder.build(), args.throughput_log).await;
Expand Down
1 change: 0 additions & 1 deletion mistralrs/src/anymoe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ impl AnyMoeModelBuilder {

let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
.with_no_kv_cache(self.base.no_kv_cache)
.with_gemm_full_precision_f16(true)
.with_no_prefix_cache(self.base.prefix_cache_n.is_none());

if let Some(n) = self.base.prefix_cache_n {
Expand Down
3 changes: 1 addition & 2 deletions mistralrs/src/diffusion_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ impl DiffusionModelBuilder {
method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
};

let runner =
MistralRsBuilder::new(pipeline, scheduler_method).with_gemm_full_precision_f16(true);
let runner = MistralRsBuilder::new(pipeline, scheduler_method);

Ok(Model::new(runner.build()))
}
Expand Down
1 change: 0 additions & 1 deletion mistralrs/src/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,6 @@ impl GgufModelBuilder {

let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
.with_no_kv_cache(self.no_kv_cache)
.with_gemm_full_precision_f16(true)
.with_no_prefix_cache(self.prefix_cache_n.is_none());

if let Some(n) = self.prefix_cache_n {
Expand Down
1 change: 0 additions & 1 deletion mistralrs/src/gguf_lora_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ impl GgufLoraModelBuilder {

let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
.with_no_kv_cache(self.gguf_model.no_kv_cache)
.with_gemm_full_precision_f16(true)
.with_no_prefix_cache(self.gguf_model.prefix_cache_n.is_none());

if let Some(n) = self.gguf_model.prefix_cache_n {
Expand Down
1 change: 0 additions & 1 deletion mistralrs/src/gguf_xlora_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ impl GgufXLoraModelBuilder {

let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
.with_no_kv_cache(self.gguf_model.no_kv_cache)
.with_gemm_full_precision_f16(true)
.with_no_prefix_cache(self.gguf_model.prefix_cache_n.is_none());

if let Some(n) = self.gguf_model.prefix_cache_n {
Expand Down
1 change: 0 additions & 1 deletion mistralrs/src/lora_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ impl LoraModelBuilder {

let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
.with_no_kv_cache(self.text_model.no_kv_cache)
.with_gemm_full_precision_f16(true)
.with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());

if let Some(n) = self.text_model.prefix_cache_n {
Expand Down
3 changes: 1 addition & 2 deletions mistralrs/src/speculative.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ impl TextSpeculativeBuilder {
self.speculative_config,
)?));

let runner =
MistralRsBuilder::new(pipeline, scheduler_method).with_gemm_full_precision_f16(true);
let runner = MistralRsBuilder::new(pipeline, scheduler_method);

Ok(Model::new(runner.build()))
}
Expand Down
1 change: 0 additions & 1 deletion mistralrs/src/text_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,6 @@ impl TextModelBuilder {

let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
.with_no_kv_cache(self.no_kv_cache)
.with_gemm_full_precision_f16(true)
.with_no_prefix_cache(self.prefix_cache_n.is_none());

if let Some(n) = self.prefix_cache_n {
Expand Down
1 change: 0 additions & 1 deletion mistralrs/src/vision_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ impl VisionModelBuilder {

let runner = MistralRsBuilder::new(pipeline, scheduler_method)
.with_no_kv_cache(false)
.with_gemm_full_precision_f16(true)
.with_no_prefix_cache(false);

Ok(Model::new(runner.build()))
Expand Down
1 change: 0 additions & 1 deletion mistralrs/src/xlora_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ impl XLoraModelBuilder {

let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
.with_no_kv_cache(self.text_model.no_kv_cache)
.with_gemm_full_precision_f16(true)
.with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());

if let Some(n) = self.text_model.prefix_cache_n {
Expand Down
Loading