Skip to content

Commit 7a4383d

Browse files
committed
refactor: Simplify Tauri plugin calls and enhance 'Flash Attention' setting
This commit introduces significant improvements to the llama.cpp extension, focusing on the 'Flash Attention' setting and refactoring Tauri plugin interactions for better code clarity and maintenance. The backend interaction is streamlined by removing the unnecessary `libraryPath` argument from the Tauri plugin commands for loading models and listing devices. * **Simplified API Calls:** The `loadLlamaModel`, `unloadLlamaModel`, and `get_devices` functions in both the extension and the Tauri plugin now manage the library path internally based on the backend executable's location. * **Decoupled Logic:** The extension (`src/index.ts`) now uses the new, simplified Tauri plugin functions, which enhances modularity and reduces boilerplate code in the extension. * **Type Consistency:** Added `UnloadResult` interface to `guest-js/index.ts` for consistency. * **Updated UI Control:** The 'Flash Attention' setting in `settings.json` is changed from a boolean checkbox to a string-based dropdown, offering **'auto'**, **'on'**, and **'off'** options. * **Improved Logic:** The extension logic in `src/index.ts` is updated to correctly handle the new string-based `flash_attn` configuration. It now passes the string value (`'auto'`, `'on'`, or `'off'`) directly as a command-line argument to the llama.cpp backend, simplifying the version-checking logic previously required for older llama.cpp versions. The old, complex logic tied to specific backend versions is removed. This refactoring cleans up the extension's codebase and moves environment and path setup concerns into the Tauri plugin where they are most relevant.
1 parent 746dbc6 commit 7a4383d

File tree

6 files changed

+45
-51
lines changed

6 files changed

+45
-51
lines changed

