Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 26 additions & 46 deletions acvm-repo/acvm/src/compiler/optimizers/redundant_range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use acir::{
opcodes::{BlackBoxFuncCall, FunctionInput},
Circuit, Opcode,
},
native_types::{Expression, Witness},
FieldElement,
native_types::Witness,
};
use std::collections::{BTreeMap, HashSet};

Expand Down Expand Up @@ -105,9 +104,11 @@ impl RangeOptimizer {
let mut new_order_list = Vec::with_capacity(order_list.len());
let mut optimized_opcodes = Vec::with_capacity(self.circuit.opcodes.len());
for (idx, opcode) in self.circuit.opcodes.into_iter().enumerate() {
let (witness, num_bits) = match extract_range_opcode(&opcode) {
Some(range_opcode) => range_opcode,
None => {
let (witness, num_bits) = match &opcode {
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { input }) => {
(input.witness, input.num_bits)
}
_ => {
// If its not the range opcode, add it to the opcode
// list and continue;
optimized_opcodes.push(opcode);
Expand All @@ -131,44 +132,19 @@ impl RangeOptimizer {
if is_lowest_bit_size {
already_seen_witness.insert(witness);
new_order_list.push(order_list[idx]);
optimized_opcodes.push(optimized_range_opcode(witness, num_bits));
optimized_opcodes.push(opcode);
}
}

(Circuit { opcodes: optimized_opcodes, ..self.circuit }, new_order_list)
}
}

/// Extract the range opcode from the `Opcode` enum
/// Returns None, if `Opcode` is not the range opcode.
fn extract_range_opcode(opcode: &Opcode) -> Option<(Witness, u32)> {
match opcode {
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE { input }) => {
Some((input.witness, input.num_bits))
}
_ => None,
}
}

fn optimized_range_opcode(witness: Witness, num_bits: u32) -> Opcode {
if num_bits == 1 {
Opcode::AssertZero(Expression {
mul_terms: vec![(FieldElement::one(), witness, witness)],
linear_combinations: vec![(-FieldElement::one(), witness)],
q_c: FieldElement::zero(),
})
} else {
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input: FunctionInput { witness, num_bits },
})
}
}

#[cfg(test)]
mod tests {
use std::collections::BTreeSet;

use crate::compiler::optimizers::redundant_range::{extract_range_opcode, RangeOptimizer};
use crate::compiler::optimizers::redundant_range::RangeOptimizer;
use acir::{
circuit::{
opcodes::{BlackBoxFuncCall, FunctionInput},
Expand Down Expand Up @@ -218,11 +194,12 @@ mod tests {
let (optimized_circuit, _) = optimizer.replace_redundant_ranges(acir_opcode_positions);
assert_eq!(optimized_circuit.opcodes.len(), 1);

let (witness, num_bits) =
extract_range_opcode(&optimized_circuit.opcodes[0]).expect("expected one range opcode");

assert_eq!(witness, Witness(1));
assert_eq!(num_bits, 16);
assert_eq!(
optimized_circuit.opcodes[0],
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input: FunctionInput { witness: Witness(1), num_bits: 16 }
})
);
}

#[test]
Expand All @@ -240,15 +217,18 @@ mod tests {
let (optimized_circuit, _) = optimizer.replace_redundant_ranges(acir_opcode_positions);
assert_eq!(optimized_circuit.opcodes.len(), 2);

let (witness_a, num_bits_a) =
extract_range_opcode(&optimized_circuit.opcodes[0]).expect("expected two range opcode");
let (witness_b, num_bits_b) =
extract_range_opcode(&optimized_circuit.opcodes[1]).expect("expected two range opcode");

assert_eq!(witness_a, Witness(1));
assert_eq!(witness_b, Witness(2));
assert_eq!(num_bits_a, 16);
assert_eq!(num_bits_b, 23);
assert_eq!(
optimized_circuit.opcodes[0],
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input: FunctionInput { witness: Witness(1), num_bits: 16 }
})
);
assert_eq!(
optimized_circuit.opcodes[1],
Opcode::BlackBoxFuncCall(BlackBoxFuncCall::RANGE {
input: FunctionInput { witness: Witness(2), num_bits: 23 }
})
);
}

#[test]
Expand Down