diff --git a/src/supernova/mod.rs b/src/supernova/mod.rs index 3c1a320f..8e636b34 100644 --- a/src/supernova/mod.rs +++ b/src/supernova/mod.rs @@ -12,8 +12,9 @@ use crate::{ }, scalar_as_base, traits::{ - circuit_supernova::StepCircuit, commitment::CommitmentTrait, AbsorbInROTrait, Group, - ROConstants, ROConstantsCircuit, ROTrait, + circuit_supernova::{StepCircuit, TrivialSecondaryCircuit}, + commitment::CommitmentTrait, + AbsorbInROTrait, Group, ROConstants, ROConstantsCircuit, ROTrait, }, Commitment, CommitmentKey, }; @@ -28,6 +29,7 @@ use crate::bellpepper::{ }; use bellpepper_core::ConstraintSystem; +use crate::compute_digest; use crate::nifs::NIFS; mod circuit; // declare the module first @@ -198,9 +200,12 @@ where } } - /// get augmented_circuit_index - pub fn get_augmented_circuit_index(&self) -> usize { - self.augmented_circuit_index + /// get primary/secondary circuit r1cs shape + pub fn get_r1cs_shape(&self) -> (&R1CSShape, &R1CSShape) { + ( + &self.params.r1cs_shape_primary, + &self.params.r1cs_shape_secondary, + ) } /// set primary/secondary commitment key @@ -213,18 +218,15 @@ where self.params.ck_secondary = Some(ck_secondary); } - /// get primary/secondary circuit r1cs shape - pub fn get_r1cs_shape(&self) -> (&R1CSShape, &R1CSShape) { - ( - &self.params.r1cs_shape_primary, - &self.params.r1cs_shape_secondary, - ) - } - - /// get augmented_circuit_index + /// Get the `PublicParams`. pub fn get_public_params(&self) -> &PublicParams { &self.params } + + /// Get this `RunningClaim`'s augmented circuit index. + pub fn get_circuit_index(&self) -> usize { + self.augmented_circuit_index + } } /// A SNARK that proves the correct execution of an non-uniform incremental computation @@ -265,7 +267,7 @@ where z0_primary: &[G1::Scalar], z0_secondary: &[G2::Scalar], ) -> Result { - let pp = &claim.params; + let pp = &claim.get_public_params(); let c_primary = &claim.c_primary; let c_secondary = &claim.c_secondary; // commitment key for primary & secondary circuit @@ -330,7 +332,7 @@ where Some(&u_primary), None, None, - G2::Scalar::from(claim.augmented_circuit_index as u64), + G2::Scalar::from(claim.get_circuit_index() as u64), ); let circuit_secondary: SuperNovaAugmentedCircuit<'_, G1, C2> = SuperNovaAugmentedCircuit::new( &pp.augmented_circuit_params_secondary, @@ -502,8 +504,8 @@ where // Split into `if let`/`else` statement // to avoid `returns a value referencing data owned by closure` error on `&RelaxedR1CSInstance::default` and `RelaxedR1CSWitness::default` let (nifs_primary, (r_U_primary_folded, r_W_primary_folded)) = match ( - self.r_U_primary.get(claim.get_augmented_circuit_index()), - self.r_W_primary.get(claim.get_augmented_circuit_index()), + self.r_U_primary.get(claim.get_circuit_index()), + self.r_W_primary.get(claim.get_circuit_index()), ) { (Some(Some(r_U_primary)), Some(Some(r_W_primary))) => NIFS::prove( ck_primary, @@ -542,7 +544,7 @@ where Some(&l_u_primary), Some(&binding), None, - G2::Scalar::from(claim.get_augmented_circuit_index() as u64), + G2::Scalar::from(claim.get_circuit_index() as u64), ); let circuit_secondary: SuperNovaAugmentedCircuit<'_, G1, C2> = SuperNovaAugmentedCircuit::new( @@ -592,8 +594,8 @@ where } // clone and updated running instance on respective circuit_index - self.r_U_primary[claim.get_augmented_circuit_index()] = Some(r_U_primary_folded); - self.r_W_primary[claim.get_augmented_circuit_index()] = Some(r_W_primary_folded); + self.r_U_primary[claim.get_circuit_index()] = Some(r_U_primary_folded); + self.r_W_primary[claim.get_circuit_index()] = Some(r_W_primary_folded); self.r_W_secondary = vec![Some(r_W_secondary_next)]; self.r_U_secondary = vec![Some(r_U_secondary_next)]; self.l_w_secondary = l_w_secondary_next; @@ -602,7 +604,7 @@ where self.zi_primary = zi_primary; self.zi_secondary = zi_secondary; self.program_counter = zi_primary_pc_next; - self.augmented_circuit_index = claim.get_augmented_circuit_index(); + self.augmented_circuit_index = claim.get_circuit_index(); Ok(()) } @@ -631,7 +633,7 @@ where let pp = &claim.params; let ck_primary = pp.ck_primary.as_ref().ok_or(SuperNovaError::MissingCK)?; - self.r_U_primary[claim.get_augmented_circuit_index()] + self.r_U_primary[claim.get_circuit_index()] .as_ref() .map_or(Ok(()), |U| { if U.X.len() != 2 { @@ -730,10 +732,10 @@ where || { pp.r1cs_shape_primary.is_sat_relaxed( pp.ck_primary.as_ref().unwrap(), - self.r_U_primary[claim.get_augmented_circuit_index()] + self.r_U_primary[claim.get_circuit_index()] .as_ref() .unwrap_or(&default_instance), - self.r_W_primary[claim.get_augmented_circuit_index()] + self.r_W_primary[claim.get_circuit_index()] .as_ref() .unwrap_or(&default_witness), ) @@ -813,13 +815,74 @@ where (ck_primary, ck_secondary) } -/// SuperNova helper trait, for implementors that provide sets of sub-circuits to be proved via NIVC. -pub trait CircuitSet { +/// SuperNova helper trait, for implementors that provide sets of sub-circuits to be proved via NIVC. `C1` must be a +/// type (likely an `Enum`) for which a potentially-distinct instance can be supplied for each `index` below +/// `self.num_circuits()`. +pub trait NonUniformCircuit +where + G1: Group::Scalar>, + G2: Group::Scalar>, + C1: StepCircuit, +{ /// Initial program counter, defaults to zero. - fn initial_program_counter(&self) -> G::Scalar { - G::Scalar::ZERO + fn initial_program_counter(&self) -> G1::Scalar { + G1::Scalar::ZERO + } + + /// Return the initial running claims for `NonUniformCircuit`'s sub-circuits. + fn initial_running_claims( + &self, + ) -> Vec>> { + (0..self.num_circuits()) + .map(|i| { + RunningClaim::new( + i, + self.primary_circuit(i), + self.secondary_circuit(), + self.num_circuits(), + ) + }) + .collect() } - /// How many augmented circuits are provided? - fn num_augmented_circuits(&self) -> usize; + /// Return digest and initial running claims. + fn compute_digest_and_initial_running_claims( + &self, + ) -> ( + G1::Scalar, + Vec>>, + ) { + let mut running_claims = self.initial_running_claims(); + + let running_claim_params = running_claims + .iter() + .map(|c| c.get_public_params()) + .collect::>(); + + let (ck_primary, ck_secondary) = compute_commitment_keys(&running_claim_params); + + for claim in &mut running_claims { + claim.set_commitment_key(ck_primary.clone(), ck_secondary.clone()); + } + + let public_params = running_claims + .iter() + .map(|c| c.get_public_params()) + .collect::>(); + + let digest = compute_digest::>(&public_params); + + (digest, running_claims) + } + + /// How many circuits are provided? + fn num_circuits(&self) -> usize; + + /// Return a new instance of the primary circuit at `index`. + fn primary_circuit(&self, circuit_index: usize) -> C1; + + /// Return a new instance of the secondary circuit. + fn secondary_circuit(&self) -> TrivialSecondaryCircuit { + Default::default() + } } diff --git a/src/supernova/test.rs b/src/supernova/test.rs index 020b7b90..80b0dd15 100644 --- a/src/supernova/test.rs +++ b/src/supernova/test.rs @@ -2,12 +2,9 @@ use crate::bellpepper::test_shape_cs::TestShapeCS; use crate::gadgets::utils::alloc_const; use crate::gadgets::utils::alloc_num_equals; use crate::gadgets::utils::conditionally_select; +use crate::gadgets::utils::{add_allocated_num, alloc_one, alloc_zero}; use crate::provider::poseidon::PoseidonConstantsCircuit; use crate::traits::circuit_supernova::{TrivialSecondaryCircuit, TrivialTestCircuit}; -use crate::{ - compute_digest, - gadgets::utils::{add_allocated_num, alloc_one, alloc_zero}, -}; use bellpepper::gadgets::boolean::Boolean; use bellpepper_core::num::AllocatedNum; use bellpepper_core::{ConstraintSystem, LinearCombination, SynthesisError}; @@ -276,26 +273,74 @@ fn print_constraints_name_on_error_index( const OPCODE_0: usize = 0; const OPCODE_1: usize = 1; -#[derive(Clone, Debug)] -struct TestROM { - op0: CubicCircuit, - op1: SquareCircuit, +struct TestROM +where + G1: Group::Scalar>, + G2: Group::Scalar>, + S: StepCircuit + Default, +{ rom: Vec, - _p: PhantomData, + _p: PhantomData<(G1, G2, S)>, +} + +#[derive(Debug, Clone)] +enum TestRomCircuit { + Cubic(CubicCircuit), + Square(SquareCircuit), +} + +impl StepCircuit for TestRomCircuit { + fn arity(&self) -> usize { + match self { + Self::Cubic(x) => x.arity(), + Self::Square(x) => x.arity(), + } + } + + fn synthesize>( + &self, + cs: &mut CS, + pc: Option<&AllocatedNum>, + z: &[AllocatedNum], + ) -> Result<(Option>, Vec>), SynthesisError> { + match self { + Self::Cubic(x) => x.synthesize(cs, pc, z), + Self::Square(x) => x.synthesize(cs, pc, z), + } + } } -impl CircuitSet for TestROM { - fn num_augmented_circuits(&self) -> usize { +impl NonUniformCircuit> + for TestROM> +where + G1: Group::Scalar>, + G2: Group::Scalar>, +{ + fn num_circuits(&self) -> usize { 2 } + + fn primary_circuit(&self, circuit_index: usize) -> TestRomCircuit { + match circuit_index { + 0 => TestRomCircuit::Cubic(CubicCircuit::new(circuit_index, self.rom.len())), + 1 => TestRomCircuit::Square(SquareCircuit::new(circuit_index, self.rom.len())), + _ => panic!("unsupported primary circuit index"), + } + } + + fn secondary_circuit(&self) -> TrivialSecondaryCircuit { + Default::default() + } } -impl TestROM { +impl TestROM +where + G1: Group::Scalar>, + G2: Group::Scalar>, + S: StepCircuit + Default, +{ fn new(rom: Vec) -> Self { - let rom_len = rom.len(); Self { - op0: CubicCircuit::new(0, rom_len), - op1: SquareCircuit::new(1, rom_len), rom, _p: Default::default(), } @@ -333,50 +378,13 @@ where OPCODE_1, OPCODE_1, OPCODE_0, OPCODE_0, OPCODE_1, OPCODE_1, OPCODE_0, OPCODE_0, OPCODE_1, OPCODE_1, ]; // Rom can be arbitrary length. - let circuit_secondary = TrivialSecondaryCircuit::default(); - - let test_rom = TestROM::::new(rom); - // Structuring running claims - let mut running_claim1 = RunningClaim::< - G1, - G2, - CubicCircuit<::Scalar>, - TrivialSecondaryCircuit<::Scalar>, - >::new( - OPCODE_0, - test_rom.op0.clone(), - circuit_secondary.clone(), - test_rom.num_augmented_circuits(), - ); - - let mut running_claim2 = RunningClaim::< - G1, - G2, - SquareCircuit<::Scalar>, - TrivialSecondaryCircuit<::Scalar>, - >::new( - OPCODE_1, - test_rom.op1.clone(), - circuit_secondary, - test_rom.num_augmented_circuits(), - ); - - // generate the commitkey based on max num of constraints and reused it for all other augmented circuit - let (ck_primary, ck_secondary) = - compute_commitment_keys(&[&running_claim1.params, &running_claim2.params]); - - // set unified ck_primary, ck_secondary and update digest - running_claim1.set_commitment_key(ck_primary.clone(), ck_secondary.clone()); - running_claim2.set_commitment_key(ck_primary, ck_secondary); - - let digest = compute_digest::>(&[ - running_claim1.get_public_params(), - running_claim2.get_public_params(), - ]); + let test_rom = TestROM::>::new(rom); let num_steps = test_rom.num_steps(); let initial_program_counter = test_rom.initial_program_counter(); + let (digest, running_claims) = test_rom.compute_digest_and_initial_running_claims(); + // extend z0_primary/secondary with rom content let mut z0_primary = vec![::Scalar::ONE]; z0_primary.extend( @@ -399,62 +407,47 @@ where program_counter.to_repr().as_ref()[0..4].try_into().unwrap(), ) as usize]; - let mut recursive_snark = recursive_snark_option.unwrap_or_else(|| { - if augmented_circuit_index == OPCODE_0 { - RecursiveSNARK::iter_base_step( - &running_claim1, - digest, - Some(program_counter), - augmented_circuit_index, - test_rom.num_augmented_circuits(), - &z0_primary, - &z0_secondary, - ) - .unwrap() - } else if augmented_circuit_index == OPCODE_1 { - RecursiveSNARK::iter_base_step( - &running_claim2, + let mut recursive_snark = + recursive_snark_option.unwrap_or_else(|| match augmented_circuit_index { + OPCODE_0 | OPCODE_1 => RecursiveSNARK::iter_base_step( + &running_claims[augmented_circuit_index], digest, Some(program_counter), augmented_circuit_index, - test_rom.num_augmented_circuits(), + test_rom.num_circuits(), &z0_primary, &z0_secondary, ) - .unwrap() - } else { - unimplemented!() - } - }); - - if augmented_circuit_index == OPCODE_0 { - recursive_snark - .prove_step(&running_claim1, &z0_primary, &z0_secondary) - .unwrap(); - recursive_snark - .verify(&running_claim1, &z0_primary, &z0_secondary) - .map_err(|err| { - print_constraints_name_on_error_index( - err, - &running_claim1, - test_rom.num_augmented_circuits(), + .unwrap(), + _ => { + unimplemented!() + } + }); + match augmented_circuit_index { + OPCODE_0 | OPCODE_1 => { + recursive_snark + .prove_step( + &running_claims[augmented_circuit_index], + &z0_primary, + &z0_secondary, ) - }) - .unwrap(); - } else if augmented_circuit_index == OPCODE_1 { - recursive_snark - .prove_step(&running_claim2, &z0_primary, &z0_secondary) - .unwrap(); - recursive_snark - .verify(&running_claim2, &z0_primary, &z0_secondary) - .map_err(|err| { - print_constraints_name_on_error_index( - err, - &running_claim2, - test_rom.num_augmented_circuits(), + .unwrap(); + recursive_snark + .verify( + &running_claims[augmented_circuit_index], + &z0_primary, + &z0_secondary, ) - }) - .unwrap(); + .map_err(|err| { + print_constraints_name_on_error_index( + err, + &running_claims[augmented_circuit_index], + test_rom.num_circuits(), + ) + }) + .unwrap(); + } + _ => (), } recursive_snark_option = Some(recursive_snark) }