@@ -36,11 +36,7 @@ use acvm::acir::circuit::brillig::BrilligBytecode;
3636use acvm:: acir:: circuit:: { AssertionPayload , ErrorSelector , OpcodeLocation } ;
3737use acvm:: acir:: native_types:: Witness ;
3838use acvm:: acir:: BlackBoxFunc ;
39- use acvm:: {
40- acir:: AcirField ,
41- acir:: { circuit:: opcodes:: BlockId , native_types:: Expression } ,
42- FieldElement ,
43- } ;
39+ use acvm:: { acir:: circuit:: opcodes:: BlockId , acir:: AcirField , FieldElement } ;
4440use fxhash:: FxHashMap as HashMap ;
4541use im:: Vector ;
4642use iter_extended:: { try_vecmap, vecmap} ;
@@ -330,38 +326,10 @@ impl Ssa {
330326 bytecode : brillig. byte_code ,
331327 } ) ;
332328
333- let runtime_types = self . functions . values ( ) . map ( |function| function. runtime ( ) ) ;
334- for ( acir, runtime_type) in acirs. iter_mut ( ) . zip ( runtime_types) {
335- if matches ! ( runtime_type, RuntimeType :: Acir ( _) ) {
336- generate_distinct_return_witnesses ( acir) ;
337- }
338- }
339-
340329 Ok ( ( acirs, brillig, self . error_selector_to_type ) )
341330 }
342331}
343332
344- fn generate_distinct_return_witnesses ( acir : & mut GeneratedAcir ) {
345- // Create a witness for each return witness we have to guarantee that the return witnesses match the standard
346- // layout for serializing those types as if they were being passed as inputs.
347- //
348- // This is required for recursion as otherwise in situations where we cannot make use of the program's ABI
349- // (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're
350- // working with rather than following the standard ABI encoding rules.
351- //
352- // TODO: We're being conservative here by generating a new witness for every expression.
353- // This means that we're likely to get a number of constraints which are just renumbering witnesses.
354- // This can be tackled by:
355- // - Tracking the last assigned public input witness and only renumbering a witness if it is below this value.
356- // - Modifying existing constraints to rearrange their outputs so they are suitable
357- // - See: https://github.com/noir-lang/noir/pull/4467
358- let distinct_return_witness = vecmap ( acir. return_witnesses . clone ( ) , |return_witness| {
359- acir. create_witness_for_expression ( & Expression :: from ( return_witness) )
360- } ) ;
361-
362- acir. return_witnesses = distinct_return_witness;
363- }
364-
365333impl < ' a > Context < ' a > {
366334 fn new ( shared_context : & ' a mut SharedContext ) -> Context < ' a > {
367335 let mut acir_context = AcirContext :: default ( ) ;
@@ -422,15 +390,45 @@ impl<'a> Context<'a> {
422390 let dfg = & main_func. dfg ;
423391 let entry_block = & dfg[ main_func. entry_block ( ) ] ;
424392 let input_witness = self . convert_ssa_block_params ( entry_block. parameters ( ) , dfg) ?;
393+ let num_return_witnesses =
394+ self . get_num_return_witnesses ( entry_block. unwrap_terminator ( ) , dfg) ;
395+
396+ // Create a witness for each return witness we have to guarantee that the return witnesses match the standard
397+ // layout for serializing those types as if they were being passed as inputs.
398+ //
399+ // This is required for recursion as otherwise in situations where we cannot make use of the program's ABI
400+ // (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're
401+ // working with rather than following the standard ABI encoding rules.
402+ //
403+ // We allocate these witnesses now before performing ACIR gen for the rest of the program as the location of
404+ // the function's return values can then be determined through knowledge of its ABI alone.
405+ let return_witness_vars =
406+ vecmap ( 0 ..num_return_witnesses, |_| self . acir_context . add_variable ( ) ) ;
407+
408+ let return_witnesses = vecmap ( & return_witness_vars, |return_var| {
409+ let expr = self . acir_context . var_to_expression ( * return_var) . unwrap ( ) ;
410+ expr. to_witness ( ) . expect ( "return vars should be witnesses" )
411+ } ) ;
425412
426413 self . data_bus = dfg. data_bus . to_owned ( ) ;
427414 let mut warnings = Vec :: new ( ) ;
428415 for instruction_id in entry_block. instructions ( ) {
429416 warnings. extend ( self . convert_ssa_instruction ( * instruction_id, dfg, ssa, brillig) ?) ;
430417 }
431418
432- warnings. extend ( self . convert_ssa_return ( entry_block. unwrap_terminator ( ) , dfg) ?) ;
433- Ok ( self . acir_context . finish ( input_witness, warnings) )
419+ let ( return_vars, return_warnings) =
420+ self . convert_ssa_return ( entry_block. unwrap_terminator ( ) , dfg) ?;
421+
422+ // TODO: This is a naive method of assigning the return values to their witnesses as
423+ // we're likely to get a number of constraints which are asserting one witness to be equal to another.
424+ //
425+ // We should search through the program and relabel these witnesses so we can remove this constraint.
426+ for ( witness_var, return_var) in return_witness_vars. iter ( ) . zip ( return_vars) {
427+ self . acir_context . assert_eq_var ( * witness_var, return_var, None ) ?;
428+ }
429+
430+ warnings. extend ( return_warnings) ;
431+ Ok ( self . acir_context . finish ( input_witness, return_witnesses, warnings) )
434432 }
435433
436434 fn convert_brillig_main (
@@ -468,17 +466,13 @@ impl<'a> Context<'a> {
468466 ) ?;
469467 self . shared_context . insert_generated_brillig ( main_func. id ( ) , arguments, 0 , code) ;
470468
471- let output_vars : Vec < _ > = output_values
469+ let return_witnesses : Vec < Witness > = output_values
472470 . iter ( )
473471 . flat_map ( |value| value. clone ( ) . flatten ( ) )
474- . map ( |value| value . 0 )
475- . collect ( ) ;
472+ . map ( |( value, _ ) | self . acir_context . var_to_witness ( value ) )
473+ . collect :: < Result < _ , _ > > ( ) ? ;
476474
477- for acir_var in output_vars {
478- self . acir_context . return_var ( acir_var) ?;
479- }
480-
481- let generated_acir = self . acir_context . finish ( witness_inputs, Vec :: new ( ) ) ;
475+ let generated_acir = self . acir_context . finish ( witness_inputs, return_witnesses, Vec :: new ( ) ) ;
482476
483477 assert_eq ! (
484478 generated_acir. opcodes( ) . len( ) ,
@@ -1724,12 +1718,39 @@ impl<'a> Context<'a> {
17241718 self . define_result ( dfg, instruction, AcirValue :: Var ( result, typ) ) ;
17251719 }
17261720
1721+ /// Converts an SSA terminator's return values into their ACIR representations
1722+ fn get_num_return_witnesses (
1723+ & mut self ,
1724+ terminator : & TerminatorInstruction ,
1725+ dfg : & DataFlowGraph ,
1726+ ) -> usize {
1727+ let return_values = match terminator {
1728+ TerminatorInstruction :: Return { return_values, .. } => return_values,
1729+ // TODO(https://github.com/noir-lang/noir/issues/4616): Enable recursion on foldable/non-inlined ACIR functions
1730+ _ => unreachable ! ( "ICE: Program must have a singular return" ) ,
1731+ } ;
1732+
1733+ return_values. iter ( ) . fold ( 0 , |acc, value_id| {
1734+ let is_databus = self
1735+ . data_bus
1736+ . return_data
1737+ . map_or ( false , |return_databus| dfg[ * value_id] == dfg[ return_databus] ) ;
1738+
1739+ if is_databus {
1740+ // We do not return value for the data bus.
1741+ acc
1742+ } else {
1743+ acc + dfg. type_of_value ( * value_id) . flattened_size ( )
1744+ }
1745+ } )
1746+ }
1747+
17271748 /// Converts an SSA terminator's return values into their ACIR representations
17281749 fn convert_ssa_return (
17291750 & mut self ,
17301751 terminator : & TerminatorInstruction ,
17311752 dfg : & DataFlowGraph ,
1732- ) -> Result < Vec < SsaReport > , RuntimeError > {
1753+ ) -> Result < ( Vec < AcirVar > , Vec < SsaReport > ) , RuntimeError > {
17331754 let ( return_values, call_stack) = match terminator {
17341755 TerminatorInstruction :: Return { return_values, call_stack } => {
17351756 ( return_values, call_stack. clone ( ) )
@@ -1739,6 +1760,7 @@ impl<'a> Context<'a> {
17391760 } ;
17401761
17411762 let mut has_constant_return = false ;
1763+ let mut return_vars: Vec < AcirVar > = Vec :: new ( ) ;
17421764 for value_id in return_values {
17431765 let is_databus = self
17441766 . data_bus
@@ -1759,7 +1781,7 @@ impl<'a> Context<'a> {
17591781 dfg,
17601782 ) ?;
17611783 } else {
1762- self . acir_context . return_var ( acir_var) ? ;
1784+ return_vars . push ( acir_var) ;
17631785 }
17641786 }
17651787 }
@@ -1770,7 +1792,7 @@ impl<'a> Context<'a> {
17701792 Vec :: new ( )
17711793 } ;
17721794
1773- Ok ( warnings)
1795+ Ok ( ( return_vars , warnings) )
17741796 }
17751797
17761798 /// Gets the cached `AcirVar` that was converted from the corresponding `ValueId`. If it does
@@ -3079,8 +3101,8 @@ mod test {
30793101 check_call_opcode (
30803102 & func_with_nested_call_opcodes[ 1 ] ,
30813103 2 ,
3082- vec ! [ Witness ( 2 ) , Witness ( 1 ) ] ,
3083- vec ! [ Witness ( 3 ) ] ,
3104+ vec ! [ Witness ( 3 ) , Witness ( 1 ) ] ,
3105+ vec ! [ Witness ( 4 ) ] ,
30843106 ) ;
30853107 }
30863108
@@ -3100,13 +3122,13 @@ mod test {
31003122 for ( expected_input, input) in expected_inputs. iter ( ) . zip ( inputs) {
31013123 assert_eq ! (
31023124 expected_input, input,
3103- "Expected witness {expected_input:?} but got {input:?}"
3125+ "Expected input witness {expected_input:?} but got {input:?}"
31043126 ) ;
31053127 }
31063128 for ( expected_output, output) in expected_outputs. iter ( ) . zip ( outputs) {
31073129 assert_eq ! (
31083130 expected_output, output,
3109- "Expected witness {expected_output:?} but got {output:?}"
3131+ "Expected output witness {expected_output:?} but got {output:?}"
31103132 ) ;
31113133 }
31123134 }
0 commit comments