Skip to content

Commit 42068ad

Browse files
committed
download with progress
1 parent 58f6774 commit 42068ad

File tree

7 files changed

+198
-69
lines changed

7 files changed

+198
-69
lines changed

extensions/model-extension/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
],
2727
"dependencies": {
2828
"@janhq/core": "../../core/package.tgz",
29+
"@tauri-apps/api": "^2.5.0",
2930
"ky": "^1.7.2",
3031
"p-queue": "^8.0.1"
3132
},

extensions/model-extension/src/index.ts

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import { invoke } from '@tauri-apps/api/core';
2+
import { listen } from '@tauri-apps/api/event';
13
import {
24
ModelExtension,
35
Model,
@@ -9,6 +11,7 @@ import {
911
ModelSource,
1012
extractInferenceParams,
1113
extractModelLoadParams,
14+
events,
1215
} from '@janhq/core'
1316
import { scanModelsFolder } from './legacy/model-json'
1417
import { deleteModelFiles } from './legacy/delete'
@@ -26,6 +29,14 @@ type Data<T> = {
2629
data: T[]
2730
}
2831

32+
type DownloadInfo = {
33+
task_id: string
34+
total_size: number
35+
downloaded_size: number
36+
download_type: string
37+
event_type: string
38+
}
39+
2940
/**
3041
* A extension for models
3142
*/
@@ -64,6 +75,29 @@ export default class JanModelExtension extends ModelExtension {
6475
this.updateCortexConfig({ huggingface_token: huggingfaceToken })
6576
}
6677

78+
// listen to tauri events
79+
// TODO: move this to core? i.e. forward tauri events to core events
80+
listen<DownloadInfo>('download', (event) => {
81+
let payload = event.payload
82+
let eventName = {
83+
Updated: 'onFileDownloadUpdate',
84+
Error: 'onFileDownloadError',
85+
Success: 'onFileDownloadSuccess',
86+
Stopped: 'onFileDownloadStopped',
87+
Started: 'onFileDownloadStarted',
88+
}[payload.event_type]
89+
90+
events.emit(eventName, {
91+
modelId: payload.task_id,
92+
percent: payload.downloaded_size / payload.total_size,
93+
size: {
94+
transferred: payload.downloaded_size,
95+
total: payload.total_size,
96+
},
97+
downloadType: payload.download_type,
98+
})
99+
})
100+
67101
// Sync with cortexsohub
68102
this.fetchCortexsoModels()
69103
}
@@ -92,6 +126,19 @@ export default class JanModelExtension extends ModelExtension {
92126
* @returns A Promise that resolves when the model is downloaded.
93127
*/
94128
async pullModel(model: string, id?: string, name?: string): Promise<void> {
129+
if (id == null && name == null) {
130+
let [modelName, branch] = model.split(":")
131+
return invoke<void>("get_jan_data_folder_path").then((path) => {
132+
// cortexso format
133+
return invoke<void>("download_hf_repo", {
134+
taskId: model,
135+
repoId: `cortexso/${modelName}`,
136+
branch: branch,
137+
saveDir: `${path}/models/cortex.so/${modelName}/${branch}`
138+
})
139+
}).catch(console.error)
140+
}
141+
95142
/**
96143
* Sending POST to /models/pull/{id} endpoint to pull the model
97144
*/

src-tauri/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ rand = "0.8"
3030
tauri-plugin-http = { version = "2", features = ["unsafe-headers"] }
3131
tauri-plugin-store = "2"
3232
hyper = { version = "0.14", features = ["server"] }
33-
reqwest = { version = "0.11", features = ["json", "blocking"] }
33+
reqwest = { version = "0.11", features = ["json", "blocking", "stream"] }
3434
tokio = { version = "1", features = ["full"] }
3535
rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "main", features = [
3636
"client",
@@ -40,6 +40,7 @@ rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "mai
4040
] }
4141
uuid = { version = "1.7", features = ["v4"] }
4242
env = "1.0.1"
43+
futures-util = "0.3.31"
4344

4445
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
4546
tauri-plugin-updater = "2"

src-tauri/src/core/mod.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,3 @@ pub mod setup;
66
pub mod state;
77
pub mod threads;
88
pub mod utils;
9-
pub mod download;
Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,44 @@
1-
use std::fs::File;
2-
use std::io::Write;
1+
use futures_util::StreamExt;
32
use std::path::Path;
43
use std::time::Duration;
4+
use tokio::fs::File;
5+
use tokio::io::AsyncWriteExt;
56

6-
pub async fn download(url: &str, save_path: &Path) -> Result<(), Box<dyn std::error::Error>> {
7+
// this is to emulate the current way of downloading files by Cortex + Jan
8+
// we can change this later
9+
#[derive(serde::Serialize, Clone, Debug)]
10+
pub enum DownloadEventType {
11+
Started,
12+
Updated,
13+
Success,
14+
Error,
15+
Stopped,
16+
}
17+
18+
#[derive(serde::Serialize, Clone, Debug)]
19+
pub struct DownloadEvent {
20+
pub task_id: String,
21+
pub total_size: u64,
22+
pub downloaded_size: u64,
23+
pub download_type: String, // TODO: make this an enum as well
24+
pub event_type: DownloadEventType,
25+
}
26+
27+
pub async fn download<F>(
28+
url: &str,
29+
save_path: &Path,
30+
mut callback: Option<F>,
31+
) -> Result<(), Box<dyn std::error::Error>>
32+
where
33+
// F: FnMut(u64) + Send + 'static,
34+
F: FnMut(u64),
35+
{
736
let client = reqwest::Client::builder()
837
.http2_keep_alive_timeout(Duration::from_secs(15))
938
.build()?;
1039

11-
let mut resp = client
12-
.get(url)
13-
.header("User-Agent", "rust-reqwest/huggingface-downloader")
14-
.send()
15-
.await?;
16-
17-
// Check if request was successful
40+
// NOTE: might want to add User-Agent header
41+
let resp = client.get(url).send().await?;
1842
if !resp.status().is_success() {
1943
return Err(format!(
2044
"Failed to download: HTTP status {}, {}",
@@ -23,17 +47,26 @@ pub async fn download(url: &str, save_path: &Path) -> Result<(), Box<dyn std::er
2347
)
2448
.into());
2549
}
50+
2651
// Create parent directories if they don't exist
2752
if let Some(parent) = save_path.parent() {
2853
if !parent.exists() {
2954
std::fs::create_dir_all(parent)?;
3055
}
3156
}
32-
let mut file = File::create(save_path)?;
57+
let mut file = File::create(save_path).await?;
3358

34-
while let Some(chunk) = resp.chunk().await? {
35-
file.write_all(&chunk)?;
36-
}
59+
// write chunk to file, and call callback if needed (e.g. download progress)
60+
let mut stream = resp.bytes_stream();
61+
while let Some(chunk) = stream.next().await {
62+
let chunk = chunk?;
63+
file.write_all(&chunk).await?;
3764

65+
// NOTE: might want to reduce frequency of callback e.g. every 1MB
66+
if let Some(cb) = callback.as_mut() {
67+
cb(chunk.len() as u64);
68+
}
69+
}
70+
file.flush().await?;
3871
Ok(())
3972
}

src-tauri/src/core/utils/hf.rs

Lines changed: 98 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,103 @@
1+
use crate::core::utils::download::{download, DownloadEvent, DownloadEventType};
12
use std::path::Path;
3+
use std::sync::{Arc, Mutex};
4+
use tauri::{AppHandle, Emitter};
25

3-
pub async fn download_file(
6+
// task_id is only for compatibility with current Jan+Cortex we can remove it later
7+
// should we disallow custom save_dir? i.e. all HF models will be downloaded to
8+
// <app_dir>/models/<repo_id>/<branch>?
9+
// (only cortexso uses branch. normally people don't use it)
10+
#[tauri::command]
11+
pub async fn download_hf_repo(
12+
app: AppHandle,
13+
task_id: &str,
414
repo_id: &str,
515
branch: &str,
6-
file_path: &str,
7-
save_path: &Path,
8-
) -> Result<(), Box<dyn std::error::Error>> {
9-
let url = format!(
10-
"https://huggingface.co/{}/resolve/{}/{}",
11-
repo_id, branch, file_path
12-
);
13-
super::download::download(&url, save_path).await
16+
save_dir: &Path,
17+
) -> Result<(), String> {
18+
// TODO: check if it has been downloaded
19+
let files = list_files(repo_id, branch)
20+
.await
21+
.map_err(|e| format!("Failed to list files {}", e))?;
22+
23+
// obtain total download size. emit download started event
24+
let info = DownloadEvent {
25+
task_id: task_id.to_string(),
26+
total_size: files.iter().map(|f| f.size).sum(),
27+
downloaded_size: 0,
28+
download_type: "Model".to_string(),
29+
event_type: DownloadEventType::Started,
30+
};
31+
let info_arc = Arc::new(Mutex::new(info));
32+
log::info!("Start download repo_id: {} branch: {}", repo_id, branch);
33+
app.emit("download", info_arc.lock().unwrap().clone())
34+
.unwrap();
35+
36+
let download_result = async {
37+
// NOTE: currently we are downloading sequentially. we can spawn tokio tasks
38+
// to download files in parallel.
39+
for file in files {
40+
let url = format!(
41+
"https://huggingface.co/{}/resolve/{}/{}",
42+
repo_id, branch, file.path
43+
);
44+
let full_path = save_dir.join(&file.path);
45+
46+
// update download progress. clone app handle and info_arc
47+
// to move them into the closure
48+
let callback = {
49+
let app = app.clone();
50+
let info_arc = Arc::clone(&info_arc);
51+
move |size| {
52+
let mut info = info_arc.lock().unwrap();
53+
info.event_type = DownloadEventType::Updated;
54+
info.downloaded_size += size;
55+
app.emit("download", info.clone()).unwrap();
56+
}
57+
};
58+
download(&url, &full_path, Some(callback))
59+
.await
60+
.map_err(|e| format!("Failed to download file {}: {}", file.path, e))?;
61+
}
62+
Ok(())
63+
}
64+
.await;
65+
match download_result {
66+
Ok(_) => {
67+
let mut info = info_arc.lock().unwrap();
68+
info.event_type = DownloadEventType::Success;
69+
log::info!("Finished download repo_id: {} branch: {}", repo_id, branch);
70+
app.emit("download", info.clone()).unwrap();
71+
Ok(())
72+
}
73+
Err(e) => {
74+
// on failure, remove the directory to restore the original state
75+
if save_dir.exists() {
76+
// NOTE: we don't check error here
77+
let _ = std::fs::remove_dir_all(save_dir);
78+
}
79+
log::info!("Failed to download repo_id: {} branch: {}", repo_id, branch);
80+
// TODO: check what cortex and Jan does on download error
81+
// app.emit("download", info.clone()).unwrap();
82+
Err(e)
83+
}
84+
}
1485
}
1586

16-
pub async fn list_files(
87+
#[derive(Debug)]
88+
struct FileInfo {
89+
path: String,
90+
size: u64,
91+
}
92+
93+
async fn list_files(
1794
repo_id: &str,
1895
branch: &str,
19-
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
96+
) -> Result<Vec<FileInfo>, Box<dyn std::error::Error>> {
2097
let mut files = vec![];
2198
let client = reqwest::Client::new();
2299

23-
// do recursion with a stack
100+
// DFS with a stack (similar to recursion)
24101
let mut stack = vec!["".to_string()];
25102
while let Some(directory) = stack.pop() {
26103
let url = format!(
@@ -38,14 +115,19 @@ pub async fn list_files(
38115
.into());
39116
}
40117

118+
// this struct is only used for internal deserialization
41119
#[derive(serde::Deserialize)]
42120
struct Item {
43121
r#type: String,
44122
path: String,
123+
size: u64,
45124
}
46125
for item in resp.json::<Vec<Item>>().await?.into_iter() {
47126
match item.r#type.as_str() {
48-
"file" => files.push(item.path),
127+
"file" => files.push(FileInfo {
128+
path: item.path,
129+
size: item.size,
130+
}),
49131
"directory" => stack.push(item.path),
50132
_ => {}
51133
}
@@ -59,44 +141,8 @@ pub async fn list_files(
59141
mod tests {
60142
use super::*;
61143

62-
#[tokio::test]
63-
async fn test_download_file() {
64-
let repo_id = "openai-community/gpt2";
65-
let branch = "main";
66-
let file_path = "config.json";
67-
let save_path = std::path::PathBuf::from("subdir/test_config.json");
68-
69-
if let Some(parent) = save_path.parent() {
70-
if parent.exists() {
71-
std::fs::remove_dir_all(parent).unwrap();
72-
}
73-
}
74-
75-
let result = download_file(repo_id, branch, file_path, &save_path).await;
76-
assert!(result.is_ok(), "{}", result.unwrap_err());
77-
assert!(save_path.exists());
78-
79-
// Read the file and verify its content
80-
let file_content = std::fs::read_to_string(&save_path).unwrap();
81-
let json_result: Result<serde_json::Value, _> = serde_json::from_str(&file_content);
82-
assert!(json_result.is_ok(), "Downloaded file is not valid JSON");
83-
84-
if let Ok(json) = json_result {
85-
assert!(json.is_object(), "JSON root should be an object");
86-
assert_eq!(
87-
json.get("model_type")
88-
.and_then(|v| v.as_str())
89-
.unwrap_or(""),
90-
"gpt2",
91-
"model_type should be gpt2"
92-
);
93-
}
94-
95-
// Clean up
96-
// NOTE: this will not run if there are errors
97-
// TODO: use tempfile crate instead
98-
std::fs::remove_dir_all(save_path.parent().unwrap()).unwrap();
99-
}
144+
// TODO: add test for download_hf_repo (need to find a small repo)
145+
// TODO: test when repo does not exist
100146

101147
#[tokio::test]
102148
async fn test_list_files() {
@@ -107,7 +153,7 @@ mod tests {
107153
assert!(result.is_ok(), "{}", result.unwrap_err());
108154
let files = result.unwrap();
109155
assert!(
110-
files.iter().any(|f| f == "config.json"),
156+
files.iter().any(|f| f.path == "config.json"),
111157
"config.json should be in the list"
112158
);
113159
}

src-tauri/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ pub fn run() {
5858
core::threads::delete_message,
5959
core::threads::get_thread_assistant,
6060
core::threads::create_thread_assistant,
61-
core::threads::modify_thread_assistant
61+
core::threads::modify_thread_assistant,
62+
// Download
63+
core::utils::hf::download_hf_repo,
6264
])
6365
.manage(AppState {
6466
app_token: Some(generate_app_token()),

0 commit comments

Comments
 (0)