1+ use crate :: core:: utils:: download:: { download, DownloadEvent , DownloadEventType } ;
12use std:: path:: Path ;
3+ use std:: sync:: { Arc , Mutex } ;
4+ use tauri:: { AppHandle , Emitter } ;
25
3- pub async fn download_file (
6+ // task_id is only for compatibility with current Jan+Cortex we can remove it later
7+ // should we disallow custom save_dir? i.e. all HF models will be downloaded to
8+ // <app_dir>/models/<repo_id>/<branch>?
9+ // (only cortexso uses branch. normally people don't use it)
10+ #[ tauri:: command]
11+ pub async fn download_hf_repo (
12+ app : AppHandle ,
13+ task_id : & str ,
414 repo_id : & str ,
515 branch : & str ,
6- file_path : & str ,
7- save_path : & Path ,
8- ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
9- let url = format ! (
10- "https://huggingface.co/{}/resolve/{}/{}" ,
11- repo_id, branch, file_path
12- ) ;
13- super :: download:: download ( & url, save_path) . await
16+ save_dir : & Path ,
17+ ) -> Result < ( ) , String > {
18+ // TODO: check if it has been downloaded
19+ let files = list_files ( repo_id, branch)
20+ . await
21+ . map_err ( |e| format ! ( "Failed to list files {}" , e) ) ?;
22+
23+ // obtain total download size. emit download started event
24+ let info = DownloadEvent {
25+ task_id : task_id. to_string ( ) ,
26+ total_size : files. iter ( ) . map ( |f| f. size ) . sum ( ) ,
27+ downloaded_size : 0 ,
28+ download_type : "Model" . to_string ( ) ,
29+ event_type : DownloadEventType :: Started ,
30+ } ;
31+ let info_arc = Arc :: new ( Mutex :: new ( info) ) ;
32+ log:: info!( "Start download repo_id: {} branch: {}" , repo_id, branch) ;
33+ app. emit ( "download" , info_arc. lock ( ) . unwrap ( ) . clone ( ) )
34+ . unwrap ( ) ;
35+
36+ let download_result = async {
37+ // NOTE: currently we are downloading sequentially. we can spawn tokio tasks
38+ // to download files in parallel.
39+ for file in files {
40+ let url = format ! (
41+ "https://huggingface.co/{}/resolve/{}/{}" ,
42+ repo_id, branch, file. path
43+ ) ;
44+ let full_path = save_dir. join ( & file. path ) ;
45+
46+ // update download progress. clone app handle and info_arc
47+ // to move them into the closure
48+ let callback = {
49+ let app = app. clone ( ) ;
50+ let info_arc = Arc :: clone ( & info_arc) ;
51+ move |size| {
52+ let mut info = info_arc. lock ( ) . unwrap ( ) ;
53+ info. event_type = DownloadEventType :: Updated ;
54+ info. downloaded_size += size;
55+ app. emit ( "download" , info. clone ( ) ) . unwrap ( ) ;
56+ }
57+ } ;
58+ download ( & url, & full_path, Some ( callback) )
59+ . await
60+ . map_err ( |e| format ! ( "Failed to download file {}: {}" , file. path, e) ) ?;
61+ }
62+ Ok ( ( ) )
63+ }
64+ . await ;
65+ match download_result {
66+ Ok ( _) => {
67+ let mut info = info_arc. lock ( ) . unwrap ( ) ;
68+ info. event_type = DownloadEventType :: Success ;
69+ log:: info!( "Finished download repo_id: {} branch: {}" , repo_id, branch) ;
70+ app. emit ( "download" , info. clone ( ) ) . unwrap ( ) ;
71+ Ok ( ( ) )
72+ }
73+ Err ( e) => {
74+ // on failure, remove the directory to restore the original state
75+ if save_dir. exists ( ) {
76+ // NOTE: we don't check error here
77+ let _ = std:: fs:: remove_dir_all ( save_dir) ;
78+ }
79+ log:: info!( "Failed to download repo_id: {} branch: {}" , repo_id, branch) ;
80+ // TODO: check what cortex and Jan does on download error
81+ // app.emit("download", info.clone()).unwrap();
82+ Err ( e)
83+ }
84+ }
1485}
1586
16- pub async fn list_files (
87+ #[ derive( Debug ) ]
88+ struct FileInfo {
89+ path : String ,
90+ size : u64 ,
91+ }
92+
93+ async fn list_files (
1794 repo_id : & str ,
1895 branch : & str ,
19- ) -> Result < Vec < String > , Box < dyn std:: error:: Error > > {
96+ ) -> Result < Vec < FileInfo > , Box < dyn std:: error:: Error > > {
2097 let mut files = vec ! [ ] ;
2198 let client = reqwest:: Client :: new ( ) ;
2299
23- // do recursion with a stack
100+ // DFS with a stack (similar to recursion)
24101 let mut stack = vec ! [ "" . to_string( ) ] ;
25102 while let Some ( directory) = stack. pop ( ) {
26103 let url = format ! (
@@ -38,14 +115,19 @@ pub async fn list_files(
38115 . into ( ) ) ;
39116 }
40117
118+ // this struct is only used for internal deserialization
41119 #[ derive( serde:: Deserialize ) ]
42120 struct Item {
43121 r#type : String ,
44122 path : String ,
123+ size : u64 ,
45124 }
46125 for item in resp. json :: < Vec < Item > > ( ) . await ?. into_iter ( ) {
47126 match item. r#type . as_str ( ) {
48- "file" => files. push ( item. path ) ,
127+ "file" => files. push ( FileInfo {
128+ path : item. path ,
129+ size : item. size ,
130+ } ) ,
49131 "directory" => stack. push ( item. path ) ,
50132 _ => { }
51133 }
@@ -59,44 +141,8 @@ pub async fn list_files(
59141mod tests {
60142 use super :: * ;
61143
62- #[ tokio:: test]
63- async fn test_download_file ( ) {
64- let repo_id = "openai-community/gpt2" ;
65- let branch = "main" ;
66- let file_path = "config.json" ;
67- let save_path = std:: path:: PathBuf :: from ( "subdir/test_config.json" ) ;
68-
69- if let Some ( parent) = save_path. parent ( ) {
70- if parent. exists ( ) {
71- std:: fs:: remove_dir_all ( parent) . unwrap ( ) ;
72- }
73- }
74-
75- let result = download_file ( repo_id, branch, file_path, & save_path) . await ;
76- assert ! ( result. is_ok( ) , "{}" , result. unwrap_err( ) ) ;
77- assert ! ( save_path. exists( ) ) ;
78-
79- // Read the file and verify its content
80- let file_content = std:: fs:: read_to_string ( & save_path) . unwrap ( ) ;
81- let json_result: Result < serde_json:: Value , _ > = serde_json:: from_str ( & file_content) ;
82- assert ! ( json_result. is_ok( ) , "Downloaded file is not valid JSON" ) ;
83-
84- if let Ok ( json) = json_result {
85- assert ! ( json. is_object( ) , "JSON root should be an object" ) ;
86- assert_eq ! (
87- json. get( "model_type" )
88- . and_then( |v| v. as_str( ) )
89- . unwrap_or( "" ) ,
90- "gpt2" ,
91- "model_type should be gpt2"
92- ) ;
93- }
94-
95- // Clean up
96- // NOTE: this will not run if there are errors
97- // TODO: use tempfile crate instead
98- std:: fs:: remove_dir_all ( save_path. parent ( ) . unwrap ( ) ) . unwrap ( ) ;
99- }
144+ // TODO: add test for download_hf_repo (need to find a small repo)
145+ // TODO: test when repo does not exist
100146
101147 #[ tokio:: test]
102148 async fn test_list_files ( ) {
@@ -107,7 +153,7 @@ mod tests {
107153 assert ! ( result. is_ok( ) , "{}" , result. unwrap_err( ) ) ;
108154 let files = result. unwrap ( ) ;
109155 assert ! (
110- files. iter( ) . any( |f| f == "config.json" ) ,
156+ files. iter( ) . any( |f| f. path == "config.json" ) ,
111157 "config.json should be in the list"
112158 ) ;
113159 }
0 commit comments