Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 65 additions & 32 deletions mistralrs-core/src/pipeline/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,65 @@ macro_rules! api_dir_list {
.collect::<Vec<String>>()
.into_iter()
} else {
$api.info()
.map(|repo| {
repo.siblings
.iter()
.map(|x| x.rfilename.clone())
.collect::<Vec<String>>()
})
.unwrap_or_else(|e| {
if $should_panic {
panic!("Could not get directory listing from API: {:?}", e)
} else {
tracing::warn!("Could not get directory listing from API: {:?}", e);
Vec::<String>::new()
}
})
.into_iter()
let sanitized_id = std::path::Path::new($model_id)
.display()
.to_string()
.replace("/", "-");

let home_folder = if dirs::home_dir().is_some() {
let mut path = dirs::home_dir().unwrap();
path.push(".cache/huggingface/hub/");
if !path.exists() {
let _ = std::fs::create_dir_all(&path);
}
path
} else {
"./".into()
};

let cache_dir: std::path::PathBuf = std::env::var("HF_HUB_CACHE")
.map(std::path::PathBuf::from)
.unwrap_or(home_folder.into());
let cache_file = cache_dir.join(format!("{sanitized_id}_repo_list.json"));
if std::path::Path::new(&cache_file).exists() {
use std::io::Read;
// Read from cache
let mut file = std::fs::File::open(&cache_file).expect("Could not open cache file");
let mut contents = String::new();
file.read_to_string(&mut contents)
.expect("Could not read cache file");
let cache: $crate::pipeline::FileListCache =
serde_json::from_str(&contents).expect("Could not parse cache JSON");
tracing::info!("Read from cache file {:?}", cache_file);
cache.files.into_iter()
} else {
$api.info()
.map(|repo| {
let files: Vec<String> = repo
.siblings
.iter()
.map(|x| x.rfilename.clone())
.collect::<Vec<String>>();
// Save to cache
let cache = $crate::pipeline::FileListCache {
files: files.clone(),
};
let json = serde_json::to_string_pretty(&cache)
.expect("Could not serialize cache");
let ret = std::fs::write(&cache_file, json);
tracing::info!("Write to cache file {:?}, {:?}", cache_file, ret);
files
})
.unwrap_or_else(|e| {
if $should_panic {
panic!("Could not get directory listing from API: {:?}", e)
} else {
tracing::warn!("Could not get directory listing from API: {:?}", e);
Vec::<String>::new()
}
})
.into_iter()
}
}
};
}
Expand Down Expand Up @@ -117,10 +160,9 @@ macro_rules! get_paths {
revision.clone(),
$this.xlora_order.as_ref(),
)?;
let gen_conf = if $crate::api_dir_list!(api, model_id, false)
.collect::<Vec<_>>()
.contains(&"generation_config.json".to_string())
{
let dir_list = $crate::api_dir_list!(api, model_id, false).collect::<Vec<_>>();

let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) {
info!("Loading `generation_config.json` at `{}`", $this.model_id);
Some($crate::api_get_file!(
api,
Expand All @@ -130,10 +172,7 @@ macro_rules! get_paths {
} else {
None
};
let preprocessor_config = if $crate::api_dir_list!(api, model_id, false)
.collect::<Vec<_>>()
.contains(&"preprocessor_config.json".to_string())
{
let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string()) {
info!("Loading `preprocessor_config.json` at `{}`", $this.model_id);
Some($crate::api_get_file!(
api,
Expand All @@ -143,10 +182,7 @@ macro_rules! get_paths {
} else {
None
};
let processor_config = if $crate::api_dir_list!(api, model_id, false)
.collect::<Vec<_>>()
.contains(&"processor_config.json".to_string())
{
let processor_config = if dir_list.contains(&"processor_config.json".to_string()) {
info!("Loading `processor_config.json` at `{}`", $this.model_id);
Some($crate::api_get_file!(
api,
Expand All @@ -167,10 +203,7 @@ macro_rules! get_paths {
model_id
))
};
let chat_template_json_filename = if $crate::api_dir_list!(api, model_id, false)
.collect::<Vec<_>>()
.contains(&"chat_template.json".to_string())
{
let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) {
info!("Loading `chat_template.json` at `{}`", $this.model_id);
Some($crate::api_get_file!(api, "chat_template.json", model_id))
} else {
Expand Down
5 changes: 5 additions & 0 deletions mistralrs-core/src/pipeline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,11 @@ impl ForwardInputsResult {
}
}

#[derive(serde::Serialize, serde::Deserialize)]
pub(crate) struct FileListCache {
files: Vec<String>,
}

#[async_trait::async_trait]
pub trait Pipeline:
Send
Expand Down
13 changes: 9 additions & 4 deletions mistralrs-core/src/pipeline/paths.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ pub fn get_xlora_paths(
revision,
));
let model_id = Path::new(&xlora_id);

let dir_list = api_dir_list!(api, model_id, true).collect::<Vec<_>>();
// Get the path for the xlora classifier
let xlora_classifier = &api_dir_list!(api, model_id, true)
let xlora_classifier = &dir_list
.clone()
.into_iter()
.filter(|x| x.contains("xlora_classifier.safetensors"))
.collect::<Vec<_>>();
if xlora_classifier.len() > 1 {
Comment on lines +81 to 88
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid cloning the full directory listing multiple times

dir_list.clone().into_iter() copies the whole Vec<String> for every filter.
With large repos this recreates thousands of strings three times.

-let xlora_classifier = &dir_list
-    .clone()
-    .into_iter()
-    .filter(|x| x.contains("xlora_classifier.safetensors"))
-    .collect::<Vec<_>>();
+let xlora_classifier: Vec<&str> = dir_list
+    .iter()
+    .map(String::as_str)
+    .filter(|x| x.contains("xlora_classifier.safetensors"))
+    .collect();

The same pattern appears for xlora_configs (lines 99-103) and adapter_files (lines 142-154).
Switching to iter() (or a single pass that extracts all three groups) removes the redundant allocations without altering behaviour.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
let dir_list = api_dir_list!(api, model_id, true).collect::<Vec<_>>();
// Get the path for the xlora classifier
let xlora_classifier = &api_dir_list!(api, model_id, true)
let xlora_classifier = &dir_list
.clone()
.into_iter()
.filter(|x| x.contains("xlora_classifier.safetensors"))
.collect::<Vec<_>>();
if xlora_classifier.len() > 1 {
let dir_list = api_dir_list!(api, model_id, true).collect::<Vec<_>>();
// Get the path for the xlora classifier
- let xlora_classifier = &dir_list
- .clone()
- .into_iter()
- .filter(|x| x.contains("xlora_classifier.safetensors"))
- .collect::<Vec<_>>();
+ let xlora_classifier: Vec<&str> = dir_list
+ .iter()
+ .map(String::as_str)
+ .filter(|x| x.contains("xlora_classifier.safetensors"))
+ .collect();
if xlora_classifier.len() > 1 {
🤖 Prompt for AI Agents
In mistralrs-core/src/pipeline/paths.rs around lines 81 to 88, avoid cloning the
entire dir_list Vec multiple times for filtering, as it causes unnecessary
memory allocations. Replace the clone and into_iter calls with iter() to iterate
over references without copying the Vec or its strings. Apply the same change to
the filtering logic for xlora_configs (lines 99-103) and adapter_files (lines
142-154) to eliminate redundant allocations while preserving behavior.

Expand All @@ -94,7 +96,9 @@ pub fn get_xlora_paths(

// Get the path for the xlora config by checking all for valid versions.
// NOTE(EricLBuehler): Remove this functionality because all configs should be deserializable
let xlora_configs = &api_dir_list!(api, model_id, true)
let xlora_configs = &dir_list
.clone()
.into_iter()
.filter(|x| x.contains("xlora_config.json"))
.collect::<Vec<_>>();
if xlora_configs.len() > 1 {
Expand Down Expand Up @@ -135,7 +139,8 @@ pub fn get_xlora_paths(
});

// If there are adapters in the ordering file, get their names and remote paths
let adapter_files = api_dir_list!(api, model_id, true)
let adapter_files = dir_list
.into_iter()
.filter_map(|name| {
if let Some(ref adapters) = xlora_order.adapters {
for adapter_name in adapters {
Expand Down
Loading