diff --git a/mistralrs-core/src/attention.rs b/mistralrs-core/src/attention.rs index f602aaae2b..3d1f0c8822 100644 --- a/mistralrs-core/src/attention.rs +++ b/mistralrs-core/src/attention.rs @@ -318,8 +318,14 @@ impl Sdpa { mask.layout().broadcast_as(tgt_mask_shape.clone()).is_ok() && sdpa_params.softcap.is_none_or(|x| x == 1.0) }); + let valid_head_dims: &[usize] = if can_use_mask && mask.is_some() { + &[64, 80, 128] + } else { + &[32, 64, 96, 128, 256] + }; if [q, k, v].into_iter().all(|x| x.device().is_metal()) && all_head_dims_match + && valid_head_dims.contains(&head_dim) && can_use_mask { let mask = match mask { diff --git a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs index 27068b6475..c616d6aa8d 100644 --- a/mistralrs-core/src/pipeline/loaders/vision_loaders.rs +++ b/mistralrs-core/src/pipeline/loaders/vision_loaders.rs @@ -3145,13 +3145,17 @@ impl VisionModelLoader for Gemma3Loader { } fn get_processor( &self, - _model_config: &str, + config: &str, processor_config: Option, _preprocessor_config: PreProcessorConfig, _max_edge: Option, ) -> Arc { + let config: Gemma3Config = serde_json::from_str(config).unwrap(); // Handle the Gemma 3 1b case here - Arc::new(Gemma3Processor::new(processor_config.unwrap_or_default())) + Arc::new(Gemma3Processor::new( + processor_config.unwrap_or_default(), + matches!(config, Gemma3Config::WithVision { .. }), + )) } fn supports_paged_attention(&self) -> bool { true diff --git a/mistralrs-core/src/vision_models/gemma3/inputs_processor.rs b/mistralrs-core/src/vision_models/gemma3/inputs_processor.rs index 4e4b25b0f2..41dbdd27f2 100644 --- a/mistralrs-core/src/vision_models/gemma3/inputs_processor.rs +++ b/mistralrs-core/src/vision_models/gemma3/inputs_processor.rs @@ -31,6 +31,7 @@ use super::Gemma3SpecificArgs; struct Gemma3ImageProcessor { full_image_sequence: String, + supports_images: bool, } const IMAGE_TOKEN: &str = ""; @@ -39,16 +40,18 @@ const EOI_TOKEN: &str = ""; pub struct Gemma3Processor { full_image_sequence: String, + supports_images: bool, } impl Gemma3Processor { - pub fn new(processor_config: ProcessorConfig) -> Self { + pub fn new(processor_config: ProcessorConfig, supports_images: bool) -> Self { let image_tokens_expanded = vec![IMAGE_TOKEN.to_string(); processor_config.image_seq_len.unwrap_or(256)].join(""); let full_image_sequence = format!("\n\n{BOI_TOKEN}{image_tokens_expanded}{EOI_TOKEN}\n\n"); Self { full_image_sequence, + supports_images, } } } @@ -57,6 +60,7 @@ impl Processor for Gemma3Processor { fn inputs_processor(&self) -> Arc { Arc::new(Gemma3ImageProcessor { full_image_sequence: self.full_image_sequence.clone(), + supports_images: self.supports_images, }) } @@ -114,6 +118,12 @@ impl InputsProcessor for Gemma3ImageProcessor { let has_images = input_seqs.iter().all(|seq| seq.has_images()); let pixel_values = if has_images { + if !self.supports_images { + return Box::new(std::iter::once(Err(anyhow::Error::msg( + "This image processor does not support images.", + )))); + } + let mut pixel_values_accum = Vec::new(); let re = Regex::new(BOI_TOKEN).unwrap(); for seq in input_seqs.iter_mut() {