Skip to content

Commit afb15a5

Browse files
authored
Remove API for matmul_via_f16 (#1197)
1 parent 515bd1c commit afb15a5

File tree

12 files changed

+3
-67
lines changed

12 files changed

+3
-67
lines changed

mistralrs-core/src/lib.rs

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ pub struct MistralRsBuilder {
187187
no_prefix_cache: Option<bool>,
188188
prefix_cache_n: Option<usize>,
189189
disable_eos_stop: Option<bool>,
190-
gemm_full_precision_f16: Option<bool>,
191190
throughput_logging_enabled: Option<()>,
192191
}
193192

@@ -202,7 +201,6 @@ impl MistralRsBuilder {
202201
no_prefix_cache: None,
203202
prefix_cache_n: None,
204203
disable_eos_stop: None,
205-
gemm_full_precision_f16: None,
206204
throughput_logging_enabled: None,
207205
}
208206
}
@@ -234,11 +232,6 @@ impl MistralRsBuilder {
234232
self.disable_eos_stop = Some(disable_eos_stop);
235233
self
236234
}
237-
/// This setting is only applicable on CUDA. If set to false or not specified, this setting enables f16/bf16 reduced precision matmul for GPUs which support it. If set to true, this setting has no effect.
238-
pub fn with_gemm_full_precision_f16(mut self, gemm_full_precision: bool) -> Self {
239-
self.gemm_full_precision_f16 = Some(gemm_full_precision);
240-
self
241-
}
242235
pub fn with_throughput_logging(mut self) -> Self {
243236
self.throughput_logging_enabled = Some(());
244237
self
@@ -249,42 +242,6 @@ impl MistralRsBuilder {
249242
}
250243
}
251244

