Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mistralrs-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ pub use pipeline::{
MistralLoader, MixtralLoader, ModelKind, ModelPaths, NormalLoader, NormalLoaderBuilder,
NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader,
SpeculativeConfig, SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline,
AutoLoader, AutoLoaderBuilder,
Starcoder2Loader, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
VisionPromptPrefixer, VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
};
Expand Down
288 changes: 288 additions & 0 deletions mistralrs-core/src/pipeline/auto.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
use super::{
NormalLoaderBuilder, VisionLoaderBuilder, NormalSpecificConfig, VisionSpecificConfig,
Loader, ModelPaths, TokenSource, ModelKind, NormalLoaderType, VisionLoaderType,
Ordering,

Check failure on line 4 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Clippy

unresolved import `super::Ordering`

Check failure on line 4 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

unresolved import `super::Ordering`

Check failure on line 4 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

unresolved import `super::Ordering`

Check failure on line 4 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Docs

unresolved import `super::Ordering`
};
use crate::utils::tokens::get_token;
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType, Cache};
use anyhow::Result;
use std::sync::Arc;
use tokio::sync::Mutex;
use candle_core::Device;
use crate::{PagedAttentionConfig, Pipeline, TryIntoDType, DeviceMapSetting, IsqType};
use std::path::PathBuf;
use std::path::Path;
use serde::Deserialize;

/// Automatically selects between a normal or vision loader based on the `architectures` field.
pub struct AutoLoader {
model_id: String,
normal_builder: Mutex<Option<NormalLoaderBuilder>>,
vision_builder: Mutex<Option<VisionLoaderBuilder>>,
loader: Mutex<Option<Box<dyn Loader>>>,
hf_cache_path: Option<PathBuf>,
}

pub struct AutoLoaderBuilder {
normal_cfg: NormalSpecificConfig,
vision_cfg: VisionSpecificConfig,
chat_template: Option<String>,
tokenizer_json: Option<String>,
model_id: Option<String>,
jinja_explicit: Option<String>,
no_kv_cache: bool,
xlora_model_id: Option<String>,
xlora_order: Option<Ordering>,
tgt_non_granular_index: Option<usize>,
lora_adapter_ids: Option<Vec<String>>,
hf_cache_path: Option<PathBuf>,
}

