-
Notifications
You must be signed in to change notification settings - Fork 472
Add auto loader for vision/text detection #1402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,288 @@ | ||
| use super::{ | ||
| NormalLoaderBuilder, VisionLoaderBuilder, NormalSpecificConfig, VisionSpecificConfig, | ||
| Loader, ModelPaths, TokenSource, ModelKind, NormalLoaderType, VisionLoaderType, | ||
| Ordering, | ||
|
Check failure on line 4 in mistralrs-core/src/pipeline/auto.rs
|
||
| }; | ||
| use crate::utils::tokens::get_token; | ||
| use hf_hub::{api::sync::ApiBuilder, Repo, RepoType, Cache}; | ||
| use anyhow::Result; | ||
| use std::sync::Arc; | ||
| use tokio::sync::Mutex; | ||
| use candle_core::Device; | ||
| use crate::{PagedAttentionConfig, Pipeline, TryIntoDType, DeviceMapSetting, IsqType}; | ||
| use std::path::PathBuf; | ||
| use std::path::Path; | ||
| use serde::Deserialize; | ||
|
|
||
| /// Automatically selects between a normal or vision loader based on the `architectures` field. | ||
| pub struct AutoLoader { | ||
| model_id: String, | ||
| normal_builder: Mutex<Option<NormalLoaderBuilder>>, | ||
| vision_builder: Mutex<Option<VisionLoaderBuilder>>, | ||
| loader: Mutex<Option<Box<dyn Loader>>>, | ||
| hf_cache_path: Option<PathBuf>, | ||
| } | ||
|
|
||
| pub struct AutoLoaderBuilder { | ||
| normal_cfg: NormalSpecificConfig, | ||
| vision_cfg: VisionSpecificConfig, | ||
| chat_template: Option<String>, | ||
| tokenizer_json: Option<String>, | ||
| model_id: Option<String>, | ||
| jinja_explicit: Option<String>, | ||
| no_kv_cache: bool, | ||
| xlora_model_id: Option<String>, | ||
| xlora_order: Option<Ordering>, | ||
| tgt_non_granular_index: Option<usize>, | ||
| lora_adapter_ids: Option<Vec<String>>, | ||
| hf_cache_path: Option<PathBuf>, | ||
| } | ||
|
|
||
| impl AutoLoaderBuilder { | ||
| #[allow(clippy::too_many_arguments)] | ||
| pub fn new( | ||
| normal_cfg: NormalSpecificConfig, | ||
| vision_cfg: VisionSpecificConfig, | ||
| chat_template: Option<String>, | ||
| tokenizer_json: Option<String>, | ||
| model_id: Option<String>, | ||
| no_kv_cache: bool, | ||
| jinja_explicit: Option<String>, | ||
| ) -> Self { | ||
| Self { | ||
| normal_cfg, | ||
| vision_cfg, | ||
| chat_template, | ||
| tokenizer_json, | ||
| model_id, | ||
| jinja_explicit, | ||
| no_kv_cache, | ||
| xlora_model_id: None, | ||
| xlora_order: None, | ||
| tgt_non_granular_index: None, | ||
| lora_adapter_ids: None, | ||
| hf_cache_path: None, | ||
| } | ||
| } | ||
|
|
||
| pub fn with_xlora(mut self, model_id: String, order: Ordering, no_kv_cache: bool, tgt_non_granular_index: Option<usize>) -> Self { | ||
| self.xlora_model_id = Some(model_id); | ||
| self.xlora_order = Some(order); | ||
| self.no_kv_cache = no_kv_cache; | ||
| self.tgt_non_granular_index = tgt_non_granular_index; | ||
| self | ||
| } | ||
|
|
||
| pub fn with_lora(mut self, adapters: Vec<String>) -> Self { | ||
| self.lora_adapter_ids = Some(adapters); | ||
| self | ||
| } | ||
|
|
||
| pub fn hf_cache_path(mut self, path: PathBuf) -> Self { | ||
| self.hf_cache_path = Some(path); | ||
| self | ||
| } | ||
|
|
||
| pub fn build(self) -> Box<dyn Loader> { | ||
| let model_id = self.model_id.expect("model id required"); | ||
| let mut normal_builder = NormalLoaderBuilder::new( | ||
| self.normal_cfg, | ||
| self.chat_template.clone(), | ||
| self.tokenizer_json.clone(), | ||
| Some(model_id.clone()), | ||
| self.no_kv_cache, | ||
| self.jinja_explicit.clone(), | ||
| ); | ||
| if let (Some(id), Some(ord)) = (self.xlora_model_id.clone(), self.xlora_order.clone()) { | ||
| normal_builder = normal_builder.with_xlora(id, ord, self.no_kv_cache, self.tgt_non_granular_index); | ||
| } | ||
| if let Some(ref adapters) = self.lora_adapter_ids { | ||
| normal_builder = normal_builder.with_lora(adapters.clone()); | ||
| } | ||
| if let Some(ref path) = self.hf_cache_path { | ||
| normal_builder = normal_builder.hf_cache_path(path.clone()); | ||
| } | ||
|
|
||
| let mut vision_builder = VisionLoaderBuilder::new( | ||
| self.vision_cfg, | ||
| self.chat_template, | ||
| self.tokenizer_json, | ||
| Some(model_id.clone()), | ||
| self.jinja_explicit, | ||
| ); | ||
| if let Some(ref adapters) = self.lora_adapter_ids { | ||
| vision_builder = vision_builder.with_lora(adapters.clone()); | ||
| } | ||
| if let Some(ref path) = self.hf_cache_path { | ||
| vision_builder = vision_builder.hf_cache_path(path.clone()); | ||
| } | ||
|
|
||
| Box::new(AutoLoader { | ||
| model_id, | ||
| normal_builder: Mutex::new(Some(normal_builder)), | ||
| vision_builder: Mutex::new(Some(vision_builder)), | ||
| loader: Mutex::new(None), | ||
| hf_cache_path: self.hf_cache_path, | ||
| }) | ||
| } | ||
| } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Improve error handling in the build method The -pub fn build(self) -> Box<dyn Loader> {
- let model_id = self.model_id.expect("model id required");
+pub fn build(self) -> Result<Box<dyn Loader>> {
+ let model_id = self.model_id.ok_or_else(|| anyhow::anyhow!("model_id is required"))?;Also consider that both normal and vision builders are created even though only one will be used. While this works, it's somewhat inefficient.
🤖 Prompt for AI Agents |
||
|
|
||
| #[derive(Deserialize)] | ||
| struct AutoConfig { | ||
| architectures: Vec<String>, | ||
| } | ||
|
|
||
| enum Detected { | ||
| Normal(NormalLoaderType), | ||
| Vision(VisionLoaderType), | ||
| } | ||
|
|
||
| impl AutoLoader { | ||
| fn read_config_from_path(&self, paths: &Box<dyn ModelPaths>) -> Result<String> { | ||
| Ok(std::fs::read_to_string(paths.get_config_filename())?) | ||
| } | ||
|
|
||
| fn read_config_from_hf( | ||
| &self, | ||
| revision: Option<String>, | ||
| token_source: &TokenSource, | ||
| silent: bool, | ||
| ) -> Result<String> { | ||
| let cache = self.hf_cache_path.clone().map(Cache::new).unwrap_or_default(); | ||
| let mut api = ApiBuilder::from_cache(cache) | ||
| .with_progress(!silent) | ||
| .with_token(get_token(token_source)?); | ||
| if let Ok(x) = std::env::var("HF_HUB_CACHE") { | ||
| api = api.with_cache_dir(x.into()); | ||
| } | ||
| let api = api.build()?; | ||
| let revision = revision.unwrap_or_else(|| "main".to_string()); | ||
| let api = api.repo(Repo::with_revision( | ||
| self.model_id.clone(), | ||
| RepoType::Model, | ||
| revision, | ||
| )); | ||
| let model_id = Path::new(&self.model_id); | ||
| let config_filename = api_get_file!(api, "config.json", model_id); | ||
|
Check failure on line 166 in mistralrs-core/src/pipeline/auto.rs
|
||
| Ok(std::fs::read_to_string(config_filename)?) | ||
| } | ||
|
|
||
| fn detect(&self, config: &str) -> Result<Detected> { | ||
| let cfg: AutoConfig = serde_json::from_str(config)?; | ||
| if cfg.architectures.len() != 1 { | ||
| anyhow::bail!("Expected exactly one architecture in config"); | ||
| } | ||
| let name = &cfg.architectures[0]; | ||
| if let Ok(tp) = VisionLoaderType::from_causal_lm_name(name) { | ||
| return Ok(Detected::Vision(tp)); | ||
| } | ||
| let tp = NormalLoaderType::from_causal_lm_name(name)?; | ||
| Ok(Detected::Normal(tp)) | ||
| } | ||
|
|
||
| fn ensure_loader(&self, config: &str) -> Result<()> { | ||
| let mut guard = self.loader.lock().unwrap(); | ||
|
Check failure on line 184 in mistralrs-core/src/pipeline/auto.rs
|
||
| if guard.is_some() { | ||
| return Ok(()); | ||
| } | ||
| match self.detect(config)? { | ||
| Detected::Normal(tp) => { | ||
| let builder = self | ||
| .normal_builder | ||
| .lock() | ||
| .unwrap() | ||
|
Check failure on line 193 in mistralrs-core/src/pipeline/auto.rs
|
||
| .take() | ||
| .expect("builder taken"); | ||
| let loader = builder.build(Some(tp)).expect("build normal"); | ||
| *guard = Some(loader); | ||
| } | ||
| Detected::Vision(tp) => { | ||
| let builder = self | ||
| .vision_builder | ||
| .lock() | ||
| .unwrap() | ||
|
Check failure on line 203 in mistralrs-core/src/pipeline/auto.rs
|
||
| .take() | ||
| .expect("builder taken"); | ||
| let loader = builder.build(Some(tp)); | ||
| *guard = Some(loader); | ||
| } | ||
| } | ||
| Ok(()) | ||
| } | ||
| } | ||
|
|
||
| impl Loader for AutoLoader { | ||
| #[allow(clippy::type_complexity, clippy::too_many_arguments)] | ||
| fn load_model_from_hf( | ||
| &self, | ||
| revision: Option<String>, | ||
| token_source: TokenSource, | ||
| dtype: &dyn TryIntoDType, | ||
| device: &Device, | ||
| silent: bool, | ||
| mapper: DeviceMapSetting, | ||
| in_situ_quant: Option<IsqType>, | ||
| paged_attn_config: Option<PagedAttentionConfig>, | ||
| ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> { | ||
| let config = self.read_config_from_hf(revision.clone(), &token_source, silent)?; | ||
| self.ensure_loader(&config)?; | ||
| self.loader | ||
| .lock() | ||
| .unwrap() | ||
|
Check failure on line 231 in mistralrs-core/src/pipeline/auto.rs
|
||
| .as_ref() | ||
| .unwrap() | ||
| .load_model_from_hf( | ||
| revision, | ||
| token_source, | ||
| dtype, | ||
| device, | ||
| silent, | ||
| mapper, | ||
| in_situ_quant, | ||
| paged_attn_config, | ||
| ) | ||
| } | ||
|
|
||
| #[allow(clippy::type_complexity, clippy::too_many_arguments)] | ||
| fn load_model_from_path( | ||
| &self, | ||
| paths: &Box<dyn ModelPaths>, | ||
| dtype: &dyn TryIntoDType, | ||
| device: &Device, | ||
| silent: bool, | ||
| mapper: DeviceMapSetting, | ||
| in_situ_quant: Option<IsqType>, | ||
| paged_attn_config: Option<PagedAttentionConfig>, | ||
| ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> { | ||
| let config = self.read_config_from_path(paths)?; | ||
| self.ensure_loader(&config)?; | ||
| self.loader | ||
| .lock() | ||
| .unwrap() | ||
|
Check failure on line 261 in mistralrs-core/src/pipeline/auto.rs
|
||
| .as_ref() | ||
| .unwrap() | ||
| .load_model_from_path( | ||
| paths, | ||
| dtype, | ||
| device, | ||
| silent, | ||
| mapper, | ||
| in_situ_quant, | ||
| paged_attn_config, | ||
| ) | ||
| } | ||
|
|
||
| fn get_id(&self) -> String { | ||
| self.model_id.clone() | ||
| } | ||
|
|
||
| fn get_kind(&self) -> ModelKind { | ||
| self.loader | ||
| .lock() | ||
| .unwrap() | ||
|
Check failure on line 282 in mistralrs-core/src/pipeline/auto.rs
|
||
| .as_ref() | ||
| .map(|l| l.get_kind()) | ||
| .unwrap_or(ModelKind::Normal) | ||
| } | ||
|
Comment on lines
+292
to
+299
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potentially incorrect default behavior in get_kind() The fn get_kind(&self) -> ModelKind {
+ // Ensure loader is initialized by reading a minimal config
+ // This is a limitation of the current design where we can't know
+ // the kind without first detecting from config
self.loader
.lock()
.unwrap()
.as_ref()
.map(|l| l.get_kind())
.unwrap_or(ModelKind::Normal)
}Consider documenting this behavior or restructuring to avoid this ambiguity. 🤖 Prompt for AI Agents |
||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove parameter shadowing in
with_xloraThe
no_kv_cacheparameter shadows the instance field, which can be confusing. Since you're already settingself.no_kv_cachefrom the parameter, consider whether this parameter is necessary or if it should use the value already set in the builder.If you need to update
no_kv_cacheseparately, consider adding a dedicated method likewith_no_kv_cache(mut self, no_kv_cache: bool) -> Self.📝 Committable suggestion
🤖 Prompt for AI Agents