diff --git a/tooling/ast_fuzzer/src/lib.rs b/tooling/ast_fuzzer/src/lib.rs index 26fdd49a73c..46b620d43a8 100644 --- a/tooling/ast_fuzzer/src/lib.rs +++ b/tooling/ast_fuzzer/src/lib.rs @@ -60,6 +60,8 @@ pub struct Config { pub avoid_large_int_literals: bool, /// Avoid using loop control (break/continue). pub avoid_loop_control: bool, + /// Avoid using function pointers in parameters. + pub avoid_lambdas: bool, /// Only use comptime friendly expressions. pub comptime_friendly: bool, } @@ -117,6 +119,8 @@ impl Default for Config { avoid_large_int_literals: false, avoid_negative_int_literals: false, avoid_loop_control: false, + // TODO(#8543): Allow lambdas when ICE is fixed. + avoid_lambdas: true, comptime_friendly: false, } } diff --git a/tooling/ast_fuzzer/src/program/func.rs b/tooling/ast_fuzzer/src/program/func.rs index bf3e1f53e0b..56c19f115ab 100644 --- a/tooling/ast_fuzzer/src/program/func.rs +++ b/tooling/ast_fuzzer/src/program/func.rs @@ -19,7 +19,7 @@ use noirc_frontend::{ }; use super::{ - Context, VariableId, expr, + CallableId, Context, VariableId, expr, freq::Freq, make_name, scope::{Scope, ScopeStack, Variable}, @@ -161,7 +161,7 @@ pub(super) struct FunctionContext<'a> { in_loop: bool, /// All the functions callable from this one, with the types we can /// produce from their return value. - call_targets: BTreeMap>, + call_targets: BTreeMap>, /// Indicate that we have generated a `Call`. has_call: bool, } @@ -185,17 +185,25 @@ impl<'a> FunctionContext<'a> { ); // Collect all the functions we can call from this one. - // TODO(#8484): Look for call targets in function-valued arguments as well. - let call_targets = ctx - .function_declarations - .iter() - .filter_map(|(callee_id, callee_decl)| { - if !can_call(id, decl.unconstrained, *callee_id, callee_decl.unconstrained) { - return None; - } - Some((*callee_id, types::types_produced(&callee_decl.return_type))) - }) - .collect(); + let mut call_targets = BTreeMap::new(); + + // Consider calling any allowed global function. + for (callee_id, callee_decl) in &ctx.function_declarations { + if !can_call(id, decl.unconstrained, *callee_id, callee_decl.unconstrained) { + continue; + } + let produces = types::types_produced(&callee_decl.return_type); + call_targets.insert(CallableId::Global(*callee_id), produces); + } + + // Consider function pointers as callable; they are already filtered during construction. + for (callee_id, _, _, typ, _) in &decl.params { + let Type::Function(_, return_type, _, _) = typ else { + continue; + }; + let produces = types::types_produced(return_type); + call_targets.insert(CallableId::Local(*callee_id), produces); + } Self { ctx, @@ -316,12 +324,20 @@ impl<'a> FunctionContext<'a> { max_depth: usize, flags: Flags, ) -> arbitrary::Result { - // For now if we need a function, return one without further nesting, - // e.g. avoid `if { func_1 } else { func_2 }`, because it makes rewriting - // harder when we need to deal with proxies. + // For now if we need a function, return one without further nesting, e.g. avoid `if { func_1 } else { func_2 }`, + // because it makes it harder to rewrite functions to add recursion limit: we would need to replace functions in the + // expressions to proxy version if we call Brillig from ACIR, but we would also need to keep track whether we are calling a function, + // For example if we could return function pointers, we could have something like this: + // `acir_func_1(if c { brillig_func_2 } else { unsafe { brillig_func_3(brillig_func_4) } })` + // We could replace `brillig_func_2` with `brillig_func_2_proxy`, but we wouldn't replace `brillig_func_4` with `brillig_func_4_proxy` + // because that is a parameter of another call. But we would have to deal with the return value. + // For this reason we handle function parameters directly here. if matches!(typ, Type::Function(_, _, _, _)) { - // Local variables we should consider in `gen_expr_from_vars`, so here we just look through global functions. - return self.find_function_with_signature(u, typ); + // Prefer functions in variables over globals. + return match self.gen_expr_from_vars(u, typ, max_depth)? { + Some(expr) => Ok(expr), + None => self.find_global_function_with_signature(u, typ), + }; } let mut freq = Freq::new(u, &self.ctx.config.expr_freqs)?; @@ -969,9 +985,7 @@ impl<'a> FunctionContext<'a> { let callee_id = *u.choose_iter(opts)?; let callee_ident = self.function_ident(callee_id); - - let callee = self.ctx.function_decl(callee_id).clone(); - let param_types = callee.params.iter().map(|p| p.3.clone()).collect::>(); + let (param_types, return_type) = self.callable_signature(callee_id); // Generate an expression for each argument. let mut args = Vec::new(); @@ -982,12 +996,12 @@ impl<'a> FunctionContext<'a> { let call_expr = Expression::Call(Call { func: Box::new(callee_ident), arguments: args, - return_type: callee.return_type.clone(), + return_type: return_type.clone(), location: Location::dummy(), }); // Derive the final result from the call, e.g. by casting, or accessing a member. - self.gen_expr_from_source(u, call_expr, &callee.return_type, typ, self.max_depth()) + self.gen_expr_from_source(u, call_expr, &return_type, typ, self.max_depth()) } /// Generate a call to a specific function, with arbitrary literals @@ -1163,7 +1177,9 @@ impl<'a> FunctionContext<'a> { } /// Find a global function matching a type signature. - fn find_function_with_signature( + /// + /// For local functions we use `gen_expr_from_vars`. + fn find_global_function_with_signature( &mut self, u: &mut Unstructured, typ: &Type, @@ -1196,26 +1212,60 @@ impl<'a> FunctionContext<'a> { let callee_id = u.choose_iter(candidates)?; - Ok(self.function_ident(callee_id)) + Ok(self.function_ident(CallableId::Global(callee_id))) } /// Generate an identifier for calling a global function. - fn function_ident(&mut self, callee_id: FuncId) -> Expression { - let callee = self.ctx.function_decl(callee_id).clone(); - let param_types = callee.params.iter().map(|p| p.3.clone()).collect::>(); - Expression::Ident(Ident { - location: None, - definition: Definition::Function(callee_id), - mutable: false, - name: callee.name.clone(), - typ: Type::Function( - param_types, - Box::new(callee.return_type.clone()), - Box::new(Type::Unit), - callee.unconstrained, - ), - id: self.next_ident_id(), - }) + fn function_ident(&mut self, callee_id: CallableId) -> Expression { + match callee_id { + CallableId::Global(id) => { + let callee = self.ctx.function_decl(id).clone(); + let param_types = callee.params.iter().map(|p| p.3.clone()).collect::>(); + Expression::Ident(Ident { + location: None, + definition: Definition::Function(id), + mutable: false, + name: callee.name.clone(), + typ: Type::Function( + param_types, + Box::new(callee.return_type.clone()), + Box::new(Type::Unit), + callee.unconstrained, + ), + id: self.next_ident_id(), + }) + } + CallableId::Local(id) => { + let (mutable, name, typ) = self.locals.current().get_variable(&id); + Expression::Ident(Ident { + location: None, + definition: Definition::Local(id), + mutable: *mutable, + name: name.clone(), + typ: typ.clone(), + id: self.next_ident_id(), + }) + } + } + } + + /// Get the parameter types and return type of a callable function. + fn callable_signature(&self, callee_id: CallableId) -> (Vec, Type) { + match callee_id { + CallableId::Global(id) => { + let decl = self.ctx.function_decl(id); + let return_type = decl.return_type.clone(); + let param_types = decl.params.iter().map(|p| p.3.clone()).collect::>(); + (param_types, return_type) + } + CallableId::Local(id) => { + let (_, _, typ) = self.locals.current().get_variable(&id); + let Type::Function(param_types, return_type, _, _) = typ else { + unreachable!("function pointers should have function type; got {typ}") + }; + (param_types.clone(), return_type.as_ref().clone()) + } + } } } diff --git a/tooling/ast_fuzzer/src/program/mod.rs b/tooling/ast_fuzzer/src/program/mod.rs index aaff1335304..b0665cd655f 100644 --- a/tooling/ast_fuzzer/src/program/mod.rs +++ b/tooling/ast_fuzzer/src/program/mod.rs @@ -77,6 +77,14 @@ pub(crate) enum VariableId { Global(GlobalId), } +/// ID of a function we can call, either as a pointer in a local variable, +/// or directly as a global function. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub(crate) enum CallableId { + Local(LocalId), + Global(FuncId), +} + /// Name of a variable. type Name = String; @@ -181,7 +189,7 @@ impl Context { || bool::arbitrary(u)?; // Which existing functions we could receive as parameters. - let func_param_candidates: Vec = if is_main { + let func_param_candidates: Vec = if is_main || self.config.avoid_lambdas { // Main cannot receive function parameters from outside. vec![] } else { diff --git a/tooling/ast_fuzzer/src/program/rewrite.rs b/tooling/ast_fuzzer/src/program/rewrite.rs deleted file mode 100644 index 3f4617354a2..00000000000 --- a/tooling/ast_fuzzer/src/program/rewrite.rs +++ /dev/null @@ -1,409 +0,0 @@ -use std::collections::{HashMap, HashSet}; - -use arbitrary::Unstructured; -use nargo::errors::Location; -use noirc_frontend::{ - ast::BinaryOpKind, - monomorphization::ast::{ - Call, Definition, Expression, FuncId, Function, Ident, IdentId, LocalId, Program, Type, - }, - shared::Visibility, -}; - -use super::{ - Context, VariableId, expr, types, - visitor::{visit_expr, visit_expr_mut}, -}; - -/// To avoid the potential of infinite recursion at runtime, add a `ctx_limit: &mut u32` -/// parameter to all functions, which we use to limit the number of recursive calls. -/// -/// This is complicated by the fact that we cannot pass mutable references from ACIR to Brillig. -/// To overcome that, we create a proxy function for unconstrained functions that take -/// `mut ctx_limit: u32` instead, and pass it on as a mutable ref. -/// -/// Originally only actually recursive functions (ie. one that called something else) -/// received this extra parameters, but in order to support higher order functions -/// which can be passed a recursive or a non-recursive function as an argument, -/// all functions get the extra parameter. -pub(crate) fn add_recursion_limit( - ctx: &mut Context, - u: &mut Unstructured, -) -> arbitrary::Result<()> { - // Collect recursive functions, ie. the ones which call other functions. - let recursive_functions = ctx - .functions - .iter() - .filter_map(|(id, func)| expr::has_call(&func.body).then_some(*id)) - .collect::>(); - - // Collect functions called from ACIR; they will need proxy functions. - let called_from_acir = ctx.functions.values().filter(|func| !func.unconstrained).fold( - HashSet::::new(), - |mut acc, func| { - acc.extend(expr::callees(&func.body)); - acc - }, - ); - - let unconstrained_functions = ctx - .functions - .iter() - .filter_map(|(id, func)| func.unconstrained.then_some(*id)) - .collect::>(); - - // Create proxies for unconstrained functions called from ACIR. - let mut proxy_functions = HashMap::new(); - let mut next_func_id = FuncId(ctx.functions.len() as u32); - - /// Decide how to pass the limit to function valued parameter passed to a function. - fn limit_type_for_func_param(callee_unconstrained: bool, param_unconstrained: bool) -> Type { - // If the function receiving the parameter is ACIR, and the function we pass - // to it is Brillig, it will have to pass the limit by value. - // Otherwise by ref should work. We don't pass ACIR to Brillig. - if !callee_unconstrained && param_unconstrained { - types::U32 - } else { - types::ref_mut(types::U32) - } - } - - for (func_id, func) in &ctx.functions { - if !func.unconstrained - || *func_id == Program::main_id() - || !called_from_acir.contains(func_id) - { - continue; - } - let mut proxy = func.clone(); - proxy.id = next_func_id; - proxy.name = format!("{}_proxy", proxy.name); - // We will replace the body, update the params, and append the function later. - proxy_functions.insert(*func_id, proxy); - next_func_id = FuncId(next_func_id.0 + 1); - } - - // Rewrite functions. - for (func_id, func) in ctx.functions.iter_mut() { - let is_main = *func_id == Program::main_id(); - let is_recursive = recursive_functions.contains(func_id); - - // We'll need a new ID for variables or parameters. We could speed this up by - // 1) caching this value in a "function meta" construct, or - // 2) using `u32::MAX`, but then we would be in a worse situation next time - // 3) draw values from `Context` instead of `FunctionContext`, which breaks continuity, but saves an extra traversal. - // We wouldn't be able to add caching to `Program` without changing it, so eventually we'll need to look at the values - // to do random mutations, or we have to pass back some meta along with `Program` and look it up there. For now we - // traverse the AST to figure out what the next ID to use is. - let (mut next_local_id, mut next_ident_id) = next_local_and_ident_id(func); - - let mut next_local_id = || { - let id = next_local_id; - next_local_id += 1; - LocalId(id) - }; - - let mut next_ident_id = || { - let id = next_ident_id; - next_ident_id += 1; - IdentId(id) - }; - - let limit_name = "ctx_limit".to_string(); - let limit_id = next_local_id(); - let limit_var = VariableId::Local(limit_id); - - if is_main { - // In main we initialize the limit to its maximum value. - let init_limit = expr::let_var( - limit_id, - true, - limit_name.clone(), - expr::u32_literal(ctx.config.max_recursive_calls as u32), - ); - expr::prepend(&mut func.body, init_limit); - } else if is_recursive { - // In non-main we look at the limit and return a random value if it's zero, - // otherwise decrease it by one and continue with the original body. - let limit_type = types::ref_mut(types::U32); - func.parameters.push(( - limit_id, - false, - limit_name.clone(), - limit_type.clone(), - Visibility::Private, - )); - - // Generate a random value to return. - let default_return = expr::gen_literal(u, &func.return_type)?; - - let limit_ident = expr::ident_inner( - limit_var, - next_ident_id(), - false, - limit_name.clone(), - limit_type, - ); - let limit_expr = Expression::Ident(limit_ident.clone()); - - expr::replace(&mut func.body, |mut body| { - expr::prepend( - &mut body, - expr::assign_ref( - limit_ident, - expr::binary( - expr::deref(limit_expr.clone(), types::U32), - BinaryOpKind::Subtract, - expr::u32_literal(1), - ), - ), - ); - expr::if_else( - expr::equal(expr::deref(limit_expr.clone(), types::U32), expr::u32_literal(0)), - default_return, - body, - func.return_type.clone(), - ) - }); - } else { - // For non-recursive functions just add an unused parameter. - // In non-main we look at the limit and return a random value if it's zero, - // otherwise decrease it by one and continue with the original body. - let limit_type = types::ref_mut(types::U32); - func.parameters.push(( - limit_id, - false, - format!("_{limit_name}"), - limit_type.clone(), - Visibility::Private, - )); - } - - // Add the non-reference version of the parameter to the proxy function. - if let Some(proxy) = proxy_functions.get_mut(func_id) { - proxy.parameters.push(( - limit_id, - true, - limit_name.clone(), - types::U32, - Visibility::Private, - )); - // The body is just a call the the non-proxy function. - proxy.body = Expression::Call(Call { - func: Box::new(Expression::Ident(Ident { - location: None, - definition: Definition::Function(*func_id), - mutable: false, - name: func.name.clone(), - typ: Type::Function( - func.parameters.iter().map(|p| p.3.clone()).collect(), - Box::new(func.return_type.clone()), - Box::new(Type::Unit), - func.unconstrained, - ), - id: next_ident_id(), - })), - arguments: proxy - .parameters - .iter() - .map(|(id, mutable, name, typ, _visibility)| { - if *id == limit_id { - // Pass mutable reference to the limit. - expr::ref_mut( - expr::ident( - VariableId::Local(*id), - next_ident_id(), - *mutable, - name.clone(), - typ.clone(), - ), - typ.clone(), - ) - } else { - // Pass every other parameter as-is. - expr::ident( - VariableId::Local(*id), - next_ident_id(), - *mutable, - name.clone(), - typ.clone(), - ) - } - }) - .collect(), - return_type: proxy.return_type.clone(), - location: Location::dummy(), - }); - } - - // Update calls to pass along the limit and call the proxy if necessary. - // Also find places where we are passing a function pointer, and change - // it into the proxy version if necessary. - visit_expr_mut(&mut func.body, &mut |expr: &mut Expression| { - if let Expression::Call(call) = expr { - let Expression::Ident(ident) = call.func.as_mut() else { - unreachable!("functions are called by ident"); - }; - let Definition::Function(callee_id) = ident.definition else { - unreachable!("function definition expected"); - }; - let Type::Function(param_types, _, _, _) = &mut ident.typ else { - unreachable!("function type expected"); - }; - let callee_unconstrained = unconstrained_functions.contains(&callee_id); - - if callee_unconstrained && !func.unconstrained { - // Calling Brillig from ACIR: call the proxy. - let Some(proxy) = proxy_functions.get(&callee_id) else { - unreachable!("expected to have a proxy"); - }; - ident.name = proxy.name.clone(); - ident.definition = Definition::Function(proxy.id); - // Pass the limit by value. - let limit_expr = if is_main { - expr::ident( - limit_var, - next_ident_id(), - true, - limit_name.clone(), - types::U32, - ) - } else { - expr::deref( - expr::ident( - limit_var, - next_ident_id(), - false, - limit_name.clone(), - types::ref_mut(types::U32), - ), - types::U32, - ) - }; - param_types.push(types::U32); - call.arguments.push(limit_expr); - } else { - // Pass the limit by reference. - let limit_type = types::ref_mut(types::U32); - let limit_expr = if is_main { - expr::ref_mut( - expr::ident( - limit_var, - next_ident_id(), - true, - limit_name.clone(), - types::U32, - ), - limit_type, - ) - } else { - expr::ident( - limit_var, - next_ident_id(), - false, - limit_name.clone(), - limit_type, - ) - }; - param_types.push(types::U32); - call.arguments.push(limit_expr); - } - - // Now go through all the parameters: if they pass a function pointer, - // change the proxy or the original based on the caller. - for i in 0..param_types.len() { - let param_type = &mut param_types[i]; - if let Type::Function(param_types, _, _, param_unconstrained) = param_type { - let typ = - limit_type_for_func_param(callee_unconstrained, *param_unconstrained); - - // If we need to pass by value, then it's going to the proxy. - // We don't have to update when the value we pass on is an input parameter, - // but I don't know yet what that will look like. - if !types::is_reference(&typ) { - let arg = &mut call.arguments[i]; - let Expression::Ident(func_param_ident) = arg else { - unreachable!("functions are passed by ident; got {arg}"); - }; - let Definition::Function(func_param_id) = func_param_ident.definition - else { - unreachable!( - "function definition expected; got {}", - func_param_ident.definition - ); - }; - let Some(proxy) = proxy_functions.get(&func_param_id) else { - unreachable!( - "expected to have a proxy for the function pointer: {func_param_id}; got some for {:?}", - proxy_functions.keys().collect::>() - ); - }; - func_param_ident.name = proxy.name.clone(); - func_param_ident.definition = Definition::Function(proxy.id); - } - - // Add the limit to the function described in the parameter. - param_types.push(typ); - } - } - } - true - }); - } - - // Append proxy functions. - for (_, proxy) in proxy_functions { - ctx.functions.insert(proxy.id, proxy); - } - - // Rewrite function valued parameters to take the limit. - for func in ctx.functions.values_mut() { - for param in func.parameters.iter_mut() { - if let Type::Function(param_types, _, _, param_unconstrained) = &mut param.3 { - let typ = limit_type_for_func_param(func.unconstrained, *param_unconstrained); - param_types.push(typ); - } - } - } - - Ok(()) -} - -/// Find the next local ID and ident IDs (in that order) that we can use to add -/// variables to a [Function] during mutations. -fn next_local_and_ident_id(func: &Function) -> (u32, u32) { - let mut next_local_id = func.parameters.iter().map(|p| p.0.0 + 1).max().unwrap_or_default(); - let mut next_ident_id = 0; - - visit_expr(&func.body, &mut |expr| { - let local_id = match expr { - Expression::Let(let_) => Some(let_.id), - Expression::For(for_) => Some(for_.index_variable), - Expression::Ident(ident) => { - next_ident_id = next_ident_id.max(ident.id.0 + 1); - None - } - _ => None, - }; - if let Some(id) = local_id { - next_local_id = next_local_id.max(id.0 + 1); - } - true - }); - (next_local_id, next_ident_id) -} - -/// Turn all ACIR functions into Brillig functions. -/// -/// This is more involved than flipping the `unconstrained` property because of the -/// "ownership analysis", which can only run on a function once. -pub fn change_all_functions_into_unconstrained(mut program: Program) -> Program { - for f in program.functions.iter_mut() { - if f.unconstrained { - continue; - } - f.unconstrained = true; - f.handle_ownership(); - } - program -} diff --git a/tooling/ast_fuzzer/src/program/rewrite/limit.rs b/tooling/ast_fuzzer/src/program/rewrite/limit.rs new file mode 100644 index 00000000000..42858a3e3eb --- /dev/null +++ b/tooling/ast_fuzzer/src/program/rewrite/limit.rs @@ -0,0 +1,500 @@ +use std::collections::{HashMap, HashSet}; + +use arbitrary::Unstructured; +use nargo::errors::Location; +use noirc_frontend::{ + ast::BinaryOpKind, + monomorphization::ast::{ + Call, Definition, Expression, FuncId, Function, Ident, IdentId, LocalId, Program, Type, + }, + shared::Visibility, +}; + +use crate::program::{Context, VariableId, expr, types, visitor::visit_expr_mut}; + +use super::next_local_and_ident_id; + +const LIMIT_NAME: &str = "ctx_limit"; + +/// To avoid the potential of infinite recursion at runtime, add a `ctx_limit: &mut u32` +/// parameter to all functions, which we use to limit the number of recursive calls. +/// +/// This is complicated by the fact that we cannot pass mutable references from ACIR to Brillig. +/// To overcome that, we create a proxy function for unconstrained functions that take +/// `mut ctx_limit: u32` instead, and pass it on as a mutable ref. +/// +/// Originally only actually recursive functions (ie. one that called something else) +/// received this extra parameters, but in order to support higher order functions +/// which can be passed a recursive or a non-recursive function as an argument, +/// all functions get the extra parameter. +pub(crate) fn add_recursion_limit( + ctx: &mut Context, + u: &mut Unstructured, +) -> arbitrary::Result<()> { + // Collect functions called from ACIR; they will need proxy functions. + let called_from_acir = ctx.functions.values().filter(|func| !func.unconstrained).fold( + HashSet::::new(), + |mut acc, func| { + acc.extend(expr::callees(&func.body)); + acc + }, + ); + + // Create proxies for unconstrained functions called from ACIR. + let mut proxy_functions = HashMap::new(); + let mut next_func_id = FuncId(ctx.functions.len() as u32); + + for (func_id, func) in &ctx.functions { + if !func.unconstrained + || *func_id == Program::main_id() + || !called_from_acir.contains(func_id) + { + continue; + } + let mut proxy = func.clone(); + proxy.id = next_func_id; + proxy.name = format!("{}_proxy", proxy.name); + // We will replace the body, update the params, and append the function later. + proxy_functions.insert(*func_id, proxy); + next_func_id = FuncId(next_func_id.0 + 1); + } + + // Rewrite functions. + for (func_id, func) in ctx.functions.iter_mut() { + let mut ctx = LimitContext::new(*func_id, func, ctx.config.max_recursive_calls as u32); + + ctx.rewrite_functions(u, &mut proxy_functions)?; + } + + // Append proxy functions. + for (_, proxy) in proxy_functions { + ctx.functions.insert(proxy.id, proxy); + } + + Ok(()) +} + +/// Decide how to pass the recursion limit to function: by value or by ref. +fn ctx_limit_type_for_func_param(callee_unconstrained: bool, param_unconstrained: bool) -> Type { + // If the function receiving the parameter is ACIR, and the function we pass + // to it is Brillig, it will have to pass the limit by value. + // Otherwise by ref should work. We don't pass ACIR to Brillig. + if !callee_unconstrained && param_unconstrained { + types::U32 + } else { + types::ref_mut(types::U32) + } +} + +struct LimitContext<'a> { + func_id: FuncId, + func: &'a mut Function, + is_main: bool, + is_recursive: bool, + next_local_id: u32, + next_ident_id: u32, + max_recursive_calls: u32, +} + +impl<'a> LimitContext<'a> { + fn new(func_id: FuncId, func: &'a mut Function, max_recursive_calls: u32) -> Self { + let is_main = func_id == Program::main_id(); + + // Recursive functions are those that call another function. + let is_recursive = expr::has_call(&func.body); + + // We'll need a new ID for variables or parameters. We could speed this up by + // 1) caching this value in a "function meta" construct, or + // 2) using `u32::MAX`, but then we would be in a worse situation next time + // 3) draw values from `Context` instead of `FunctionContext`, which breaks continuity, but saves an extra traversal. + // We wouldn't be able to add caching to `Program` without changing it, so eventually we'll need to look at the values + // to do random mutations, or we have to pass back some meta along with `Program` and look it up there. For now we + // traverse the AST to figure out what the next ID to use is. + let (next_local_id, next_ident_id) = next_local_and_ident_id(func); + + Self { + func_id, + func, + is_main, + is_recursive, + next_local_id, + next_ident_id, + max_recursive_calls, + } + } + + /// Rewrite the function and its proxy (if it has one). + fn rewrite_functions( + &mut self, + u: &mut Unstructured, + proxy_functions: &mut HashMap, + ) -> arbitrary::Result<()> { + let limit_id = self.next_local_id(); + + // Limit variable operations in the body + if self.is_main { + self.modify_body_when_main(limit_id); + } else if self.is_recursive { + self.modify_body_when_recursive(u, limit_id)?; + } else { + self.modify_body_when_non_recursive(limit_id); + } + + // Call forwarding in the proxy + self.set_proxy_function(limit_id, proxy_functions); + + // Passing along the limit in calls + self.modify_calls(limit_id, proxy_functions); + + // Update function pointer types to have the extra parameter. + self.modify_function_pointer_param_types(proxy_functions); + + Ok(()) + } + + fn next_local_id(&mut self) -> LocalId { + let id = self.next_local_id; + self.next_local_id += 1; + LocalId(id) + } + + fn next_ident_id(&mut self) -> IdentId { + let id = self.next_ident_id; + self.next_ident_id += 1; + IdentId(id) + } + + /// In `main` we initialize the recursion limit. + fn modify_body_when_main(&mut self, limit_id: LocalId) { + let init_limit = expr::let_var( + limit_id, + true, + LIMIT_NAME.to_string(), + expr::u32_literal(self.max_recursive_calls), + ); + expr::prepend(&mut self.func.body, init_limit); + } + + /// In non-main we look at the limit and return a random value if it's zero, + /// otherwise decrease it by one and continue with the original body. + fn modify_body_when_recursive( + &mut self, + u: &mut Unstructured, + limit_id: LocalId, + ) -> arbitrary::Result<()> { + let limit_var = VariableId::Local(limit_id); + + let limit_type = types::ref_mut(types::U32); + self.func.parameters.push(( + limit_id, + false, + LIMIT_NAME.to_string(), + limit_type.clone(), + Visibility::Private, + )); + + // Generate a random value to return. + let default_return = expr::gen_literal(u, &self.func.return_type)?; + + let limit_ident = expr::ident_inner( + limit_var, + self.next_ident_id(), + false, + LIMIT_NAME.to_string(), + limit_type, + ); + let limit_expr = Expression::Ident(limit_ident.clone()); + + expr::replace(&mut self.func.body, |mut body| { + expr::prepend( + &mut body, + expr::assign_ref( + limit_ident, + expr::binary( + expr::deref(limit_expr.clone(), types::U32), + BinaryOpKind::Subtract, + expr::u32_literal(1), + ), + ), + ); + expr::if_else( + expr::equal(expr::deref(limit_expr.clone(), types::U32), expr::u32_literal(0)), + default_return, + body, + self.func.return_type.clone(), + ) + }); + + Ok(()) + } + + /// For non-recursive functions just add an unused parameter. + /// In non-main we look at the limit and return a random value if it's zero, + /// otherwise decrease it by one and continue with the original body. + fn modify_body_when_non_recursive(&mut self, limit_id: LocalId) { + let limit_type = types::ref_mut(types::U32); + self.func.parameters.push(( + limit_id, + false, + format!("_{LIMIT_NAME}"), + limit_type.clone(), + Visibility::Private, + )); + } + + /// Fill the body of a `func_{i}_proxy` with an expression to forward the call + /// to the original function. Add the `ctx_parameter` as well. + fn set_proxy_function( + &mut self, + limit_id: LocalId, + proxy_functions: &mut HashMap, + ) { + let Some(proxy) = proxy_functions.get_mut(&self.func_id) else { + return; + }; + + proxy.parameters.push(( + limit_id, + true, + LIMIT_NAME.to_string(), + types::U32, + Visibility::Private, + )); + + // The body is just a call the the non-proxy function. + proxy.body = Expression::Call(Call { + func: Box::new(Expression::Ident(Ident { + location: None, + definition: Definition::Function(self.func_id), + mutable: false, + name: self.func.name.clone(), + typ: Type::Function( + self.func.parameters.iter().map(|p| p.3.clone()).collect(), + Box::new(self.func.return_type.clone()), + Box::new(Type::Unit), + self.func.unconstrained, + ), + id: self.next_ident_id(), + })), + arguments: proxy + .parameters + .iter() + .map(|(id, mutable, name, typ, _visibility)| { + if *id == limit_id { + // Pass mutable reference to the limit. + expr::ref_mut( + expr::ident( + VariableId::Local(*id), + self.next_ident_id(), + *mutable, + name.clone(), + typ.clone(), + ), + typ.clone(), + ) + } else { + // Pass every other parameter as-is. + expr::ident( + VariableId::Local(*id), + self.next_ident_id(), + *mutable, + name.clone(), + typ.clone(), + ) + } + }) + .collect(), + return_type: proxy.return_type.clone(), + location: Location::dummy(), + }); + } + + /// Visit all the calls made by this function and pass along the limit. + fn modify_calls(&mut self, limit_id: LocalId, proxy_functions: &HashMap) { + let limit_var = VariableId::Local(limit_id); + + // Swap out the body because we need mutable access to self in the visitor. + let mut body = Expression::Break; + std::mem::swap(&mut self.func.body, &mut body); + + // Update calls to pass along the limit and call the proxy if necessary. + // Also find places where we are passing a function pointer, and change + // it into the proxy version if necessary. + visit_expr_mut(&mut body, &mut |expr: &mut Expression| { + if let Expression::Call(call) = expr { + let Expression::Ident(ident) = call.func.as_mut() else { + unreachable!("functions are called by ident"); + }; + + let proxy = match &ident.definition { + Definition::Function(id) => proxy_functions.get(id), + Definition::Local(_) => None, + other => unreachable!("function or local definition expected; got {}", other), + }; + + let Type::Function(param_types, _, _, callee_unconstrained) = &mut ident.typ else { + unreachable!("function type expected"); + }; + + if *callee_unconstrained && !self.func.unconstrained { + // Calling Brillig from ACIR: call the proxy if it's global. + if let Some(proxy) = proxy { + ident.name = proxy.name.clone(); + ident.definition = Definition::Function(proxy.id); + } + // Pass the limit by value. + let limit_expr = if self.is_main { + expr::ident( + limit_var, + self.next_ident_id(), + true, + LIMIT_NAME.to_string(), + types::U32, + ) + } else { + expr::deref( + expr::ident( + limit_var, + self.next_ident_id(), + false, + LIMIT_NAME.to_string(), + types::ref_mut(types::U32), + ), + types::U32, + ) + }; + param_types.push(types::U32); + call.arguments.push(limit_expr); + } else { + // Pass the limit by reference. + let limit_type = types::ref_mut(types::U32); + let limit_expr = if self.is_main { + // In main we take a mutable reference to the limit. + expr::ref_mut( + expr::ident( + limit_var, + self.next_ident_id(), + true, + LIMIT_NAME.to_string(), + types::U32, + ), + limit_type, + ) + } else { + // In non-main we just pass along the parameter. + expr::ident( + limit_var, + self.next_ident_id(), + false, + LIMIT_NAME.to_string(), + limit_type, + ) + }; + param_types.push(types::U32); + call.arguments.push(limit_expr); + } + + // Now go through all the parameters: if they are function pointer, + // change the signature type of the parameter based on the caller. + modify_function_pointer_param_types(param_types, *callee_unconstrained); + + // Go through the arguments of the call: if they point at a global + // function, they might need to point at the proxy instead. + modify_function_pointer_param_values( + &mut call.arguments, + param_types, + *callee_unconstrained, + proxy_functions, + ); + } + + // Continue the visiting expressions. + true + }); + + // Put the result back. + std::mem::swap(&mut self.func.body, &mut body); + } + + /// Update any function pointer and the function and its proxy's signature to take the limit. + fn modify_function_pointer_param_types( + &mut self, + proxy_functions: &mut HashMap, + ) { + for (_, _, _, param_type, _) in self.func.parameters.iter_mut() { + modify_function_pointer_param_type(param_type, self.func.unconstrained); + } + if let Some(proxy) = proxy_functions.get_mut(&self.func_id) { + for (_, _, _, param_type, _) in proxy.parameters.iter_mut() { + modify_function_pointer_param_type(param_type, self.func.unconstrained); + } + } + } +} + +/// Go through the types of each function parameter. If they are function pointers, +/// then they need the context, depending on the callee type. +fn modify_function_pointer_param_types(param_types: &mut [Type], callee_unconstrained: bool) { + for param_type in param_types.iter_mut() { + modify_function_pointer_param_type(param_type, callee_unconstrained); + } +} + +/// Recursively modify function pointers in the param type. +fn modify_function_pointer_param_type(param_type: &mut Type, callee_unconstrained: bool) { + let Type::Function(param_types, _, _, param_unconstrained) = param_type else { + return; + }; + + let limit_typ = ctx_limit_type_for_func_param(callee_unconstrained, *param_unconstrained); + + // Add the limit to the function described in the parameter. + param_types.push(limit_typ); + + // We need to recurse into the parameters of the function pointer. + modify_function_pointer_param_types(param_types, *param_unconstrained); +} + +/// Go through the call arguments and update global function pointers to their +/// proxy equivalents if necessary. +fn modify_function_pointer_param_values( + args: &mut [Expression], + param_types: &[Type], + callee_unconstrained: bool, + proxy_functions: &HashMap, +) { + for i in 0..param_types.len() { + let Type::Function(_, _, _, param_unconstrained) = ¶m_types[i] else { + continue; + }; + let limit_typ = ctx_limit_type_for_func_param(callee_unconstrained, *param_unconstrained); + + // If it's passed by reference we can leave it alone. + if types::is_reference(&limit_typ) { + continue; + } + + // If we need to pass by value, then it's going to the proxy, but only if it's a global function, + // and not a function parameter, which is we wouldn't know what to change to, and doing so is the + // happens when it's first passed as a global. + let arg = &mut args[i]; + let Expression::Ident(param_func_ident) = arg else { + unreachable!("functions are passed by ident; got {arg}"); + }; + let param_func_id = match ¶m_func_ident.definition { + Definition::Function(id) => id, + Definition::Local(_) => continue, + other => { + unreachable!("function definition expected; got {}", other); + } + }; + let Some(proxy) = proxy_functions.get(param_func_id) else { + unreachable!( + "expected to have a proxy for the function pointer: {param_func_id}; only have them for {:?}", + proxy_functions.keys().collect::>() + ); + }; + param_func_ident.name = proxy.name.clone(); + param_func_ident.definition = Definition::Function(proxy.id); + } +} diff --git a/tooling/ast_fuzzer/src/program/rewrite/mod.rs b/tooling/ast_fuzzer/src/program/rewrite/mod.rs new file mode 100644 index 00000000000..94db55f9b8d --- /dev/null +++ b/tooling/ast_fuzzer/src/program/rewrite/mod.rs @@ -0,0 +1,46 @@ +use noirc_frontend::monomorphization::ast::{Expression, Function, Program}; + +use super::visitor::visit_expr; + +mod limit; + +pub(crate) use limit::add_recursion_limit; + +/// Find the next local ID and ident IDs (in that order) that we can use to add +/// variables to a [Function] during mutations. +fn next_local_and_ident_id(func: &Function) -> (u32, u32) { + let mut next_local_id = func.parameters.iter().map(|p| p.0.0 + 1).max().unwrap_or_default(); + let mut next_ident_id = 0; + + visit_expr(&func.body, &mut |expr| { + let local_id = match expr { + Expression::Let(let_) => Some(let_.id), + Expression::For(for_) => Some(for_.index_variable), + Expression::Ident(ident) => { + next_ident_id = next_ident_id.max(ident.id.0 + 1); + None + } + _ => None, + }; + if let Some(id) = local_id { + next_local_id = next_local_id.max(id.0 + 1); + } + true + }); + (next_local_id, next_ident_id) +} + +/// Turn all ACIR functions into Brillig functions. +/// +/// This is more involved than flipping the `unconstrained` property because of the +/// "ownership analysis", which can only run on a function once. +pub fn change_all_functions_into_unconstrained(mut program: Program) -> Program { + for f in program.functions.iter_mut() { + if f.unconstrained { + continue; + } + f.unconstrained = true; + f.handle_ownership(); + } + program +} diff --git a/tooling/ast_fuzzer/tests/smoke.rs b/tooling/ast_fuzzer/tests/smoke.rs index 576a634e6d2..4b708ce3d6e 100644 --- a/tooling/ast_fuzzer/tests/smoke.rs +++ b/tooling/ast_fuzzer/tests/smoke.rs @@ -27,7 +27,8 @@ fn arb_program_can_be_executed() { let maybe_seed = seed_from_env(); let mut prop = arbtest(|u| { - let program = arb_program(u, Config::default())?; + let config = Config::default(); + let program = arb_program(u, config)?; let abi = program_abi(&program); let options = ssa::SsaEvaluatorOptions {