@@ -34,6 +34,8 @@ pub enum MetalKernelError {
3434 FailedToCreatePipeline ( String ) ,
3535 #[ error( "dtype mismatch, got {got:?}, expected {expected:?}" ) ]
3636 DTypeMismatch { expected : Vec < DType > , got : DType } ,
37+ #[ error( "Failed to compile Metal shader: {0}" ) ]
38+ CompilationError ( String ) ,
3739}
3840
3941impl < T > From < std:: sync:: PoisonError < T > > for MetalKernelError {
@@ -70,15 +72,188 @@ impl Kernels {
7072 Ok ( lib. clone ( ) )
7173 } else {
7274 let source_data = KERNELS ;
73- let lib = device. new_library_with_data ( source_data) . map_err ( |e| {
74- MetalKernelError :: LoadLibraryError ( format ! (
75- "Metal requires macosx > 13.0 or higher, cannot load candle metal library: {e}"
76- ) )
77- } ) ?;
75+ // Check if the precompiled library is empty (which indicates runtime compilation is needed)
76+ let lib = if source_data. is_empty ( ) {
77+ // Runtime compilation path
78+ self . compile_kernels_at_runtime ( device) ?
79+ } else {
80+ // Precompiled path
81+ device. new_library_with_data ( source_data) . map_err ( |e| {
82+ MetalKernelError :: LoadLibraryError ( format ! (
83+ "Metal requires macosx > 13.0 or higher, cannot load candle metal library: {e}"
84+ ) )
85+ } ) ?
86+ } ;
7887 Ok ( LIBRARY . get_or_init ( || lib) . clone ( ) )
7988 }
8089 }
8190
91+ fn compile_kernels_at_runtime ( & self , device : & Device ) -> Result < Library , MetalKernelError > {
92+ use std:: collections:: { HashMap , HashSet } ;
93+
94+ // Create a virtual filesystem with all our Metal sources
95+ let mut file_system = HashMap :: new ( ) ;
96+ file_system. insert ( "copy_blocks.metal" , include_str ! ( "copy_blocks.metal" ) ) ;
97+ file_system. insert ( "pagedattention.metal" , include_str ! ( "pagedattention.metal" ) ) ;
98+ file_system. insert (
99+ "reshape_and_cache.metal" ,
100+ include_str ! ( "reshape_and_cache.metal" ) ,
101+ ) ;
102+ file_system. insert ( "utils.metal" , include_str ! ( "utils.metal" ) ) ;
103+ file_system. insert ( "float8.metal" , include_str ! ( "float8.metal" ) ) ;
104+
105+ // Recursive include preprocessor
106+ fn preprocess_includes (
107+ content : & str ,
108+ current_file : & str ,
109+ file_system : & HashMap < & str , & str > ,
110+ included_files : & mut HashSet < String > ,
111+ include_stack : & mut Vec < String > ,
112+ ) -> Result < String , String > {
113+ // Check for circular includes
114+ if include_stack. contains ( & current_file. to_string ( ) ) {
115+ return Err ( format ! (
116+ "Circular include detected: {} -> {}" ,
117+ include_stack. join( " -> " ) ,
118+ current_file
119+ ) ) ;
120+ }
121+
122+ include_stack. push ( current_file. to_string ( ) ) ;
123+
124+ let mut result = String :: new ( ) ;
125+ let mut lines = content. lines ( ) ;
126+
127+ while let Some ( line) = lines. next ( ) {
128+ let trimmed = line. trim ( ) ;
129+
130+ // Check for #include directive
131+ if trimmed. starts_with ( "#include" ) {
132+ // Extract the included filename
133+ if let Some ( start) = trimmed. find ( '"' ) {
134+ if let Some ( end) = trimmed[ start + 1 ..] . find ( '"' ) {
135+ let include_file = & trimmed[ start + 1 ..start + 1 + end] ;
136+
137+ // Check if this is one of our local files
138+ if let Some ( included_content) = file_system. get ( include_file) {
139+ // Only include each file once (like #pragma once)
140+ if !included_files. contains ( include_file) {
141+ included_files. insert ( include_file. to_string ( ) ) ;
142+
143+ // Recursively process the included file
144+ let processed = preprocess_includes (
145+ included_content,
146+ include_file,
147+ file_system,
148+ included_files,
149+ include_stack,
150+ ) ?;
151+
152+ result. push_str ( & format ! (
153+ "\n // ===== Start of {} =====\n " ,
154+ include_file
155+ ) ) ;
156+ result. push_str ( & processed) ;
157+ result. push_str ( & format ! (
158+ "\n // ===== End of {} =====\n " ,
159+ include_file
160+ ) ) ;
161+ }
162+ // Skip the original #include line
163+ continue ;
164+ } else if !trimmed. contains ( '<' ) {
165+ // This is a quoted include but not one of our files
166+ // Skip it to avoid "file not found" errors
167+ continue ;
168+ }
169+ }
170+ }
171+ // For system includes (with < >), keep them
172+ if trimmed. contains ( '<' ) {
173+ result. push_str ( line) ;
174+ result. push ( '\n' ) ;
175+ }
176+ } else if trimmed == "#pragma once" {
177+ // Skip #pragma once as we handle it differently
178+ continue ;
179+ } else {
180+ // Fix backslash-newline warnings by removing trailing spaces
181+ if line. ends_with ( "\\ " ) || line. ends_with ( "\\ \t " ) {
182+ let cleaned = line. trim_end ( ) ;
183+ let without_backslash = cleaned. trim_end_matches ( '\\' ) ;
184+ result. push_str ( without_backslash) ;
185+ result. push_str ( " \\ " ) ;
186+ } else {
187+ result. push_str ( line) ;
188+ }
189+ result. push ( '\n' ) ;
190+ }
191+ }
192+
193+ include_stack. pop ( ) ;
194+ Ok ( result)
195+ }
196+
197+ // Start with a clean slate
198+ let mut included_files = HashSet :: new ( ) ;
199+ let mut include_stack = Vec :: new ( ) ;
200+
201+ // Build the main source file
202+ let mut main_source = String :: new ( ) ;
203+
204+ // Add standard Metal includes first
205+ main_source. push_str ( "#include <metal_stdlib>\n " ) ;
206+ main_source. push_str ( "#include <metal_common>\n " ) ;
207+ main_source. push_str ( "#include <metal_math>\n " ) ;
208+ main_source. push_str ( "#include <metal_simdgroup>\n " ) ;
209+ main_source. push_str ( "\n using namespace metal;\n \n " ) ;
210+
211+ // Process all the main implementation files
212+ // Order matters - we need to ensure dependencies are included first
213+ let main_files = vec ! [
214+ "float8.metal" , // Float8 types
215+ "utils.metal" , // Utility functions (depends on float8)
216+ "copy_blocks.metal" , // Main implementations
217+ "pagedattention.metal" ,
218+ "reshape_and_cache.metal" ,
219+ ] ;
220+
221+ for file in main_files {
222+ if !included_files. contains ( file) {
223+ if let Some ( content) = file_system. get ( file) {
224+ match preprocess_includes (
225+ content,
226+ file,
227+ & file_system,
228+ & mut included_files,
229+ & mut include_stack,
230+ ) {
231+ Ok ( processed) => {
232+ main_source. push_str ( & format ! ( "\n // ===== {} =====\n " , file) ) ;
233+ main_source. push_str ( & processed) ;
234+ }
235+ Err ( e) => {
236+ return Err ( MetalKernelError :: CompilationError ( format ! (
237+ "Failed to preprocess {}: {}" ,
238+ file, e
239+ ) ) ) ;
240+ }
241+ }
242+ }
243+ }
244+ }
245+
246+ // Compile the preprocessed source
247+ let compile_options = metal:: CompileOptions :: new ( ) ;
248+ device
249+ . new_library_with_source ( & main_source, & compile_options)
250+ . map_err ( |e| {
251+ MetalKernelError :: CompilationError ( format ! (
252+ "Failed to compile Metal kernels at runtime: {e}"
253+ ) )
254+ } )
255+ }
256+
82257 fn load_function (
83258 & self ,
84259 device : & Device ,
0 commit comments