Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 10 additions & 42 deletions benches/recursive-snark-supernova.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,36 +118,23 @@ fn bench_one_augmented_circuit_recursive_snark(c: &mut Criterion) {
let num_warmup_steps = 10;
let z0_primary = vec![<G1 as Group>::Scalar::from(2u64)];
let z0_secondary = vec![<G2 as Group>::Scalar::from(2u64)];
let initial_program_counter = <G1 as Group>::Scalar::from(0);
let mut recursive_snark_option: Option<RecursiveSNARK<G1, G2>> = None;

for _ in 0..num_warmup_steps {
let program_counter = recursive_snark_option.as_ref().map_or_else(
|| initial_program_counter,
|recursive_snark| recursive_snark.get_program_counter(),
);

let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| {
RecursiveSNARK::iter_base_step(
RecursiveSNARK::new(
&pp,
0,
&bench,
&bench.primary_circuit(0),
&bench.secondary_circuit(),
Some(program_counter),
0,
1,
&z0_primary,
&z0_secondary,
)
.unwrap()
});

let res = recursive_snark.prove_step(
&pp,
0,
&bench.primary_circuit(0),
&bench.secondary_circuit(),
);
let res =
recursive_snark.prove_step(&pp, &bench.primary_circuit(0), &bench.secondary_circuit());
if let Err(e) = &res {
println!("res failed {:?}", e);
}
Expand All @@ -170,7 +157,6 @@ fn bench_one_augmented_circuit_recursive_snark(c: &mut Criterion) {
assert!(black_box(&mut recursive_snark.clone())
.prove_step(
black_box(&pp),
black_box(0),
&bench.primary_circuit(0),
&bench.secondary_circuit(),
)
Expand Down Expand Up @@ -224,38 +210,25 @@ fn bench_two_augmented_circuit_recursive_snark(c: &mut Criterion) {
let num_warmup_steps = 10;
let z0_primary = vec![<G1 as Group>::Scalar::from(2u64)];
let z0_secondary = vec![<G2 as Group>::Scalar::from(2u64)];
let initial_program_counter = <G1 as Group>::Scalar::from(0);
let mut recursive_snark_option: Option<RecursiveSNARK<G1, G2>> = None;
let mut selected_augmented_circuit = 0;

for _ in 0..num_warmup_steps {
let program_counter = recursive_snark_option.as_ref().map_or_else(
|| initial_program_counter,
|recursive_snark| recursive_snark.get_program_counter(),
);

let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| {
RecursiveSNARK::iter_base_step(
RecursiveSNARK::new(
&pp,
0,
&bench,
&bench.primary_circuit(0),
&bench.secondary_circuit(),
Some(program_counter),
0,
2,
&z0_primary,
&z0_secondary,
)
.unwrap()
});

if selected_augmented_circuit == 0 {
let res = recursive_snark.prove_step(
&pp,
0,
&bench.primary_circuit(0),
&bench.secondary_circuit(),
);
let res =
recursive_snark.prove_step(&pp, &bench.primary_circuit(0), &bench.secondary_circuit());
if let Err(e) = &res {
println!("res failed {:?}", e);
}
Expand All @@ -266,12 +239,8 @@ fn bench_two_augmented_circuit_recursive_snark(c: &mut Criterion) {
}
assert!(res.is_ok());
} else if selected_augmented_circuit == 1 {
let res = recursive_snark.prove_step(
&pp,
1,
&bench.primary_circuit(1),
&bench.secondary_circuit(),
);
let res =
recursive_snark.prove_step(&pp, &bench.primary_circuit(1), &bench.secondary_circuit());
if let Err(e) = &res {
println!("res failed {:?}", e);
}
Expand All @@ -298,7 +267,6 @@ fn bench_two_augmented_circuit_recursive_snark(c: &mut Criterion) {
assert!(black_box(&mut recursive_snark.clone())
.prove_step(
black_box(&pp),
black_box(0),
&bench.primary_circuit(0),
&bench.secondary_circuit(),
)
Expand Down
61 changes: 43 additions & 18 deletions src/supernova/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,17 +416,21 @@ where
{
/// iterate base step to get new instance of recursive SNARK
#[allow(clippy::too_many_arguments)]
pub fn iter_base_step<C1: StepCircuit<G1::Scalar>, C2: StepCircuit<G2::Scalar>>(
pub fn new<
C0: NonUniformCircuit<G1, G2, C1, C2>,
C1: StepCircuit<G1::Scalar>,
C2: StepCircuit<G2::Scalar>,
>(
pp: &PublicParams<G1, G2, C1, C2>,
circuit_index: usize,
non_uniform_circuit: &C0,
c_primary: &C1,
c_secondary: &C2,
initial_program_counter: Option<G1::Scalar>,
first_augmented_circuit_index: usize,
num_augmented_circuits: usize,
z0_primary: &[G1::Scalar],
z0_secondary: &[G2::Scalar],
) -> Result<Self, SuperNovaError> {
let num_augmented_circuits = non_uniform_circuit.num_circuits();
let circuit_index = non_uniform_circuit.initial_circuit_index();

if z0_primary.len() != pp[circuit_index].F_arity
|| z0_secondary.len() != pp.circuit_shape_secondary.F_arity
{
Expand All @@ -446,7 +450,7 @@ where
None,
None,
None,
initial_program_counter,
Some(G1::Scalar::from(circuit_index as u64)),
G1::Scalar::ZERO, // set augmented circuit index selector to 0 in base case
);

Expand Down Expand Up @@ -545,11 +549,11 @@ where

// handle the base case by initialize U_next in next round
let r_W_primary_initial_list = (0..num_augmented_circuits)
.map(|i| (i == first_augmented_circuit_index).then(|| r_W_primary.clone()))
.map(|i| (i == circuit_index).then(|| r_W_primary.clone()))
.collect::<Vec<Option<RelaxedR1CSWitness<G1>>>>();

let r_U_primary_initial_list = (0..num_augmented_circuits)
.map(|i| (i == first_augmented_circuit_index).then(|| r_U_primary.clone()))
.map(|i| (i == circuit_index).then(|| r_U_primary.clone()))
.collect::<Vec<Option<RelaxedR1CSInstance<G1>>>>();

Ok(Self {
Expand All @@ -566,7 +570,7 @@ where
zi_primary,
zi_secondary,
program_counter: zi_primary_pc_next,
augmented_circuit_index: first_augmented_circuit_index,
augmented_circuit_index: circuit_index,
num_augmented_circuits,
})
}
Expand All @@ -577,7 +581,6 @@ where
pub fn prove_step<C1: StepCircuit<G1::Scalar>, C2: StepCircuit<G2::Scalar>>(
&mut self,
pp: &PublicParams<G1, G2, C1, C2>,
circuit_index: usize,
c_primary: &C1,
c_secondary: &C2,
) -> Result<(), SuperNovaError> {
Expand All @@ -591,6 +594,9 @@ where
return Err(NovaError::ProofVerifyError.into());
}

let circuit_index = c_primary.circuit_index();
assert_eq!(self.program_counter, G1::Scalar::from(circuit_index as u64));

// fold the secondary circuit's instance
let (nifs_secondary, (r_U_secondary_folded, r_W_secondary_folded)) = NIFS::prove(
&pp.ck_secondary,
Expand Down Expand Up @@ -921,11 +927,6 @@ where

Ok((self.zi_primary.clone(), self.zi_secondary.clone()))
}

/// get program counter
pub fn get_program_counter(&self) -> G1::Scalar {
self.program_counter
}
}

/// SuperNova helper trait, for implementors that provide sets of sub-circuits to be proved via NIVC. `C1` must be a
Expand All @@ -938,9 +939,9 @@ where
C1: StepCircuit<G1::Scalar>,
C2: StepCircuit<G2::Scalar>,
{
/// Initial program counter, defaults to zero.
fn initial_program_counter(&self) -> G1::Scalar {
G1::Scalar::ZERO
/// Initial circuit index, defaults to zero.
fn initial_circuit_index(&self) -> usize {
0
}

/// How many circuits are provided?
Expand All @@ -953,6 +954,30 @@ where
fn secondary_circuit(&self) -> C2;
}

/// Extension trait to simplify getting scalar form of initial circuit index.
pub trait InitialProgramCounter<G1, G2, C1, C2>: NonUniformCircuit<G1, G2, C1, C2>
where
G1: Group<Base = <G2 as Group>::Scalar>,
G2: Group<Base = <G1 as Group>::Scalar>,
C1: StepCircuit<G1::Scalar>,
C2: StepCircuit<G2::Scalar>,
{
/// Initial program counter is the initial circuit index as a `Scalar`.
fn initial_program_counter(&self) -> G1::Scalar {
G1::Scalar::from(self.initial_circuit_index() as u64)
}
}

impl<G1, G2, C1, C2, T: NonUniformCircuit<G1, G2, C1, C2>> InitialProgramCounter<G1, G2, C1, C2>
for T
where
G1: Group<Base = <G2 as Group>::Scalar>,
G2: Group<Base = <G1 as Group>::Scalar>,
C1: StepCircuit<G1::Scalar>,
C2: StepCircuit<G2::Scalar>,
{
}

/// Compute the circuit digest of a supernova [StepCircuit].
///
/// Note for callers: This function should be called with its performance characteristics in mind.
Expand Down
33 changes: 9 additions & 24 deletions src/supernova/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,8 +395,8 @@ where
Default::default()
}

fn initial_program_counter(&self) -> G1::Scalar {
G1::Scalar::from(self.rom[0] as u64)
fn initial_circuit_index(&self) -> usize {
self.rom[0]
}
}

Expand Down Expand Up @@ -484,14 +484,11 @@ where

let mut recursive_snark =
recursive_snark_option.unwrap_or_else(|| match augmented_circuit_index {
OPCODE_0 | OPCODE_1 => RecursiveSNARK::iter_base_step(
OPCODE_0 | OPCODE_1 => RecursiveSNARK::new(
&pp,
augmented_circuit_index,
&test_rom,
&test_rom.primary_circuit(augmented_circuit_index),
&test_rom.secondary_circuit(),
Some(program_counter),
augmented_circuit_index,
test_rom.num_circuits(),
&z0_primary,
&z0_secondary,
)
Expand All @@ -505,12 +502,7 @@ where
let circuit_primary = test_rom.primary_circuit(augmented_circuit_index);
let circuit_secondary = test_rom.secondary_circuit();
recursive_snark
.prove_step(
&pp,
augmented_circuit_index,
&circuit_primary,
&circuit_secondary,
)
.prove_step(&pp, &circuit_primary, &circuit_secondary)
.unwrap();
recursive_snark
.verify(&pp, augmented_circuit_index, &z0_primary, &z0_secondary)
Expand Down Expand Up @@ -961,14 +953,12 @@ where
// produce a recursive SNARK

let circuit_primary = &roots[0];
let mut recursive_snark = RecursiveSNARK::<G1, G2>::iter_base_step(

let mut recursive_snark = RecursiveSNARK::<G1, G2>::new(
&pp,
circuit_primary.circuit_index(),
circuit_primary,
circuit_primary,
&circuit_secondary,
Some(G1::Scalar::from(circuit_primary.circuit_index() as u64)),
circuit_primary.circuit_index(),
2,
&z0_primary,
&z0_secondary,
)
Expand All @@ -989,12 +979,7 @@ where
.unwrap();

for circuit_primary in roots.iter().take(num_steps) {
let res = recursive_snark.prove_step(
&pp,
circuit_primary.circuit_index(),
circuit_primary,
&circuit_secondary,
);
let res = recursive_snark.prove_step(&pp, circuit_primary, &circuit_secondary);
assert!(res
.map_err(|err| {
print_constraints_name_on_error_index(
Expand Down