Skip to content

Commit 901f1ff

Browse files
committed
add cancel download
1 parent 42068ad commit 901f1ff

File tree

6 files changed

+81
-4
lines changed

6 files changed

+81
-4
lines changed

extensions/model-extension/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ export default class JanModelExtension extends ModelExtension {
160160
* @returns {Promise<void>} A promise that resolves when the download has been cancelled.
161161
*/
162162
async cancelModelPull(model: string): Promise<void> {
163+
return invoke<void>("cancel_download_task", {taskId: model}).catch(console.error)
163164
/**
164165
* Sending DELETE to /models/pull/{id} endpoint to cancel a model pull
165166
*/

src-tauri/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "mai
4141
uuid = { version = "1.7", features = ["v4"] }
4242
env = "1.0.1"
4343
futures-util = "0.3.31"
44+
tokio-util = "0.7.14"
4445

4546
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
4647
tauri-plugin-updater = "2"

src-tauri/src/core/state.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::{collections::HashMap, sync::Arc};
22

3+
use crate::core::utils::download::DownloadManagerState;
34
use rand::{distributions::Alphanumeric, Rng};
45
use rmcp::{service::RunningService, RoleClient};
56
use tokio::sync::Mutex;
@@ -8,6 +9,7 @@ use tokio::sync::Mutex;
89
pub struct AppState {
910
pub app_token: Option<String>,
1011
pub mcp_servers: Arc<Mutex<HashMap<String, RunningService<RoleClient, ()>>>>,
12+
pub download_manager: Arc<Mutex<DownloadManagerState>>,
1113
}
1214
pub fn generate_app_token() -> String {
1315
rand::thread_rng()

src-tauri/src/core/utils/download.rs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
1+
use crate::core::state::AppState;
12
use futures_util::StreamExt;
3+
use std::collections::HashMap;
24
use std::path::Path;
35
use std::time::Duration;
6+
use tauri::State;
47
use tokio::fs::File;
58
use tokio::io::AsyncWriteExt;
9+
use tokio_util::sync::CancellationToken;
10+
11+
#[derive(Default)]
12+
pub struct DownloadManagerState {
13+
pub cancel_tokens: HashMap<String, CancellationToken>,
14+
}
615

716
// this is to emulate the current way of downloading files by Cortex + Jan
817
// we can change this later
@@ -27,10 +36,10 @@ pub struct DownloadEvent {
2736
pub async fn download<F>(
2837
url: &str,
2938
save_path: &Path,
39+
cancel_token: Option<CancellationToken>,
3040
mut callback: Option<F>,
3141
) -> Result<(), Box<dyn std::error::Error>>
3242
where
33-
// F: FnMut(u64) + Send + 'static,
3443
F: FnMut(u64),
3544
{
3645
let client = reqwest::Client::builder()
@@ -58,7 +67,16 @@ where
5867

5968
// write chunk to file, and call callback if needed (e.g. download progress)
6069
let mut stream = resp.bytes_stream();
70+
let mut is_cancelled = false;
6171
while let Some(chunk) = stream.next().await {
72+
if let Some(token) = cancel_token.as_ref() {
73+
if token.is_cancelled() {
74+
log::info!("Download cancelled: {}", url);
75+
is_cancelled = true;
76+
break;
77+
}
78+
}
79+
6280
let chunk = chunk?;
6381
file.write_all(&chunk).await?;
6482

@@ -67,6 +85,23 @@ where
6785
cb(chunk.len() as u64);
6886
}
6987
}
88+
89+
// cleanup
7090
file.flush().await?;
91+
if is_cancelled {
92+
// NOTE: we don't check error here
93+
let _ = std::fs::remove_file(save_path);
94+
}
7195
Ok(())
7296
}
97+
98+
#[tauri::command]
99+
pub async fn cancel_download_task(state: State<'_, AppState>, task_id: &str) -> Result<(), String> {
100+
let mut download_manager = state.download_manager.lock().await;
101+
if let Some(token) = download_manager.cancel_tokens.remove(task_id) {
102+
token.cancel();
103+
Ok(())
104+
} else {
105+
Err(format!("No download task with id {}", task_id))
106+
}
107+
}

src-tauri/src/core/utils/hf.rs

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
use crate::core::state::AppState;
12
use crate::core::utils::download::{download, DownloadEvent, DownloadEventType};
23
use std::path::Path;
34
use std::sync::{Arc, Mutex};
4-
use tauri::{AppHandle, Emitter};
5+
use tauri::{AppHandle, Emitter, State};
6+
use tokio_util::sync::CancellationToken;
57

68
// task_id is only for compatibility with current Jan+Cortex we can remove it later
79
// should we disallow custom save_dir? i.e. all HF models will be downloaded to
@@ -10,12 +12,22 @@ use tauri::{AppHandle, Emitter};
1012
#[tauri::command]
1113
pub async fn download_hf_repo(
1214
app: AppHandle,
15+
state: State<'_, AppState>,
1316
task_id: &str,
1417
repo_id: &str,
1518
branch: &str,
1619
save_dir: &Path,
1720
) -> Result<(), String> {
18-
// TODO: check if it has been downloaded
21+
// TODO: check if it has been/is being downloaded
22+
23+
// check if task_id already exists
24+
{
25+
let download_manager = state.download_manager.lock().await;
26+
if download_manager.cancel_tokens.contains_key(task_id) {
27+
return Err(format!("Task ID {} already exists", task_id));
28+
}
29+
}
30+
1931
let files = list_files(repo_id, branch)
2032
.await
2133
.map_err(|e| format!("Failed to list files {}", e))?;
@@ -33,6 +45,15 @@ pub async fn download_hf_repo(
3345
app.emit("download", info_arc.lock().unwrap().clone())
3446
.unwrap();
3547

48+
// insert cancel tokens
49+
let cancel_token = CancellationToken::new();
50+
{
51+
let mut download_manager = state.download_manager.lock().await;
52+
download_manager
53+
.cancel_tokens
54+
.insert(task_id.to_string(), cancel_token.clone());
55+
}
56+
3657
let download_result = async {
3758
// NOTE: currently we are downloading sequentially. we can spawn tokio tasks
3859
// to download files in parallel.
@@ -55,13 +76,27 @@ pub async fn download_hf_repo(
5576
app.emit("download", info.clone()).unwrap();
5677
}
5778
};
58-
download(&url, &full_path, Some(callback))
79+
download(&url, &full_path, Some(cancel_token.clone()), Some(callback))
5980
.await
6081
.map_err(|e| format!("Failed to download file {}: {}", file.path, e))?;
6182
}
6283
Ok(())
6384
}
6485
.await;
86+
87+
// cleanup
88+
{
89+
let mut download_manager = state.download_manager.lock().await;
90+
download_manager.cancel_tokens.remove(task_id);
91+
92+
if cancel_token.is_cancelled() && save_dir.exists() {
93+
// NOTE: we don't check error here
94+
let _ = std::fs::remove_dir_all(save_dir);
95+
}
96+
}
97+
98+
// report results
99+
// TODO: what if it is cancelled? do we still emit success?
65100
match download_result {
66101
Ok(_) => {
67102
let mut info = info_arc.lock().unwrap();

src-tauri/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use core::{
33
cmd::get_jan_data_folder_path,
44
setup::{self, setup_engine_binaries, setup_mcp, setup_sidecar},
55
state::{generate_app_token, AppState},
6+
utils::download::DownloadManagerState,
67
};
78
use std::{collections::HashMap, sync::Arc};
89

@@ -61,10 +62,12 @@ pub fn run() {
6162
core::threads::modify_thread_assistant,
6263
// Download
6364
core::utils::hf::download_hf_repo,
65+
core::utils::download::cancel_download_task,
6466
])
6567
.manage(AppState {
6668
app_token: Some(generate_app_token()),
6769
mcp_servers: Arc::new(Mutex::new(HashMap::new())),
70+
download_manager: Arc::new(Mutex::new(DownloadManagerState::default())),
6871
})
6972
.setup(|app| {
7073
app.handle().plugin(

0 commit comments

Comments
 (0)