Skip to content
Merged
20 changes: 20 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,14 @@ cc = "1.2.10"
cfg-if = "1.0.0"
clap = "4.5.9"
clap-cargo = "0.14.1"
fs4 = "0.12.0"
itertools = "0.13.0"
paste = "1.0.15"
pretty_assertions = "1.4.1"
proc-macro2 = "1.0.93"
quote = "1.0.38"
rustversion = "1.0.19"
scratch = "1.0"
serde = "1.0"
serde_json = "1.0"
syn = "2.0.87"
Expand Down
4 changes: 4 additions & 0 deletions crates/wdk-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@ categories = [
proc-macro = true

[dependencies]
fs4.workspace = true
itertools.workspace = true
proc-macro2.workspace = true
quote.workspace = true
scratch.workspace = true
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
syn = { workspace = true, features = ["full", "extra-traits"] }

[dev-dependencies]
Expand Down
157 changes: 151 additions & 6 deletions crates/wdk-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
//! A collection of macros that help make it easier to interact with
//! [`wdk-sys`]'s direct bindings to the Windows Driver Kit (WDK).

use std::path::PathBuf;
use std::{collections::HashMap, path::PathBuf, str::FromStr};

use fs4::fs_std::FileExt;
use itertools::Itertools;
use proc_macro::TokenStream;
use proc_macro2::{Span, TokenStream as TokenStream2};
use quote::{format_ident, quote};
use quote::{format_ident, quote, ToTokens};
use serde::{Deserialize, Serialize};
use syn::{
parse::{Parse, ParseStream},
parse::{Parse, ParseStream, Parser},
parse2,
parse_file,
parse_quote,
Expand Down Expand Up @@ -95,6 +97,14 @@ struct IntermediateOutputASTFragments {
inline_wdf_fn_invocation: ExprCall,
}

/// Struct storing string representations of the information we want to cache
/// from `types.rs`.
#[derive(Serialize, Deserialize)]
struct SavedFunctionInfo {
parameters: String,
return_type: String,
}

impl StringExt for String {
fn to_snake_case(&self) -> String {
// There will be, at max, 2 characters unhandled by the 3-char windows. It is
Expand Down Expand Up @@ -171,9 +181,27 @@ impl Inputs {
span = self.wdf_function_identifier.span()
);

let types_ast = parse_types_ast(&self.types_path)?;
let (parameters, return_type) =
generate_parameters_and_return_type(&types_ast, &function_pointer_type)?;
let function_name_to_info_map: HashMap<String, SavedFunctionInfo> =
self.get_fragment_info_map()?;
let function_info = function_name_to_info_map
.get(&self.wdf_function_identifier.to_string())
.ok_or_else(|| {
Error::new(
self.wdf_function_identifier.span(),
format!(
"Failed to find function info for {}",
self.wdf_function_identifier
),
)
})?;
let parameters_tokens = TokenStream2::from_str(&function_info.parameters)
.map_err(|e| map_to_syn_error(self.wdf_function_identifier.span(), e))?;
let return_type_tokens = TokenStream2::from_str(&function_info.return_type)
.map_err(|e| map_to_syn_error(self.wdf_function_identifier.span(), e))?;
let parameters =
Punctuated::<BareFnArg, Token![,]>::parse_terminated.parse2(parameters_tokens)?;
let return_type = ReturnType::parse.parse2(return_type_tokens)?;

let parameter_identifiers = parameters
.iter()
.cloned()
Expand Down Expand Up @@ -202,6 +230,119 @@ impl Inputs {
inline_wdf_fn_name,
})
}

// Motivation for this function is to reduce build time. Rather than parse
// types.rs for relevant information to construct function table call on each
// macro invocation, we store all possible function table information for each
// function on first invocation. We cache this using the `scratch` crate, and
// read from this cache for each subsequent invocation.
fn get_fragment_info_map(&self) -> Result<HashMap<String, SavedFunctionInfo>> {
let scratch_dir = scratch::path("ast_fragments");
let flock = std::fs::File::create(scratch_dir.join(".lock"))
.map_err(|e| map_to_syn_error(self.wdf_function_identifier.span(), e))?;

let fragment_info_map_path = scratch_dir.join("fragment_info_map.json");

if !fragment_info_map_path.exists() {
FileExt::lock_exclusive(&flock)
.map_err(|e| map_to_syn_error(self.wdf_function_identifier.span(), e))?;

if !fragment_info_map_path.exists() {
let generated_map = self.generate_fragment_info_map_from_types()?;
let generated_map_string = serde_json::to_string(&generated_map)
.map_err(|e| map_to_syn_error(self.wdf_function_identifier.span(), e))?;
std::fs::write(&fragment_info_map_path, generated_map_string)
.map_err(|e| map_to_syn_error(self.wdf_function_identifier.span(), e))?;
}
FileExt::unlock(&flock)
.map_err(|e| map_to_syn_error(self.wdf_function_identifier.span(), e))?;
}

let generated_map_string = std::fs::read_to_string(&fragment_info_map_path)
.map_err(|e| map_to_syn_error(self.wdf_function_identifier.span(), e))?;
let map: HashMap<String, SavedFunctionInfo> =
serde_json::from_str(&generated_map_string)
.map_err(|e| map_to_syn_error(self.wdf_function_identifier.span(), e))?;
Ok(map)
}

// To generate the cache we parse the types.rs file for `mod _WDFFUNCENUM`. This
// stores each function table name with the suffix "TableIndex", which we parse
// and trim. We then use this function name to parse for the function pointer
// type alias, which we then parse for the function parameters and return type.
fn generate_fragment_info_map_from_types(&self) -> Result<HashMap<String, SavedFunctionInfo>> {
let types_ast = parse_types_ast(&self.types_path)?;

let func_enum_mod: Option<&syn::ItemMod> = types_ast.items.iter().find_map(|item| {
if let Item::Mod(mod_alias) = item {
if mod_alias.ident == "_WDFFUNCENUM" {
return Some(mod_alias);
}
}
None
});

let func_enum_mod = func_enum_mod.ok_or_else(|| {
Error::new(
self.wdf_function_identifier.span(),
"Failed to find _WDFFUNCENUM module in types file",
)
})?;

let func_enum_mod_contents = &func_enum_mod
.content
.as_ref()
.ok_or_else(|| {
Error::new(
self.wdf_function_identifier.span(),
"Failed to find _WDFFUNCENUM module contents in types file",
)
})?
.1;

let mut const_func_enum_types: Vec<String> = vec![];
for func_enum_mod_item in func_enum_mod_contents {
if let Item::Const(const_alias) = func_enum_mod_item {
const_func_enum_types.push(const_alias.ident.to_string());
}
}

let mut function_name_to_info_map: HashMap<String, SavedFunctionInfo> = HashMap::new();
for const_func_enum_type in const_func_enum_types {
let Some(wdf_function_name) = const_func_enum_type.strip_suffix("TableIndex") else {
continue;
};
let function_pointer_type = format_ident!(
"PFN_{uppercase_c_function_name}",
uppercase_c_function_name = wdf_function_name.to_uppercase(),
span = self.wdf_function_identifier.span()
);

let (parameters, return_type) =
match generate_parameters_and_return_type(&types_ast, &function_pointer_type) {
Ok((parameters, return_type)) => (parameters, return_type),
Err(err) => {
if err
.to_string()
.contains("Failed to find type alias definition for")
{
continue;
}
return Err(err);
}
};

function_name_to_info_map.insert(
wdf_function_name.into(),
SavedFunctionInfo {
parameters: parameters.to_token_stream().to_string(),
return_type: return_type.to_token_stream().to_string(),
},
);
}

Ok(function_name_to_info_map)
}
}

impl DerivedASTFragments {
Expand Down Expand Up @@ -318,6 +459,10 @@ impl IntermediateOutputASTFragments {
}
}

fn map_to_syn_error<E: std::fmt::Display>(span: Span, error: E) -> Error {
Error::new(span, error.to_string())
}

fn call_unsafe_wdf_function_binding_impl(input_tokens: TokenStream2) -> TokenStream2 {
let inputs = match parse2::<Inputs>(input_tokens) {
Ok(syntax_tree) => syntax_tree,
Expand Down
6 changes: 6 additions & 0 deletions crates/wdk-macros/tests/unit-tests-input/generated-types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

//! Snippet of a bindgen-generated file containing types information used by tests for [`wdk_macros::call_unsafe_wdf_function_binding!`]

pub mod _WDFFUNCENUM {
pub const WdfDriverCreateTableIndex: Type = 116;
pub const WdfVerifierDbgBreakPointTableIndex: Type = 367;
}


pub type PFN_WDFDRIVERCREATE = ::core::option::Option<
unsafe extern "C" fn(
DriverGlobals: PWDF_DRIVER_GLOBALS,
Expand Down
24 changes: 22 additions & 2 deletions examples/sample-kmdf-driver/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 22 additions & 2 deletions examples/sample-umdf-driver/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading