Skip to content

Commit cb939a8

Browse files
authored
Add immediate isq predicates for qwen3 (#1358)
* Add immediate isq predicates for qwen3 * Fix parsing of "parse_isq_value" depedent of device * Typo
1 parent 504401f commit cb939a8

File tree

9 files changed

+56
-29
lines changed

9 files changed

+56
-29
lines changed

mistralrs-bench/src/main.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ use cli_table::{format::Justify, print_stdout, Cell, CellStruct, Style, Table};
44
use mistralrs_core::{
55
get_auto_device_map_params, get_model_dtype, initialize_logging, paged_attn_supported,
66
parse_isq_value, Constraint, DefaultSchedulerMethod, DeviceLayerMapMetadata, DeviceMapMetadata,
7-
DeviceMapSetting, DrySamplingParams, IsqType, Loader, LoaderBuilder, MemoryGpuConfig,
8-
MistralRs, MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, Request,
9-
RequestMessage, Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
7+
DeviceMapSetting, DrySamplingParams, Loader, LoaderBuilder, MemoryGpuConfig, MistralRs,
8+
MistralRsBuilder, ModelSelected, NormalRequest, PagedAttentionConfig, Request, RequestMessage,
9+
Response, SamplingParams, SchedulerConfig, TokenSource, Usage,
1010
};
1111
use std::sync::Arc;
1212
use std::{fmt::Display, num::NonZeroUsize};
@@ -300,8 +300,8 @@ struct Args {
300300
num_device_layers: Option<Vec<String>>,
301301

302302
/// In-situ quantization to apply.
303-
#[arg(long = "isq", value_parser = parse_isq_value)]
304-
in_situ_quant: Option<IsqType>,
303+
#[arg(long = "isq")]
304+
in_situ_quant: Option<String>,
305305

306306
/// GPU memory to allocate for KV cache with PagedAttention in MBs.
307307
/// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
@@ -490,14 +490,19 @@ fn main() -> anyhow::Result<()> {
490490
(_, _, _, _, _, _) => None,
491491
};
492492

493+
let isq = args
494+
.in_situ_quant
495+
.as_ref()
496+
.and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
497+
493498
let pipeline = loader.load_model_from_hf(
494499
None,
495500
token_source,
496501
&dtype,
497502
&device,
498503
false,
499504
mapper,
500-
args.in_situ_quant,
505+
isq,
501506
cache_config,
502507
)?;
503508
info!("Model loaded.");

mistralrs-core/src/pipeline/isq.rs

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,20 @@ pub const UQFF_MULTI_FILE_DELIMITER: &str = ";";
5959
/// - `AFQ4`
6060
/// - `AFQ6`
6161
/// - `AFQ8`
62-
pub fn parse_isq_value(s: &str) -> Result<IsqType, String> {
62+
pub fn parse_isq_value(s: &str, device: Option<&Device>) -> Result<IsqType, String> {
63+
let is_metal = device.map(|device| device.is_metal()).unwrap_or(false);
6364
let tp = match s.to_lowercase().as_str() {
64-
"2" if cfg!(feature = "metal") => IsqType::AFQ2,
65-
"2" if !cfg!(feature = "metal") => IsqType::Q2K,
66-
"3" if cfg!(feature = "metal") => IsqType::AFQ3,
67-
"3" if !cfg!(feature = "metal") => IsqType::Q3K,
68-
"4" if cfg!(feature = "metal") => IsqType::AFQ4,
69-
"4" if !cfg!(feature = "metal") => IsqType::Q4K,
65+
"2" if is_metal => IsqType::AFQ2,
66+
"2" if !is_metal => IsqType::Q2K,
67+
"3" if is_metal => IsqType::AFQ3,
68+
"3" if !is_metal => IsqType::Q3K,
69+
"4" if is_metal => IsqType::AFQ4,
70+
"4" if !is_metal => IsqType::Q4K,
7071
"5" => IsqType::Q5K,
71-
"6" if cfg!(feature = "metal") => IsqType::AFQ6,
72-
"6" if !cfg!(feature = "metal") => IsqType::Q6K,
73-
"8" if cfg!(feature = "metal") => IsqType::AFQ8,
74-
"8" if !cfg!(feature = "metal") => IsqType::Q8_0,
72+
"6" if is_metal => IsqType::AFQ6,
73+
"6" if !is_metal => IsqType::Q6K,
74+
"8" if is_metal => IsqType::AFQ8,
75+
"8" if !is_metal => IsqType::Q8_0,
7576
"q4_0" => IsqType::Q4_0,
7677
"q4_1" => IsqType::Q4_1,
7778
"q5_0" => IsqType::Q5_0,

mistralrs-core/src/pipeline/loaders/normal_loaders.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3004,8 +3004,8 @@ impl IsqModelLoader for Qwen3Loader {
30043004
Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
30053005
])
30063006
}
3007-
fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3008-
self.isq_layer_regexes_moqe(config)
3007+
fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3008+
self.isq_layer_regexes(config)
30093009
}
30103010
}
30113011

@@ -3189,6 +3189,9 @@ impl IsqModelLoader for Qwen3MoELoader {
31893189
Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
31903190
])
31913191
}
3192+
fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3193+
self.isq_layer_regexes(config)
3194+
}
31923195
fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
31933196
self.isq_layer_regexes_moqe(config)
31943197
}

