Skip to content

Commit 359f99c

Browse files
EricLBuehlerEric Buehler
andauthored
Allow disabling metal precompilation (#1518)
* Allow disabling metal precompilation * Simple preprocessor * Simple docs --------- Co-authored-by: Eric Buehler <[email protected]>
1 parent 5fbf607 commit 359f99c

File tree

5 files changed

+414
-10
lines changed

5 files changed

+414
-10
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,10 @@ If you want to add a new model, please contact us via an issue and we can coordi
717717
- Metal not found (error: unable to find utility "metal", not a developer tool or in PATH)
718718
1) Install Xcode: `xcode-select --install`
719719
2) Set the active developer directory: `sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer`
720+
- Disabling Metal kernel precompilation:
721+
- By default, Metal kernels are precompiled during build time for better performance
722+
- To skip Metal kernel precompilation (useful for CI or when Metal is not needed), set `MISTRALRS_METAL_PRECOMPILE=0` or `MISTRALRS_METAL_PRECOMPILE=false`
723+
- Example: `MISTRALRS_METAL_PRECOMPILE=0 cargo build --release --features metal`
720724
721725
## Credits
722726
This project would not be possible without the excellent work at [`candle`](https://github.com/huggingface/candle). Additionally, thank you to all contributors! Contributing can range from raising an issue or suggesting a feature to adding some new functionality.

mistralrs-paged-attn/build.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,24 @@ fn main() -> Result<(), String> {
151151
println!("cargo::rerun-if-changed=src/metal/kernels/float8.metal");
152152
println!("cargo::rerun-if-changed=build.rs");
153153

154+
// Check if precompilation should be skipped
155+
// https://github.com/EricLBuehler/mistral.rs/pull/1311#issuecomment-3001309885
156+
println!("cargo:rerun-if-env-changed=MISTRALRS_METAL_PRECOMPILE");
157+
let skip_precompile = env::var("MISTRALRS_METAL_PRECOMPILE")
158+
.map(|v| v == "0" || v.to_lowercase() == "false")
159+
.unwrap_or(false);
160+
161+
if skip_precompile {
162+
println!(
163+
"cargo:warning=Skipping Metal kernel precompilation (MISTRALRS_METAL_PRECOMPILE=0)"
164+
);
165+
// Write a dummy metallib file to satisfy the include_bytes! macro
166+
let out_dir = PathBuf::from(std::env::var("OUT_DIR").map_err(|_| "OUT_DIR not set")?);
167+
std::fs::write(out_dir.join("mistralrs_paged_attention.metallib"), &[]).unwrap();
168+
std::fs::write(out_dir.join("mistralrs_paged_attention_ios.metallib"), &[]).unwrap();
169+
return Ok(());
170+
}
171+
154172
enum Platform {
155173
MacOS,
156174
Ios,

mistralrs-paged-attn/src/metal/kernels/mod.rs

Lines changed: 180 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ pub enum MetalKernelError {
3434
FailedToCreatePipeline(String),
3535
#[error("dtype mismatch, got {got:?}, expected {expected:?}")]
3636
DTypeMismatch { expected: Vec<DType>, got: DType },
37+
#[error("Failed to compile Metal shader: {0}")]
38+
CompilationError(String),
3739
}
3840

3941
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
@@ -70,15 +72,188 @@ impl Kernels {
7072
Ok(lib.clone())
7173
} else {
7274
let source_data = KERNELS;
73-
let lib = device.new_library_with_data(source_data).map_err(|e| {
74-
MetalKernelError::LoadLibraryError(format!(
75-
"Metal requires macosx > 13.0 or higher, cannot load candle metal library: {e}"
76-
))
77-
})?;
75+
// Check if the precompiled library is empty (which indicates runtime compilation is needed)
76+
let lib = if source_data.is_empty() {
77+
// Runtime compilation path
78+
self.compile_kernels_at_runtime(device)?
79+
} else {
80+
// Precompiled path
81+
device.new_library_with_data(source_data).map_err(|e| {
82+
MetalKernelError::LoadLibraryError(format!(
83+
"Metal requires macosx > 13.0 or higher, cannot load candle metal library: {e}"
84+
))
85+
})?
86+
};
7887
Ok(LIBRARY.get_or_init(|| lib).clone())
7988
}
8089
}
8190

91+
fn compile_kernels_at_runtime(&self, device: &Device) -> Result<Library, MetalKernelError> {
92+
use std::collections::{HashMap, HashSet};
93+
94+
// Create a virtual filesystem with all our Metal sources
95+
let mut file_system = HashMap::new();
96+
file_system.insert("copy_blocks.metal", include_str!("copy_blocks.metal"));
97+
file_system.insert("pagedattention.metal", include_str!("pagedattention.metal"));
98+
file_system.insert(
99+
"reshape_and_cache.metal",
100+
include_str!("reshape_and_cache.metal"),
101+
);
102+
file_system.insert("utils.metal", include_str!("utils.metal"));
103+
file_system.insert("float8.metal", include_str!("float8.metal"));
104+
105+
// Recursive include preprocessor
106+
fn preprocess_includes(
107+
content: &str,
108+
current_file: &str,
109+
file_system: &HashMap<&str, &str>,
110+
included_files: &mut HashSet<String>,
111+
include_stack: &mut Vec<String>,
112+
) -> Result<String, String> {
113+
// Check for circular includes
114+
if include_stack.contains(&current_file.to_string()) {
115+
return Err(format!(
116+
"Circular include detected: {} -> {}",
117+
include_stack.join(" -> "),
118+
current_file
119+
));
120+
}
121+
122+
include_stack.push(current_file.to_string());
123+
124+
let mut result = String::new();
125+
let mut lines = content.lines();
126+
127+
while let Some(line) = lines.next() {
128+
let trimmed = line.trim();
129+
130+
// Check for #include directive
131+
if trimmed.starts_with("#include") {
132+
// Extract the included filename
133+
if let Some(start) = trimmed.find('"') {
134+
if let Some(end) = trimmed[start + 1..].find('"') {
135+
let include_file = &trimmed[start + 1..start + 1 + end];
136+
137+
// Check if this is one of our local files
138+
if let Some(included_content) = file_system.get(include_file) {
139+
// Only include each file once (like #pragma once)
140+
if !included_files.contains(include_file) {
141+
included_files.insert(include_file.to_string());
142+
143+
// Recursively process the included file
144+
let processed = preprocess_includes(
145+
included_content,
146+
include_file,
147+
file_system,
148+
included_files,
149+
include_stack,
150+
)?;
151+
152+
result.push_str(&format!(
153+
"\n// ===== Start of {} =====\n",
154+
include_file
155+
));
156+
result.push_str(&processed);
157+
result.push_str(&format!(
158+
"\n// ===== End of {} =====\n",
159+
include_file
160+
));
161+
}
162+
// Skip the original #include line
163+
continue;
164+
} else if !trimmed.contains('<') {
165+
// This is a quoted include but not one of our files
166+
// Skip it to avoid "file not found" errors
167+
continue;
168+
}
169+
}
170+
}
171+
// For system includes (with < >), keep them
172+
if trimmed.contains('<') {
173+
result.push_str(line);
174+
result.push('\n');
175+
}
176+
} else if trimmed == "#pragma once" {
177+
// Skip #pragma once as we handle it differently
178+
continue;
179+
} else {
180+
// Fix backslash-newline warnings by removing trailing spaces
181+
if line.ends_with("\\ ") || line.ends_with("\\\t") {
182+
let cleaned = line.trim_end();
183+
let without_backslash = cleaned.trim_end_matches('\\');
184+
result.push_str(without_backslash);
185+
result.push_str(" \\");
186+
} else {
187+
result.push_str(line);
188+
}
189+
result.push('\n');
190+
}
191+
}
192+
193+
include_stack.pop();
194+
Ok(result)
195+
}
196+
197+
// Start with a clean slate
198+
let mut included_files = HashSet::new();
199+
let mut include_stack = Vec::new();
200+
201+
// Build the main source file
202+
let mut main_source = String::new();
203+
204+
// Add standard Metal includes first
205+
main_source.push_str("#include <metal_stdlib>\n");
206+
main_source.push_str("#include <metal_common>\n");
207+
main_source.push_str("#include <metal_math>\n");
208+
main_source.push_str("#include <metal_simdgroup>\n");
209+
main_source.push_str("\nusing namespace metal;\n\n");
210+
211+
// Process all the main implementation files
212+
// Order matters - we need to ensure dependencies are included first
213+
let main_files = vec![
214+
"float8.metal", // Float8 types
215+
"utils.metal", // Utility functions (depends on float8)
216+
"copy_blocks.metal", // Main implementations
217+
"pagedattention.metal",
218+
"reshape_and_cache.metal",
219+
];
220+
221+
for file in main_files {
222+
if !included_files.contains(file) {
223+
if let Some(content) = file_system.get(file) {
224+
match preprocess_includes(
225+
content,
226+
file,
227+
&file_system,
228+
&mut included_files,
229+
&mut include_stack,
230+
) {
231+
Ok(processed) => {
232+
main_source.push_str(&format!("\n// ===== {} =====\n", file));
233+
main_source.push_str(&processed);
234+
}
235+
Err(e) => {
236+
return Err(MetalKernelError::CompilationError(format!(
237+
"Failed to preprocess {}: {}",
238+
file, e
239+
)));
240+
}
241+
}
242+
}
243+
}
244+
}
245+
246+
// Compile the preprocessed source
247+
let compile_options = metal::CompileOptions::new();
248+
device
249+
.new_library_with_source(&main_source, &compile_options)
250+
.map_err(|e| {
251+
MetalKernelError::CompilationError(format!(
252+
"Failed to compile Metal kernels at runtime: {e}"
253+
))
254+
})
255+
}
256+
82257
fn load_function(
83258
&self,
84259
device: &Device,

mistralrs-quant/build.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,24 @@ fn main() -> Result<(), String> {
174174
}
175175
println!("cargo::rerun-if-changed=build.rs");
176176

177+
// Check if precompilation should be skipped
178+
// https://github.com/EricLBuehler/mistral.rs/pull/1311#issuecomment-3001309885
179+
println!("cargo:rerun-if-env-changed=MISTRALRS_METAL_PRECOMPILE");
180+
let skip_precompile = env::var("MISTRALRS_METAL_PRECOMPILE")
181+
.map(|v| v == "0" || v.to_lowercase() == "false")
182+
.unwrap_or(false);
183+
184+
if skip_precompile {
185+
println!(
186+
"cargo:warning=Skipping Metal kernel precompilation (MISTRALRS_METAL_PRECOMPILE=0)"
187+
);
188+
// Write a dummy metallib file to satisfy the include_bytes! macro
189+
let out_dir = PathBuf::from(std::env::var("OUT_DIR").map_err(|_| "OUT_DIR not set")?);
190+
std::fs::write(out_dir.join("mistralrs_quant.metallib"), &[]).unwrap();
191+
std::fs::write(out_dir.join("mistralrs_quant_ios.metallib"), &[]).unwrap();
192+
return Ok(());
193+
}
194+
177195
enum Platform {
178196
MacOS,
179197
Ios,

0 commit comments

Comments
 (0)