Skip to content

Commit aad0da0

Browse files
TomAFrenchkevaundrayvezenovm
authored
feat: backpropagate constants in ACIR during optimization (#3926)
# Description ## Problem\* Resolves <!-- Link to GitHub Issue --> ## Summary\* This is a mildy bruteforce-y optimisation method where we just literally attempt to execute the circuit backwards. Any witnesses which we can determine from this can just be written into the circuit directly. A lot of the complexity here comes from the fact that memory opcodes, etc. require witnesses to be unassigned at the point at which the opcode is encountered so we need to "forget" certain witnesses so that we don't optimise them away. Draft as I'm just pushing this up to track effects. ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[Exceptional Case]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --------- Co-authored-by: kevaundray <kevtheappdev@gmail.com> Co-authored-by: Maxim Vezenov <mvezenov@gmail.com>
1 parent cb4c1c5 commit aad0da0

12 files changed

Lines changed: 403 additions & 35 deletions

File tree

acvm-repo/acir/src/circuit/opcodes/black_box_function_call.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,10 @@ impl BlackBoxFuncCall {
217217
| BlackBoxFuncCall::PedersenCommitment { inputs, .. }
218218
| BlackBoxFuncCall::PedersenHash { inputs, .. }
219219
| BlackBoxFuncCall::BigIntFromLeBytes { inputs, .. }
220-
| BlackBoxFuncCall::Poseidon2Permutation { inputs, .. }
221-
| BlackBoxFuncCall::Sha256Compression { inputs, .. } => inputs.to_vec(),
220+
| BlackBoxFuncCall::Poseidon2Permutation { inputs, .. } => inputs.to_vec(),
221+
BlackBoxFuncCall::Sha256Compression { inputs, hash_values, .. } => {
222+
inputs.iter().chain(hash_values).copied().collect()
223+
}
222224
BlackBoxFuncCall::AND { lhs, rhs, .. } | BlackBoxFuncCall::XOR { lhs, rhs, .. } => {
223225
vec![*lhs, *rhs]
224226
}
Lines changed: 331 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,331 @@
1+
use std::collections::{BTreeMap, BTreeSet, HashMap};
2+
3+
use crate::{
4+
compiler::optimizers::GeneralOptimizer,
5+
pwg::{
6+
arithmetic::ExpressionSolver, blackbox::solve_range_opcode, directives::solve_directives,
7+
BrilligSolver, BrilligSolverStatus,
8+
},
9+
};
10+
use acir::{
11+
circuit::{
12+
brillig::{Brillig, BrilligInputs, BrilligOutputs},
13+
directives::Directive,
14+
opcodes::BlackBoxFuncCall,
15+
Circuit, Opcode,
16+
},
17+
native_types::{Expression, Witness, WitnessMap},
18+
};
19+
use acvm_blackbox_solver::StubbedBlackBoxSolver;
20+
21+
/// `ConstantBackpropagationOptimizer` will attempt to determine any constant witnesses within the program.
22+
/// It does this by attempting to solve the program without any inputs (i.e. using an empty witness map),
23+
/// any values which it can determine are then enforced to be constant values.
24+
///
25+
/// The optimizer will then replace any witnesses wherever they appear within the circuit with these constant values.
26+
/// This is repeated until the circuit stabilizes.
27+
pub(crate) struct ConstantBackpropagationOptimizer {
28+
circuit: Circuit,
29+
}
30+
31+
impl ConstantBackpropagationOptimizer {
32+
/// Creates a new `ConstantBackpropagationOptimizer`
33+
pub(crate) fn new(circuit: Circuit) -> Self {
34+
Self { circuit }
35+
}
36+
37+
fn gather_known_witnesses(&self) -> (WitnessMap, BTreeSet<Witness>) {
38+
// We do not want to affect the circuit's interface so avoid optimizing away these witnesses.
39+
let mut required_witnesses: BTreeSet<Witness> = self
40+
.circuit
41+
.private_parameters
42+
.union(&self.circuit.public_parameters.0)
43+
.chain(&self.circuit.return_values.0)
44+
.copied()
45+
.collect();
46+
47+
for opcode in &self.circuit.opcodes {
48+
match &opcode {
49+
Opcode::BlackBoxFuncCall(func_call) => {
50+
required_witnesses.extend(
51+
func_call.get_inputs_vec().into_iter().map(|func_input| func_input.witness),
52+
);
53+
required_witnesses.extend(func_call.get_outputs_vec());
54+
}
55+
56+
Opcode::MemoryInit { init, .. } => {
57+
required_witnesses.extend(init);
58+
}
59+
60+
Opcode::MemoryOp { op, .. } => {
61+
required_witnesses.insert(op.index.to_witness().unwrap());
62+
required_witnesses.insert(op.value.to_witness().unwrap());
63+
}
64+
65+
_ => (),
66+
};
67+
}
68+
69+
let mut known_witnesses = WitnessMap::new();
70+
for opcode in self.circuit.opcodes.iter().rev() {
71+
if let Opcode::AssertZero(expr) = opcode {
72+
let solve_result = ExpressionSolver::solve(&mut known_witnesses, expr);
73+
// It doesn't matter what the result is. We expect most opcodes to not be solved successfully so we discard errors.
74+
// At the same time, if the expression can be solved then we track this by the updates to `known_witnesses`
75+
drop(solve_result);
76+
}
77+
}
78+
79+
// We want to retain any references to required witnesses so we "forget" these assignments.
80+
let known_witnesses: BTreeMap<_, _> = known_witnesses
81+
.into_iter()
82+
.filter(|(witness, _)| !required_witnesses.contains(witness))
83+
.collect();
84+
85+
(known_witnesses.into(), required_witnesses)
86+
}
87+
88+
/// Returns a `Circuit` where with any constant witnesses replaced with the constant they resolve to.
89+
#[tracing::instrument(level = "trace", skip_all)]
90+
pub(crate) fn backpropagate_constants(
91+
circuit: Circuit,
92+
order_list: Vec<usize>,
93+
) -> (Circuit, Vec<usize>) {
94+
let old_circuit_size = circuit.opcodes.len();
95+
96+
let optimizer = Self::new(circuit);
97+
let (circuit, order_list) = optimizer.backpropagate_constants_iteration(order_list);
98+
99+
let new_circuit_size = circuit.opcodes.len();
100+
if new_circuit_size < old_circuit_size {
101+
Self::backpropagate_constants(circuit, order_list)
102+
} else {
103+
(circuit, order_list)
104+
}
105+
}
106+
107+
/// Applies a single round of constant backpropagation to a `Circuit`.
108+
pub(crate) fn backpropagate_constants_iteration(
109+
mut self,
110+
order_list: Vec<usize>,
111+
) -> (Circuit, Vec<usize>) {
112+
let (mut known_witnesses, required_witnesses) = self.gather_known_witnesses();
113+
114+
let opcodes = std::mem::take(&mut self.circuit.opcodes);
115+
116+
fn remap_expression(known_witnesses: &WitnessMap, expression: Expression) -> Expression {
117+
GeneralOptimizer::optimize(ExpressionSolver::evaluate(&expression, known_witnesses))
118+
}
119+
120+
let mut new_order_list = Vec::with_capacity(order_list.len());
121+
let mut new_opcodes = Vec::with_capacity(opcodes.len());
122+
for (idx, opcode) in opcodes.into_iter().enumerate() {
123+
let new_opcode = match opcode {
124+
Opcode::AssertZero(expression) => {
125+
let new_expr = remap_expression(&known_witnesses, expression);
126+
if new_expr.is_zero() {
127+
continue;
128+
}
129+
130+
// Attempt to solve the opcode to see if we can determine the value of any witnesses in the expression.
131+
// We only do this _after_ we apply any simplifications to create the new opcode as we want to
132+
// keep the constraint on the witness which we are solving for here.
133+
let solve_result = ExpressionSolver::solve(&mut known_witnesses, &new_expr);
134+
// It doesn't matter what the result is. We expect most opcodes to not be solved successfully so we discard errors.
135+
// At the same time, if the expression can be solved then we track this by the updates to `known_witnesses`
136+
drop(solve_result);
137+
138+
Opcode::AssertZero(new_expr)
139+
}
140+
Opcode::Brillig(brillig) => {
141+
let remapped_inputs = brillig
142+
.inputs
143+
.into_iter()
144+
.map(|input| match input {
145+
BrilligInputs::Single(expr) => {
146+
BrilligInputs::Single(remap_expression(&known_witnesses, expr))
147+
}
148+
BrilligInputs::Array(expr_array) => {
149+
let new_input: Vec<_> = expr_array
150+
.into_iter()
151+
.map(|expr| remap_expression(&known_witnesses, expr))
152+
.collect();
153+
154+
BrilligInputs::Array(new_input)
155+
}
156+
input @ BrilligInputs::MemoryArray(_) => input,
157+
})
158+
.collect();
159+
160+
let remapped_predicate = brillig
161+
.predicate
162+
.map(|predicate| remap_expression(&known_witnesses, predicate));
163+
164+
let new_brillig = Brillig {
165+
inputs: remapped_inputs,
166+
predicate: remapped_predicate,
167+
..brillig
168+
};
169+
170+
let brillig_output_is_required_witness =
171+
new_brillig.outputs.iter().any(|output| match output {
172+
BrilligOutputs::Simple(witness) => required_witnesses.contains(witness),
173+
BrilligOutputs::Array(witness_array) => witness_array
174+
.iter()
175+
.any(|witness| required_witnesses.contains(witness)),
176+
});
177+
178+
if brillig_output_is_required_witness {
179+
// If one of the brillig opcode's outputs is a required witness then we can't remove the opcode. In this case we can't replace
180+
// all of the uses of this witness with the calculated constant so we'll be attempting to use an uninitialized witness.
181+
//
182+
// We then do not attempt execution of this opcode and just simplify the inputs.
183+
Opcode::Brillig(new_brillig)
184+
} else if let Ok(mut solver) = BrilligSolver::new(
185+
&known_witnesses,
186+
&HashMap::new(),
187+
&new_brillig,
188+
&StubbedBlackBoxSolver,
189+
idx,
190+
) {
191+
match solver.solve() {
192+
Ok(BrilligSolverStatus::Finished) => {
193+
// Write execution outputs
194+
match solver.finalize(&mut known_witnesses, &new_brillig) {
195+
Ok(()) => {
196+
// If we've managed to execute the brillig opcode at compile time, we can now just write in the
197+
// results as constants for the rest of the circuit.
198+
continue;
199+
}
200+
_ => Opcode::Brillig(new_brillig),
201+
}
202+
}
203+
Ok(BrilligSolverStatus::InProgress) => unreachable!(
204+
"Solver should either finish, block on foreign call, or error."
205+
),
206+
Ok(BrilligSolverStatus::ForeignCallWait(_)) | Err(_) => {
207+
Opcode::Brillig(new_brillig)
208+
}
209+
}
210+
} else {
211+
Opcode::Brillig(new_brillig)
212+
}
213+
}
214+
215+
Opcode::Directive(Directive::ToLeRadix { a, b, radix }) => {
216+
if b.iter().all(|output| known_witnesses.contains_key(output)) {
217+
continue;
218+
} else if b.iter().any(|witness| required_witnesses.contains(witness)) {
219+
// If one of the brillig opcode's outputs is a required witness then we can't remove the opcode. In this case we can't replace
220+
// all of the uses of this witness with the calculated constant so we'll be attempting to use an uninitialized witness.
221+
//
222+
// We then do not attempt execution of this opcode and just simplify the inputs.
223+
Opcode::Directive(Directive::ToLeRadix {
224+
a: remap_expression(&known_witnesses, a),
225+
b,
226+
radix,
227+
})
228+
} else {
229+
let directive = Directive::ToLeRadix {
230+
a: remap_expression(&known_witnesses, a),
231+
b,
232+
radix,
233+
};
234+
let result = solve_directives(&mut known_witnesses, &directive);
235+
236+
match result {
237+
Ok(()) => continue,
238+
Err(_) => Opcode::Directive(directive),
239+
}
240+
}
241+
}
242+
243+
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { input }) => {
244+
if solve_range_opcode(&known_witnesses, &input).is_ok() {
245+
continue;
246+
} else {
247+
opcode
248+
}
249+
}
250+
251+
Opcode::BlackBoxFuncCall(_)
252+
| Opcode::MemoryOp { .. }
253+
| Opcode::MemoryInit { .. } => opcode,
254+
};
255+
256+
new_opcodes.push(new_opcode);
257+
new_order_list.push(order_list[idx]);
258+
}
259+
260+
self.circuit.opcodes = new_opcodes;
261+
262+
(self.circuit, new_order_list)
263+
}
264+
}
265+
266+
#[cfg(test)]
267+
mod tests {
268+
use std::collections::BTreeSet;
269+
270+
use crate::compiler::optimizers::constant_backpropagation::ConstantBackpropagationOptimizer;
271+
use acir::{
272+
brillig::MemoryAddress,
273+
circuit::{
274+
brillig::{Brillig, BrilligOutputs},
275+
opcodes::{BlackBoxFuncCall, FunctionInput},
276+
Circuit, ExpressionWidth, Opcode, PublicInputs,
277+
},
278+
native_types::Witness,
279+
};
280+
use brillig_vm::brillig::Opcode as BrilligOpcode;
281+
282+
fn test_circuit(opcodes: Vec<Opcode>) -> Circuit {
283+
Circuit {
284+
current_witness_index: 1,
285+
expression_width: ExpressionWidth::Bounded { width: 3 },
286+
opcodes,
287+
private_parameters: BTreeSet::new(),
288+
public_parameters: PublicInputs::default(),
289+
return_values: PublicInputs::default(),
290+
assert_messages: Default::default(),
291+
recursive: false,
292+
}
293+
}
294+
295+
#[test]
296+
fn retain_brillig_with_required_witness_outputs() {
297+
let brillig_opcode = Opcode::Brillig(Brillig {
298+
inputs: Vec::new(),
299+
outputs: vec![BrilligOutputs::Simple(Witness(1))],
300+
bytecode: vec![
301+
BrilligOpcode::Const {
302+
destination: MemoryAddress(0),
303+
bit_size: 32,
304+
value: 1u128.into(),
305+
},
306+
BrilligOpcode::Stop { return_data_offset: 0, return_data_size: 1 },
307+
],
308+
predicate: None,
309+
});
310+
let blackbox_opcode = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::AND {
311+
lhs: FunctionInput { witness: Witness(1), num_bits: 64 },
312+
rhs: FunctionInput { witness: Witness(2), num_bits: 64 },
313+
output: Witness(3),
314+
});
315+
316+
let opcodes = vec![brillig_opcode, blackbox_opcode];
317+
// The optimizer should keep the lowest bit size range constraint
318+
let circuit = test_circuit(opcodes);
319+
let acir_opcode_positions = circuit.opcodes.iter().enumerate().map(|(i, _)| i).collect();
320+
let optimizer = ConstantBackpropagationOptimizer::new(circuit);
321+
322+
let (optimized_circuit, _) =
323+
optimizer.backpropagate_constants_iteration(acir_opcode_positions);
324+
325+
assert_eq!(
326+
optimized_circuit.opcodes.len(),
327+
2,
328+
"The brillig opcode should not be removed as the output is needed as a witness"
329+
);
330+
}
331+
}

