Skip to content

Commit 1252b5f

Browse files
TomAFrenchjfecher
andauthored
feat: place return value witnesses directly after function arguments (#5142)
# Description ## Problem\* Resolves #5104 ## Summary\* This PR preallocates some witnesses to hold the return values at the beginning of ACIR gen and then adds assertions to fill these witnesses with the return values. This ensures that the return values will be placed in the witness map directly after any function inputs (reasons for this being desirable are laid out in #5104) ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --------- Co-authored-by: jfecher <jake@aztecprotocol.com>
1 parent d6122eb commit 1252b5f

3 files changed

Lines changed: 75 additions & 68 deletions

File tree

compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ impl AcirContext {
255255
}
256256

257257
/// Converts an [`AcirVar`] to a [`Witness`]
258-
fn var_to_witness(&mut self, var: AcirVar) -> Result<Witness, InternalError> {
258+
pub(crate) fn var_to_witness(&mut self, var: AcirVar) -> Result<Witness, InternalError> {
259259
let expression = self.var_to_expression(var)?;
260260
let witness = if let Some(constant) = expression.to_const() {
261261
// Check if a witness has been assigned this value already, if so reuse it.
@@ -1027,15 +1027,6 @@ impl AcirContext {
10271027
Ok(remainder)
10281028
}
10291029

1030-
/// Converts the `AcirVar` to a `Witness` if it hasn't been already, and appends it to the
1031-
/// `GeneratedAcir`'s return witnesses.
1032-
pub(crate) fn return_var(&mut self, acir_var: AcirVar) -> Result<(), InternalError> {
1033-
let return_var = self.get_or_create_witness_var(acir_var)?;
1034-
let witness = self.var_to_witness(return_var)?;
1035-
self.acir_ir.push_return_witness(witness);
1036-
Ok(())
1037-
}
1038-
10391030
/// Constrains the `AcirVar` variable to be of type `NumericType`.
10401031
pub(crate) fn range_constrain_var(
10411032
&mut self,
@@ -1538,9 +1529,11 @@ impl AcirContext {
15381529
pub(crate) fn finish(
15391530
mut self,
15401531
inputs: Vec<Witness>,
1532+
return_values: Vec<Witness>,
15411533
warnings: Vec<SsaReport>,
15421534
) -> GeneratedAcir {
15431535
self.acir_ir.input_witnesses = inputs;
1536+
self.acir_ir.return_witnesses = return_values;
15441537
self.acir_ir.warnings = warnings;
15451538
self.acir_ir
15461539
}

compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,6 @@ pub(crate) struct GeneratedAcir {
4545
opcodes: Vec<AcirOpcode<FieldElement>>,
4646

4747
/// All witness indices that comprise the final return value of the program
48-
///
49-
/// Note: This may contain repeated indices, which is necessary for later mapping into the
50-
/// abi's return type.
5148
pub(crate) return_witnesses: Vec<Witness>,
5249

5350
/// All witness indices which are inputs to the main function
@@ -164,11 +161,6 @@ impl GeneratedAcir {
164161

165162
fresh_witness
166163
}
167-
168-
/// Adds a witness index to the program's return witnesses.
169-
pub(crate) fn push_return_witness(&mut self, witness: Witness) {
170-
self.return_witnesses.push(witness);
171-
}
172164
}
173165

174166
impl GeneratedAcir {

compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs

Lines changed: 72 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,7 @@ use acvm::acir::circuit::brillig::BrilligBytecode;
3636
use acvm::acir::circuit::{AssertionPayload, ErrorSelector, OpcodeLocation};
3737
use acvm::acir::native_types::Witness;
3838
use acvm::acir::BlackBoxFunc;
39-
use acvm::{
40-
acir::AcirField,
41-
acir::{circuit::opcodes::BlockId, native_types::Expression},
42-
FieldElement,
43-
};
39+
use acvm::{acir::circuit::opcodes::BlockId, acir::AcirField, FieldElement};
4440
use fxhash::FxHashMap as HashMap;
4541
use im::Vector;
4642
use iter_extended::{try_vecmap, vecmap};
@@ -330,38 +326,10 @@ impl Ssa {
330326
bytecode: brillig.byte_code,
331327
});
332328

333-
let runtime_types = self.functions.values().map(|function| function.runtime());
334-
for (acir, runtime_type) in acirs.iter_mut().zip(runtime_types) {
335-
if matches!(runtime_type, RuntimeType::Acir(_)) {
336-
generate_distinct_return_witnesses(acir);
337-
}
338-
}
339-
340329
Ok((acirs, brillig, self.error_selector_to_type))
341330
}
342331
}
343332

344-
fn generate_distinct_return_witnesses(acir: &mut GeneratedAcir) {
345-
// Create a witness for each return witness we have to guarantee that the return witnesses match the standard
346-
// layout for serializing those types as if they were being passed as inputs.
347-
//
348-
// This is required for recursion as otherwise in situations where we cannot make use of the program's ABI
349-
// (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're
350-
// working with rather than following the standard ABI encoding rules.
351-
//
352-
// TODO: We're being conservative here by generating a new witness for every expression.
353-
// This means that we're likely to get a number of constraints which are just renumbering witnesses.
354-
// This can be tackled by:
355-
// - Tracking the last assigned public input witness and only renumbering a witness if it is below this value.
356-
// - Modifying existing constraints to rearrange their outputs so they are suitable
357-
// - See: https://github.com/noir-lang/noir/pull/4467
358-
let distinct_return_witness = vecmap(acir.return_witnesses.clone(), |return_witness| {
359-
acir.create_witness_for_expression(&Expression::from(return_witness))
360-
});
361-
362-
acir.return_witnesses = distinct_return_witness;
363-
}
364-
365333
impl<'a> Context<'a> {
366334
fn new(shared_context: &'a mut SharedContext) -> Context<'a> {
367335
let mut acir_context = AcirContext::default();
@@ -422,15 +390,45 @@ impl<'a> Context<'a> {
422390
let dfg = &main_func.dfg;
423391
let entry_block = &dfg[main_func.entry_block()];
424392
let input_witness = self.convert_ssa_block_params(entry_block.parameters(), dfg)?;
393+
let num_return_witnesses =
394+
self.get_num_return_witnesses(entry_block.unwrap_terminator(), dfg);
395+
396+
// Create a witness for each return witness we have to guarantee that the return witnesses match the standard
397+
// layout for serializing those types as if they were being passed as inputs.
398+
//
399+
// This is required for recursion as otherwise in situations where we cannot make use of the program's ABI
400+
// (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're
401+
// working with rather than following the standard ABI encoding rules.
402+
//
403+
// We allocate these witnesses now before performing ACIR gen for the rest of the program as the location of
404+
// the function's return values can then be determined through knowledge of its ABI alone.
405+
let return_witness_vars =
406+
vecmap(0..num_return_witnesses, |_| self.acir_context.add_variable());
407+
408+
let return_witnesses = vecmap(&return_witness_vars, |return_var| {
409+
let expr = self.acir_context.var_to_expression(*return_var).unwrap();
410+
expr.to_witness().expect("return vars should be witnesses")
411+
});
425412

426413
self.data_bus = dfg.data_bus.to_owned();
427414
let mut warnings = Vec::new();
428415
for instruction_id in entry_block.instructions() {
429416
warnings.extend(self.convert_ssa_instruction(*instruction_id, dfg, ssa, brillig)?);
430417
}
431418

432-
warnings.extend(self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?);
433-
Ok(self.acir_context.finish(input_witness, warnings))
419+
let (return_vars, return_warnings) =
420+
self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?;
421+
422+
// TODO: This is a naive method of assigning the return values to their witnesses as
423+
// we're likely to get a number of constraints which are asserting one witness to be equal to another.
424+
//
425+
// We should search through the program and relabel these witnesses so we can remove this constraint.
426+
for (witness_var, return_var) in return_witness_vars.iter().zip(return_vars) {
427+
self.acir_context.assert_eq_var(*witness_var, return_var, None)?;
428+
}
429+
430+
warnings.extend(return_warnings);
431+
Ok(self.acir_context.finish(input_witness, return_witnesses, warnings))
434432
}
435433

436434
fn convert_brillig_main(
@@ -468,17 +466,13 @@ impl<'a> Context<'a> {
468466
)?;
469467
self.shared_context.insert_generated_brillig(main_func.id(), arguments, 0, code);
470468

471-
let output_vars: Vec<_> = output_values
469+
let return_witnesses: Vec<Witness> = output_values
472470
.iter()
473471
.flat_map(|value| value.clone().flatten())
474-
.map(|value| value.0)
475-
.collect();
472+
.map(|(value, _)| self.acir_context.var_to_witness(value))
473+
.collect::<Result<_, _>>()?;
476474

477-
for acir_var in output_vars {
478-
self.acir_context.return_var(acir_var)?;
479-
}
480-
481-
let generated_acir = self.acir_context.finish(witness_inputs, Vec::new());
475+
let generated_acir = self.acir_context.finish(witness_inputs, return_witnesses, Vec::new());
482476

483477
assert_eq!(
484478
generated_acir.opcodes().len(),
@@ -1724,12 +1718,39 @@ impl<'a> Context<'a> {
17241718
self.define_result(dfg, instruction, AcirValue::Var(result, typ));
17251719
}
17261720

1721+
/// Converts an SSA terminator's return values into their ACIR representations
1722+
fn get_num_return_witnesses(
1723+
&mut self,
1724+
terminator: &TerminatorInstruction,
1725+
dfg: &DataFlowGraph,
1726+
) -> usize {
1727+
let return_values = match terminator {
1728+
TerminatorInstruction::Return { return_values, .. } => return_values,
1729+
// TODO(https://github.com/noir-lang/noir/issues/4616): Enable recursion on foldable/non-inlined ACIR functions
1730+
_ => unreachable!("ICE: Program must have a singular return"),
1731+
};
1732+
1733+
return_values.iter().fold(0, |acc, value_id| {
1734+
let is_databus = self
1735+
.data_bus
1736+
.return_data
1737+
.map_or(false, |return_databus| dfg[*value_id] == dfg[return_databus]);
1738+
1739+
if is_databus {
1740+
// We do not return value for the data bus.
1741+
acc
1742+
} else {
1743+
acc + dfg.type_of_value(*value_id).flattened_size()
1744+
}
1745+
})
1746+
}
1747+
17271748
/// Converts an SSA terminator's return values into their ACIR representations
17281749
fn convert_ssa_return(
17291750
&mut self,
17301751
terminator: &TerminatorInstruction,
17311752
dfg: &DataFlowGraph,
1732-
) -> Result<Vec<SsaReport>, RuntimeError> {
1753+
) -> Result<(Vec<AcirVar>, Vec<SsaReport>), RuntimeError> {
17331754
let (return_values, call_stack) = match terminator {
17341755
TerminatorInstruction::Return { return_values, call_stack } => {
17351756
(return_values, call_stack.clone())
@@ -1739,6 +1760,7 @@ impl<'a> Context<'a> {
17391760
};
17401761

17411762
let mut has_constant_return = false;
1763+
let mut return_vars: Vec<AcirVar> = Vec::new();
17421764
for value_id in return_values {
17431765
let is_databus = self
17441766
.data_bus
@@ -1759,7 +1781,7 @@ impl<'a> Context<'a> {
17591781
dfg,
17601782
)?;
17611783
} else {
1762-
self.acir_context.return_var(acir_var)?;
1784+
return_vars.push(acir_var);
17631785
}
17641786
}
17651787
}
@@ -1770,7 +1792,7 @@ impl<'a> Context<'a> {
17701792
Vec::new()
17711793
};
17721794

1773-
Ok(warnings)
1795+
Ok((return_vars, warnings))
17741796
}
17751797

17761798
/// Gets the cached `AcirVar` that was converted from the corresponding `ValueId`. If it does
@@ -3079,8 +3101,8 @@ mod test {
30793101
check_call_opcode(
30803102
&func_with_nested_call_opcodes[1],
30813103
2,
3082-
vec![Witness(2), Witness(1)],
3083-
vec![Witness(3)],
3104+
vec![Witness(3), Witness(1)],
3105+
vec![Witness(4)],
30843106
);
30853107
}
30863108

@@ -3100,13 +3122,13 @@ mod test {
31003122
for (expected_input, input) in expected_inputs.iter().zip(inputs) {
31013123
assert_eq!(
31023124
expected_input, input,
3103-
"Expected witness {expected_input:?} but got {input:?}"
3125+
"Expected input witness {expected_input:?} but got {input:?}"
31043126
);
31053127
}
31063128
for (expected_output, output) in expected_outputs.iter().zip(outputs) {
31073129
assert_eq!(
31083130
expected_output, output,
3109-
"Expected witness {expected_output:?} but got {output:?}"
3131+
"Expected output witness {expected_output:?} but got {output:?}"
31103132
);
31113133
}
31123134
}

0 commit comments

Comments
 (0)