diff --git a/tooling/ast_fuzzer/src/program/expr.rs b/tooling/ast_fuzzer/src/program/expr.rs index 786c46c0194..12a298d38f1 100644 --- a/tooling/ast_fuzzer/src/program/expr.rs +++ b/tooling/ast_fuzzer/src/program/expr.rs @@ -92,7 +92,7 @@ pub fn gen_literal(u: &mut Unstructured, typ: &Type) -> arbitrary::Result unreachable!("unexpected literal type: {typ}"), + _ => unreachable!("unexpected type to generate a literal for: {typ}"), }; Ok(expr) } @@ -378,6 +378,14 @@ pub fn callees(expr: &Expression) -> HashSet { callees.insert(func_id); } } + // Consider functions passed as arguments as at least callable. + for arg in &call.arguments { + if let Expression::Ident(ident) = arg { + if let Definition::Function(func_id) = ident.definition { + callees.insert(func_id); + } + } + } } true }); diff --git a/tooling/ast_fuzzer/src/program/func.rs b/tooling/ast_fuzzer/src/program/func.rs index defa6975af9..bf3e1f53e0b 100644 --- a/tooling/ast_fuzzer/src/program/func.rs +++ b/tooling/ast_fuzzer/src/program/func.rs @@ -53,14 +53,6 @@ impl FunctionDeclaration { (param_types, return_type) } - - fn is_acir(&self) -> bool { - !self.unconstrained - } - - fn is_brillig(&self) -> bool { - self.unconstrained - } } /// HIR representation of a function parameter. @@ -85,6 +77,43 @@ pub(crate) fn hir_param( (pat, typ, vis) } +/// Help avoid infinite recursion by limiting which function can call which other one. +pub(super) fn can_call( + caller_id: FuncId, + caller_unconstrained: bool, + callee_id: FuncId, + callee_unconstrained: bool, +) -> bool { + // Nobody should call `main`. + if callee_id == Program::main_id() { + return false; + } + + // From an ACIR function we can call any Brillig function, + // but we avoid creating infinite recursive ACIR calls by + // only calling functions with lower IDs than ours, + // otherwise the inliner could get stuck. + if !caller_unconstrained && !callee_unconstrained { + // Higher calls lower, so we can use this rule to pick function parameters + // as we create the declarations: we can pass functions already declared. + return callee_id < caller_id; + } + + // From a Brillig function we restrict ourselves to only call + // other Brillig functions. That's because the `Monomorphizer` + // would make an unconstrained copy of any ACIR function called + // from Brillig, and this is expected by the inliner for example, + // but if we did similarly in the generator after we know who + // calls who, we would incur two drawbacks: + // 1) it would make programs bigger for little benefit + // 2) it would skew calibration frequencies as ACIR freqs would overlay Brillig ones + if caller_unconstrained { + return callee_unconstrained; + } + + true +} + /// Control what kind of expressions we can generate, depending on the surrounding context. #[derive(Debug, Clone, Copy)] struct Flags { @@ -156,35 +185,14 @@ impl<'a> FunctionContext<'a> { ); // Collect all the functions we can call from this one. + // TODO(#8484): Look for call targets in function-valued arguments as well. let call_targets = ctx .function_declarations .iter() .filter_map(|(callee_id, callee_decl)| { - // We can't call `main`. - if *callee_id == Program::main_id() { - return None; - } - - // From an ACIR function we can call any Brillig function, - // but we avoid creating infinite recursive ACIR calls by - // only calling functions with higher IDs than ours, - // otherwise the inliner could get stuck. - if decl.is_acir() && callee_decl.is_acir() && *callee_id <= id { - return None; - } - - // From a Brillig function we restrict ourselves to only call - // other Brillig functions. That's because the `Monomorphizer` - // would make an unconstrained copy of any ACIR function called - // from Brillig, and this is expected by the inliner for example, - // but if we did similarly in the generator after we know who - // calls who, we would incur two drawbacks: - // 1) it would make programs bigger for little benefit - // 2) it would skew calibration frequencies as ACIR freqs would overlay Brillig ones - if decl.is_brillig() && !callee_decl.is_brillig() { + if !can_call(id, decl.unconstrained, *callee_id, callee_decl.unconstrained) { return None; } - Some((*callee_id, types::types_produced(&callee_decl.return_type))) }) .collect(); @@ -308,6 +316,14 @@ impl<'a> FunctionContext<'a> { max_depth: usize, flags: Flags, ) -> arbitrary::Result { + // For now if we need a function, return one without further nesting, + // e.g. avoid `if { func_1 } else { func_2 }`, because it makes rewriting + // harder when we need to deal with proxies. + if matches!(typ, Type::Function(_, _, _, _)) { + // Local variables we should consider in `gen_expr_from_vars`, so here we just look through global functions. + return self.find_function_with_signature(u, typ); + } + let mut freq = Freq::new(u, &self.ctx.config.expr_freqs)?; // Stop nesting if we reached the bottom. @@ -952,28 +968,19 @@ impl<'a> FunctionContext<'a> { self.has_call = true; let callee_id = *u.choose_iter(opts)?; + let callee_ident = self.function_ident(callee_id); + let callee = self.ctx.function_decl(callee_id).clone(); let param_types = callee.params.iter().map(|p| p.3.clone()).collect::>(); + // Generate an expression for each argument. let mut args = Vec::new(); for typ in ¶m_types { args.push(self.gen_expr(u, typ, max_depth, Flags::CALL)?); } let call_expr = Expression::Call(Call { - func: Box::new(Expression::Ident(Ident { - location: None, - definition: Definition::Function(callee_id), - mutable: false, - name: callee.name.clone(), - typ: Type::Function( - param_types, - Box::new(callee.return_type.clone()), - Box::new(Type::Unit), - callee.unconstrained, - ), - id: self.next_ident_id(), - })), + func: Box::new(callee_ident), arguments: args, return_type: callee.return_type.clone(), location: Location::dummy(), @@ -1154,6 +1161,62 @@ impl<'a> FunctionContext<'a> { } Ok(None) } + + /// Find a global function matching a type signature. + fn find_function_with_signature( + &mut self, + u: &mut Unstructured, + typ: &Type, + ) -> arbitrary::Result { + let Type::Function(param_types, return_type, _, unconstrained) = typ else { + unreachable!( + "find_function_with_signature should only be called with Type::Function; got {typ}" + ); + }; + + // TODO(#8484): Take the callee ID into account, so we don't create a problem inlining ACIR. + let candidates = self + .ctx + .function_declarations + .iter() + .skip(1) // Can't call main. + .filter_map(|(func_id, func)| { + let matches = func.return_type == *return_type.as_ref() + && func.unconstrained == *unconstrained + && func.params.len() == param_types.len() + && func.params.iter().zip(param_types).all(|((_, _, _, a, _), b)| a == b); + + matches.then_some(*func_id) + }) + .collect::>(); + + if candidates.is_empty() { + panic!("No candidate found for function type: {typ}"); + } + + let callee_id = u.choose_iter(candidates)?; + + Ok(self.function_ident(callee_id)) + } + + /// Generate an identifier for calling a global function. + fn function_ident(&mut self, callee_id: FuncId) -> Expression { + let callee = self.ctx.function_decl(callee_id).clone(); + let param_types = callee.params.iter().map(|p| p.3.clone()).collect::>(); + Expression::Ident(Ident { + location: None, + definition: Definition::Function(callee_id), + mutable: false, + name: callee.name.clone(), + typ: Type::Function( + param_types, + Box::new(callee.return_type.clone()), + Box::new(Type::Unit), + callee.unconstrained, + ), + id: self.next_ident_id(), + }) + } } #[test] diff --git a/tooling/ast_fuzzer/src/program/mod.rs b/tooling/ast_fuzzer/src/program/mod.rs index 9eb86406fe8..aaff1335304 100644 --- a/tooling/ast_fuzzer/src/program/mod.rs +++ b/tooling/ast_fuzzer/src/program/mod.rs @@ -1,7 +1,7 @@ //! Module responsible for generating arbitrary [Program] ASTs. use std::collections::{BTreeMap, BTreeSet}; // Using BTree for deterministic enumeration, for repeatability. -use func::{FunctionContext, FunctionDeclaration}; +use func::{FunctionContext, FunctionDeclaration, can_call}; use strum::IntoEnumIterator; use arbitrary::{Arbitrary, Unstructured}; @@ -166,6 +166,7 @@ impl Context { u: &mut Unstructured, i: usize, ) -> arbitrary::Result { + let id = FuncId(i as u32); let is_main = i == 0; let num_params = u.int_in_range(0..=self.config.max_function_args)?; @@ -179,19 +180,49 @@ impl Context { .unwrap_or_default()) || bool::arbitrary(u)?; + // Which existing functions we could receive as parameters. + let func_param_candidates: Vec = if is_main { + // Main cannot receive function parameters from outside. + vec![] + } else { + self.function_declarations + .iter() + .filter_map(|(callee_id, callee)| { + can_call(id, unconstrained, *callee_id, callee.unconstrained) + .then_some(*callee_id) + }) + .collect() + }; + + // Choose parameter types. let mut params = Vec::new(); for p in 0..num_params { let id = LocalId(p as u32); let name = make_name(p, false); let is_mutable = !is_main && bool::arbitrary(u)?; - let typ = self.gen_type( - u, - self.config.max_depth, - false, - is_main, - false, - self.config.comptime_friendly, - )?; + + let typ = if func_param_candidates.is_empty() || u.ratio(7, 10)? { + // Take some kind of data type. + self.gen_type( + u, + self.config.max_depth, + false, + is_main, + false, + self.config.comptime_friendly, + )? + } else { + // Take a function type. + let callee_id = u.choose_iter(&func_param_candidates)?; + let callee = &self.function_declarations[callee_id]; + let param_types = callee.params.iter().map(|p| p.3.clone()).collect::>(); + Type::Function( + param_types, + Box::new(callee.return_type.clone()), + Box::new(Type::Unit), + callee.unconstrained, + ) + }; let visibility = if is_main { match u.choose_index(5)? { @@ -206,6 +237,7 @@ impl Context { params.push((id, is_mutable, name, typ, visibility)); } + // We could return a function as well. let return_type = self.gen_type( u, self.config.max_depth, diff --git a/tooling/ast_fuzzer/src/program/rewrite.rs b/tooling/ast_fuzzer/src/program/rewrite.rs index d29a6d552fc..3f4617354a2 100644 --- a/tooling/ast_fuzzer/src/program/rewrite.rs +++ b/tooling/ast_fuzzer/src/program/rewrite.rs @@ -56,6 +56,18 @@ pub(crate) fn add_recursion_limit( let mut proxy_functions = HashMap::new(); let mut next_func_id = FuncId(ctx.functions.len() as u32); + /// Decide how to pass the limit to function valued parameter passed to a function. + fn limit_type_for_func_param(callee_unconstrained: bool, param_unconstrained: bool) -> Type { + // If the function receiving the parameter is ACIR, and the function we pass + // to it is Brillig, it will have to pass the limit by value. + // Otherwise by ref should work. We don't pass ACIR to Brillig. + if !callee_unconstrained && param_unconstrained { + types::U32 + } else { + types::ref_mut(types::U32) + } + } + for (func_id, func) in &ctx.functions { if !func.unconstrained || *func_id == Program::main_id() @@ -225,6 +237,8 @@ pub(crate) fn add_recursion_limit( } // Update calls to pass along the limit and call the proxy if necessary. + // Also find places where we are passing a function pointer, and change + // it into the proxy version if necessary. visit_expr_mut(&mut func.body, &mut |expr: &mut Expression| { if let Expression::Call(call) = expr { let Expression::Ident(ident) = call.func.as_mut() else { @@ -237,6 +251,7 @@ pub(crate) fn add_recursion_limit( unreachable!("function type expected"); }; let callee_unconstrained = unconstrained_functions.contains(&callee_id); + if callee_unconstrained && !func.unconstrained { // Calling Brillig from ACIR: call the proxy. let Some(proxy) = proxy_functions.get(&callee_id) else { @@ -293,6 +308,44 @@ pub(crate) fn add_recursion_limit( param_types.push(types::U32); call.arguments.push(limit_expr); } + + // Now go through all the parameters: if they pass a function pointer, + // change the proxy or the original based on the caller. + for i in 0..param_types.len() { + let param_type = &mut param_types[i]; + if let Type::Function(param_types, _, _, param_unconstrained) = param_type { + let typ = + limit_type_for_func_param(callee_unconstrained, *param_unconstrained); + + // If we need to pass by value, then it's going to the proxy. + // We don't have to update when the value we pass on is an input parameter, + // but I don't know yet what that will look like. + if !types::is_reference(&typ) { + let arg = &mut call.arguments[i]; + let Expression::Ident(func_param_ident) = arg else { + unreachable!("functions are passed by ident; got {arg}"); + }; + let Definition::Function(func_param_id) = func_param_ident.definition + else { + unreachable!( + "function definition expected; got {}", + func_param_ident.definition + ); + }; + let Some(proxy) = proxy_functions.get(&func_param_id) else { + unreachable!( + "expected to have a proxy for the function pointer: {func_param_id}; got some for {:?}", + proxy_functions.keys().collect::>() + ); + }; + func_param_ident.name = proxy.name.clone(); + func_param_ident.definition = Definition::Function(proxy.id); + } + + // Add the limit to the function described in the parameter. + param_types.push(typ); + } + } } true }); @@ -303,6 +356,16 @@ pub(crate) fn add_recursion_limit( ctx.functions.insert(proxy.id, proxy); } + // Rewrite function valued parameters to take the limit. + for func in ctx.functions.values_mut() { + for param in func.parameters.iter_mut() { + if let Type::Function(param_types, _, _, param_unconstrained) = &mut param.3 { + let typ = limit_type_for_func_param(func.unconstrained, *param_unconstrained); + param_types.push(typ); + } + } + } + Ok(()) } diff --git a/tooling/ast_fuzzer/src/program/types.rs b/tooling/ast_fuzzer/src/program/types.rs index b8ad7a158d5..31c0d87480b 100644 --- a/tooling/ast_fuzzer/src/program/types.rs +++ b/tooling/ast_fuzzer/src/program/types.rs @@ -1,6 +1,7 @@ use std::collections::HashSet; use acir::FieldElement; +use iter_extended::vecmap; use noirc_frontend::{ ast::{BinaryOpKind, IntegerBitSize}, hir_def, @@ -154,10 +155,14 @@ pub(crate) fn to_hir_type(typ: &Type) -> hir_def::types::Type { Type::String(size) => HirType::String(size_const(*size)), Type::Array(size, typ) => HirType::Array(size_const(*size), Box::new(to_hir_type(typ))), Type::Tuple(items) => HirType::Tuple(items.iter().map(to_hir_type).collect()), - Type::FmtString(_, _) - | Type::Slice(_) - | Type::Reference(_, _) - | Type::Function(_, _, _, _) => { + Type::Function(param_types, return_type, env_type, unconstrained) => HirType::Function( + vecmap(param_types, to_hir_type), + Box::new(to_hir_type(return_type)), + Box::new(to_hir_type(env_type)), + *unconstrained, + ), + Type::Reference(typ, mutable) => HirType::Reference(Box::new(to_hir_type(typ)), *mutable), + Type::FmtString(_, _) | Type::Slice(_) => { unreachable!("unexpected type converting to HIR: {}", typ) } } @@ -178,6 +183,10 @@ pub(crate) fn is_bool(typ: &Type) -> bool { matches!(typ, Type::Bool) } +pub(crate) fn is_reference(typ: &Type) -> bool { + matches!(typ, Type::Reference(_, _)) +} + /// Can the type be returned by some `UnaryOp`. pub(crate) fn can_unary_return(typ: &Type) -> bool { match typ {