Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion tooling/ast_fuzzer/src/program/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
ctx.set_function_decl(FuncId(1), decl_inner.clone());
ctx.gen_function(u, FuncId(1))?;

// Parameterless main declaration wrapping the inner "main"

Check warning on line 54 in tooling/ast_fuzzer/src/program/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (Parameterless)
// function call
let decl_main = FunctionDeclaration {
name: "main".into(),
Expand All @@ -63,7 +63,7 @@
};

ctx.set_function_decl(FuncId(0), decl_main);
ctx.gen_function_with_body(u, FuncId(0), |u, fctx| fctx.gen_body_with_lit_call(u, FuncId(1)))?;

Check warning on line 66 in tooling/ast_fuzzer/src/program/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (fctx)

Check warning on line 66 in tooling/ast_fuzzer/src/program/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (fctx)
ctx.rewrite_functions(u)?;

let program = ctx.finalize();
Expand Down Expand Up @@ -169,6 +169,16 @@
let is_main = i == 0;
let num_params = u.int_in_range(0..=self.config.max_function_args)?;

// If `main` is unconstrained, it won't call ACIR, so no point generating ACIR functions.
let unconstrained = self.config.force_brillig
|| (!is_main
&& self
.functions
.get(&Program::main_id())
.map(|func| func.unconstrained)
.unwrap_or_default())
|| bool::arbitrary(u)?;

let mut params = Vec::new();
for p in 0..num_params {
let id = LocalId(p as u32);
Expand Down Expand Up @@ -227,7 +237,7 @@
} else {
*u.choose(&[InlineType::Inline, InlineType::InlineAlways])?
},
unconstrained: self.config.force_brillig || bool::arbitrary(u)?,
unconstrained,
};

Ok(decl)
Expand All @@ -251,7 +261,7 @@

/// Generate random function body.
fn gen_function(&mut self, u: &mut Unstructured, id: FuncId) -> arbitrary::Result<()> {
self.gen_function_with_body(u, id, |u, fctx| fctx.gen_body(u))

Check warning on line 264 in tooling/ast_fuzzer/src/program/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (fctx)

Check warning on line 264 in tooling/ast_fuzzer/src/program/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (fctx)
}

