Skip to content

Commit f1877f8

Browse files
committed
refactor: Simplify Tauri plugin calls and enhance 'Flash Attention' setting
This commit introduces significant improvements to the llama.cpp extension, focusing on the 'Flash Attention' setting and refactoring Tauri plugin interactions for better code clarity and maintenance. The backend interaction is streamlined by removing the unnecessary `libraryPath` argument from the Tauri plugin commands for loading models and listing devices. * **Simplified API Calls:** The `loadLlamaModel`, `unloadLlamaModel`, and `get_devices` functions in both the extension and the Tauri plugin now manage the library path internally based on the backend executable's location. * **Decoupled Logic:** The extension (`src/index.ts`) now uses the new, simplified Tauri plugin functions, which enhances modularity and reduces boilerplate code in the extension. * **Type Consistency:** Added `UnloadResult` interface to `guest-js/index.ts` for consistency. * **Updated UI Control:** The 'Flash Attention' setting in `settings.json` is changed from a boolean checkbox to a string-based dropdown, offering **'auto'**, **'on'**, and **'off'** options. * **Improved Logic:** The extension logic in `src/index.ts` is updated to correctly handle the new string-based `flash_attn` configuration. It now passes the string value (`'auto'`, `'on'`, or `'off'`) directly as a command-line argument to the llama.cpp backend, simplifying the version-checking logic previously required for older llama.cpp versions. The old, complex logic tied to specific backend versions is removed. This refactoring cleans up the extension's codebase and moves environment and path setup concerns into the Tauri plugin where they are most relevant.
1 parent 7b5060c commit f1877f8

File tree

6 files changed

+49
-54
lines changed

6 files changed

+49
-54
lines changed

extensions/llamacpp-extension/settings.json

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,14 @@
149149
"key": "flash_attn",
150150
"title": "Flash Attention",
151151
"description": "Enable Flash Attention for optimized performance.",
152-
"controllerType": "checkbox",
152+
"controllerType": "dropdown",
153153
"controllerProps": {
154-
"value": false
154+
"value": "auto",
155+
"options": [
156+
{ "value": "auto", "name": "Auto" },
157+
{ "value": "on", "name": "ON" },
158+
{ "value": "off", "name": "OFF" }
159+
]
155160
}
156161
},
157162
{

extensions/llamacpp-extension/src/index.ts

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@ import { invoke } from '@tauri-apps/api/core'
3838
import { getProxyConfig } from './util'
3939
import { basename } from '@tauri-apps/api/path'
4040
import {
41+
loadLlamaModel,
4142
readGgufMetadata,
4243
getModelSize,
4344
isModelSupported,
4445
planModelLoadInternal,
46+
unloadLlamaModel,
4547
} from '@janhq/tauri-plugin-llamacpp-api'
4648
import { getSystemUsage, getSystemInfo } from '@janhq/tauri-plugin-hardware-api'
4749

@@ -69,7 +71,7 @@ type LlamacppConfig = {
6971
device: string
7072
split_mode: string
7173
main_gpu: number
72-
flash_attn: boolean
74+
flash_attn: string
7375
cont_batching: boolean
7476
no_mmap: boolean
7577
mlock: boolean
@@ -333,14 +335,12 @@ export default class llamacpp_extension extends AIEngine {
333335
)
334336
// Clear the invalid stored preference
335337
this.clearStoredBackendType()
336-
bestAvailableBackendString = await this.determineBestBackend(
337-
version_backends
338-
)
338+
bestAvailableBackendString =
339+
await this.determineBestBackend(version_backends)
339340
}
340341
} else {
341-
bestAvailableBackendString = await this.determineBestBackend(
342-
version_backends
343-
)
342+
bestAvailableBackendString =
343+
await this.determineBestBackend(version_backends)
344344
}
345345

