diff --git a/src-tauri/Cargo.toml b/src-tauri/Cargo.toml index 0f334d1788..ca1a54bbac 100644 --- a/src-tauri/Cargo.toml +++ b/src-tauri/Cargo.toml @@ -63,8 +63,12 @@ nix = "=0.30.1" [target.'cfg(windows)'.dependencies] libc = "0.2.172" +windows-sys = { version = "0.60.2", features = ["Win32_Storage_FileSystem"] } [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] tauri-plugin-updater = "2" once_cell = "1.18" tauri-plugin-single-instance = { version = "2.0.0", features = ["deep-link"] } + +[target.'cfg(windows)'.dev-dependencies] +tempfile = "3.20.0" diff --git a/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs b/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs index ffa6cfe920..191eb5c6ee 100644 --- a/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs +++ b/src-tauri/src/core/utils/extensions/inference_llamacpp_extension/server.rs @@ -67,13 +67,39 @@ pub struct DeviceInfo { pub free: i32, } +#[cfg(windows)] +use std::os::windows::ffi::OsStrExt; + +#[cfg(windows)] +use std::ffi::OsStr; + +#[cfg(windows)] +use windows_sys::Win32::Storage::FileSystem::GetShortPathNameW; + +#[cfg(windows)] +pub fn get_short_path>(path: P) -> Option { + let wide: Vec = OsStr::new(path.as_ref()) + .encode_wide() + .chain(Some(0)) + .collect(); + + let mut buffer = vec![0u16; 260]; + let len = unsafe { GetShortPathNameW(wide.as_ptr(), buffer.as_mut_ptr(), buffer.len() as u32) }; + + if len > 0 { + Some(String::from_utf16_lossy(&buffer[..len as usize])) + } else { + None + } +} + // --- Load Command --- #[tauri::command] pub async fn load_llama_model( state: State<'_, AppState>, backend_path: &str, library_path: Option<&str>, - args: Vec, + mut args: Vec, ) -> ServerResult { let mut process_map = state.llama_server_process.lock().await; @@ -105,13 +131,38 @@ pub async fn load_llama_model( 8080 } }; - - let model_path = args + // FOR MODEL PATH; TODO: DO SIMILARLY FOR MMPROJ PATH + let model_path_index = args .iter() .position(|arg| arg == "-m") - .and_then(|i| args.get(i + 1)) - .cloned() - .unwrap_or_default(); + .ok_or(ServerError::LlamacppError("Missing `-m` flag".into()))?; + + let model_path = args + .get(model_path_index + 1) + .ok_or(ServerError::LlamacppError("Missing path after `-m`".into()))? + .clone(); + + let model_path_pb = PathBuf::from(model_path); + if !model_path_pb.exists() { + return Err(ServerError::LlamacppError(format!( + "Invalid or inaccessible model path: {}", + model_path_pb.display().to_string(), + ))); + } + #[cfg(windows)] + { + // use short path on Windows + if let Some(short) = get_short_path(&model_path_pb) { + args[model_path_index + 1] = short; + } else { + args[model_path_index + 1] = model_path_pb.display().to_string(); + } + } + #[cfg(not(windows))] + { + args[model_path_index + 1] = model_path_pb.display().to_string(); + } + // ----------------------------------------------------------------- let api_key = args .iter() @@ -181,7 +232,6 @@ pub async fn load_llama_model( // Create channels for communication between tasks let (ready_tx, mut ready_rx) = mpsc::channel::(1); - let (error_tx, mut error_rx) = mpsc::channel::(1); // Spawn task to monitor stdout for readiness let _stdout_task = tokio::spawn(async move { @@ -228,20 +278,10 @@ pub async fn load_llama_model( // Check for critical error indicators that should stop the process let line_lower = line.to_string().to_lowercase(); - if line_lower.contains("error loading model") - || line_lower.contains("unknown model architecture") - || line_lower.contains("fatal") - || line_lower.contains("cuda error") - || line_lower.contains("out of memory") - || line_lower.contains("error") - || line_lower.contains("failed") - { - let _ = error_tx.send(line.to_string()).await; - } // Check for readiness indicator - llama-server outputs this when ready - else if line.contains("server is listening on") - || line.contains("starting the main loop") - || line.contains("server listening on") + if line_lower.contains("server is listening on") + || line_lower.contains("starting the main loop") + || line_lower.contains("server listening on") { log::info!("Server appears to be ready based on stderr: '{}'", line); let _ = ready_tx.send(true).await; @@ -279,26 +319,6 @@ pub async fn load_llama_model( log::info!("Server is ready to accept requests!"); break; } - // Error occurred - Some(error_msg) = error_rx.recv() => { - log::error!("Server encountered an error: {}", error_msg); - - // Give process a moment to exit naturally - tokio::time::sleep(Duration::from_millis(100)).await; - - // Check if process already exited - if let Some(status) = child.try_wait()? { - log::info!("Process exited with code {:?}", status); - return Err(ServerError::LlamacppError(error_msg)); - } else { - log::info!("Process still running, killing it..."); - let _ = child.kill().await; - } - - // Get full stderr output - let stderr_output = stderr_task.await.unwrap_or_default(); - return Err(ServerError::LlamacppError(format!("Error: {}\n\nFull stderr:\n{}", error_msg, stderr_output))); - } // Check for process exit more frequently _ = tokio::time::sleep(Duration::from_millis(50)) => { // Check if process exited @@ -332,7 +352,7 @@ pub async fn load_llama_model( pid: pid.clone(), port: port, model_id: model_id, - model_path: model_path, + model_path: model_path_pb.display().to_string(), api_key: api_key, }; @@ -714,6 +734,9 @@ pub fn is_port_available(port: u16) -> bool { #[cfg(test)] mod tests { use super::*; + use std::path::PathBuf; + #[cfg(windows)] + use tempfile; #[test] fn test_parse_multiple_devices() { @@ -899,4 +922,41 @@ Vulkan1: AMD Radeon Graphics (RADV GFX1151) (87722 MiB, 87722 MiB free)"#; let (_start, content) = result.unwrap(); assert_eq!(content, "8128 MiB, 8128 MiB free"); } + #[test] + fn test_path_with_uncommon_dir_names() { + const UNCOMMON_DIR_NAME: &str = "тест-你好-éàç-🚀"; + #[cfg(windows)] + { + let dir = tempfile::tempdir().expect("Failed to create temp dir"); + let long_path = dir.path().join(UNCOMMON_DIR_NAME); + std::fs::create_dir(&long_path) + .expect("Failed to create test directory with non-ASCII name"); + let short_path = get_short_path(&long_path); + assert!( + short_path.is_ascii(), + "The resulting short path must be composed of only ASCII characters. Got: {}", + short_path + ); + assert!( + PathBuf::from(&short_path).exists(), + "The returned short path must exist on the filesystem" + ); + assert_ne!( + short_path, + long_path.to_str().unwrap(), + "Short path should not be the same as the long path" + ); + } + #[cfg(not(windows))] + { + // On Unix, paths are typically UTF-8 and there's no "short path" concept. + let long_path_str = format!("/tmp/{}", UNCOMMON_DIR_NAME); + let path_buf = PathBuf::from(&long_path_str); + let displayed_path = path_buf.display().to_string(); + assert_eq!( + displayed_path, long_path_str, + "Path with non-ASCII characters should be preserved exactly on non-Windows platforms" + ); + } + } }