Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,12 @@ nix = "=0.30.1"

[target.'cfg(windows)'.dependencies]
libc = "0.2.172"
windows-sys = { version = "0.60.2", features = ["Win32_Storage_FileSystem"] }

[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
tauri-plugin-updater = "2"
once_cell = "1.18"
tauri-plugin-single-instance = { version = "2.0.0", features = ["deep-link"] }

[target.'cfg(windows)'.dev-dependencies]
tempfile = "3.20.0"
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,39 @@ pub struct DeviceInfo {
pub free: i32,
}

#[cfg(windows)]
use std::os::windows::ffi::OsStrExt;

#[cfg(windows)]
use std::ffi::OsStr;

#[cfg(windows)]
use windows_sys::Win32::Storage::FileSystem::GetShortPathNameW;

#[cfg(windows)]
pub fn get_short_path<P: AsRef<std::path::Path>>(path: P) -> Option<String> {
let wide: Vec<u16> = OsStr::new(path.as_ref())
.encode_wide()
.chain(Some(0))
.collect();

let mut buffer = vec![0u16; 260];
let len = unsafe { GetShortPathNameW(wide.as_ptr(), buffer.as_mut_ptr(), buffer.len() as u32) };

if len > 0 {
Some(String::from_utf16_lossy(&buffer[..len as usize]))
} else {
None
}
}

// --- Load Command ---
#[tauri::command]
pub async fn load_llama_model(
state: State<'_, AppState>,
backend_path: &str,
library_path: Option<&str>,
args: Vec<String>,
mut args: Vec<String>,
) -> ServerResult<SessionInfo> {
let mut process_map = state.llama_server_process.lock().await;

Expand Down Expand Up @@ -105,13 +131,38 @@ pub async fn load_llama_model(
8080
}
};

let model_path = args
// FOR MODEL PATH; TODO: DO SIMILARLY FOR MMPROJ PATH
let model_path_index = args
.iter()
.position(|arg| arg == "-m")
.and_then(|i| args.get(i + 1))
.cloned()
.unwrap_or_default();
.ok_or(ServerError::LlamacppError("Missing `-m` flag".into()))?;

let model_path = args
.get(model_path_index + 1)
.ok_or(ServerError::LlamacppError("Missing path after `-m`".into()))?
.clone();

let model_path_pb = PathBuf::from(model_path);
if !model_path_pb.exists() {
return Err(ServerError::LlamacppError(format!(
"Invalid or inaccessible model path: {}",
model_path_pb.display().to_string(),
)));
}
#[cfg(windows)]
{
// use short path on Windows
if let Some(short) = get_short_path(&model_path_pb) {
args[model_path_index + 1] = short;
} else {
args[model_path_index + 1] = model_path_pb.display().to_string();
}
}
#[cfg(not(windows))]
{
args[model_path_index + 1] = model_path_pb.display().to_string();
}
// -----------------------------------------------------------------

let api_key = args
.iter()
Expand Down Expand Up @@ -181,7 +232,6 @@ pub async fn load_llama_model(

// Create channels for communication between tasks
let (ready_tx, mut ready_rx) = mpsc::channel::<bool>(1);
let (error_tx, mut error_rx) = mpsc::channel::<String>(1);

// Spawn task to monitor stdout for readiness
let _stdout_task = tokio::spawn(async move {
Expand Down Expand Up @@ -228,20 +278,10 @@ pub async fn load_llama_model(

// Check for critical error indicators that should stop the process
let line_lower = line.to_string().to_lowercase();
if line_lower.contains("error loading model")
|| line_lower.contains("unknown model architecture")
|| line_lower.contains("fatal")
|| line_lower.contains("cuda error")
|| line_lower.contains("out of memory")
|| line_lower.contains("error")
|| line_lower.contains("failed")
{
let _ = error_tx.send(line.to_string()).await;
}
// Check for readiness indicator - llama-server outputs this when ready
else if line.contains("server is listening on")
|| line.contains("starting the main loop")
|| line.contains("server listening on")
if line_lower.contains("server is listening on")
|| line_lower.contains("starting the main loop")
|| line_lower.contains("server listening on")
{
log::info!("Server appears to be ready based on stderr: '{}'", line);
let _ = ready_tx.send(true).await;
Expand Down Expand Up @@ -279,26 +319,6 @@ pub async fn load_llama_model(
log::info!("Server is ready to accept requests!");
break;
}
// Error occurred
Some(error_msg) = error_rx.recv() => {
log::error!("Server encountered an error: {}", error_msg);

// Give process a moment to exit naturally
tokio::time::sleep(Duration::from_millis(100)).await;

// Check if process already exited
if let Some(status) = child.try_wait()? {
log::info!("Process exited with code {:?}", status);
return Err(ServerError::LlamacppError(error_msg));
} else {
log::info!("Process still running, killing it...");
let _ = child.kill().await;
}

// Get full stderr output
let stderr_output = stderr_task.await.unwrap_or_default();
return Err(ServerError::LlamacppError(format!("Error: {}\n\nFull stderr:\n{}", error_msg, stderr_output)));
}
// Check for process exit more frequently
_ = tokio::time::sleep(Duration::from_millis(50)) => {
// Check if process exited
Expand Down Expand Up @@ -332,7 +352,7 @@ pub async fn load_llama_model(
pid: pid.clone(),
port: port,
model_id: model_id,
model_path: model_path,
model_path: model_path_pb.display().to_string(),
api_key: api_key,
};

Expand Down Expand Up @@ -714,6 +734,9 @@ pub fn is_port_available(port: u16) -> bool {
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[cfg(windows)]
use tempfile;

#[test]
fn test_parse_multiple_devices() {
Expand Down Expand Up @@ -899,4 +922,41 @@ Vulkan1: AMD Radeon Graphics (RADV GFX1151) (87722 MiB, 87722 MiB free)"#;
let (_start, content) = result.unwrap();
assert_eq!(content, "8128 MiB, 8128 MiB free");
}
#[test]
fn test_path_with_uncommon_dir_names() {
const UNCOMMON_DIR_NAME: &str = "тест-你好-éàç-🚀";
#[cfg(windows)]
{
let dir = tempfile::tempdir().expect("Failed to create temp dir");
let long_path = dir.path().join(UNCOMMON_DIR_NAME);
std::fs::create_dir(&long_path)
.expect("Failed to create test directory with non-ASCII name");
let short_path = get_short_path(&long_path);
assert!(
short_path.is_ascii(),
"The resulting short path must be composed of only ASCII characters. Got: {}",
short_path
);
assert!(
PathBuf::from(&short_path).exists(),
"The returned short path must exist on the filesystem"
);
assert_ne!(
short_path,
long_path.to_str().unwrap(),
"Short path should not be the same as the long path"
);
}
#[cfg(not(windows))]
{
// On Unix, paths are typically UTF-8 and there's no "short path" concept.
let long_path_str = format!("/tmp/{}", UNCOMMON_DIR_NAME);
let path_buf = PathBuf::from(&long_path_str);
let displayed_path = path_buf.display().to_string();
assert_eq!(
displayed_path, long_path_str,
"Path with non-ASCII characters should be preserved exactly on non-Windows platforms"
);
}
}
}
Loading