diff --git a/benches/recursive-snark-supernova.rs b/benches/recursive-snark-supernova.rs index 0bedc9f3..a42ba633 100644 --- a/benches/recursive-snark-supernova.rs +++ b/benches/recursive-snark-supernova.rs @@ -118,36 +118,23 @@ fn bench_one_augmented_circuit_recursive_snark(c: &mut Criterion) { let num_warmup_steps = 10; let z0_primary = vec![::Scalar::from(2u64)]; let z0_secondary = vec![::Scalar::from(2u64)]; - let initial_program_counter = ::Scalar::from(0); let mut recursive_snark_option: Option> = None; for _ in 0..num_warmup_steps { - let program_counter = recursive_snark_option.as_ref().map_or_else( - || initial_program_counter, - |recursive_snark| recursive_snark.get_program_counter(), - ); - let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| { - RecursiveSNARK::iter_base_step( + RecursiveSNARK::new( &pp, - 0, + &bench, &bench.primary_circuit(0), &bench.secondary_circuit(), - Some(program_counter), - 0, - 1, &z0_primary, &z0_secondary, ) .unwrap() }); - let res = recursive_snark.prove_step( - &pp, - 0, - &bench.primary_circuit(0), - &bench.secondary_circuit(), - ); + let res = + recursive_snark.prove_step(&pp, &bench.primary_circuit(0), &bench.secondary_circuit()); if let Err(e) = &res { println!("res failed {:?}", e); } @@ -170,7 +157,6 @@ fn bench_one_augmented_circuit_recursive_snark(c: &mut Criterion) { assert!(black_box(&mut recursive_snark.clone()) .prove_step( black_box(&pp), - black_box(0), &bench.primary_circuit(0), &bench.secondary_circuit(), ) @@ -224,25 +210,16 @@ fn bench_two_augmented_circuit_recursive_snark(c: &mut Criterion) { let num_warmup_steps = 10; let z0_primary = vec![::Scalar::from(2u64)]; let z0_secondary = vec![::Scalar::from(2u64)]; - let initial_program_counter = ::Scalar::from(0); let mut recursive_snark_option: Option> = None; let mut selected_augmented_circuit = 0; for _ in 0..num_warmup_steps { - let program_counter = recursive_snark_option.as_ref().map_or_else( - || initial_program_counter, - |recursive_snark| recursive_snark.get_program_counter(), - ); - let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| { - RecursiveSNARK::iter_base_step( + RecursiveSNARK::new( &pp, - 0, + &bench, &bench.primary_circuit(0), &bench.secondary_circuit(), - Some(program_counter), - 0, - 2, &z0_primary, &z0_secondary, ) @@ -250,12 +227,8 @@ fn bench_two_augmented_circuit_recursive_snark(c: &mut Criterion) { }); if selected_augmented_circuit == 0 { - let res = recursive_snark.prove_step( - &pp, - 0, - &bench.primary_circuit(0), - &bench.secondary_circuit(), - ); + let res = + recursive_snark.prove_step(&pp, &bench.primary_circuit(0), &bench.secondary_circuit()); if let Err(e) = &res { println!("res failed {:?}", e); } @@ -266,12 +239,8 @@ fn bench_two_augmented_circuit_recursive_snark(c: &mut Criterion) { } assert!(res.is_ok()); } else if selected_augmented_circuit == 1 { - let res = recursive_snark.prove_step( - &pp, - 1, - &bench.primary_circuit(1), - &bench.secondary_circuit(), - ); + let res = + recursive_snark.prove_step(&pp, &bench.primary_circuit(1), &bench.secondary_circuit()); if let Err(e) = &res { println!("res failed {:?}", e); } @@ -298,7 +267,6 @@ fn bench_two_augmented_circuit_recursive_snark(c: &mut Criterion) { assert!(black_box(&mut recursive_snark.clone()) .prove_step( black_box(&pp), - black_box(0), &bench.primary_circuit(0), &bench.secondary_circuit(), ) diff --git a/src/supernova/mod.rs b/src/supernova/mod.rs index c8a94e68..e3892e30 100644 --- a/src/supernova/mod.rs +++ b/src/supernova/mod.rs @@ -416,17 +416,21 @@ where { /// iterate base step to get new instance of recursive SNARK #[allow(clippy::too_many_arguments)] - pub fn iter_base_step, C2: StepCircuit>( + pub fn new< + C0: NonUniformCircuit, + C1: StepCircuit, + C2: StepCircuit, + >( pp: &PublicParams, - circuit_index: usize, + non_uniform_circuit: &C0, c_primary: &C1, c_secondary: &C2, - initial_program_counter: Option, - first_augmented_circuit_index: usize, - num_augmented_circuits: usize, z0_primary: &[G1::Scalar], z0_secondary: &[G2::Scalar], ) -> Result { + let num_augmented_circuits = non_uniform_circuit.num_circuits(); + let circuit_index = non_uniform_circuit.initial_circuit_index(); + if z0_primary.len() != pp[circuit_index].F_arity || z0_secondary.len() != pp.circuit_shape_secondary.F_arity { @@ -446,7 +450,7 @@ where None, None, None, - initial_program_counter, + Some(G1::Scalar::from(circuit_index as u64)), G1::Scalar::ZERO, // set augmented circuit index selector to 0 in base case ); @@ -545,11 +549,11 @@ where // handle the base case by initialize U_next in next round let r_W_primary_initial_list = (0..num_augmented_circuits) - .map(|i| (i == first_augmented_circuit_index).then(|| r_W_primary.clone())) + .map(|i| (i == circuit_index).then(|| r_W_primary.clone())) .collect::>>>(); let r_U_primary_initial_list = (0..num_augmented_circuits) - .map(|i| (i == first_augmented_circuit_index).then(|| r_U_primary.clone())) + .map(|i| (i == circuit_index).then(|| r_U_primary.clone())) .collect::>>>(); Ok(Self { @@ -566,7 +570,7 @@ where zi_primary, zi_secondary, program_counter: zi_primary_pc_next, - augmented_circuit_index: first_augmented_circuit_index, + augmented_circuit_index: circuit_index, num_augmented_circuits, }) } @@ -577,7 +581,6 @@ where pub fn prove_step, C2: StepCircuit>( &mut self, pp: &PublicParams, - circuit_index: usize, c_primary: &C1, c_secondary: &C2, ) -> Result<(), SuperNovaError> { @@ -591,6 +594,9 @@ where return Err(NovaError::ProofVerifyError.into()); } + let circuit_index = c_primary.circuit_index(); + assert_eq!(self.program_counter, G1::Scalar::from(circuit_index as u64)); + // fold the secondary circuit's instance let (nifs_secondary, (r_U_secondary_folded, r_W_secondary_folded)) = NIFS::prove( &pp.ck_secondary, @@ -921,11 +927,6 @@ where Ok((self.zi_primary.clone(), self.zi_secondary.clone())) } - - /// get program counter - pub fn get_program_counter(&self) -> G1::Scalar { - self.program_counter - } } /// SuperNova helper trait, for implementors that provide sets of sub-circuits to be proved via NIVC. `C1` must be a @@ -938,9 +939,9 @@ where C1: StepCircuit, C2: StepCircuit, { - /// Initial program counter, defaults to zero. - fn initial_program_counter(&self) -> G1::Scalar { - G1::Scalar::ZERO + /// Initial circuit index, defaults to zero. + fn initial_circuit_index(&self) -> usize { + 0 } /// How many circuits are provided? @@ -953,6 +954,30 @@ where fn secondary_circuit(&self) -> C2; } +/// Extension trait to simplify getting scalar form of initial circuit index. +pub trait InitialProgramCounter: NonUniformCircuit +where + G1: Group::Scalar>, + G2: Group::Scalar>, + C1: StepCircuit, + C2: StepCircuit, +{ + /// Initial program counter is the initial circuit index as a `Scalar`. + fn initial_program_counter(&self) -> G1::Scalar { + G1::Scalar::from(self.initial_circuit_index() as u64) + } +} + +impl> InitialProgramCounter + for T +where + G1: Group::Scalar>, + G2: Group::Scalar>, + C1: StepCircuit, + C2: StepCircuit, +{ +} + /// Compute the circuit digest of a supernova [StepCircuit]. /// /// Note for callers: This function should be called with its performance characteristics in mind. diff --git a/src/supernova/test.rs b/src/supernova/test.rs index 93970f36..93c46747 100644 --- a/src/supernova/test.rs +++ b/src/supernova/test.rs @@ -395,8 +395,8 @@ where Default::default() } - fn initial_program_counter(&self) -> G1::Scalar { - G1::Scalar::from(self.rom[0] as u64) + fn initial_circuit_index(&self) -> usize { + self.rom[0] } } @@ -484,14 +484,11 @@ where let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| match augmented_circuit_index { - OPCODE_0 | OPCODE_1 => RecursiveSNARK::iter_base_step( + OPCODE_0 | OPCODE_1 => RecursiveSNARK::new( &pp, - augmented_circuit_index, + &test_rom, &test_rom.primary_circuit(augmented_circuit_index), &test_rom.secondary_circuit(), - Some(program_counter), - augmented_circuit_index, - test_rom.num_circuits(), &z0_primary, &z0_secondary, ) @@ -505,12 +502,7 @@ where let circuit_primary = test_rom.primary_circuit(augmented_circuit_index); let circuit_secondary = test_rom.secondary_circuit(); recursive_snark - .prove_step( - &pp, - augmented_circuit_index, - &circuit_primary, - &circuit_secondary, - ) + .prove_step(&pp, &circuit_primary, &circuit_secondary) .unwrap(); recursive_snark .verify(&pp, augmented_circuit_index, &z0_primary, &z0_secondary) @@ -961,14 +953,12 @@ where // produce a recursive SNARK let circuit_primary = &roots[0]; - let mut recursive_snark = RecursiveSNARK::::iter_base_step( + + let mut recursive_snark = RecursiveSNARK::::new( &pp, - circuit_primary.circuit_index(), + circuit_primary, circuit_primary, &circuit_secondary, - Some(G1::Scalar::from(circuit_primary.circuit_index() as u64)), - circuit_primary.circuit_index(), - 2, &z0_primary, &z0_secondary, ) @@ -989,12 +979,7 @@ where .unwrap(); for circuit_primary in roots.iter().take(num_steps) { - let res = recursive_snark.prove_step( - &pp, - circuit_primary.circuit_index(), - circuit_primary, - &circuit_secondary, - ); + let res = recursive_snark.prove_step(&pp, circuit_primary, &circuit_secondary); assert!(res .map_err(|err| { print_constraints_name_on_error_index(