Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions acvm-repo/acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,7 @@ impl<'b, B: BlackBoxFunctionSolver<F>, F: AcirField> BrilligSolver<'b, F, B> {
let memory_block = memory
.get(block_id)
.ok_or(OpcodeNotSolvable::MissingMemoryBlock(block_id.0))?;
for memory_index in 0..memory_block.block_len {
let memory_value = memory_block
.block_value
.get(&memory_index)
.expect("All memory is initialized on creation");
calldata.push(*memory_value);
}
calldata.extend(&memory_block.block_value);
}
}
}
Expand Down
80 changes: 38 additions & 42 deletions acvm-repo/acvm/src/pwg/memory_op.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::collections::HashMap;

use acir::{
AcirField,
circuit::opcodes::MemOp,
Expand All @@ -14,16 +12,31 @@ use super::{
type MemoryIndex = u32;

/// Maintains the state for solving [`MemoryInit`][`acir::circuit::Opcode::MemoryInit`] and [`MemoryOp`][`acir::circuit::Opcode::MemoryOp`] opcodes.
#[derive(Default)]
pub(crate) struct MemoryOpSolver<F> {
/// Known values of the memory block, based on the index
/// This map evolves as we process the opcodes
pub(super) block_value: HashMap<MemoryIndex, F>,
/// Length of the block, i.e the number of elements stored into the memory block.
pub(super) block_len: u32,
/// This vec starts as big as it needs to, when initialized,
/// then evolves as we process the opcodes.
pub(super) block_value: Vec<F>,
}

impl<F: AcirField> MemoryOpSolver<F> {
/// Creates a new MemoryOpSolver with the values given in `init`.
pub(crate) fn new(
init: &[Witness],
initial_witness: &WitnessMap<F>,
) -> Result<Self, OpcodeResolutionError<F>> {
Ok(Self {
block_value: init
.iter()
.map(|witness| witness_to_value(initial_witness, *witness).copied())
.collect::<Result<Vec<_>, _>>()?,
})
}

fn len(&self) -> u32 {
self.block_value.len() as u32
}

/// Convert a field element into a memory index
/// Only 32 bits values are valid memory indices
fn index_from_field(&self, index: F) -> Result<MemoryIndex, OpcodeResolutionError<F>> {
Expand All @@ -34,7 +47,7 @@ impl<F: AcirField> MemoryOpSolver<F> {
Err(OpcodeResolutionError::IndexOutOfBounds {
opcode_location: ErrorLocation::Unresolved,
index,
array_size: self.block_len,
array_size: self.len(),
})
}
}
Expand All @@ -46,41 +59,28 @@ impl<F: AcirField> MemoryOpSolver<F> {
index: MemoryIndex,
value: F,
) -> Result<(), OpcodeResolutionError<F>> {
if index >= self.block_len {
if index >= self.len() {
return Err(OpcodeResolutionError::IndexOutOfBounds {
opcode_location: ErrorLocation::Unresolved,
index: F::from(u128::from(index)),
array_size: self.block_len,
array_size: self.len(),
});
}
self.block_value.insert(index, value);

self.block_value[index as usize] = value;
Ok(())
}

/// Returns the value stored in the 'block_value' map for the provided index
/// Returns an 'IndexOutOfBounds' error if the index is not in the map.
fn read_memory_index(&self, index: MemoryIndex) -> Result<F, OpcodeResolutionError<F>> {
self.block_value.get(&index).copied().ok_or(OpcodeResolutionError::IndexOutOfBounds {
opcode_location: ErrorLocation::Unresolved,
index: F::from(u128::from(index)),
array_size: self.block_len,
})
}

/// Set the block_value from a MemoryInit opcode
pub(crate) fn init(
&mut self,
init: &[Witness],
initial_witness: &WitnessMap<F>,
) -> Result<(), OpcodeResolutionError<F>> {
self.block_len = init.len() as u32;
for (memory_index, witness) in init.iter().enumerate() {
self.write_memory_index(
memory_index as MemoryIndex,
*witness_to_value(initial_witness, *witness)?,
)?;
}
Ok(())
self.block_value.get(index as usize).copied().ok_or(
OpcodeResolutionError::IndexOutOfBounds {
opcode_location: ErrorLocation::Unresolved,
index: F::from(u128::from(index)),
array_size: self.len(),
},
)
}

/// Update the 'block_values' by processing the provided Memory opcode
Expand Down Expand Up @@ -192,8 +192,7 @@ mod tests {
MemOp::read_at_mem_index(FieldElement::one().into(), Witness(4)),
];

let mut block_solver = MemoryOpSolver::default();
block_solver.init(&init, &initial_witness).unwrap();
let mut block_solver = MemoryOpSolver::new(&init, &initial_witness).unwrap();

for op in trace {
block_solver
Expand All @@ -218,8 +217,7 @@ mod tests {
MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
MemOp::read_at_mem_index(FieldElement::from(2u128).into(), Witness(4)),
];
let mut block_solver = MemoryOpSolver::default();
block_solver.init(&init, &initial_witness).unwrap();
let mut block_solver = MemoryOpSolver::new(&init, &initial_witness).unwrap();
let mut err = None;
for op in invalid_trace {
if err.is_none() {
Expand All @@ -240,7 +238,7 @@ mod tests {
}

#[test]
// TODO: to review after the serialisation changes are merged because it will remove the predicate.
// TODO: to review after the serialization changes are merged because it will remove the predicate.
fn test_predicate_on_read() {
let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
(Witness(1), FieldElement::from(1u128)),
Expand All @@ -254,8 +252,7 @@ mod tests {
MemOp::write_to_mem_index(FieldElement::from(1u128).into(), Witness(3).into()),
MemOp::read_at_mem_index(FieldElement::from(2u128).into(), Witness(4)),
];
let mut block_solver = MemoryOpSolver::default();
block_solver.init(&init, &initial_witness).unwrap();
let mut block_solver = MemoryOpSolver::new(&init, &initial_witness).unwrap();
let mut err = None;
for op in invalid_trace {
if err.is_none() {
Expand All @@ -277,7 +274,7 @@ mod tests {
}

#[test]
// TODO: to review after the serialisation changes are merged because it will remove the predicate.
// TODO: to review after the serialization changes are merged because it will remove the predicate.
fn test_predicate_on_write() {
let mut initial_witness = WitnessMap::from(BTreeMap::from_iter([
(Witness(1), FieldElement::from(1u128)),
Expand All @@ -292,8 +289,7 @@ mod tests {
MemOp::read_at_mem_index(FieldElement::from(0u128).into(), Witness(4)),
MemOp::read_at_mem_index(FieldElement::from(1u128).into(), Witness(5)),
];
let mut block_solver = MemoryOpSolver::default();
block_solver.init(&init, &initial_witness).unwrap();
let mut block_solver = MemoryOpSolver::new(&init, &initial_witness).unwrap();
let mut err = None;
for op in invalid_trace {
if err.is_none() {
Expand Down
32 changes: 16 additions & 16 deletions acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@
//! - `assertion_payloads`: additional information used to provide feedback to the user when an assertion fails.
//!
//! Returns: ACVM Status
//! - `Solved`: all witness have been sucessfully computed, execution is complete.
//! - `Solved`: all witness have been successfully computed, execution is complete.
//! - `InProgress`: The ACVM is processing the circuit, i.e solving the opcodes. This status is used to resume execution after it has been paused.
//! - `Failure(OpcodeResolutionError<F>)`: Error, execution is stopped.
//! - `RequiresForeignCall(ForeignCallWaitInfo<F>)`: Execution is paused until the result of a foreign call is provided
//! - `RequiresAcirCall(AcirCallWaitInfo<F>)`: Execution is paused until the result of an ACIR call is provided
//!
//! Each opcode is solved independently. In general we require its inputs to be already known, i.e previoulsy solved,
//! Each opcode is solved independently. In general we require its inputs to be already known, i.e previously solved,
//! and the output is simply computed from the inputs, and then the output becomes 'known' for the subsequent opcodes.
//!
//! - AssertZero opcode: The arithmetic expression of the opcode is solved for one unknwon witness.
//! It will fail if there is more than one unkwnown witness in the expression.
//! - AssertZero opcode: The arithmetic expression of the opcode is solved for one unknown witness.
//! It will fail if there is more than one unknown witness in the expression.
//!
//! - BlackBoxFuncCall opcode: The blackbox module knows how to compute the result of the function when all its input are known.
//!
Expand Down Expand Up @@ -68,7 +68,7 @@
// ASSERT w0 - w2 - w9 = 0
//!
//! This ACIR program defines the 'main' function and indicates it is 'non-transformed'.
//! Indeed, some ACIR pass can transform the ACIR program in order to apply optimisations,
//! Indeed, some ACIR pass can transform the ACIR program in order to apply optimizations,
//! or to make it compatible with a specific proving system.
//! However, ACIR execution is expected to work on any ACIR program (transformed or not).
//! Then the program indicates the 'current witness', which is the lasted witness used in the program.
Expand All @@ -81,13 +81,13 @@
//! Solving this black-box simply means to validate that the values (from `initial_witness`) are indeed 32 bits for w0, w1, w2, w3, w4
//! If `initial_witness` does not have values for w0, w1, w2, w3, w4, or if the values are over 32 bits, the execution will fail.
//! The next opcode is an AssertZero opcode: ASSERT w0 - w1 - w6 = 0, which indicates that `w0 - w1 - w6` should be equal to 0.
//! Since we know the values of `w0, w1` from `initial_witness`, we can compute `w6 = w0 + w1` so that the AssertZero is satified.
//! Since we know the values of `w0, w1` from `initial_witness`, we can compute `w6 = w0 + w1` so that the AssertZero is satisfied.
//! Solving AssertZero means computing the unknown witness and adding the result to `initial_witness`, which now contains the value for `w6`.
//! The next opcode is a Brillig Call where input is `w6` and output is `w7`. From the function id of the opcode, the solver will retrieve the
//! corresponding Brillig bytecode and instantiate a Brillig VM with the value of the input. This value was just computed before.
//! Executing the Brillig VM on this input will give us the output which is the value for `w7`, that we add to `initial_witness`.
//! The next opcode is again an AssertZero: `w6 * w7 + w8 - 1 = 0`, which computes the value of `w8`.
//! The two next opcode are AssertZero without any unkwown witness: `w6 * w8 = 0` and `w1 * w8 = 0`
//! The two next opcode are AssertZero without any unknown witness: `w6 * w8 = 0` and `w1 * w8 = 0`
//! Solving such opcodes means that we compute `w6 * w8 ` and `w1 * w8` using the known values, and check that it is 0.
//! If not, we would return an error.
//! Finally, the last AssertZero computes `w9` which is the last witness. All the witness have now been computed; execution is complete.
Expand Down Expand Up @@ -514,11 +514,16 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> ACVM<'a, F, B> {
blackbox::solve(self.backend, &mut self.witness_map, bb_func)
}
Opcode::MemoryInit { block_id, init, .. } => {
let solver = self.block_solvers.entry(*block_id).or_default();
solver.init(init, &self.witness_map)
MemoryOpSolver::new(init, &self.witness_map).map(|solver| {
let existing_block_id = self.block_solvers.insert(*block_id, solver);
assert!(existing_block_id.is_none(), "Memory block already initialized");
})
}
Opcode::MemoryOp { block_id, op } => {
let solver = self.block_solvers.entry(*block_id).or_default();
let solver = self
.block_solvers
.get_mut(block_id)
.expect("Memory block should have been initialized before use");
solver.solve_memory_op(
op,
&mut self.witness_map,
Expand Down Expand Up @@ -604,12 +609,7 @@ impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> ACVM<'a, F, B> {
}
ExpressionOrMemory::Memory(block_id) => {
let memory_block = self.block_solvers.get(block_id)?;
fields.extend((0..memory_block.block_len).map(|memory_index| {
*memory_block
.block_value
.get(&memory_index)
.expect("All memory is initialized on creation")
}));
fields.extend(&memory_block.block_value);
}
}
}
Expand Down
Loading