Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 90 additions & 60 deletions acvm-repo/acvm/src/compiler/optimizers/merge_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,36 @@ use acir::{

use crate::compiler::CircuitSimulator;

pub(crate) struct MergeExpressionsOptimizer {
pub(crate) struct MergeExpressionsOptimizer<F> {
resolved_blocks: HashMap<BlockId, BTreeSet<Witness>>,
modified_gates: HashMap<usize, Opcode<F>>,
deleted_gates: BTreeSet<usize>,
}

impl MergeExpressionsOptimizer {
impl<F: AcirField> MergeExpressionsOptimizer<F> {
pub(crate) fn new() -> Self {
MergeExpressionsOptimizer { resolved_blocks: HashMap::new() }
MergeExpressionsOptimizer {
resolved_blocks: HashMap::new(),
modified_gates: HashMap::new(),
deleted_gates: BTreeSet::new(),
}
}
/// This pass analyzes the circuit and identifies intermediate variables that are
/// only used in two gates. It then merges the gate that produces the
/// intermediate variable into the second one that uses it
/// Note: This pass is only relevant for backends that can handle unlimited width
pub(crate) fn eliminate_intermediate_variable<F: AcirField>(
pub(crate) fn eliminate_intermediate_variable(
&mut self,
circuit: &Circuit<F>,
acir_opcode_positions: Vec<usize>,
) -> (Vec<Opcode<F>>, Vec<usize>) {
// Initialisation
self.modified_gates.clear();
self.deleted_gates.clear();
self.resolved_blocks.clear();

// Keep track, for each witness, of the gates that use it
let circuit_inputs = circuit.circuit_arguments();
self.resolved_blocks = HashMap::new();
let mut used_witness: BTreeMap<Witness, BTreeSet<usize>> = BTreeMap::new();
for (i, opcode) in circuit.opcodes.iter().enumerate() {
let witnesses = self.witness_inputs(opcode);
Expand All @@ -46,61 +56,69 @@ impl MergeExpressionsOptimizer {
}
}

let mut modified_gates: HashMap<usize, Opcode<F>> = HashMap::new();
let mut new_circuit = Vec::new();
let mut new_acir_opcode_positions = Vec::new();
// For each opcode, try to get a target opcode to merge with
for (i, (opcode, opcode_position)) in
circuit.opcodes.iter().zip(acir_opcode_positions).enumerate()
{
for (i, opcode) in circuit.opcodes.iter().enumerate() {
if !matches!(opcode, Opcode::AssertZero(_)) {
new_circuit.push(opcode.clone());
new_acir_opcode_positions.push(opcode_position);
continue;
}
let opcode = modified_gates.get(&i).unwrap_or(opcode).clone();
let mut to_keep = true;
let input_witnesses = self.witness_inputs(&opcode);
for w in input_witnesses {
let Some(gates_using_w) = used_witness.get(&w) else {
continue;
};
// We only consider witness which are used in exactly two arithmetic gates
if gates_using_w.len() == 2 {
let first = *gates_using_w.first().expect("gates_using_w.len == 2");
let second = *gates_using_w.last().expect("gates_using_w.len == 2");
let b = if second == i {
first
} else {
// sanity check
assert!(i == first);
second
if let Some(opcode) = self.get_opcode(i, circuit) {
let input_witnesses = self.witness_inputs(&opcode);
for w in input_witnesses {
let Some(gates_using_w) = used_witness.get(&w) else {
continue;
};

let second_gate = modified_gates.get(&b).unwrap_or(&circuit.opcodes[b]);
if let (Opcode::AssertZero(expr_define), Opcode::AssertZero(expr_use)) =
(&opcode, second_gate)
{
// We cannot merge an expression into an earlier opcode, because this
// would break the 'execution ordering' of the opcodes
// This case can happen because a previous merge would change an opcode
// and eliminate a witness from it, giving new opportunities for this
// witness to be used in only two expressions
// TODO: the missed optimization for the i>b case can be handled by
// - doing this pass again until there is no change, or
// - merging 'b' into 'i' instead
if i < b {
if let Some(expr) = Self::merge(expr_use, expr_define, w) {
modified_gates.insert(b, Opcode::AssertZero(expr));
to_keep = false;
// Update the 'used_witness' map to account for the merge.
for w2 in CircuitSimulator::expr_wit(expr_define) {
if !circuit_inputs.contains(&w2) {
let v = used_witness.entry(w2).or_default();
v.insert(b);
v.remove(&i);
// We only consider witness which are used in exactly two arithmetic gates
if gates_using_w.len() == 2 {
let first = *gates_using_w.first().expect("gates_using_w.len == 2");
let second = *gates_using_w.last().expect("gates_using_w.len == 2");
let b = if second == i {
first
} else {
// sanity check
assert!(i == first);
second
};
// Merge source opcode into target opcode
// by updating modified_gates/deleted_gates/used_witness
let mut merge_opcodes = |target, source| -> bool {
assert!(source < target);
let source_opcode = self.get_opcode(source, circuit);
let target_opcode = self.get_opcode(target, circuit);
if let (
Some(Opcode::AssertZero(expr_use)),
Some(Opcode::AssertZero(expr_define)),
) = (target_opcode, source_opcode)
{
if let Some(expr) =
Self::merge_expression(&expr_use, &expr_define, w)
{
self.modified_gates.insert(target, Opcode::AssertZero(expr));
self.deleted_gates.insert(source);
// Update the 'used_witness' map to account for the merge.
let mut witness_list = CircuitSimulator::expr_wit(&expr_use);
Comment thread
aakoshh marked this conversation as resolved.
witness_list.extend(CircuitSimulator::expr_wit(&expr_define));
for w2 in witness_list {
if !circuit_inputs.contains(&w2) {
let mut v = used_witness[&w2].clone();
v.insert(b);
v.remove(&i);
Comment thread
guipublic marked this conversation as resolved.
Outdated
used_witness.insert(w2, v);
Comment thread
guipublic marked this conversation as resolved.
Outdated
}
}
return true;
}
}
false
};
if i < b {
if merge_opcodes(b, i) {
// We need to stop here and continue with the next opcode
// because the merge invalidates the current opcode.
break;
}
} else if i != b {
// Merge b into i
if merge_opcodes(i, b) {
Comment thread
aakoshh marked this conversation as resolved.
Outdated
// We need to stop here and continue with the next opcode
// because the merge invalidates the current opcode.
break;
Expand All @@ -109,17 +127,22 @@ impl MergeExpressionsOptimizer {
}
}
}
}

// Construct the new circuit from modified/deleted gates
let mut new_circuit = Vec::new();
let mut new_acir_opcode_positions = Vec::new();

if to_keep {
let opcode = modified_gates.get(&i).cloned().unwrap_or(opcode);
new_circuit.push(opcode);
new_acir_opcode_positions.push(opcode_position);
for (i, opcode_position) in acir_opcode_positions.iter().enumerate() {
if let Some(op) = self.get_opcode(i, circuit) {
new_circuit.push(op);
new_acir_opcode_positions.push(*opcode_position);
}
}
(new_circuit, new_acir_opcode_positions)
}

fn brillig_input_wit<F>(&self, input: &BrilligInputs<F>) -> BTreeSet<Witness> {
fn brillig_input_wit(&self, input: &BrilligInputs<F>) -> BTreeSet<Witness> {
Comment thread
aakoshh marked this conversation as resolved.
let mut result = BTreeSet::new();
match input {
BrilligInputs::Single(expr) => {
Expand Down Expand Up @@ -152,7 +175,7 @@ impl MergeExpressionsOptimizer {
}

// Returns the input witnesses used by the opcode
fn witness_inputs<F: AcirField>(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
fn witness_inputs(&self, opcode: &Opcode<F>) -> BTreeSet<Witness> {
Comment thread
aakoshh marked this conversation as resolved.
match opcode {
Opcode::AssertZero(expr) => CircuitSimulator::expr_wit(expr),
Opcode::BlackBoxFuncCall(bb_func) => {
Expand Down Expand Up @@ -198,7 +221,7 @@ impl MergeExpressionsOptimizer {

// Merge 'expr' into 'target' via Gaussian elimination on 'w'
// Returns None if the expressions cannot be merged
fn merge<F: AcirField>(
fn merge_expression(
target: &Expression<F>,
expr: &Expression<F>,
w: Witness,
Expand Down Expand Up @@ -226,6 +249,13 @@ impl MergeExpressionsOptimizer {
}
None
}

fn get_opcode(&self, g: usize, circuit: &Circuit<F>) -> Option<Opcode<F>> {
if self.deleted_gates.contains(&g) {
return None;
}
self.modified_gates.get(&g).or(circuit.opcodes.get(g)).cloned()
}
}

#[cfg(test)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn main(x: Field) {
value += term2;
value.assert_max_bit_size::<1>();

// Regression test for Aztec Packages issue #6451
// Regression test for #6447 (Aztec Packages issue #9488)
let y = unsafe { empty(x + 1) };
let z = y + x + 1;
let z1 = z + y;
Expand Down