From 8e740778536d183d7956aee1af46cd85da050f88 Mon Sep 17 00:00:00 2001 From: porcuquine Date: Mon, 6 Nov 2023 17:21:41 -0800 Subject: [PATCH 1/6] Remove circuit_index from prove_step. --- src/supernova/mod.rs | 2 +- src/supernova/test.rs | 14 ++------------ 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/src/supernova/mod.rs b/src/supernova/mod.rs index c8a94e68..d8f1f778 100644 --- a/src/supernova/mod.rs +++ b/src/supernova/mod.rs @@ -577,10 +577,10 @@ where pub fn prove_step, C2: StepCircuit>( &mut self, pp: &PublicParams, - circuit_index: usize, c_primary: &C1, c_secondary: &C2, ) -> Result<(), SuperNovaError> { + let circuit_index = c_primary.circuit_index(); // First step was already done in the constructor if self.i == 0 { self.i = 1; diff --git a/src/supernova/test.rs b/src/supernova/test.rs index 93970f36..4c226bbe 100644 --- a/src/supernova/test.rs +++ b/src/supernova/test.rs @@ -505,12 +505,7 @@ where let circuit_primary = test_rom.primary_circuit(augmented_circuit_index); let circuit_secondary = test_rom.secondary_circuit(); recursive_snark - .prove_step( - &pp, - augmented_circuit_index, - &circuit_primary, - &circuit_secondary, - ) + .prove_step(&pp, &circuit_primary, &circuit_secondary) .unwrap(); recursive_snark .verify(&pp, augmented_circuit_index, &z0_primary, &z0_secondary) @@ -989,12 +984,7 @@ where .unwrap(); for circuit_primary in roots.iter().take(num_steps) { - let res = recursive_snark.prove_step( - &pp, - circuit_primary.circuit_index(), - circuit_primary, - &circuit_secondary, - ); + let res = recursive_snark.prove_step(&pp, circuit_primary, &circuit_secondary); assert!(res .map_err(|err| { print_constraints_name_on_error_index( From 67c30b3866abfa0d9ee57506f9fc810f3312737e Mon Sep 17 00:00:00 2001 From: porcuquine Date: Mon, 6 Nov 2023 17:41:46 -0800 Subject: [PATCH 2/6] Simplify API. --- src/supernova/mod.rs | 33 +++++++++++++++++++++------------ src/supernova/test.rs | 19 +++++++------------ 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/src/supernova/mod.rs b/src/supernova/mod.rs index d8f1f778..b1acc44c 100644 --- a/src/supernova/mod.rs +++ b/src/supernova/mod.rs @@ -416,17 +416,21 @@ where { /// iterate base step to get new instance of recursive SNARK #[allow(clippy::too_many_arguments)] - pub fn iter_base_step, C2: StepCircuit>( + pub fn new< + C0: NonUniformCircuit, + C1: StepCircuit, + C2: StepCircuit, + >( pp: &PublicParams, - circuit_index: usize, + non_uniform_circuit: &C0, c_primary: &C1, c_secondary: &C2, - initial_program_counter: Option, - first_augmented_circuit_index: usize, - num_augmented_circuits: usize, z0_primary: &[G1::Scalar], z0_secondary: &[G2::Scalar], ) -> Result { + let num_augmented_circuits = non_uniform_circuit.num_circuits(); + let circuit_index = non_uniform_circuit.initial_circuit_index(); + if z0_primary.len() != pp[circuit_index].F_arity || z0_secondary.len() != pp.circuit_shape_secondary.F_arity { @@ -446,14 +450,14 @@ where None, None, None, - initial_program_counter, + Some(G1::Scalar::from(circuit_index as u64)), G1::Scalar::ZERO, // set augmented circuit index selector to 0 in base case ); let circuit_primary: SuperNovaAugmentedCircuit<'_, G2, C1> = SuperNovaAugmentedCircuit::new( &pp.augmented_circuit_params_primary, Some(inputs_primary), - c_primary, + &c_primary, pp.ro_consts_circuit_primary.clone(), num_augmented_circuits, ); @@ -545,11 +549,11 @@ where // handle the base case by initialize U_next in next round let r_W_primary_initial_list = (0..num_augmented_circuits) - .map(|i| (i == first_augmented_circuit_index).then(|| r_W_primary.clone())) + .map(|i| (i == circuit_index).then(|| r_W_primary.clone())) .collect::>>>(); let r_U_primary_initial_list = (0..num_augmented_circuits) - .map(|i| (i == first_augmented_circuit_index).then(|| r_U_primary.clone())) + .map(|i| (i == circuit_index).then(|| r_U_primary.clone())) .collect::>>>(); Ok(Self { @@ -566,7 +570,7 @@ where zi_primary, zi_secondary, program_counter: zi_primary_pc_next, - augmented_circuit_index: first_augmented_circuit_index, + augmented_circuit_index: circuit_index, num_augmented_circuits, }) } @@ -938,9 +942,14 @@ where C1: StepCircuit, C2: StepCircuit, { - /// Initial program counter, defaults to zero. + /// Initial program counter is the initial circuit index as a `Scalar`. fn initial_program_counter(&self) -> G1::Scalar { - G1::Scalar::ZERO + G1::Scalar::from(self.initial_circuit_index() as u64) + } + + /// Initial circuit index, defaults to zero. + fn initial_circuit_index(&self) -> usize { + 0 } /// How many circuits are provided? diff --git a/src/supernova/test.rs b/src/supernova/test.rs index 4c226bbe..93c46747 100644 --- a/src/supernova/test.rs +++ b/src/supernova/test.rs @@ -395,8 +395,8 @@ where Default::default() } - fn initial_program_counter(&self) -> G1::Scalar { - G1::Scalar::from(self.rom[0] as u64) + fn initial_circuit_index(&self) -> usize { + self.rom[0] } } @@ -484,14 +484,11 @@ where let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| match augmented_circuit_index { - OPCODE_0 | OPCODE_1 => RecursiveSNARK::iter_base_step( + OPCODE_0 | OPCODE_1 => RecursiveSNARK::new( &pp, - augmented_circuit_index, + &test_rom, &test_rom.primary_circuit(augmented_circuit_index), &test_rom.secondary_circuit(), - Some(program_counter), - augmented_circuit_index, - test_rom.num_circuits(), &z0_primary, &z0_secondary, ) @@ -956,14 +953,12 @@ where // produce a recursive SNARK let circuit_primary = &roots[0]; - let mut recursive_snark = RecursiveSNARK::::iter_base_step( + + let mut recursive_snark = RecursiveSNARK::::new( &pp, - circuit_primary.circuit_index(), + circuit_primary, circuit_primary, &circuit_secondary, - Some(G1::Scalar::from(circuit_primary.circuit_index() as u64)), - circuit_primary.circuit_index(), - 2, &z0_primary, &z0_secondary, ) From 1b32c1445b129eedf6f6ed4ea56f8bcbae34adef Mon Sep 17 00:00:00 2001 From: porcuquine Date: Tue, 7 Nov 2023 10:50:20 -0800 Subject: [PATCH 3/6] Clippy and bench. --- benches/recursive-snark-supernova.rs | 52 ++++++---------------------- src/supernova/mod.rs | 2 +- 2 files changed, 11 insertions(+), 43 deletions(-) diff --git a/benches/recursive-snark-supernova.rs b/benches/recursive-snark-supernova.rs index 0bedc9f3..a42ba633 100644 --- a/benches/recursive-snark-supernova.rs +++ b/benches/recursive-snark-supernova.rs @@ -118,36 +118,23 @@ fn bench_one_augmented_circuit_recursive_snark(c: &mut Criterion) { let num_warmup_steps = 10; let z0_primary = vec![::Scalar::from(2u64)]; let z0_secondary = vec![::Scalar::from(2u64)]; - let initial_program_counter = ::Scalar::from(0); let mut recursive_snark_option: Option> = None; for _ in 0..num_warmup_steps { - let program_counter = recursive_snark_option.as_ref().map_or_else( - || initial_program_counter, - |recursive_snark| recursive_snark.get_program_counter(), - ); - let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| { - RecursiveSNARK::iter_base_step( + RecursiveSNARK::new( &pp, - 0, + &bench, &bench.primary_circuit(0), &bench.secondary_circuit(), - Some(program_counter), - 0, - 1, &z0_primary, &z0_secondary, ) .unwrap() }); - let res = recursive_snark.prove_step( - &pp, - 0, - &bench.primary_circuit(0), - &bench.secondary_circuit(), - ); + let res = + recursive_snark.prove_step(&pp, &bench.primary_circuit(0), &bench.secondary_circuit()); if let Err(e) = &res { println!("res failed {:?}", e); } @@ -170,7 +157,6 @@ fn bench_one_augmented_circuit_recursive_snark(c: &mut Criterion) { assert!(black_box(&mut recursive_snark.clone()) .prove_step( black_box(&pp), - black_box(0), &bench.primary_circuit(0), &bench.secondary_circuit(), ) @@ -224,25 +210,16 @@ fn bench_two_augmented_circuit_recursive_snark(c: &mut Criterion) { let num_warmup_steps = 10; let z0_primary = vec![::Scalar::from(2u64)]; let z0_secondary = vec![::Scalar::from(2u64)]; - let initial_program_counter = ::Scalar::from(0); let mut recursive_snark_option: Option> = None; let mut selected_augmented_circuit = 0; for _ in 0..num_warmup_steps { - let program_counter = recursive_snark_option.as_ref().map_or_else( - || initial_program_counter, - |recursive_snark| recursive_snark.get_program_counter(), - ); - let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| { - RecursiveSNARK::iter_base_step( + RecursiveSNARK::new( &pp, - 0, + &bench, &bench.primary_circuit(0), &bench.secondary_circuit(), - Some(program_counter), - 0, - 2, &z0_primary, &z0_secondary, ) @@ -250,12 +227,8 @@ fn bench_two_augmented_circuit_recursive_snark(c: &mut Criterion) { }); if selected_augmented_circuit == 0 { - let res = recursive_snark.prove_step( - &pp, - 0, - &bench.primary_circuit(0), - &bench.secondary_circuit(), - ); + let res = + recursive_snark.prove_step(&pp, &bench.primary_circuit(0), &bench.secondary_circuit()); if let Err(e) = &res { println!("res failed {:?}", e); } @@ -266,12 +239,8 @@ fn bench_two_augmented_circuit_recursive_snark(c: &mut Criterion) { } assert!(res.is_ok()); } else if selected_augmented_circuit == 1 { - let res = recursive_snark.prove_step( - &pp, - 1, - &bench.primary_circuit(1), - &bench.secondary_circuit(), - ); + let res = + recursive_snark.prove_step(&pp, &bench.primary_circuit(1), &bench.secondary_circuit()); if let Err(e) = &res { println!("res failed {:?}", e); } @@ -298,7 +267,6 @@ fn bench_two_augmented_circuit_recursive_snark(c: &mut Criterion) { assert!(black_box(&mut recursive_snark.clone()) .prove_step( black_box(&pp), - black_box(0), &bench.primary_circuit(0), &bench.secondary_circuit(), ) diff --git a/src/supernova/mod.rs b/src/supernova/mod.rs index b1acc44c..0b4d7d95 100644 --- a/src/supernova/mod.rs +++ b/src/supernova/mod.rs @@ -457,7 +457,7 @@ where let circuit_primary: SuperNovaAugmentedCircuit<'_, G2, C1> = SuperNovaAugmentedCircuit::new( &pp.augmented_circuit_params_primary, Some(inputs_primary), - &c_primary, + c_primary, pp.ro_consts_circuit_primary.clone(), num_augmented_circuits, ); From 80db2297bf72c08569710a6871422a020230995f Mon Sep 17 00:00:00 2001 From: porcuquine Date: Tue, 7 Nov 2023 11:14:47 -0800 Subject: [PATCH 4/6] Remove get_program_counter. --- src/supernova/mod.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/supernova/mod.rs b/src/supernova/mod.rs index 0b4d7d95..91d38b77 100644 --- a/src/supernova/mod.rs +++ b/src/supernova/mod.rs @@ -925,11 +925,6 @@ where Ok((self.zi_primary.clone(), self.zi_secondary.clone())) } - - /// get program counter - pub fn get_program_counter(&self) -> G1::Scalar { - self.program_counter - } } /// SuperNova helper trait, for implementors that provide sets of sub-circuits to be proved via NIVC. `C1` must be a From c8d65ba554a70ea377fcf1a82c0781813524bc74 Mon Sep 17 00:00:00 2001 From: porcuquine Date: Tue, 7 Nov 2023 11:15:12 -0800 Subject: [PATCH 5/6] Assert circuit_index and program_counter match. --- src/supernova/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/supernova/mod.rs b/src/supernova/mod.rs index 91d38b77..9b6db254 100644 --- a/src/supernova/mod.rs +++ b/src/supernova/mod.rs @@ -584,7 +584,6 @@ where c_primary: &C1, c_secondary: &C2, ) -> Result<(), SuperNovaError> { - let circuit_index = c_primary.circuit_index(); // First step was already done in the constructor if self.i == 0 { self.i = 1; @@ -595,6 +594,9 @@ where return Err(NovaError::ProofVerifyError.into()); } + let circuit_index = c_primary.circuit_index(); + assert_eq!(self.program_counter, G1::Scalar::from(circuit_index as u64)); + // fold the secondary circuit's instance let (nifs_secondary, (r_U_secondary_folded, r_W_secondary_folded)) = NIFS::prove( &pp.ck_secondary, From 9d1458a32b34f939de9cf4a7576dcca02f828de7 Mon Sep 17 00:00:00 2001 From: porcuquine Date: Tue, 7 Nov 2023 11:33:49 -0800 Subject: [PATCH 6/6] Add InitialProgramCounter trait. --- src/supernova/mod.rs | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/src/supernova/mod.rs b/src/supernova/mod.rs index 9b6db254..e3892e30 100644 --- a/src/supernova/mod.rs +++ b/src/supernova/mod.rs @@ -939,11 +939,6 @@ where C1: StepCircuit, C2: StepCircuit, { - /// Initial program counter is the initial circuit index as a `Scalar`. - fn initial_program_counter(&self) -> G1::Scalar { - G1::Scalar::from(self.initial_circuit_index() as u64) - } - /// Initial circuit index, defaults to zero. fn initial_circuit_index(&self) -> usize { 0 @@ -959,6 +954,30 @@ where fn secondary_circuit(&self) -> C2; } +/// Extension trait to simplify getting scalar form of initial circuit index. +pub trait InitialProgramCounter: NonUniformCircuit +where + G1: Group::Scalar>, + G2: Group::Scalar>, + C1: StepCircuit, + C2: StepCircuit, +{ + /// Initial program counter is the initial circuit index as a `Scalar`. + fn initial_program_counter(&self) -> G1::Scalar { + G1::Scalar::from(self.initial_circuit_index() as u64) + } +} + +impl> InitialProgramCounter + for T +where + G1: Group::Scalar>, + G2: Group::Scalar>, + C1: StepCircuit, + C2: StepCircuit, +{ +} + /// Compute the circuit digest of a supernova [StepCircuit]. /// /// Note for callers: This function should be called with its performance characteristics in mind.