acvm-repo/acvm/src/compiler/optimizers/general.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ impl GeneralOptimizer {
1313
pub(crate) fn optimize(opcode: Expression) -> Expression {
1414
// XXX: Perhaps this optimization can be done on the fly
1515
let opcode = remove_zero_coefficients(opcode);
16-
simplify_mul_terms(opcode)
16+
let opcode = simplify_mul_terms(opcode);
17+
simplify_linear_terms(opcode)
1718
}
1819
}
1920

@@ -42,3 +43,20 @@ fn simplify_mul_terms(mut gate: Expression) -> Expression {
4243
gate.mul_terms = hash_map.into_iter().map(|((w_l, w_r), scale)| (scale, w_l, w_r)).collect();
4344
gate
4445
}
46+
47+
// Simplifies all linear terms with the same variables
48+
fn simplify_linear_terms(mut gate: Expression) -> Expression {
49+
let mut hash_map: IndexMap<Witness, FieldElement> = IndexMap::new();
50+
51+
// Canonicalize the ordering of the terms, lets just order by variable name
52+
for (scale, witness) in gate.linear_combinations.into_iter() {
53+
*hash_map.entry(witness).or_insert_with(FieldElement::zero) += scale;
54+
}
55+
56+
gate.linear_combinations = hash_map
57+
.into_iter()
58+
.filter(|(_, scale)| scale != &FieldElement::zero())
59+
.map(|(witness, scale)| (scale, witness))
60+
.collect();
61+
gate
62+
}

0 commit comments

Comments
 (0)