From 7ea3effef7dcaa3e06daec9e0f92ff2cc9826393 Mon Sep 17 00:00:00 2001 From: Adrian Hamelink Date: Thu, 14 Dec 2023 14:52:06 +0100 Subject: [PATCH] apply bound witness fix to ppsnark --- src/spartan/batched_ppsnark.rs | 80 +-------------------- src/spartan/ppsnark.rs | 128 +++++++++++++++++++++++++++++---- 2 files changed, 116 insertions(+), 92 deletions(-) diff --git a/src/spartan/batched_ppsnark.rs b/src/spartan/batched_ppsnark.rs index 2678838d..df3a180e 100644 --- a/src/spartan/batched_ppsnark.rs +++ b/src/spartan/batched_ppsnark.rs @@ -20,7 +20,7 @@ use crate::{ powers, ppsnark::{ InnerSumcheckInstance, MemorySumcheckInstance, OuterSumcheckInstance, - R1CSShapeSparkCommitment, R1CSShapeSparkRepr, SumcheckEngine, + R1CSShapeSparkCommitment, R1CSShapeSparkRepr, SumcheckEngine, WitnessBoundSumcheck, }, sumcheck::SumcheckProof, PolyEvalInstance, PolyEvalWitness, @@ -95,84 +95,6 @@ impl> DigestHelperTrait for VerifierK } } -/// The [WitnessBoundSumcheck] ensures that the witness polynomial W defined over n = log(N) variables, -/// is zero outside of the first `num_vars = 2^m` entries. -/// -/// # Details -/// -/// The `W` polynomial is padded with zeros to size N = 2^n. -/// The `masked_eq` polynomials is defined as with regards to a random challenge `tau` as -/// the eq(tau) polynomial, where the first 2^m evaluations to 0. -/// -/// The instance is given by -/// `0 = ∑_{0≤i<2^n} masked_eq[i] * W[i]`. -/// It is equivalent to the expression -/// `0 = ∑_{2^m≤i<2^n} eq[i] * W[i]` -/// Since `eq` is random, the instance is only satisfied if `W[2^{m}..] = 0`. -pub(in crate::spartan) struct WitnessBoundSumcheck { - poly_W: MultilinearPolynomial, - poly_masked_eq: MultilinearPolynomial, -} - -impl WitnessBoundSumcheck { - pub fn new(tau: E::Scalar, poly_W_padded: Vec, num_vars: usize) -> Self { - let num_vars_log = num_vars.log_2(); - // When num_vars = num_rounds, we shouldn't have to prove anything - // but we still want this instance to compute the evaluation of W - let num_rounds = poly_W_padded.len().log_2(); - assert!(num_vars_log < num_rounds); - - let tau_coords = PowPolynomial::new(&tau, num_rounds).coordinates(); - let poly_masked_eq_evals = - MaskedEqPolynomial::new(&EqPolynomial::new(tau_coords), num_vars_log).evals(); - - Self { - poly_W: MultilinearPolynomial::new(poly_W_padded), - poly_masked_eq: MultilinearPolynomial::new(poly_masked_eq_evals), - } - } -} -impl SumcheckEngine for WitnessBoundSumcheck { - fn initial_claims(&self) -> Vec { - vec![E::Scalar::ZERO] - } - - fn degree(&self) -> usize { - 3 - } - - fn size(&self) -> usize { - assert_eq!(self.poly_W.len(), self.poly_masked_eq.len()); - self.poly_W.len() - } - - fn evaluation_points(&self) -> Vec> { - let comb_func = |poly_A_comp: &E::Scalar, - poly_B_comp: &E::Scalar, - _: &E::Scalar| - -> E::Scalar { *poly_A_comp * *poly_B_comp }; - - let (eval_point_0, eval_point_2, eval_point_3) = SumcheckProof::::compute_eval_points_cubic( - &self.poly_masked_eq, - &self.poly_W, - &self.poly_W, // unused - &comb_func, - ); - - vec![vec![eval_point_0, eval_point_2, eval_point_3]] - } - - fn bound(&mut self, r: &E::Scalar) { - [&mut self.poly_W, &mut self.poly_masked_eq] - .par_iter_mut() - .for_each(|poly| poly.bind_poly_var_top(r)); - } - - fn final_claims(&self) -> Vec> { - vec![vec![self.poly_W[0], self.poly_masked_eq[0]]] - } -} - /// A succinct proof of knowledge of a witness to a relaxed R1CS instance /// The proof is produced using Spartan's combination of the sum-check and /// the commitment to a vector viewed as a polynomial commitment diff --git a/src/spartan/ppsnark.rs b/src/spartan/ppsnark.rs index 30c6bcc2..973d4c5d 100644 --- a/src/spartan/ppsnark.rs +++ b/src/spartan/ppsnark.rs @@ -37,6 +37,8 @@ use once_cell::sync::OnceCell; use rayon::prelude::*; use serde::{Deserialize, Serialize}; +use super::polys::masked_eq::MaskedEqPolynomial; + fn padded(v: &[E::Scalar], n: usize, e: &E::Scalar) -> Vec { let mut v_padded = vec![*e; n]; for (i, v_i) in v.iter().enumerate() { @@ -267,6 +269,84 @@ pub trait SumcheckEngine: Send + Sync { fn final_claims(&self) -> Vec>; } +/// The [WitnessBoundSumcheck] ensures that the witness polynomial W defined over n = log(N) variables, +/// is zero outside of the first `num_vars = 2^m` entries. +/// +/// # Details +/// +/// The `W` polynomial is padded with zeros to size N = 2^n. +/// The `masked_eq` polynomials is defined as with regards to a random challenge `tau` as +/// the eq(tau) polynomial, where the first 2^m evaluations to 0. +/// +/// The instance is given by +/// `0 = ∑_{0≤i<2^n} masked_eq[i] * W[i]`. +/// It is equivalent to the expression +/// `0 = ∑_{2^m≤i<2^n} eq[i] * W[i]` +/// Since `eq` is random, the instance is only satisfied if `W[2^{m}..] = 0`. +pub(in crate::spartan) struct WitnessBoundSumcheck { + poly_W: MultilinearPolynomial, + poly_masked_eq: MultilinearPolynomial, +} + +impl WitnessBoundSumcheck { + pub fn new(tau: E::Scalar, poly_W_padded: Vec, num_vars: usize) -> Self { + let num_vars_log = num_vars.log_2(); + // When num_vars = num_rounds, we shouldn't have to prove anything + // but we still want this instance to compute the evaluation of W + let num_rounds = poly_W_padded.len().log_2(); + assert!(num_vars_log < num_rounds); + + let tau_coords = PowPolynomial::new(&tau, num_rounds).coordinates(); + let poly_masked_eq_evals = + MaskedEqPolynomial::new(&EqPolynomial::new(tau_coords), num_vars_log).evals(); + + Self { + poly_W: MultilinearPolynomial::new(poly_W_padded), + poly_masked_eq: MultilinearPolynomial::new(poly_masked_eq_evals), + } + } +} +impl SumcheckEngine for WitnessBoundSumcheck { + fn initial_claims(&self) -> Vec { + vec![E::Scalar::ZERO] + } + + fn degree(&self) -> usize { + 3 + } + + fn size(&self) -> usize { + assert_eq!(self.poly_W.len(), self.poly_masked_eq.len()); + self.poly_W.len() + } + + fn evaluation_points(&self) -> Vec> { + let comb_func = |poly_A_comp: &E::Scalar, + poly_B_comp: &E::Scalar, + _: &E::Scalar| + -> E::Scalar { *poly_A_comp * *poly_B_comp }; + + let (eval_point_0, eval_point_2, eval_point_3) = SumcheckProof::::compute_eval_points_cubic( + &self.poly_masked_eq, + &self.poly_W, + &self.poly_W, // unused + &comb_func, + ); + + vec![vec![eval_point_0, eval_point_2, eval_point_3]] + } + + fn bound(&mut self, r: &E::Scalar) { + [&mut self.poly_W, &mut self.poly_masked_eq] + .par_iter_mut() + .for_each(|poly| poly.bind_poly_var_top(r)); + } + + fn final_claims(&self) -> Vec> { + vec![vec![self.poly_W[0], self.poly_masked_eq[0]]] + } +} + pub(in crate::spartan) struct MemorySumcheckInstance { // row w_plus_r_row: MultilinearPolynomial, @@ -890,10 +970,11 @@ impl> RelaxedR1CSSNARK where ::Repr: Abomonation, { - fn prove_helper( + fn prove_helper( mem: &mut T1, outer: &mut T2, inner: &mut T3, + witness: &mut T4, transcript: &mut E::TE, ) -> Result< ( @@ -902,6 +983,7 @@ where Vec>, Vec>, Vec>, + Vec>, ), NovaError, > @@ -909,12 +991,15 @@ where T1: SumcheckEngine, T2: SumcheckEngine, T3: SumcheckEngine, + T4: SumcheckEngine, { // sanity checks assert_eq!(mem.size(), outer.size()); assert_eq!(mem.size(), inner.size()); + assert_eq!(mem.size(), witness.size()); assert_eq!(mem.degree(), outer.degree()); assert_eq!(mem.degree(), inner.degree()); + assert_eq!(mem.degree(), witness.degree()); // these claims are already added to the transcript, so we do not need to add let claims = mem @@ -922,6 +1007,7 @@ where .into_iter() .chain(outer.initial_claims()) .chain(inner.initial_claims()) + .chain(witness.initial_claims()) .collect::>(); let s = transcript.squeeze(b"r")?; @@ -935,15 +1021,16 @@ where let mut cubic_polys: Vec> = Vec::new(); let num_rounds = mem.size().log_2(); for _ in 0..num_rounds { - let (evals_mem, (evals_outer, evals_inner)) = rayon::join( - || mem.evaluation_points(), - || rayon::join(|| outer.evaluation_points(), || inner.evaluation_points()), + let ((evals_mem, evals_outer), (evals_inner, evals_witness)) = rayon::join( + || rayon::join(|| mem.evaluation_points(), || outer.evaluation_points()), + || rayon::join(|| inner.evaluation_points(), || witness.evaluation_points()), ); let evals: Vec> = evals_mem .into_iter() .chain(evals_outer.into_iter()) .chain(evals_inner.into_iter()) + .chain(evals_witness.into_iter()) .collect::>>(); assert_eq!(evals.len(), claims.len()); @@ -967,8 +1054,8 @@ where r.push(r_i); let _ = rayon::join( - || mem.bound(&r_i), - || rayon::join(|| outer.bound(&r_i), || inner.bound(&r_i)), + || rayon::join(|| mem.bound(&r_i), || outer.bound(&r_i)), + || rayon::join(|| inner.bound(&r_i), || witness.bound(&r_i)), ); e = poly.evaluate(&r_i); @@ -978,6 +1065,7 @@ where let mem_claims = mem.final_claims(); let outer_claims = outer.final_claims(); let inner_claims = inner.final_claims(); + let witness_claims = witness.final_claims(); Ok(( SumcheckProof::new(cubic_polys), @@ -985,6 +1073,7 @@ where mem_claims, outer_claims, inner_claims, + witness_claims, )) } } @@ -1225,10 +1314,13 @@ where let (mut mem_sc_inst, comm_mem_oracles, mem_oracles) = mem_res?; - let (sc, rand_sc, claims_mem, claims_outer, claims_inner) = Self::prove_helper( + let mut witness_sc_inst = WitnessBoundSumcheck::new(tau, W.clone(), S.num_vars); + + let (sc, rand_sc, claims_mem, claims_outer, claims_inner, claims_witness) = Self::prove_helper( &mut mem_sc_inst, &mut outer_sc_inst, &mut inner_sc_inst, + &mut witness_sc_inst, &mut transcript, )?; @@ -1246,11 +1338,11 @@ where let eval_t_plus_r_inv_col = claims_mem[1][0]; let eval_w_plus_r_inv_col = claims_mem[1][1]; let eval_ts_col = claims_mem[1][2]; + let eval_W = claims_witness[0][0]; // compute the remaining claims that did not come for free from the sum-check prover - let (eval_W, eval_Cz, eval_E, eval_val_A, eval_val_B, eval_val_C, eval_row, eval_col) = { + let (eval_Cz, eval_E, eval_val_A, eval_val_B, eval_val_C, eval_row, eval_col) = { let e = [ - &W, &Cz, &E, &pk.S_repr.val_A, @@ -1262,7 +1354,7 @@ where .into_par_iter() .map(|p| MultilinearPolynomial::evaluate_with(p, &rand_sc)) .collect::>(); - (e[0], e[1], e[2], e[3], e[4], e[5], e[6], e[7]) + (e[0], e[1], e[2], e[3], e[4], e[5], e[6]) }; // all the evaluations are at rand_sc, we can fold them into one claim @@ -1447,7 +1539,7 @@ where let rho = transcript.squeeze(b"r")?; - let num_claims = 9; + let num_claims = 10; let s = transcript.squeeze(b"r")?; let coeffs = powers::(&s, num_claims); let claim = (coeffs[7] + coeffs[8]) * claim; // rest are zeros @@ -1461,7 +1553,12 @@ where let poly_eq_coords = PowPolynomial::new(&rho, num_rounds_sc).coordinates(); EqPolynomial::new(poly_eq_coords).evaluate(&rand_sc) }; - let taus_bound_rand_sc = PowPolynomial::new(&tau, num_rounds_sc).evaluate(&rand_sc); + let taus_coords = PowPolynomial::new(&tau, num_rounds_sc).coordinates(); + let eq_tau = EqPolynomial::new(taus_coords); + + let taus_bound_rand_sc = eq_tau.evaluate(&rand_sc); + let taus_masked_bound_rand_sc = + MaskedEqPolynomial::new(&eq_tau, vk.num_vars.log_2()).evaluate(&rand_sc); let eval_t_plus_r_row = { let eval_addr_row = IdentityPolynomial::new(num_rounds_sc).evaluate(&rand_sc); @@ -1546,7 +1643,12 @@ where * self.eval_L_col * (self.eval_val_A + c * self.eval_val_B + c * c * self.eval_val_C); - claim_mem_final_expected + claim_outer_final_expected + claim_inner_final_expected + let claim_witness_final_expected = coeffs[9] * taus_masked_bound_rand_sc * self.eval_W; + + claim_mem_final_expected + + claim_outer_final_expected + + claim_inner_final_expected + + claim_witness_final_expected }; if claim_sc_final_expected != claim_sc_final {