252-
#[cfg(feature = "cuda")]
253-
fn set_gemm_reduced_precision_f16(device: candle_core::Device) {
254-
use mistralrs_quant::INHIBIT_GEMM_F16;
255-
256-
use candle_core::{DType, Tensor};
257-
258-
let a = Tensor::zeros((2, 2), DType::BF16, &device).unwrap();
259-
candle_core::cuda::set_gemm_reduced_precision_bf16(true);
260-
match a.matmul(&a) {
261-
Ok(_) => tracing::info!("Enabling GEMM reduced precision in BF16."),
262-
Err(e) => {
263-
if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") {
264-
tracing::info!("GEMM reduced precision in BF16 not supported.");
265-
candle_core::cuda::set_gemm_reduced_precision_bf16(false);
266-
INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed);
267-
}
268-
}
269-
}
270-
271-
let a = Tensor::zeros((2, 2), DType::F16, &device).unwrap();
272-
candle_core::cuda::set_gemm_reduced_precision_f16(true);
273-
match a.matmul(&a) {
274-
Ok(_) => tracing::info!("Enabling GEMM reduced precision in F16."),
275-
Err(e) => {
276-
if format!("{e:?}").contains("CUBLAS_STATUS_NOT_SUPPORTED") {
277-
tracing::info!("GEMM reduced precision in F16 not supported.");
278-
candle_core::cuda::set_gemm_reduced_precision_f16(false);
279-
INHIBIT_GEMM_F16.store(true, std::sync::atomic::Ordering::Relaxed);
280-
}
281-
}
282-
}
283-
}
284-
285-
#[cfg(not(feature = "cuda"))]
286-
fn set_gemm_reduced_precision_f16(_device: candle_core::Device) {}
287-
288245
impl Drop for MistralRs {
289246
fn drop(&mut self) {
290247
ENGINE_INSTRUCTIONS
@@ -305,19 +262,10 @@ impl MistralRs {
305262
no_prefix_cache,
306263
prefix_cache_n,
307264
disable_eos_stop,
308-
gemm_full_precision_f16,
309265
throughput_logging_enabled,
310266
} = config;
311267

312268
let category = pipeline.try_lock().unwrap().category();
313-
let model_supports_reduced_gemm = match category {
314-
ModelCategory::Text => true,
315-
ModelCategory::Vision { has_conv2d, .. } => !has_conv2d,
316-
ModelCategory::Diffusion => true,
317-
};
318-
if !gemm_full_precision_f16.unwrap_or(false) && model_supports_reduced_gemm {
319-
set_gemm_reduced_precision_f16(get_mut_arcmutex!(pipeline).device());
320-
}
321269
setup_cublas_lt_wrapper(get_mut_arcmutex!(pipeline).device());
322270

323271
let truncate_sequence = truncate_sequence.unwrap_or(false);

mistralrs-server/src/main.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,7 @@ async fn main() -> Result<()> {
478478
.with_opt_log(args.log)
479479
.with_truncate_sequence(args.truncate_sequence)
480480
.with_no_kv_cache(args.no_kv_cache)
481-
.with_prefix_cache_n(args.prefix_cache_n)
482-
.with_gemm_full_precision_f16(args.cpu)
483-
.with_gemm_full_precision_f16(args.cpu); // Required to allow `cuda` build to use `--cpu`, #1056
481+
.with_prefix_cache_n(args.prefix_cache_n);
484482

485483
if args.interactive_mode {
486484
interactive_mode(builder.build(), args.throughput_log).await;

mistralrs/src/anymoe.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ impl AnyMoeModelBuilder {
112112

113113
let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
114114
.with_no_kv_cache(self.base.no_kv_cache)
115-
.with_gemm_full_precision_f16(true)
116115
.with_no_prefix_cache(self.base.prefix_cache_n.is_none());
117116

118117
if let Some(n) = self.base.prefix_cache_n {

mistralrs/src/diffusion_model.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,7 @@ impl DiffusionModelBuilder {
102102
method: DefaultSchedulerMethod::Fixed(self.max_num_seqs.try_into()?),
103103
};
104104

105-
let runner =
106-
MistralRsBuilder::new(pipeline, scheduler_method).with_gemm_full_precision_f16(true);
105+
let runner = MistralRsBuilder::new(pipeline, scheduler_method);
107106

108107
Ok(Model::new(runner.build()))
109108
}

mistralrs/src/gguf.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ impl GgufModelBuilder {
206206

207207
let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
208208
.with_no_kv_cache(self.no_kv_cache)
209-
.with_gemm_full_precision_f16(true)
210209
.with_no_prefix_cache(self.prefix_cache_n.is_none());
211210

212211
if let Some(n) = self.prefix_cache_n {

mistralrs/src/gguf_lora_model.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ impl GgufLoraModelBuilder {
7979

8080
let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
8181
.with_no_kv_cache(self.gguf_model.no_kv_cache)
82-
.with_gemm_full_precision_f16(true)
8382
.with_no_prefix_cache(self.gguf_model.prefix_cache_n.is_none());
8483

8584
if let Some(n) = self.gguf_model.prefix_cache_n {

mistralrs/src/gguf_xlora_model.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ impl GgufXLoraModelBuilder {
9191

9292
let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
9393
.with_no_kv_cache(self.gguf_model.no_kv_cache)
94-
.with_gemm_full_precision_f16(true)
9594
.with_no_prefix_cache(self.gguf_model.prefix_cache_n.is_none());
9695

9796
if let Some(n) = self.gguf_model.prefix_cache_n {

mistralrs/src/lora_model.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ impl LoraModelBuilder {
8585

8686
let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
8787
.with_no_kv_cache(self.text_model.no_kv_cache)
88-
.with_gemm_full_precision_f16(true)
8988
.with_no_prefix_cache(self.text_model.prefix_cache_n.is_none());
9089

9190
if let Some(n) = self.text_model.prefix_cache_n {

mistralrs/src/speculative.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ impl TextSpeculativeBuilder {
9292
self.speculative_config,
9393
)?));
9494

95-
let runner =
96-
MistralRsBuilder::new(pipeline, scheduler_method).with_gemm_full_precision_f16(true);
95+
let runner = MistralRsBuilder::new(pipeline, scheduler_method);
9796

9897
Ok(Model::new(runner.build()))
9998
}

mistralrs/src/text_model.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,6 @@ impl TextModelBuilder {
309309

310310
let mut runner = MistralRsBuilder::new(pipeline, scheduler_method)
311311
.with_no_kv_cache(self.no_kv_cache)
312-
.with_gemm_full_precision_f16(true)
313312
.with_no_prefix_cache(self.prefix_cache_n.is_none());
314313

315314
if let Some(n) = self.prefix_cache_n {

0 commit comments

Comments
 (0)