extensions/llamacpp-extension/settings.json

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,14 @@
149149
"key": "flash_attn",
150150
"title": "Flash Attention",
151151
"description": "Enable Flash Attention for optimized performance.",
152-
"controllerType": "checkbox",
152+
"controllerType": "dropdown",
153153
"controllerProps": {
154-
"value": false
154+
"value": "auto",
155+
"options": [
156+
{ "value": "auto", "name": "Auto" },
157+
{ "value": "on", "name": "ON" },
158+
{ "value": "off", "name": "OFF" }
159+
]
155160
}
156161
},
157162
{

extensions/llamacpp-extension/src/index.ts

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@ import { invoke } from '@tauri-apps/api/core'
3838
import { getProxyConfig } from './util'
3939
import { basename } from '@tauri-apps/api/path'
4040
import {
41+
loadLlamaModel,
4142
readGgufMetadata,
4243
getModelSize,
4344
isModelSupported,
4445
planModelLoadInternal,
46+
unloadLlamaModel,
4547
} from '@janhq/tauri-plugin-llamacpp-api'
4648
import { getSystemUsage, getSystemInfo } from '@janhq/tauri-plugin-hardware-api'
4749

@@ -69,7 +71,7 @@ type LlamacppConfig = {
6971
device: string
7072
split_mode: string
7173
main_gpu: number
72-
flash_attn: boolean
74+
flash_attn: string
7375
cont_batching: boolean
7476
no_mmap: boolean
7577
mlock: boolean
@@ -1646,14 +1648,11 @@ export default class llamacpp_extension extends AIEngine {
16461648
args.push('--split-mode', cfg.split_mode)
16471649
if (cfg.main_gpu !== undefined && cfg.main_gpu != 0)
16481650
args.push('--main-gpu', String(cfg.main_gpu))
1651+
// Note: Older llama.cpp versions are no longer supported
1652+
if (cfg.flash_attn !== undefined || cfg.flash_attn === '') args.push('--flash-attn', String(cfg.flash_attn)) //default: auto = ON when supported
16491653

16501654
// Boolean flags
16511655
if (cfg.ctx_shift) args.push('--context-shift')
1652-
if (Number(version.replace(/^b/, '')) >= 6325) {
1653-
if (!cfg.flash_attn) args.push('--flash-attn', 'off') //default: auto = ON when supported
1654-
} else {
1655-
if (cfg.flash_attn) args.push('--flash-attn')
1656-
}
16571656
if (cfg.cont_batching) args.push('--cont-batching')
16581657
args.push('--no-mmap')
16591658
if (cfg.mlock) args.push('--mlock')
@@ -1688,20 +1687,9 @@ export default class llamacpp_extension extends AIEngine {
16881687

16891688
logger.info('Calling Tauri command llama_load with args:', args)
16901689
const backendPath = await getBackendExePath(backend, version)
1691-
const libraryPath = await joinPath([await this.getProviderPath(), 'lib'])
16921690

16931691
try {
1694-
// TODO: add LIBRARY_PATH
1695-
const sInfo = await invoke<SessionInfo>(
1696-
'plugin:llamacpp|load_llama_model',
1697-
{
1698-
backendPath,
1699-
libraryPath,
1700-
args,
1701-
envs,
1702-
isEmbedding,
1703-
}
1704-
)
1692+
const sInfo = await loadLlamaModel(backendPath, args, envs)
17051693
return sInfo
17061694
} catch (error) {
17071695
logger.error('Error in load command:\n', error)
@@ -1717,12 +1705,7 @@ export default class llamacpp_extension extends AIEngine {
17171705
const pid = sInfo.pid
17181706
try {
17191707
// Pass the PID as the session_id
1720-
const result = await invoke<UnloadResult>(
1721-
'plugin:llamacpp|unload_llama_model',
1722-
{
1723-
pid: pid,
1724-
}
1725-
)
1708+
const result = await unloadLlamaModel(pid)
17261709

17271710
// If successful, remove from active sessions
17281711
if (result.success) {
@@ -2042,7 +2025,10 @@ export default class llamacpp_extension extends AIEngine {
20422025
if (sysInfo?.os_type === 'linux' && Array.isArray(sysInfo.gpus)) {
20432026
const usage = await getSystemUsage()
20442027
if (usage && Array.isArray(usage.gpus)) {
2045-
const uuidToUsage: Record<string, { total_memory: number; used_memory: number }> = {}
2028+
const uuidToUsage: Record<
2029+
string,
2030+
{ total_memory: number; used_memory: number }
2031+
> = {}
20462032
for (const u of usage.gpus as any[]) {
20472033
if (u && typeof u.uuid === 'string') {
20482034
uuidToUsage[u.uuid] = u
@@ -2082,7 +2068,10 @@ export default class llamacpp_extension extends AIEngine {
20822068
typeof u.used_memory === 'number'
20832069
) {
20842070
const total = Math.max(0, Math.floor(u.total_memory))
2085-
const free = Math.max(0, Math.floor(u.total_memory - u.used_memory))
2071+
const free = Math.max(
2072+
0,
2073+
Math.floor(u.total_memory - u.used_memory)
2074+
)
20862075
return { ...dev, mem: total, free }
20872076
}
20882077
}

src-tauri/plugins/tauri-plugin-llamacpp/guest-js/index.ts

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@ import { invoke } from '@tauri-apps/api/core'
22

33
// Types
44
export interface SessionInfo {
5-
pid: number
6-
port: number
7-
model_id: string
8-
model_path: string
9-
api_key: string
5+
pid: number;
6+
port: number;
7+
model_id: string;
8+
model_path: string;
9+
api_key: string;
10+
mmproj_path?: string;
11+
}
12+
13+
export interface UnloadResult {
14+
success: boolean;
15+
error?: string;
1016
}
1117

1218
export interface DeviceInfo {
@@ -29,19 +35,17 @@ export async function cleanupLlamaProcesses(): Promise<void> {
2935
// LlamaCpp server commands
3036
export async function loadLlamaModel(
3137
backendPath: string,
32-
libraryPath?: string,
33-
args: string[] = [],
34-
isEmbedding: boolean = false
38+
args: string[],
39+
envs: Record<string, string>
3540
): Promise<SessionInfo> {
3641
return await invoke('plugin:llamacpp|load_llama_model', {
3742
backendPath,
38-
libraryPath,
3943
args,
40-
isEmbedding,
44+
envs
4145
})
4246
}
4347

44-
export async function unloadLlamaModel(pid: number): Promise<void> {
48+
export async function unloadLlamaModel(pid: number): Promise<UnloadResult> {
4549
return await invoke('plugin:llamacpp|unload_llama_model', { pid })
4650
}
4751

src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ pub struct UnloadResult {
4141
pub async fn load_llama_model<R: Runtime>(
4242
app_handle: tauri::AppHandle<R>,
4343
backend_path: &str,
44-
library_path: Option<&str>,
4544
mut args: Vec<String>,
4645
envs: HashMap<String, String>,
4746
is_embedding: bool,
@@ -52,7 +51,7 @@ pub async fn load_llama_model<R: Runtime>(
5251
log::info!("Attempting to launch server at path: {:?}", backend_path);
5352
log::info!("Using arguments: {:?}", args);
5453

55-
validate_binary_path(backend_path)?;
54+
let bin_path = validate_binary_path(backend_path)?;
5655

5756
let port = parse_port_from_args(&args);
5857
let model_path_pb = validate_model_path(&mut args)?;
@@ -83,11 +82,11 @@ pub async fn load_llama_model<R: Runtime>(
8382
let model_id = extract_arg_value(&args, "-a");
8483

8584
// Configure the command to run the server
86-
let mut command = Command::new(backend_path);
85+
let mut command = Command::new(&bin_path);
8786
command.args(args);
8887
command.envs(envs);
8988

90-
setup_library_path(library_path, &mut command);
89+
setup_library_path(bin_path.parent().and_then(|p| p.to_str()), &mut command);
9190
command.stdout(Stdio::piped());
9291
command.stderr(Stdio::piped());
9392
setup_windows_process_flags(&mut command);
@@ -280,10 +279,9 @@ pub async fn unload_llama_model<R: Runtime>(
280279
#[tauri::command]
281280
pub async fn get_devices(
282281
backend_path: &str,
283-
library_path: Option<&str>,
284282
envs: HashMap<String, String>,
285283
) -> ServerResult<Vec<DeviceInfo>> {
286-
get_devices_from_backend(backend_path, library_path, envs).await
284+
get_devices_from_backend(backend_path, envs).await
287285
}
288286

289287
/// Generate API key using HMAC-SHA256

src-tauri/plugins/tauri-plugin-llamacpp/src/device.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,19 @@ pub struct DeviceInfo {
1919

2020
pub async fn get_devices_from_backend(
2121
backend_path: &str,
22-
library_path: Option<&str>,
2322
envs: HashMap<String, String>,
2423
) -> ServerResult<Vec<DeviceInfo>> {
2524
log::info!("Getting devices from server at path: {:?}", backend_path);
2625

27-
validate_binary_path(backend_path)?;
26+
let bin_path = validate_binary_path(backend_path)?;
2827

2928
// Configure the command to run the server with --list-devices
30-
let mut command = Command::new(backend_path);
29+
let mut command = Command::new(&bin_path);
3130
command.arg("--list-devices");
3231
command.envs(envs);
3332

3433
// Set up library path
35-
setup_library_path(library_path, &mut command);
34+
setup_library_path(bin_path.parent().and_then(|p| p.to_str()), &mut command);
3635

3736
command.stdout(Stdio::piped());
3837
command.stderr(Stdio::piped());
@@ -410,4 +409,4 @@ AnotherInvalid
410409
assert_eq!(result[0].id, "Vulkan0");
411410
assert_eq!(result[1].id, "CUDA0");
412411
}
413-
}
412+
}

src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/utils.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ pub async fn estimate_kv_cache_internal(
6262
ctx_size: Option<u64>,
6363
) -> Result<KVCacheEstimate, KVCacheError> {
6464
log::info!("Received ctx_size parameter: {:?}", ctx_size);
65-
log::info!("Received model metadata:\n{:?}", &meta);
6665
let arch = meta
6766
.get("general.architecture")
6867
.ok_or(KVCacheError::ArchitectureNotFound)?;

0 commit comments

Comments
 (0)