/// Generate function with a specified body generator.
Expand All @@ -261,7 +271,7 @@
id: FuncId,
f: impl Fn(&mut Unstructured, FunctionContext) -> arbitrary::Result<Expression>,
) -> arbitrary::Result<()> {
let fctx = FunctionContext::new(self, id);

Check warning on line 274 in tooling/ast_fuzzer/src/program/mod.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (fctx)
let body = f(u, fctx)?;
let decl = self.function_decl(id);
let func = Function {
Expand Down
92 changes: 61 additions & 31 deletions tooling/ast_fuzzer/src/program/rewrite.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::collections::BTreeMap;
use std::collections::{HashMap, HashSet};

use arbitrary::Unstructured;
use im::HashMap;
use nargo::errors::Location;
use noirc_frontend::{
ast::BinaryOpKind,
Expand All @@ -16,44 +15,66 @@ use super::{
visitor::{visit_expr, visit_expr_mut},
};

/// Find recursive functions and add a `ctx_limit: &mut u32` parameter to them,
/// which we use to limit the number of recursive calls. This is complicated by
/// the fact that we cannot pass mutable references from ACIR to Brillig. To
/// overcome that, we create a proxy function for unconstrained functions that
/// take `mut ctx_limit: u32` instead, and pass it on as a mutable ref.
/// To avoid the potential of infinite recursion at runtime, add a `ctx_limit: &mut u32`
/// parameter to all functions, which we use to limit the number of recursive calls.
///
/// This is complicated by the fact that we cannot pass mutable references from ACIR to Brillig.
/// To overcome that, we create a proxy function for unconstrained functions that take
/// `mut ctx_limit: u32` instead, and pass it on as a mutable ref.
///
/// Originally only actually recursive functions (ie. one that called something else)
/// received this extra parameters, but in order to support higher order functions
/// which can be passed a recursive or a non-recursive function as an argument,
/// all functions get the extra parameter.
pub(crate) fn add_recursion_limit(
ctx: &mut Context,
u: &mut Unstructured,
) -> arbitrary::Result<()> {
// Collect recursive functions, ie. the ones which call other functions.
// Remember if they are unconstrained; those need proxies as well.
let recursive_functions = ctx
.functions
.iter()
.filter_map(|(id, func)| expr::has_call(&func.body).then_some((*id, func.unconstrained)))
.collect::<BTreeMap<_, _>>();
.filter_map(|(id, func)| expr::has_call(&func.body).then_some(*id))
.collect::<HashSet<_>>();

// Collect functions called from ACIR; they will need proxy functions.
let called_from_acir = ctx.functions.values().filter(|func| !func.unconstrained).fold(
HashSet::<FuncId>::new(),
|mut acc, func| {
acc.extend(expr::callees(&func.body));
acc
},
);

// Create proxies for unconstrained recursive functions.
// We could check whether they are called from ACIR, but that would require further traversals.
let unconstrained_functions = ctx
.functions
.iter()
.filter_map(|(id, func)| func.unconstrained.then_some(*id))
.collect::<HashSet<_>>();

// Create proxies for unconstrained functions called from ACIR.
let mut proxy_functions = HashMap::new();
let mut next_func_id = FuncId(ctx.functions.len() as u32);

for (func_id, unconstrained) in &recursive_functions {
if !*unconstrained || *func_id == Program::main_id() {
for (func_id, func) in &ctx.functions {
if !func.unconstrained
|| *func_id == Program::main_id()
|| !called_from_acir.contains(func_id)
{
continue;
}
let mut proxy = ctx.functions[func_id].clone();
let mut proxy = func.clone();
proxy.id = next_func_id;
proxy.name = format!("{}_proxy", proxy.name);
// We will replace the body and update the params later.
// We will replace the body, update the params, and append the function later.
proxy_functions.insert(*func_id, proxy);
next_func_id = FuncId(next_func_id.0 + 1);
}

// Rewrite recursive functions.
for (func_id, unconstrained) in recursive_functions.iter() {
let func = ctx.functions.get_mut(func_id).unwrap();
// Rewrite functions.
for (func_id, func) in ctx.functions.iter_mut() {
let is_main = *func_id == Program::main_id();
let is_recursive = recursive_functions.contains(func_id);

// We'll need a new ID for variables or parameters. We could speed this up by
// 1) caching this value in a "function meta" construct, or
Expand Down Expand Up @@ -89,7 +110,7 @@ pub(crate) fn add_recursion_limit(
expr::u32_literal(ctx.config.max_recursive_calls as u32),
);
expr::prepend(&mut func.body, init_limit);
} else {
} else if is_recursive {
// In non-main we look at the limit and return a random value if it's zero,
// otherwise decrease it by one and continue with the original body.
let limit_type = types::ref_mut(types::U32);
Expand Down Expand Up @@ -132,6 +153,18 @@ pub(crate) fn add_recursion_limit(
func.return_type.clone(),
)
});
} else {
// For non-recursive functions just add an unused parameter.
// In non-main we look at the limit and return a random value if it's zero,
// otherwise decrease it by one and continue with the original body.
let limit_type = types::ref_mut(types::U32);
func.parameters.push((
limit_id,
false,
format!("_{limit_name}"),
limit_type.clone(),
Visibility::Private,
));
}

// Add the non-reference version of the parameter to the proxy function.
Expand Down Expand Up @@ -194,26 +227,23 @@ pub(crate) fn add_recursion_limit(
// Update calls to pass along the limit and call the proxy if necessary.
visit_expr_mut(&mut func.body, &mut |expr: &mut Expression| {
if let Expression::Call(call) = expr {
let Expression::Ident(func) = call.func.as_mut() else {
let Expression::Ident(ident) = call.func.as_mut() else {
unreachable!("functions are called by ident");
};
let Definition::Function(func_id) = func.definition else {
let Definition::Function(callee_id) = ident.definition else {
unreachable!("function definition expected");
};
// If the callee isn't recursive, it won't have the extra parameter.
let Some(callee_unconstrained) = recursive_functions.get(&func_id) else {
return true;
};
let Type::Function(param_types, _, _, _) = &mut func.typ else {
let Type::Function(param_types, _, _, _) = &mut ident.typ else {
unreachable!("function type expected");
};
if *callee_unconstrained && !unconstrained {
let callee_unconstrained = unconstrained_functions.contains(&callee_id);
if callee_unconstrained && !func.unconstrained {
// Calling Brillig from ACIR: call the proxy.
let Some(proxy) = proxy_functions.get(&func_id) else {
let Some(proxy) = proxy_functions.get(&callee_id) else {
unreachable!("expected to have a proxy");
};
func.name = proxy.name.clone();
func.definition = Definition::Function(proxy.id);
ident.name = proxy.name.clone();
ident.definition = Definition::Function(proxy.id);
// Pass the limit by value.
let limit_expr = if is_main {
expr::ident(
Expand Down
9 changes: 3 additions & 6 deletions tooling/ast_fuzzer/src/program/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ fn test_recursion_limit_rewrite() {
// - main passes the limit to foo by ref
// - foo passes the limit to bar_proxy by value
// - bar_proxy passes the limit to baz by ref
// - bar does not passes the limit to qux
// - bar passes the limit to qux, even though it's unused
// - baz passes the limit to itself by ref

let code = format!("{}", DisplayAstAsNoir(&program));
Expand All @@ -220,7 +220,7 @@ fn test_recursion_limit_rewrite() {
} else {
*ctx_limit = ((*ctx_limit) - 1);
baz(ctx_limit);
qux()
qux(ctx_limit)
}
}
unconstrained fn baz(ctx_limit: &mut u32) -> () {
Expand All @@ -231,13 +231,10 @@ fn test_recursion_limit_rewrite() {
baz(ctx_limit)
}
}
unconstrained fn qux() -> () {
unconstrained fn qux(_ctx_limit: &mut u32) -> () {
}
unconstrained fn bar_proxy(mut ctx_limit: u32) -> () {
bar((&mut ctx_limit))
}
unconstrained fn baz_proxy(mut ctx_limit: u32) -> () {
baz((&mut ctx_limit))
}
");
}
Loading