mistralrs-core/src/pipeline/normal.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,10 @@ impl Loader for NormalLoader {
478478
} else {
479479
self.inner.immediate_isq_predicates(&config)?
480480
};
481+
info!("Applying ISQ to {in_situ_quant:?}");
482+
if predicates.is_empty() {
483+
warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
484+
}
481485
mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
482486
false
483487
} else {

mistralrs-core/src/pipeline/vision.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,10 @@ impl Loader for VisionLoader {
409409
&& self.config.write_uqff.is_none()
410410
{
411411
let predicates = self.inner.immediate_isq_predicates(&config)?;
412+
info!("Applying ISQ to {in_situ_quant:?}");
413+
if predicates.is_empty() {
414+
warn!("No predicates for this model and ISQ setting detected. ISQ will not be applied to any weights!");
415+
}
412416
mistralrs_quant::set_immediate_isq(in_situ_quant, predicates);
413417
false
414418
} else {

mistralrs-core/src/topology/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ impl Topology {
107107
}
108108
let range = CustomRange { start, end };
109109
let isq = if let Some(isq) = isq {
110-
Some(parse_isq_value(&isq).map_err(anyhow::Error::msg)?)
110+
Some(parse_isq_value(&isq, None).map_err(anyhow::Error::msg)?)
111111
} else {
112112
None
113113
};

mistralrs-pyo3/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ impl Runner {
682682

683683
let device = get_device(seed).as_ref().map_err(PyApiErr::from)?;
684684
let isq = if let Some(isq) = in_situ_quant {
685-
Some(parse_isq_value(&isq).map_err(PyApiErr::from)?)
685+
Some(parse_isq_value(&isq, Some(device)).map_err(PyApiErr::from)?)
686686
} else {
687687
None
688688
};
@@ -1416,7 +1416,7 @@ impl Runner {
14161416
/// Send a request to re-ISQ the model. If the model was loaded as GGUF or GGML
14171417
/// then nothing will happen.
14181418
fn send_re_isq(&self, dtype: String) -> PyApiResult<()> {
1419-
let request = _Request::ReIsq(parse_isq_value(&dtype)?);
1419+
let request = _Request::ReIsq(parse_isq_value(&dtype, None)?);
14201420
self.runner.get_sender()?.blocking_send(request).unwrap();
14211421
Ok(())
14221422
}

mistralrs-server/src/main.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use clap::Parser;
1010
use mistralrs_core::{
1111
get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, initialize_logging,
1212
paged_attn_supported, parse_isq_value, BertEmbeddingModel, DefaultSchedulerMethod,
13-
DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, IsqType, Loader, LoaderBuilder,
13+
DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, Loader, LoaderBuilder,
1414
MemoryGpuConfig, MistralRs, MistralRsBuilder, ModelSelected, PagedAttentionConfig, Request,
1515
SchedulerConfig, TokenSource,
1616
};
@@ -119,8 +119,8 @@ struct Args {
119119
num_device_layers: Option<Vec<String>>,
120120

121121
/// In-situ quantization to apply.
122-
#[arg(long = "isq", value_parser = parse_isq_value)]
123-
in_situ_quant: Option<IsqType>,
122+
#[arg(long = "isq")]
123+
in_situ_quant: Option<String>,
124124

125125
/// GPU memory to allocate for KV cache with PagedAttention in MBs.
126126
/// PagedAttention is supported on CUDA and Metal. It is automatically activated on CUDA but not on Metal.
@@ -223,7 +223,7 @@ async fn re_isq(
223223
) -> Result<String, String> {
224224
let repr = format!("Re ISQ: {:?}", request.ggml_type);
225225
MistralRs::maybe_log_request(state.clone(), repr.clone());
226-
let request = Request::ReIsq(parse_isq_value(&request.ggml_type)?);
226+
let request = Request::ReIsq(parse_isq_value(&request.ggml_type, None)?);
227227
state.get_sender().unwrap().send(request).await.unwrap();
228228
Ok(repr)
229229
}
@@ -300,7 +300,12 @@ async fn main() -> Result<()> {
300300
.build()?;
301301

302302
#[cfg(feature = "metal")]
303-
let device = Device::new_metal(0)?;
303+
let device = if args.cpu {
304+
args.no_paged_attn = true;
305+
Device::Cpu
306+
} else {
307+
Device::new_metal(0)?
308+
};
304309
#[cfg(not(feature = "metal"))]
305310
let device = if args.cpu {
306311
args.no_paged_attn = true;
@@ -426,14 +431,19 @@ async fn main() -> Result<()> {
426431
(_, _, _, _, _, _) => None,
427432
};
428433

434+
let isq = args
435+
.in_situ_quant
436+
.as_ref()
437+
.and_then(|isq| parse_isq_value(isq, Some(&device)).ok());
438+
429439
let pipeline = loader.load_model_from_hf(
430440
None,
431441
args.token_source,
432442
&dtype,
433443
&device,
434444
false,
435445
mapper,
436-
args.in_situ_quant,
446+
isq,
437447
cache_config,
438448
)?;
439449
info!("Model loaded.");

mistralrs/examples/perplexity/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ async fn main() -> Result<()> {
7474
let args = Args::parse();
7575

7676
let quant = if let Some(isq) = &args.isq {
77-
Some(parse_isq_value(isq).map_err(anyhow::Error::msg)?)
77+
Some(parse_isq_value(isq, None).map_err(anyhow::Error::msg)?)
7878
} else {
7979
None
8080
};

0 commit comments

Comments
 (0)