346346
let settings = structuredClone(SETTINGS)
@@ -1624,14 +1624,11 @@ export default class llamacpp_extension extends AIEngine {
16241624
args.push('--split-mode', cfg.split_mode)
16251625
if (cfg.main_gpu !== undefined && cfg.main_gpu != 0)
16261626
args.push('--main-gpu', String(cfg.main_gpu))
1627+
// Note: Older llama.cpp versions are no longer supported
1628+
if (cfg.flash_attn !== undefined || cfg.flash_attn === '') args.push('--flash-attn', String(cfg.flash_attn)) //default: auto = ON when supported
16271629

16281630
// Boolean flags
16291631
if (cfg.ctx_shift) args.push('--context-shift')
1630-
if (Number(version.replace(/^b/, '')) >= 6325) {
1631-
if (!cfg.flash_attn) args.push('--flash-attn', 'off') //default: auto = ON when supported
1632-
} else {
1633-
if (cfg.flash_attn) args.push('--flash-attn')
1634-
}
16351632
if (cfg.cont_batching) args.push('--cont-batching')
16361633
args.push('--no-mmap')
16371634
if (cfg.mlock) args.push('--mlock')
@@ -1666,19 +1663,9 @@ export default class llamacpp_extension extends AIEngine {
16661663

16671664
logger.info('Calling Tauri command llama_load with args:', args)
16681665
const backendPath = await getBackendExePath(backend, version)
1669-
const libraryPath = await joinPath([await this.getProviderPath(), 'lib'])
16701666

16711667
try {
1672-
// TODO: add LIBRARY_PATH
1673-
const sInfo = await invoke<SessionInfo>(
1674-
'plugin:llamacpp|load_llama_model',
1675-
{
1676-
backendPath,
1677-
libraryPath,
1678-
args,
1679-
envs,
1680-
}
1681-
)
1668+
const sInfo = await loadLlamaModel(backendPath, args, envs)
16821669
return sInfo
16831670
} catch (error) {
16841671
logger.error('Error in load command:\n', error)
@@ -1694,12 +1681,7 @@ export default class llamacpp_extension extends AIEngine {
16941681
const pid = sInfo.pid
16951682
try {
16961683
// Pass the PID as the session_id
1697-
const result = await invoke<UnloadResult>(
1698-
'plugin:llamacpp|unload_llama_model',
1699-
{
1700-
pid: pid,
1701-
}
1702-
)
1684+
const result = await unloadLlamaModel(pid)
17031685

17041686
// If successful, remove from active sessions
17051687
if (result.success) {
@@ -2019,7 +2001,10 @@ export default class llamacpp_extension extends AIEngine {
20192001
if (sysInfo?.os_type === 'linux' && Array.isArray(sysInfo.gpus)) {
20202002
const usage = await getSystemUsage()
20212003
if (usage && Array.isArray(usage.gpus)) {
2022-
const uuidToUsage: Record<string, { total_memory: number; used_memory: number }> = {}
2004+
const uuidToUsage: Record<
2005+
string,
2006+
{ total_memory: number; used_memory: number }
2007+
> = {}
20232008
for (const u of usage.gpus as any[]) {
20242009
if (u && typeof u.uuid === 'string') {
20252010
uuidToUsage[u.uuid] = u
@@ -2059,7 +2044,10 @@ export default class llamacpp_extension extends AIEngine {
20592044
typeof u.used_memory === 'number'
20602045
) {
20612046
const total = Math.max(0, Math.floor(u.total_memory))
2062-
const free = Math.max(0, Math.floor(u.total_memory - u.used_memory))
2047+
const free = Math.max(
2048+
0,
2049+
Math.floor(u.total_memory - u.used_memory)
2050+
)
20632051
return { ...dev, mem: total, free }
20642052
}
20652053
}

src-tauri/plugins/tauri-plugin-llamacpp/guest-js/index.ts

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@ import { invoke } from '@tauri-apps/api/core'
22

33
// Types
44
export interface SessionInfo {
5-
pid: number
6-
port: number
7-
model_id: string
8-
model_path: string
9-
api_key: string
5+
pid: number;
6+
port: number;
7+
model_id: string;
8+
model_path: string;
9+
api_key: string;
10+
mmproj_path?: string;
11+
}
12+
13+
export interface UnloadResult {
14+
success: boolean;
15+
error?: string;
1016
}
1117

1218
export interface DeviceInfo {
@@ -29,17 +35,17 @@ export async function cleanupLlamaProcesses(): Promise<void> {
2935
// LlamaCpp server commands
3036
export async function loadLlamaModel(
3137
backendPath: string,
32-
libraryPath?: string,
33-
args: string[] = []
38+
args: string[],
39+
envs: Record<string, string>
3440
): Promise<SessionInfo> {
3541
return await invoke('plugin:llamacpp|load_llama_model', {
3642
backendPath,
37-
libraryPath,
3843
args,
44+
envs
3945
})
4046
}
4147

42-
export async function unloadLlamaModel(pid: number): Promise<void> {
48+
export async function unloadLlamaModel(pid: number): Promise<UnloadResult> {
4349
return await invoke('plugin:llamacpp|unload_llama_model', { pid })
4450
}
4551

src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ pub struct UnloadResult {
4141
pub async fn load_llama_model<R: Runtime>(
4242
app_handle: tauri::AppHandle<R>,
4343
backend_path: &str,
44-
library_path: Option<&str>,
4544
mut args: Vec<String>,
4645
envs: HashMap<String, String>,
4746
) -> ServerResult<SessionInfo> {
@@ -51,7 +50,7 @@ pub async fn load_llama_model<R: Runtime>(
5150
log::info!("Attempting to launch server at path: {:?}", backend_path);
5251
log::info!("Using arguments: {:?}", args);
5352

54-
validate_binary_path(backend_path)?;
53+
let bin_path = validate_binary_path(backend_path)?;
5554

5655
let port = parse_port_from_args(&args);
5756
let model_path_pb = validate_model_path(&mut args)?;
@@ -82,11 +81,11 @@ pub async fn load_llama_model<R: Runtime>(
8281
let model_id = extract_arg_value(&args, "-a");
8382

8483
// Configure the command to run the server
85-
let mut command = Command::new(backend_path);
84+
let mut command = Command::new(&bin_path);
8685
command.args(args);
8786
command.envs(envs);
8887

89-
setup_library_path(library_path, &mut command);
88+
setup_library_path(bin_path.parent().and_then(|p| p.to_str()), &mut command);
9089
command.stdout(Stdio::piped());
9190
command.stderr(Stdio::piped());
9291
setup_windows_process_flags(&mut command);
@@ -278,10 +277,9 @@ pub async fn unload_llama_model<R: Runtime>(
278277
#[tauri::command]
279278
pub async fn get_devices(
280279
backend_path: &str,
281-
library_path: Option<&str>,
282280
envs: HashMap<String, String>,
283281
) -> ServerResult<Vec<DeviceInfo>> {
284-
get_devices_from_backend(backend_path, library_path, envs).await
282+
get_devices_from_backend(backend_path, envs).await
285283
}
286284

287285
/// Generate API key using HMAC-SHA256

src-tauri/plugins/tauri-plugin-llamacpp/src/device.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,19 @@ pub struct DeviceInfo {
1919

2020
pub async fn get_devices_from_backend(
2121
backend_path: &str,
22-
library_path: Option<&str>,
2322
envs: HashMap<String, String>,
2423
) -> ServerResult<Vec<DeviceInfo>> {
2524
log::info!("Getting devices from server at path: {:?}", backend_path);
2625

27-
validate_binary_path(backend_path)?;
26+
let bin_path = validate_binary_path(backend_path)?;
2827

2928
// Configure the command to run the server with --list-devices
30-
let mut command = Command::new(backend_path);
29+
let mut command = Command::new(&bin_path);
3130
command.arg("--list-devices");
3231
command.envs(envs);
3332

3433
// Set up library path
35-
setup_library_path(library_path, &mut command);
34+
setup_library_path(bin_path.parent().and_then(|p| p.to_str()), &mut command);
3635

3736
command.stdout(Stdio::piped());
3837
command.stderr(Stdio::piped());
@@ -410,4 +409,4 @@ AnotherInvalid
410409
assert_eq!(result[0].id, "Vulkan0");
411410
assert_eq!(result[1].id, "CUDA0");
412411
}
413-
}
412+
}

src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/utils.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ pub async fn estimate_kv_cache_internal(
6262
ctx_size: Option<u64>,
6363
) -> Result<KVCacheEstimate, KVCacheError> {
6464
log::info!("Received ctx_size parameter: {:?}", ctx_size);
65-
log::info!("Received model metadata:\n{:?}", &meta);
6665
let arch = meta
6766
.get("general.architecture")
6867
.ok_or(KVCacheError::ArchitectureNotFound)?;

0 commit comments

Comments
 (0)