Skip to content

Commit b2a8efd

Browse files
authored
refactor: Simplify Tauri plugin calls and update 'FA' setting (#6779)
* refactor: Simplify Tauri plugin calls and enhance 'Flash Attention' setting This commit introduces significant improvements to the llama.cpp extension, focusing on the 'Flash Attention' setting and refactoring Tauri plugin interactions for better code clarity and maintenance. The backend interaction is streamlined by removing the unnecessary `libraryPath` argument from the Tauri plugin commands for loading models and listing devices. * **Simplified API Calls:** The `loadLlamaModel`, `unloadLlamaModel`, and `get_devices` functions in both the extension and the Tauri plugin now manage the library path internally based on the backend executable's location. * **Decoupled Logic:** The extension (`src/index.ts`) now uses the new, simplified Tauri plugin functions, which enhances modularity and reduces boilerplate code in the extension. * **Type Consistency:** Added `UnloadResult` interface to `guest-js/index.ts` for consistency. * **Updated UI Control:** The 'Flash Attention' setting in `settings.json` is changed from a boolean checkbox to a string-based dropdown, offering **'auto'**, **'on'**, and **'off'** options. * **Improved Logic:** The extension logic in `src/index.ts` is updated to correctly handle the new string-based `flash_attn` configuration. It now passes the string value (`'auto'`, `'on'`, or `'off'`) directly as a command-line argument to the llama.cpp backend, simplifying the version-checking logic previously required for older llama.cpp versions. The old, complex logic tied to specific backend versions is removed. This refactoring cleans up the extension's codebase and moves environment and path setup concerns into the Tauri plugin where they are most relevant. * feat: Simplify backend architecture This commit introduces a functional flag for embedding models and refactors the backend detection logic for cleaner implementation. Key changes: - Embedding Support: The loadLlamaModel API and SessionInfo now include an isEmbedding: boolean flag. This allows the core process to differentiate and correctly initialize models intended for embedding tasks. - Backend Naming Simplification (Refactor): Consolidated the CPU-specific backend tags (e.g., win-noavx-x64, win-avx2-x64) into generic *-common_cpus-x64 variants (e.g., win-common_cpus-x64). This streamlines supported backend detection. - File Structure Update: Changed the download path for CUDA runtime libraries (cudart) to place them inside the specific backend's directory (/build/bin/) rather than a shared lib folder, improving asset isolation. * fix: compare * fix mmap settings and adjust flash attention * fix: correct flash_attn and main_gpu flag checks in llamacpp extension Previously the condition for `flash_attn` was always truthy, causing unnecessary or incorrect `--flash-attn` arguments to be added. The `main_gpu` check also used a loose inequality which could match values that were not intended. The updated logic uses strict comparison and correctly handles the empty string case, ensuring the command line arguments are generated only when appropriate.
1 parent 7b634f0 commit b2a8efd

File tree

7 files changed

+71
-95
lines changed

7 files changed

+71
-95
lines changed

extensions/llamacpp-extension/settings.json

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,14 @@
149149
"key": "flash_attn",
150150
"title": "Flash Attention",
151151
"description": "Enable Flash Attention for optimized performance.",
152-
"controllerType": "checkbox",
152+
"controllerType": "dropdown",
153153
"controllerProps": {
154-
"value": false
154+
"value": "auto",
155+
"options": [
156+
{ "value": "auto", "name": "Auto" },
157+
{ "value": "on", "name": "ON" },
158+
{ "value": "off", "name": "OFF" }
159+
]
155160
}
156161
},
157162
{

extensions/llamacpp-extension/src/backend.ts

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -102,50 +102,27 @@ export async function listSupportedBackends(): Promise<
102102
// TODO: fetch versions from the server?
103103
// TODO: select CUDA version based on driver version
104104
if (sysType == 'windows-x86_64') {
105-
// NOTE: if a machine supports AVX2, should we include noavx and avx?
106-
supportedBackends.push('win-noavx-x64')
107-
if (features.avx) supportedBackends.push('win-avx-x64')
108-
if (features.avx2) supportedBackends.push('win-avx2-x64')
109-
if (features.avx512) supportedBackends.push('win-avx512-x64')
105+
supportedBackends.push('win-common_cpus-x64')
110106
if (features.cuda11) {
111-
if (features.avx512) supportedBackends.push('win-avx512-cuda-cu11.7-x64')
112-
else if (features.avx2) supportedBackends.push('win-avx2-cuda-cu11.7-x64')
113-
else if (features.avx) supportedBackends.push('win-avx-cuda-cu11.7-x64')
114-
else supportedBackends.push('win-noavx-cuda-cu11.7-x64')
107+
supportedBackends.push('win-cuda-11-common_cpus-x64')
115108
}
116109
if (features.cuda12) {
117-
if (features.avx512) supportedBackends.push('win-avx512-cuda-cu12.0-x64')
118-
else if (features.avx2) supportedBackends.push('win-avx2-cuda-cu12.0-x64')
119-
else if (features.avx) supportedBackends.push('win-avx-cuda-cu12.0-x64')
120-
else supportedBackends.push('win-noavx-cuda-cu12.0-x64')
110+
supportedBackends.push('win-cuda-12-common_cpus-x64')
121111
}
122-
if (features.vulkan) supportedBackends.push('win-vulkan-x64')
112+
if (features.vulkan) supportedBackends.push('win-vulkan-common_cpus-x64')
123113
}
124114
// not available yet, placeholder for future
125115
else if (sysType === 'windows-aarch64' || sysType === 'windows-arm64') {
126116
supportedBackends.push('win-arm64')
127117
} else if (sysType === 'linux-x86_64' || sysType === 'linux-x86') {
128-
supportedBackends.push('linux-noavx-x64')
129-
if (features.avx) supportedBackends.push('linux-avx-x64')
130-
if (features.avx2) supportedBackends.push('linux-avx2-x64')
131-
if (features.avx512) supportedBackends.push('linux-avx512-x64')
118+
supportedBackends.push('linux-common_cpus-x64')
132119
if (features.cuda11) {
133-
if (features.avx512)
134-
supportedBackends.push('linux-avx512-cuda-cu11.7-x64')
135-
else if (features.avx2)
136-
supportedBackends.push('linux-avx2-cuda-cu11.7-x64')
137-
else if (features.avx) supportedBackends.push('linux-avx-cuda-cu11.7-x64')
138-
else supportedBackends.push('linux-noavx-cuda-cu11.7-x64')
120+
supportedBackends.push('linux-cuda-11-common_cpus-x64')
139121
}
140122
if (features.cuda12) {
141-
if (features.avx512)
142-
supportedBackends.push('linux-avx512-cuda-cu12.0-x64')
143-
else if (features.avx2)
144-
supportedBackends.push('linux-avx2-cuda-cu12.0-x64')
145-
else if (features.avx) supportedBackends.push('linux-avx-cuda-cu12.0-x64')
146-
else supportedBackends.push('linux-noavx-cuda-cu12.0-x64')
123+
supportedBackends.push('linux-cuda-12-common_cpus-x64')
147124
}
148-
if (features.vulkan) supportedBackends.push('linux-vulkan-x64')
125+
if (features.vulkan) supportedBackends.push('linux-vulkan-common_cpus-x64')
149126
}
150127
// not available yet, placeholder for future
151128
else if (sysType === 'linux-aarch64' || sysType === 'linux-arm64') {
@@ -230,10 +207,7 @@ export async function downloadBackend(
230207
version: string,
231208
source: 'github' | 'cdn' = 'github'
232209
): Promise<void> {
233-
const janDataFolderPath = await getJanDataFolderPath()
234-
const llamacppPath = await joinPath([janDataFolderPath, 'llamacpp'])
235210
const backendDir = await getBackendDir(backend, version)
236-
const libDir = await joinPath([llamacppPath, 'lib'])
237211

238212
const downloadManager = window.core.extensionManager.getByName(
239213
'@janhq/download-extension'
@@ -265,7 +239,7 @@ export async function downloadBackend(
265239
source === 'github'
266240
? `https://github.com/janhq/llama.cpp/releases/download/${version}/cudart-llama-bin-${platformName}-cu11.7-x64.tar.gz`
267241
: `https://catalog.jan.ai/llama.cpp/releases/${version}/cudart-llama-bin-${platformName}-cu11.7-x64.tar.gz`,
268-
save_path: await joinPath([libDir, 'cuda11.tar.gz']),
242+
save_path: await joinPath([backendDir, 'build', 'bin', 'cuda11.tar.gz']),
269243
proxy: proxyConfig,
270244
})
271245
} else if (backend.includes('cu12.0') && !(await _isCudaInstalled('12.0'))) {
@@ -274,7 +248,7 @@ export async function downloadBackend(
274248
source === 'github'
275249
? `https://github.com/janhq/llama.cpp/releases/download/${version}/cudart-llama-bin-${platformName}-cu12.0-x64.tar.gz`
276250
: `https://catalog.jan.ai/llama.cpp/releases/${version}/cudart-llama-bin-${platformName}-cu12.0-x64.tar.gz`,
277-
save_path: await joinPath([libDir, 'cuda12.tar.gz']),
251+
save_path: await joinPath([backendDir, 'build', 'bin', 'cuda12.tar.gz']),
278252
proxy: proxyConfig,
279253
})
280254
}
@@ -344,8 +318,8 @@ async function _getSupportedFeatures() {
344318
}
345319

346320
// https://docs.nvidia.com/deploy/cuda-compatibility/#cuda-11-and-later-defaults-to-minor-version-compatibility
347-
let minCuda11DriverVersion
348-
let minCuda12DriverVersion
321+
let minCuda11DriverVersion: string
322+
let minCuda12DriverVersion: string
349323
if (sysInfo.os_type === 'linux') {
350324
minCuda11DriverVersion = '450.80.02'
351325
minCuda12DriverVersion = '525.60.13'

extensions/llamacpp-extension/src/index.ts

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,12 @@ import { invoke } from '@tauri-apps/api/core'
3838
import { getProxyConfig } from './util'
3939
import { basename } from '@tauri-apps/api/path'
4040
import {
41+
loadLlamaModel,
4142
readGgufMetadata,
4243
getModelSize,
4344
isModelSupported,
4445
planModelLoadInternal,
46+
unloadLlamaModel,
4547
} from '@janhq/tauri-plugin-llamacpp-api'
4648
import { getSystemUsage, getSystemInfo } from '@janhq/tauri-plugin-hardware-api'
4749

@@ -69,7 +71,7 @@ type LlamacppConfig = {
6971
device: string
7072
split_mode: string
7173
main_gpu: number
72-
flash_attn: boolean
74+
flash_attn: string
7375
cont_batching: boolean
7476
no_mmap: boolean
7577
mlock: boolean
@@ -549,9 +551,9 @@ export default class llamacpp_extension extends AIEngine {
549551

550552
// Helper to map backend string to a priority category
551553
const getBackendCategory = (backendString: string): string | undefined => {
552-
if (backendString.includes('cu12.0')) return 'cuda-cu12.0'
553-
if (backendString.includes('cu11.7')) return 'cuda-cu11.7'
554-
if (backendString.includes('vulkan')) return 'vulkan'
554+
if (backendString.includes('cuda-12-common_cpus')) return 'cuda-cu12.0'
555+
if (backendString.includes('cuda-11-common_cpus')) return 'cuda-cu11.7'
556+
if (backendString.includes('vulkan-common_cpus')) return 'vulkan'
555557
if (backendString.includes('avx512')) return 'avx512'
556558
if (backendString.includes('avx2')) return 'avx2'
557559
if (
@@ -1644,18 +1646,20 @@ export default class llamacpp_extension extends AIEngine {
16441646
if (cfg.device.length > 0) args.push('--device', cfg.device)
16451647
if (cfg.split_mode.length > 0 && cfg.split_mode != 'layer')
16461648
args.push('--split-mode', cfg.split_mode)
1647-
if (cfg.main_gpu !== undefined && cfg.main_gpu != 0)
1649+
if (cfg.main_gpu !== undefined && cfg.main_gpu !== 0)
16481650
args.push('--main-gpu', String(cfg.main_gpu))
1651+
// Note: Older llama.cpp versions are no longer supported
1652+
if (
1653+
cfg.flash_attn !== undefined ||
1654+
!cfg.flash_attn ||
1655+
cfg.flash_attn !== ''
1656+
)
1657+
args.push('--flash-attn', String(cfg.flash_attn)) //default: auto = ON when supported
16491658

16501659
// Boolean flags
16511660
if (cfg.ctx_shift) args.push('--context-shift')
1652-
if (Number(version.replace(/^b/, '')) >= 6325) {
1653-
if (!cfg.flash_attn) args.push('--flash-attn', 'off') //default: auto = ON when supported
1654-
} else {
1655-
if (cfg.flash_attn) args.push('--flash-attn')
1656-
}
16571661
if (cfg.cont_batching) args.push('--cont-batching')
1658-
args.push('--no-mmap')
1662+
if (cfg.no_mmap) args.push('--no-mmap')
16591663
if (cfg.mlock) args.push('--mlock')
16601664
if (cfg.no_kv_offload) args.push('--no-kv-offload')
16611665
if (isEmbedding) {
@@ -1667,7 +1671,7 @@ export default class llamacpp_extension extends AIEngine {
16671671
if (cfg.cache_type_k && cfg.cache_type_k != 'f16')
16681672
args.push('--cache-type-k', cfg.cache_type_k)
16691673
if (
1670-
cfg.flash_attn &&
1674+
cfg.flash_attn !== 'on' &&
16711675
cfg.cache_type_v != 'f16' &&
16721676
cfg.cache_type_v != 'f32'
16731677
) {
@@ -1688,20 +1692,9 @@ export default class llamacpp_extension extends AIEngine {
16881692

16891693
logger.info('Calling Tauri command llama_load with args:', args)
16901694
const backendPath = await getBackendExePath(backend, version)
1691-
const libraryPath = await joinPath([await this.getProviderPath(), 'lib'])
16921695

16931696
try {
1694-
// TODO: add LIBRARY_PATH
1695-
const sInfo = await invoke<SessionInfo>(
1696-
'plugin:llamacpp|load_llama_model',
1697-
{
1698-
backendPath,
1699-
libraryPath,
1700-
args,
1701-
envs,
1702-
isEmbedding,
1703-
}
1704-
)
1697+
const sInfo = await loadLlamaModel(backendPath, args, envs, isEmbedding)
17051698
return sInfo
17061699
} catch (error) {
17071700
logger.error('Error in load command:\n', error)
@@ -1717,12 +1710,7 @@ export default class llamacpp_extension extends AIEngine {
17171710
const pid = sInfo.pid
17181711
try {
17191712
// Pass the PID as the session_id
1720-
const result = await invoke<UnloadResult>(
1721-
'plugin:llamacpp|unload_llama_model',
1722-
{
1723-
pid: pid,
1724-
}
1725-
)
1713+
const result = await unloadLlamaModel(pid)
17261714

17271715
// If successful, remove from active sessions
17281716
if (result.success) {
@@ -2042,7 +2030,10 @@ export default class llamacpp_extension extends AIEngine {
20422030
if (sysInfo?.os_type === 'linux' && Array.isArray(sysInfo.gpus)) {
20432031
const usage = await getSystemUsage()
20442032
if (usage && Array.isArray(usage.gpus)) {
2045-
const uuidToUsage: Record<string, { total_memory: number; used_memory: number }> = {}
2033+
const uuidToUsage: Record<
2034+
string,
2035+
{ total_memory: number; used_memory: number }
2036+
> = {}
20462037
for (const u of usage.gpus as any[]) {
20472038
if (u && typeof u.uuid === 'string') {
20482039
uuidToUsage[u.uuid] = u
@@ -2082,7 +2073,10 @@ export default class llamacpp_extension extends AIEngine {
20822073
typeof u.used_memory === 'number'
20832074
) {
20842075
const total = Math.max(0, Math.floor(u.total_memory))
2085-
const free = Math.max(0, Math.floor(u.total_memory - u.used_memory))
2076+
const free = Math.max(
2077+
0,
2078+
Math.floor(u.total_memory - u.used_memory)
2079+
)
20862080
return { ...dev, mem: total, free }
20872081
}
20882082
}

src-tauri/plugins/tauri-plugin-llamacpp/guest-js/index.ts

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,18 @@ import { invoke } from '@tauri-apps/api/core'
22

33
// Types
44
export interface SessionInfo {
5-
pid: number
6-
port: number
7-
model_id: string
8-
model_path: string
9-
api_key: string
5+
pid: number;
6+
port: number;
7+
model_id: string;
8+
model_path: string;
9+
is_embedding: boolean
10+
api_key: string;
11+
mmproj_path?: string;
12+
}
13+
14+
export interface UnloadResult {
15+
success: boolean;
16+
error?: string;
1017
}
1118

1219
export interface DeviceInfo {
@@ -29,19 +36,19 @@ export async function cleanupLlamaProcesses(): Promise<void> {
2936
// LlamaCpp server commands
3037
export async function loadLlamaModel(
3138
backendPath: string,
32-
libraryPath?: string,
33-
args: string[] = [],
34-
isEmbedding: boolean = false
39+
args: string[],
40+
envs: Record<string, string>,
41+
isEmbedding: boolean
3542
): Promise<SessionInfo> {
3643
return await invoke('plugin:llamacpp|load_llama_model', {
3744
backendPath,
38-
libraryPath,
3945
args,
40-
isEmbedding,
46+
envs,
47+
isEmbedding
4148
})
4249
}
4350

44-
export async function unloadLlamaModel(pid: number): Promise<void> {
51+
export async function unloadLlamaModel(pid: number): Promise<UnloadResult> {
4552
return await invoke('plugin:llamacpp|unload_llama_model', { pid })
4653
}
4754

src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ pub struct UnloadResult {
4141
pub async fn load_llama_model<R: Runtime>(
4242
app_handle: tauri::AppHandle<R>,
4343
backend_path: &str,
44-
library_path: Option<&str>,
4544
mut args: Vec<String>,
4645
envs: HashMap<String, String>,
4746
is_embedding: bool,
@@ -52,7 +51,7 @@ pub async fn load_llama_model<R: Runtime>(
5251
log::info!("Attempting to launch server at path: {:?}", backend_path);
5352
log::info!("Using arguments: {:?}", args);
5453

55-
validate_binary_path(backend_path)?;
54+
let bin_path = validate_binary_path(backend_path)?;
5655

5756
let port = parse_port_from_args(&args);
5857
let model_path_pb = validate_model_path(&mut args)?;
@@ -83,11 +82,11 @@ pub async fn load_llama_model<R: Runtime>(
8382
let model_id = extract_arg_value(&args, "-a");
8483

8584
// Configure the command to run the server
86-
let mut command = Command::new(backend_path);
85+
let mut command = Command::new(&bin_path);
8786
command.args(args);
8887
command.envs(envs);
8988

90-
setup_library_path(library_path, &mut command);
89+
setup_library_path(bin_path.parent().and_then(|p| p.to_str()), &mut command);
9190
command.stdout(Stdio::piped());
9291
command.stderr(Stdio::piped());
9392
setup_windows_process_flags(&mut command);
@@ -280,10 +279,9 @@ pub async fn unload_llama_model<R: Runtime>(
280279
#[tauri::command]
281280
pub async fn get_devices(
282281
backend_path: &str,
283-
library_path: Option<&str>,
284282
envs: HashMap<String, String>,
285283
) -> ServerResult<Vec<DeviceInfo>> {
286-
get_devices_from_backend(backend_path, library_path, envs).await
284+
get_devices_from_backend(backend_path, envs).await
287285
}
288286

289287
/// Generate API key using HMAC-SHA256

src-tauri/plugins/tauri-plugin-llamacpp/src/device.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,19 @@ pub struct DeviceInfo {
1919

2020
pub async fn get_devices_from_backend(
2121
backend_path: &str,
22-
library_path: Option<&str>,
2322
envs: HashMap<String, String>,
2423
) -> ServerResult<Vec<DeviceInfo>> {
2524
log::info!("Getting devices from server at path: {:?}", backend_path);
2625

27-
validate_binary_path(backend_path)?;
26+
let bin_path = validate_binary_path(backend_path)?;
2827

2928
// Configure the command to run the server with --list-devices
30-
let mut command = Command::new(backend_path);
29+
let mut command = Command::new(&bin_path);
3130
command.arg("--list-devices");
3231
command.envs(envs);
3332

3433
// Set up library path
35-
setup_library_path(library_path, &mut command);
34+
setup_library_path(bin_path.parent().and_then(|p| p.to_str()), &mut command);
3635

3736
command.stdout(Stdio::piped());
3837
command.stderr(Stdio::piped());
@@ -410,4 +409,4 @@ AnotherInvalid
410409
assert_eq!(result[0].id, "Vulkan0");
411410
assert_eq!(result[1].id, "CUDA0");
412411
}
413-
}
412+
}

src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/utils.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ pub async fn estimate_kv_cache_internal(
6262
ctx_size: Option<u64>,
6363
) -> Result<KVCacheEstimate, KVCacheError> {
6464
log::info!("Received ctx_size parameter: {:?}", ctx_size);
65-
log::info!("Received model metadata:\n{:?}", &meta);
6665
let arch = meta
6766
.get("general.architecture")
6867
.ok_or(KVCacheError::ArchitectureNotFound)?;

0 commit comments

Comments
 (0)