diff --git a/mistralrs-core/src/pipeline/macros.rs b/mistralrs-core/src/pipeline/macros.rs index 2fee07dde1..474ce6eceb 100644 --- a/mistralrs-core/src/pipeline/macros.rs +++ b/mistralrs-core/src/pipeline/macros.rs @@ -22,22 +22,65 @@ macro_rules! api_dir_list { .collect::>() .into_iter() } else { - $api.info() - .map(|repo| { - repo.siblings - .iter() - .map(|x| x.rfilename.clone()) - .collect::>() - }) - .unwrap_or_else(|e| { - if $should_panic { - panic!("Could not get directory listing from API: {:?}", e) - } else { - tracing::warn!("Could not get directory listing from API: {:?}", e); - Vec::::new() - } - }) - .into_iter() + let sanitized_id = std::path::Path::new($model_id) + .display() + .to_string() + .replace("/", "-"); + + let home_folder = if dirs::home_dir().is_some() { + let mut path = dirs::home_dir().unwrap(); + path.push(".cache/huggingface/hub/"); + if !path.exists() { + let _ = std::fs::create_dir_all(&path); + } + path + } else { + "./".into() + }; + + let cache_dir: std::path::PathBuf = std::env::var("HF_HUB_CACHE") + .map(std::path::PathBuf::from) + .unwrap_or(home_folder.into()); + let cache_file = cache_dir.join(format!("{sanitized_id}_repo_list.json")); + if std::path::Path::new(&cache_file).exists() { + use std::io::Read; + // Read from cache + let mut file = std::fs::File::open(&cache_file).expect("Could not open cache file"); + let mut contents = String::new(); + file.read_to_string(&mut contents) + .expect("Could not read cache file"); + let cache: $crate::pipeline::FileListCache = + serde_json::from_str(&contents).expect("Could not parse cache JSON"); + tracing::info!("Read from cache file {:?}", cache_file); + cache.files.into_iter() + } else { + $api.info() + .map(|repo| { + let files: Vec = repo + .siblings + .iter() + .map(|x| x.rfilename.clone()) + .collect::>(); + // Save to cache + let cache = $crate::pipeline::FileListCache { + files: files.clone(), + }; + let json = serde_json::to_string_pretty(&cache) + .expect("Could not serialize cache"); + let ret = std::fs::write(&cache_file, json); + tracing::info!("Write to cache file {:?}, {:?}", cache_file, ret); + files + }) + .unwrap_or_else(|e| { + if $should_panic { + panic!("Could not get directory listing from API: {:?}", e) + } else { + tracing::warn!("Could not get directory listing from API: {:?}", e); + Vec::::new() + } + }) + .into_iter() + } } }; } @@ -117,10 +160,9 @@ macro_rules! get_paths { revision.clone(), $this.xlora_order.as_ref(), )?; - let gen_conf = if $crate::api_dir_list!(api, model_id, false) - .collect::>() - .contains(&"generation_config.json".to_string()) - { + let dir_list = $crate::api_dir_list!(api, model_id, false).collect::>(); + + let gen_conf = if dir_list.contains(&"generation_config.json".to_string()) { info!("Loading `generation_config.json` at `{}`", $this.model_id); Some($crate::api_get_file!( api, @@ -130,10 +172,7 @@ macro_rules! get_paths { } else { None }; - let preprocessor_config = if $crate::api_dir_list!(api, model_id, false) - .collect::>() - .contains(&"preprocessor_config.json".to_string()) - { + let preprocessor_config = if dir_list.contains(&"preprocessor_config.json".to_string()) { info!("Loading `preprocessor_config.json` at `{}`", $this.model_id); Some($crate::api_get_file!( api, @@ -143,10 +182,7 @@ macro_rules! get_paths { } else { None }; - let processor_config = if $crate::api_dir_list!(api, model_id, false) - .collect::>() - .contains(&"processor_config.json".to_string()) - { + let processor_config = if dir_list.contains(&"processor_config.json".to_string()) { info!("Loading `processor_config.json` at `{}`", $this.model_id); Some($crate::api_get_file!( api, @@ -167,10 +203,7 @@ macro_rules! get_paths { model_id )) }; - let chat_template_json_filename = if $crate::api_dir_list!(api, model_id, false) - .collect::>() - .contains(&"chat_template.json".to_string()) - { + let chat_template_json_filename = if dir_list.contains(&"chat_template.json".to_string()) { info!("Loading `chat_template.json` at `{}`", $this.model_id); Some($crate::api_get_file!(api, "chat_template.json", model_id)) } else { diff --git a/mistralrs-core/src/pipeline/mod.rs b/mistralrs-core/src/pipeline/mod.rs index 8272f3559a..94f4ea3e18 100644 --- a/mistralrs-core/src/pipeline/mod.rs +++ b/mistralrs-core/src/pipeline/mod.rs @@ -360,6 +360,11 @@ impl ForwardInputsResult { } } +#[derive(serde::Serialize, serde::Deserialize)] +pub(crate) struct FileListCache { + files: Vec, +} + #[async_trait::async_trait] pub trait Pipeline: Send diff --git a/mistralrs-core/src/pipeline/paths.rs b/mistralrs-core/src/pipeline/paths.rs index f63ed71e5b..c25f05825a 100644 --- a/mistralrs-core/src/pipeline/paths.rs +++ b/mistralrs-core/src/pipeline/paths.rs @@ -78,9 +78,11 @@ pub fn get_xlora_paths( revision, )); let model_id = Path::new(&xlora_id); - + let dir_list = api_dir_list!(api, model_id, true).collect::>(); // Get the path for the xlora classifier - let xlora_classifier = &api_dir_list!(api, model_id, true) + let xlora_classifier = &dir_list + .clone() + .into_iter() .filter(|x| x.contains("xlora_classifier.safetensors")) .collect::>(); if xlora_classifier.len() > 1 { @@ -94,7 +96,9 @@ pub fn get_xlora_paths( // Get the path for the xlora config by checking all for valid versions. // NOTE(EricLBuehler): Remove this functionality because all configs should be deserializable - let xlora_configs = &api_dir_list!(api, model_id, true) + let xlora_configs = &dir_list + .clone() + .into_iter() .filter(|x| x.contains("xlora_config.json")) .collect::>(); if xlora_configs.len() > 1 { @@ -135,7 +139,8 @@ pub fn get_xlora_paths( }); // If there are adapters in the ordering file, get their names and remote paths - let adapter_files = api_dir_list!(api, model_id, true) + let adapter_files = dir_list + .into_iter() .filter_map(|name| { if let Some(ref adapters) = xlora_order.adapters { for adapter_name in adapters {