From 24193e194dae75e7d351eb5515788cfe7070636d Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Mon, 16 Jun 2025 11:39:28 +0800 Subject: [PATCH 1/4] Remove duplicate calls for api_dir_list --- mistralrs-core/src/pipeline/macros.rs | 22 ++++++---------------- mistralrs-core/src/pipeline/paths.rs | 13 +++++++++---- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index 2fee07dde1..b12ac48ce2 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -117,10 +117,9 @@ macro_rules! get_paths { revision.clone(), $this.xlora_order.as_ref(), )?; - let gen_conf = if $crate::api_dir_list!(api, model_id, false) - .collect::>() - .contains(&"generation_config.json".to_string()) - { + let dir_list = $crate::api_dir_list!(api, model_id, false).collect::>(); + + let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) { info!("Loading `generation_config.json` at `{}`", $this.model_id); Some($crate::api_get_file!( api, @@ -130,10 +129,7 @@ macro_rules! get_paths { } else { None }; - let preprocessor_config = if $crate::api_dir_list!(api, model_id, false) - .collect::>() - .contains(&"preprocessor_config.json".to_string()) - { + let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string()) { info!("Loading `preprocessor_config.json` at `{}`", $this.model_id); Some($crate::api_get_file!( api, @@ -143,10 +139,7 @@ macro_rules! get_paths { } else { None }; - let processor_config = if $crate::api_dir_list!(api, model_id, false) - .collect::>() - .contains(&"processor_config.json".to_string()) - { + let processor_config = if dir_list.contains(&"processor_config.json".to_string()) { info!("Loading `processor_config.json` at `{}`", $this.model_id); Some($crate::api_get_file!( api, @@ -167,10 +160,7 @@ macro_rules! get_paths { model_id )) }; - let chat_template_json_filename = if $crate::api_dir_list!(api, model_id, false) - .collect::>() - .contains(&"chat_template.json".to_string()) - { + let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) { info!("Loading `chat_template.json` at `{}`", $this.model_id); Some($crate::api_get_file!(api, "chat_template.json", model_id)) } else { diff --git a/mistralrs-core/src/pipeline/paths.rs b/mistralrs-core/src/pipeline/paths.rs index f63ed71e5b..c25f05825a 100644 --- a/mistralrs-core/src/pipeline/paths.rs +++ b/mistralrs-core/src/pipeline/paths.rs @@ -78,9 +78,11 @@ pub fn get_xlora_paths( revision, )); let model_id = Path::new(&xlora_id); - + let dir_list = api_dir_list!(api, model_id, true).collect::>(); // Get the path for the xlora classifier - let xlora_classifier = &api_dir_list!(api, model_id, true) + let xlora_classifier = &dir_list + .clone() + .into_iter() .filter(|x| x.contains("xlora_classifier.safetensors")) .collect::>(); if xlora_classifier.len() > 1 { @@ -94,7 +96,9 @@ pub fn get_xlora_paths( // Get the path for the xlora config by checking all for valid versions. // NOTE(EricLBuehler): Remove this functionality because all configs should be deserializable - let xlora_configs = &api_dir_list!(api, model_id, true) + let xlora_configs = &dir_list + .clone() + .into_iter() .filter(|x| x.contains("xlora_config.json")) .collect::>(); if xlora_configs.len() > 1 { @@ -135,7 +139,8 @@ pub fn get_xlora_paths( }); // If there are adapters in the ordering file, get their names and remote paths - let adapter_files = api_dir_list!(api, model_id, true) + let adapter_files = dir_list + .into_iter() .filter_map(|name| { if let Some(ref adapters) = xlora_order.adapters { for adapter_name in adapters { From c7402f66dc5e4a4789bc6275eae542ffad1e371a Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 17 Jun 2025 11:46:00 +0800 Subject: [PATCH 2/4] Support local cache for api_dir_list --- mistralrs-core/src/pipeline/macros.rs | 63 ++++++++++++++++++++------- mistralrs-core/src/pipeline/mod.rs | 5 +++ 2 files changed, 52 insertions(+), 16 deletions(-) diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index b12ac48ce2..f993b226a9 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -22,22 +22,53 @@ macro_rules! api_dir_list { .collect::>() .into_iter() } else { - $api.info() - .map(|repo| { - repo.siblings - .iter() - .map(|x| x.rfilename.clone()) - .collect::>() - }) - .unwrap_or_else(|e| { - if $should_panic { - panic!("Could not get directory listing from API: {:?}", e) - } else { - tracing::warn!("Could not get directory listing from API: {:?}", e); - Vec::::new() - } - }) - .into_iter() + let sanitized_id = std::path::Path::new($model_id) + .display() + .to_string() + .replace("/", "-"); + let cache_dir: std::path::PathBuf = std::env::var("HF_HUB_CACHE") + .map(std::path::PathBuf::from) + .unwrap_or("~/.cache/huggingface/hub/".into()); + let cache_file = cache_dir.join(format!("{sanitized_id}_repo_list.json")); + if std::path::Path::new(&cache_file).exists() { + use std::io::Read; + // Read from cache + let mut file = std::fs::File::open(&cache_file).expect("Could not open cache file"); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .expect("Could not read cache file"); + let cache: $crate::pipeline::FileListCache = + serde_json::from_str(&contents).expect("Could not parse cache JSON"); + tracing::info!("read from cache file {:?}", cache_file); + cache.files.into_iter() + } else { + $api.info() + .map(|repo| { + let files: Vec = repo + .siblings + .iter() + .map(|x| x.rfilename.clone()) + .collect::>(); + // Save to cache + let cache = $crate::pipeline::FileListCache { + files: files.clone(), + }; + let json = serde_json::to_string_pretty(&cache) + .expect("Could not serialize cache"); + let _ = std::fs::write(&cache_file, json); + tracing::info!("write to cache file {:?}", cache_file); + files + }) + .unwrap_or_else(|e| { + if $should_panic { + panic!("Could not get directory listing from API: {:?}", e) + } else { + tracing::warn!("Could not get directory listing from API: {:?}", e); + Vec::::new() + } + }) + .into_iter() + } } }; } diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 8272f3559a..94f4ea3e18 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -360,6 +360,11 @@ impl ForwardInputsResult { } } +#[derive(serde::Serialize, serde::Deserialize)] +pub(crate) struct FileListCache { + files: Vec, +} + #[async_trait::async_trait] pub trait Pipeline: Send From d6fdf01e199231f29e3f8880cc00a3ff3d1e469a Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 18 Jun 2025 21:30:24 +0800 Subject: [PATCH 3/4] Fix home folder for metal --- mistralrs-core/src/pipeline/macros.rs | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index f993b226a9..29516799fe 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -26,9 +26,21 @@ macro_rules! api_dir_list { .display() .to_string() .replace("/", "-"); + + let home_folder = if dirs::home_dir().is_some() { + let mut path = dirs::home_dir().unwrap(); + path.push(".cache/huggingface/hub/"); + if !path.exists() { + let _ = std::fs::create_dir_all(&path); + } + path + } else { + "./".into() + }; + let cache_dir: std::path::PathBuf = std::env::var("HF_HUB_CACHE") .map(std::path::PathBuf::from) - .unwrap_or("~/.cache/huggingface/hub/".into()); + .unwrap_or(home_folder.into()); let cache_file = cache_dir.join(format!("{sanitized_id}_repo_list.json")); if std::path::Path::new(&cache_file).exists() { use std::io::Read; @@ -55,8 +67,8 @@ macro_rules! api_dir_list { }; let json = serde_json::to_string_pretty(&cache) .expect("Could not serialize cache"); - let _ = std::fs::write(&cache_file, json); - tracing::info!("write to cache file {:?}", cache_file); + let ret = std::fs::write(&cache_file, json); + tracing::info!("write to cache file {:?}, {:?}", cache_file, ret); files }) .unwrap_or_else(|e| { From c13f9c2e104fdb5acdeb55ef4db4ab0c72a37c05 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 18 Jun 2025 21:58:37 +0800 Subject: [PATCH 4/4] Capitalized --- mistralrs-core/src/pipeline/macros.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index 29516799fe..474ce6eceb 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -51,7 +51,7 @@ macro_rules! api_dir_list { .expect("Could not read cache file"); let cache: $crate::pipeline::FileListCache = serde_json::from_str(&contents).expect("Could not parse cache JSON"); - tracing::info!("read from cache file {:?}", cache_file); + tracing::info!("Read from cache file {:?}", cache_file); cache.files.into_iter() } else { $api.info() @@ -68,7 +68,7 @@ macro_rules! api_dir_list { let json = serde_json::to_string_pretty(&cache) .expect("Could not serialize cache"); let ret = std::fs::write(&cache_file, json); - tracing::info!("write to cache file {:?}, {:?}", cache_file, ret); + tracing::info!("Write to cache file {:?}, {:?}", cache_file, ret); files }) .unwrap_or_else(|e| {