@@ -6,9 +6,13 @@ use std::ffi::c_void;
66mod ffi {
77 use super :: * ;
88 #[ repr( C ) ]
9- pub struct M40llmCudaContext { _private : [ u8 ; 0 ] }
9+ pub struct M40llmCudaContext {
10+ _private : [ u8 ; 0 ] ,
11+ }
1012 #[ repr( C ) ]
11- pub struct M40llmKVCache { _private : [ u8 ; 0 ] }
13+ pub struct M40llmKVCache {
14+ _private : [ u8 ; 0 ] ,
15+ }
1216
1317 extern "C" {
1418 pub fn m40llm_create_context ( device_id : i32 ) -> * mut M40llmCudaContext ;
@@ -64,7 +68,10 @@ impl CudaContext {
6468 if ptr. is_null ( ) {
6569 return Err ( anyhow ! ( "m40llm_create_context returned null" ) ) ;
6670 }
67- Ok ( Self { device_id, raw : ptr } )
71+ Ok ( Self {
72+ device_id,
73+ raw : ptr,
74+ } )
6875 }
6976 #[ cfg( not( feature = "cuda" ) ) ]
7077 {
@@ -73,18 +80,37 @@ impl CudaContext {
7380 }
7481
7582 #[ cfg( feature = "cuda" ) ]
76- pub fn create_kvcache ( & self , max_seq_len : u32 , max_batch_size : u32 , num_heads : u32 , head_dim : u32 ) -> Result < * mut ffi:: M40llmKVCache > {
77- let kv = unsafe { ffi:: m40llm_kvcache_create ( self . raw , max_seq_len, max_batch_size, num_heads, head_dim) } ;
78- if kv. is_null ( ) { return Err ( anyhow ! ( "m40llm_kvcache_create returned null" ) ) ; }
83+ pub fn create_kvcache (
84+ & self ,
85+ max_seq_len : u32 ,
86+ max_batch_size : u32 ,
87+ num_heads : u32 ,
88+ head_dim : u32 ,
89+ ) -> Result < * mut ffi:: M40llmKVCache > {
90+ let kv = unsafe {
91+ ffi:: m40llm_kvcache_create ( self . raw , max_seq_len, max_batch_size, num_heads, head_dim)
92+ } ;
93+ if kv. is_null ( ) {
94+ return Err ( anyhow ! ( "m40llm_kvcache_create returned null" ) ) ;
95+ }
7996 Ok ( kv)
8097 }
8198
8299 pub fn upload_weights ( & self , data : & [ u8 ] ) -> Result < * mut c_void > {
83100 #[ cfg( feature = "cuda" ) ]
84101 {
85102 let mut d_ptr: * mut c_void = std:: ptr:: null_mut ( ) ;
86- let rc = unsafe { ffi:: m40llm_upload_weights ( self . raw , data. as_ptr ( ) as * const _ , data. len ( ) , & mut d_ptr as * mut _ ) } ;
87- if rc != 0 { return Err ( anyhow ! ( "m40llm_upload_weights failed: {rc}" ) ) ; }
103+ let rc = unsafe {
104+ ffi:: m40llm_upload_weights (
105+ self . raw ,
106+ data. as_ptr ( ) as * const _ ,
107+ data. len ( ) ,
108+ & mut d_ptr as * mut _ ,
109+ )
110+ } ;
111+ if rc != 0 {
112+ return Err ( anyhow ! ( "m40llm_upload_weights failed: {rc}" ) ) ;
113+ }
88114 Ok ( d_ptr)
89115 }
90116 #[ cfg( not( feature = "cuda" ) ) ]
@@ -104,8 +130,12 @@ impl CudaContext {
104130 ) -> Result < ( ) > {
105131 #[ cfg( feature = "cuda" ) ]
106132 {
107- let rc = unsafe { ffi:: m40llm_gemm_f16_storage_f32_compute ( self . raw , d_a, d_b, d_c, m, n, k) } ;
108- if rc != 0 { return Err ( anyhow ! ( "m40llm_gemm_f16_storage_f32_compute failed: {rc}" ) ) ; }
133+ let rc = unsafe {
134+ ffi:: m40llm_gemm_f16_storage_f32_compute ( self . raw , d_a, d_b, d_c, m, n, k)
135+ } ;
136+ if rc != 0 {
137+ return Err ( anyhow ! ( "m40llm_gemm_f16_storage_f32_compute failed: {rc}" ) ) ;
138+ }
109139 Ok ( ( ) )
110140 }
111141 #[ cfg( not( feature = "cuda" ) ) ]
@@ -119,15 +149,19 @@ impl CudaContext {
119149 #[ cfg( feature = "cuda" ) ]
120150 {
121151 let rc = unsafe { ffi:: m40llm_start_persistent_decode ( self . raw ) } ;
122- if rc != 0 { return Err ( anyhow ! ( "m40llm_start_persistent_decode failed: {rc}" ) ) ; }
152+ if rc != 0 {
153+ return Err ( anyhow ! ( "m40llm_start_persistent_decode failed: {rc}" ) ) ;
154+ }
123155 }
124156 Ok ( ( ) )
125157 }
126158 pub fn stop_persistent_decode ( & self ) -> Result < ( ) > {
127159 #[ cfg( feature = "cuda" ) ]
128160 {
129161 let rc = unsafe { ffi:: m40llm_stop_persistent_decode ( self . raw ) } ;
130- if rc != 0 { return Err ( anyhow ! ( "m40llm_stop_persistent_decode failed: {rc}" ) ) ; }
162+ if rc != 0 {
163+ return Err ( anyhow ! ( "m40llm_stop_persistent_decode failed: {rc}" ) ) ;
164+ }
131165 }
132166 Ok ( ( ) )
133167 }
@@ -151,16 +185,33 @@ pub struct KVCache {
151185}
152186
153187impl KVCache {
154- pub fn new_with_context ( ctx : & CudaContext , max_seq_len : u32 , max_batch_size : u32 , num_heads : u32 , head_dim : u32 ) -> Result < Self > {
188+ pub fn new_with_context (
189+ ctx : & CudaContext ,
190+ max_seq_len : u32 ,
191+ max_batch_size : u32 ,
192+ num_heads : u32 ,
193+ head_dim : u32 ,
194+ ) -> Result < Self > {
155195 #[ cfg( feature = "cuda" ) ]
156196 {
157197 let raw = ctx. create_kvcache ( max_seq_len, max_batch_size, num_heads, head_dim) ?;
158- Ok ( KVCache { max_seq_len, max_batch_size, num_heads, head_dim, raw } )
198+ Ok ( KVCache {
199+ max_seq_len,
200+ max_batch_size,
201+ num_heads,
202+ head_dim,
203+ raw,
204+ } )
159205 }
160206 #[ cfg( not( feature = "cuda" ) ) ]
161207 {
162208 let _ = ctx;
163- Ok ( KVCache { max_seq_len, max_batch_size, num_heads, head_dim } )
209+ Ok ( KVCache {
210+ max_seq_len,
211+ max_batch_size,
212+ num_heads,
213+ head_dim,
214+ } )
164215 }
165216 }
166217}
@@ -185,4 +236,4 @@ impl<T> SharedRing<T> {
185236 std:: mem:: forget ( v) ; // leak capacity; fine for test stubs
186237 Ok ( Self { ptr, len : count } )
187238 }
188- }
239+ }
0 commit comments