impl AutoLoaderBuilder {
#[allow(clippy::too_many_arguments)]
pub fn new(
normal_cfg: NormalSpecificConfig,
vision_cfg: VisionSpecificConfig,
chat_template: Option<String>,
tokenizer_json: Option<String>,
model_id: Option<String>,
no_kv_cache: bool,
jinja_explicit: Option<String>,
) -> Self {
Self {
normal_cfg,
vision_cfg,
chat_template,
tokenizer_json,
model_id,
jinja_explicit,
no_kv_cache,
xlora_model_id: None,
xlora_order: None,
tgt_non_granular_index: None,
lora_adapter_ids: None,
hf_cache_path: None,
}
}

pub fn with_xlora(mut self, model_id: String, order: Ordering, no_kv_cache: bool, tgt_non_granular_index: Option<usize>) -> Self {
self.xlora_model_id = Some(model_id);
self.xlora_order = Some(order);
self.no_kv_cache = no_kv_cache;
self.tgt_non_granular_index = tgt_non_granular_index;
self
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove parameter shadowing in with_xlora

The no_kv_cache parameter shadows the instance field, which can be confusing. Since you're already setting self.no_kv_cache from the parameter, consider whether this parameter is necessary or if it should use the value already set in the builder.

-pub fn with_xlora(mut self, model_id: String, order: Ordering, no_kv_cache: bool, tgt_non_granular_index: Option<usize>) -> Self {
+pub fn with_xlora(mut self, model_id: String, order: Ordering, tgt_non_granular_index: Option<usize>) -> Self {
     self.xlora_model_id = Some(model_id);
     self.xlora_order = Some(order);
-    self.no_kv_cache = no_kv_cache;
     self.tgt_non_granular_index = tgt_non_granular_index;
     self
 }

If you need to update no_kv_cache separately, consider adding a dedicated method like with_no_kv_cache(mut self, no_kv_cache: bool) -> Self.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
pub fn with_xlora(mut self, model_id: String, order: Ordering, no_kv_cache: bool, tgt_non_granular_index: Option<usize>) -> Self {
self.xlora_model_id = Some(model_id);
self.xlora_order = Some(order);
self.no_kv_cache = no_kv_cache;
self.tgt_non_granular_index = tgt_non_granular_index;
self
}
pub fn with_xlora(mut self, model_id: String, order: Ordering, tgt_non_granular_index: Option<usize>) -> Self {
self.xlora_model_id = Some(model_id);
self.xlora_order = Some(order);
self.tgt_non_granular_index = tgt_non_granular_index;
self
}
🤖 Prompt for AI Agents
In mistralrs-core/src/pipeline/auto.rs around lines 68 to 74, the with_xlora
method has a parameter no_kv_cache that shadows the instance field of the same
name, which can cause confusion. To fix this, remove the no_kv_cache parameter
from with_xlora and rely on the existing value set in the builder, or
alternatively, create a separate method named with_no_kv_cache that takes a
boolean and sets self.no_kv_cache. This will clarify the API and avoid parameter
shadowing.


pub fn with_lora(mut self, adapters: Vec<String>) -> Self {
self.lora_adapter_ids = Some(adapters);
self
}

pub fn hf_cache_path(mut self, path: PathBuf) -> Self {
self.hf_cache_path = Some(path);
self
}

pub fn build(self) -> Box<dyn Loader> {
let model_id = self.model_id.expect("model id required");
let mut normal_builder = NormalLoaderBuilder::new(
self.normal_cfg,
self.chat_template.clone(),
self.tokenizer_json.clone(),
Some(model_id.clone()),
self.no_kv_cache,
self.jinja_explicit.clone(),
);
if let (Some(id), Some(ord)) = (self.xlora_model_id.clone(), self.xlora_order.clone()) {
normal_builder = normal_builder.with_xlora(id, ord, self.no_kv_cache, self.tgt_non_granular_index);
}
if let Some(ref adapters) = self.lora_adapter_ids {
normal_builder = normal_builder.with_lora(adapters.clone());
}
if let Some(ref path) = self.hf_cache_path {
normal_builder = normal_builder.hf_cache_path(path.clone());
}

let mut vision_builder = VisionLoaderBuilder::new(
self.vision_cfg,
self.chat_template,
self.tokenizer_json,
Some(model_id.clone()),
self.jinja_explicit,
);
if let Some(ref adapters) = self.lora_adapter_ids {
vision_builder = vision_builder.with_lora(adapters.clone());
}
if let Some(ref path) = self.hf_cache_path {
vision_builder = vision_builder.hf_cache_path(path.clone());
}

Box::new(AutoLoader {
model_id,
normal_builder: Mutex::new(Some(normal_builder)),
vision_builder: Mutex::new(Some(vision_builder)),
loader: Mutex::new(None),
hf_cache_path: self.hf_cache_path,
})
}
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling in the build method

The build() method uses expect() which will panic if model_id is None. For a builder pattern, it's better to return a Result to allow graceful error handling.

-pub fn build(self) -> Box<dyn Loader> {
-    let model_id = self.model_id.expect("model id required");
+pub fn build(self) -> Result<Box<dyn Loader>> {
+    let model_id = self.model_id.ok_or_else(|| anyhow::anyhow!("model_id is required"))?;

Also consider that both normal and vision builders are created even though only one will be used. While this works, it's somewhat inefficient.

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In mistralrs-core/src/pipeline/auto.rs around lines 86 to 128, replace the use
of expect() on model_id with returning a Result to handle the missing model_id
error gracefully. Change the build method signature to return Result<Box<dyn
Loader>, ErrorType> and return an appropriate error if model_id is None.
Additionally, refactor the method to defer creation of normal_builder and
vision_builder until it is clear which one is needed, avoiding unnecessary
construction of both builders to improve efficiency.


#[derive(Deserialize)]
struct AutoConfig {
architectures: Vec<String>,
}

enum Detected {
Normal(NormalLoaderType),
Vision(VisionLoaderType),
}

impl AutoLoader {
fn read_config_from_path(&self, paths: &Box<dyn ModelPaths>) -> Result<String> {
Ok(std::fs::read_to_string(paths.get_config_filename())?)
}

fn read_config_from_hf(
&self,
revision: Option<String>,
token_source: &TokenSource,
silent: bool,
) -> Result<String> {
let cache = self.hf_cache_path.clone().map(Cache::new).unwrap_or_default();
let mut api = ApiBuilder::from_cache(cache)
.with_progress(!silent)
.with_token(get_token(token_source)?);
if let Ok(x) = std::env::var("HF_HUB_CACHE") {
api = api.with_cache_dir(x.into());
}
let api = api.build()?;
let revision = revision.unwrap_or_else(|| "main".to_string());
let api = api.repo(Repo::with_revision(
self.model_id.clone(),
RepoType::Model,
revision,
));
let model_id = Path::new(&self.model_id);
let config_filename = api_get_file!(api, "config.json", model_id);

Check failure on line 166 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Clippy

cannot find macro `api_get_file` in this scope

Check failure on line 166 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

cannot find macro `api_get_file` in this scope

Check failure on line 166 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

cannot find macro `api_get_file` in this scope

Check failure on line 166 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Docs

cannot find macro `api_get_file` in this scope
Ok(std::fs::read_to_string(config_filename)?)
}

fn detect(&self, config: &str) -> Result<Detected> {
let cfg: AutoConfig = serde_json::from_str(config)?;
if cfg.architectures.len() != 1 {
anyhow::bail!("Expected exactly one architecture in config");
}
let name = &cfg.architectures[0];
if let Ok(tp) = VisionLoaderType::from_causal_lm_name(name) {
return Ok(Detected::Vision(tp));
}
let tp = NormalLoaderType::from_causal_lm_name(name)?;
Ok(Detected::Normal(tp))
}

fn ensure_loader(&self, config: &str) -> Result<()> {
let mut guard = self.loader.lock().unwrap();

Check failure on line 184 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Clippy

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope

Check failure on line 184 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope

Check failure on line 184 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope

Check failure on line 184 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Docs

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope
if guard.is_some() {
return Ok(());
}
match self.detect(config)? {
Detected::Normal(tp) => {
let builder = self
.normal_builder
.lock()
.unwrap()

Check failure on line 193 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Clippy

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<NormalLoaderBuilder>>>` in the current scope

Check failure on line 193 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<NormalLoaderBuilder>>>` in the current scope

Check failure on line 193 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<NormalLoaderBuilder>>>` in the current scope

Check failure on line 193 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Docs

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<NormalLoaderBuilder>>>` in the current scope
.take()
.expect("builder taken");
let loader = builder.build(Some(tp)).expect("build normal");
*guard = Some(loader);
}
Detected::Vision(tp) => {
let builder = self
.vision_builder
.lock()
.unwrap()

Check failure on line 203 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Clippy

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<VisionLoaderBuilder>>>` in the current scope

Check failure on line 203 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<VisionLoaderBuilder>>>` in the current scope

Check failure on line 203 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<VisionLoaderBuilder>>>` in the current scope

Check failure on line 203 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Docs

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<VisionLoaderBuilder>>>` in the current scope
.take()
.expect("builder taken");
let loader = builder.build(Some(tp));
*guard = Some(loader);
}
}
Ok(())
}
}

impl Loader for AutoLoader {
#[allow(clippy::type_complexity, clippy::too_many_arguments)]
fn load_model_from_hf(
&self,
revision: Option<String>,
token_source: TokenSource,
dtype: &dyn TryIntoDType,
device: &Device,
silent: bool,
mapper: DeviceMapSetting,
in_situ_quant: Option<IsqType>,
paged_attn_config: Option<PagedAttentionConfig>,
) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
let config = self.read_config_from_hf(revision.clone(), &token_source, silent)?;
self.ensure_loader(&config)?;
self.loader
.lock()
.unwrap()

Check failure on line 231 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Clippy

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope

Check failure on line 231 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope

Check failure on line 231 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope

Check failure on line 231 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Docs

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope
.as_ref()
.unwrap()
.load_model_from_hf(
revision,
token_source,
dtype,
device,
silent,
mapper,
in_situ_quant,
paged_attn_config,
)
}

#[allow(clippy::type_complexity, clippy::too_many_arguments)]
fn load_model_from_path(
&self,
paths: &Box<dyn ModelPaths>,
dtype: &dyn TryIntoDType,
device: &Device,
silent: bool,
mapper: DeviceMapSetting,
in_situ_quant: Option<IsqType>,
paged_attn_config: Option<PagedAttentionConfig>,
) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>> {
let config = self.read_config_from_path(paths)?;
self.ensure_loader(&config)?;
self.loader
.lock()
.unwrap()

Check failure on line 261 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Clippy

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope

Check failure on line 261 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope

Check failure on line 261 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope

Check failure on line 261 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Docs

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope
.as_ref()
.unwrap()
.load_model_from_path(
paths,
dtype,
device,
silent,
mapper,
in_situ_quant,
paged_attn_config,
)
}

fn get_id(&self) -> String {
self.model_id.clone()
}

fn get_kind(&self) -> ModelKind {
self.loader
.lock()
.unwrap()

Check failure on line 282 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Clippy

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope

Check failure on line 282 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope

Check failure on line 282 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Check (macOS-latest, stable)

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope

Check failure on line 282 in mistralrs-core/src/pipeline/auto.rs

View workflow job for this annotation

GitHub Actions / Docs

no method named `unwrap` found for opaque type `impl Future<Output = MutexGuard<'_, Option<Box<dyn Loader>>>>` in the current scope
.as_ref()
.map(|l| l.get_kind())
.unwrap_or(ModelKind::Normal)
}
Comment on lines +292 to +299
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potentially incorrect default behavior in get_kind()

The get_kind() method returns ModelKind::Normal as a default when no loader is initialized. This might not be correct - consider ensuring the loader is initialized first or returning an Option/Result.

 fn get_kind(&self) -> ModelKind {
+    // Ensure loader is initialized by reading a minimal config
+    // This is a limitation of the current design where we can't know
+    // the kind without first detecting from config
     self.loader
         .lock()
         .unwrap()
         .as_ref()
         .map(|l| l.get_kind())
         .unwrap_or(ModelKind::Normal)
 }

Consider documenting this behavior or restructuring to avoid this ambiguity.

🤖 Prompt for AI Agents
In mistralrs-core/src/pipeline/auto.rs around lines 279 to 286, the get_kind()
method currently returns ModelKind::Normal by default if the loader is not
initialized, which may be misleading. Modify the method to return an
Option<ModelKind> or a Result<ModelKind, Error> instead, reflecting the
possibility that the loader might be uninitialized. Update the method signature
accordingly and handle the uninitialized case explicitly by returning None or an
appropriate error. Additionally, document this behavior clearly to avoid
ambiguity.

}

2 changes: 2 additions & 0 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod sampling;
mod speculative;
mod speech;
mod vision;
mod auto;

pub use super::diffusion_models::DiffusionGenerationParams;
use crate::amoe::{AnyMoeConfig, AnyMoeExpertType, AnyMoeTrainingInputs, AnyMoeTrainingResult};
Expand Down Expand Up @@ -60,6 +61,7 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use tokenizers::Tokenizer;
pub use vision::{VisionLoader, VisionLoaderBuilder, VisionSpecificConfig};
pub use auto::{AutoLoader, AutoLoaderBuilder};

use anyhow::Result;
use candle_core::{DType, Device, IndexOp, Tensor, Var};
Expand Down
Loading