Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions crates/cubecl-macros-core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[package]
authors = [
"nathanielsimard <nathaniel.simard.42@gmail.com>",
"louisfd <louisfd94@gmail.com",
]
categories = ["science"]
description = "Expansion logic for the CubeCL proc-macros (cubecl-macros is the proc-macro shim over this)."
edition.workspace = true
keywords = []
license.workspace = true
name = "cubecl-macros-core"
readme.workspace = true
repository = "https://github.com/tracel-ai/cubecl/tree/main/crates/cubecl-macros-core"
version.workspace = true


[lints]
workspace = true


[features]
debug_symbols = []
default = []
std = []

tracing = ["cubecl-common/tracing"]


[dependencies]
darling = { workspace = true }
derive-new = { workspace = true }
derive_more = { workspace = true }
ident_case = { workspace = true }
inflections = "1"
prettyplease = "0.2"
proc-macro2 = { workspace = true }
quote = { workspace = true }
syn = { workspace = true }

cubecl-common = { path = "../cubecl-common", version = "=0.11.0-pre.1", default-features = false }
17 changes: 17 additions & 0 deletions crates/cubecl-macros-core/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use std::env;

// Allow overriding nightly macro features on the end user side without having
// to propagate the feature everywhere
fn main() {
println!("cargo::rustc-check-cfg=cfg(debug_symbols)");
println!("cargo:rerun-if-env-changed=CUBECL_DEBUG");

let debug_feature_enabled = env::var("CARGO_FEATURE_DEBUG_SYMBOLS").is_ok();
let debug_override_env = env::var("CUBECL_DEBUG").unwrap_or_default();
let debug_override_enabled = matches!(debug_override_env.to_lowercase().as_str(), "1" | "true");
let debug_enabled = debug_feature_enabled || debug_override_enabled;

if debug_enabled {
println!("cargo:rustc-cfg=debug_symbols");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,9 @@ impl KernelFn {
let (debug_source, debug_params) = if cfg_debug || self.args.debug_symbols.is_present() {
let debug_source = frontend_type("debug_source_expand");
let cube_debug = frontend_type("CubeDebug");
// Span-based source-file inference needs the compiler-only `proc_macro::Span`, absent
// in this normal lib, so rely on the explicit `src_file` arg (and `file!()` below).
let src_file = self.args.src_file.as_ref().map(|file| file.value());
let src_file = src_file.or_else(|| {
let span: proc_macro::Span = self.span.unwrap();
let source_path = span.local_file();
let source_file = source_path.as_ref().and_then(|path| path.file_name());
source_file.map(|file| file.to_string_lossy().into())
});
let source_text = match src_file {
Some(file) => quote![include_str!(#file)],
None => quote![""],
Expand Down
156 changes: 156 additions & 0 deletions crates/cubecl-macros-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
//! `#[cube]` expansion logic, kept in a normal library (not the `cubecl-macros` proc-macro crate)
//! so it can be tested and fuzzed.

#![allow(clippy::large_enum_variant)]

mod error;
mod expression;
mod generate;
mod operator;
mod parse;
mod paths;
mod scope;
mod statement;

use error::error_into_token_stream;
use generate::{
assign::generate_cube_type_mut, autotune::generate_autotune_key,
into_runtime::generate_into_runtime,
};
use parse::{
cube_impl::CubeImpl,
cube_trait::{CubeTrait, CubeTraitImpl},
cube_type::generate_cube_type,
derive_expand::generate_derive_expand,
helpers::{RemoveHelpers, ReplaceDefines},
kernel::{Launch, from_tokens},
};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Item, visit_mut::VisitMut};

/// Expand a `#[cube]` item, returning a `compile_error!` stream on failure.
pub fn cube(args: TokenStream, input: TokenStream) -> TokenStream {
match cube_impl(args, input.clone()) {
Ok(tokens) => tokens,
Err(e) => error_into_token_stream(e, input),
}
}

/// Fallible core of `#[cube]` expansion, and the fuzz entry point: must return `Ok`/`Err`, never panic.
pub fn cube_impl(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
let mut item: Item = syn::parse2(input)?;
let args = from_tokens(args)?;

let tokens = match item.clone() {
Item::Fn(kernel) => {
let kernel = Launch::from_item_fn(kernel, args)?;
RemoveHelpers.visit_item_mut(&mut item);
ReplaceDefines.visit_item_mut(&mut item);

let extra_allow = match kernel.func.context.is_intrinsic {
true => quote![#[allow(unused_variables)]],
false => quote![],
};

return Ok(quote! {
#[allow(dead_code, clippy::too_many_arguments)]
#extra_allow
#item
#kernel
});
}
Item::Trait(kernel_trait) => {
let is_debug = args.debug.is_present();
let expand_trait = CubeTrait::from_item_trait(kernel_trait, args)?;

let tokens = quote! {
#expand_trait
};
if is_debug {
panic!("{tokens}");
}
return Ok(tokens);
}
Item::Impl(item_impl) => {
if item_impl.trait_.is_some() {
let mut expand_impl = CubeTraitImpl::from_item_impl(item_impl, &args)?;
let expand_impl = expand_impl.to_tokens_mut();

Ok(quote! {
#expand_impl
})
} else {
let mut expand_impl = CubeImpl::from_item_impl(item_impl, &args)?;
let expand_impl = expand_impl.to_tokens_mut();

Ok(quote! {
#expand_impl
})
}
}
item => Err(syn::Error::new_spanned(
item,
"`#[cube]` is only supported on traits and functions",
))?,
};

if args.debug.is_present() {
match tokens {
Ok(tokens) => panic!("{tokens}"),
Err(err) => panic!("{err}"),
};
}

tokens
}

/// Expand the `CubeLaunch` / `CubeType` derives.
pub fn cube_type(input: TokenStream, with_launch: bool) -> TokenStream {
let parsed = syn::parse2(input);

let input = match &parsed {
Ok(val) => val,
Err(err) => return err.to_compile_error(),
};

match generate_cube_type(input, with_launch) {
Ok(val) => val,
Err(err) => err.to_compile_error(),
}
}

/// Expand the `derive_expand` attribute.
pub fn derive_expand(metadata: TokenStream, input: TokenStream) -> TokenStream {
match generate_derive_expand(input, metadata) {
Ok(val) => val,
Err(err) => err.to_compile_error(),
}
}

/// Expand the `AutotuneKey` derive.
pub fn autotune_key(input: TokenStream) -> TokenStream {
let input = syn::parse2(input).unwrap();
match generate_autotune_key(input) {
Ok(tokens) => tokens,
Err(e) => e.into_compile_error(),
}
}

/// Expand the `IntoRuntime` derive.
pub fn into_runtime(input: TokenStream) -> TokenStream {
let input = syn::parse2(input).unwrap();
match generate_into_runtime(&input) {
Ok(tokens) => tokens,
Err(e) => e.into_compile_error(),
}
}

/// Expand the `CubeTypeMut` derive.
pub fn cube_type_mut(input: TokenStream) -> TokenStream {
let input = syn::parse2(input).unwrap();
match generate_cube_type_mut(&input) {
Ok(tokens) => tokens,
Err(e) => e.into_compile_error(),
}
}
15 changes: 4 additions & 11 deletions crates/cubecl-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,15 @@ workspace = true


[features]
debug_symbols = []
debug_symbols = ["cubecl-macros-core/debug_symbols"]
default = []
std = []
std = ["cubecl-macros-core/std"]

tracing = ["cubecl-common/tracing"]
tracing = ["cubecl-macros-core/tracing"]


[dependencies]
darling = { workspace = true }
derive-new = { workspace = true }
derive_more = { workspace = true }
ident_case = { workspace = true }
inflections = "1"
prettyplease = "0.2"
proc-macro2 = { workspace = true }
quote = { workspace = true }
syn = { workspace = true }

cubecl-common = { path = "../cubecl-common", version = "=0.11.0-pre.1", default-features = false }
cubecl-macros-core = { path = "../cubecl-macros-core", version = "=0.11.0-pre.1" }
Loading
Loading